In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, Dense, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist

(x_train, _), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train.astype('float32')/255., x_test.astype('float32')/255.

x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

In [2]:
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=K.shape(z_mean))
    return z_mean + K.exp(z_log_var / 2) * epsilon

In [3]:
x = Input(shape=(x_train.shape[1],))
x1 = Dense(x_train.shape[1]//2, activation='relu')(x)
x2 = Dense(x_train.shape[1]//3, activation='relu')(x1)
h = Dense(x_train.shape[1]//4, activation='relu')(x2)
z_mean = Dense(2)(h)
z_log_var = Dense(2)(h)
z = Lambda(sampling, output_shape=(2,))([z_mean, z_log_var])

decoder_h = Dense(x_train.shape[1]//4, activation='relu')
dc1 = Dense(x_train.shape[1]//3, activation='relu')
dc2 = Dense(x_train.shape[1]//2, activation='relu')

h_decoded = decoder_h(z)
h_decoded = dc1(h_decoded)
h_decoded = dc2(h_decoded)

decoder_mean = Dense(x_train.shape[1], activation='sigmoid')
x_decoded_mean = decoder_mean(h_decoded)

xent_loss = K.sum(K.binary_crossentropy(x, x_decoded_mean), axis=-1)
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
vae_loss = K.mean(xent_loss + kl_loss)

In [None]:
vae = Model(x, x_decoded_mean)
vae.add_loss(vae_loss)
vae.compile(optimizer=tf.keras.optimizers.Adam(amsgrad=True))
vae.fit(x_train, shuffle=True, epochs=5, batch_size=32, validation_data=(x_test, None), verbose=0);

In [None]:
decoder_input = Input(shape=(2,))
_h_decoded = decoder_h(decoder_input)
_h_decoded = dc1(_h_decoded)
_h_decoded = dc2(_h_decoded)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)

n = 10 
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = np.linspace(-2, 2, n)
grid_y = np.linspace(-2, 2, n)

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]]) * 1.0
        x_decoded = generator.predict(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

plt.figure(figsize=(6, 6))
plt.imshow(figure);