In [10]:
from keras.datasets import mnist
from keras.layers import Input,Dense, Conv2D,MaxPooling2D,Dropout,Flatten,Lambda, Layer,Reshape, Conv2DTranspose
from keras.models import Model
import keras.backend as K
from keras.utils import np_utils
from keras.callbacks import TensorBoard
import cv2
import numpy as np
from keras.metrics import binary_crossentropy

In [11]:
import matplotlib.pyplot as plt

In [12]:
tensorboard = TensorBoard(log_dir='./logs', histogram_freq=1,
                          write_graph=True, write_images=False)

In [13]:
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0],28, 28,1).astype('float32') 
X_test = X_test.reshape(X_test.shape[0],28, 28,1).astype('float32')


X_train = X_train / 255
X_test = X_test / 255

In [14]:
Y_train = np_utils.to_categorical(Y_train)
Y_test = np_utils.to_categorical(Y_test)
num_classes = Y_test.shape[1]


In [None]:
latent_dim = 2

In [15]:
X_train.shape

(60000, 28, 28, 1)

In [20]:
input_layer = Input(shape=(X_train.shape[1],X_train.shape[2],X_train.shape[3]))

conv_1 = Conv2D(32,(3,3), padding='same',activation = 'relu')(input_layer)
conv_2 = Conv2D(64,(3,3), padding='same',activation = 'relu',strides=(2,2))(conv_1)
conv_3 = Conv2D(64,(3,3), padding='same',activation = 'relu')(conv_2)
conv_4 = Conv2D(64,(3,3), padding='same',activation = 'relu')(conv_3)

flat = Flatten()(conv_4)
dense_1 = Dense(32, activation = 'relu')(flat)

z_mean = Dense(latent_dim)(dense_1)
z_log_var = Dense(latent_dim)(dense_1)

shape_before_flattening = K.int_shape(conv_4)

In [29]:
def sampling(input_layer):
    
    z_mean,z_log_var = input_layer
    epsilon = K.random_normal(shape = (K.shape(z_mean)[0],latent_dim), mean=0., stddev = 1.)
    return z_mean + K.exp(0.5 * z_log_var) * epsilon
z = Lambda(sampling)([z_mean,z_log_var])

In [40]:
decoder_input = Input(shape= (K.int_shape(z)[1:]))

decode_dense_1 = Dense(np.prod(shape_before_flattening[1:]),activation = 'relu')(decoder_input)
decode_dense_2= Reshape(shape_before_flattening[1:])(decode_dense_1)

trans_conv2d = Conv2DTranspose(32,(3,3),padding='same',activation = 'relu',strides = (2,2))(decode_dense_2)

decoder_conv2d_1 = Conv2D(1,(3,3), padding='same', activation = 'sigmoid')(trans_conv2d)

decoder = Model(decoder_input,decoder_conv2d_1)

z_decoded = decoder(z)

In [62]:
class CustomerVariationalLayer(Layer):
    def vae_loss(self,x,z_decoded):
        x = K.flatten(x)
        z_decoded = K.flatten(z_decoded)
        xent_loss = binary_crossentropy(x,z_decoded)
        kl_loss = -5e-4 * K.mean(1+ z_log_var - K.square(z_mean) - K.exp(z_log_var),axis=-1)
        return K.mean(xent_loss+kl_loss)
    def call(self, inputs):
        x = inputs[0]
        z_decoded = inputs[1]
        loss = self.vae_loss(x,z_decoded)
        self.add_loss(loss,inputs = inputs)
        return x
    


In [63]:
y = CustomerVariationalLayer()([input_layer, z_decoded])

In [64]:
vae = Model(input_layer, y)

In [65]:
vae.compile(optimizer = 'rmsprop')

In [66]:
vae.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_5 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 28, 28, 32)   320         input_5[0][0]                    
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 14, 14, 64)   18496       conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 14, 14, 64)   36928       conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_15 

In [None]:
batch_size = 128
num_epoch = 3
#model training
model_log = vae.fit(x = X_train,y=None ,
          batch_size=batch_size,
          epochs=num_epoch,
          verbose=1,
          validation_data=(X_test, None), callbacks =[tensorboard])

Train on 60000 samples, validate on 10000 samples
