# MNIST Handwritten Digit Recognition using VAE

In [None]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    print('Running TF without GPU')
else:
    print(f'Found GPU at {device_name}')

In [None]:
%matplotlib inline

import glob
import imageio
import math
import os
import seaborn as sn
import time

from abc import abstractstaticmethod
from matplotlib import gridspec
from matplotlib import pyplot as plt
from PIL import Image
from tensorflow import keras

sn.set_theme()

In [None]:
TRAIN_SIZE = 60000
BATCH_SIZE = 64
EPOCHS = 200

N_EXAMPLES = 25
LEARNING_RATE = 1e-4
LATENT_DIMENSION = 3

ARCH = 'vae'
METRICS_PATH = f'metrics/{ARCH}/'
OUTPUT_PATH = f'output/{ARCH}/'

IN_COLAB = False
try:
    from google.colab import drive
    drive.mount('/content/drive')
    path_prefix = '/content/drive/My Drive/gan-vae/vae/'
    IN_COLAB = True
except:
    IN_COLAB = False
    
METRICS_PATH = f'{path_prefix}metrics/{ARCH}/'
OUTPUT_PATH = f'{path_prefix}output/{ARCH}/'

if not IN_COLAB:
    !mkdir -p $METRICS_PATH
    !mkdir -p $OUTPUT_PATH

## Base Class

In [None]:
class BaseNetwork(tf.Module):
    def __init__(self):
        super().__init__()

    @tf.Module.with_name_scope
    def __call__(self, input_data, training=False) -> tf.Tensor:
        output_data = input_data
        for layer in self.layers:
            output_data = layer(output_data, training=training)
        return output_data

## Encoder Network
### Layers
<pre>
<b>Input</b>:  input_size=(28, 28, 1), target_shape=(784, 1)
<b>Dense</b>:  units=256, activation=tanh
<b>Dense</b>:  units=2*2, target_shape=(3, 2)
</pre>
### Output
The output layer of the encoder network represents the mean (&mu;) and log-variance (log &sigma;<sup>2</sup>) of the posterior distribution which the decoder should sample from. The output looks as follows:
<pre>
         [ &mu;<sub>1</sub> &sigma;<sub>1</sub> ]
e(x)  ~  [ &mu;<sub>2</sub> &sigma;<sub>2</sub> ]
         [ &mu;<sub>3</sub> &sigma;<sub>3</sub> ]
</pre>
Using the reparameterization trick, the input to the decoder network becomes:
<center><i><b>z</b></i> = &mu; + &sigma; &xodot; &epsilon;</center>
where &epsilon; ~ <i>N</i>(0,<i>I</i>). We use log &sigma;<sup>2</sup> here since the parameters the network outputs does not have a bound just as log &sigma;<sup>2</sup> has no bound on its range, i.e the output does not have any non-linearity which would normally restrict the range of the output values, as to not completely restrict the latent distribution which the values parameterize.

In [None]:
class Encoder(BaseNetwork):
    def __init__(self, input_shape: tuple, latent_dimension: int) -> None:
        super().__init__()
        self.latent_dimension = latent_dimension
        self.layers = [
            keras.layers.InputLayer(input_shape=input_shape),
            keras.layers.Flatten(),
            keras.layers.Dense(units=256),
            keras.layers.Activation(tf.nn.tanh),
            keras.layers.Dense(units=latent_dimension*2),
            keras.layers.Reshape(target_shape=(latent_dimension, 2))
        ]
    
    @tf.Module.with_name_scope
    def __call__(self, input_data, training=False) -> tf.Tensor:
        output_data = input_data
        for layer in self.layers:
            output_data = layer(output_data, training=training)
        mean, log_var = output_data[:, :, 0], output_data[:, :, 1]
        epsilon = tf.random.normal(shape=log_var.shape)
        # sd = e^(log(var^.5)) = e^(0.5*log(var))
        stddev = tf.math.multiply(epsilon, tf.math.exp(0.5 * log_var))
        # Reparameterization: z ~ N(mean, stddev * epsilon)
        return mean + stddev*epsilon, mean, tf.math.exp(log_var)

## Decoder Network
### Layers
<pre>
<b>Input</b>:   input_size=(2,)
<b>Dense</b>:   units=256, activation=tanh
<b>Dense</b>:   units=784, activation=sigmoid, target_shape=(28, 28)
</pre>

### Optimizer
<pre>
<b>Adagrad</b>: learning_rate=0.0001
</pre>

### Loss
<pre>
<b>&Lscr;</b>(<i>&theta;</i>,<i>&phi;</i>) = -<i>D<sub>KL</sub></i>(<i>q<sub>&phi;</sub></i>(<i><b>z</b></i>) &vert;&vert; <i>p<sub>&theta;</sub></i>(<i><b>z</b></i>)) + &Eopf;<sub><i>q<sub>&phi;</sub></i>(<i><b>z</b></i>&vert;<i><b>x</b></i>)</sub>log <i>p<sub>&theta;</sub></i>(<i><b>x</b></i>&vert;<i><b>z</b></i>)
        &approx; &frac12; &Sum;<sub>j</sub> [1 + log(<i>&sigma;<sub>j</sub><sup>2</sup></i>) - <i>&mu;<sub>j</sub><sup>2</sup></i> - <i>&sigma;<sub>j</sub><sup>2</sup></i>] + &Eopf;<sub><i>q<sub>&phi;</sub></i>(<i><b>z</b></i>&vert;<i><b>x</b></i>)</sub>log <i>p<sub>&theta;</sub></i>(<i><b>x</b></i>&vert;<i><b>z</b></i>)
        &approx; &frac12; &Sum;<sub>j</sub> [1 + log(<i>&sigma;<sub>j</sub><sup>2</sup></i>) - <i>&mu;<sub>j</sub><sup>2</sup></i> - <i>&sigma;<sub>j</sub><sup>2</sup></i>] +  &Sum;<sub>i</sub> [<i>x<sub>i</sub></i>log(<i>y<sub>i</sub></i>) + (1 - <i>x<sub>i</sub></i>)&lowast;log(1 - <i>y<sub>i</sub></i>)]
</pre>

### Goal
Find argmin<sub><i>&theta;</i>,<i>&phi;</i></sub> <b>&Lscr;</b>(<i>&theta;</i>,<i>&phi;</i>), and use <i>p<sub>&theta;</sub></i>(<i><b>x</b></i>&vert;<i><b>z</b></i>) as a generator

In [None]:
class Decoder(BaseNetwork):
    def __init__(self, latent_dimension: int, output_shape: tuple) -> None:
        super().__init__()
        self.latent_dimension = latent_dimension
        self.layers = [
            keras.layers.InputLayer(input_shape=(latent_dimension,)),
            keras.layers.Dense(units=256),
            keras.layers.Activation(tf.nn.tanh),
            keras.layers.Dense(units=tf.math.reduce_prod(output_shape)),
            keras.layers.Activation(tf.nn.sigmoid),
            keras.layers.Reshape(target_shape=output_shape),
        ]

## Preprocessing
* Import MNIST images
* Normalize to \[0, 1\]
* Collapse pixels to 0/1 for sharper image features
* Shuffle and batch dataset

In [None]:
(train_images, _), (test_images, _) = keras.datasets.mnist.load_data()
train_images = train_images[:TRAIN_SIZE] / 255.
train_images = tf.round(train_images)
train_images = tf.dtypes.cast(train_images, tf.float32)
train_dataset = tf.data.Dataset.from_tensor_slices(train_images) \
                               .shuffle(TRAIN_SIZE) \
                               .batch(BATCH_SIZE)
image_shape = train_images.shape[1:]

test_images = test_images[:N_EXAMPLES] / 255.
test_images = tf.round(test_images)
test_images = tf.dtypes.cast(test_images, tf.float32)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images) \
                              .shuffle(TRAIN_SIZE) \
                              .batch(BATCH_SIZE)

### Initialize Encoder and Decoder

In [None]:
encoder = Encoder(image_shape, LATENT_DIMENSION)
decoder = Decoder(LATENT_DIMENSION, image_shape)

### Define metrics to record through training

In [None]:
sample_encoded_tests = test_images[:N_EXAMPLES, :, :]
metric_names = ['kl_loss', 'rec_loss', 'loss']
batch_history = { name: [] for name in metric_names }

### Capture training visualization

In [None]:
def record_sample(encoder: Encoder, decoder: Decoder, epoch: int, save: bool=True, show: bool=True):
    encoded_input, _, _ = encoder(sample_encoded_tests)
    sampled_input = tf.random.normal(encoded_input.shape)

    encoder_decoder = decoder(encoded_input)
    sampled_decoder = decoder(sampled_input)

    f = plt.figure(figsize=(12, 4))
    f.suptitle('VAE Output')
    outer = gridspec.GridSpec(1, 3)
    titles = ['Ground Truth', 'Encoder Latent Distribution', 'Standard Normal Distribution']
    data = [sample_encoded_tests, encoder_decoder, sampled_decoder]
    for i in range(3):
        inner = gridspec.GridSpecFromSubplotSpec(int(math.sqrt(N_EXAMPLES)), int(math.sqrt(N_EXAMPLES)), subplot_spec=outer[i])
        predictions = data[i]
        for j in range(predictions.shape[0]):
            ax = plt.Subplot(f, inner[j])
            ax.imshow(predictions[j] * 255, cmap=plt.cm.gray)
            ax.axis('off')
            if j == int(math.sqrt(N_EXAMPLES)) // 2:
                ax.set_title(titles[i])
            f.add_subplot(ax)

    if save:
        plt.savefig(os.path.join(OUTPUT_PATH, 'epoch_{:04d}.png'.format(epoch)))
    if show:
        plt.show()
    plt.close()

## VAE Algorithm
<pre>
<i>&theta;</i>,<i>&phi;</i> &leftarrow; Initialize parameters
<b>repeat</b>
    <i>X<sup>M</sup></i>  &leftarrow; Random minibatch of <i>M</i> datapoints (drawn from full dataset)
    <i>&epsilon;</i>   &leftarrow; Random samples from noise distribution <i>p</i>(<i>&epsilon;</i>)
    <i>g</i>   &leftarrow; &Del;<sub><i>&theta;</i>,<i>&phi;</i></sub>&Lscr;<sup>M</sup>(<i>&theta;</i>,<i>&phi;</i>; <i>X<sup>M</sup></i>,<i>&epsilon;</i>) (Gradients of minibatch estimator)
    <i>&theta;</i>,<i>&phi;</i> &leftarrow; Update parameters using gradients <i><b>g</b></i> (e.g. SGD or Adagrad)
<b>until</b> convergence of parameters (<i>&theta;</i>,<i>&phi;</i>)
<b>return</b> <i>&theta;</i>,<i>&phi;</i>
</pre>

Source: https://arxiv.org/pdf/1312.6114.pdf

In [None]:
record_sample(encoder, decoder, 0, show=False)
optimizer = tf.optimizers.Adam(learning_rate=LEARNING_RATE)
epsilon = 1e-8

for epoch in range(1, EPOCHS+1):
    print(f'\nepoch {epoch}/{EPOCHS}')
    progress_bar = keras.utils.Progbar(TRAIN_SIZE / BATCH_SIZE, stateful_metrics=metric_names)

    for i, image_batch in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            posterior_sample, mean, var = encoder(image_batch, training=True)

            reconstructed_image_batch = decoder(posterior_sample,training=True)

            analytic_kl_divergence = 0.5 * tf.reduce_sum(
                1 + tf.math.log(var+epsilon) - tf.math.square(mean) - var,
                axis=1,
            )
            log_likelihood = tf.reduce_sum(
                tf.math.add(
                    tf.math.multiply(image_batch, tf.math.log(reconstructed_image_batch+epsilon)),
                    tf.math.multiply((1-image_batch), tf.math.log(1-reconstructed_image_batch+epsilon)),
                ),
                axis=[1,2],
            )
            kl_loss = -tf.reduce_mean(analytic_kl_divergence)
            rec_loss = -tf.reduce_mean(log_likelihood)
            loss = kl_loss + rec_loss

        training_metrics = {
            'kl_loss': kl_loss,
            'rec_loss': rec_loss,
            'loss': loss,
        }

        # Record loss history
        for metric in batch_history:
            batch_history[metric].append(training_metrics[metric])

        metric_values = training_metrics.items()
        progress_bar.update(i, values=metric_values)

        model_vars = [*encoder.trainable_variables, *decoder.trainable_variables]
        grad = tape.gradient(loss, model_vars)
        optimizer.apply_gradients(zip(grad, model_vars))

    record_sample(encoder, decoder, epoch, show=False)

### Loss Plots

In [None]:
gs = gridspec.GridSpec(2, 2)
plt.figure(figsize=(12, 12))

batch_label = 'Batch #'
loss_label = 'Loss'

ax = plt.subplot(gs[0, 0])
plt.title('KL Divergence')
plt.xlabel(batch_label)
plt.ylabel(loss_label)
plt.plot(tf.math.reduce_mean(tf.reshape(batch_history['rec_loss'], shape=(-1, math.ceil(TRAIN_SIZE / BATCH_SIZE))), axis=1))

ax = plt.subplot(gs[0, 1])
plt.title('Reconstruction Loss')
plt.xlabel(batch_label)
plt.ylabel(loss_label)
plt.plot(tf.math.reduce_mean(tf.reshape(batch_history['rec_loss'], shape=(-1, math.ceil(TRAIN_SIZE / BATCH_SIZE))), axis=1))

ax = plt.subplot(gs[1, :])
plt.title('Total Loss')
plt.xlabel(batch_label)
plt.ylabel(loss_label)
plt.plot(tf.math.reduce_mean(tf.reshape(batch_history['loss'], shape=(-1, math.ceil(TRAIN_SIZE / BATCH_SIZE))), axis=1))

plt.savefig(os.path.join(METRICS_PATH, 'loss.png'))
plt.show()
plt.close()

### Final Result

In [None]:
with imageio.get_writer(os.path.join(METRICS_PATH, f'{ARCH}.gif'), mode='I') as writer:
    filenames = glob.glob(os.path.join(OUTPUT_PATH, 'epoch*.png'))
    filenames = sorted(filenames)
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)
        image = imageio.imread(filename)
        writer.append_data(image)
final = Image.open(os.path.join(OUTPUT_PATH, 'epoch_{:04d}.png'.format(epoch)))
final.save(os.path.join(METRICS_PATH, 'epoch_{:04d}.png'.format(epoch)))