In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import datetime

from __future__ import print_function

%matplotlib inline

In [None]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("data/MNIST", one_hot=True)

In [None]:
image_size = mnist.train.images.shape[1]

In [None]:
batch_size = 128
noise_size = 100
hidden_size = 128

In [None]:
with tf.variable_scope("E"):
    with tf.variable_scope("Layer_1"):
        W = tf.get_variable("W", [image_size, hidden_size], dtype=tf.float32,
                            initializer=tf.contrib.layers.xavier_initializer(seed=1))

        b = tf.get_variable("b", [hidden_size], dtype=tf.float32,
                            initializer=tf.zeros_initializer())

    with tf.variable_scope("Layer_mean"):
        W_mean = tf.get_variable("W", [hidden_size, noise_size], dtype=tf.float32,
                                 initializer=tf.contrib.layers.xavier_initializer(seed=1))

        b_mean = tf.get_variable("b", [noise_size], dtype=tf.float32,
                                 initializer=tf.zeros_initializer())

    with tf.variable_scope("Layer_logvar"):
        W_logvar = tf.get_variable("W", [hidden_size, noise_size], dtype=tf.float32,
                                   initializer=tf.contrib.layers.xavier_initializer(seed=1))

        b_logvar = tf.get_variable("b", [noise_size], dtype=tf.float32,
                                   initializer=tf.zeros_initializer())

def encoder(images):
    with tf.variable_scope("E"):
        with tf.variable_scope("Layer_1"):
            Z = tf.add(tf.matmul(images, W, name="matmul"), b, name="Z")
            A = tf.nn.relu(Z, name="A")

        with tf.variable_scope("Layer_mean"):
            mean = tf.add(tf.matmul(A, W_mean, name="matmul"), b_mean, name="mean")

        with tf.variable_scope("Layer_logvar"):
            logvar = tf.add(tf.matmul(A, W_logvar, name="matmul"), b_logvar, name="logvar")
            
    return mean, logvar

In [None]:
def reparametrize(mean, logvar):
    with tf.variable_scope("Z"):
        return mean + tf.random_normal([batch_size, noise_size], name="noise") * tf.exp(logvar / 2)

In [None]:
with tf.variable_scope("D"):
    with tf.variable_scope("Layer_1"):
        W1 = tf.get_variable("W", [noise_size, hidden_size], dtype=tf.float32,
                             initializer=tf.contrib.layers.xavier_initializer(seed=1))

        b1 = tf.get_variable("b", [hidden_size], dtype=tf.float32,
                             initializer=tf.zeros_initializer())

    with tf.variable_scope("Layer_2"):
        W2 = tf.get_variable("W", [hidden_size, image_size], dtype=tf.float32,
                             initializer=tf.contrib.layers.xavier_initializer(seed=1))

        b2 = tf.get_variable("b", [image_size], dtype=tf.float32,
                             initializer=tf.zeros_initializer())

def decoder(noise):
    with tf.variable_scope("D"):
        with tf.variable_scope("Layer_1"):
            Z1 = tf.add(tf.matmul(noise, W1, name="matmul"), b1, name="Z")
            A1 = tf.nn.relu(Z1, name="A")

        with tf.variable_scope("Layer_2"):
            Z2 = tf.add(tf.matmul(A1, W2, name="matmul"), b2, name="Z")
            A2 = tf.nn.sigmoid(Z2, name="A")
            
    return Z2, A2

In [None]:
sample_size = 16
sample_grid = int(np.ceil(np.sqrt(sample_size)))

In [None]:
images = tf.placeholder(dtype=tf.float32, shape=[None, image_size], name="images")
sample_noise = tf.placeholder(dtype=tf.float32, shape=[None, noise_size], name="sample_noise")

mean, logvar = encoder(images)
noise = reparametrize(mean, logvar)
logits, probabilities = decoder(noise)
    
with tf.variable_scope("sample"):
    sample_logits, sample_probabilities = decoder(sample_noise)
    
with tf.variable_scope("optimizer"):
    with tf.variable_scope("recontruction_loss"):
        recontruction_loss = tf.reduce_sum(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=images), 1)

    with tf.variable_scope("kl_loss"):
        kl_loss = 0.5 * tf.reduce_sum(tf.exp(logvar) + mean ** 2 - 1. - logvar, 1)

    with tf.variable_scope("loss"):
        loss = tf.reduce_mean(recontruction_loss + kl_loss)

    optimizer = tf.train.AdamOptimizer().minimize(loss)

In [None]:
def plot_samples(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(sample_grid, sample_grid)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis("off")
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect("equal")
        plt.imshow(sample.reshape(28, 28), cmap="Greys_r")
        
    plt.show()
    plt.pause(0.001)

    return fig

In [None]:
iterations = 500000
log_every = 10000

plt.ion()

with tf.Session() as session:
    session.run(tf.global_variables_initializer())

    for iteration in range(iterations):
        batch_images, _ = mnist.train.next_batch(batch_size)

        session.run([optimizer], feed_dict={images: batch_images})

        if iteration % log_every == log_every - 1:
            print(iteration + 1)
            samples, = session.run([sample_probabilities],
                                   feed_dict={sample_noise: np.random.normal(size=(sample_size, noise_size))})
            plot_samples(samples)