In [None]:
import os
import numpy as np
from tqdm import tqdm
from statistics import mean
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import MNIST

In [None]:
def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    if torch.backends.mps.is_available():
        return torch.device('mps')
    return torch.device('cpu')

In [None]:
DATA_DIR = 'GAN_data'
GEN_MODEL_FILE_NAME = 'trained_models/basic_gan_generator.pt'
DISC_MODEL_FILE_NAME = 'trained_models/basic_gan_discriminator.pt'
EPOCHS = 100
BATCH_SIZE = 128
LR = 5e-5
LOSS_FUNC = nn.BCEWithLogitsLoss()
LATENT_DIM = 64
HIDDEN_DIM = 128
MEAN_GEN_LOSS = 0
MEAN_DISC_LOSS = 0
DEVICE = get_device()

In [None]:
def show_examples(tensors, size=(28, 28), channels=1, num_images=25):
    image_data = tensors.detach().cpu().view(tensors.shape[0], channels, *size)
    image_grid = make_grid(image_data[:num_images], nrow=5).permute(1, 2, 0)
    plt.imshow(image_grid.clip(min=0, max=1))
    plt.axis('off')
    plt.show()

In [None]:
train_data = MNIST(
    root=DATA_DIR,
    train=True,
    download=True,
    transform=transforms.ToTensor(),
)

test_data = MNIST(
    root=DATA_DIR,
    train=False,
    download=True,
    transform=transforms.ToTensor(),
)

In [None]:
train_dl = DataLoader(
    dataset=train_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
)

test_dl = DataLoader(
    dataset=test_data,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=False,
)

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, hidden_dim, out_dim):
        super(Generator, self).__init__()
        self.generator = nn.Sequential(
            self.generator_block(latent_dim, hidden_dim),
            self.generator_block(hidden_dim, hidden_dim*2),
            self.generator_block(hidden_dim*2, hidden_dim*4),
            self.generator_block(hidden_dim*4, hidden_dim*8),
            nn.Linear(hidden_dim*8, out_dim*out_dim),
            nn.Sigmoid(),
        )
        
    def forward(self, noise):
        return self.generator(noise)
    
    def generator_block(self, in_features, out_features):
        return nn.Sequential(
            nn.Linear(in_features=in_features, out_features=out_features),
            nn.BatchNorm1d(num_features=out_features),
            nn.ReLU(inplace=True),
        )

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            self.discriminator_block(input_dim, hidden_dim*4),
            self.discriminator_block(hidden_dim*4, hidden_dim*2),
            self.discriminator_block(hidden_dim*2, hidden_dim),
            nn.Linear(hidden_dim, 1),
        )
        
    def forward(self, input_):
        return self.discriminator(input_)
    
    def discriminator_block(self, in_features, out_features):
        return nn.Sequential(
            nn.Linear(in_features=in_features, out_features=out_features),
            nn.LeakyReLU(inplace=True),
        )

In [None]:
gen_model = Generator(LATENT_DIM, HIDDEN_DIM, 28)
gen_model

In [None]:
disc_model = Discriminator(28*28, HIDDEN_DIM*2)
disc_model

In [None]:
gen_model.to(DEVICE)
disc_model.to(DEVICE)

gen_opt = torch.optim.Adam(params=gen_model.parameters(), lr=LR)
disc_opt = torch.optim.Adam(params=disc_model.parameters(), lr=LR)

In [None]:
def generate_noise(num_vectors, latent_dim):
    return torch.randn(num_vectors, latent_dim).to(DEVICE)

In [None]:
sample_imgs, sample_labels = next(iter(train_dl))
print(sample_imgs.shape, sample_labels.shape)
print(sample_imgs[0], sample_labels[0])

In [None]:
show_examples(tensors=sample_imgs, size=(28, 28))

In [None]:
noise = generate_noise(BATCH_SIZE, LATENT_DIM)
gen_fake_images = gen_model(noise)
show_examples(tensors=gen_fake_images)

In [None]:
def calculate_generator_loss(generator_model, discriminator_model, loss_function, LATENT_DIM):
    noise = generate_noise(BATCH_SIZE, LATENT_DIM)
    gen_fake_images = generator_model(noise)
    disc_preds = discriminator_model(gen_fake_images)
    disc_pred_targets = torch.ones_like(disc_preds)
    return loss_function(disc_preds, disc_pred_targets)

def calculate_discriminator_loss(generator_model, discriminator_model, loss_function, LATENT_DIM, real_images):
    # loss with fake images
    noise = generate_noise(BATCH_SIZE, LATENT_DIM)
    gen_fake_images = generator_model(noise)
    disc_preds = discriminator_model(gen_fake_images.detach())
    disc_pred_targets = torch.zeros_like(disc_preds)
    disc_fake_loss = loss_function(disc_preds, disc_pred_targets)
    # loss with real images
    disc_preds = discriminator_model(real_images)
    disc_pred_targets = torch.ones_like(disc_preds)
    disc_real_loss = loss_function(disc_preds, disc_pred_targets)
    return (disc_fake_loss+disc_real_loss)/2

In [None]:
calculate_generator_loss(gen_model, disc_model, LOSS_FUNC, LATENT_DIM)

In [None]:
test_images = sample_imgs.view(sample_imgs.shape[0], -1).to(DEVICE)
calculate_discriminator_loss(gen_model, disc_model, LOSS_FUNC, LATENT_DIM, test_images)

In [None]:
for epoch in range(EPOCHS):
    gen_losses = []
    disc_losses = []
    for real_imgs, _ in tqdm(train_dl):
        # discriminator training
        disc_opt.zero_grad()
        test_imgs = real_imgs.view(real_imgs.shape[0], -1).to(DEVICE)
        disc_loss = calculate_discriminator_loss(gen_model, disc_model, LOSS_FUNC, LATENT_DIM, test_imgs)
        disc_loss.backward()
        disc_opt.step()
        # generator training
        gen_opt.zero_grad()
        gen_loss = calculate_generator_loss(gen_model, disc_model, LOSS_FUNC, LATENT_DIM)
        gen_loss.backward()
        gen_opt.step()

        disc_losses.append(disc_loss.item())
        gen_losses.append(gen_loss.item())

    MEAN_GEN_LOSS = mean(gen_losses)
    MEAN_DISC_LOSS = mean(disc_losses)
    noise = generate_noise(BATCH_SIZE, LATENT_DIM)
    fake_imgs = gen_model(noise)
    show_examples(fake_imgs)
    show_examples(real_imgs)
    print(f'EPOCH: {epoch} | Mean Gen Loss: {MEAN_GEN_LOSS} | Mean Disc Loss: {MEAN_DISC_LOSS}\n\n')

In [None]:
torch.save(gen_model.state_dict(), GEN_MODEL_FILE_NAME)
torch.save(disc_model.state_dict(), DISC_MODEL_FILE_NAME)