# Variational Autoencoders

<img src="VAE.JPG">

## Download MNIST and load it

In [None]:
import os
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline
import shutil
from keras.datasets import mnist


def img_tile(imgs, aspect_ratio=1.0, tile_shape=None, border=1,
             border_color=0):
    imgs = np.array(imgs)
    if imgs.ndim != 3 and imgs.ndim != 4:
        raise ValueError('imgs has wrong number of dimensions.')
    n_imgs = imgs.shape[0]

    # Grid shape
    img_shape = np.array(imgs.shape[1:3])
    if tile_shape is None:
        img_aspect_ratio = img_shape[1] / float(img_shape[0])
        aspect_ratio *= img_aspect_ratio
        tile_height = int(np.ceil(np.sqrt(n_imgs * aspect_ratio)))
        tile_width = int(np.ceil(np.sqrt(n_imgs / aspect_ratio)))
        grid_shape = np.array((tile_height, tile_width))
    else:
        assert len(tile_shape) == 2
        grid_shape = np.array(tile_shape)

    # Tile image shape
    tile_img_shape = np.array(imgs.shape[1:])
    tile_img_shape[:2] = (img_shape[:2] + border) * grid_shape[:2] - border

    # Assemble tile image
    tile_img = np.empty(tile_img_shape)
    tile_img[:] = border_color
    for i in range(grid_shape[0]):
        for j in range(grid_shape[1]):
            img_idx = j + i * grid_shape[1]
            if img_idx >= n_imgs:
                # No more images - stop filling out the grid.
                break
            img = imgs[img_idx]
            yoff = (img_shape[0] + border) * i
            xoff = (img_shape[1] + border) * j
            tile_img[yoff:yoff + img_shape[0], xoff:xoff + img_shape[1], ...] = img

    return tile_img


def plot_network_output(data, reconst_data, generated, step):
    num = 8
    
    fig, ax = plt.subplots(nrows=3, ncols=num, figsize=(18, 6))
    for i in range(num):
        ax[(0, i)].imshow(np.squeeze(generated[i]), cmap=plt.cm.gray)
        ax[(1, i)].imshow(np.squeeze(data[i]), cmap=plt.cm.gray)
        ax[(2, i)].imshow(np.squeeze(reconst_data[i]), cmap=plt.cm.gray)
        ax[(0, i)].axis('off')
        ax[(1, i)].axis('off')
        ax[(2, i)].axis('off')

    fig.suptitle('Top: generated | Middle: data | Bottom: recunstructed')
#     plt.show()
    plt.savefig(IMAGE_DIR + '/{}.png'.format(str(step).zfill(6)))
    plt.close()
    
    
(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data / 255.
test_data = test_data / 255.

## show MNIST

In [None]:
# size of MNIST
print(train_data.shape)
print(train_label.shape)
print(test_data.shape)
print(test_label.shape)

In [None]:
# show data
idx = np.random.randint(0, train_data.shape[0])
_, (ax1, ax2) = plt.subplots(1, 2)
sample_data = train_data[idx]
ax1.imshow(sample_data, 'gray');
ax2.hist(sample_data, bins=20, range=[0, 1]);

## Delete summary folder and make it

In [None]:
SUMMARY_DIR = './vae_summary'
TRAIN_DIR = SUMMARY_DIR + '/train'
TEST_DIR = SUMMARY_DIR + '/test'
IMAGE_DIR = SUMMARY_DIR + '/image'

if os.path.exists(SUMMARY_DIR):
    shutil.rmtree(SUMMARY_DIR)
if not os.path.exists(SUMMARY_DIR):
    os.makedirs(SUMMARY_DIR)
    os.makedirs(TRAIN_DIR)
    os.makedirs(TEST_DIR)
    os.makedirs(IMAGE_DIR)

## Define tensorflow graph

In [None]:
def fully_connected(inputs, out_channel, name='fc'):
    """
    very simple fully connected layer function

    Args:
        inputs: a batch of input tensor [batch_size, n]
                where n is the number of feature dimension
        out_channel: output channel dimension

    Returns:
        inputs * weights + biases [batch_size, out_channel]
    """
    # in_channel: input channel dimension
    # w_shape: shape of weight matrix
    # b_shape: shape of bias vector
    in_channel = inputs.get_shape().as_list()[1]
    w_shape = [in_channel, out_channel]
    b_shape = [out_channel]

    # Define weight matrix variable, bias vector variable
    with tf.variable_scope(name):
        # To share the variables you have to use
        # a function 'tf.get_variable' instead of 'tf.Variable'
        weights = tf.get_variable('weights', shape=w_shape,
                                  initializer=tf.truncated_normal_initializer(stddev=0.02))
        biases = tf.get_variable('biases', shape=b_shape,
                                 initializer=tf.constant_initializer(0.0))

        fc = tf.matmul(inputs, weights)
        fc = tf.nn.bias_add(fc, biases)

        return fc


def encoder(x, z_dim):
    """
    build the encoder

    Args:
        x: a batch of input to the network [batch_size, 28, 28, 1]
        z_dim: dimension of the latent variable z

    returns:
        z_mean: mean of the latent variable [batch_size, n_latent]
        z_log_sigma_sq : log sigma squre of the latent variable [batch_size, n_latent]
    """
    with tf.variable_scope('encoder') as scope:
        
        # Vectorize the input x
        # Fully connected layer with 256 output units and 'fc1' as its name
        # Apply non-linearity function 'relu'
        # Fully connected layer with 256 output units and 'fc2' as its name
        # Apply non-linearity function 'relu'
        # Fully connected layer with z_dim * 2 units and 'fc3' as its name
        # split the final tensor into mean and the log sigma square of the latent variable
        # Return the final tensors
        
        return z_mean, z_log_sigma_sq


def decoder(z, reuse=False):
    """
    build the decoder

    Args:
        z: a batch of input to the network [batch_size, n_latent]

    Returns:
        net: output of the generator [batch_size, 28, 28, 1]
    """
    with tf.variable_scope('generator') as scope:
        if reuse:
            scope.reuse_variables()

        # Fully connected layer with 256 output units and 'fc1' as its name
        # Apply non-linearity function 'relu'
        # Fully connected layer with 256 output units and 'fc2' as its name
        # Apply non-linearity function 'relu'
        # Fully connected layer with 784 output units and 'fc3' as its name
        # Apply non-linearity function 'sigmoid'
        # Reshape final output to be a proper image file [28, 28, 1]
        # Return the final tensor
        
        
        return net

### reconstrunction_loss = $$-\Sigma_{i=1}^{D} x_{i}log y_{i} + (1-x_{i})log(1-y_{i}))$$
### KL_loss = $$-\frac{1}{2}\Sigma_{j=1}^{J}(1+log \sigma_{j}^2 - \mu_{j}^2 -\sigma_{j}^2)$$
    
### Loss = reconstruction_loss + KL_loss

In [None]:
def get_loss(x, reconst_x, z_mean, z_log_sigma_sq, eps=1e-8):
    """
    get loss of GAN

    Args:
        x: input tensor [batch_size, 28, 28, 1]
        reconst_x: reconstructed tensor [batch_size, 28, 28, 1]
        z_mean: mean of the latent variable [batch_size, z_dim]
        z_log_sigma_sq: log sigma square of the latent variable [batch_size, z_dim]

    Returns:
        reconst_loss: reconstruction loss
        kl_loss: regularization loss
    """

    return reconst_loss, kl_loss


def get_next_batch(data, label, batch_size):
    """
    get 'batch_size' amount of data and label randomly

    Args:
        data: data
        label: label
        batch_size: # of data to get

    Returns:
        batch_data: data of 'batch_size'
        batch_label: coresponding label of batch_data
    """
    n_data = data.shape[0]
    random_idx = random.sample(range(1, n_data), batch_size)

    batch_data = data[random_idx]
    batch_label = label[random_idx]
    return batch_data, batch_label

In [None]:
# Set hyperparameters
batch_size = 100
z_dim = 2
max_step = 20000
lr = 0.001
beta1 = 0.9

# expand the data to be 3 dimensional data.
train_data = np.expand_dims(train_data, 3)
test_data = np.expand_dims(test_data, 3)

############################# Build the model #############################
# Define image tensor x placeholder
x = tf.placeholder(tf.float32, [batch_size, 28, 28, 1], name='input_x')
# Define latent tensor z placeholder
z = tf.placeholder(tf.float32, [batch_size, z_dim], name='input_z')

# Defin normal distribution (mu=0, sigma=1)

# Build encoder

# Get epsilon to recover encoded_z

# Get encoded_z using z_mean, z_log_sigma_sq, eps

# Build decoder with encoded_z which outputs reconst_x
# reconst_x is reconstruction of the input data x

# Build decoder with placeholder z which outputs sample_x
# sample_x is generated data from VAE

# Get reconst_loss and kl_loss

# Make optimization op
opt = tf.train.AdamOptimizer(lr, beta1=beta1)

# Make train op for each network
train = opt.minimize(loss)

# Make initialization op
init = tf.global_variables_initializer()

## Train VAE

In [None]:
with tf.Session() as sess:
    # Define writer
    train_writer = tf.summary.FileWriter(TRAIN_DIR, sess.graph)
    test_writer = tf.summary.FileWriter(TEST_DIR)
    
    # Initialize variables
    sess.run(init)
    
    # Before train the model, shows train data and save it
    batch_x, batch_y = get_next_batch(train_data, train_label, batch_size)
    train_tiled = img_tile(batch_x, border_color=1.0)
    train_tiled = np.squeeze(train_tiled)
    print("Training data")
    plt.imshow(train_tiled, cmap=plt.cm.gray)
    plt.show()
    plt.imsave(IMAGE_DIR + '/train.png', train_tiled, cmap=plt.cm.gray)

    samples = []
    canvases = []
    # Train the model
    for step in range(max_step):
        batch_x, batch_y = get_next_batch(train_data, train_label, batch_size)
        batch_z = np.random.normal(loc=0., scale=1.0, size=[batch_size, z_dim])
        
        _, reconst_losses, kl_losses = sess.run(
            [train, reconst_loss, kl_loss], feed_dict={x: batch_x, z: batch_z})
        summary = sess.run(merged, feed_dict={x: batch_x, z: batch_z})
        train_writer.add_summary(summary, step)

        # Save generarted data to make gif files
        if step % 50 == 0:
            r_x, s_x = sess.run([reconst_x, sample_x], feed_dict={x: batch_x, z: batch_z})
            sample_tiled = img_tile(s_x, border_color=1.0)
            sample_tiled = np.squeeze(sample_tiled)
            samples.append(sample_tiled)
            
            nx = ny = 20
            x_values = np.linspace(-3, 3, nx)
            y_values = np.linspace(-3, 3, ny)
            canvas = np.empty((28*ny, 28*nx))
            for i, yi in enumerate(x_values):
                for j, xi in enumerate(y_values):
                    z_mu = np.array([[xi, yi]] * batch_size)
                    x_mean = sess.run(sample_x, feed_dict={z: z_mu})
                    canvas[(nx - i - 1) * 28:(nx - i) * 28, j * 28:(j + 1) * 28] = x_mean[0].reshape(28, 28)
            canvases.append(canvas)

        # Log loss and save train data and reconstructed data
        if step % 200 == 0:
            plot_network_output(batch_x, r_x, s_x, step)
            print("{} steps |  total_loss: {:.4f}, KL_loss: {:.4f}, reconst_loss: {:.4f}".format(
                step, kl_losses + reconst_losses, kl_losses, reconst_losses))
            plt.imshow(sample_tiled, cmap=plt.cm.gray)
            plt.show()

In [None]:
import imageio
# Make gif files
imageio.mimsave(SUMMARY_DIR + '/generated.gif', samples)
imageio.mimsave(SUMMARY_DIR + '/canvase.gif', canvases)