In [0]:
import torch
import torch.nn
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data
from itertools import *


mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 32
z_dim = 10
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
cnt = 0
lr = 1e-3


def log(x):
    return torch.log(x + 1e-8)


# Inference net (Encoder) Q(z|X)
Q = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, z_dim)
)

# Generator net (Decoder) P(X|z)
P = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
)

D_ = torch.nn.Sequential(
    torch.nn.Linear(X_dim + z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1),
    torch.nn.Sigmoid()
)


def D(X, z):
    return D_(torch.cat([X, z], 1))


def reset_grad():
    Q.zero_grad()
    P.zero_grad()
    D_.zero_grad()


G_solver = optim.Adam(chain(Q.parameters(), P.parameters()), lr=lr)
D_solver = optim.Adam(D_.parameters(), lr=lr)


for it in range(1000000):
    # Sample data
    z = Variable(torch.randn(mb_size, z_dim))
    X, _ = mnist.train.next_batch(mb_size)
    X = Variable(torch.from_numpy(X))

    # Discriminator
    z_hat = Q(X)
    X_hat = P(z)

    D_enc = D(X, z_hat)
    D_gen = D(X_hat, z)

    D_loss = -torch.mean(log(D_enc) + log(1 - D_gen))

    D_loss.backward()
    D_solver.step()
    G_solver.step()
    reset_grad()

    # Autoencoder Q, P
    z_hat = Q(X)
    X_hat = P(z)

    D_enc = D(X, z_hat)
    D_gen = D(X_hat, z)

    G_loss = -torch.mean(log(D_gen) + log(1 - D_enc))

    G_loss.backward()
    G_solver.step()
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}'
              .format(it, D_loss.data, G_loss.data))

        samples = P(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
        cnt += 1
plt.close(fig)

Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz
Iter-0; D_loss: 1.422; G_loss: 1.834
Iter-1000; D_loss: 0.4682; G_loss: 6.367
Iter-2000; D_loss: 1.048; G_loss: 3.732
Iter-3000; D_loss: 0.2675; G_loss: 8.972
Iter-4000; D_loss: 0.8141; G_loss: 5.205
Iter-5000; D_loss: 0.6362; G_loss: 4.389
Iter-6000; D_loss: 0.663; G_loss: 5.814
Iter-7000; D_loss: 0.4118; G_loss: 7.871
Iter-8000; D_loss: 0.9557; G_loss: 6.43
Iter-9000; D_loss: 0.804; G_loss: 6.321
Iter-10000; D_loss: 0.2697; G_loss: 11.7
Iter-11000; D_loss: 1.126; G_loss: 6.42
Iter-12000; D_loss: 1.042; G_loss: 5.864
Iter-13000; D_loss: 0.9941; G_loss: 4.802
Iter-14000; D_loss: 1.169; G_loss: 3.814
Iter-15000; D_loss: 0.7733; G_loss: 5.571
Iter-16000; D_loss: 0.7557; G_loss: 5.368
Iter-17000; D_loss: 1.023; G_loss: 4.883
Iter-18000; D_loss: 0.804; G_loss: 6.01
Iter-19000



Iter-21000; D_loss: 1.042; G_loss: 4.385
Iter-22000; D_loss: 0.8814; G_loss: 6.471
Iter-23000; D_loss: 0.587; G_loss: 6.394
Iter-24000; D_loss: 0.733; G_loss: 4.752
Iter-25000; D_loss: 0.6253; G_loss: 5.496
Iter-26000; D_loss: 0.915; G_loss: 4.988
Iter-27000; D_loss: 0.7199; G_loss: 6.139
Iter-28000; D_loss: 0.8066; G_loss: 4.493
Iter-29000; D_loss: 0.6729; G_loss: 5.889
Iter-30000; D_loss: 0.8404; G_loss: 4.299
Iter-31000; D_loss: 0.7865; G_loss: 4.888
Iter-32000; D_loss: 0.7477; G_loss: 4.867
Iter-33000; D_loss: 0.6541; G_loss: 5.676
Iter-34000; D_loss: 0.5396; G_loss: 5.979
Iter-35000; D_loss: 0.5718; G_loss: 6.842
Iter-36000; D_loss: 0.7233; G_loss: 6.454
Iter-37000; D_loss: 0.6916; G_loss: 7.283
Iter-38000; D_loss: 0.5481; G_loss: 6.289
Iter-39000; D_loss: 0.946; G_loss: 5.962
Iter-40000; D_loss: 0.6878; G_loss: 6.95
Iter-41000; D_loss: 0.8118; G_loss: 6.661
Iter-42000; D_loss: 0.6892; G_loss: 5.203
Iter-43000; D_loss: 0.8961; G_loss: 5.451
Iter-44000; D_loss: 0.8511; G_loss: 4.09