<a href="https://colab.research.google.com/github/rajatha94/faiProject/blob/master/vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from keras.layers import Lambda, Input, Dense
from keras.models import Model
from keras.losses import binary_crossentropy
from keras import backend as K

import numpy as np
import matplotlib.pyplot as plot
import tensorflow as tf

In [0]:
def reparameterization_trick(args):
    #Instead of sampling from Z directly, the samples are taken from epsilon which is a unit normal distribution.
    mean, std_dev = args
    epsilon = K.random_normal(shape=(K.shape(mean)[0],latent_dim), mean=0, stddev=1.0)
    return mean + K.exp(0.5 * std_dev) * epsilon


In [0]:
def plot_current_epoch(decoder,samples,epoch):
    digit_size = 28
    count = 1

    fig = plot.figure(figsize=(6, 6))
    print("Epoch :: ",epoch)
    for sample in samples:
        image = decoder.predict(sample)
        digit = image[0].reshape(digit_size, digit_size)
        plot.subplot(6, 6, count)
        plot.imshow(digit, cmap='gray')
        plot.axis('off')
        count += 1
    plot.savefig("vae_epoch_"+str(epoch)+".png")
    plot.show()

In [0]:
# Loading the MNIST dataset
(train_images, train_labels),(test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_images = np.reshape(train_images, [-1, 784])
train_images = train_images.astype('float32') / 255

# network parameters
batch_size = 64
latent_dim = 2
epochs = 5000

# encoder layers
inputs = Input(shape=(784,))
encoder_hidden_1 = Dense(512, activation='relu')(inputs)
encoder_hidden_2 = Dense(256, activation='relu')(encoder_hidden_1)
mean = Dense(latent_dim)(encoder_hidden_2)
std_dev = Dense(latent_dim)(encoder_hidden_2)

# reparameterization trick
z = Lambda(reparameterization_trick)([mean, std_dev])

# encoder
encoder = Model(inputs, mean)

# decoder layers
latent_inputs = Input(shape=(latent_dim,))
decoder_hidden_layer_1 = Dense(256, activation='relu')
decoder_hidden_layer_2 = Dense(512, activation='relu')
decoder_output_layer = Dense(784, activation='sigmoid')

# decoder
decoder_hidden_1 = decoder_hidden_layer_1(latent_inputs)
decoder_hidden_2 = decoder_hidden_layer_2(decoder_hidden_1)
decoder_output = decoder_output_layer(decoder_hidden_2)
decoder = Model(latent_inputs, decoder_output)

# VAE
vae_decoder_hidden_1 = decoder_hidden_layer_1(z)
vae_decoder_hidden_2 = decoder_hidden_layer_2(vae_decoder_hidden_1)
vae_decoder_output = decoder_output_layer(vae_decoder_hidden_2)
vae = Model(inputs, vae_decoder_output)

In [0]:
    #Calculate the loss of the VAE
    reconstruction_loss = K.sum(tf.keras.losses.binary_crossentropy(inputs,vae_decoder_output))
    kl_loss = -0.5 * K.sum(1 + std_dev - K.square(mean) - K.exp(std_dev), axis=-1)
    vae.add_loss(K.mean(reconstruction_loss + kl_loss))

In [0]:
    vae.compile(optimizer='adam')

    random_samples = []
    # Random samples for reconstructing the image
    for i in range(36):
      random_samples.append(np.random.normal(0,1,size=[batch_size, latent_dim]))
    
    plot_current_epoch(decoder,random_samples,0)
    lossDict = {}
    # train the autoencoder
    for epoch in range(epochs):
      vae.fit(train_images,epochs=1,batch_size=batch_size)
      #if (epoch+1) % 100 == 0:
      plot_current_epoch(decoder,random_samples,epoch+1)