# Cycle-GAN implementations

## Model and Dataset Selection

Select a GAN model and execute all cells of the notebook. Generated images during training will be placed in `./images/cycle-gan/<model-name>/<dataset-name>/`

**Model Selection**

In [None]:
model = "fdcgan"
# model = "resnet"

**Dataset Selection**

In [None]:
dataset = "apple2orange64"

## Imports

In [None]:
import sys
import os
import math
import itertools
import datetime
import time
import glob
import random

import numpy as np

from PIL import Image

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch

## Hyperparameters and Datasets Specific Configuration

**Global Configuration**

In [None]:
start_epoch = 17         # epoch to start training from
n_epochs = 200          # number of epochs of training
decay_epoch=100         # epoch from which to start lr decay
n_cpu=8                 # number of cpu threads to use during batch generation
lr = 0.0002             # adam: learning rate
b1 = 0.5                # adam: decay of first order momentum of gradient
b2 = 0.999              # adam: decay of first order momentum of gradient
n_cpu = 8               # number of cpu threads to use during batch generation
sample_interval = 200   # interval between image samples
checkpoint_interval = 1 # interval between batches for saving model checkpoints
lambda_id = 5.0         # identity loss weight
lambda_cyc = 10.0       # cycle loss weight
if model == "fdcgan":
    lambda_id = 1.0 
    lambda_cyc = 2.0 

**Datasets Specifications**

In [None]:
if dataset == "apple2orange64":
    batch_size = 4      # size of the batches
    img_size = 64       # size of each image dimension
    channels = 3        # rgb
    
else:
    raise Exception("Unknown dataset")

print({
    "n_epochs": n_epochs,
    "batch_size": batch_size,
    "lr": lr,
    "b1": b1,
    "b2": b2,
    "n_cpu": n_cpu,
    "img_size": img_size,
    "channels": channels,
    "sample_interval": sample_interval
})


**Importing Chosen Models**

In [None]:
if model == "fdcgan" or model == "resnet":
    if model == "fdcgan":
        from fdcgan import Generator, Discriminator, weights_init_normal
    elif model == "resnet":
        from resnet import Generator, Discriminator, weights_init_normal
        n_residual_blocks = 9

else:
    raise Exception("Unknown model")

img_shape = (channels, img_size, img_size)

**Create Progress and Checkpoint Directories**

In [None]:
image_progress_folder = "images/cycle-gan/%s/%s/" % (model, dataset)

os.makedirs(image_progress_folder, exist_ok=True)
os.makedirs("./saved_models/cycle-gan/%s/%s" % (model, dataset), exist_ok=True)

**Setting Up Cuda**

In [None]:
cuda = True if torch.cuda.is_available() else False
device = torch.device("cuda" if cuda else "cpu")
print("Device use for training: ", device)

## Initialising Models

In [None]:
# -- Loss functions -- #

if model == "fdcgan":
    adversarial_loss = torch.nn.BCELoss()
elif model == "resnet":
    adversarial_loss = torch.nn.MSELoss()

identity_loss = torch.nn.L1Loss()
criterion_cycle = torch.nn.L1Loss()

# -- Initialize generator and discriminator -- #

if model == "fdcgan":
    G_AB = Generator(img_shape)
    G_BA = Generator(img_shape)
elif model == "resnet":
    G_AB = Generator(img_shape, n_residual_blocks)
    G_BA = Generator(img_shape, n_residual_blocks)

D_A = Discriminator(img_shape)
D_B = Discriminator(img_shape)

if cuda:
    G_AB.cuda()
    G_BA.cuda()
    D_A.cuda()
    D_B.cuda()
    adversarial_loss.cuda()
    identity_loss.cuda()
    criterion_cycle.cuda()

if start_epoch != 0:
    # Load pretrained models
    G_AB.load_state_dict(torch.load("saved_models/cycle-gan/%s/%s/G_AB_%d.pth" % (model, dataset, start_epoch)))
    G_BA.load_state_dict(torch.load("saved_models/cycle-gan/%s/%s/G_BA_%d.pth" % (model, dataset, start_epoch)))
    D_A.load_state_dict(torch.load("saved_models/cycle-gan/%s/%s/D_A_%d.pth" % (model, dataset, start_epoch)))
    D_B.load_state_dict(torch.load("saved_models/cycle-gan/%s/%s/D_B_%d.pth" % (model, dataset, start_epoch)))
else:
    # Initialize weights
    if model == "dcgan" or model == "fdcgan" or model == "resnet":
        G_AB.apply(weights_init_normal)
        G_BA.apply(weights_init_normal)
        D_A.apply(weights_init_normal)
        D_B.apply(weights_init_normal)

## Setting up Optimizers

In [None]:
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)

optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr, betas=(b1, b2)
)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr, betas=(b1, b2))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr, betas=(b1, b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(n_epochs, start_epoch, decay_epoch).step
)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A, lr_lambda=LambdaLR(n_epochs, start_epoch, decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B, lr_lambda=LambdaLR(n_epochs, start_epoch, decay_epoch).step
)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

**Buffers of previously generated samples:**

In [None]:
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))
        
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

## Importing Dataset

In [None]:
def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image

class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, "%s/A" % mode) + "/*.*"))
        self.files_B = sorted(glob.glob(os.path.join(root, "%s/B" % mode) + "/*.*"))

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])

        if self.unaligned:
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])

        # Convert grayscale images to rgb
        if image_A.mode != "RGB":
            image_A = to_rgb(image_A)
        if image_B.mode != "RGB":
            image_B = to_rgb(image_B)

        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

transforms_ = [
    transforms.Resize(int(img_size * 1.12), Image.BICUBIC),
    transforms.RandomCrop((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

if dataset == "apple2orange64":

    import subprocess
    command = "bash ./datasets/download_cyclegan_dataset.sh"
    subprocess.run(command, shell=True)

    dataloader = torch.utils.data.DataLoader(
        ImageDataset("./datasets/apple2orange64", transforms_=transforms_, unaligned=True, mode="train"),
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_cpu,
    )

    # Test data loader
    val_dataloader = DataLoader(
        ImageDataset("./datasets/apple2orange64", transforms_=transforms_, unaligned=True, mode="validation"),
        batch_size=5,
        shuffle=True,
        num_workers=1,
    )

def sample_images(batches_done):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    G_AB.eval()
    G_BA.eval()
    real_A = Variable(imgs["A"].type(Tensor))
    fake_B = G_AB(real_A)
    real_B = Variable(imgs["B"].type(Tensor))
    fake_A = G_BA(real_B)
    # Arange images along x-axis
    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    # Arange images along y-axis
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    print("saving to ", batches_done)
    save_image(image_grid, "images/cycle-gan/%s/%s/%s.png" % (model, dataset, batches_done), normalize=False)


## Training

Unfortunately depending on how the dataset is imported we need to do separated training loops for mnist and apple2orange but the procedure is exactly the same.

In [None]:
print(f"Using '{model}' with '{dataset}', saving progress to '{image_progress_folder}'.")

In [None]:
prev_time = time.time()
for epoch in range(start_epoch, n_epochs):
    for i, batch in enumerate(dataloader):

        # Set model input
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        G_AB.train()
        G_BA.train()

        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = identity_loss(G_BA(real_A), real_A)
        loss_id_B = identity_loss(G_AB(real_B), real_B)

        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = adversarial_loss(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = adversarial_loss(D_A(fake_A), valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity

        loss_G.backward()
        optimizer_G.step()

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        optimizer_D_A.zero_grad()

        # Real loss
        loss_real = adversarial_loss(D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
        fake_A_ = fake_A_buffer.push_and_pop(fake_A)
        loss_fake = adversarial_loss(D_A(fake_A_.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        optimizer_D_B.zero_grad()

        # Real loss
        loss_real = adversarial_loss(D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
        fake_B_ = fake_B_buffer.push_and_pop(fake_B)
        loss_fake = adversarial_loss(D_B(fake_B_.detach()), fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = (loss_D_A + loss_D_B) / 2

        # --------------
        #  Log Progress
        # --------------

        # Determine approximate time left
        batches_done = epoch * len(dataloader) + i
        batches_left = n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
        prev_time = time.time()

        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
            % (
                epoch,
                n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_GAN.item(),
                loss_cycle.item(),
                loss_identity.item(),
                time_left,
            )
        )

        # If at sample interval save image
        if batches_done % sample_interval == 0:
            sample_images(batches_done)

    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

    if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(G_AB.state_dict(), "saved_models/cycle-gan/%s/%s/G_AB_%d.pth" % (model, dataset, epoch))
        torch.save(G_BA.state_dict(), "saved_models/cycle-gan/%s/%s/G_BA_%d.pth" % (model, dataset, epoch))
        torch.save(D_A.state_dict(), "saved_models/cycle-gan/%s/%s/D_A_%d.pth" % (model, dataset, epoch))
        torch.save(D_B.state_dict(), "saved_models/cycle-gan/%s/%s/D_B_%d.pth" % (model, dataset, epoch))