# Variational Auto Encoder `VAE`

## Import dependencies

In [None]:
import os
import sys
import datetime as dt

import numpy as np
import tensorflow as tf
import tflearn
import matplotlib.pyplot as plt

from dataset import Dataset

%matplotlib inline

## Loading datasets

In [None]:
data_dir = 'datasets/pokemon/'
save_file = 'datasets/data/save.pkl'
data = Dataset(data_dir=data_dir, dest_dir='datasets/save/', background=Dataset.COLOR_WHITE)
# data.create()
# data.save(save_file=save_file, force=True)
data = data.load(save_file=save_file)

## Hyperparameters

### Image dimensions

In [None]:
img_size = data.size
img_channel = data.channels
img_shape = data.images.shape[1:]
img_size_flat = np.array(img_shape).prod()
print('Size: {}\tChannel: {}\t shape: {}\tFlattened: {:,}'.format(img_size, img_channel, img_shape, img_size_flat))

### Network Hyperparameters

In [None]:
keep_prob = 0.8
hidden_dim = 256
latent_dim = 128

### Training Hyperparameters

In [None]:
batch_size = 24
learning_rate = 1e-1
iterations = 10000
save_interval = 100
log_interval = 1000

### Helpers

In [None]:
# weights
def weight(shape, name):
    initial = tf.truncated_normal(shape=shape, mean=0, stddev=0.4)
    return tf.Variable(initial, name=name)

# biases
def bias(shape, name):
    initial = tf.zeros(shape=[shape])
    return tf.Variable(initial, name=name)

def leakyReLU(X, alpha=0.3):
    return tf.maximum(X, tf.multiply(X, alpha))

# convolutional block
def conv(X, W, b):
    layer = tf.nn.conv2d(X, W, strides=[1, 1, 1, 1], padding='SAME')
    layer = layer + b  # add bias
    layer = tf.nn.relu(layer)
    layer = tf.nn.max_pool(layer, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    layer = tf.contrib.layers.batch_norm(layer)
    return layer

# deconvolutional block
def deconv(X, W, b):
    pass

# fully connected block
def dense(X, W, b, activation=leakyReLU, batch_norm=False):
    layer = tf.matmul(X, W) + b
    if activation:
        layer = activation(layer)
    if batch_norm:
        layer = tf.contrib.layers.batch_norm(layer)
    return layer

# flatten
def flatten(layer):
    shape = layer.get_shape()
    features = np.array(shape[1:4], dtype=int).prod()
    layer = tf.reshape(layer, [-1, features])
    return layer, features

# Plot images in grid
def plot_images(imgs, name=None, smooth=False, **kwargs):
    grid = int(np.sqrt(len(imgs)))
    # Create figure with sub-plots.
    fig, axes = plt.subplots(grid, grid)
    fig.subplots_adjust(hspace=0.3, wspace=0.3)
    
    for i, ax in enumerate(axes.flat):
        # Interpolation type.
        interpolation = 'spline16' if smooth else 'nearest'
        # Plot image.
        ax.imshow(imgs[i].reshape(img_shape), interpolation=interpolation, **kwargs)
        # Remove ticks from the plot.
        ax.set_xticks([])
        ax.set_yticks([])
    if name:
        plt.suptitle(name)
    plt.show()

## The Variational Auto Encoder

### Encoder

In [None]:
def encoder(image):
    with tf.name_scope('encoder'):
        # 1st convolutional layer
        W_conv1 = weight(shape=[filter_size, filter_size, img_channel, conv1_size], name='W_conv1')
        b_conv1 = bias(shape=conv1_size, name='b_conv1')
        conv1 = conv(image, W_conv1, b_conv1)
        # 2nd convolutional layer
        W_conv2 = weight(shape=[filter_size, filter_size, conv1_size, conv2_size], name='W_conv2')
        b_conv2 = bias(shape=conv2_size, name='b_conv2')
        conv2 = conv(conv1, W_conv2, b_conv2)
        # Flatten
        flattened, n_features = flatten(conv2)
        # 1st Fully connected layer
        W_fc1 = weight(shape=[n_features, fc1_size], name='W_fc1')
        b_fc1 = bias(shape=fc1_size, name='b_fc1')
        fc1 = dense(flattened, W_fc1, b_fc1)
        # 2nd Fully connected layer
        W_fc2 = weight(shape=[fc1_size, fc2_size], name='W_fc1')
        b_fc2 = bias(shape=fc2_size, name='b_fc2')
        fc2 = dense(fc1, W_fc2, b_fc2)
        # Mean
        

### Decoder

In [None]:
def decoder(encoded):
    with tf.name_scope('decoder'):
        pass

In [None]:
tf.reset_default_graph()

X = tf.placeholder(tf.float32, shape=[None, img_size, img_size, img_channel])
encoded, mean, stddev = encoder(X)
decoded = decoder(encoded)

### Loss function

In [None]:
rec_loss = tf.reduce_sum(tf.squared_difference(decoded, X), reduction_indices=1) 
kl_term = tf.reduce_sum(1.0 + 2.0 * stddev - tf.square(mean) - tf.exp(2.0 * stddev), reduction_indices=1)
loss = tf.reduce_mean(rec_loss + kl_term)

### Optimizer

In [None]:
global_step = tf.Variable(0, trainable=False, name='global_step')
optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate)
trainer = optimizer.minimize(loss, global_step=global_step)

## Tensorflow `Session`

In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

### Tensorboard

In [None]:
tensorboard_dir = 'tesnorboard'
logdir = os.path.join(tensorboard_dir, 'log')
save_path = 'models'
saved_model = os.path.join(save_path, 'model')

# Summaries
tf.summary.scalar('loss', loss)
tf.summary.scalar('rec_loss_mean', tf.reduce_mean(rec_loss))
tf.summary.scalar('kl_term_mean', tf.reduce_mean(kl_term))
tf.summary.histogram('rec_loss', rec_loss)
tf.summary.histogram('kl_term', kl_term)
tf.summary.image('decoded', decoded, max_outputs=4)
merged = tf.summary.merge_all()

# saver & writer
saver = tf.train.Saver()
writer = tf.summary.FileWriter(logdir=logdir, graph=sess.graph)

### Loading Pretrained model

In [None]:
if tf.gfile.Exists(save_path):
    try:
        sys.stdout.write('INFO: Attempting to load last checkpoint\n')
        last_ckpt = tf.train.latest_checkpoint(save_path)
        sess.restore(sess=sess, save_path=last_ckpt)
        sys.stdout.write('INFO: Loaded checkpoint from {}\n'.format(last_ckpt))
    except Exception as e:
        sys.stderr.write('{}\n'.format(e))
else:
    tf.gfile.MakeDirs(save_path)
    sys.stdout.write('Created check piont directory – {}'.format(save_path))

## Training

In [None]:
start_time = dt.datetime.now()
for i in range(iterations):
    img_batch = data.next_batch(batch_size=batch_size)
    _, _loss, _global_step = sess.run([trainer, loss, global_step], feed_dict={X: img_batch})
    if i % save_interval == 0:
        saver.save(sess=sess, save_path=save_model, global_step=global_step)
        summary = sess.run(merged, feed_dict={X: img_batch})
        writer.add_summary(summary=summary, global_step=_global_step)
    if i % log_interval == 0:
        randoms = [np.random.normal(0, 1, latent_dim) for _ in range(9)]
        imgs = sess.run(decoded, feed_dict={encoded: randoms})
        plot_images(imgs, size=img_size, name='Test images', smooth=True)
    sys.stdout.write('Iteration: {:,}\tGlobal steps: {:,}Loss: {:.2f}\tTime taken: {}'
                     .format(i+1, _global_step, _loss, dt.datetime.now() - start_time))