In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

In [None]:
# Defining the generator block consisting of a Linear layer followed by a batch normalization and ReLU activation function
def generator_block(input_dim, output_dim):
  block = nn.Sequential(
    nn.Linear(input_dim, output_dim),
    nn.BatchNorm1d(output_dim),
    nn.ReLU(inplace=True)
  )
  return block

In [None]:
# Defining the generator model consisting of 4 generator blocks followed by the last Linear layer with Sigmoid activation function
class Generator(nn.Module):

  def __init__(self, input_dim, output_dim, hidden_dim):
    super(Generator, self).__init__()

    self.gen = nn.Sequential(
        generator_block(input_dim, hidden_dim),
        generator_block(hidden_dim, hidden_dim * 2),
        generator_block(hidden_dim * 2, hidden_dim * 4),
        generator_block(hidden_dim * 4, hidden_dim * 8),
        nn.Linear(hidden_dim * 8, output_dim),
        nn.Sigmoid()
    )

  def forward(self, noise):
    return self.gen(noise)

In [None]:
# Generating num_samples noise vectors of size z_dim
def get_noise(num_samples, z_dim, device='cpu'):
  return torch.randn((num_samples, z_dim)).to(device)

In [None]:
# Defining the discriminator block consisting of a Linear layer followed by  LeakyReLU activation function
def discriminator_block(input_dim, output_dim):
  block = nn.Sequential(
      nn.Linear(input_dim, output_dim),
      nn.LeakyReLU(0.2, inplace=True)
  )
  return block

In [None]:
# Defining the discriminator model consisting of 3 discriminator blocks followed by the last Linear layer
class Discriminator(nn.Module):

  def __init__(self, input_dim, hidden_dim):
    super(Discriminator, self).__init__()

    self.disc = nn.Sequential(
        discriminator_block(input_dim, hidden_dim * 4),
        discriminator_block(hidden_dim * 4, hidden_dim * 2),
        discriminator_block(hidden_dim * 2, hidden_dim),
        nn.Linear(hidden_dim, 1)
    )

  def forward(self, image):
    return self.disc(image)

  def get_disc(self):
    return self.disc

In [None]:
# Setup training parameters
criterion = nn.BCEWithLogitsLoss()
n_epochs = 10
z_dim = 64
hidden_dim = 128
lr = 0.00001
batch_size = 128
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Loading the MNIST dataset
dataloader = DataLoader(MNIST(".", download=True, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True)

gen = Generator(z_dim, 784, hidden_dim).to(device) # 784 = 28 x 28
gen_opt = optim.Adam(gen.parameters(), lr=lr)

disc = Discriminator(784, hidden_dim).to(device)
disc_opt = optim.Adam(disc.parameters(), lr=lr)

In [None]:
# Calculates the loss of the3 discriminator model given 2 generator and discriminator models, the BCE loss function,
# a batch of real examples, the number of images the generator should produce, noise dimension, and device
def calculate_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
  noise = get_noise(num_samples=num_images, z_dim=z_dim, device=device) # generate noise vectors
  gan_fake_output = gen(noise) # produce outputs from the generator

  disc_fake_output = disc(gan_fake_output.detach()) # discriminator predictions from the generator's outputs
  disc_fake_loss = criterion(disc_fake_output, torch.zeros_like(disc_fake_output)) # calculating loss of discriminator (fake = 0)

  disc_real_output = disc(real) # discriminator predictions from the real examples
  disc_real_loss = criterion(disc_real_output, torch.ones_like(disc_real_output)) # calculating loss of discriminator (real = 1)

  disc_loss = (disc_fake_loss + disc_real_loss) / 2 # average of the 2 losses
  return disc_loss

In [None]:
def calculate_gen_loss(gen, disc, criterion, num_images, z_dim, device):
  noise = get_noise(num_samples=num_images, z_dim=z_dim, device=device) # generate noise vectors
  gan_fake_output = gen(noise) # produce outputs from the generator

  disc_fake_output = disc(gan_fake_output) # discriminator predictions from the generator's outputs
  gen_loss = criterion(disc_fake_output, torch.ones_like(disc_fake_output)) # calculating loss of generator
  return gen_loss

In [None]:
# GAN Training
for epoch in range(n_epochs):
  for real, _ in dataloader:
    cur_batch_size = len(real)
    real = real.view(cur_batch_size, -1).to(device)

    disc_opt.zero_grad()
    disc_loss = calculate_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)
    disc_loss.backward(retain_graph=True)
    disc_opt.step()

    gen_opt.zero_grad()
    gen_loss = calculate_gen_loss(gen, disc, criterion, batch_size, z_dim, device)
    gen_loss.backward(retain_graph=True)
    gen_opt.step()

