______________________________________________________________________________________________________
This notebook contains a simple variational autoencoder. 

The following resources have been helpful:
* https://arxiv.org/pdf/1606.05908v2.pdf
* https://jmetzen.github.io/2015-11-27/vae.html
_______________________________________________________________________________________________________

# Setup

In [None]:
from __future__ import print_function

import numpy as np
import time
import matplotlib.pyplot as plt
import tensorflow as tf

import sys
sys.path.append('..')
import models.VAE as vae
import models.CVAE as cvae
import models.GAN as gan

np.random.seed(0)

In [None]:
# load data
import models.DataReader as Data

# get data handler
data_type = 'mnist'
data_dir = '/home/mattw/Dropbox/git/dreamscape/data/'

if data_type is 'mnist':
    data = Data.DataReaderMNIST(data_dir + 'mnist/', one_hot=True)
elif data_type is 'cifar':
    data = Data.DataReaderCIFAR(data_dir + 'cifar/', one_hot=True)      

# Define and Train a VAE

In [None]:
saving = False
save_dir = '/home/mattw/Dropbox/git/dreamscape/tmp/'
net_type = 'cvae' # 'vae' | 'cvae' | 'gan'

# define model params
layers_encoder = [784, 400, 400]
layer_latent = 20
layers_decoder = [400, 400, 784]
num_categories = 10

# define training params
batch_size = 128
epochs = {
    'training': 30,
    'disp': 5,
    'ckpt': None,
    'summary': 3
}
use_gpu = 1

# initialize network
reload(vae)
if net_type is 'vae':
    net = vae.VAE(
        layers_encoder=layers_encoder, 
        layer_latent=layer_latent,
        layers_decoder=layers_decoder,
        learning_rate=1e-3)
elif net_type is 'cvae':
    net = cvae.CVAE(
        layers_encoder=layers_encoder, 
        layer_latent=layer_latent,
        layers_decoder=layers_decoder,
        num_categories=num_categories,
        learning_rate=1e-3)
elif net_type is 'gan':
    raise NotImplementedError
else:
    raise Error('Invalid net_type')

# start the tensorflow session
config = tf.ConfigProto(device_count = {'GPU': use_gpu})
sess = tf.Session(config=config, graph=net.graph)
sess.run(net.init)

# train network
time_start = time.time()
net.train(sess, 
          data=data,
          batch_size=batch_size,
          epochs_training=epochs['training'],
          epochs_disp=epochs['disp'],
          epochs_ckpt=epochs['ckpt'],
          epochs_summary=epochs['summary'],
          output_dir=save_dir)
time_end = time.time()
print('time_elapsed: %g' % (time_end - time_start))

# save network
if saving:
    net.save_model(sess, save_dir)

# close the tensorflow session
sess.close()

# Visualize Model

## Reconstruction Visualization

In [None]:
x = data.train.next_batch(net.batch_size)
eps = np.zeros((net.batch_size, net.num_lvs))
recon = net.reconstruct(sess, x[0], eps)

f, ax = plt.subplots(2,5)
for j in range(5):
    ax[0,j].imshow(np.reshape(x[0][j,:], (28, 28)),
                  interpolation="nearest",
                  cmap="gray")
    ax[0,j].axes.get_xaxis().set_visible(False)
    ax[0,j].axes.get_yaxis().set_visible(False)
    
    ax[1,j].imshow(np.reshape(recon[j,:], (28, 28)),
                  interpolation="nearest",
                  cmap="gray")
    ax[1,j].axes.get_xaxis().set_visible(False)
    ax[1,j].axes.get_yaxis().set_visible(False)

plt.show()

## Latent Space Visualization of Already Trained Model

In [None]:
nx = ny = 20
x_values = np.linspace(-3, 3, nx)
y_values = np.linspace(-3, 3, ny)

canvas = np.empty((28*ny, 28*nx))
for i, yi in enumerate(x_values):
    for j, xi in enumerate(y_values):
        z_mean = np.array([[xi, yi]])
        x_mean = net.generate(sess, z_mean=z_mean)
        canvas[(nx-i-1)*28:(nx-i)*28, j*28:(j+1)*28] = x_mean[0].reshape(28, 28)

plt.figure(figsize=(8, 10))        
Xi, Yi = np.meshgrid(x_values, y_values)
plt.imshow(canvas, origin="upper",
           interpolation="nearest",
           cmap="gray")
plt.tight_layout()
plt.show()

In [None]:
%reload_ext watermark
%watermark -a "Matt Whiteway" -d -v -m -p numpy,tensorflow