# Generative Adversarial Networks

Paper: [Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems. 2014.](https://papers.nips.cc/paper/5423-generative-adversarial-nets)

In [None]:
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim

from torch.autograd import Variable

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from torchvision import datasets, transforms

from tqdm import tqdm

%matplotlib inline

## Load the dataset

In [None]:
use_cuda = torch.cuda.is_available()

In [None]:
batch_size = 16
noise_size = 100
hidden_size = 128
learning_rate = 1e-3

In [None]:
mnist = torch.utils.data.DataLoader(
    datasets.MNIST('data/mnist/raw/', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

In [None]:
image_height = mnist.dataset.train_data.shape[1]
image_width = mnist.dataset.train_data.shape[2]
image_size = image_height * image_width
print("Number of features:", image_size)

## Models

In [None]:
def initialize_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0.0)

In [None]:
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        
        self.h1 = nn.Linear(noise_size, hidden_size)
        self.g1 = nn.ReLU()
        self.h2 = nn.Linear(hidden_size, image_size)
        self.g2 = nn.Sigmoid()
        
        self.apply(initialize_weights)

    def forward(self, A0):
        Z1 = self.h1(A0)
        A1 = self.g1(Z1)
        Z2 = self.h2(A1)
        A2 = self.g2(Z2)
        return A2

In [None]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.h1 = nn.Linear(image_size, hidden_size)
        self.g1 = nn.ReLU()
        self.h2 = nn.Linear(hidden_size, 1)
        self.g2 = nn.Sigmoid()
        
        self.apply(initialize_weights)

    def forward(self, A0):
        Z1 = self.h1(A0)
        A1 = self.g1(Z1)
        Z2 = self.h2(A1)
        A2 = self.g2(Z2)
        return A2

## Train the networks

In [None]:
G = Generator()
D = Discriminator()

G_solver = optim.Adam(G.parameters(), lr=learning_rate)
D_solver = optim.Adam(D.parameters(), lr=learning_rate)

ones_label = Variable(torch.ones(batch_size, 1))
zeros_label = Variable(torch.zeros(batch_size, 1))

In [None]:
def generate_noise(size):
    return torch.randn(size, noise_size)

In [None]:
def generate_images(size):
    with torch.no_grad():
        Z = Variable(generate_noise(size))
        if use_cuda:
            Z = Z.cuda()
        G_Z = G(Z)
    return G_Z.data

In [None]:
def plot_images(images):
    fig = plt.figure(figsize=(len(images), 1))
    gs = gridspec.GridSpec(1, len(images))
    gs.update(wspace=0.05, hspace=0.05)
    
    for i, image in enumerate(images):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')

        plt.imshow(image.reshape(image_width, image_height),
                   cmap='Greys_r',
                   interpolation='None')
        
    plt.show()

In [None]:
if use_cuda:
    G = G.cuda()
    D = D.cuda()
    
    ones_label = ones_label.cuda()
    zeros_label = zeros_label.cuda()

In [None]:
plot_size = 10
plot_images(generate_images(plot_size))

epochs = 10

for epoch in range(epochs):
    with tqdm(total=len(mnist)) as progress_bar:
        progress_bar.set_description("Epoch {:d}/{:d}".format(epoch + 1, epochs))
        for iteration, (X, _) in enumerate(mnist):
            D_solver.zero_grad()

            Z = Variable(generate_noise(batch_size))
            X = Variable(X).view(-1, image_size)
            if use_cuda:
                Z = Z.cuda()
                X = X.cuda()

            G_Z = G(Z)
            D_X = D(X)
            D_G = D(G_Z)

            D_X_loss = F.binary_cross_entropy(D_X, ones_label)
            D_G_loss = F.binary_cross_entropy(D_G, zeros_label)
            D_loss = D_X_loss + D_G_loss

            D_loss.backward()
            D_solver.step()

            G_solver.zero_grad()

            Z = Variable(generate_noise(batch_size))
            if use_cuda:
                Z = Z.cuda()

            G_Z = G(Z)
            D_G = D(G_Z)

            G_loss = F.binary_cross_entropy(D_G, ones_label)

            G_loss.backward()
            G_solver.step()
            
            if use_cuda:
                D_loss = D_loss.cpu()
                G_loss = G_loss.cpu()
            
            progress_bar.set_postfix(D_loss="{:.03f}".format(D_loss.data.item()),
                                     G_loss="{:.03f}".format(G_loss.data.item()))
            progress_bar.update()
        
    plot_images(generate_images(plot_size))