## Import Packages and define helper functions

In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [None]:
def get_noise(n_samples, z_dim, device='cpu'):
  """
  return (n_samples, z_dim, 1, 1)
  """
  # view(len(noise), self.z_dim, 1, 1)
  noise = torch.randn(n_samples, z_dim, device=device)
  return noise.view(len(noise), z_dim, 1, 1)

## Downlaod dataset and preprocess it

## Create Models(Generator and Discriminator)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

In [None]:
dataloader = DataLoader(MNIST('.', download=True, transform=transform), batch_size=128, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 5454696.97it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 160205.49it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1513099.68it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5115609.23it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw






Generator

In [None]:
def gen_block(input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
  if final_layer:
    return nn.Sequential(
        nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
        nn.Tanh()
    )
  else:
    return nn.Sequential(
        nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
        nn.BatchNorm2d(output_channels),
        nn.ReLU(inplace=True)
    )


In [None]:
class Generator(nn.Module):
  def __init__(self, z_dim=10, image_channel=1, hidden_dim=64):
    super(Generator, self).__init__()
    self.gen_layers = nn.Sequential(
        # (1, 10, 1, 1)
        gen_block(z_dim, hidden_dim*4, kernel_size=3, stride=2),
        # (1, C, 3, 3)
        gen_block(hidden_dim*4, hidden_dim*2, kernel_size=4, stride=1),
        # (1, C, 6, 6)
        gen_block(hidden_dim*2, hidden_dim, kernel_size=3, stride=2),
        # (1, C, 13, 13)
        gen_block(hidden_dim, image_channel, kernel_size=4, stride=2, final_layer=True)
        # (1, image_channel, 28, 28)
    )

  def forward(self, noise):
    """
    Input:
      noise: (None, z_dim, 1, 1)
    Output:
      fake: (None, 3, 28, 28)
    """
    fake = self.gen_layers(noise)
    return fake

Discriminator

In [None]:
def disc_block(input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
  if final_layer:
    return nn.Sequential(
        nn.Conv2d(input_channels, output_channels, kernel_size, stride)
    )
  else:
    return nn.Sequential(
        nn.Conv2d(input_channels, output_channels, kernel_size, stride),
        nn.BatchNorm2d(output_channels),
        nn.LeakyReLU(0.2, inplace=True)
    )

In [None]:
class Discriminator(nn.Module):
  def __init__(self, image_channel=1, hidden_dim=16):
    super(Discriminator, self).__init__()
    self.disc_layers = nn.Sequential(
        disc_block(image_channel, hidden_dim),
        disc_block(hidden_dim, hidden_dim*2),
        disc_block(hidden_dim*2, 1, final_layer=True)
    )

  def forward(self, image):
    """
    Inputs:
      image: (None, 1, 28, 28)
    Outputs:
      dis_red: (None, 1)
    """
    disc_pred = self.disc_layers(image)
    return disc_pred.view(len(disc_pred), -1)


## Training

In [None]:
epochs = 200
device='cpu'
z_dim = 10
beta_1 = 0.5
beta_2 = 0.999
lr = 0.0002
z_dim = 64
display_step = 500

Define loss function - BCE

In [None]:
criterion = nn.BCEWithLogitsLoss()

Define optimizers for Generator and Discriminator

In [None]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

Initalize weights for Generator and Discriminator

In [None]:
def weights_init(m):
  if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
    # set mean and variance as 0.0 and 0.02 respectively
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
  if isinstance(m, nn.BatchNorm2d):
    torch.nn.init.normal_(m.weight, 0.0, 0.02)
    torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [None]:
mean_generator_loss = 0
mean_discriminator_loss = 0
curr_step = 0

for epoch in range(epochs):
  for real, _ in tqdm(dataloader):
    curr_batch_size = len(real)
    real = real.to(device)

    # update Discriminator
    disc_opt.zero_grad()
    noise = get_noise(curr_batch_size, z_dim, device=device)
    fake = gen(noise)

    disc_fake_pred = disc(fake.detach())
    disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))

    disc_real_pred = disc(real)
    disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))

    disc_loss = (disc_fake_loss + disc_real_loss) / 2

    disc_loss.backward()
    disc_opt.step()

    # update Generator
    gen_opt.zero_grad()

    noise = get_noise(curr_batch_size, z_dim, device=device)
    fake2 = gen(noise)
    disc_fake2_pred = disc(fake2)

    gen_loss = criterion(disc_fake2_pred, torch.ones_like(disc_fake2_pred))

    gen_loss.backward()
    gen_opt.step()

    # for display
    mean_discriminator_loss += disc_loss.item() / display_step
    mean_generator_loss += gen_loss.item() / display_step

    if curr_step % display_step == 0 and curr_step > 0:
      print(f"Epoch: {curr_step}: Generator Loss: {mean_generator_loss}, Discriminator Loss: {mean_discriminator_loss}")
      noise = get_noise(curr_batch_size, z_dim, device=device)
      fake = gen(noise)
      show_tensor_images(fake)
      show_tensor_images(real)
      mean_generator_loss = 0
      mean_discriminator_loss = 0

    curr_step += 1