In [None]:
'''This script demonstrates VAE on image data

 #Reference

 - Auto-Encoding Variational Bayes
   https://arxiv.org/abs/1312.6114
 - Joint Multi-Modal VAE
   https://arxiv.org/pdf/1611.01891.pdf
'''
import warnings
import os
from glob import glob
import numpy as np
np.random.seed(0)
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.mlab as mlab
plt.rcParams["figure.figsize"] = [20,20]
from scipy.stats import norm
from keras.utils.vis_utils import model_to_dot
from keras.layers import Input, Dense, Lambda, Flatten, Reshape, Layer, LeakyReLU, BatchNormalization
from keras.layers import Conv2D, Conv2DTranspose
from keras.regularizers import l2
from keras.models import Model
from keras import backend as K
from keras import metrics
import keras
from skimage.transform import resize
from PIL import Image

import sys, os
sys.path.append(os.path.expanduser('/opt/repositories/twbserver_notebook/notebook/tools'))
import vae_tools
from vae_tools import plot_model, layers, nb_tools, viz, loader, build_model, sanity, sampling, custom_variational_layer

nb_tools.notebook_resize()
sanity.check()

data_set_size = 1710

In [None]:
def get_iterator(data_set = 'img_0', batch_size = 1, shuffle=False):
    from keras.preprocessing.image import ImageDataGenerator
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0,
        zoom_range=0,
        horizontal_flip=False,
        width_shift_range=0.0,  # randomly shift images horizontally (fraction of total width)
        height_shift_range=0.0)  # randomly shift images vertically (fraction of total height))

    train_generator = train_datagen.flow_from_directory(data_set, interpolation='nearest',
            color_mode='rgb', shuffle=shuffle, seed=None,
            target_size=(48, 64),
            batch_size=batch_size,
            #save_to_dir='img_0_augmented',
            class_mode=None)
    return train_generator

X_train_0 = get_iterator(data_set = 'img_0', batch_size = data_set_size).next()
X_train_1 = get_iterator(data_set = 'img_pitch_pi2', batch_size = data_set_size).next()
X_train_2 = get_iterator(data_set = 'img_pitch_-pi2', batch_size = data_set_size).next()
X_train_3 = get_iterator(data_set = 'img_roll_pi', batch_size = data_set_size).next()
X_train_4 = get_iterator(data_set = 'img_roll_pi2', batch_size = data_set_size).next()
X_train_5 = get_iterator(data_set = 'img_roll_-pi2', batch_size = data_set_size).next()

# Stack the set
X_train = np.concatenate((X_train_0, X_train_1, X_train_2, X_train_3, X_train_4, X_train_5), axis=0)
# reshape from (: 48, 64, 3) to (:, 64, 64, 3)
X_train = np.concatenate((X_train, np.repeat(X_train[:, [-1], :, :], 8, axis = 1)), axis = 1)
X_train = np.concatenate((np.repeat(X_train[:, [0], :, :], 8, axis = 1), X_train), axis = 1)
print(X_train.shape)

In [None]:
# Show some scaled and squared image
x = get_iterator(shuffle = True, batch_size = 10).next()
# make it square
x = np.concatenate((x, np.repeat(x[:, [-1], :, :], 8, axis = 1)), axis = 1)
x = np.concatenate((np.repeat(x[:, [0], :, :], 8, axis = 1), x), axis = 1)
print(x[0].shape)
plt.imshow(x[0])
plt.show()

In [None]:
# input image dimensions and config
batch_size = 64
n_channels = 3
image_rows_cols_chns = (64, 64, n_channels)
original_dim = np.prod(image_rows_cols_chns)
img_chns = image_rows_cols_chns[2]
if keras.backend.image_data_format() == 'channels_first':
    original_img_size = (image_rows_cols_chns[2], image_rows_cols_chns[0], image_rows_cols_chns[1])
else:
    original_img_size = image_rows_cols_chns
latent_dim = 2
epochs = 4000
save_model = False
#beta = 0.012207031 # 50.
beta = 1.
n_encoder = 1024
latent_dim = 128
decode_from_shape = (8, 8, 256)
n_decoder = np.prod(decode_from_shape)
leaky_relu_alpha = 0.2
recon_depth=9
wdecay=1e-5
bn_mom=0.9
bn_eps=1e-6

def conv_block(x, filters, leaky=True, transpose=False, name=''):
    conv = Conv2DTranspose if transpose else Conv2D
    activation = LeakyReLU(leaky_relu_alpha) if leaky else Activation('relu')
    layers = [
        conv(filters, 5, strides=2, padding='same', kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name=name + 'conv'),
        BatchNormalization(momentum=bn_mom, epsilon=bn_eps, name=name + 'bn'),
        activation
    ]
    if x is None:
        return layers
    for layer in layers:
        x = layer(x)
    return x

encoder = [[
    Input(shape=original_img_size),
    *conv_block(None, 64, name='enc_blk_1_'),
    *conv_block(None, 128, name='enc_blk_2_'),
    *conv_block(None, 256, name='enc_blk_3_'),
    Flatten(),
    Dense(n_encoder, kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name='enc_h_dense'),
    BatchNormalization(name='enc_h_bn'),
    LeakyReLU(leaky_relu_alpha)
]]

decoder = [[
    Dense(n_decoder, kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', input_shape=(latent_dim,), name='dec_h_dense'),
    BatchNormalization(name='dec_h_bn'),
    LeakyReLU(leaky_relu_alpha),
    Reshape(decode_from_shape),
    *conv_block(None, 256, transpose=True, name='dec_blk_1_'),
    *conv_block(None, 128, transpose=True, name='dec_blk_2_'),
    *conv_block(None, 32, transpose=True, name='dec_blk_3_'),
    Conv2D(n_channels, 5, activation='tanh', padding='same', kernel_regularizer=l2(wdecay), kernel_initializer='he_uniform', name='dec_output')
]]

In [None]:
model_obj = vae_tools.mmvae.MmVae(latent_dim, encoder, decoder, [original_dim], beta, beta_is_normalized = False, reconstruction_loss_metrics = [vae_tools.mmvae.ReconstructionLoss.MSE], name='Vae')
vae = model_obj.get_model()
vae.compile(optimizer='rmsprop', loss=None)
vae_tools.viz.plot_model(vae, file = 'myVAE', print_svg = False, verbose = True)

In [None]:
# Show some examples
#viz.random_images_from_set(X_set, image_rows_cols_chns, n = 15);

In [None]:
# Train
vae.fit(X_train,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        verbose = 1)


In [None]:
# Store the model
if use_conv:
    model_obj.store_model("cameraRGB_conv_encoder_mean", model = model_obj.get_encoder_mean([encoder[0][0]]), overwrite = save_model)
else:
    model_obj.store_model("cameraRGB_encoder_mean", model = model_obj.get_encoder_mean([encoder[0][0]]), overwrite = save_model)

In [None]:
import vae_tools
from importlib import reload
reload(vae_tools)

# Vizualization
# Encode samples to get the min and max values in latent space
x_test_encoded = model_obj.get_encoder_mean([encoder[0][0]]).predict(X_train, batch_size=batch_size)

# display a 2D manifold
nx = 20
ny = 20

## linearly spaced coordinates on the unit square were transformed through the inverse CDF (ppf) of the Gaussian
## to produce values of the latent variables z, since the prior of the latent space is Gaussian
grid_x = norm.ppf(np.linspace(0.001, 0.999, nx))
grid_y = norm.ppf(np.linspace(0.001, 0.999, ny))
grid_x = np.linspace(np.amin(grid_x), np.amax(grid_x), nx)
grid_y = np.linspace(np.amin(grid_y), np.amax(grid_y), ny)
grid_x = np.linspace(np.amin(x_test_encoded[:, 0]), np.amax(x_test_encoded[:, 0]), nx)
grid_y = np.linspace(np.amin(x_test_encoded[:, 1]), np.amax(x_test_encoded[:, 1]), ny)

# display a 2D plot of the digit classes in the latent space
import vae_tools
from importlib import reload  # Python 3.4+ only.
viz = reload(vae_tools.viz)
vae_tools.viz.scatter_encoder(X_train, np.zeros((len(X_train),3)), grid_x, grid_y, model_obj, figsize=(15, 15), dpi=150)

# Plot the resampled inputs
figure, x_mean_test_encoded, x_std_test_encoded = viz.get_image_dec_enc_samples(grid_x, grid_y, model_obj, image_rows_cols_chns)
plt.figure(figsize=(15, 15), dpi=96)
plt.imshow(figure, cmap='Greys_r')
plt.show()

# Plot the resampled std deviations
X, Y = np.meshgrid(np.arange(0,len(grid_x)), np.arange(0,len(grid_y)))
plt.pcolor(X, Y, x_std_test_encoded, cmap='coolwarm', vmin=x_std_test_encoded.min(), vmax=x_std_test_encoded.max())
plt.colorbar()
plt.axis("image")
plt.show()

In [None]:
np.ones((len(X_train),1)).shape