## Import Packages and Define Helper Functions

In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def make_grad_hook():
    '''
    Function to keep track of gradients for visualization purposes,
    which fills the grads list when using model.apply(grad_hook).
    '''
    grads = []
    def grad_hook(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            grads.append(m.weight.grad)
    return grads, grad_hook

def get_noise(n_samples, z_dim, device='cpu'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, z_dim)
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
      n_samples: the number of samples to generate, a scalar
      z_dim: the dimension of the noise vector, a scalar
      device: the device type
    '''
    noise = torch.randn(n_samples, z_dim, device=device)
    return noise.view(len(noise), z_dim, 1, 1)

## Download Dataset and Preprocess it

In [None]:
batch_size = 128

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 17410818.45it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 481289.26it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4469454.84it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5622942.38it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw






## Create Models(Generator and Critic)

In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim=10, image_channel=1, hidden_dim=64):
    super(Generator, self).__init__()
    self.gen_layers = nn.Sequential(
        self.gen_block(z_dim, hidden_dim*4),
        self.gen_block(hidden_dim*4, hidden_dim*2, kernel_size=4, stride=1),
        self.gen_block(hidden_dim*2, hidden_dim),
        self.gen_block(hidden_dim, image_channel, kernel_size=4, final_layer=True)

    )

  def gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
    if final_layer:
      return nn.Sequential(
          nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
          nn.Tanh()
      )
    else:
      return nn.Sequential(
          nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
          nn.BatchNorm2d(output_channels),
          nn.ReLU(inplace=True),
      )

  def forward(self, noise):
    return self.gen_layers(noise)

In [None]:
class Critic(nn.Module):
  def __init__(self, image_channel=1, hidden_dim=64):
    super(Critic, self).__init__()
    self.critc_layers = nn.Sequential(
        self.critic_block(image_channel, hidden_dim),
        self.critic_block(hidden_dim, hidden_dim*2),
        self.critic_block(hidden_dim*2, 1, final_layer=True)
    )

  def critic_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
    if final_layer:
      return nn.Sequential(
          nn.Conv2d(input_channels, output_channels, kernel_size, stride),
      )
    else:
      return nn.Sequential(
          nn.Conv2d(input_channels, output_channels, kernel_size, stride),
          nn.BatchNorm2d(output_channels),
          nn.LeakyReLU(0.2, inplace=True),
      )

  def forward(self, image):
    critc_pred = self.critc_layers(image)
    # (None, 1, 1, 1)

    return critc_pred.view(len(critc_pred), -1)

In [None]:
device = 'cuda'

Weight Initalization of models

In [None]:
z_dim = 64
gen = Generator(z_dim).to(device)
crit = Critic().to(device)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
critic = crit.apply(weights_init)

## Trainig

Define loss functions

In [None]:
def get_gen_loss(gen, critic, curr_batch_size, z_dim, device):
  noise = get_noise(curr_batch_size, z_dim, device)
  fake_image = gen(noise)
  critic_fake_pred = critic(fake_image)
  gen_loss = -1. * torch.mean(critic_fake_pred)
  return gen_loss

In [None]:
def get_critic_loss(gen, critic, real_image, c_lambda, curr_batch_size, z_dim, device):


  noise = get_noise(curr_batch_size, z_dim, device)
  fake_image = gen(noise)
  critic_fake_pred = critic(fake_image)
  critic_real_pred = critic(real_image)

  # calculating gradient penalty
  epsilon = torch.rand(len(real_image), 1, 1, 1, device=device, requires_grad=True)

  mixed_images = real_image * epsilon + fake_image * (1 - epsilon)
  mixed_scores = critic(mixed_images)

  gradient = torch.autograd.grad(
      inputs=mixed_images,
      outputs=mixed_scores,
      grad_outputs=torch.ones_like(mixed_scores),
      create_graph=True,
      retain_graph=True,
  )[0]

  gradient = gradient.view(len(gradient), -1)
  gradient_norm = gradient.norm(2, dim=1)
  penalty = torch.mean((gradient_norm - 1)**2)

  critic_loss = torch.mean(critic_fake_pred) - torch.mean(critic_real_pred) + c_lambda * penalty
  return critic_loss

Define Optimizers

In [None]:
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999

gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
critic_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))

Training

In [None]:
curr_step = 0
epochs = 200
crit_repeats = 5
c_lambda = 10
generator_losses = []
critic_losses = []
mean_iteration_critic_loss = 0
display_step = 500

for epoch in range(epochs):
  for real, _ in tqdm(dataloader):
    curr_batch_size = len(real)
    real = real.to(device)


    # update critic
    mean_iteration_critic_loss = 0
    for _ in range(crit_repeats):
      critic_opt.zero_grad()
      critic_loss = get_critic_loss(gen, critic, real, c_lambda, curr_batch_size, z_dim, device)
      critic_loss.backward()
      critic_opt.step()

      mean_iteration_critic_loss += critic_loss.item() / crit_repeats

    critic_losses += [mean_iteration_critic_loss]

    # update generator
    gen_opt.zero_grad()
    gen_loss = get_gen_loss(gen, critic, curr_batch_size, z_dim, device)
    gen_loss.backward()
    gen_opt.step()

    generator_losses += [gen_loss.item()]

    if curr_step % display_step == 0 and curr_step > 0:
      gen_mean = sum(generator_losses[-display_step:]) / display_step
      crit_mean = sum(critic_losses[-display_step:]) / display_step
      print(f"Epoch {epoch}, step {curr_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
      fake_noise = get_noise(curr_batch_size, z_dim, device)
      fake = gen(fake_noise)
      show_tensor_images(fake)
      show_tensor_images(real)
      step_bins = 20
      num_examples = (len(generator_losses) // step_bins) * step_bins
      plt.plot(
        range(num_examples // step_bins),
        torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
        label="Generator Loss"
      )
      plt.plot(
        range(num_examples // step_bins),
        torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
        label="Critic Loss"
      )
      plt.legend()
      plt.show()

    curr_step += 1







