# Different GANs implementations

## Model and Dataset Selection

Select a model and dataset and execute all cells of the notebook. Generated images during training will be placed in `./images/<model-name>/<dataset-name>/`

**Model Selection**

In [None]:
# -- Takes noise vectors as input -- #

# model = "gan"
# model = "dcgan"

# -- Takes images as input -- #

# model = "fdcgan"
model = "resnet"

**Dataset Selection**

In [None]:
# -- Only adapted to noise vectors input -- #

# dataset = "mnist"

# -- Adapted to both noise vectors and images input -- #

# dataset = "apple2orange64"
dataset = "orange2apple64"

## Imports

In [None]:
import sys
import os
import math
import itertools
import datetime
import time
import glob
import random

import numpy as np

from PIL import Image

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch

## Hyperparameters and Datasets Specific Configuration

**Global Configuration**

In [None]:
start_epoch = 0         # epoch to start training from
n_epochs = 200          # number of epochs of training
decay_epoch = 100       # epoch from which to start lr decay
n_cpu = 8               # number of cpu threads to use during batch generation
lr = 0.0002             # adam: learning rate
b1 = 0.5                # adam: decay of first order momentum of gradient
b2 = 0.999              # adam: decay of first order momentum of gradient
sample_interval = 400   # interval between image samples
checkpoint_interval = 1 # interval between batches for saving model checkpoints

**Datasets Specifications**

In [None]:
if dataset == "mnist":
    batch_size = 64     # size of the batches
    img_size = 28       # size of each image dimension
    latent_dim = 100    # dimensionality of the latent space
    channels = 1        # grayscale
    
elif dataset == "apple2orange64" or dataset == "orange2apple64":
    batch_size = 4      # size of the batches
    img_size = 64       # size of each image dimension
    latent_dim = 300    # dimensionality of the latent space
    channels = 3        # rgb

else:
    raise Exception("Unknown dataset")

print({
    "n_epochs": n_epochs,
    "batch_size": batch_size,
    "lr": lr,
    "b1": b1,
    "b2": b2,
    "n_cpu": n_cpu,
    "latent_dim": latent_dim,
    "img_size": img_size,
    "channels": channels,
    "sample_interval": sample_interval
})


**Importing Chosen Models**

In [None]:
if model == "gan":
    from gan import Generator, Discriminator
    
elif model == "dcgan":
    from dcgan import Generator, Discriminator, weights_init_normal
    if dataset == "mnist":
        img_size = 32

elif model == "fdcgan" or model == "resnet":
    lambda_id = 5.0
    if dataset not in ["apple2orange64", "orange2apple64"]:
        raise Exception(f"Dataset {dataset} has no input image for the generator")
    if model == "fdcgan":
        from fdcgan import Generator, Discriminator, weights_init_normal
    elif model == "resnet":
        from resnet import Generator, Discriminator, weights_init_normal
        n_residual_blocks = 9

else:
    raise Exception("Unknown model")

img_shape = (channels, img_size, img_size)

**Create Progress and Checkpoint Directories**

In [None]:
image_progress_folder = "images/%s/%s/" % (model, dataset)

os.makedirs(image_progress_folder, exist_ok=True)
os.makedirs("./saved_models/%s/%s" % (model, dataset), exist_ok=True)

**Setting Up Cuda**

In [None]:
cuda = True if torch.cuda.is_available() else False
device = torch.device("cuda" if cuda else "cpu")
print("Device use for training: ", device)

## Initialising Models

In [None]:
# -- Loss functions -- #

if model == "gan" or model == "dcgan" or model == "fdcgan":
    adversarial_loss = torch.nn.BCELoss()
elif model == "resnet":
    adversarial_loss = torch.nn.MSELoss()

identity_loss = torch.nn.L1Loss()

# -- Initialize generator and discriminator -- #

if model == "gan" or model == "dcgan":
    generator = Generator(img_shape, latent_dim)
elif model == "fdcgan":
    generator = Generator(img_shape)
elif model == "resnet":
    generator = Generator(img_shape, n_residual_blocks)

discriminator = Discriminator(img_shape)

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()
    identity_loss.cuda()

if start_epoch != 0:
    # Load pretrained models
    generator.load_state_dict(torch.load("saved_models/%s/%s/generator_%d.pth" % (model, dataset, start_epoch)))
    discriminator.load_state_dict(torch.load("saved_models/%s/%s/discriminator_%d.pth" % (model, dataset, start_epoch)))
else:
    # Initialize weights
    if model == "dcgan" or model == "fdcgan"  or model == "resnet":
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)

## Setting up Optimizers

In [None]:
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0, "Decay must start before the training session ends!"
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch

    def step(self, epoch):
        return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)
    
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# Learning rate update schedulers
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G, lr_lambda=LambdaLR(n_epochs, start_epoch, decay_epoch).step
)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D, lr_lambda=LambdaLR(n_epochs, start_epoch, decay_epoch).step
)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

## Importing Dataset

In [None]:
def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image

class ImageDataset(Dataset):
    def __init__(self, root, transforms_=None, unaligned=False, mode="train"):
        self.transform = transforms.Compose(transforms_)
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, "%s/A" % mode) + "/*.*"))
        self.files_B = sorted(glob.glob(os.path.join(root, "%s/B" % mode) + "/*.*"))

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])

        if self.unaligned:
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])

        # Convert grayscale images to rgb
        if image_A.mode != "RGB":
            image_A = to_rgb(image_A)
        if image_B.mode != "RGB":
            image_B = to_rgb(image_B)

        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

transforms_ = [
    transforms.Resize(int(img_size * 1.12), Image.BICUBIC),
    transforms.RandomCrop((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]

if dataset == "mnist":
    
    os.makedirs("./datasets/mnist", exist_ok=True)

    dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./datasets/mnist",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_cpu,
    )

if dataset == "apple2orange64":
    my_class_A = "A"
    my_class_B = "B"

if dataset == "orange2apple64":
    my_class_A = "B"
    my_class_B = "A"

if dataset == "apple2orange64" or dataset == "orange2apple64":

    import subprocess
    command = "bash ./datasets/download_cyclegan_dataset.sh"
    subprocess.run(command, shell=True)

    dataloader = torch.utils.data.DataLoader(
        ImageDataset("./datasets/apple2orange64", transforms_=transforms_, unaligned=True, mode="train"),
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_cpu,
    )

    # Test data loader
    val_dataloader = DataLoader(
        ImageDataset("./datasets/apple2orange64", transforms_=transforms_, unaligned=True, mode="validation"),
        batch_size=5,
        shuffle=True,
        num_workers=1,
    )

def sample_images(batches_done, my_class_A):
    """Saves a generated sample from the test set"""
    imgs = next(iter(val_dataloader))
    generator.eval()
    real = Variable(imgs[my_class_A].type(Tensor))
    fake = generator(real)
    # Arange images along x-axis
    real = make_grid(real, nrow=5, normalize=True)
    fake = make_grid(fake, nrow=5, normalize=True)
    # Arange images along y-axis
    image_grid = torch.cat((real, fake), 1)
    save_image(image_grid, "images/%s/%s/%s.png" % (model, dataset, batches_done), normalize=False)

## Training

Unfortunately depending on how the dataset is imported we need to do separated training loops for mnist and apple2orange but the procedure is exactly the same.

In [None]:
print(f"Using '{model}' with '{dataset}', saving progress to '{image_progress_folder}'.")

### Random vector noise input

**mnist:**

In [None]:
if dataset == "mnist":
    prev_time = time.time()
    for epoch in range(n_epochs):
        for i, (imgs, _) in enumerate(dataloader):

            if (imgs.shape[0] != batch_size):
                continue

            # Adversarial ground truths
            valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
            fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

            # Configure input
            real_imgs = Variable(imgs.type(Tensor))

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Sample noise as generator input
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

            # Generate a batch of images
            gen_imgs = generator(z)

            # Loss measures generator's ability to fool the discriminator
            loss_identity = identity_loss(gen_imgs, real_imgs)

            loss_GAN = adversarial_loss(discriminator(gen_imgs), valid)

            # Total loss
            g_loss = loss_GAN + lambda_id * loss_identity

            g_loss.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            # --------------
            #  Log Progress
            # --------------

            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()
            
            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s"
                % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), time_left))

            batches_done = epoch * len(dataloader) + i
            if batches_done % sample_interval == 0:
                save_image(gen_imgs.data[:5], image_progress_folder + "/%d.png" % batches_done, nrow=5, normalize=True)
            
        if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(generator.state_dict(), "saved_models/%s/%s/generator_%d.pth" % (model, dataset, epoch))
            torch.save(discriminator.state_dict(), "saved_models/%s/%s/generator_%d.pth" % (model, dataset, epoch))

**orange2apple or apple2orange with gan or dcgan:**

In [None]:
if (dataset == "orange2apple64" or dataset == "apple2orange64") and (model == "gan" or model == "dcgan"):
    prev_time = time.time()
    for epoch in range(start_epoch, n_epochs):
        for i, batch in enumerate(dataloader):
            # Set model input

            if dataset == "orange2apple64" or dataset == "apple2orange64":
                real_imgs = Variable(batch[my_class_B].type(Tensor))

                if (real_imgs.shape[0] != batch_size):
                    continue

            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_imgs.size(0), 1))), requires_grad=False)
            fake = Variable(Tensor(np.zeros((real_imgs.size(0), 1))), requires_grad=False)

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Sample noise as generator input
            z = Variable(Tensor(np.random.normal(0, 1, (real_imgs.shape[0], latent_dim))))

            # Generate a batch of images
            gen_imgs = generator(z)
            
            # Identity loss, Generator should be identity if real image is fed
            loss_identity = identity_loss(generator(real_imgs), real_imgs) 

            # Loss measures generator's ability to fool the discriminator
            loss_GAN = adversarial_loss(discriminator(gen_imgs), valid)

            # Total loss
            loss_G = loss_GAN + lambda_id * loss_identity

            loss_G.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()

            # --------------
            #  Log Progress
            # --------------
            
            # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()

            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] ETA: %s"
                % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), time_left))

            batches_done = epoch * len(dataloader) + i
            if batches_done % sample_interval == 0:
                save_image(gen_imgs.data[:batch_size], image_progress_folder + "/%d.png" % batches_done, nrow=5, normalize=True)
        
        if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(generator.state_dict(), "saved_models/%s/%s/generator_%d.pth" % (model, dataset, epoch))
            torch.save(discriminator.state_dict(), "saved_models/%s/%s/generator_%d.pth" % (model, dataset, epoch))
                    

### Image input

**orange2apple or apple2orange with fdcgan or resnet (improved fdgcan):**

In [None]:
if (dataset == "orange2apple64" or dataset == "apple2orange64") and (model == "fdcgan" or model == "resnet"):
    prev_time = time.time()
    for epoch in range(start_epoch, n_epochs):
        for i, batch in enumerate(dataloader):

            # Set model input
            real_A = Variable(batch[my_class_A].type(Tensor))
            real_B = Variable(batch[my_class_B].type(Tensor))

            # Adversarial ground truths
            valid = Variable(Tensor(np.ones((real_A.size(0), *discriminator.output_shape))), requires_grad=False)
            fake = Variable(Tensor(np.zeros((real_A.size(0), *discriminator.output_shape))), requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            generator.train()

            optimizer_G.zero_grad()

            # Identity loss
            loss_identity = identity_loss(generator(real_B), real_B) # Should't modify real_B

            # GAN loss
            fake_B = generator(real_A)
            loss_GAN = adversarial_loss(discriminator(fake_B), valid)


            # Total loss
            loss_G = loss_GAN + lambda_id * loss_identity

            loss_G.backward()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            # Real loss
            loss_real = adversarial_loss(discriminator(real_B), valid)
            # TODO add using previous batch ?
            loss_fake = adversarial_loss(discriminator(fake_B.detach()), fake)
            # Total loss
            loss_D = (loss_real + loss_fake) / 2

            loss_D.backward()
            optimizer_D.step()

            # --------------
            #  Log Progress
            # --------------

            # # Determine approximate time left
            batches_done = epoch * len(dataloader) + i
            batches_left = n_epochs * len(dataloader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, identity: %f] ETA: %s"
                % (
                    epoch,
                    n_epochs,
                    i,
                    len(dataloader),
                    loss_D.item(),
                    loss_G.item(),
                    loss_GAN.item(),
                    loss_identity.item(),
                    time_left,
                )
            )

            # If at sample interval save image
            if batches_done % sample_interval == 0:
                sample_images(batches_done, my_class_A)

        # Update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D.step()

        if checkpoint_interval != -1 and epoch % checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(generator.state_dict(), "saved_models/%s/%s/generator_%d.pth" % (model, dataset, epoch))
            torch.save(discriminator.state_dict(), "saved_models/%s/%s/generator_%d.pth" % (model, dataset, epoch))
