In [31]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from tqdm.notebook import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataset = torch.load("../datasets/64skulldataset.pt")

dataset = [sample for sample in dataset if sample != None]

NC = []
AD = []
for data in dataset:
    if data[1] == 0:
        NC.append(data)
    else:
        AD.append(data)
        
def process_gan(dataset, s):
    
    output = []
    dataset = [sample[0] for sample in dataset]
    for sample in dataset:
        sample = sample[s][0]
        sample /= torch.max(sample)
        output.append(torch.unsqueeze(sample, 0))
    return output

        
NCgan1 = process_gan(NC, 0)
NCgan2 = process_gan(NC, 1)
NCgan3 = process_gan(NC, 2)

ADgan1 = process_gan(AD, 0)
ADgan2 = process_gan(AD, 1)
ADgan3 = process_gan(AD, 2)

gan1 = []
for i in range(len(ADgan1)):
    gan1.append({"A": NCgan1[i], "B": ADgan1[i]})

gan2 = []
for i in range(len(ADgan2)):
    gan2.append({"A": NCgan2[i], "B": ADgan2[i]})
    
gan3 = []
for i in range(len(ADgan3)):
    gan3.append({"A": NCgan3[i], "B": ADgan3[i]})

batch_size = 1
dataloader1 = DataLoader(gan1, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader2 = DataLoader(gan2, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader3 = DataLoader(gan3, batch_size=batch_size, shuffle=True, num_workers=4)

In [2]:
import random
import time
import datetime
import sys

from torch.autograd import Variable
import torch
import numpy as np

from torchvision.utils import save_image


class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))


class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)


def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, "bias") and m.bias is not None:
            torch.nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


##############################
#           RESNET
##############################


class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)


class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, num_residual_blocks):
        super(GeneratorResNet, self).__init__()

        channels = input_shape[0]

        # Initial convolution block
        out_features = 64
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]
        in_features = out_features

        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(out_features)]

        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features

        # Output layer
        model += [nn.ReflectionPad2d(3), nn.Conv2d(out_features, channels, 7), nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


##############################
#        Discriminator
##############################


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape

        # Calculate output shape of image discriminator (PatchGAN)
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

In [53]:
import os
import numpy as np
import math
import itertools
import datetime
import time

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable


import torch.nn as nn
import torch.nn.functional as F
import torch


epoch = 0
n_epochs = 100
dataset_name = 'CycleGAN2'
batch_size = 1
lr=0.0002
b1 = 0.5
b2 = 0.999
decay_epoch = 100
n_cpu = 8
img_height = 64
img_width = 64
channels = 1
sample_interval = 500
checkpoint_interval = 25
n_residual_blocks = 9
lambda_cyc = 10
lambda_id = 5


def sample_images(batches_done, dataloader, images):
    """Saves a generated sample from the test set"""
    imgs = next(iter(dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)

    real_A = make_grid(real_A, nrow=4, normalize=True)
    real_B = make_grid(real_B, nrow=4, normalize=True)
    fake_A = make_grid(fake_A, nrow=4, normalize=True)
    fake_B = make_grid(fake_B, nrow=4, normalize=True)

    image_grid = torch.stack((real_A, fake_B, real_B, fake_A), 0)
    save_image(image_grid, "images/%s/%s.png" % (dataset_name, batches_done), normalize=False)



In [51]:
G_losses = []
D_losses = []
        
def train_gan(dataloader, epoch):
    prev_time = time.time()
    for epoch in range(epoch, n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(batch["A"].type(Tensor))
            real_B = Variable(batch["B"].type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
            fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            G_AB.train()
            G_BA.train()

            optimizer_G.zero_grad()


            # Identity loss
            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)

            loss_identity = (loss_id_A + loss_id_B) / 2

            # GAN loss
            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A)
            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Total loss
            loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity
                
            loss_G.backward()
            optimizer_G.step()

            # -----------------------
            #  Train Discriminator A
            # -----------------------
            
           
                
            optimizer_D_A.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = fake_A_buffer.push_and_pop(fake_A)
            loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            optimizer_D_A.step()

            # -----------------------
            #  Train Discriminator B
            # -----------------------

            optimizer_D_B.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_B(real_B), valid)
            # Fake loss (on batch of previously generated samples)
            fake_B_ = fake_B_buffer.push_and_pop(fake_B)
            loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            optimizer_D_B.step()

            loss_D = (loss_D_A + loss_D_B) / 2

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()
            
            if i % 100 == 0:
            
                print(
                    "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
                    % (
                        epoch,
                        n_epochs,
                        i,
                        len(dataloader),
                        loss_D.item(),
                        loss_G.item(),
                        loss_GAN.item(),
                        loss_cycle.item(),
                        loss_identity.item(),
                        time_left,
                    )
                )

            # If at sample interval save image
            if batches_done % sample_interval == 0:
                sample_images(batches_done, dataloader, images)
                
            
            G_losses.append(loss_G.item())
            D_losses.append(loss_D.item())
            
            
        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (dataset_name, epoch))
            torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (dataset_name, epoch))
            torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (dataset_name, epoch))
            torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (dataset_name, epoch))


In [4]:
input_shape = (channels, img_height, img_width)

# Initialize generator and discriminator
G_AB = GeneratorResNet(input_shape, n_residual_blocks).cuda()
G_BA = GeneratorResNet(input_shape, n_residual_blocks).cuda()
D_A = Discriminator(input_shape).cuda()
D_B = Discriminator(input_shape).cuda()

# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)

Tensor = torch.cuda.FloatTensor

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Losses
criterion_GAN = torch.nn.MSELoss().cuda()
criterion_cycle = torch.nn.L1Loss().cuda()
criterion_identity = torch.nn.L1Loss().cuda()



if epoch != 0:

    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (dataset_name, epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (dataset_name, epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (dataset_name, epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (dataset_name, epoch)))
else:

    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)



In [None]:
epoch = 0
dataset_name = 'CycleGAN3'

os.makedirs("images/%s" % dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % dataset_name, exist_ok=True)



train_gan(dataloader3, epoch)
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

[Epoch 0/200] [Batch 0/476] [D loss: 0.453004] [G loss: 2.251709, adv: 1.445777, cycle: 0.059891, identity: 0.041404] ETA: 4:04:21.734772
[Epoch 0/200] [Batch 100/476] [D loss: 0.590417] [G loss: 2.922217, adv: 1.870370, cycle: 0.073319, identity: 0.063732] ETA: 1:48:32.766838
[Epoch 0/200] [Batch 200/476] [D loss: 0.428210] [G loss: 1.888819, adv: 1.187153, cycle: 0.048301, identity: 0.043732] ETA: 1:40:32.062769
[Epoch 0/200] [Batch 300/476] [D loss: 0.508606] [G loss: 2.310069, adv: 1.484732, cycle: 0.057868, identity: 0.049332] ETA: 1:52:17.253261
[Epoch 0/200] [Batch 400/476] [D loss: 0.550286] [G loss: 2.099478, adv: 1.475994, cycle: 0.042640, identity: 0.039417] ETA: 1:41:39.013424
[Epoch 1/200] [Batch 0/476] [D loss: 0.514739] [G loss: 2.200931, adv: 1.498567, cycle: 0.046734, identity: 0.047005] ETA: 6:12:26.310460
[Epoch 1/200] [Batch 100/476] [D loss: 0.453647] [G loss: 2.516351, adv: 1.583667, cycle: 0.058232, identity: 0.070074] ETA: 1:52:44.561630
[Epoch 1/200] [Batch 200

[Epoch 11/200] [Batch 400/476] [D loss: 0.423341] [G loss: 1.869408, adv: 1.205441, cycle: 0.049268, identity: 0.034257] ETA: 1:46:09.388084
[Epoch 12/200] [Batch 0/476] [D loss: 0.397314] [G loss: 2.362360, adv: 1.563859, cycle: 0.053647, identity: 0.052407] ETA: 4:37:15.603420
[Epoch 12/200] [Batch 100/476] [D loss: 0.498287] [G loss: 2.436456, adv: 1.630832, cycle: 0.054671, identity: 0.051782] ETA: 1:31:27.245393
[Epoch 12/200] [Batch 200/476] [D loss: 0.459675] [G loss: 2.342281, adv: 1.328782, cycle: 0.064769, identity: 0.073162] ETA: 1:36:37.764492
[Epoch 12/200] [Batch 300/476] [D loss: 0.436274] [G loss: 2.448567, adv: 1.519252, cycle: 0.067952, identity: 0.049958] ETA: 1:32:36.834722
[Epoch 12/200] [Batch 400/476] [D loss: 0.544095] [G loss: 2.589240, adv: 1.566826, cycle: 0.071475, identity: 0.061532] ETA: 1:30:05.915771
[Epoch 13/200] [Batch 0/476] [D loss: 0.543753] [G loss: 3.367979, adv: 2.084400, cycle: 0.093810, identity: 0.069096] ETA: 4:52:14.857046
[Epoch 13/200] [B

[Epoch 23/200] [Batch 300/476] [D loss: 0.477721] [G loss: 2.617160, adv: 1.645809, cycle: 0.063164, identity: 0.067942] ETA: 1:23:11.099064
[Epoch 23/200] [Batch 400/476] [D loss: 0.503884] [G loss: 2.935109, adv: 1.830579, cycle: 0.075980, identity: 0.068947] ETA: 1:28:59.489869
[Epoch 24/200] [Batch 0/476] [D loss: 0.461209] [G loss: 2.046082, adv: 1.279859, cycle: 0.050284, identity: 0.052676] ETA: 4:11:23.081131
[Epoch 24/200] [Batch 100/476] [D loss: 0.448921] [G loss: 2.094182, adv: 1.338816, cycle: 0.051251, identity: 0.048572] ETA: 1:28:02.118522
[Epoch 24/200] [Batch 200/476] [D loss: 0.502411] [G loss: 2.086637, adv: 1.237586, cycle: 0.054390, identity: 0.061030] ETA: 1:31:52.706997
[Epoch 24/200] [Batch 300/476] [D loss: 0.426006] [G loss: 2.239276, adv: 1.256577, cycle: 0.068099, identity: 0.060343] ETA: 1:32:04.281692
[Epoch 24/200] [Batch 400/476] [D loss: 0.474662] [G loss: 2.013623, adv: 1.293030, cycle: 0.045923, identity: 0.052273] ETA: 1:33:32.682575
[Epoch 25/200] 

[Epoch 35/200] [Batch 200/476] [D loss: 0.520717] [G loss: 1.944502, adv: 1.281647, cycle: 0.042490, identity: 0.047590] ETA: 1:19:24.964366
[Epoch 35/200] [Batch 300/476] [D loss: 0.530924] [G loss: 2.504402, adv: 1.565632, cycle: 0.065958, identity: 0.055838] ETA: 1:21:55.350609
[Epoch 35/200] [Batch 400/476] [D loss: 0.511608] [G loss: 2.564733, adv: 1.535026, cycle: 0.076718, identity: 0.052505] ETA: 1:23:43.568358
[Epoch 36/200] [Batch 0/476] [D loss: 0.458950] [G loss: 2.109517, adv: 1.368830, cycle: 0.049977, identity: 0.048183] ETA: 3:55:27.443184
[Epoch 36/200] [Batch 100/476] [D loss: 0.437432] [G loss: 1.869408, adv: 1.205441, cycle: 0.049268, identity: 0.034257] ETA: 1:21:19.311587
[Epoch 36/200] [Batch 200/476] [D loss: 0.530694] [G loss: 2.257644, adv: 1.337681, cycle: 0.063516, identity: 0.056961] ETA: 1:22:10.732216
[Epoch 36/200] [Batch 300/476] [D loss: 0.457076] [G loss: 2.957239, adv: 1.580116, cycle: 0.101603, identity: 0.072218] ETA: 1:16:53.570212
[Epoch 36/200] 

[Epoch 47/200] [Batch 100/476] [D loss: 0.472666] [G loss: 2.353031, adv: 1.625188, cycle: 0.050178, identity: 0.045212] ETA: 1:19:17.391212
[Epoch 47/200] [Batch 200/476] [D loss: 0.434577] [G loss: 2.364021, adv: 1.398483, cycle: 0.067953, identity: 0.057202] ETA: 1:25:23.002426
[Epoch 47/200] [Batch 300/476] [D loss: 0.472487] [G loss: 2.352597, adv: 1.506616, cycle: 0.060196, identity: 0.048804] ETA: 1:19:08.562355
[Epoch 47/200] [Batch 400/476] [D loss: 0.488543] [G loss: 2.368036, adv: 1.407190, cycle: 0.067905, identity: 0.056360] ETA: 1:15:41.859658
[Epoch 48/200] [Batch 0/476] [D loss: 0.450627] [G loss: 2.694856, adv: 1.512193, cycle: 0.085152, identity: 0.066229] ETA: 3:50:49.832504
[Epoch 48/200] [Batch 100/476] [D loss: 0.442542] [G loss: 1.953194, adv: 1.062913, cycle: 0.059989, identity: 0.058079] ETA: 1:18:29.476061
[Epoch 48/200] [Batch 200/476] [D loss: 0.444010] [G loss: 1.780614, adv: 1.074807, cycle: 0.043017, identity: 0.055128] ETA: 1:20:53.943201
[Epoch 48/200] 

[Epoch 59/200] [Batch 0/476] [D loss: 0.406364] [G loss: 1.996421, adv: 1.268855, cycle: 0.053459, identity: 0.038596] ETA: 3:24:53.243059
[Epoch 59/200] [Batch 100/476] [D loss: 0.508097] [G loss: 2.002154, adv: 1.168095, cycle: 0.057724, identity: 0.051363] ETA: 1:07:16.582512
[Epoch 59/200] [Batch 200/476] [D loss: 0.432177] [G loss: 2.142757, adv: 1.293353, cycle: 0.057772, identity: 0.054336] ETA: 1:04:06.960365
[Epoch 59/200] [Batch 300/476] [D loss: 0.483189] [G loss: 2.018765, adv: 1.168164, cycle: 0.050804, identity: 0.068513] ETA: 1:11:21.218811
[Epoch 59/200] [Batch 400/476] [D loss: 0.575558] [G loss: 2.775151, adv: 1.590730, cycle: 0.083823, identity: 0.069239] ETA: 1:06:23.725425
[Epoch 60/200] [Batch 0/476] [D loss: 0.422798] [G loss: 2.230638, adv: 1.338627, cycle: 0.061377, identity: 0.055649] ETA: 3:26:45.025043
[Epoch 60/200] [Batch 100/476] [D loss: 0.496897] [G loss: 1.797628, adv: 1.095371, cycle: 0.043381, identity: 0.053689] ETA: 1:07:26.620617
[Epoch 60/200] [B

[Epoch 70/200] [Batch 400/476] [D loss: 0.444049] [G loss: 2.528434, adv: 1.495501, cycle: 0.074188, identity: 0.058211] ETA: 1:06:43.107424
[Epoch 71/200] [Batch 0/476] [D loss: 0.536739] [G loss: 2.456072, adv: 1.622720, cycle: 0.059636, identity: 0.047398] ETA: 3:14:06.326277
[Epoch 71/200] [Batch 100/476] [D loss: 0.442771] [G loss: 2.310849, adv: 1.476599, cycle: 0.055479, identity: 0.055891] ETA: 1:06:41.571922
[Epoch 71/200] [Batch 200/476] [D loss: 0.540985] [G loss: 2.198469, adv: 1.262367, cycle: 0.065502, identity: 0.056217] ETA: 1:05:46.540123
[Epoch 71/200] [Batch 300/476] [D loss: 0.499525] [G loss: 2.775151, adv: 1.590730, cycle: 0.083823, identity: 0.069239] ETA: 1:01:33.566689
[Epoch 71/200] [Batch 400/476] [D loss: 0.396071] [G loss: 2.152011, adv: 1.320026, cycle: 0.062931, identity: 0.040536] ETA: 1:02:27.008919
[Epoch 72/200] [Batch 0/476] [D loss: 0.559914] [G loss: 3.367979, adv: 2.084400, cycle: 0.093810, identity: 0.069096] ETA: 3:26:02.519531
[Epoch 72/200] [B

[Epoch 82/200] [Batch 300/476] [D loss: 0.464028] [G loss: 1.940802, adv: 1.163587, cycle: 0.048516, identity: 0.058410] ETA: 1:11:12.872952
[Epoch 82/200] [Batch 400/476] [D loss: 0.443429] [G loss: 2.274064, adv: 1.301475, cycle: 0.065743, identity: 0.063032] ETA: 0:58:40.495407
[Epoch 83/200] [Batch 0/476] [D loss: 0.554077] [G loss: 2.257644, adv: 1.337681, cycle: 0.063516, identity: 0.056961] ETA: 3:08:09.612605
[Epoch 83/200] [Batch 100/476] [D loss: 0.399049] [G loss: 2.328885, adv: 1.326483, cycle: 0.071826, identity: 0.056829] ETA: 1:03:53.091452
[Epoch 83/200] [Batch 200/476] [D loss: 0.449289] [G loss: 2.567541, adv: 1.563485, cycle: 0.075203, identity: 0.050406] ETA: 1:06:31.085936
[Epoch 83/200] [Batch 300/476] [D loss: 0.513862] [G loss: 3.367979, adv: 2.084400, cycle: 0.093810, identity: 0.069096] ETA: 1:07:18.727066
[Epoch 83/200] [Batch 400/476] [D loss: 0.471028] [G loss: 1.983853, adv: 1.202811, cycle: 0.055378, identity: 0.045453] ETA: 1:08:17.955494
[Epoch 84/200] 

[Epoch 94/200] [Batch 200/476] [D loss: 0.555611] [G loss: 2.285561, adv: 1.324404, cycle: 0.061952, identity: 0.068327] ETA: 0:59:52.001129
[Epoch 94/200] [Batch 300/476] [D loss: 0.590813] [G loss: 2.780258, adv: 1.841271, cycle: 0.063669, identity: 0.060459] ETA: 0:53:46.181803
[Epoch 94/200] [Batch 400/476] [D loss: 0.498287] [G loss: 2.436456, adv: 1.630832, cycle: 0.054671, identity: 0.051782] ETA: 0:56:17.962263
[Epoch 95/200] [Batch 0/476] [D loss: 0.388284] [G loss: 2.441398, adv: 1.515830, cycle: 0.061607, identity: 0.061899] ETA: 2:36:18.435445
[Epoch 95/200] [Batch 100/476] [D loss: 0.514386] [G loss: 2.340749, adv: 1.413212, cycle: 0.063346, identity: 0.058815] ETA: 0:51:38.044167
[Epoch 95/200] [Batch 200/476] [D loss: 0.541244] [G loss: 3.193427, adv: 2.039233, cycle: 0.085453, identity: 0.059933] ETA: 0:54:49.894314
[Epoch 95/200] [Batch 300/476] [D loss: 0.533777] [G loss: 2.466260, adv: 1.542658, cycle: 0.060142, identity: 0.064437] ETA: 0:48:35.402927
[Epoch 95/200] 

[Epoch 106/200] [Batch 100/476] [D loss: 0.452207] [G loss: 2.021915, adv: 1.325482, cycle: 0.046784, identity: 0.045719] ETA: 0:55:58.594837
[Epoch 106/200] [Batch 200/476] [D loss: 0.468238] [G loss: 2.374851, adv: 1.493929, cycle: 0.059964, identity: 0.056256] ETA: 0:56:16.538818
[Epoch 106/200] [Batch 300/476] [D loss: 0.489644] [G loss: 1.841742, adv: 1.205466, cycle: 0.043214, identity: 0.040827] ETA: 0:59:07.421055
[Epoch 106/200] [Batch 400/476] [D loss: 0.512845] [G loss: 2.065351, adv: 1.299897, cycle: 0.049964, identity: 0.053162] ETA: 0:50:39.965803
[Epoch 107/200] [Batch 0/476] [D loss: 0.479084] [G loss: 2.496989, adv: 1.594411, cycle: 0.058681, identity: 0.063153] ETA: 2:28:12.093693
[Epoch 107/200] [Batch 100/476] [D loss: 0.494778] [G loss: 1.973574, adv: 1.264193, cycle: 0.048858, identity: 0.044160] ETA: 0:57:01.087027
[Epoch 107/200] [Batch 200/476] [D loss: 0.392010] [G loss: 2.359776, adv: 1.458809, cycle: 0.064439, identity: 0.051316] ETA: 0:53:44.621316
[Epoch 1

In [48]:
dataset_name = 'CycleGAN'

In [49]:
ADgan1 = []
NCgan1 = []

ADgan2 = []
NCgan2 = []

ADgan3 = []
NCgan3 = []


epoch = 25
G_AB = GeneratorResNet(input_shape, n_residual_blocks).cuda()
G_BA = GeneratorResNet(input_shape, n_residual_blocks).cuda()
D_A = Discriminator(input_shape).cuda()
D_B = Discriminator(input_shape).cuda()
G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (dataset_name, epoch)))
G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (dataset_name, epoch)))
D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (dataset_name, epoch)))
D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (dataset_name, epoch)))

for imgs in tqdm(dataloader1):

    G_AB.eval()
    G_BA.eval()
    
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = torch.squeeze(G_AB(real_A))
    fake_B = fake_B.detach().cpu()

    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = torch.squeeze(G_BA(real_B))
    fake_A = fake_A.detach().cpu()
    
    B = torch.stack([fake_B, fake_B, fake_B], 0)
    ADgan1.append((B.reshape(3, 64, 64).double(), torch.tensor(1).long()))

    A = torch.stack([fake_A, fake_A, fake_A], 0)
    NCgan1.append((A.reshape(3, 64, 64).double(), torch.tensor(0).long()))
    
    
torch.save(ADgan1, '../datasets/64ADgan1.pt')
torch.save(NCgan1, '../datasets/64NCgan1.pt')

    
# for imgs in tqdm(dataloader2):

#     G_AB.eval()
#     G_BA.eval()
    
#     real_A = Variable(imgs["A"].type(Tensor))
#     fake_B = torch.squeeze(G_AB(real_A))
#     fake_B = fake_B.detach().cpu()

#     real_B = Variable(imgs["B"].type(Tensor))
#     fake_A = torch.squeeze(G_BA(real_B))
#     fake_A = fake_A.detach().cpu()
    
#     ADgan2.append(fake_B)
#     NCgan2.append(fake_A)
    
# for imgs in tqdm(dataloader3):

#     G_AB.eval()
#     G_BA.eval()

#     real_A = Variable(imgs["A"].type(Tensor))
#     fake_B = torch.squeeze(G_AB(real_A))
#     fake_B = fake_B.detach().cpu()

#     real_B = Variable(imgs["B"].type(Tensor))
#     fake_A = torch.squeeze(G_BA(real_B))
#     fake_A = fake_A.detach().cpu()
    
#     ADgan3.append(fake_B)
#     NCgan3.append(fake_A)

HBox(children=(FloatProgress(value=0.0, max=476.0), HTML(value='')))


