# CHAPTER 5.6

### Implementing a variational autoencoder

Variational Autoencoders (VAEs), They differ from the rest of the autoencoders in that, instead of learning an arbitrary function, they learn a probability distribution of the input images. We can then sample this distribution to produce new, unseen data points.A VAE is, in fact, a generative model.

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.layers import *
from tensorflow.keras.losses import mse
from tensorflow.keras.optimizers import Adam

In [2]:
tf.config.experimental_run_functions_eagerly(True)

Instructions for updating:
Use `tf.config.run_functions_eagerly` instead of the experimental version.


In [3]:
class VAE(object):
    def __init__(self,
                 original_dimension=784,
                 encoding_dimension=512,
                 latent_dimension=2):
        self.original_dimension = original_dimension
        self.encoding_dimension = encoding_dimension
        self.latent_dimension = latent_dimension

        self.z_log_var = None
        self.z_mean = None

        self.inputs = None
        self.outputs = None

        self.encoder = None
        self.decoder = None
        self.vae = None

    def build_vae(self):
        # Build encoder
        self.inputs = Input(shape=(self.original_dimension,))
        x = Dense(self.encoding_dimension)(self.inputs)
        x = ReLU()(x)
        self.z_mean = Dense(self.latent_dimension)(x)
        self.z_log_var = Dense(self.latent_dimension)(x)

        z = Lambda(sampling)([self.z_mean, self.z_log_var])

        self.encoder = Model(self.inputs,
                             [self.z_mean, self.z_log_var, z])

        # Build decoder
        latent_inputs = Input(shape=(self.latent_dimension,))
        x = Dense(self.encoding_dimension)(latent_inputs)
        x = ReLU()(x)
        self.outputs = Dense(self.original_dimension)(x)
        self.outputs = Activation('sigmoid')(self.outputs)
        self.decoder = Model(latent_inputs, self.outputs)

        # Build end-to-end VAE.
        self.outputs = self.encoder(self.inputs)[2]
        self.outputs = self.decoder(self.outputs)
        self.vae = Model(self.inputs, self.outputs)

    @tf.function
    def train(self,
              X_train,
              X_test,
              epochs=50,
              batch_size=64):
        reconstruction_loss = mse(self.inputs, self.outputs)
        reconstruction_loss *= self.original_dimension

        kl_loss = (1 + self.z_log_var -
                   K.square(self.z_mean) -
                   K.exp(self.z_log_var))
        kl_loss = K.sum(kl_loss, axis=-1)
        kl_loss *= -0.5

        vae_loss = K.mean(reconstruction_loss + kl_loss)

        self.vae.add_loss(vae_loss)
        self.vae.compile(optimizer=Adam(lr=1e-3))
        self.vae.fit(X_train,
                     epochs=epochs,
                     batch_size=batch_size,
                     validation_data=(X_test, None))

        return self.encoder, self.decoder, self.vae


In [4]:
def sampling(arguments):
    z_mean, z_log_var = arguments
    batch = K.shape(z_mean)[0]
    dimension = K.int_shape(z_mean)[1]

    epsilon = K.random_normal(shape=(batch, dimension))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

In [5]:
def generate_and_plot(decoder, grid_size=5):
    cell_size = 28

    figure_shape = (grid_size * cell_size,
                    grid_size * cell_size)
    figure = np.zeros(figure_shape)
    grid_x = np.linspace(-4, 4, grid_size)
    grid_y = np.linspace(-4, 4, grid_size)[::-1]

    for i, z_log_var in enumerate(grid_y):
        for j, z_mean in enumerate(grid_x):
            z_sample = np.array([[z_mean, z_log_var]])
            generated = decoder.predict(z_sample)[0]

            # Reshape as image.
            fashion_item = generated.reshape(cell_size,
                                             cell_size)

            # Assign to the corresponding cell in the grid.
            y_slice = slice(i * cell_size,
                            (i + 1) * cell_size)
            x_slice = slice(j * cell_size,
                            (j + 1) * cell_size)
            figure[y_slice, x_slice] = fashion_item

    plt.figure(figsize=(10, 10))
    start = cell_size // 2
    end = (grid_size - 2) * cell_size + start + 1
    pixel_range = np.arange(start, end, cell_size)

    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)

    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel('z_mean')
    plt.ylabel('z_log_var')
    plt.imshow(figure)
    plt.show()


In [6]:
(X_train, _), (X_test, _) = fashion_mnist.load_data()

X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0

X_train = X_train.reshape((X_train.shape[0], -1))
X_test = X_test.reshape((X_test.shape[0], -1))

In [7]:
vae = VAE(original_dimension=784,
          encoding_dimension=512,
          latent_dimension=2)

In [11]:
vae.build_vae()


In [12]:
vae

<__main__.VAE at 0x1f4b508aca0>

In [13]:
_, decoder_model, vae_model = vae.train(X_train, X_test,
                                        epochs=100)

  super(Adam, self).__init__(name, **kwargs)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

In [30]:
vae_model.save('vae_fit.h5')



In [31]:
decoder_model.save('decoder_model_fit.h5')



In [26]:
cell_size = 28
grid_size=5
grid_x = np.linspace(-4, 4, grid_size)
print(grid_x)
grid_y = np.linspace(-4, 4, grid_size)[::-1]
print(grid_y)
plt.figure(figsize=(10, 10))
start = cell_size // 2
print(start)
end = (grid_size - 2) * cell_size + start + 1
print(end)
pixel_range = np.arange(start, end, cell_size)
print(pixel_range)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)

[-4. -2.  0.  2.  4.]
[ 4.  2.  0. -2. -4.]
14
99
[14 42 70 98]


<Figure size 720x720 with 0 Axes>

In [None]:
generate_and_plot(decoder_model, grid_size=7)

![image.png](attachment:image.png)