# CSC52087EP lab1a by Vicky Kalogeiton
# Ecole Polytechnique
# Basic GAN Notebook

In [None]:
# import the libraries
import torch, pdb
from torch.utils.data import DataLoader
from torch import nn
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from typing import Callable
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

## Setup

In [None]:
# visualization function
def show(tensor, ch=1, size=(28,28), num=16):
  # tensor: 128 x 784 (Batch size = 128, 28*28 = 784)
  data=tensor.detach().cpu().view(-1,ch,*size) # 128 x 784 --> 128 x 1 x 28 x 28
  # matplotlib has a different order (Width ,Height ,Channels) than pytorch for images
  grid = make_grid(data[:num], nrow=4).permute(1,2,0)   # 1 x 28 x 28  = 28 x 28 x 1
  plt.imshow(grid)
  plt.show()


In [None]:
# setup of the main parameters and hyperparameters
epochs = 200
cur_step = 0
# every how many steps we want to show information on the screen
info_step = 300
mean_gen_loss = 0
mean_disc_loss = 0

#dimensionality of noise vector that is the input of the generator
z_dim = 64
# learning rate
lr = 0.0002 #0.0001 #0.0002 # 0.00001

# Binary Cross Entropy with Logits (transfoms the output wiht a sigmoid from 0 to 1)
loss_func = nn.BCEWithLogitsLoss()

# batch size
bs = 128
device = 'cuda'
print(device)

# 1. where to store the data (.), 2. download the data = True 3. tranform according to the Tensor structure
# 4. shuffle: at every epoch we shuffle the data, 5. batch size
dataloader = DataLoader(MNIST('.', download=True, transform=transforms.ToTensor()),shuffle=True, batch_size=bs)

# every epoch is going to have number of steps:
# number of steps = 60000 / 128 = 468.75

## Declare the models
#### Fill in the missing blanks

In [None]:
# Generator
def generatorBlock(input, output):
  return nn.Sequential(
    nn.Linear(input, output),   # Linear Layer (use torch.nn)
    nn.BatchNorm1d(output),   # Batch Norm 1d
    nn.ReLU(),   # ReLU
  )

class Generator(nn.Module):
  def __init__(self, z_dim=64, image_dim=784, h_dim=128): # z_dim: latent space dimensionality
    super().__init__()
    self.generator = nn.Sequential( # Fill in the rest by using the z_dim and the h_dim
        generatorBlock(z_dim, h_dim), # 64 --> 128
        generatorBlock(h_dim, h_dim * 2), # 128 --> 256
        generatorBlock(h_dim * 2, h_dim * 4), # 256 --> 512
        generatorBlock(h_dim * 4, h_dim * 8), # 512 --> 1024
        generatorBlock(h_dim * 8, image_dim), # 1024 --> 784 (28x28)
        nn.Sigmoid(), # to make the values between 0 and 1
    )

  def forward(self, noise):
       return self.generator(noise)

# function that generates noise
def gen_noise(number, z_dim):
  return torch.randn(number, z_dim).to(device)

In [None]:
## Discriminator
def discriminatorBlock(input, output):
  return nn.Sequential(
      nn.Linear(input, output),   # Linear Layer
      nn.LeakyReLU(0.2),   # LeakyReLU default 0.1
  )

class Discriminator(nn.Module):
  def __init__(self, image_dim=784, h_dim=256):
    super().__init__()
    self.discriminator=nn.Sequential( # Fill in the rest by using the image_dim and the h_dim
        discriminatorBlock(image_dim, h_dim * 4), # 784 --> 1024
        discriminatorBlock(h_dim * 4, h_dim * 2), # 1024 --> 512
        discriminatorBlock(h_dim * 2, h_dim), # 512 --> 256
        nn.Linear(h_dim, 1) # output: 256 --> 1
    )

  def forward(self, image):
      return self.discriminator(image)

## Main code

In [None]:
import torch.optim as optim

gen = Generator(z_dim).to(device)
# optimizer of the generator
gen_opt = optim.Adam(gen.parameters()) # Adam optimizer
disc = Discriminator().to(device)
# optimizer of the discriminator
disc_opt = optim.Adam(disc.parameters())  # Adam optimizer

In [None]:
# check your generator
gen

In [None]:
# check your discriminator
disc

In [None]:
x,y=next(iter(dataloader))
print(x.shape, y.shape)
print(y[:10])

In [None]:
noise = gen_noise(bs, z_dim)
fake = gen(noise)
show(fake)

# Here we see the initial output of passing the noise through the generator
# Since the generator did not start learning, it produces a very noisy output

## Compute the loss

In [None]:
# generator loss
def calc_gen_loss(loss_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
                  gen: nn.Module,
                  disc: nn.Module,
                  number: int,
                  z_dim: int) -> torch.Tensor: # number is the number of elemenent we want to process, i.e. batch size
    """
    Calculate the generator's loss in a GAN setup.

    Args:
        loss_func (torch.nn.Module): The loss function to compute the generator loss (e.g., BCEWithLogitsLoss).
        gen (torch.nn.Module): The generator model.
        disc (torch.nn.Module): The discriminator model.
        number (int): The batch size (number of samples to process in this step).
        z_dim (int): Dimensionality of the noise vector (latent space).

    Returns:
        torch.Tensor: The computed generator loss.
    """
    noise = gen_noise(number, z_dim)
    fake = gen(noise)
    pred = disc(fake)
    targets=torch.ones_like(pred) # 1: real, 0: fake
    gen_loss=loss_func(pred,targets)

    return gen_loss


def calc_disc_loss(loss_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
                  gen: nn.Module,
                  disc: nn.Module,
                  number: int,
                  real: torch.Tensor,
                  z_dim: int) -> torch.Tensor:
    """
    Calculate the discriminator's loss in a GAN setup.

    Args:
        loss_func (torch.nn.Module): The loss function to compute the discriminator loss (e.g., BCEWithLogitsLoss).
        gen (torch.nn.Module): The generator model.
        disc (torch.nn.Module): The discriminator model.
        number (int): The batch size (number of samples to process in this step).
        real (torch.Tensor): A batch of real images.
        z_dim (int): Dimensionality of the noise vector (latent space).

    Returns:
        torch.Tensor: The computed discriminator loss.
    """

    # number is the number of elemenent we want to process, i.e. batch size
    # real is the number of real images
    noise = gen_noise(number, z_dim)
    fake = gen(noise)
    disc_fake = disc(fake.detach()) # need to detach so that we do not change the generator
    disc_fake_targets=torch.zeros_like(disc_fake) # 1: real, 0: fake
    disc_fake_loss=loss_func(disc_fake, disc_fake_targets)

    disc_real = disc(real)
    disc_real_targets=torch.ones_like(disc_real)
    disc_real_loss=loss_func(disc_real, disc_real_targets)

    disc_loss=(disc_fake_loss+disc_real_loss)/2

    return disc_loss





**GANs are known for their training instability and difficulty in achieving convergence. Discuss the potential causes of these issues**

### Causes of Training Instability and Difficulty in GANs:
- ```Non-convergence of Loss Functions:``` GANs aim for a Nash equilibrium where the generator and discriminator are perfectly balanced. This is often hard to achieve as the optimization of two competing objectives may not converge smoothly.

- ```Mode Collapse:``` The generator may produce limited variations of data (or even a single mode), leading to a lack of diversity in generated outputs.

- ```Vanishing Gradients:``` If the discriminator becomes too strong, the generator's gradients diminish, leading to slower learning.

- ```Imbalance in Generator & Discriminator Training:``` If one model (generator or discriminator) significantly outpaces the other during training, it can cause instability. For example, a highly capable discriminator might overpower the generator and lead to no meaningful updates.

- ```Lack of Proper Regularization:``` Overfitting in the discriminator or improper handling of the generator's outputs can lead to instability.

- ```Sensitivity to Hyperparameters:``` GANs are highly sensitive to learning rates, batch sizes, and other hyperparameters. Slight deviations may lead to poor convergence or collapse.

- ```Noisy Gradient Updates:``` The stochasticity inherent in gradient-based optimization can lead to oscillations or divergence in training.

## Training loop

In [None]:
### batch size = 128
### 60000 / 128 = 468.75  = 469 steps in each epoch
### Each step is going to process 128 images = size of the batch (except the last step)

for epoch in range(epochs):
  for real, _ in tqdm(dataloader):
    ### discriminator
    disc_opt.zero_grad(), # set the gradients to zero


    current_batch_size=len(real) # real: 128 x 1 x 28 x 28
    real = real.view(current_batch_size, -1) # 128 x 784
    real = real.to(device)

    disc_loss = calc_disc_loss(loss_func,gen,disc,current_batch_size,real,z_dim)
    disc_loss.backward(), # Backpropagation
    disc_opt.step, # Optimizer step

    ### generator
    gen_opt.zero_grad() , # set the gradients to zero
    gen_loss = calc_gen_loss(loss_func,gen,disc,current_batch_size,z_dim)
    gen_loss.backward(), # Backpropagation
    gen_opt.step(), # Optimizer step

    ### statistics + visualization

    # adding the values into the losses
    mean_disc_loss+=disc_loss.item()/info_step # .item() transforms the tensor value into a standalone value
    mean_gen_loss+=gen_loss.item()/info_step

    if cur_step % info_step == 0 and cur_step>0:
      fake_noise = gen_noise(current_batch_size, z_dim)
      fake = gen(fake_noise)
      show(fake)
      show(real)
      print(f"{epoch}: step {cur_step} / Gen loss: {mean_gen_loss} / disc_loss: {mean_disc_loss}")
      mean_gen_loss, mean_disc_loss=0,0
    cur_step+=1


Observed Score:

```lr: 0.0001:``` Gen_loss = 0.587, disc_loss = 0.705

```lr: 0.00002:``` Gen_loss = 0.595, disc_loss = 0.750

**In the quantitative assessment of GANs, especially for complex image datasets, which metrics are suitable for evaluating the quality and diversity of the generated images?**

###Suitable Metrics for Evaluating GANs on Complex Datasets:

```Fréchet Inception Distance (FID):```

- Measures the similarity between the distributions of real and generated images in the feature space of a pre-trained model.
- Lower FID scores indicate higher similarity between generated and real images.

```Inception Score (IS): ```
- Evaluates the quality and diversity of generated images.
- A higher IS indicates that images are meaningful and belong to a variety of categories.

```Precision and Recall for Distributions:```

- Measures the quality (precision) and diversity (recall) of generated samples relative to real samples.

```Kernel Inception Distance (KID):```

- Similar to FID but uses polynomial kernel methods for a more robust comparison.


```Diversity Score:```

- Evaluates the variance across the generated images, highlighting the generator's ability to produce diverse samples.