# 1. Importing Libraries

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import imageio

import torch
import torch.nn as nn

from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter # to print to tensorboard

# 2. Coding the model

1. Generator And Discriminator:


In [None]:
img_size = 28
channels = 1
img_shape = (channels, img_size, img_size)

In [None]:
img_shape = (1, 28, 28)

In [None]:
class Generator(nn.Module):
  def __init__(self, latent_dim,img_shape):    #latent_dim = z_dim
    super().__init__()

    def layer_block(input_size, output_size, normalize = True):
      layers = [nn.Linear(input_size, output_size)]
      if normalize:
        layers.append(nn.BatchNorm1d(output_size, 0.8))
      layers.append(nn.LeakyReLU(0.2, inplace = True))
      return layers


    self.model = nn.Sequential(
        *layer_block(latent_dim, 128, normalize = False),
        *layer_block(128,256),
        nn.Linear(256,int(np.prod(img_shape))),
        nn.Tanh()

    )

  def forward(self, z):
    img = self.model(z)
    img = img.view(img.size(0), *img_shape)
    return img


class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.img_shape = img_shape
    self.flattened_size = int(np.prod(self.img_shape))
    self.model = nn.Sequential(
        nn.Linear(self.flattened_size, 512),   # (int(np.prod(img_shape) meaning?
        nn.LeakyReLU(0.2, inplace = True),
        nn.Linear(512, 256),
        nn.LeakyReLU(0.2, inplace = True),
        nn.Linear(256,1),
        nn.Sigmoid()
    )

  def forward(self,img):
    img_flat = img.view(img.size(0), -1)
    verdict = self.model(img_flat)
    return verdict


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
#creating instances
generator = Generator(latent_dim = 128, img_shape = (1,28,28)).to(device)
discriminator = Discriminator().to(device)
adv_loss = torch.nn.BCELoss().to(device)


#3. Getting datasets

In [None]:
batch_size = 32
os.makedirs('data/mnist', exist_ok = True)
dataloader = DataLoader(datasets.MNIST('data/mnist', train = True, download = True,
                                       transform = transforms.Compose([transforms.ToTensor(),
                                                                       transforms.Normalize((0.5,),(0.5,))]
                                                                      )),
                        batch_size = batch_size, shuffle = True
                        )

#4. Training GAN:

In [None]:
lr = 3e-4
latent_dim = 128 # 64, 128, 256
batchSize = 32  # Batch size
epochs = 200  # Change as per your need
logStep = 625  # Change as per your need

In [None]:
os.makedirs('output_dir/images', exist_ok = True)

In [None]:
optim_G = torch.optim.Adam(generator.parameters(), lr = lr)
optim_D = torch.optim.Adam(discriminator.parameters(), lr = lr)

In [None]:
fixedNoise = torch.randn((batch_size,
                              latent_dim)).to(device)

writerFake = SummaryWriter(f"logs/fake")
writerReal = SummaryWriter(f"logs/real")
def prepareVisualisation(epochs,i,loaderlen,lossD, lossG, writerFake, writerReal, step):
  with torch.no_grad():
    fake = generator(fixedNoise.reshape(-1,1,28,28))
    data = real.reshape(-1,1,28,28)

    imgGridFake = torchvision.utils.make_grid(fake, normalize = True)
    imgGridReal = torchvision.utlis.make_grid(data, normalize = True)
    writerFake.add_image("Mnist Fake Images",
                            imgGridFake,
                            global_step=step)
    writerReal.add_image("Mnist Real Images",
                          imgGridReal,
                          global_step=step)
    # increment step
    step += 1
    return step





In [None]:
losses = []
step = 0
images_for_gif = []
epochs = 5
for epoch in range(epochs):
  for i, (real, _)in enumerate(dataloader):   # no model.train()?
    real = real.view(-1, 784).to(device)

  ## Training discrimibator:
    noise = torch.randn(batch_size,latent_dim).to(device)
    fake = generator(noise)
    discReal = discriminator(real).view(-1)
    lossDreal = adv_loss(discReal, torch.ones_like(discReal))
    discFake = discriminator(fake).view(-1)
    lossDfake = adv_loss(discFake, torch.zeros_like(discFake))

    lossD = (lossDreal + lossDfake) / 2
    discriminator.zero_grad()
    lossD.backward(retain_graph=True)
    optim_D.step()

    ### Training Generator
    output = discriminator(fake).view(-1)
    lossG = adv_loss(output, torch.ones_like(output))
    generator.zero_grad()
    lossG.backward()
    optim_G.step()
  print(
      f"Epoch [{epoch}/{epochs}] Batch {i}/{len(dataloader)} \
                            Loss DISC: {lossD:.4f}, loss GEN: {lossG:.4f}"
  )

  losses.append((lossG.item(),lossD.item()))g
  image_filename = f'output_dir/images/{epoch}.png'
  save_image(fake.data[:25], image_filename, nrow = 5, normalize = True)
  images_for_gif.append(imageio.imread(image_filename))


imageio.mimwrite(f'output_dir/genimage.gif', images_for_gif, fps = len(fake)/5)




Epoch [0/5] Batch 1874/1875                             Loss DISC: 0.2194, loss GEN: 3.1103


  images_for_gif.append(imageio.imread(image_filename))


Epoch [1/5] Batch 1874/1875                             Loss DISC: 0.2035, loss GEN: 3.1787
Epoch [2/5] Batch 1874/1875                             Loss DISC: 0.1654, loss GEN: 4.0779
Epoch [3/5] Batch 1874/1875                             Loss DISC: 0.4142, loss GEN: 2.3229
Epoch [4/5] Batch 1874/1875                             Loss DISC: 0.2253, loss GEN: 2.9356


Resources:


Code:
https://www.baeldung.com/cs/pytorch-generative-adversarial-networks#:~:text=In%20this%20article%2C%20we%20showed,but%20are%20challenging%20to%20train.

Theory: +code

https://medium.com/ai-society/gans-from-scratch-1-a-deep-introduction-with-code-in-pytorch-and-tensorflow-cb03cdcdba0f

https://github.com/vamsi3/simple-GAN/blob/master/src/PyTorch/gan-mnist-pytorch.py


https://medium.com/@wasuratme96/building-a-simple-gan-model-9bfea22c651f


YT:
GAN explanation:

https://youtu.be/OXWvrRLzEaU?si=s7mw0UYlhrKvIiYR

Simple GAN explanation:

https://youtu.be/OXWvrRLzEaU?si=s7mw0UYlhrKvIiYR


To Try:
1. increasing size of Gen and discrim
2. increasing number of epochs to 200