<a href="https://colab.research.google.com/github/vkjdinesh/Reseacrh/blob/main/Intro_to_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Digit generation using Deep Convolution GANs



---



Instructions: 
* Make sure the runtime is set to GPU.
* If you want to clear all outputs you can go to Edit -> Clear all outputs

## 1. Import libraries and load data

### 1.0 Installing required libraries

In [None]:
!pip install tensorboardX

### 1.1 Importing the required libraries

We will download and extreact the assets folder that contains utils.py file and some gif file for visual representation.

In [None]:
import gdown
url = 'https://drive.google.com/uc?export=download&id=1NueLFReJ0BClJzjPoi3vVkVlB1PfkJJn'
output = 'Assets.zip'
gdown.download(url, output, quiet=False)

In [None]:
!unzip /content/Assets.zip -d /content/

In [None]:
from IPython import display
from utils import Logger


import torch
from torch import nn
from torch.optim import Adam
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn as sns

from torchvision import transforms, datasets

### 1.2 Loading and transforming the data

In [None]:
DATA_FOLDER = './torch_data/DCGAN/MNIST'
'''
The normalize function will normalize the image in the range [-1,1]. 
For example, the minimum value 0 will be converted to (0-0.5)/0.5=-1, and 
the maximum value of 1 will be converted to (1-0.5)/0.5=1.
'''
def mnist_data():
    # We resize the image, convert them to tensors and normalize our image values to be between -1 and 1
    compose = transforms.Compose(
        [
            transforms.Resize(64),
            transforms.ToTensor(),
            transforms.Normalize((.5,), (.5,)) #Can refer https://discuss.pytorch.org/t/understanding-transform-normalize/21730 for more info
        ])
    out_dir = '{}/dataset'.format(DATA_FOLDER)
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

In [None]:
data = mnist_data()
batch_size = 100  #can change to lower values if needed default: 100
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
num_batches = len(data_loader)

MNIST Data Sample:

![Train Data MNIST Sample](https://drive.google.com/uc?export=view&id=1iy7j2gIRaoFcL-nQV99PtQOeoPOo-sxN)



---



## 2. Create(define) the Generative Adversarial Network

![GAN Structure](https://drive.google.com/uc?export=view&id=1iCSX8g4EqB2dHe-_rlsBBl8FSKt5IufY)

### 2.1 Generator Class

In [None]:
class GenerativeNet(torch.nn.Module):
    #Defining the constructor
    def __init__(self):
        super(GenerativeNet, self).__init__()
        # Fully connected layer that takes in the generated noise
        self.linear = torch.nn.Linear(100, 1024*4*4)
        #First ConvTranspose layer
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        #Second ConvTranspose layer
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        #Third ConvTranspose layer
        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        #Fourth ConvTranspose layer
        self.conv4 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=1, kernel_size=4, stride=2, padding=1, bias=False)
        )
        #Apply tanh to get outputs in the range (-1,1)
        self.out = torch.nn.Tanh()

    def forward(self, x):
        # Project and reshape
        x = self.linear(x)
        x = x.view(x.shape[0], 1024, 4, 4)
        # Apply ConvTranspose layers
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        # Apply Tanh
        return self.out(x)
    

In [None]:
# Noise creator using random numbers
def noise(size):
    n = Variable(torch.randn(size, 100))
    if torch.cuda.is_available(): return n.cuda()
    return n

### 2.2 Discriminator Class

In [None]:
class DiscriminativeNet(torch.nn.Module):
    #Defining the constructor
    def __init__(self):
        super(DiscriminativeNet, self).__init__()
        #First Conv layer
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
        #Second Conv layer
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )
        #Third Conv layer
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        #Fourth Conv layer
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True)
        )
        #Apply the Fullyconnected layer and then sigmoid
        self.out = nn.Sequential(
            nn.Linear(1024*4*4, 1),
            #Apply sigmoid to get outputs in the range (0,1)
            nn.Sigmoid(),
        )

    def forward(self, x):
        # Apply Convolutional layers
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        # Flatten and apply sigmoid
        x = x.view(-1, 1024*4*4)
        x = self.out(x)
        return x

### 2.3 Initialize weights

In [None]:
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('BatchNorm') != -1:
        m.weight.data.normal_(0.00, 0.02)

### 2.4 Create the instance of Network Class

In [None]:
# Create Network instances and init weights
generator = GenerativeNet()
generator.apply(init_weights)

discriminator = DiscriminativeNet()
discriminator.apply(init_weights)

# Enable cuda if available to train faster using a GPU
if torch.cuda.is_available():
    generator.cuda()
    discriminator.cuda()

### 2.5 Visualize the model using summary

In [None]:
!pip install torchsummary
from torchsummary import summary
#Summary for generator
summary(generator, input_size=(100,))

In [None]:
#Summary for discriminator
summary(discriminator, input_size=(1, 64, 64))



---



# 3. Training

### 3.1 Optimization

In [None]:
# Optimizers
d_optimizer = Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss function
loss = nn.BCELoss()

# Number of epochs
num_epochs = 1 #Change to a larger number if needed after hands-on training default value: 200

### 3.2 Create target data (labels)

In [None]:
#Real data's label will have value = 1 (for discriminator to show it's a true image)
def real_data_target(size):
    #Tensor containing ones, with shape = size [1,1,1,....,1]
    data = Variable(torch.ones(size, 1))
    if torch.cuda.is_available(): return data.cuda()
    return data
    
#Fake data's label will have value = 0 (for discriminator to show it's a generated image)
def fake_data_target(size):
    #Tensor containing zeros, with shape = size [0,0,0....,0]
    data = Variable(torch.zeros(size, 1))
    if torch.cuda.is_available(): return data.cuda()
    return data

### 3.3 Encapsulating train functions for generator and discriminator

In [None]:
def train_discriminator(optimizer, real_data, fake_data):
    # Reset gradients
    optimizer.zero_grad()
    
    # 1. Train on Real Data
    prediction_real = discriminator(real_data)
    # Calculate error by comparing real with 1's and backpropagate
    error_real = loss(prediction_real, real_data_target(real_data.size(0)))
    error_real.backward()

    # 2. Train on Fake Data
    prediction_fake = discriminator(fake_data)
    # Calculate error by comparing fake with 0's and backpropagate
    error_fake = loss(prediction_fake, fake_data_target(real_data.size(0)))
    error_fake.backward()
    
    # Update weights with gradients
    optimizer.step()
    
    return error_real + error_fake, prediction_real, prediction_fake
    return (0, 0, 0)


def train_generator(optimizer, fake_data):
    # Reset gradients
    optimizer.zero_grad()
    # Sample noise and generate fake data
    prediction = discriminator(fake_data)
    # Calculate error and backpropagate
    error = loss(prediction, real_data_target(prediction.size(0)))
    error.backward()
    # Update weights with gradients
    optimizer.step()
    # Return error
    return error

### 3.4 Generate samples for testing using noise function

In [None]:
num_test_samples = 16
test_noise = noise(num_test_samples)

### 3.5 Train the model and log the results after each epoch

In [None]:
logger = Logger(model_name='DCGAN', data_name='MNIST')
cond = False
gen_loss_list = []
disc_loss_list = []
for epoch in range(num_epochs):
    for n_batch, (real_batch,_) in enumerate(data_loader):
        
        # 1. Train Discriminator
        real_data = Variable(real_batch)
        if torch.cuda.is_available(): real_data = real_data.cuda()
        # Generate fake data
        fake_data = generator(noise(real_data.size(0))).detach()
        # Run Train Discriminator function
        d_error, d_pred_real, d_pred_fake = train_discriminator(d_optimizer, real_data, fake_data)
        disc_loss_list.append(d_error.item())

        # 2. Train Generator
        # Generate fake data
        fake_data = generator(noise(real_batch.size(0)))
        # Run Train Generator function
        g_error = train_generator(g_optimizer, fake_data)
        gen_loss_list.append(g_error.item())
        # Log error
        logger.log(d_error, g_error, epoch, n_batch, num_batches)
        
        # Display Progress
        if (n_batch) % 100 == 0:
            #Uncomment the line below to only see the latest iteration output
            #display.clear_output(True) 
            # Display Images
            test_images = generator(test_noise).data.cpu()
            logger.log_images(test_images, num_test_samples, epoch, n_batch, num_batches);
            # Display status Logs
            logger.display_status(
                epoch, num_epochs, n_batch, num_batches,
                d_error, g_error, d_pred_real, d_pred_fake
            )

            #For plotting the disc error and gen error graph
            if cond:
              step_bins = 20
              x_axis = sorted([i * step_bins for i in range(len(gen_loss_list) // step_bins)] * step_bins)
              sns.lineplot(x_axis, gen_loss_list[:len(x_axis)], label="Generator's Loss")
              sns.lineplot(x_axis, disc_loss_list[:len(x_axis)], label="Discriminator's Loss")
              plt.legend()
              plt.ylim(min(min(disc_loss_list[-100:]), min(gen_loss_list[-100:])), 
                       max(max(disc_loss_list[-100:]), max(gen_loss_list[-100:])))
              plt.show()
            cond = True
        # Model Checkpoints
        logger.save_models(generator, discriminator, epoch)



---



# 4. Results after training

Epoch 0:

![Epoch_0](https://drive.google.com/uc?export=view&id=1bAoQhSMhKpx2Dr1FW0PpfjZ92mqaoulx)

Epoch 10:

![Epoch_10](https://drive.google.com/uc?export=view&id=1K4Su633DlbqYBMEHffReVa0kt12Csj_p)

Epoch 50:

![Epoch_10](https://drive.google.com/uc?export=view&id=1LBv1ke4APw8Xh3k6kRcbgXunktfm6W_2)

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from IPython import display
from pathlib import Path

In [None]:
gifPath = Path("./transitions.gif")
# Display 30 images compiled into GIF (nearly 90 epochs)
with open(gifPath,'rb') as f:
    display.Image(data=f.read(), format='png')

In [None]:
gifPath2 = Path("./Interpolation.gif")
# Cool interpolation gif
with open(gifPath2,'rb') as f:
    display.Image(data=f.read(), format='png')

# 5. Credits and references:

* Thanks to diegoalejogm for his public code for all types of GANs. Available [here](https://github.com/diegoalejogm/gans)

*   Gan Structure by Garima Nishad. Available [here](https://medium.com/intel-student-ambassadors/mnist-gan-detailed-step-by-step-explanation-implementation-in-code-ecc93b22dc60)
*   Interpolation GIF by Nikesh Bajaj

