In [None]:
import keras
from keras import backend as K
from keras.layers import Input, Lambda, Dense, Conv2D, Conv2DTranspose, MaxPool2D, Flatten, Reshape, Layer
from keras.models import Model
from keras import metrics
from keras.losses import mse, binary_crossentropy
from keras.datasets import mnist
import tensorflow as tf
from tensorflow.python import debug as tf_debug
import numpy as np

In [None]:
batch_size = 128
epochs = 50
image_size = (28,28,1)
latent_dimension = 3 ## to view representation clusters in 3 dimensions

In [None]:
## defining the input for mnist images
input_image = Input(shape=image_size, name='encoder_input')

In [None]:
## defining the inference network
## this is the network that will produce a latent space representation of the original image
## 5 layer convolutional network
encoder = Conv2D(16, (3,3), activation='relu', padding='same', name="encoder_conv1")(input_image)
encoder = MaxPool2D((2,2), padding="same", name="encoder_pool1")(encoder)
encoder = Conv2D(8, (3,3), activation='relu', padding='same', name="encoder_conv2")(encoder)
encoder = MaxPool2D((2,2), padding="same", name="encoder_pool2")(encoder)
encoder = Conv2D(4, (3,3), activation='relu', padding='same', name="encoder_conv3")(encoder)
encoder_shape = K.int_shape(encoder)
encoder = Flatten()(encoder) ## turns output to size of (None, 196)
encoder = Dense(32, activation='relu')(encoder)

z_mean = Dense(latent_dimension, name="z_mean")(encoder)
z_var = Dense(latent_dimension, name="z_var")(encoder)

In [None]:
## defining the sampling method for the generator network
def normal_sample(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dimension),
                              mean=0., stddev=1.)
    return z_mean + K.exp(z_log_var) * epsilon
z = Lambda(normal_sample, name="z_sample", output_shape=(latent_dimension,))([z_mean, z_var])
encoder = Model(input_image, [z_mean, z_var, z], name='encoder')

In [None]:
encoder.summary()

In [None]:
## building decoder model

filters = 16
kernel_size = 3

latent_inputs = Input(shape=(latent_dimension,), name='z_sampling')
decoder = Dense(encoder_shape[1] * encoder_shape[2] * encoder_shape[3], activation='relu')(latent_inputs)
decoder = Reshape(encoder_shape[1:])(decoder)
for i in range(2):
    decoder = Conv2DTranspose(filters=filters,
                        kernel_size=kernel_size,
                        activation='relu',
                        strides=2,
                        padding='same')(decoder)
    filters //= 2
outputs = Conv2DTranspose(filters=1,
                          kernel_size=kernel_size,
                          activation='sigmoid',
                          padding='same',
                          name='decoder_output')(decoder)
decoder = Model(latent_inputs, outputs, name='decoder')

In [None]:
outputs = decoder(encoder(input_image)[2])

In [None]:
from keras.datasets import mnist

vae = Model(input_image, outputs)
reconstruction_loss = binary_crossentropy(K.flatten(input_image),
                                                  K.flatten(outputs))
kl_loss = 1 + z_var - K.square(z_mean) - K.exp(z_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='rmsprop', loss=None)
vae.summary()

# Train the VAE on MNIST digits
(x_train, _), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.astype('float32') / 255.
x_test = x_test.reshape(x_test.shape + (1,))

history = vae.fit(x_train,
        shuffle=True,
        epochs=10,
        batch_size=batch_size,
        validation_data=(x_test, None)
       )