# Simple Variational AutoEncoder training and inference on noisy MNIST dataset

**Description**: in this notebook, we showcase the training process and inference capabilities of a simple variational auto-encoder model on the MNIST dataset.
The model is trained to reconstruct noisy images.
The noisy images are built from MNIST images, with added random noise.

## Imports, definitions and setup

The first block is needed only when the current environment doesn't have the `dlproject` package installed.
Therefore, if you already cloned the whole repository and run the `pip install -e .` command, you can skip the first block.

If you're running this notebook only on a Jupyter server, run the first block as well in order to obtain the necessary dependencies.

In [None]:
!git clone https://github.com/peiva-git/deep_learning_project.git
%cd deep_learning_project
!pip install -e .

In [None]:
import dlproject as dlp
import tensorflow as tf

import os

## Load the MNIST dataset

In [None]:
dataset_builder = dlp.data.MNISTDatasetBuilder()
dataset_builder.preprocess_dataset_simple_vae(noise_factor=0.4)
train_x, test_x = dataset_builder.train_x, dataset_builder.test_x
train_y, test_y = dataset_builder.train_y, dataset_builder.test_y
noisy_train_data, noisy_test_data = dataset_builder.noisy_train_data, dataset_builder.noisy_test_data

## Instantiate the model

In [None]:
simple_vae = dlp.models.SimpleVAE(input_dim=28 * 28, latent_dim=2)
vae = simple_vae.vae
encoder = simple_vae.encoder
decoder = simple_vae.decoder
vae.compile(optimizer='adam')

## Train the model

Train the instantiated model on the MNIST dataset.

This block also saves a backup and a checkpoint every 20 epochs, so that you can automatically resume the training if it gets interrupted.

In [None]:
if not os.path.exists(os.path.join(os.getcwd(), 'output', 'training-callback-results', f'{vae.name}_noisy_mnist')):
    os.makedirs(os.path.join(os.getcwd(), 'output', 'training-callback-results', f'{vae.name}_noisy_mnist', 'backup'))
    os.makedirs(os.path.join(os.getcwd(), 'output', 'training-callback-results', f'{vae.name}_noisy_mnist', 'model_checkpoints'))
    os.makedirs(os.path.join(os.getcwd(), 'output', 'training-callback-results', f'{vae.name}_noisy_mnist', 'tensorboard_logs'))

model_dir_path = os.path.join(os.getcwd(), 'output', 'training-callback-results', f'{vae.name}_noisy_mnist')

vae.fit(
    noisy_train_data, train_x,
    epochs=100,
    batch_size=32,
    validation_data=(noisy_test_data, test_x),
    callbacks=[
        tf.keras.callbacks.BackupAndRestore(
            backup_dir=os.path.join(model_dir_path, 'backup'),
            save_freq=37500 # 20 * 1875, each 20 epochs
        ),
        tf.keras.callbacks.ModelCheckpoint(os.path.join(model_dir_path, 'model_checkpoints'), save_freq=37500),
        tf.keras.callbacks.TensorBoard(log_dir=os.path.join(model_dir_path, 'tensorboard_logs'))
    ]
)

## Save the trained model

Save the just trained model for later use.

In [None]:
if not os.path.exists(os.path.join(os.getcwd(), 'output', 'models')):
    os.makedirs(os.path.join(os.getcwd(), 'output', 'models'))

vae.save_weights(os.path.join(os.getcwd(), 'output', 'models', f'{vae.name}_weights_noisy_mnist.keras'))
encoder.save_weights(os.path.join(os.getcwd(), 'output', 'models', f'{encoder.name}_weights_noisy_mnist.keras'))
decoder.save_weights(os.path.join(os.getcwd(), 'output', 'models', f'{decoder.name}_weights_noisy_mnist.keras'))

## Load the model

Instead of training the model, you can load its weights from a previously saved `.keras` file.

In [None]:
vae.load_weights(os.path.join(os.getcwd(), 'models', f'{vae.name}_weights_noisy_mnist.keras'))
encoder.load_weights(os.path.join(os.getcwd(), 'models', f'{encoder.name}_weights_noisy_mnist.keras'))
decoder.load_weights(os.path.join(os.getcwd(), 'models', f'{decoder.name}_weights_noisy_mnist.keras'))

## Visualization

Display a scatter plot of the encoded test data.

In [None]:
dlp.data.show_encoder_scatter_plot(noisy_test_data, test_y, encoder)

Display artificially generated digits.

In [None]:
dlp.data.show_latent_plane_sampled_points(decoder, (-1, 1), (-1, 1), number_of_figures=15, figure_size=28)

## Metrics

Compute the PSNR and the SSIM metrics for the trained VAE model, between the original testing images and the reconstructed images.

In [None]:
reconstructed_images = vae.predict(noisy_test_data)
print(dlp.evaluation.compute_mean_psnr(test_x, reconstructed_images))
print(dlp.evaluation.compute_mean_ssim(test_x, reconstructed_images))