# Digits Generation with VAE
## Import Packages

In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import time
import pandas as pd
from IPython import display

## Utilities

In [None]:
def sample_images(images, row_count, column_count):
    fig, axs = plt.subplots(row_count, column_count, figsize=(10,10))
    for i in range(row_count):
        for j in range(column_count):
            axs[i,j].imshow(images[i * column_count + j])
            axs[i,j].axis('off')
    plt.show()

## Import Datasets

In [None]:
item_size = 10
batch_size = item_size ** 2
n_epochs = 10
image_width = 32
latent_dimension = 2
input_shape = [image_width, image_width, 1]

In [None]:
def preprocess_image(item):
    image = item["image"]
    image = tf.cast(image, "float")  / 255.0
    image =tf.image.resize(image, (image_width, image_width))
    return image

In [None]:
train = tfds.load("mnist", split='train', as_supervised=False).map(preprocess_image).shuffle(1024).batch(batch_size, drop_remainder=True).prefetch(1).repeat(n_epochs)
test = tfds.load("mnist", split='test', as_supervised=False).map(preprocess_image).batch(batch_size, drop_remainder=True).prefetch(1)

Let's what the images looks like.

In [None]:
for images in train.take(1):
    sample_images(images, item_size, item_size)

## Model Development

In [None]:
tf.keras.backend.clear_session()

### Sampling Class

In [None]:
class Sampling(tf.keras.layers.Layer):
  def call(self, inputs):
    mu, sigma = inputs
    batch = tf.shape(mu)[0]
    dim = tf.shape(mu)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
    return mu + tf.exp(0.5 * sigma) * epsilon

### Build the Encoder

In [None]:
def get_encoder(latent_dimension, input_shape):
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=2, padding="same", activation='relu', name="encode_conv1")(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, padding='same', activation='relu', name="encode_conv2")(x)
    batch_norm = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Flatten(name="encode_flatten")(batch_norm)
    x = tf.keras.layers.Dense(20, activation='relu', name="encode_dense")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    mu = tf.keras.layers.Dense(latent_dimension, name='mu')(x)
    sigma = tf.keras.layers.Dense(latent_dimension, name ='sigma')(x)
    z = Sampling()((mu, sigma))
    encoder = tf.keras.Model(inputs, outputs=(mu, sigma, z))
    return encoder, batch_norm.shape

In [None]:
encoder, conv_shape = get_encoder(latent_dimension, input_shape)
tf.keras.utils.plot_model(encoder, show_shapes=True)

### Build the Decoder

In [None]:
def get_decoder(latent_dimension, conv_shape):
    inputs = tf.keras.layers.Input(shape=(latent_dimension,))
    decoder_input_units = conv_shape[1] * conv_shape[2] * conv_shape[3]
    x = tf.keras.layers.Dense(decoder_input_units, activation = 'relu', name="decode_dense1")(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Reshape((conv_shape[1], conv_shape[2], conv_shape[3]), name="decode_reshape")(x)
    x = tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=3, strides=2, padding='same', activation='relu', name="decode_conv2d_2")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Conv2DTranspose(filters=32, kernel_size=3, strides=2, padding='same', activation='relu', name="decode_conv2d_3")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    outputs = tf.keras.layers.Conv2DTranspose(filters=1, kernel_size=3, strides=1, padding='same', activation='sigmoid', name="decode_final")(x)
    decoder = tf.keras.Model(inputs, outputs)
    return decoder

In [None]:
decoder = get_decoder(latent_dimension, conv_shape)
tf.keras.utils.plot_model(decoder, show_shapes=True)

### KL Divergence

In [None]:
def kl_divergence(encoder, decoder, mu, sigma):
  kl_loss = -0.5 * (1 + sigma - tf.square(mu) - tf.exp(sigma))
  return tf.reduce_mean(kl_loss)

### Build the VAE Model

In [None]:
def get_vae(encoder, decoder, input_shape):
    inputs = tf.keras.layers.Input(shape=input_shape)
    mu, sigma, z = encoder(inputs)
    reconstructed = decoder(z)
    model = tf.keras.Model(inputs=inputs, outputs=reconstructed)
    loss = kl_divergence(inputs, z, mu, sigma)
    model.add_loss(loss)
    return model

In [None]:
vae = get_vae(encoder, decoder, input_shape)
vae.summary()

## Train the Model

In [None]:
optimizer = tf.keras.optimizers.Adam()
loss_metric = tf.keras.metrics.Mean()
bce_loss = tf.keras.losses.BinaryCrossentropy()

In [None]:
steps_per_epoch = 60000 // batch_size
for step, images in enumerate(train):
    with tf.GradientTape() as tape:
        reconstructed = vae(images)
        flattened_inputs = tf.reshape(images, shape=(-1))
        flatten_outputs = tf.reshape(reconstructed, shape=(-1))
        loss = bce_loss(flattened_inputs, flatten_outputs) * (image_width ** 2)
        loss += sum(vae.losses)
    grads = tape.gradient(loss, vae.trainable_weights)
    optimizer.apply_gradients(zip(grads, vae.trainable_weights))
    loss_metric(loss)
    if step % steps_per_epoch == 0 and step > 0:
        display.clear_output(wait=False)    
        images = decoder.predict(tf.random.normal(shape=[item_size ** 2, latent_dimension]))
        sample_images(images, item_size, item_size)

## Latent Space Visualization

In [None]:
def plot_label_clusters(encoder, data, labels):
    z_mean, _, _ = encoder.predict(data)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.show()
test_labels = tfds.load("mnist", split='test', as_supervised=True).map(lambda image, label: label).batch(batch_size, drop_remainder=True).prefetch(1)
test_labels = np.array([label for label in test_labels]).reshape(-1)
plot_label_clusters(encoder, test, test_labels)

## Evaluation

In [None]:
mse = tf.keras.metrics.MeanSquaredError()
mae = tf.keras.metrics.MeanAbsoluteError()
test_metrics = {"mse": [], "mae": []}
for images in test:
    gen_images = vae.predict(images)
    test_metrics["mse"].append(mse(images, gen_images).numpy())
    test_metrics["mae"].append(mae(images, gen_images).numpy())
print("MSE: ",np.mean(test_metrics["mse"]))
print("MAE: ",np.mean(test_metrics["mae"]))

## Save the Model

In [None]:
encoder.save("encoder.h5")

In [None]:
decoder.save("decoder.h5")

In [None]:
vae.save("vae.h5")