In [None]:
# @title <p>Essential Import
import os, shutil, json
from PIL import Image
from zipfile import ZipFile
import matplotlib.pyplot as plt
import numpy as np, pandas as pd, random as rd
import warnings
warnings.filterwarnings("ignore")

In [None]:
# @title <p>Torch Essential Import
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets
from torch.utils.tensorboard import SummaryWriter
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(img_dim, 128),
        nn.LeakyReLU(0.1),
        nn.Linear(128, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim, 256),
        nn.LeakyReLU(0.1),
        nn.Linear(256, img_dim),
        nn.Tanh(),
    )

  def forward(self, x):
    return self.gen(x)

In [None]:
# Hyperparamters etc.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lr = 3e-4
z_dim = 64
img_dim = 28 * 28 *1
bs = 2
num_epochs = 50

disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)
fixed_noise = torch.randn((bs, z_dim)).to(device)
transforms = T.Compose(
    [T.ToTensor(), T.Normalize((0.1307,), (0.3061,))]
)

dataset = datasets.MNIST(root="dataset", transform = transforms, download=True)
loader = DataLoader(dataset, batch_size=bs, shuffle=True)
opt_gen = torch.optim.Adam(gen.parameters(), lr=lr)
opt_disc = torch.optim.Adam(disc.parameters() ,lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(F"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(F"runs/GAN_MNIST/real")

step = 0

In [None]:
for epoch in range(num_epochs):
  for idx, (real, _) in enumerate(loader):
    real = real.view(-1, 784)
    batch_size = real.shape[0]

    ## Train Discriminators max log(D(real)) + log() =D(G(z))
    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)

    disc_real = disc(real).view(-1)
    lossD_real = criterion(disc_real, torch.ones_like(disc_real))

    disc_fake = disc(fake).view(-1)
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

    lossD = (lossD_real + lossD_fake) /2

    disc.zero_grad()
    lossD.backward(retain_graph=True)
    opt_disc.step()

    ## Train Generator min log(1- D(G(z))) <--> max log(D(G(z)))
    output = disc(fake).view(-1)
    lossG  = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()

    if idx == 0:
      print(
        f"Epoch [{epoch}/{num_epochs}] Batch {idx}/{len(loader)} \
        Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
      )

      with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
        data = real.reshape(-1, 1, 28, 28)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        img_grid_real = torchvision.utils.make_grid(data, normalize=True)

        writer_fake.add_image(
            "MNIST Fake Images", img_grid_fake, global_step=step
        )

        writer_real.add_image(
            "MNIST real Images", img_grid_real, global_step=step
        )

        step += 1

Epoch [0/50] Batch 0/30000         Loss D: 0.6191, loss G: 0.7407
Epoch [1/50] Batch 0/30000         Loss D: 0.0895, loss G: 2.8929
Epoch [2/50] Batch 0/30000         Loss D: 0.0506, loss G: 5.9136
Epoch [3/50] Batch 0/30000         Loss D: 0.0155, loss G: 5.1268


KeyboardInterrupt: ignored