In [1]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets

Variables

In [2]:
NUM_EPOCHS = 10
BATCH_SIZE = 32
LR = 1e-3
LAT_DIM = 64
IMAGE_SIZE = 64
CHANNELS = 1
LOG_INTERVAL = 100

Generator and Discriminator models

In [3]:
class GANGenerator(nn.Module):
    def __init__(self, lat_dim, image_size, channels):
        super(GANGenerator, self).__init__()
        self.input_size = image_size // 4
        
        self.lin = nn.Linear(lat_dim, 128 * self.input_size ** 2)
        self.bn1 = nn.BatchNorm2d(128, 0.8)
        self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.cn1 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
        
        self.bn2 = nn.BatchNorm2d(128, 0.8)
        self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.cn2 = nn.Conv2d(128, 64, 3, stride=1, padding=1)
        
        self.bn3 = nn.BatchNorm2d(64, 0.8)
        self.cn3 = nn.Conv2d(64, channels, 3, stride=1, padding=1)
        self.act = nn.Tanh()

    def forward(self, x):
        x = self.lin(x)
        x = x.view(x.shape[0], 128, self.input_size, self.input_size)
        x = self.bn1(x)
        x = self.up1(x)
        x = self.cn1(x)
        x = self.bn2(x)
        x = F.leaky_relu(x, inplace=True)
        x = self.up2(x)
        x = self.cn2(x)
        x = self.bn3(x)
        x = F.leaky_relu(x, inplace=True)
        x = self.cn3(x)
        out = self.act(x)
        return out

In [4]:
class GANDiscriminator(nn.Module):
    def __init__(self, image_size, channels):
        super(GANDiscriminator, self).__init__()
        
        self.disc_model = nn.Sequential(
            nn.Conv2d(channels, 16, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            
            nn.Conv2d(16, 32, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(32, 0.8),
            
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(64, 0.8),
            
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.25),
            nn.BatchNorm2d(128, 0.8),
        )
        
        
        ds_size = image_size // 2 ** 4
        self.adverse_lyr = nn.Linear(128 * ds_size ** 2, 1)
        
    def forward(self, x):
        x = self.disc_model(x)
        x = x.view(x.shape[0], -1)
        out = torch.sigmoid(self.adverse_lyr(x))
        return out  

Model parameters and data

In [5]:
# instantiate the discriminator and generator models
gen = GANGenerator(LAT_DIM, IMAGE_SIZE, CHANNELS)
disc = GANDiscriminator(IMAGE_SIZE, CHANNELS)

gen = torch.compile(gen)
disc = torch.compile(disc)

# define the loss metric
def gen_loss_func(pred_probs, good_img):
    return F.binary_cross_entropy(pred_probs, good_img)

def disc_loss_func(real_probs, fake_probs, good_img, bad_img):
    real_loss = F.binary_cross_entropy(real_probs, good_img)
    fake_loss = F.binary_cross_entropy(fake_probs, bad_img)
    return (real_loss + fake_loss) / 2

In [6]:
# define the dataset and corresponding dataloader
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist/",
        download=True,
        transform=transforms.Compose(
            [transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), 
             transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
)

# define the optimization schedule for both G and D
opt_gen = torch.optim.Adam(gen.parameters(), lr=LR)
opt_disc = torch.optim.Adam(disc.parameters(), lr=LR)

Training loop

In [None]:
os.makedirs("./images_mnist", exist_ok=True)

for epoch in range(NUM_EPOCHS):
    for idx, (images, _) in enumerate(dataloader):

        # get a batch of real images
        good_img = torch.ones(images.shape[0], 1)
        bad_img = torch.zeros(images.shape[0], 1)

        # get a real image
        actual_images = images.type(torch.FloatTensor)

        # train the generator model
        opt_gen.zero_grad()

        # generate a batch of images based on random noise as input
        noise = torch.randn(images.shape[0], LAT_DIM)
        gen_images = gen(noise)

        # generator model optimization - how well can it fool the discriminator
        generator_loss = gen_loss_func(
            disc(gen_images),
            good_img
        )
        generator_loss.backward()
        opt_gen.step()

        # train the discriminator model
        opt_disc.zero_grad()

        # calculate discriminator loss as average of mistakes(losses) in confusing real images as fake and vice versa
        discriminator_loss = disc_loss_func(
            disc(actual_images),
            disc(gen_images.detach()),
            good_img,
            bad_img
        )

        # discriminator model optimization
        discriminator_loss.backward()
        opt_disc.step()

        batches_completed = epoch * len(dataloader) + idx
        if batches_completed % LOG_INTERVAL == 0:
            print(f"epoch number {epoch:3}"
                  f" | batch number {idx:4}"
                  f" | generator loss = {generator_loss.item():.6f}"
                  f" | discriminator loss = {discriminator_loss.item():.6f}"
            )
            save_image(gen_images.data[:25],
                       f"images_mnist/{epoch}_{batches_completed}.png",
                       nrow=5,
                       normalize=True
            )

epoch number   0 | batch number    0 | generator loss = 0.682919 | discriminator loss = 0.692570
epoch number   0 | batch number  100 | generator loss = 1.353329 | discriminator loss = 0.392550
epoch number   0 | batch number  200 | generator loss = 1.608291 | discriminator loss = 0.389186
epoch number   0 | batch number  300 | generator loss = 1.080813 | discriminator loss = 0.581661
epoch number   0 | batch number  400 | generator loss = 0.915626 | discriminator loss = 0.338195
epoch number   0 | batch number  500 | generator loss = 1.142022 | discriminator loss = 0.273801
epoch number   0 | batch number  600 | generator loss = 0.633800 | discriminator loss = 0.784687
epoch number   0 | batch number  700 | generator loss = 2.952504 | discriminator loss = 0.064790
epoch number   0 | batch number  800 | generator loss = 1.005011 | discriminator loss = 0.796017
epoch number   0 | batch number  900 | generator loss = 0.354879 | discriminator loss = 1.133571
epoch number   0 | batch numbe