<a href="https://colab.research.google.com/github/taznica/ComputerVision_Assignments/blob/main/assignment3_mnist_gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import torchvision

# It takes quite long time to train a GAN, better to save the snapshots/final model in your drive.
# Mount GoogleDrive here.
from google.colab import drive
drive.mount('/content/gdrive/')

# Check everything going well
# !ls /content/gdrive/'My Drive'

# Define some hyper parameters
epo_size  = 200
bch_size  = 100    # batch size
base_lr   = 0.0001 # learning rate
mnist_dim = 784    # =28x28, 28 is the height/width of a mnist image.
z_dim     = 100    # dimension of the random vector z for Generator's input.
save_root = F"/content/gdrive/My Drive/Computer Vision/"  # where to save your models.


# Define transform func.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])])

# Define dataloader
train_dataset = datasets.MNIST(root='./mnist_data/', train=True,  transform=transform, download=True )
test_dataset  = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)
train_loader  = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bch_size, shuffle=True )
test_loader   = torch.utils.data.DataLoader(dataset= test_dataset, batch_size=bch_size, shuffle=False)

In [None]:
# Define the two networks
class Generator(nn.Module):
    def __init__(self, g_input_dim=100, g_output_dim=784):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim=784):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [None]:
# Initialize a Generator and a Discriminator. 
G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).cuda()
D = Discriminator(mnist_dim).cuda()

In [None]:
# Loss func. BCELoss means Binary Cross Entropy Loss.
criterion = nn.BCELoss() 

# Initialize the optimizer. Use Adam.
G_optimizer = optim.Adam(G.parameters(), lr = base_lr)
D_optimizer = optim.Adam(D.parameters(), lr = base_lr)

In [None]:
# Code for training the discriminator.
def D_train(x, D_optimizer):
    D_optimizer.zero_grad()
    b,c,h,w = x.size()

    # train discriminator on real image
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(b, 1)
    x_real, y_real = Variable(x_real).cuda(), Variable(y_real).cuda()

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on fake
    z      = Variable(torch.randn(b, z_dim)).cuda()
    y_fake = Variable(torch.zeros(b, 1)).cuda()
    x_fake = G(z)

    D_output = D(x_fake.detach()) # Detach the x_fake, no need grad. for Generator.
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # Only update D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [None]:
# Code for training the generator
def G_train(bch_size, z_dim, G_optimizer):
    G_optimizer.zero_grad()

    z = Variable(torch.randn(bch_size, z_dim)).cuda()
    y = Variable(torch.ones(bch_size, 1)).cuda()

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y) # Fool the discriminator :P

    # Only update G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item(), G_output

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output, display

def Logging(images, G_loss, D_loss):
    clear_output(wait=True)
    plt.clf()
    x_values = np.arange(0,len(G_loss), 1)
    fig, ax = plt.subplots()
    ax.plot(G_loss, label='G_loss')
    ax.plot(D_loss, label='D_loss')
    legend = ax.legend(loc='upper right', shadow=True, fontsize='x-large')
    plt.grid(linestyle='-')
    plt.title("Training loss")
    plt.ylabel("Loss")
    plt.show()
    show_imgs = torchvision.utils.make_grid(G_output, nrow=10).numpy().transpose((1,2,0))
    plt.imshow(show_imgs)
    plt.show()
        
    

In [None]:
D_epoch_losses, G_epoch_losses = [], []   # record the average loss per epoch.
for epoch in range(1, epo_size+1):
    D_losses, G_losses = [], []     
    for iteration, (x, _) in enumerate(train_loader):
        # Train discriminator 
        D_loss = D_train(x, D_optimizer)
        D_losses.append(D_loss)
        # Train generator
        G_loss, G_output = G_train(bch_size, z_dim, G_optimizer)
        G_losses.append(G_loss)

    # Record losses for logging
    D_epoch_loss = torch.mean(torch.FloatTensor(D_losses))
    G_epoch_loss = torch.mean(torch.FloatTensor(G_losses))
    D_epoch_losses.append(D_epoch_loss)
    G_epoch_losses.append(G_epoch_loss)

    # Convert G_output to an image.
    G_output = G_output.detach().cpu()
    G_output = G_output.view(-1, 1, 28, 28)

    # Logging 
    Logging(G_output, G_epoch_losses, D_epoch_losses)
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), epo_size, D_epoch_loss, G_epoch_loss))
    
    # Save G/D models
    save_pth_G = save_root+'G_model.pt'
    save_pth_D = save_root+'D_model.pt'
    torch.save(G.state_dict(), save_pth_G)
    torch.save(D.state_dict(), save_pth_D)
print("Training is finished.")

In [None]:
# Lets evaluate (have fun with) the Generative model!
import os

# Load the pretrained model for G
# - First check if the path is correct
os.path.isfile(save_pth_G)

# - Load pretrained G. 
G.load_state_dict(torch.load(save_pth_G))
G.eval()

In [None]:
# Run this block many times to generate random digit images.
# Have fun!

# z = Variable(torch.randn(1, z_dim)).cuda()
# print(torch.randn(1, z_dim))

z0 = [ 0.0670,  0.4246,  0.8163,  0.0165,  0.7616,  2.2575,  0.3409, -2.2294,
         -1.4312,  1.2157, -0.3054,  1.6559,  1.1967, -1.2935, -0.7403, -2.4972,
         -1.3574, -0.4469, -0.2014, -1.4115, -0.1373,  0.4235, -0.6672,  0.1750,
          0.5466, -0.1026,  0.1705, -0.7856, -0.8666, -1.4930,  0.3237,  2.3630,
         -1.1341, -0.5896, -1.5767,  0.0465,  1.3691, -1.2543, -1.3945, -1.2385,
          2.4788, -1.0771,  1.5651,  0.3856, -0.1555,  0.1183,  1.0716,  0.9131,
         -0.5365,  1.1387,  1.3006,  1.1782, -0.1185, -0.1885,  0.6651, -0.0319,
          2.6813, -1.9735,  0.1165, -1.9463,  0.0412, -0.8545,  0.7815,  0.8114,
          0.8538,  0.4223, -0.2408, -1.1178, -1.1417,  2.0656,  0.9076,  0.6768,
          0.2702, -0.2271,  0.8220, -0.7479, -0.2487, -2.0722,  1.3153, -0.0199,
          1.6692, -0.4006, -0.9921, -0.3804, -0.8372,  1.0963,  0.2151,  1.0713,
          0.4092, -0.7751,  0.6467,  0.0751,  0.5303, -1.8843,  1.0279,  0.2468,
         -0.3287,  1.1260, -1.1273, -0.2366]

z1 = [ 0.0765, -0.4623, -1.4586, -0.2161, -0.3548,  1.4146, -1.0551, -0.7890,
         -0.5789, -0.4960, -0.5829, -2.1166,  0.2552, -1.9895,  0.9629,  0.6168,
          0.3311,  0.1782, -0.3973, -0.9166, -0.5615, -0.1346,  0.3775,  0.4320,
         -0.0350,  0.1914,  0.8589, -0.9403, -0.1992,  0.1560, -0.1421, -2.5100,
         -0.3334,  0.8744, -0.4826,  1.0336,  0.3423,  2.0692,  1.0377, -0.3209,
          0.2176,  0.6150, -1.1691, -0.3906,  1.4453,  0.7056,  1.9323, -0.5974,
         -1.5143, -0.3021,  1.4453,  2.8274, -1.5002,  0.2371, -0.1654,  0.3807,
          0.6270,  0.0617, -0.7433, -0.3248,  1.5076,  0.2362,  0.7952,  0.4909,
         -1.6238,  0.2728, -0.5981,  1.1826, -0.4085, -0.6560, -0.4504, -1.1603,
         -0.3098, -0.6432, -2.3294,  0.2107, -0.1927,  0.5924, -0.1989,  0.8946,
          0.3909, -1.3076, -0.7229, -0.6188, -0.2572, -0.5444, -0.1203,  1.0342,
         -0.4790, -0.1645,  1.3566,  0.1919,  0.2689, -0.6553,  1.3844, -0.2718,
         -0.0985, -0.1386, -0.4254, -1.1643]

z6 = [-0.5175, -0.1774, -1.7525,  1.5123, -0.7054,  1.1917, -1.5286, -1.0867,
          0.9565,  0.6768,  0.1136,  0.9018,  0.3761, -1.0810,  0.5448, -1.5261,
          0.9193,  1.3252,  0.5495,  1.0456,  1.2668, -1.2595, -0.8136, -0.7019,
          0.8347, -1.2693,  1.5457,  1.6564, -0.7748, -0.3196,  2.1886,  1.5403,
         -0.4817, -1.5949,  2.9807,  0.3579, -1.0033,  0.2182,  2.3000, -0.9246,
          0.2465, -0.9418,  1.1661, -0.4932,  0.2754, -0.5000, -0.2833, -0.4199,
          0.1732,  0.9455, -0.3853,  0.0548,  0.7936,  0.2142,  1.8438, -0.1598,
          1.1431,  1.2672, -0.2559,  1.0985, -0.3706,  0.6408, -0.0689,  0.2470,
         -0.3579,  0.3868,  1.7309,  0.4774,  0.5839,  0.6602, -0.7842, -1.2809,
          1.4804, -0.6044, -0.0341,  0.0410, -0.3965,  1.2379,  0.3320,  0.4810,
          1.0857, -0.8209,  0.6416, -0.4204, -0.3886,  0.5497,  0.1795, -0.9070,
         -0.0452, -1.0870, -0.3837, -0.5121,  0.8246, -0.1282,  0.0889, -1.1181,
          2.0394,  0.0564, -1.6174, -1.0564]

z9 = [ 1.2116,  0.2228,  1.5141,  0.5140,  0.0359,  0.7169,  0.4445, -0.9991,
          0.2119,  1.2045, -0.8397, -1.0614,  0.9099, -0.5794, -1.7643, -0.0369,
         -1.5591, -1.8883, -1.4656, -0.8106, -0.1735, -2.6018, -0.1823,  0.4802,
         -1.6327,  0.2586, -0.5053,  0.1902, -0.7573, -0.0097, -1.7001,  0.0591,
          0.6444, -0.0860, -0.1687, -1.3006, -0.8792,  0.7566,  1.6040, -1.6940,
          0.9833,  0.1014,  0.2752,  0.2458, -2.2526,  1.3864, -0.2813, -0.8472,
         -0.5591,  0.1651,  1.6453, -0.8115, -0.5205, -0.3648, -0.3865,  1.6940,
          0.3970,  0.4219,  0.2787, -0.3909,  0.6798,  1.5773,  0.8854, -0.3179,
          1.8259,  0.1343,  0.2036,  0.1482, -1.5066, -1.1446, -0.0480, -0.7576,
         -0.1442,  1.0321, -0.6639,  0.5685, -0.5995, -0.0343,  0.8576, -0.1236,
         -0.0276,  2.7488,  0.9188,  0.0185, -1.5115,  1.1500, -0.0425,  1.1283,
         -0.1181, -0.8480,  0.8114,  0.5064, -1.0460,  1.5380, -1.5904,  0.0841,
         -0.5427,  2.6939, -1.1816,  1.5358]

a = 0.3

# z = a * z0 + (1 - a) * z1
z0 = np.multiply(z0, a)
z1 = np.multiply(z1, (1 - a))

z = Variable(torch.FloatTensor(z0 + z1)).cuda()

# see z, you should see a random vector
print(z) 
plt.hist(z.cpu().numpy().flat, bins=20)
plt.show()

# Generate img form z
img_vec = G(z)
img = img_vec.view(1, 28, 28).data.cpu()
img = torchvision.utils.make_grid(img, nrow=1).numpy().transpose((1,2,0))
plt.imshow(img)

# for a in np.arange(0.1, 0.9, 0.1):
#   print('a: ' + str(a))
#   z0 = np.multiply(z0, a)
#   z1 = np.multiply(z1, (1 - a))
#   z = Variable(torch.FloatTensor(z0 + z1)).cuda()

#   # plt.hist(z.cpu().numpy().flat, bins=20)
#   # plt.show()

#   img_vec = G(z)
#   img = img_vec.view(1, 28, 28).data.cpu()
#   img = torchvision.utils.make_grid(img, nrow=1).numpy().transpose((1,2,0))
#   plt.figure()
#   plt.imshow(img)
