## Author: Francisco Carrillo-Perez

# 3. Generative Adversarial Networks

In this notebook we are going to:
   - Learn to code our first Generative Adversarial Network.
   - Learn to train and test our model.

### 3.1. Utils functions

First we are going to code some useful functions. One of them will be the conversion from images to vectors and from vectos to images, in order to represent mnist images for generator and discriminator. Also, we will do a function to return n samples of Gaussian noise, for the generator input.

In [8]:
import torch
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
import os
from torchvision.utils import save_image

def mnist_dataset():
    compose = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
        ])
    out_dir = "./dataset"
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

def images_to_vectors(images):
    """
    Flatten 28x28 image to 784 vector
    """
    return images.view(images.size(0), 784)

def vectors_to_images(vectors):
    """
    784 vector to 28x28 images
    """
    return vectors.view(vectors.size(0),1, 28, 28)

def noise(size):
    '''
    Generates a 1-d vector of gaussian sampled random values
    '''
    n = Variable(torch.randn(size, 100))
    return n

def ones_target(size):
    '''
    Tensor containing ones, with shape = size
    '''
    data = Variable(torch.ones(size, 1))
    return data

def zeros_target(size):
    '''
    Tensor containing zeros, with shape = size
    '''
    data = Variable(torch.zeros(size, 1))
    return data

def save_figures(images,path="./generated/"):
    # if dir doesnot exist, create it
    if not os.path.exists(path):
        os.makedirs(path)
    name = "generated.png"
    save_image(images,path+name)

### 3.2. Generator

Generator is a model that aims to generate new data similar to the expected one. G(z, θ₁) is used to model the Generator. It’s role is mapping input noise variables z to the desired data space x (say images).

In [4]:
import torch
from torch import nn, optim

class GeneratorNet(torch.nn.Module):
    """
    Three hidden layer generative neural network
    """

    def __init__(self):
        super(GeneratorNet, self).__init__()
        n_features = 100 # latent variable vector
        n_out = 784 # image generated

        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 256),
            nn.LeakyReLU(0.2)
        )
        self.hidden1 = nn.Sequential(            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(1024, n_out),
            nn.Tanh()
        )

        self.optimizer = optim.Adam(self.parameters(), lr=0.00002)

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
    def train(self, discriminator, fake_data):
        loss = nn.BCELoss()
        N = fake_data.size(0)

        # Reset gradients
        self.optimizer.zero_grad()

        # Sample noise and generate fake data
        prediction = discriminator(fake_data)

        # Calculate error and backpropagate
        error = loss(prediction, ones_target(N))
        error.backward()

        # Update weights with gradient
        self.optimizer.step()

        return error

### 3.2. Discriminator

The Discriminator model, however, aims to recognize if an input data is ‘real’ — belongs to the original dataset — or if it is ‘fake’ — generated by a forger. D(x, θ₂) models the discriminator and outputs the probability that the data came from the real dataset, in the range (0,1).

In [5]:
class DiscriminatorNet(torch.nn.Module):
    """
    Three hidden layer discriminative NN
    """
    def __init__(self):
        super(DiscriminatorNet, self).__init__()
        n_features = 784 # 28x28
        n_out = 1 # fake or real
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden1 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.out = nn.Sequential(
            nn.Linear(256, n_out),
            nn.Sigmoid()
        )

        self.optimizer = optim.Adam(self.parameters(), lr=0.00002)

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
    def train(self, real_data, fake_data):
        loss = nn.BCELoss()
        N = real_data.size(0)
        # reset gradients
        self.optimizer.zero_grad()

        # 1.1 Train on real data
        prediction_real = self(real_data)
        # Calculate error and backpropagate
        error_real = loss(prediction_real, ones_target(N))
        error_real.backward()

        # 1.2 Train on fake data
        prediction_fake = self(fake_data)
        # calculate error and backpropagate
        error_fake = loss(prediction_fake, zeros_target(N))
        error_fake.backward()

        # 1.3 Update weights with gradients
        self.optimizer.step()

        # Return error and predictions for real and fake inputs
        return error_real + error_fake, prediction_real, prediction_fake

### 3.3. Train our model
Let's train our two architectures at the same time

In [13]:
import torch
from torch import nn, optim
from torch.autograd.variable import Variable
import os
from IPython.core.debugger import set_trace
# Load data
data = mnist_dataset()

# Create loader with data, so that we can iterate over it
data_loader = torch.utils.data.DataLoader(data, batch_size=100, shuffle=True)

# Num batches
num_batches = len(data_loader)

# total number of epochs
n_epochs = 200

num_test_samples = 16
test_noise = noise(num_test_samples)

generator = GeneratorNet()
discriminator = DiscriminatorNet()


print("Starting training")
for epoch in range(n_epochs):
    for n_batch, (real_batch, _) in enumerate(data_loader):
        N = real_batch.size(0)

        # Train discriminator
        real_data = Variable(images_to_vectors(real_batch))

        # Generate fake data and detach 
        # (so gradients are not calculated for generator)
        fake_data = generator(noise(N)).detach()

        # Train Discriminator
        d_error, d_pred_real, d_pred_fake = discriminator.train(real_data, fake_data)

        # Train Generator

        # Generate fake data
        fake_data = generator(noise(N))

        # Train g
        g_error = generator.train(discriminator, fake_data)
        
        if (n_batch) % 100 == 0:
            print("-------")
            print("Discriminator error: {}".format(d_error.data))
            print("\n")
            print("Generator error: {}".format(g_error.data))
            print("-------")
        
    if (epoch) % 20 == 0:
        print("Saving samples")
        test_images = vectors_to_images(generator(test_noise))
        test_images = test_images.data
        save_figures(test_images,path="./generated/epoch-"+(str(epoch))+"/")

Starting training
-------
Discriminator error: 1.360549807548523


Generator error: 0.6761853098869324
-------
-------
Discriminator error: 1.035473108291626


Generator error: 0.5511122941970825
-------
-------
Discriminator error: 1.0761303901672363


Generator error: 0.6966161131858826
-------
-------
Discriminator error: 0.93159019947052


Generator error: 0.9237263202667236
-------
-------
Discriminator error: 0.694118082523346


Generator error: 1.2150769233703613
-------
-------
Discriminator error: 0.6711529493331909


Generator error: 1.195939540863037
-------
Saving samples
-------
Discriminator error: 0.5739818215370178


Generator error: 1.3262486457824707
-------
-------
Discriminator error: 1.3923016786575317


Generator error: 0.6596801280975342
-------
-------
Discriminator error: 1.1216189861297607


Generator error: 0.855226457118988
-------
-------
Discriminator error: 1.108191967010498


Generator error: 0.9839785695075989
-------
-------
Discriminator error: 1.0717

KeyboardInterrupt: 