# Example on a MNIST GAN with PyTorch

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!nvidia-smi

project = '/content/drive/MyDrive/mnist-gan'
%mkdir -p {project}

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
bs = 128  # batch size

transform = transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize(mean=(0.5,), std=(0.5,))
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# MNIST contains 60000 images
train_dataset = datasets.MNIST(root='.', train=True, transform=transform, download=True)
#test_dataset = datasets.MNIST(root='.', train=False, transform=transform, download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True, drop_last=True)
#test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False, drop_last=True)

In [None]:
z_dim = 100
mnist_dim = train_dataset.data.size(1) * train_dataset.data.size(2)

print(z_dim)
print(mnist_dim)

## Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features * 2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features * 2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
print(G)

## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features // 2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features // 2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

D = Discriminator(mnist_dim).to(device)
print(D)

## Train

In [None]:
# loss function
criterion = nn.BCELoss() 

# optimizer
lr = 0.00001
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [None]:
def D_train(x):
    D.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)
    x_real, y_real = x_real.to(device), y_real.to(device)

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on fake
    z = torch.randn(bs, z_dim).to(device)
    x_fake, y_fake = G(z), torch.zeros(bs, 1).to(device)

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [None]:
def G_train(x):
    G.zero_grad()

    z = torch.randn(bs, z_dim).to(device)
    y = torch.ones(bs, 1).to(device)

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [None]:
%load_ext tensorboard

In [None]:
%mkdir -p {project}/runs
%tensorboard --logdir={project}/runs

In [None]:
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import trange

from pathlib import Path
import numpy as np
import re

dirpath = Path(f'{project}/runs')
experiments = [0]
for exp in dirpath.iterdir():
    if exp.is_dir():
        m = re.match(r'.*\/exp([0-9]+)', str(exp))
        if bool(m):
            experiments.append(int(m.groups()[0]))

exp_n = np.max(experiments) + 1
writer = SummaryWriter(f'{project}/runs/exp{exp_n:03d}_{bs}_{lr}')

n_epoch = 300

for epoch in trange(1, n_epoch + 1):
    D_losses, G_losses = [], []

    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))

    writer.add_scalar('Loss/Generator', torch.mean(torch.FloatTensor(G_losses)), epoch)
    writer.add_scalar('Loss/Discriminator', torch.mean(torch.FloatTensor(D_losses)), epoch)
    with torch.no_grad():
        fake_z = torch.randn(bs, z_dim).to(device)
        generated = G(fake_z)
        generated = (generated + 1) / 2
    fake_image = make_grid(generated[0:25].view(25, 1, 28, 28), nrow=5)
    writer.add_image('Fake Image', fake_image, epoch)

writer.flush()
writer.close()