In [2]:
import torch
import torch.nn
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data


mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
cnt = 0

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting MNIST_data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
G = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
)


D = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1),
)

In [5]:
def reset_grad():
    G.zero_grad()
    D.zero_grad()


G_solver = optim.RMSprop(G.parameters(), lr=0.0001)
D_solver = optim.RMSprop(D.parameters(), lr=0.0001)

In [6]:
for it in range(1000000):
    for _ in range(5):
        # Sample data
        z = Variable(torch.randn(32, 10))
        X, _ = mnist.train.next_batch(32)
        X = Variable(torch.from_numpy(X))

        # Dicriminator forward-loss-backward-update
        G_sample = G(z)
        D_real = D(X)
        D_fake = D(G_sample)

        D_loss = -(torch.mean(D_real) - torch.mean(D_fake))

        D_loss.backward()
        D_solver.step()

        # Weight clipping
        for p in D.parameters():
            p.data.clamp_(-0.01, 0.01)

        # Housekeeping - reset gradient
        reset_grad()

    # Generator forward-loss-backward-update
    X, _ = mnist.train.next_batch(32)
    X = Variable(torch.from_numpy(X))
    z = Variable(torch.randn(32, 10))

    G_sample = G(z)
    D_fake = D(G_sample)

    G_loss = -torch.mean(D_fake)

    G_loss.backward()
    G_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'
              .format(it, D_loss.data.numpy(), G_loss.data.numpy()))

        samples = G(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)

Iter-0; D_loss: [-0.13205393]; G_loss: [ 0.17916088]
Iter-1000; D_loss: [-0.03855303]; G_loss: [ 0.0265324]
Iter-2000; D_loss: [ 0.00039473]; G_loss: [ 0.02957562]
Iter-3000; D_loss: [-0.03051957]; G_loss: [ 0.03755067]
Iter-4000; D_loss: [-0.03814043]; G_loss: [-0.00855665]
Iter-5000; D_loss: [-0.03380208]; G_loss: [-0.0035083]
Iter-6000; D_loss: [-0.03416292]; G_loss: [ 0.00056723]
Iter-7000; D_loss: [-0.03160179]; G_loss: [-0.00885035]
Iter-8000; D_loss: [-0.02553104]; G_loss: [-0.01521913]
Iter-9000; D_loss: [-0.02716193]; G_loss: [-0.01076826]
Iter-10000; D_loss: [-0.02417421]; G_loss: [-0.00784746]
Iter-11000; D_loss: [-0.02789254]; G_loss: [-0.00908536]
Iter-12000; D_loss: [-0.0254624]; G_loss: [-0.00729543]
Iter-13000; D_loss: [-0.02547303]; G_loss: [-0.00636355]
Iter-14000; D_loss: [-0.02813479]; G_loss: [-0.00762533]
Iter-15000; D_loss: [-0.0196582]; G_loss: [-0.00554279]
Iter-16000; D_loss: [-0.01845284]; G_loss: [-0.00166655]
Iter-17000; D_loss: [-0.01922079]; G_loss: [-0.0

KeyboardInterrupt: 