# Imports

In [1]:
from __future__ import print_function

import os
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

if torch.cuda.is_available():
    torch.cuda.set_device(0)

seed = 999
random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f56013eec70>

# Dataloader

In [None]:
class Dataloader(object):
    def __init__(self, path, image_size, batch_size):
        dataset = datasets.ImageFolder(root=path, transform=transforms.Compose([
                                                    transforms.Resize(image_size),
                                                    transforms.CenterCrop(image_size),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]))

        self.dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    def show(self):
        _, one_batch = next(enumerate(self.dataloader))
        plt.figure(figsize=(8,8))
        plt.axis("off")
        plt.title("Real Images")
        plt.imshow(np.transpose(make_grid(one_batch[0][:64], padding=2, normalize=True), (1,2,0)))

In [None]:
dataloader = Dataloader("./data/celeba", image_size=64, batch_size=512)
dataloader.show()

# Build Models

## Weight initialization

In [None]:
def init_weights(model):
    layer_name = model.__class__.__name__
    if layer_name.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)
    elif layer_name.find('BatchNorm') != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)

## Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, n_input, n_kernels, im_size=(3, 64, 64)):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(n_input, n_kernels * 8, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(n_kernels * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(n_kernels * 8, n_kernels * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_kernels * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(n_kernels * 4, n_kernels * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_kernels * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(n_kernels * 2, n_kernels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_kernels),
            nn.ReLU(True),
            nn.ConvTranspose2d(n_kernels, im_size[0], kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
            #out size 3 x 64 x 64
        )
        self.apply(init_weights)    

    def forward(self, X):
        return self.main(X)

    def print(self):
        print(self)


## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, n_kernels, im_size=(3, 64, 64)):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(im_size[0], n_kernels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(n_kernels, n_kernels * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_kernels * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(n_kernels * 2, n_kernels * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_kernels * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(n_kernels * 4, n_kernels * 8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(n_kernels * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(n_kernels * 8, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )

        self.apply(init_weights)
    
    def forward(self, X):
        return self.main(X)

    def print(self):
        print(self)

# Trainer

In [None]:
class Trainer(object):
    def __init__(self, dataloader, in_size, modelG, modelD, lr=0.0002, im_size=(3, 64, 64)):
        self.dataloader = dataloader
        self.modelG = modelG
        self.modelD = modelD
        self.criterion = nn.BCELoss()
        self.in_size = in_size

        #Create fixed random noise vector to generate evaluation images during training
        self.fixed_noise_vector = torch.randn(64, in_size, 1, 1)
        if torch.cuda.is_available():
            self.fixed_noise_vector = self.fixed_noise_vector.cuda()

        #label convention from original GAN paper
        self.real_label = 1
        self.fake_label = 0

        #optimizers
        self.optimizerG = optim.Adam(modelG.parameters(), lr=lr, betas=(0.5, 0.999))
        self.optimizerD = optim.Adam(modelD.parameters(), lr=lr, betas=(0.5, 0.999))

    def train(self, epochs):
        if torch.cuda.is_available():
            self.modelG.cuda()
            self.modelD.cuda()
        # Structures to store the training progress
        self.img_list = []
        self.G_losses = []
        self.D_losses = []

        count = 0

        for epoch in tqdm(range(epochs), desc='Epochs'):
            for i, real_data in enumerate(tqdm(self.dataloader, desc='Batches'), 0):
                ################Discriminator####################

                #Objective Function: log(D(x)) + log(1 - D(G(z)))
                self.modelD.zero_grad()
                #real data batch
                batch_size = len(real_data[0])
                labels = torch.full((batch_size,), self.real_label)
                if torch.cuda.is_available():
                    real_data, labels = real_data[0].cuda(), labels.cuda()

                y = self.modelD(real_data).view(-1)
                #real data loss on discriminator
                lossD_real = self.criterion(y, labels)
                lossD_real.backward()
                D_x = y.mean().item()

                #fake data batch
                noise_vectors = torch.randn(batch_size, self.in_size, 1, 1)
                if torch.cuda.is_available():
                    noise_vectors = noise_vectors.cuda()
                
                #generate fake images
                fakes = self.modelG(noise_vectors)
                labels.fill_(self.fake_label)
                y = self.modelD(fakes).view(-1)
                #fake data loss on discriminator
                lossD_fake = self.criterion(y, labels)
                lossD_fake.backward(retain_graph=True)
                D_G_z1 = y.mean().item()
                #final discriminator loss
                lossD = lossD_real + lossD_fake
                #update
                self.optimizerD.step()

                ##############Generator#######################

                #Objective function: log(D(G(z)))
                self.modelG.zero_grad()
                labels.fill_(self.real_label) #the generator loss consider the fake images as real
                #regenerate the discriminator prediction with its updated version
                y = self.modelD(fakes).view(-1)
                #generator loss
                lossG = self.criterion(y, labels)
                lossG.backward()
                D_G_z2 = y.mean().item()
                #update
                self.optimizerG.step()

                if i % 100 == 0:
                    tqdm.write("[%d/%d][%d/%d] LossD: %.4f LossG: %.4f D(x): %.4f D(G(z)): %.4f -> %.4f" % (epoch, epochs, i, len(self.dataloader), lossD.item(), lossG.item(), D_x, D_G_z1, D_G_z2))

                self.G_losses.append(lossG.item())
                self.D_losses.append(lossD.item())

                if count % 1000 == 0 or (epoch == epochs - 1 and i == len(self.dataloader) - 1):
                    with torch.no_grad():
                        fakes = self.modelG(self.fixed_noise_vector).detach().cpu()
                    self.img_list.append(make_grid(fakes, padding=2, normalize=True))
                    plt.figure(figsize=(8,8))
                    plt.axis("off")
                    plt.imshow(np.transpose(self.img_list[-1], (1,2,0)))

                count += 1

    def print_losses(self):
        plt.figure(figsize=(10,5))
        plt.plot(self.G_losses, lable="Generator")
        plt.plot(self.D_losses, lable="Discriminator")
        plt.xlable("iterations")
        plt.ylable("Loss")
        plt.legend()
        plt.show()

# Training

In [None]:
gen = Generator(100, 64)
dis = Discriminator(64)
trainer = Trainer(dataloader.dataloader, 100, modelG=gen, modelD=dis)

In [None]:
trainer.train(100)