In [1]:
import torch
import torch.nn
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
z_dim = 10
cnt = 0

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [2]:
G = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
)


D = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1),
)

In [None]:
def reset_grad():
    G.zero_grad()
    D.zero_grad()


G_solver = optim.RMSprop(G.parameters(), lr=0.0001)
D_solver = optim.RMSprop(D.parameters(), lr=0.0001)

In [None]:
for it in range(1000000):
    for _ in range(5):
        # Sample data
        z = Variable(torch.randn(32, 10))
        X, _ = mnist.train.next_batch(32)
        X = Variable(torch.from_numpy(X))

        # Dicriminator forward-loss-backward-update
        G_sample = G(z)
        D_real = D(X)
        D_fake = D(G_sample)

        D_loss = -(torch.mean(D_real) - torch.mean(D_fake))

        D_loss.backward()
        D_solver.step()

        # Weight clipping
        for p in D.parameters():
            p.data.clamp_(-0.01, 0.01)

        # Housekeeping - reset gradient
        reset_grad()

    # Generator forward-loss-backward-update
    X, _ = mnist.train.next_batch(32)
    X = Variable(torch.from_numpy(X))
    z = Variable(torch.randn(32, 10))

    G_sample = G(z)
    D_fake = D(G_sample)

    G_loss = -torch.mean(D_fake)

    G_loss.backward()
    G_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'
              .format(it, D_loss.data.numpy(), G_loss.data.numpy()))

        samples = G(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Iter-0; D_loss: [-0.12624831]; G_loss: [ 0.16157164]
Iter-1000; D_loss: [-0.03480838]; G_loss: [-0.00120416]
Iter-2000; D_loss: [-0.00738919]; G_loss: [ 0.0535388]
Iter-3000; D_loss: [-0.05886087]; G_loss: [ 0.00255544]
Iter-4000; D_loss: [-0.04449232]; G_loss: [ 0.00304415]
Iter-5000; D_loss: [-0.03686142]; G_loss: [-0.00836258]
Iter-6000; D_loss: [-0.03010887]; G_loss: [-0.00741883]
Iter-7000; D_loss: [-0.03416763]; G_loss: [-0.0095391]
Iter-8000; D_loss: [-0.03482541]; G_loss: [-0.00438419]
Iter-9000; D_loss: [-0.0277044]; G_loss: [-0.0063392]
Iter-10000; D_loss: [-0.02745331]; G_loss: [-0.00485104]
Iter-11000; D_loss: [-0.0272183]; G_loss: [-0.00546759]
Iter-12000; D_loss: [-0.02287278]; G_loss: [-0.00616492]
Iter-13000; D_loss: [-0.02330181]; G_loss: [-0.00553512]
Iter-14000; D_loss: [-0.02414522]; G_loss: [-0.00600764]
Iter-15000; D_loss: [-0.02072017]; G_loss: [-0.00366222]
Iter-16000; D_loss: [-0.02135508]; G_loss: [-0.00627525]
Iter-17000; D_loss: [-0.01786141]; G_loss: [-0.00