In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as neural_network
import torch.nn.functional as functions
import torchvision.transforms as transforms
import torch.nn as model
import torch.nn.functional as functions
import torch.optim as optim

from torchvision import datasets

# Generating Digits using Gans

A generative adversarial network (GAN) is a class of machine learning frameworks designed by Ian Goodfellow and his colleagues in 2014. Two neural networks contest with each other in a game (in the form of a zero-sum game, where one agent's gain is another agent's loss). 

- More about GAN: https://en.wikipedia.org/wiki/Generative_adversarial_network
- More about PyTorch: https://pytorch.org/


## About this notebook
This notebook kernel was created to help you understand more about machine learning. I intend to create tutorials with several machine learning algorithms from basic to advanced. I hope I can help you with this data science trail. For any information, you can contact me through the link below.

**Contact me: https://www.linkedin.com/in/vitorgamalemos/**

Other noteboks about neural networks:

- Simple Perceptron: https://www.kaggle.com/vitorgamalemos/neural-network-01-simple-perceptron
- Multilayer Perceptron: https://www.kaggle.com/vitorgamalemos/neural-network-02-multilayer-perceptron
- Convolutional neural network: https://www.kaggle.com/vitorgamalemos/object-recognition-using-convolutional-network

## Download Data Digits and Transform

In [None]:
def download_data_digits():
    return datasets.MNIST(root='data', train=True, download=True, transform=transforms.ToTensor())

train_data = download_data_digits()

## Train Data Loader

In [None]:
train = torch.utils.data.DataLoader(train_data, batch_size=64, num_workers=2)

## Visualizating Samples Test

In [None]:
def show_samples_test(length=10):
    images, labels = iter(train).next()
    images = images.numpy()
    
    indexes = 0
    for k in range(0, 5):
        fig, axs = plt.subplots(1, length, figsize=(10, 10))

        for i in range(0, length):
            axs[i].imshow(np.squeeze(images[i + indexes]), label=labels[i + i + indexes])
            axs[i].set_title('Y: {}'.format(labels[i + indexes]))
            leg = axs[i].legend()
            
        indexes += 10


In [None]:
show_samples_test(10)

## Discriminator Network

Discriminator Training Data
The discriminator's training data comes from two sources:

- Real data instances, such as real pictures of people. The discriminator uses these instances as positive examples during training.
- Fake data instances created by the generator. The discriminator uses these instances as negative examples during training.


**Reference: https://developers.google.com/machine-learning/gan/discriminator**

<img src="https://sthalles.github.io/assets/dcgan/GANs.png">

In [None]:
class Discriminator(neural_network.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.input_layer    = neural_network.Linear(784, 128)
        self.hidden_layer_1 = neural_network.Linear(128, 64)
        self.hidden_layer_2 = neural_network.Linear(64, 32)
        self.output_layer   = neural_network.Linear(32, 1)

        self.dropout = neural_network.Dropout(0.3)
        
        
    def forward(self, img_input):
        img_input = img_input.view(-1, 28 * 28)
        img_input = functions.leaky_relu(self.input_layer(img_input), 0.2) 
        img_input = self.dropout(img_input)
        img_input = functions.leaky_relu(self.hidden_layer_1(img_input), 0.2)
        img_input = self.dropout(img_input)
        img_input = functions.leaky_relu(self.hidden_layer_2(img_input), 0.2)
        img_input = self.dropout(img_input)
        
        return self.output_layer(img_input)


## Generator Network

The generator part of a GAN learns to create fake data by incorporating feedback from the discriminator. It learns to make the discriminator classify its output as real.

Generator training requires tighter integration between the generator and the discriminator than discriminator training requires. The portion of the GAN that trains the generator includes:

- random input
- generator network, which transforms the random input into a data instance
- discriminator network, which classifies the generated data
- discriminator output
- generator loss, which penalizes the generator for failing to fool the discriminator


**Reference: https://developers.google.com/machine-learning/gan/discriminator**
   
<img src=" https://www.researchgate.net/publication/310610555/figure/fig4/AS:668380134662151@1536365650654/Architecture-of-the-generator-a-and-discriminator-b-of-our-cGAN-model-The-generator.ppm">

In [None]:
class Generator(neural_network.Module):

    def __init__(self):
        super(Generator, self).__init__()

        self.input_layer    =  neural_network.Linear(100, 32)
        self.hidden_layer_1 =  neural_network.Linear(32, 64)
        self.hidden_layer_2 =  neural_network.Linear(64, 128)
        self.output_layer   =  neural_network.Linear(128, 784)
        
        self.dropout = neural_network.Dropout(0.3)

    def forward(self, img_input):
        img_input = functions.leaky_relu(self.input_layer(img_input), 0.2)
        img_input = self.dropout(img_input)
        img_input = functions.leaky_relu(self.hidden_layer_1(img_input), 0.2)
        img_input = self.dropout(img_input)
        img_input = functions.leaky_relu(self.hidden_layer_2(img_input), 0.2)
        img_input = self.dropout(img_input)
        
        return functions.tanh(self.output_layer(img_input))


In [None]:
D = Discriminator()

In [None]:
G = Generator()

In [None]:
def real_loss(D_out):
    batch_size = D_out.size(0)
    criterion = neural_network.BCEWithLogitsLoss()
    return criterion(D_out.squeeze(), torch.ones(batch_size) * 0.9)
    

def fake_loss(D_out):
    criterion = neural_network.BCEWithLogitsLoss()
    return criterion(D_out.squeeze(), torch.zeros(D_out.size(0)))

In [None]:
d_optimizer = optim.Adam(D.parameters(), 0.002)
g_optimizer = optim.Adam(G.parameters(), 0.002)

# GAN Training
Because a GAN contains two separately trained networks, its training algorithm must address two complications:

- GANs must juggle two different kinds of training (generator and discriminator).
- GAN convergence is hard to identify.

Alternating Training
The generator and the discriminator have different training processes. So how do we train the GAN as a whole?

GAN training proceeds in alternating periods:

- The discriminator trains for one or more epochs.
- The generator trains for one or more epochs.
- Repeat steps 1 and 2 to continue to train the generator and discriminator networks.

**Reference: https://developers.google.com/machine-learning/gan/training**

## Train Discriminator

<img src="https://www.researchgate.net/publication/333831200/figure/fig5/AS:782113389412353@1563481771648/GAN-framework-in-which-the-generator-and-discriminator-are-learned-during-the-training.png">

In [None]:
def train_discriminator(G, D, optimizer, real_data, batch_size, size):
    real_data = scale(real_data.view(batch_size, -1))

    optimizer.zero_grad()
    D.train()

    fake_data = G.forward(random_vector(batch_size, size))
    total_loss = real_loss(D.forward(real_data)) + fake_loss(D.forward(fake_data))
    total_loss.backward()
    optimizer.step()
    
    return total_loss

## Train Generator

In [None]:

def train_generator(G, D, optimizer, batch_size, size):
    optimizer.zero_grad()
    G.train()
    fake_data = G.forward(random_vector(batch_size, size))
    loss = real_loss(D.forward(fake_data))
    loss.backward()
    optimizer.step()
    return loss

In [None]:
def random_vector(batch_size, lenx):
    return torch.randn(batch_size, lenx).float()

In [None]:
def show_training_sample(data):
    fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4)
    for ax, img in zip(axes.flatten(), data[-1]):
        img = img.detach()
        ax.imshow(img.reshape((28,28)))
    plt.show()
    
    
def show_training_loss(d_loss_, g_loss):
    plt.plot(d_loss_arr, label="N.Discriminator", alpha=0.5)
    plt.plot(g_loss_arr, label="N.Generator", alpha=0.5)
    plt.title("Trainings loss")
    plt.legend()
    plt.show()
    

In [None]:
import pickle as pkl

max_epochs = 100

samples = list()
losses =  list()
d_loss_arr = list()
g_loss_arr = list()

Z = torch.from_numpy(np.random.uniform(-1, 1, size=(16, 100))).float()

D.train()
G.train()
                           
for epoch in range(1, max_epochs):
    for index, (img, l) in enumerate(train):
        img = img * 2 - 1 
        
        d_optimizer.zero_grad()
        
        fake_images = G(torch.from_numpy(np.random.uniform(-1, 1, size=(img.size(0), 100))).float())

        d_loss = (real_loss(D(img))) + fake_loss(D(fake_images))
        d_loss.backward()
        d_optimizer.step()
        
        g_optimizer.zero_grad()

        fake_images = G(torch.from_numpy(np.random.uniform(-1, 1, size=(img.size(0), 100))).float())

        g_loss = real_loss(D(fake_images)) 
        
       
        g_loss.backward()
        g_optimizer.step()
        
        d_loss_arr.append(d_loss.item())
        g_loss_arr.append(g_loss.item())
        
      
        if index % 10 == 0 or index == 1:
            print(f'[{epoch}] -- discriminator_loss: {d_loss.item()}  -- generator_loss: {g_loss.item()}')
        

    losses.append((d_loss.item(), g_loss.item()))

    G.eval() 
    samples.append(G(Z))
    G.train()
    
    show_training_sample(samples)
    show_training_loss(d_loss_arr, g_loss_arr)