In [1]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from tqdm.notebook import tqdm
from CycleGAN_utils import *
from CycleGAN_models import *
import os
import numpy as np
import math
import itertools
import datetime
import time
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn.functional as F


epoch = 0
n_epochs = 51
batch_size = 1
lr=0.0002
b1 = 0.5
b2 = 0.999
decay_epoch = 50
n_cpu = 8
img_height = 224
img_width = 224
channels = 1
sample_interval = 100
checkpoint_interval = 25
n_residual_blocks = 9
lambda_cyc = 10
lambda_id = 5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
dataset = torch.load("../datasets/SSdataset.pt")

dataset = [sample for sample in dataset if sample != None]

NC = []
AD = []
for data in dataset:
    if data[1] == 0:
        NC.append(data)
    else:
        AD.append(data)
        
        
def process_gan(dataset, s):
    output = []
    dataset = [sample[0] for sample in dataset]
    for sample in dataset:
        sample = sample[s][0]
        sample /= torch.max(sample)
        output.append(torch.unsqueeze(sample, 0))
    return output

        
NCgan1 = process_gan(NC, 0)
NCgan2 = process_gan(NC, 1)
NCgan3 = process_gan(NC, 2)

ADgan1 = process_gan(AD, 0)
ADgan2 = process_gan(AD, 1)
ADgan3 = process_gan(AD, 2)

gan1 = []
for i in range(len(ADgan1)):
    gan1.append({"A": NCgan1[i], "B": ADgan1[i]})

gan2 = []
for i in range(len(ADgan2)):
    gan2.append({"A": NCgan2[i], "B": ADgan2[i]})
    
gan3 = []
for i in range(len(ADgan3)):
    gan3.append({"A": NCgan3[i], "B": ADgan3[i]})

batch_size = 1
dataloader1 = DataLoader(gan1, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader2 = DataLoader(gan2, batch_size=batch_size, shuffle=True, num_workers=4)
dataloader3 = DataLoader(gan3, batch_size=batch_size, shuffle=True, num_workers=4)

In [3]:
G_losses = []
D_losses = []
        
def sample_images(batches_done, dataloader):
    """Saves a generated sample from the test set"""
    imgs = next(iter(dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)

    real_A = make_grid(real_A, nrow=4, normalize=True)
    real_B = make_grid(real_B, nrow=4, normalize=True)
    fake_A = make_grid(fake_A, nrow=4, normalize=True)
    fake_B = make_grid(fake_B, nrow=4, normalize=True)

    image_grid = torch.stack((real_A, fake_B, real_B, fake_A), 0)
    save_image(image_grid, "ganimages/%s/%s.png" % (dataset_name, batches_done), normalize=False)
    
def train_gan(dataloader, epoch):
    prev_time = time.time()
    for epoch in range(epoch, n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(batch["A"].type(Tensor))
            real_B = Variable(batch["B"].type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
            fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            G_AB.train()
            G_BA.train()

            optimizer_G.zero_grad()


            # Identity loss
            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)

            loss_identity = (loss_id_A + loss_id_B) / 2

            # GAN loss
            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)

            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A)
            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B)

            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

            # Total loss
            loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity
                
            loss_G.backward()
            optimizer_G.step()

            # -----------------------
            #  Train Discriminator A
            # -----------------------
           
            optimizer_D_A.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_A(real_A), valid)
            # Fake loss (on batch of previously generated samples)
            fake_A_ = fake_A_buffer.push_and_pop(fake_A)
            loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
            # Total loss
            loss_D_A = (loss_real + loss_fake) / 2

            loss_D_A.backward()
            optimizer_D_A.step()

            # -----------------------
            #  Train Discriminator B
            # -----------------------

            optimizer_D_B.zero_grad()

            # Real loss
            loss_real = criterion_GAN(D_B(real_B), valid)
            # Fake loss (on batch of previously generated samples)
            fake_B_ = fake_B_buffer.push_and_pop(fake_B)
            loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
            # Total loss
            loss_D_B = (loss_real + loss_fake) / 2

            loss_D_B.backward()
            optimizer_D_B.step()

            loss_D = (loss_D_A + loss_D_B) / 2

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()
            
            if i % 100 == 0:
            
                print(
                    "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
                    % (
                        epoch,
                        n_epochs,
                        i,
                        len(dataloader),
                        loss_D.item(),
                        loss_G.item(),
                        loss_GAN.item(),
                        loss_cycle.item(),
                        loss_identity.item(),
                        time_left,
                    )
                )

            # If at sample interval save image
            if batches_done % sample_interval == 0:
                sample_images(batches_done, dataloader)
                
            
            G_losses.append(loss_G.item())
            D_losses.append(loss_D.item())
            
            
        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D_A.step()
        lr_scheduler_D_B.step()

        if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(G_AB.state_dict(), "saved_models/%s/G_AB_%d.pth" % (dataset_name, epoch))
            torch.save(G_BA.state_dict(), "saved_models/%s/G_BA_%d.pth" % (dataset_name, epoch))
            torch.save(D_A.state_dict(), "saved_models/%s/D_A_%d.pth" % (dataset_name, epoch))
            torch.save(D_B.state_dict(), "saved_models/%s/D_B_%d.pth" % (dataset_name, epoch))

    plt.figure(figsize=(10,5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses,label="G")
    plt.plot(D_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [None]:
input_shape = (channels, img_height, img_width)

# Initialize generator and discriminator
G_AB = GeneratorResNet(input_shape, n_residual_blocks).cuda()
G_BA = GeneratorResNet(input_shape, n_residual_blocks).cuda()
D_A = Discriminator(input_shape).cuda()
D_B = Discriminator(input_shape).cuda()

# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)

Tensor = torch.cuda.FloatTensor

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Losses
criterion_GAN = torch.nn.MSELoss().cuda()
criterion_cycle = torch.nn.L1Loss().cuda()
criterion_identity = torch.nn.L1Loss().cuda()


if epoch != 0:

    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (dataset_name, epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (dataset_name, epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (dataset_name, epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (dataset_name, epoch)))
else:

    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)
    
epoch = 0
dataset_name = 'CycleGAN1SS'
os.makedirs("ganimages/%s" % dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % dataset_name, exist_ok=True)

train_gan(dataloader1, epoch)

[Epoch 0/51] [Batch 0/476] [D loss: 2.168109] [G loss: 6.816429, adv: 1.470208, cycle: 0.383003, identity: 0.303239] ETA: 5:32:35.700311
[Epoch 0/51] [Batch 100/476] [D loss: 0.264667] [G loss: 1.121777, adv: 0.322583, cycle: 0.054034, identity: 0.051771] ETA: 2:04:17.726608
[Epoch 0/51] [Batch 200/476] [D loss: 0.272156] [G loss: 1.007270, adv: 0.233533, cycle: 0.052532, identity: 0.049683] ETA: 2:04:43.689384
[Epoch 0/51] [Batch 300/476] [D loss: 0.239394] [G loss: 0.828983, adv: 0.240991, cycle: 0.040091, identity: 0.037417] ETA: 2:04:24.164183
[Epoch 0/51] [Batch 400/476] [D loss: 0.257638] [G loss: 0.626128, adv: 0.215599, cycle: 0.028597, identity: 0.024912] ETA: 2:04:00.187817
[Epoch 1/51] [Batch 0/476] [D loss: 0.300588] [G loss: 0.736618, adv: 0.301798, cycle: 0.030062, identity: 0.026840] ETA: 3:55:22.494459
[Epoch 1/51] [Batch 100/476] [D loss: 0.283368] [G loss: 0.689448, adv: 0.271335, cycle: 0.027733, identity: 0.028156] ETA: 2:04:34.982929
[Epoch 1/51] [Batch 200/476] [D

[Epoch 12/51] [Batch 0/476] [D loss: 0.259171] [G loss: 0.421254, adv: 0.247987, cycle: 0.011151, identity: 0.012351] ETA: 2:47:47.796910
[Epoch 12/51] [Batch 100/476] [D loss: 0.269078] [G loss: 0.513436, adv: 0.306977, cycle: 0.013379, identity: 0.014533] ETA: 1:38:50.700874
[Epoch 12/51] [Batch 200/476] [D loss: 0.264578] [G loss: 0.532459, adv: 0.324038, cycle: 0.013962, identity: 0.013760] ETA: 1:38:04.517367
[Epoch 12/51] [Batch 300/476] [D loss: 0.250814] [G loss: 0.804241, adv: 0.399250, cycle: 0.027677, identity: 0.025644] ETA: 1:37:06.938948
[Epoch 12/51] [Batch 400/476] [D loss: 0.220536] [G loss: 0.806691, adv: 0.256897, cycle: 0.036010, identity: 0.037940] ETA: 1:36:10.545248
[Epoch 13/51] [Batch 0/476] [D loss: 0.247118] [G loss: 0.656712, adv: 0.244928, cycle: 0.027622, identity: 0.027114] ETA: 3:02:35.207575
[Epoch 13/51] [Batch 100/476] [D loss: 0.277555] [G loss: 0.779757, adv: 0.285987, cycle: 0.034668, identity: 0.029417] ETA: 1:55:14.992661
[Epoch 13/51] [Batch 200

[Epoch 23/51] [Batch 400/476] [D loss: 0.278130] [G loss: 0.701702, adv: 0.433667, cycle: 0.018476, identity: 0.016656] ETA: 1:08:42.839233
[Epoch 24/51] [Batch 0/476] [D loss: 0.178935] [G loss: 1.154909, adv: 0.399584, cycle: 0.050506, identity: 0.050053] ETA: 1:54:06.745331
[Epoch 24/51] [Batch 100/476] [D loss: 0.192469] [G loss: 1.003663, adv: 0.341423, cycle: 0.044880, identity: 0.042689] ETA: 1:08:04.138550
[Epoch 24/51] [Batch 200/476] [D loss: 0.214090] [G loss: 0.595402, adv: 0.289697, cycle: 0.022233, identity: 0.016674] ETA: 1:06:58.694639
[Epoch 24/51] [Batch 300/476] [D loss: 0.128401] [G loss: 1.025875, adv: 0.667261, cycle: 0.025365, identity: 0.020994] ETA: 1:06:37.106266
[Epoch 24/51] [Batch 400/476] [D loss: 0.135564] [G loss: 0.908972, adv: 0.474736, cycle: 0.029295, identity: 0.028258] ETA: 1:05:53.897367
[Epoch 25/51] [Batch 0/476] [D loss: 0.109155] [G loss: 1.201258, adv: 0.774051, cycle: 0.029441, identity: 0.026560] ETA: 1:50:05.616941
[Epoch 25/51] [Batch 100

[Epoch 35/51] [Batch 300/476] [D loss: 0.236965] [G loss: 1.026210, adv: 0.770922, cycle: 0.017207, identity: 0.016644] ETA: 0:42:26.460247
[Epoch 35/51] [Batch 400/476] [D loss: 0.213565] [G loss: 0.819545, adv: 0.385913, cycle: 0.029199, identity: 0.028328] ETA: 0:40:30.657467
[Epoch 36/51] [Batch 0/476] [D loss: 0.158442] [G loss: 0.832348, adv: 0.460518, cycle: 0.025964, identity: 0.022438] ETA: 1:07:18.516197
[Epoch 36/51] [Batch 100/476] [D loss: 0.185027] [G loss: 0.758593, adv: 0.303116, cycle: 0.030052, identity: 0.030992] ETA: 0:40:29.203033
[Epoch 36/51] [Batch 200/476] [D loss: 0.229853] [G loss: 0.618881, adv: 0.303711, cycle: 0.021259, identity: 0.020515] ETA: 0:39:11.928563
[Epoch 36/51] [Batch 300/476] [D loss: 0.249178] [G loss: 0.559196, adv: 0.309163, cycle: 0.017322, identity: 0.015362] ETA: 0:39:16.383705
[Epoch 36/51] [Batch 400/476] [D loss: 0.234963] [G loss: 0.571385, adv: 0.249305, cycle: 0.022257, identity: 0.019903] ETA: 0:38:48.562284
[Epoch 37/51] [Batch 0

[Epoch 47/51] [Batch 200/476] [D loss: 0.191741] [G loss: 0.975357, adv: 0.284430, cycle: 0.045072, identity: 0.048041] ETA: 0:09:51.537249
[Epoch 47/51] [Batch 300/476] [D loss: 0.222476] [G loss: 0.607172, adv: 0.327046, cycle: 0.018713, identity: 0.018600] ETA: 0:09:11.544796
[Epoch 47/51] [Batch 400/476] [D loss: 0.182929] [G loss: 0.760731, adv: 0.403233, cycle: 0.024323, identity: 0.022854] ETA: 0:09:03.422081
[Epoch 48/51] [Batch 0/476] [D loss: 0.199048] [G loss: 0.751035, adv: 0.442426, cycle: 0.021007, identity: 0.019708] ETA: 0:11:26.801193
[Epoch 48/51] [Batch 100/476] [D loss: 0.154052] [G loss: 0.812317, adv: 0.432929, cycle: 0.025693, identity: 0.024492] ETA: 0:07:41.002968
[Epoch 48/51] [Batch 200/476] [D loss: 0.199017] [G loss: 0.790458, adv: 0.438326, cycle: 0.022916, identity: 0.024595] ETA: 0:07:00.350180
[Epoch 48/51] [Batch 300/476] [D loss: 0.199589] [G loss: 0.691289, adv: 0.388708, cycle: 0.020012, identity: 0.020492] ETA: 0:06:44.217768
[Epoch 48/51] [Batch 4

In [None]:
input_shape = (channels, img_height, img_width)

# Initialize generator and discriminator
G_AB = GeneratorResNet(input_shape, n_residual_blocks).cuda()
G_BA = GeneratorResNet(input_shape, n_residual_blocks).cuda()
D_A = Discriminator(input_shape).cuda()
D_B = Discriminator(input_shape).cuda()

# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)

Tensor = torch.cuda.FloatTensor

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Losses
criterion_GAN = torch.nn.MSELoss().cuda()
criterion_cycle = torch.nn.L1Loss().cuda()
criterion_identity = torch.nn.L1Loss().cuda()


if epoch != 0:

    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (dataset_name, epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (dataset_name, epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (dataset_name, epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (dataset_name, epoch)))
else:

    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)
    
epoch = 0
dataset_name = 'CycleGAN2SS'
os.makedirs("ganimages/%s" % dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % dataset_name, exist_ok=True)

train_gan(dataloader2, epoch)

In [None]:
input_shape = (channels, img_height, img_width)

# Initialize generator and discriminator
G_AB = GeneratorResNet(input_shape, n_residual_blocks).cuda()
G_BA = GeneratorResNet(input_shape, n_residual_blocks).cuda()
D_A = Discriminator(input_shape).cuda()
D_B = Discriminator(input_shape).cuda()

# Optimizers
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).step
)

Tensor = torch.cuda.FloatTensor

# Buffers of previously generated samples
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# Losses
criterion_GAN = torch.nn.MSELoss().cuda()
criterion_cycle = torch.nn.L1Loss().cuda()
criterion_identity = torch.nn.L1Loss().cuda()


if epoch != 0:

    G_AB.load_state_dict(torch.load("saved_models/%s/G_AB_%d.pth" % (dataset_name, epoch)))
    G_BA.load_state_dict(torch.load("saved_models/%s/G_BA_%d.pth" % (dataset_name, epoch)))
    D_A.load_state_dict(torch.load("saved_models/%s/D_A_%d.pth" % (dataset_name, epoch)))
    D_B.load_state_dict(torch.load("saved_models/%s/D_B_%d.pth" % (dataset_name, epoch)))
else:

    G_AB.apply(weights_init_normal)
    G_BA.apply(weights_init_normal)
    D_A.apply(weights_init_normal)
    D_B.apply(weights_init_normal)
    
epoch = 0
dataset_name = 'CycleGAN3SS'
os.makedirs("ganimages/%s" % dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % dataset_name, exist_ok=True)

train_gan(dataloader3, epoch)