# Applying **Generative Adversial Nets (GANs)** to MNIST 


## Introduction

I find **Generative Adversarial Nets framework (GANs)**, [proposed](https://arxiv.org/pdf/1406.2661.pdf) by Ian Goodfellow and his colleagues at University of Montreal, very fascinating. This novel technique relies on **pitching two different deep neural nets against each other** to achieve state-of-the-art results in estimating generative models:

 - **Generator (G)** which takes a random input and tries to generate an output (fake data) close to the real data (and consequently fool the other DNN)
 - **Discriminator (D)** which takes as input real data and generated data (fake data) and tries to distinguish one from the other

*Both these DNNs try to outdo each other and thus keep improving themselves in their function.*

Since a picture is a lot better than a thousand words, here's a diagram that explains it really well:

![GAN diagram](./images/GAN.png)

## Current Task

Here I'll apply this technique to the MNIST dataset and will use the generator to generate some handwritten digits myself. 

## Approach

### Architecture

We'll use DNNs mostly comprising of **Fully Connected** layers. We can later improve the architecture perhaps by using Convolutional (CNN) layers.

### Deep Learning Library

I'll use the [Pytorch Library](http://pytorch.org/) in this project. 

## Implementation

We start with importing the necessary libraries/ packages.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms

In [2]:
#Check if GPU is available. We'll use this flag to use GPU for training if available
has_gpu = torch.cuda.is_available()

Now we setup the dataset, dataloader and the transforms that we wish to apply to the dataset

In [3]:
#Image preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [4]:
#MNIST dataset
train_dataset = datasets.MNIST('./data/', train=True, transform=transform, download=True)

In [5]:
#Data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)

### Discriminator

We use 3 fully connected layers to model the Disciminator for the time being. Depending on the results we may make it more or less powerful.

In [6]:
#Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)
    
    def forward(self, inp):
        x = F.relu(self.fc1(inp))
        x = F.relu(self.fc2(x))
        out = F.sigmoid(self.fc3(x))
        return out

### Generator

Here again we use 3 fully connected layers to model the Generator.

In [7]:
#Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 784)
        
    def forward(self, inp):
        x = F.leaky_relu(self.fc1(inp))
        x = F.leaky_relu(self.fc2(x))
        out = F.tanh(self.fc3(x))
        return out
        

In [8]:
discriminator = Discriminator()
generator = Generator()

if has_gpu:
    discriminator.cuda()
    generator.cuda()

### Getting ready for the training phase

Before the training phase, we define the following:

1. Criterion to be used by Loss function : Binary Cross Entropy (because the target label will be binary)
2. Adam, my favourite optimizer

In [9]:
#loss & optimizer

criterion = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0005)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0005)

### Let the games begin!

So now we start the training phase.

In each iteration,

1. ** Training the Discriminator ** : 
    - We feed a batch from real data and calculate **Real loss** function (since it's real data the target label should be 1)
    - Now we generate a fake image from the generator and feed it to the discriminator calculating **Fake Loss** (with target Label set to 0 since it's fake)
    - Now we calculate ** Total Loss ** and use Adam to train the Discriminator w.r.t. this loss
    
2. ** Training the Generator ** :
    - We generate a fake image and and feed it to the Discriminator. 
    - We calculate the ** Generator loss ** function with the target label set to 1 (since Generator's aim is to generate real image after all)
    - We use Adam optimizer to train the Generator

In [10]:
#Training
for epoch in range(200):
    for i, (images, _) in enumerate(train_loader):
        #Build mini-batch dataset
        images = images.view(images.size(0), -1)
        images = Variable(images)
        real_labels = Variable(torch.ones(images.size(0)))
        fake_labels = Variable(torch.zeros(images.size(0)))
        
        #train the discriminator
        discriminator.zero_grad()
        if has_gpu:
            outputs = discriminator(images.cuda())
            real_loss = criterion(outputs, real_labels.cuda())
            real_score = outputs
        else:
            outputs = discriminator(images)
            real_loss = criterion(outputs, real_labels)
            real_score = outputs
        
        noise = Variable(torch.randn(images.size(0), 128))
        if has_gpu:
            fake_images = generator(noise.cuda())
            outputs = discriminator(fake_images.detach())
            fake_loss = criterion(outputs, fake_labels.cuda())
            fake_score = outputs
        else:
            fake_images = generator(noise)
            outputs = discriminator(fake_images.detach())
            fake_loss = criterion(outputs, fake_labels)
            fake_score = outputs
        
        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        #train the generator
        generator.zero_grad()
        noise = Variable(torch.randn(images.size(0), 128))
        if has_gpu:
            fake_images = generator(noise.cuda())
            outputs = discriminator(fake_images)
            g_loss = criterion(outputs, real_labels.cuda())
        else:
            fake_images = generator(noise)
            outputs = discriminator(fake_images)
            g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        g_optimizer.step()
        
        if (i + 1) % 300 == 0:
            print("Epoch [{}/{}], Step [{} / {}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}".format(
                epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0], real_score.data.mean(), fake_score.cuda().data.mean()))
            
    #save the sampled images
    fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
    torchvision.utils.save_image(fake_images.data, './gen/fake_samples_{}.png'.format(epoch+1))

Epoch [0/5], Step [300 / 600], d_loss: 0.2022, g_loss: 3.9815, D(x): 0.98, D(G(z)): 0.15
Epoch [0/5], Step [600 / 600], d_loss: 1.0835, g_loss: 2.4868, D(x): 0.69, D(G(z)): 0.37
Epoch [1/5], Step [300 / 600], d_loss: 0.3280, g_loss: 3.4601, D(x): 0.86, D(G(z)): 0.10
Epoch [1/5], Step [600 / 600], d_loss: 0.8882, g_loss: 2.0701, D(x): 0.70, D(G(z)): 0.36
Epoch [2/5], Step [300 / 600], d_loss: 2.9380, g_loss: 1.1472, D(x): 0.36, D(G(z)): 0.62
Epoch [2/5], Step [600 / 600], d_loss: 0.5566, g_loss: 1.8488, D(x): 0.79, D(G(z)): 0.25
Epoch [3/5], Step [300 / 600], d_loss: 1.9433, g_loss: 0.6596, D(x): 0.54, D(G(z)): 0.65
Epoch [3/5], Step [600 / 600], d_loss: 0.5841, g_loss: 2.3619, D(x): 0.71, D(G(z)): 0.16
Epoch [4/5], Step [300 / 600], d_loss: 1.4625, g_loss: 1.8418, D(x): 0.62, D(G(z)): 0.42
Epoch [4/5], Step [600 / 600], d_loss: 1.0637, g_loss: 1.3700, D(x): 0.64, D(G(z)): 0.40
Epoch [5/5], Step [300 / 600], d_loss: 1.5017, g_loss: 1.1628, D(x): 0.53, D(G(z)): 0.43
Epoch [5/5], Step [60

Epoch [46/5], Step [300 / 600], d_loss: 0.9178, g_loss: 1.8897, D(x): 0.73, D(G(z)): 0.30
Epoch [46/5], Step [600 / 600], d_loss: 0.7758, g_loss: 1.5473, D(x): 0.82, D(G(z)): 0.34
Epoch [47/5], Step [300 / 600], d_loss: 0.8377, g_loss: 1.7893, D(x): 0.74, D(G(z)): 0.31
Epoch [47/5], Step [600 / 600], d_loss: 0.8023, g_loss: 1.7132, D(x): 0.71, D(G(z)): 0.24
Epoch [48/5], Step [300 / 600], d_loss: 0.8281, g_loss: 1.7613, D(x): 0.75, D(G(z)): 0.29
Epoch [48/5], Step [600 / 600], d_loss: 0.8203, g_loss: 1.5359, D(x): 0.75, D(G(z)): 0.31
Epoch [49/5], Step [300 / 600], d_loss: 0.8987, g_loss: 1.6966, D(x): 0.76, D(G(z)): 0.32
Epoch [49/5], Step [600 / 600], d_loss: 0.8090, g_loss: 1.8267, D(x): 0.73, D(G(z)): 0.28
Epoch [50/5], Step [300 / 600], d_loss: 1.0110, g_loss: 1.4934, D(x): 0.70, D(G(z)): 0.37
Epoch [50/5], Step [600 / 600], d_loss: 1.0230, g_loss: 1.6370, D(x): 0.64, D(G(z)): 0.28
Epoch [51/5], Step [300 / 600], d_loss: 1.0921, g_loss: 1.7343, D(x): 0.65, D(G(z)): 0.29
Epoch [51/

Epoch [92/5], Step [300 / 600], d_loss: 1.1118, g_loss: 1.1675, D(x): 0.64, D(G(z)): 0.39
Epoch [92/5], Step [600 / 600], d_loss: 1.1132, g_loss: 1.2288, D(x): 0.64, D(G(z)): 0.41
Epoch [93/5], Step [300 / 600], d_loss: 1.1609, g_loss: 1.2391, D(x): 0.57, D(G(z)): 0.34
Epoch [93/5], Step [600 / 600], d_loss: 1.1668, g_loss: 1.2965, D(x): 0.61, D(G(z)): 0.39
Epoch [94/5], Step [300 / 600], d_loss: 1.0974, g_loss: 1.2885, D(x): 0.58, D(G(z)): 0.33
Epoch [94/5], Step [600 / 600], d_loss: 1.1334, g_loss: 1.4101, D(x): 0.60, D(G(z)): 0.39
Epoch [95/5], Step [300 / 600], d_loss: 1.2594, g_loss: 1.3788, D(x): 0.53, D(G(z)): 0.32
Epoch [95/5], Step [600 / 600], d_loss: 1.0774, g_loss: 1.0359, D(x): 0.63, D(G(z)): 0.38
Epoch [96/5], Step [300 / 600], d_loss: 0.9847, g_loss: 1.1312, D(x): 0.69, D(G(z)): 0.38
Epoch [96/5], Step [600 / 600], d_loss: 1.1621, g_loss: 1.3665, D(x): 0.63, D(G(z)): 0.42
Epoch [97/5], Step [300 / 600], d_loss: 1.1383, g_loss: 1.2375, D(x): 0.58, D(G(z)): 0.38
Epoch [97/

Epoch [137/5], Step [600 / 600], d_loss: 1.1095, g_loss: 1.2872, D(x): 0.59, D(G(z)): 0.35
Epoch [138/5], Step [300 / 600], d_loss: 1.1475, g_loss: 1.2265, D(x): 0.61, D(G(z)): 0.37
Epoch [138/5], Step [600 / 600], d_loss: 1.2015, g_loss: 1.2632, D(x): 0.58, D(G(z)): 0.38
Epoch [139/5], Step [300 / 600], d_loss: 1.0819, g_loss: 1.0834, D(x): 0.59, D(G(z)): 0.34
Epoch [139/5], Step [600 / 600], d_loss: 1.0844, g_loss: 1.2339, D(x): 0.64, D(G(z)): 0.38
Epoch [140/5], Step [300 / 600], d_loss: 1.0369, g_loss: 1.0390, D(x): 0.67, D(G(z)): 0.39
Epoch [140/5], Step [600 / 600], d_loss: 1.0806, g_loss: 1.1119, D(x): 0.60, D(G(z)): 0.36
Epoch [141/5], Step [300 / 600], d_loss: 1.1824, g_loss: 1.1233, D(x): 0.60, D(G(z)): 0.41
Epoch [141/5], Step [600 / 600], d_loss: 1.0818, g_loss: 1.0752, D(x): 0.63, D(G(z)): 0.37
Epoch [142/5], Step [300 / 600], d_loss: 1.1400, g_loss: 1.1554, D(x): 0.60, D(G(z)): 0.40
Epoch [142/5], Step [600 / 600], d_loss: 1.1607, g_loss: 1.0108, D(x): 0.63, D(G(z)): 0.42

Epoch [183/5], Step [300 / 600], d_loss: 1.1504, g_loss: 1.1024, D(x): 0.63, D(G(z)): 0.43
Epoch [183/5], Step [600 / 600], d_loss: 1.1382, g_loss: 1.1329, D(x): 0.59, D(G(z)): 0.39
Epoch [184/5], Step [300 / 600], d_loss: 1.1150, g_loss: 1.1487, D(x): 0.60, D(G(z)): 0.37
Epoch [184/5], Step [600 / 600], d_loss: 1.1593, g_loss: 0.9694, D(x): 0.62, D(G(z)): 0.41
Epoch [185/5], Step [300 / 600], d_loss: 1.2260, g_loss: 1.2829, D(x): 0.59, D(G(z)): 0.40
Epoch [185/5], Step [600 / 600], d_loss: 1.2560, g_loss: 1.1566, D(x): 0.55, D(G(z)): 0.39
Epoch [186/5], Step [300 / 600], d_loss: 1.1718, g_loss: 0.9140, D(x): 0.63, D(G(z)): 0.44
Epoch [186/5], Step [600 / 600], d_loss: 1.0827, g_loss: 1.0664, D(x): 0.62, D(G(z)): 0.37
Epoch [187/5], Step [300 / 600], d_loss: 1.1391, g_loss: 0.9520, D(x): 0.61, D(G(z)): 0.40
Epoch [187/5], Step [600 / 600], d_loss: 1.2663, g_loss: 1.1947, D(x): 0.59, D(G(z)): 0.43
Epoch [188/5], Step [300 / 600], d_loss: 1.2351, g_loss: 1.1382, D(x): 0.57, D(G(z)): 0.42

## Saving the model

Now we save the models to resume training later / use to generate similar images

In [11]:
torch.save(generator.state_dict(), './gen.pkl')
torch.save(discriminator.state_dict(), './dis.pkl')

## A peek at the Generator's output

Here take a look at one of the outputs from the Generator.

![Generated digits](./images/fake_samples_mnist.png)

## Conclusion

Although the generated digits look **awesome**, they can be further improved by:

1. Tweaking the hyperparameters like batch size, adding dropout & batchnorm layers etc.
2. Using Convolutional layers instead of fully-connected layers (DCGANs)
2. Training further. Most of the awesome results that I have seen so far are from extensive training for days if not weeks