# Simple Variational AutoEncoder training and inference CIFAR10 example

**Description**: in this notebook, we showcase the training process and inference capabilities of a simple variational auto-encoder model on the CIFAR10 dataset.

## Imports, definitions and setup

The first block is needed only when the current environment doesn't have the `dlproject` package installed.
Therefore, if you already cloned the whole repository and run the `pip install -e .` command, you can skip the first block.

If you're running this notebook only on a Jupyter server, run the first block as well in order to obtain the necessary dependencies.

In [None]:
!git clone https://github.com/peiva-git/deep_learning_project.git
%cd deep_learning_project
!pip install -e .

In [None]:
import dlproject as dlp
import tensorflow as tf

import os

## Load the CIFAR10 dataset

The CIFAR10 dataset contains 50000 32x32 color images in the training set and 10000 images in the testing set, labeled over 10 categories.

In order to use the images on this simple model, they're converted to grayscale 32x32 images in the pre-processing step.

In [None]:
dataset_builder = dlp.data.CIFAR10DatasetBuilder()
dataset_builder.preprocess_dataset_simple_vae()
train_x, test_x = dataset_builder.train_x, dataset_builder.test_x
train_y, test_y = dataset_builder.train_y, dataset_builder.test_y

## Instantiate the model

In [None]:
simple_vae = dlp.models.SimpleVAE(input_dim=32 * 32, latent_dim=2)
vae = simple_vae.vae
encoder = simple_vae.encoder
decoder = simple_vae.decoder
vae.compile(optimizer=tf.keras.optimizers.Adam())

## Train the model

Train the instantiated model on the CIFAR10 dataset.

This block also saves a backup and a checkpoint every 20 epochs, so that you can automatically resume the training if it gets interrupted.

In [None]:
if not os.path.exists(os.path.join(os.getcwd(), 'output', 'training-callback-results', f'{vae.name}_cifar10')):
    os.makedirs(os.path.join(os.getcwd(), 'output', 'training-callback-results', f'{vae.name}_cifar10', 'backup'))
    os.makedirs(os.path.join(os.getcwd(), 'output', 'training-callback-results', f'{vae.name}_cifar10', 'model_checkpoints'))
    os.makedirs(os.path.join(os.getcwd(), 'output', 'training-callback-results', f'{vae.name}_cifar10', 'tensorboard_logs'))

model_dir_path = os.path.join(os.getcwd(), 'output', 'training-callback-results', f'{vae.name}_cifar10')

vae.fit(
    train_x, train_x,
    epochs=100,
    batch_size=32,
    validation_data=(test_x, test_x),
    callbacks=[
        tf.keras.callbacks.BackupAndRestore(
            backup_dir=os.path.join(model_dir_path, 'backup'),
            save_freq=31260 # 20 * 1563, each 20 epochs
        ),
        tf.keras.callbacks.ModelCheckpoint(os.path.join(model_dir_path, 'model_checkpoints'), save_freq=31260),
        tf.keras.callbacks.TensorBoard(log_dir=os.path.join(model_dir_path, 'tensorboard_logs'))
    ]
)

## Save the trained model

Save the trained model's weights for later use.

In [None]:
if not os.path.exists(os.path.join(os.getcwd(), 'output', 'models')):
    os.makedirs(os.path.join(os.getcwd(), 'output', 'models'))

vae.save_weights(os.path.join(os.getcwd(), 'output', 'models', f'{vae.name}_weights_cifar10.keras'))
encoder.save_weights(os.path.join(os.getcwd(), 'output', 'models', f'{encoder.name}_weights_cifar10.keras'))
decoder.save_weights(os.path.join(os.getcwd(), 'output', 'models', f'{decoder.name}_weights_cifar10.keras'))

## Load the model

Instead of training the model, you can load its weights from a previously saved `.keras` file.

In [None]:
vae.load_weights(os.path.join(os.getcwd(), 'models', f'{vae.name}_weights_cifar10.keras'))
encoder.load_weights(os.path.join(os.getcwd(), 'models', f'{encoder.name}_weights_cifar10.keras'))
decoder.load_weights(os.path.join(os.getcwd(), 'models', f'{decoder.name}_weights_cifar10.keras'))

## Visualization

Display a scatter plot of the encoded test data.

In [None]:
dlp.data.show_encoder_scatter_plot(test_x, test_y, encoder)

Display artificially generated images.

In [None]:
dlp.data.show_latent_plane_sampled_points(decoder, (-1, 1), (-1, 1), number_of_figures=15, figure_size=32)

## Metrics

Compute the PSNR and the SSIM metrics for the trained VAE model, between the original testing images and the reconstructed images.

In [None]:
reconstructed_images = vae.predict(test_x)
print(dlp.evaluation.compute_mean_psnr(test_x, reconstructed_images))
print(dlp.evaluation.compute_mean_ssim(test_x, reconstructed_images))