In [None]:
import torch
from torch import nn
from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader

from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [2]:
def display_gen_pred(imgs, size=(1, 28, 28)):
    img_unflat = imgs.detach().cpu().view(-1, *size)
    img_grid = make_grid(img_unflat[:25], nrow=5)
    plt.imshow(img_grid.permute(1, 2, 0).squeeze())
    plt.show()
    
def display_loss_plot(gen_loss, disc_loss):
    plt.plot(range(len(gen_loss)), gen_loss, label='Generator loss')
    plt.plot(range(len(disc_loss)), disc_loss, label='Critic loss')
    plt.legend()
    plt.show()

# Generator

In [3]:
class Generator(nn.Module):
    def __init__(self, z_chan=10, n_classes=10, img_chan=1, h_chan=64, device='cpu'):
        super(Generator, self).__init__()
        
        self.z_chan = z_chan
        self.n_classes = n_classes
        self.inp_chan = z_chan + n_classes
        self.device = device

        self.gen = nn.Sequential(
            self.layer(self.inp_chan, h_chan*4),
            self.layer(h_chan*4, h_chan*2, kernel_size=4, stride=1),
            self.layer(h_chan*2, h_chan),
            self.layer(h_chan, img_chan, kernel_size=4, last_layer=True)
        )

    def layer(self, in_chan, out_chan, kernel_size=3, stride=2, last_layer=False):
        if not last_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_chan, out_chan, kernel_size, stride),
                nn.BatchNorm2d(out_chan),
                nn.ReLU(),
            )
        return nn.Sequential(
            nn.ConvTranspose2d(in_chan, out_chan, kernel_size, stride),
            nn.Tanh(),
        )

    def forward(self, inp):
        inp = inp.view(len(inp), self.inp_chan, 1, 1)
        return self.gen(inp)

    def get_imgs(self, labels):
        noise = self.get_noise(len(labels))
        inp = self.add_labels_to_noise(noise, labels)
        return self(inp)
    
    def add_labels_to_noise(self, noise, labels):
        enc_labels = nn.functional.one_hot(labels, self.n_classes).type(torch.float)
        return torch.cat((noise, enc_labels), dim=1)

    def get_noise(self, n_samples):
        return torch.randn(n_samples, self.z_chan, device=self.device)

# Critic

In [4]:
class Critic(nn.Module):
    def __init__(self, img_chan=1, n_classes=10, h_chan=16, device='cpu'):
        super(Critic, self).__init__()

        self.device = device
        self.n_classes = n_classes
        self.inp_chan = img_chan + n_classes

        self.crit = nn.Sequential(
            self.layer(self.inp_chan, h_chan),
            self.layer(h_chan, h_chan*2),
            self.layer(h_chan*2, 1, last_layer=True),
        )
        
    def get_gradient(self, real_img, fake_img, epsilon):
        mixed_img = epsilon*real_img + (1-epsilon)*fake_img
        score = self(mixed_img)

        gradient = torch.autograd.grad(
            inputs=mixed_img,
            outputs=score,
            grad_outputs=torch.ones_like(score),
            create_graph=True,
            retain_graph=True,
        )[0]

        return gradient

    def get_gradient_panelty(self, real_img, fake_img):
        epsilon = torch.rand(len(real_img), 1, 1, 1, device=self.device, requires_grad=True)
        
        gradient = self.get_gradient(real_img, fake_img, epsilon)
        gradient = gradient.view(len(gradient), -1)
        gradient_norm = gradient.norm(2, dim=1)
        panelty = torch.mean((gradient_norm - 1) ** 2)
        
        return panelty

    def layer(self, in_chan, out_chan, kernel_size=4, stride=2, last_layer=False):
        if not last_layer:
            return nn.Sequential(
                nn.Conv2d(in_chan, out_chan, kernel_size, stride),
                nn.BatchNorm2d(out_chan),
                nn.LeakyReLU(negative_slope=0.2),
            )
        return nn.Sequential(
            nn.Conv2d(in_chan, out_chan, kernel_size, stride),
        )

    def forward(self, img):
        crit_pred = self.crit(img)
        return crit_pred.view(len(crit_pred), -1)
    
    def add_labels_to_imgs(self, imgs, labels):
        enc_labels = nn.functional.one_hot(labels, self.n_classes).type(torch.float)
        enc_labels = enc_labels[:, :, None, None]                             # Increasing dim to 4D
        enc_labels = enc_labels.repeat(1, 1, imgs.shape[2], imgs.shape[3])    # Making encoding as img channel
        
        inp = torch.cat((imgs, enc_labels), dim=1)
        return inp

# Data Loading & Initialization

In [5]:
Z_DIM=64
BATCH_SIZE=128
BETAS=(0.5, 0.999)
LR=0.0002
C_LAMBDA=10
N_EPOCHS=50
CRIT_REPEAT = 5
DISPLAY_STEP=5
DEVICE='cpu'

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size=BATCH_SIZE,
    shuffle=True
)

In [7]:
gen = Generator(Z_DIM, device=DEVICE).to(DEVICE)
crit = Critic(device=DEVICE).to(DEVICE)
gen_opt = torch.optim.Adam(gen.parameters(), lr=LR, betas=BETAS)
crit_opt = torch.optim.Adam(crit.parameters(), lr=LR, betas=BETAS)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)
crit = crit.apply(weights_init)

# Training

In [None]:
crit_losses = []
gen_losses = []

In [9]:
def train(dataloader, n_epochs, crit_repeat, c_lambda, display_step):
    tqdm_obj = tqdm(range(0, n_epochs))
    no_of_batches = len(dataloader)

    for epoch in tqdm_obj:
        if epoch == 10:
            crit_repeat = 1

        for i, (real_imgs, labels) in enumerate(dataloader, 0):
            tqdm_obj.set_postfix({ 'Batch': f'{i}/{no_of_batches}' })
            real_imgs = real_imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            cur_batch_size = len(real_imgs)

            # Critic traning
            crit_mean_iteration_loss = 0
            for _ in range(crit_repeat):
                crit_opt.zero_grad()
                #Forward pass
                fake_imgs = gen.get_imgs(labels)
                fake_labeled_imgs = crit.add_labels_to_imgs(fake_imgs.detach(), labels)
                real_labeled_imgs = crit.add_labels_to_imgs(real_imgs, labels)
                fake_pred = crit(fake_labeled_imgs)
                real_pred = crit(real_labeled_imgs)
                # Backward pass
                gp = crit.get_gradient_panelty(real_labeled_imgs, fake_labeled_imgs)
                crit_loss = (torch.mean(fake_pred) - torch.mean(real_pred)) + c_lambda*gp
                crit_loss.backward(retain_graph=True)
                crit_opt.step()

                crit_mean_iteration_loss += crit_loss.cpu().detach().numpy()
            crit_losses.append(crit_mean_iteration_loss / crit_repeat)

            # Generator training
            gen_opt.zero_grad()
            # Forward pass
            fake_imgs = gen.get_imgs(labels)
            crit_pred = crit(crit.add_labels_to_imgs(fake_imgs, labels))
            # Backward pass
            gen_loss = -torch.mean(crit_pred)
            gen_loss.backward()
            gen_opt.step()

            gen_losses.append(gen_loss.cpu().detach().numpy())

        if epoch % display_step == 0:
            print(f'After {epoch} epochs')
            display_loss_plot(gen_losses, crit_losses)
            display_gen_pred(gen.get_imgs(labels))

In [10]:
train(dataloader, N_EPOCHS, CRIT_REPEAT, C_LAMBDA, DISPLAY_STEP)

# Testing

In [None]:
gen = gen.eval()
with torch.no_grad():
    imgs = gen.get_imgs(torch.Tensor([1,2,3,4,5,6,7,8,9]).type(torch.int64).to(DEVICE))
    display_gen_pred(imgs)