<a href="https://colab.research.google.com/github/sarmientoj24/EE298/blob/master/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from keras.layers import Lambda, Input, Dense, BatchNormalization, LeakyReLU, GlobalAveragePooling2D, Activation
from keras.models import Model
from keras.datasets import mnist
from keras.losses import mse, binary_crossentropy
from keras.utils import plot_model
from keras import backend as K

import numpy as nps
import matplotlib.pyplot as plt
import os

In [0]:
# reparameterization trick
# instead of sampling from Q(z|X), sample eps = N(0,I)
# z = z_mean + sqrt(var)*eps
def sampling2(args):
  z_mean, z_log_var = args
  batch = K.shape(z_mean)[0]
  dim = K.int_shape(z_mean)[1]
  # by default, random_normal has mean=0 and std=1.0
  epsilon = K.random_normal(shape=(batch, dim))
  return z_mean + K.exp(0.5 * z_log_var) * epsilon


# DISCRETE EPSILON for Code Generation
# To Use:
# Encoder's last layer should output z_mean, z_log_var, and possible z.
# But on encoding dataset in Code Generation, we disregard z and use this function
# to compute for a discrete z using z_mean and z_log_var
def discrete_z_sampling(args):
  z_mean, z_log_var, epsilon_discrete = args

  return z_mean + K.exp(0.5 * z_log_var) * epsilon_discrete

def randomly_sample_epsilon_discrete(dimension):
  epsilon_discrete = K.random_normal(shape=(1, dimension))
  print("Generating discrete epsilon...")
  print(epsilon_discrete)
  return epsilon_discrete


In [0]:
#we Plotter functions here

In [116]:
# Implement VAE here

class VAE():
  image_shape = (80, 60, 3)

  def __init__(self):
    # Build encoder
    # Adopted from https://github.com/YongWookHa/VAE-Keras/blob/master/VAE.py

    self.encoder_inputs = Input(shape=self.image_shape)
    filter_dim = 128
    z_dim = 10
    x = self.encoder_inputs
    x = Conv2D(int(filter_dim/16), kernel_size=(2,2), strides=(2,2), padding='SAME')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Conv2D(int(filter_dim/8), kernel_size=(2,2), strides=(2,2), padding='SAME')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Conv2D(int(filter_dim/4), kernel_size=(2,2), strides=(2,2), padding='SAME')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Conv2D(int(filter_dim/2), kernel_size=(2,2), strides=(2,2), padding='SAME')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    x = Conv2D(filter_dim, kernel_size=(2,2), strides=(2,2), padding='SAME')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)
    # x = GlobalAveragePooling2D()(x)

    enc_shape = K.int_shape(x)
    print(enc_shape)
    # generate latent vector Q(z|X)
    x = Flatten()(x)
    x = Dense(64)(x)
    x = LeakyReLU(0.2)(x)
    z_mean = Dense(z_dim, name='z_mean')(x)
    z_log_var = Dense(z_dim, name='z_log_var')(x)

    # use reparameterization trick to push the sampling out as input
    # note that "output_shape" isn't necessary with the TensorFlow backend
    z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

    # instantiate encoder model
    self.encoder = Model(self.encoder_inputs, [z_mean, z_log_var, z], name='encoder')
    self.encoder.summary()
    plot_model(self.encoder, to_file='vae_cnn_encoder.png', show_shapes=True)

    # Decoder
    self.latent_inputs = Input(shape=(z_dim,), name='z_sampling')
    z = Dense(enc_shape[1] * enc_shape[2] * enc_shape[3], activation='relu')(self.latent_inputs)
    z = Reshape((enc_shape[1], enc_shape[2], enc_shape[3]))(z)
    z = Conv2DTranspose(int(filter_dim/2), kernel_size=(2, 2), strides=2, padding='same')(z)
    z = BatchNormalization()(z)
    z = Activation('relu')(z)
    z = Conv2DTranspose(int(filter_dim/4), kernel_size=(2,2), strides=2, padding='same')(z)
    z = BatchNormalization()(z)
    z = Activation('relu')(z)
    z = Conv2DTranspose(int(filter_dim/8), kernel_size=(2,2), strides=2, padding='same')(z)
    z = BatchNormalization()(z)
    z = Activation('relu')(z)
    z = Conv2DTranspose(int(filter_dim/16), kernel_size=(2,2), strides=2, padding='same')(z)
    z = BatchNormalization()(z)
    z = Activation('relu')(z)
    z = Conv2DTranspose(3, kernel_size=(2,2), strides=(2,2), padding='same')(z)
    encoder_output = Activation('tanh')(z)

    # instantiate decoder model
    self.decoder = Model(self.latent_inputs, encoder_output, name='decoder')
    self.decoder.summary()
    plot_model(self.decoder, to_file='vae_cnn_decoder.png', show_shapes=True)

    # instantiate VAE model
    outputs = self.decoder(self.encoder(self.encoder_inputs)[2])
    vae = Model(self.encoder_inputs, outputs, name='vae')
    vae.summary()

  def save_encoder_to_h5(self):
    pass
  
  def save_decoder_to_h5(self):
    pass
  
  def save_vae_to_h5(self):
    pass

x = VAE()

(None, 3, 2, 128)
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_46 (InputLayer)           (None, 80, 60, 3)    0                                            
__________________________________________________________________________________________________
conv2d_222 (Conv2D)             (None, 40, 30, 8)    104         input_46[0][0]                   
__________________________________________________________________________________________________
batch_normalization_301 (BatchN (None, 40, 30, 8)    32          conv2d_222[0][0]                 
__________________________________________________________________________________________________
leaky_re_lu_238 (LeakyReLU)     (None, 40, 30, 8)    0           batch_normalization_301[0][0]    
__________________________________________________________________________

(None, 3, 2, 128)
Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_40 (InputLayer)           (None, 80, 60, 3)    0                                            
__________________________________________________________________________________________________
conv2d_192 (Conv2D)             (None, 40, 30, 8)    104         input_40[0][0]                   
__________________________________________________________________________________________________
batch_normalization_247 (BatchN (None, 40, 30, 8)    32          conv2d_192[0][0]                 
__________________________________________________________________________________________________
leaky_re_lu_202 (LeakyReLU)     (None, 40, 30, 8)    0           batch_normalization_247[0][0]    
__________________________________________________________________________