# Simple GAN implementation

Imports

In [None]:
import sys
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch


Dataset Selection

In [None]:
# dataset = "mnist"
dataset = "apple2orange"

Configure Data Loader

In [None]:
if dataset == "mnist":
    n_epochs = 200  # number of epochs of training
    batch_size = 64  # size of the batches
    lr = 0.0002  # adam: learning rate
    b1 = 0.5  # adam: decay of first order momentum of gradient
    b2 = 0.999  # adam: decay of first order momentum of gradient
    n_cpu = 8  # number of cpu threads to use during batch generation
    latent_dim = 100  # dimensionality of the latent space
    img_size = 28  # size of each image dimension
    sample_interval = 400  # interval between image samples
    channels = 1  # grayscale
    
elif dataset == "apple2orange": # TODO: Doesn't work it's not converging
    n_epochs = 400  # number of epochs of training
    batch_size = 20  # size of the batches
    lr = 0.0002  # adam: learning rate
    b1 = 0.5  # adam: decay of first order momentum of gradient
    b2 = 0.999  # adam: decay of first order momentum of gradient
    n_cpu = 8  # number of cpu threads to use during batch generation
    latent_dim = 300  # dimensionality of the latent space
    img_size = 28  # size of each image dimension
    sample_interval = 400  # interval between image samples
    channels = 3  # rgb

else:
    raise Exception("Unknown dataset")

print({
    "n_epochs": n_epochs,
    "batch_size": batch_size,
    "lr": lr,
    "b1": b1,
    "b2": b2,
    "n_cpu": n_cpu,
    "latent_dim": latent_dim,
    "img_size": img_size,
    "channels": channels,
    "sample_interval": sample_interval
})


Model selection

In [None]:
# model = "gan"
model = "dcgan"

if model == "gan":
    from gan import Generator, Discriminator
    image_progress_folder = "images_gan_" + dataset
    
elif model == "dcgan":
    from dcgan import Generator, Discriminator, weights_init_normal
    image_progress_folder = "images_dcgan_" + dataset
    global img_size
    img_size = 32 # Filter is 4x4

else:
    raise Exception("Unknown model")

In [None]:
os.makedirs(image_progress_folder, exist_ok=True)

In [None]:
img_shape = (channels, img_size, img_size)

Setting up cuda

In [None]:
cuda = True if torch.cuda.is_available() else False
print(cuda)
device = torch.device("cuda" if cuda else "cpu")

Initialise models

In [None]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator(img_shape, latent_dim)
discriminator = Discriminator(img_shape)

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

if model == "dcgan":
    # Initialize weights
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

Optimizers

In [None]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
if dataset == "mnist":
    
    os.makedirs("./datasets/mnist", exist_ok=True)
    dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(
            "./datasets/mnist",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )
    
elif dataset == "apple2orange": # TODO: Doesn't work it's not converging

    import subprocess
    command = "bash ./datasets/download_cyclegan_dataset.sh"
    subprocess.run(command, shell=True)

    dataloader = torch.utils.data.DataLoader(
        datasets.ImageFolder(
            root="./datasets/apple2orange/TrainA", # apples
            transform=transforms.Compose(
                [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )

Training

In [None]:
print(f"Using '{model}' with '{dataset}', saving progress to '{image_progress_folder}'.")

In [None]:
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        # Print log
        sys.stdout.write(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(gen_imgs.data[:25], image_progress_folder + "/%d.png" % batches_done, nrow=5, normalize=True)