In [None]:
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim

from torch.autograd import Variable

import numpy as np

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from torchvision import datasets, transforms

%matplotlib inline

In [None]:
use_cuda = torch.cuda.is_available()

In [None]:
batch_size = 16
noise_size = 100
hidden_size = 128
learning_rate = 1e-3

In [None]:
mnist = torch.utils.data.DataLoader(
    datasets.MNIST('data/MNIST', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)

In [None]:
image_height = mnist.dataset.train_data.shape[1]
image_width = mnist.dataset.train_data.shape[2]
image_size = image_height * image_width
print(image_size)

In [None]:
def initialize_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_normal(m.weight)
        nn.init.constant(m.bias, 0.0)

In [None]:
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        
        self.h1 = nn.Linear(noise_size, hidden_size)
        self.g1 = nn.ReLU()
        self.h2 = nn.Linear(hidden_size, image_size)
        self.g2 = nn.Sigmoid()
        
        self.apply(initialize_weights)

    def forward(self, A0):
        Z1 = self.h1(A0)
        A1 = self.g1(Z1)
        Z2 = self.h2(A1)
        A2 = self.g2(Z2)
        return A2

In [None]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.h1 = nn.Linear(image_size, hidden_size)
        self.g1 = nn.ReLU()
        self.h2 = nn.Linear(hidden_size, 1)
        self.g2 = nn.Sigmoid()
        
        self.apply(initialize_weights)

    def forward(self, A0):
        Z1 = self.h1(A0)
        A1 = self.g1(Z1)
        Z2 = self.h2(A1)
        A2 = self.g2(Z2)
        return A2

In [None]:
G = Generator()
D = Discriminator()

G_solver = optim.Adam(G.parameters(), lr=learning_rate)
D_solver = optim.Adam(D.parameters(), lr=learning_rate)

ones_label = Variable(torch.ones(batch_size, 1))
zeros_label = Variable(torch.zeros(batch_size, 1))

In [None]:
def generate_noise():
    return torch.randn(batch_size, noise_size)

In [None]:
if use_cuda:
    G = G.cuda()
    D = D.cuda()
    
    ones_label = ones_label.cuda()
    zeros_label = zeros_label.cuda()

In [None]:
epochs = 10

for epoch in range(epochs):
    for iteration, (X, _) in enumerate(mnist):
        D_solver.zero_grad()

        Z = Variable(generate_noise())
        X = Variable(X).view(-1, image_size)
        if use_cuda:
            Z = Z.cuda()
            X = X.cuda()

        G_Z = G(Z)
        D_X = D(X)
        D_G = D(G_Z)

        D_X_loss = F.binary_cross_entropy(D_X, ones_label)
        D_G_loss = F.binary_cross_entropy(D_G, zeros_label)
        D_loss = D_X_loss + D_G_loss

        D_loss.backward()
        D_solver.step()

        G_solver.zero_grad()

        Z = Variable(generate_noise())
        if use_cuda:
            Z = Z.cuda()
        
        G_Z = G(Z)
        D_G = D(G_Z)

        G_loss = F.binary_cross_entropy(D_G, ones_label)

        G_loss.backward()
        G_solver.step()

    if use_cuda:
        D_loss = D_loss.cpu()
        G_loss = G_loss.cpu()
    print('Epoch {}; D loss: {}; G loss: {}'.format(epoch + 1, D_loss.data.numpy(), G_loss.data.numpy()))