# Lab 03 : GAN with CNN - exercise

The goal is to implement a GAN architecture with CNNs to generate new MNIST images.</br>

In [None]:
# For Google Colaboratory
import sys, os
if 'google.colab' in sys.modules:
    # mount google drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    path_to_file = '/content/gdrive/My Drive/CS4243_codes/codes/labs_lecture15/lab03_GAN_CNN'
    print(path_to_file)
    # move to Google Drive directory
    os.chdir(path_to_file)
    !pwd

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import utils
import time

### GPU is required to train GAN

In [None]:
device= torch.device("cuda")
#device= torch.device("cpu")
print(device)

In [None]:
# Libraries
import matplotlib.pyplot as plt
import logging
logging.getLogger().setLevel(logging.CRITICAL) # remove warnings

### MNIST dataset 

In [None]:
from utils import check_mnist_dataset_exists
data_path=check_mnist_dataset_exists()

train_data=torch.load(data_path+'mnist/train_data.pt')
train_label=torch.load(data_path+'mnist/train_label.pt')
print(train_data.size())

### Network architecture

In [None]:
# Global constants
# n : nb of pixels along each spatial dimension
# dz : latent dimension
# d : hidden dimension
# b : batch size
n = train_data.size(1)
dz = n
d = 128
b = 64


In [None]:
# Define the generator and discriminator networks
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        # COMPLETE HERE
        self.linear1 = 
        self.bn1 = 
        self.linear2 = 
        self.bn2 = 
        self.tconv1 = 
        self.bn3 = 
        self.tconv2 = 
    def forward(self, z): 
        # COMPLETE HERE
        g_z = 
        return g_z

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        # COMPLETE HERE
        self.conv1 = 
        self.bn1 = 
        self.conv2 = 
        self.bn2 = 
        self.linear1 = 
        self.bn3 = 
        self.linear2 = 
    def forward(self, h): 
        # COMPLETE HERE
        d_h = 
        return d_h

# Instantiate the network
net_g = generator()
net_g = net_g.to(device)
print(net_g)
utils.display_num_param(net_g) 
net_d = discriminator()
net_d = net_d.to(device)
print(net_d)
utils.display_num_param(net_d) 

# Test the forward pass, backward pass and gradient update with a single batch
init_lr = 0.001
optimizer_g = torch.optim.Adam(net_g.parameters(), lr=init_lr)
optimizer_d = torch.optim.Adam(net_d.parameters(), lr=init_lr)

b = 10
idx = torch.LongTensor(b).random_(0,60000)
x_real = train_data[idx,:,:].view(b,-1).to(device) # [b, n**2]
print(x_real.size())

z = torch.rand(b, dz).to(device) # [b, dz]
print(z.size())

p_one = torch.ones(b, 1).to(device)
p_zero = torch.zeros(b, 1).to(device)

# update g
optimizer_g.zero_grad()
x_fake = net_g(z) # [b, 1, n, n]
p_fake = net_d(x_fake) # [b, 1]
print(x_fake.size(), p_fake.size())
loss_fake = nn.BCELoss()(p_fake, p_one)
loss = loss_fake
loss.backward()
optimizer_g.step()

# update d
optimizer_d.zero_grad()
x_fake = net_g(z) # [b, 1, n, n]
p_fake = net_d(x_fake) # [b, 1]
p_real = net_d(x_real.view(-1,n,n).unsqueeze(1)) # [b, 1]
print(x_fake.size(), p_fake.size(), p_real.size())
loss_real = nn.BCELoss()(p_real, p_one)
loss_fake = nn.BCELoss()(p_fake, p_zero)
loss = loss_real + loss_fake
loss.backward()
optimizer_d.step()


In [None]:
# Training loop
net_g = generator()
net_g = net_g.to(device)
print(net_g)
utils.display_num_param(net_g) 
net_d = discriminator()
net_d = net_d.to(device)
print(net_d)
utils.display_num_param(net_d) 

# Optimizer
init_lr = 0.0002
optimizer_g = torch.optim.Adam(net_g.parameters(), lr=init_lr, betas=(0.5, 0.999))
optimizer_d = torch.optim.Adam(net_d.parameters(), lr=init_lr, betas=(0.5, 0.999))

nb_batch = 200 # GPU # Nb of mini-batches per epoch
b = 64  # Batch size

p_one = torch.ones(b, 1).to(device)
p_zero = torch.zeros(b, 1).to(device)

start=time.time()
for epoch in range(50):

    running_loss_d = 0.0
    running_loss_g = 0.0
    num_batches = 0
    
    for _ in range(nb_batch):
        
        # FORWARD AND BACKWARD PASS
        idx = torch.LongTensor(b).random_(0,60000)
        x_real = train_data[idx,:,:].view(b,-1).to(device) # [b, n**2]
        z = torch.rand(b, dz).to(device) # Uniform distribution # [b, dz]
        
        # update d
        optimizer_d.zero_grad()
        x_fake = net_g(z) # [b, 1, n, n]
        p_fake = net_d(x_fake) # [b, 1]
        p_real = net_d(x_real.view(-1,n,n).unsqueeze(1)) # [b, 1]
        loss_real = nn.BCELoss()(p_real, p_one)
        loss_fake = nn.BCELoss()(p_fake, p_zero)
        loss = loss_real + loss_fake
        loss_d = loss.detach().item()
        loss.backward()
        optimizer_d.step()
        
        # update g
        optimizer_g.zero_grad()
        x_fake = net_g(z) # [b, 1, n, n]
        p_fake = net_d(x_fake) # [b, 1]
        loss_fake = nn.BCELoss()(p_fake, p_one)
        loss = loss_fake
        loss_g = loss.detach().item()
        loss.backward()
        optimizer_g.step()

        # COMPUTE STATS
        running_loss_d += loss_d
        running_loss_g += loss_g
        num_batches += 1        
    
    # AVERAGE STATS THEN DISPLAY
    total_loss_d = running_loss_d/ num_batches
    total_loss_g = running_loss_g/ num_batches
    elapsed = (time.time()-start)/60
    print('epoch=',epoch, '\t time=', elapsed,'min', '\t lr=', init_lr  ,'\t loss_d=', total_loss_d ,'\t loss_g=', total_loss_g )

    if not epoch%5:
        plt.imshow(x_fake.view(b,n,n).detach().cpu()[0,:,:], cmap='gray'); plt.show() 
    

In [None]:
# Generate a few images
b = 10
z = torch.rand(b, dz) # Uniform distribution 
z = z.to(device)
x_new = net_g(z).view(b,n,n).detach().cpu()
for k in range(b):
    plt.imshow(x_new[k,:,:], cmap='gray'); plt.show() 
  