# Generate numbers using the MNIST dataset

In [None]:
import torch

if torch.cuda.is_available(): 
 print("Using GPU")
 dev = "cuda:0" 
else: 
 dev = "cpu" 
device = torch.device(dev)

In [None]:
from torchvision import datasets, transforms

transform = transforms.ToTensor()
dataset = datasets.MNIST(root = './data', train = True, download = False, transform = transform)

In [None]:
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size = 64, drop_last = True) # drop_last -> discard the last incomplete batch if the dataset size is not divisible by the batch size

In [None]:
iterator = iter(dataloader) # Convert to iterator
images, labels = next(iterator) # Pick the next element in the iterator. In this case it is the first element

print(images.shape)
print(labels)

In [None]:
images = images.to(device)
labels = labels.to(device)

In [None]:
from torchvision.utils import make_grid
from matplotlib.pyplot import *


def show_image_grid(images: torch.Tensor, ncol: int):
    image_grid = make_grid(images, ncol)     # Make images into a grid
    image_grid = image_grid.permute(1, 2, 0) # Move channel to the last
    image_grid = image_grid.cpu().numpy()    # Convert into Numpy
    subplots(figsize = (8, 8))
    imshow(image_grid)
    xticks([])
    yticks([])

show_image_grid(images, ncol = 8)

We need to decide the size of the random value vector. More random values may potentially add variations to generated images. However, it would take more time and even more difficulty to train the generator as there will be much more network parameters to adjust during the optimization process. Also, MNIST images are not too complex, and we probably do not need too many random values in input vectors.

## Generator

In [None]:
from torch import nn

generator = torch.nn.Sequential(nn.Linear(100, 128, device = dev),
                          nn.LeakyReLU(0.01),
                          nn.Linear(128, 784, device = dev),
                          nn.Sigmoid())

* Converts 100 random values into 128 numeric values via the fully-connected linear layer.
* Applies non-linearity via leaky ReLU
* Converts 128 numeric values into 784 numeric values (784 = 28x28, which is the image size).
* Applies the sigmoid to squeeze output values into the 0 to 1 range (the value range in greyscale images).

In [None]:
def generate_images():
    # Random value inputs (batch size 64)
    z = torch.randn(64, 100, device = dev)
    # Generator network output
    output = generator(z)
    # Reshape the output into 64 images
    generated_images = output.reshape(64, 1, 28, 28)
    return generated_images

In [None]:
generated_images = generate_images()
show_image_grid(generated_images, ncol = 8)

## Discriminator

In [None]:
discriminator = nn.Sequential(nn.Linear(784, 128, device = dev),
                              nn.LeakyReLU(0.01),
                              nn.Linear(128, 1, device = dev))

* Converts a flattened input image into 128 values via the fully-connected linear layer.
* It applies non-linearity via leaky ReLU.
* Converts 128 values into one value.

In [None]:
discriminator.eval()
with torch.no_grad():
    prediction = discriminator(generated_images.reshape(-1, 784))

# print(prediction)

We haven’t trained the discriminator, so it has no idea whether an image is real or fake and outputs meaningless values. Even if we use real MNIST images, the discriminator behaves similarly. So the discriminator can not distinguish between real and fake images

We need to train the discriminator as a classifier (supervised learning), which needs labels. We treat MNIST images as real images and generated images as fake images. In other words, when the discriminator classifies an MNIST image, the label is real (1), and when the discriminator classifies a generated image, the label is fake (0).

## Train the Discriminator

In [None]:
real_targets = torch.ones(64, 1, device = dev)
fake_targets = torch.zeros(64, 1, device = dev)

We feed a batch of MNIST images into the discriminator, we use real_targets; when we feed a batch of generated images into the discriminator, we use fake_targets. Each time, we calculate the cross-entropy loss for optimization.

In [None]:
from torch.nn import functional as F

def calculate_loss(images: torch.Tensor, targets: torch.Tensor):
    prediction = discriminator(images.reshape(-1, 784))
    loss = F.binary_cross_entropy_with_logits(prediction, targets)
    return loss

In [None]:
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr = 1.0e-4)

During training, the discriminator predicts whether an input image is real or fake. We calculate the loss value and apply back-propagation so that the optimizer can adjust network parameters (weights and biases).

In [None]:
from tqdm import tqdm

In [None]:
# On training mode
discriminator.train()

# On eval mode
generator.eval()

# Training loop
for epoch in tqdm(range(100)):
    for images, labels in dataloader:
        # Loss with MNIST image inputs and real_targets as labels
        d_loss = calculate_loss(images.to(device), real_targets.to(device))
  
        # Loss with generated image inputs and fake_targets as labels
        generated_images = generate_images()
        d_loss += calculate_loss(generated_images, fake_targets)

        # Optimizer updates the discriminator parameters
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

It is a usual supervised training for a classification model. The discriminator will learn to distinguish between real (MNIST) images and fake (generated) images. However, the generator network learns nothing in the above training.

## Train Generative Network

In [None]:
import numpy as np
from matplotlib.pyplot import *
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm

if torch.cuda.is_available(): 
 dev = "cuda:0" 
else: 
 dev = "cpu" 
device = torch.device(dev)

# Configuration
epochs      = 100
batch_size  = 64
sample_size = 100    # Number of random values to sample
g_lr        = 1.0e-4 # Generator's learning rate
d_lr        = 1.0e-4 # Discriminator's learning rate

# DataLoader for MNIST
transform = transforms.ToTensor()
dataset = datasets.MNIST(root = './data', train = True, download = True, transform = transform)
dataloader = DataLoader(dataset, batch_size = batch_size, drop_last = True)

# Generator Network
class Generator(nn.Sequential):
    def __init__(self, sample_size: int):
        super().__init__(
            nn.Linear(sample_size, 128, device = dev),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 784, device = dev),
            nn.Sigmoid())

        # Random value vector size
        self.sample_size = sample_size

    def forward(self, batch_size: int):
        # Generate randon values
        z = torch.randn(batch_size, self.sample_size, device = dev)

        # Generator output
        output = super().forward(z)

        # Convert the output into a greyscale image (1x28x28)
        generated_images = output.reshape(batch_size, 1, 28, 28)
        return generated_images


# Discriminator Network
class Discriminator(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Linear(784, 128, device = dev),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1, device = dev))

    def forward(self, images: torch.Tensor, targets: torch.Tensor):
        prediction = super().forward(images.reshape(-1, 784))
        loss = F.binary_cross_entropy_with_logits(prediction, targets)
        return loss


# To save images in grid layout
def save_image_grid(epoch: int, images: torch.Tensor, ncol: int):
    image_grid = make_grid(images, ncol)     # Images in a grid
    image_grid = image_grid.permute(1, 2, 0) # Move channel last
    image_grid = image_grid.cpu().numpy()    # To Numpy

    subplots()
    imshow(image_grid)
    xticks([])
    yticks([])
    # savefig(f'generated_{epoch:03d}.jpg')
    close()


# Real and fake labels
real_targets = torch.ones(batch_size, 1, device = dev)
fake_targets = torch.zeros(batch_size, 1, device = dev)


# Generator and Discriminator networks
generator = Generator(sample_size)
discriminator = Discriminator()


# Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr = d_lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr = g_lr)


# Training loop
for epoch in tqdm(range(epochs)):

    d_losses = []
    g_losses = []

    for images, labels in dataloader:
        #===============================
        # Discriminator Network Training
        #===============================

        # Loss with MNIST image inputs and real_targets as labels
        discriminator.train()
        d_loss = discriminator(images.to(device), real_targets.to(device))

        # Generate images in eval mode
        generator.eval()
        with torch.no_grad():
            generated_images = generator(batch_size)

        # Loss with generated image inputs and fake_targets as labels
        d_loss += discriminator(generated_images, fake_targets)

        # Optimizer updates the discriminator parameters
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #===============================
        # Generator Network Training
        #===============================

        # Generate images in train mode
        generator.train()
        generated_images = generator(batch_size)

        # Loss with generated image inputs and real_targets as labels
        discriminator.eval() # eval but we still need gradients
        g_loss = discriminator(generated_images, real_targets)

        # Optimizer updates the generator parameters
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # Keep losses for logging
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())

    # Print average losses
    print(epoch, np.mean(d_losses), np.mean(g_losses))

    # Save images
    # save_image_grid(epoch, generator(batch_size), ncol = 8)