# Dog image generation using GANs

### Setup and imports

In [None]:
import os
from glob import glob

In [None]:
import numpy as np

import lightning as L
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (10, 10)

from tqdm.notebook import tqdm

import torch.optim as optim

### Load and define dataset

In [None]:
# Set to stanford dogs image dataset path
# available at http://vision.stanford.edu/aditya86/ImageNetDogs/
DATASET_PATH = '/media/tguy/Records/timothee/workspace/research/datasets/generative-dog-images/all_dogs_imgs/'

In [None]:
# Raw dataset
original_ds = ImageFolder(
    DATASET_PATH,
    transform=transforms.ToTensor()
)
original_ds

In [None]:
# display grid of images
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(6, 6))

for i in range(4):
    for j in range(4):
        axes[i, j].imshow(original_ds[i * 4 + j][0].permute(1, 2, 0))
        axes[i, j].axis('off')

In [None]:
batch_size = 32
img_target_size = 64

# preprocessing transform
preprocess = transforms.Compose([
    transforms.Resize(img_target_size),
    transforms.CenterCrop(img_target_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))
])
train_ds = ImageFolder(DATASET_PATH, transform=preprocess)

train_dl = DataLoader(
    dataset=train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    drop_last=True,
)


### Simple GAN model with pytorch lightning

In [None]:
def compute_deconvolution_out_size(
    in_size: int,
    kernel_size: int,
    stride: int = 1,
    padding: int = 0,
    output_padding: int = 0,
    dilation: int = 1,
) -> int:
    return int(
        (in_size - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1
    )

In [None]:
compute_deconvolution_out_size(32, 4, 2, 1)

In [None]:
# Define generator
class Generator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh(),
        )
    
    def forward(self, x):
        generated_img = self.model(x)
        return generated_img


# define discriminator
class Discriminator(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
        )
    
    def forward(self, x):
        validity = self.model(x)
        return validity.view(-1)

In [None]:
# Define GAN class with LightningModule
from typing import Any
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler


class GAN_for_dogs(L.LightningModule):
    def __init__(self,
                 latent_dim: int = 100,
                 lr_g: float = 0.001,
                 lr_d: float = 0.0005,
                 b1: float = 0.5,
                 b2: float = 0.999,
                 target_img_size: tuple[int, int] = (64, 64)) -> None:
        super().__init__()
        self.save_hyperparameters()

        self.latent_dim = latent_dim
        self.lr_g = lr_g
        self.lr_d = lr_d
        self.b1 = b1
        self.b2 = b2
        self.target_img_size = target_img_size
        
        # manual optimization
        self.automatic_optimization=False

        self.generator = Generator()
        self.discriminator = Discriminator()
        self.criterion = nn.BCEWithLogitsLoss()

        self.validation_z = torch.randn(8, self.latent_dim, 1, 1)
        self.example_input_array = torch.zeros(2, self.latent_dim, 1, 1)
        
    def forward(self, z):
        # generate image from random noise
        return self.generator(z)
    
    def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
        imgs, _ = batch
        batch_size = imgs.size(0)

        opt_g, opt_d = self.optimizers()
        
        real_targets = torch.ones(batch_size, dtype=imgs.dtype).to(self.device)
        fake_targets = torch.zeros(batch_size, dtype=imgs.dtype).to(self.device)
        

        # noise sampling
        z = torch.randn(batch_size, self.latent_dim, 1, 1, dtype=imgs.dtype).to(self.device)
        
        # optimize on generator
        generated = self.generator(z)
        
        if batch_idx % 500 == 0:
            for i, genereted_img in enumerate(generated):
                self.logger.experiment.add_image(f'generated/{i}', genereted_img * 0.5 + 0.5, self.global_step)
                
        
        loss_g = self.criterion(self.discriminator(generated), real_targets)
        opt_g.zero_grad()
        self.manual_backward(loss_g)
        opt_g.step()
        
        # discriminator
        loss_d = (
            self.criterion(self.discriminator(imgs), real_targets)
            + self.criterion(self.discriminator(generated.detach()), fake_targets)
        ) / 2
        opt_d.zero_grad()
        self.manual_backward(loss_d)
        opt_d.step()
        
        
        self.log_dict(
            {
                'loss_g': loss_g,
                'loss_d': loss_d,
            },
            prog_bar=True,
        )

    def configure_optimizers(self) -> OptimizerLRScheduler:
        # when multiple optimizers are used, optimization is done manually in lightning
        opt_g = optim.AdamW(
            self.generator.parameters(),
            lr=self.lr_g,
            betas=(self.b1, self.b2)
        )
        opt_d = optim.AdamW(
            self.discriminator.parameters(),
            lr=self.lr_d,
            betas=(self.b1, self.b2)
        )
        
        # TODO: test with ReduceLROnPlateau
        return [opt_g, opt_d], []

### Training

In [None]:
trainer = L.Trainer(
    accelerator='gpu',
    max_epochs=500,
    precision="16-mixed",
    deterministic=True,
)

In [None]:
model = GAN_for_dogs()

trainer.fit(
    model,
    train_dl,
)