# Hands-on session 2.1: Variational Autoencoder
## Building Autoencoders in PyTorch

Made by **Nathan Painchaud** and **Pierre-Marc Jodoin** from the Université de Sherbrooke, Canada.

Inspired by a [similar tutorial](https://blog.keras.io/building-autoencoders-in-keras.html) for Keras by François Chollet.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%%capture packages_install

# Make sure the repo's package and its dependencies are installed
!pip install -e ../.

In [None]:
%%capture project_path_setup

import sys

if '../' in sys.path:
    print(sys.path)
else:
    sys.path.append('../')
    print(sys.path)

## PyTorch programming

PyTorch is one of the most widely-used deep-learning library in the world.  Unfortunately, despite all of our efforts to simplify the code as much as possible, some parts of this hands-on session might look a bit cryptic for those who are new to PyTorch.   Ideally, beginners should first get familiar with PyTorch via one or two tutorials such as

* [Deep Learning With PyTorch: A 60 Minute Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
* [Training a Classifier on CIFAR10](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

before coding autoencoders.  Unfortunately, you might not have enough time to go through these tutorials during the hands-on limited time.

We thus suggest you to start playing around with these two autoencoder notebooks, tweak with some hyperparameters, retrain the networks, and ask questions if you have any.

Afterwards, if you want to know more about the nuts and bolts of PyTorch, take the time to navigate through these PyTorch tutorials.  You will see, PyTorch is a formidable tool!

### Let's start by looking at our data
In the next cell, we will load the MNIST dataset and visualize some of its images.

In [None]:
import numpy as np
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from src.visualization.utils import display_data_samples

# MNIST consists of 28x28 images, so the size of the data is
data_shape = 28, 28
data_size = data_shape[0] * data_shape[1]

# Download and prepare data
transform = transforms.Compose([
    transforms.ToTensor(),
])
mnist_train = MNIST("../data", train=True, transform=transform)
mnist_test = MNIST("../data", train=False, transform=transform)

# Check data by displaying random images
samples_indices = np.random.randint(len(mnist_train), size=10)
mnist_img_list = [mnist_train[sample_idx][0] for sample_idx in samples_indices]
display_data_samples(data=mnist_img_list)

In [None]:
# What are `mnist_train` and `mnist_test`?  Let's look at it.
print(mnist_train)
print(mnist_test)

### PyTorch datasets

Please note that `mnist_train` and `mnist_test` are PyTorch **datasets**.  A dataset is an object that encapsulates data in the form of PyTorch tensors.  You may get access (and visualize) the data with the following code.

In [None]:
import matplotlib.pyplot as plt

# Get the first training image and its class label
sample_image = mnist_train[0][0] # sample_image is a "PyTorch tensor"
sample_label = mnist_train[0][1] 

# Convert the Tensor into a numpy array
sample_image_np = sample_image.numpy()  
print("Image size = ", sample_image_np.shape)

# Call "squeeze" to remove the first dimension
sample_image_np = sample_image_np.squeeze(0)
print("Image size = ", sample_image_np.shape)

# Plot
plt.imshow(sample_image_np)
print("The image label is ", sample_label)

### Let's build a deep deterministic autoencoder

Here, we will build a simple autoencoder with only **dense** (aka fully-connected) layers and **ReLUs**.  In pytorch, a dense layer is dubbed **Linear**.

Both the encoder and the decocer have **3 layers** and the latent space has **32 dimensions**.

Since the pixels have values between 0 and 1, the last activation function is a **Sigmoid**.

In [None]:
from torch import nn

# Let's define the encoder architecture we want,
# with some options to configure the input and output size
def make_encoder(data_size, latent_space_size):
    return nn.Sequential(
        nn.Linear(data_size, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, latent_space_size),
    )

# Same thing for the decoder
def make_decoder(data_size, latent_space_size):
    return nn.Sequential(
        nn.Linear(latent_space_size, 64),
        nn.ReLU(),
        nn.Linear(64, 128),
        nn.ReLU(),
        nn.Linear(128, data_size),
        nn.Sigmoid(),
    )

# Now let's build our networks, with an arbitrary dimensionality of the latent space
# and an input and output size depending on the data.
encoder = make_encoder(data_size, 32)
decoder = make_decoder(data_size, 32)

## Questions:
    
* Can you see how the architecture of the encoder is the dual of that of the decoder?
* Why do you think that the decoder has an output **sigmoid** activation function?
* What is the latent space size of the autoencoder?

We also have to define our **forward pass** algorithm

In [None]:
import torch
import torch.nn.functional as F

def autoencoder_forward_pass(encoder, decoder, x):
    """AE forward pass.

    Args:
        encoder: neural net that predicts a latent vector
        decoder: neural net that projects a point in the latent space back into the image space
        x: batch of N MNIST images

    Returns:
        loss: crossentropy loss
        x_hat: batch of N reconstructed images
    """
    in_shape = x.shape  # Save the input shape
    encoder_input = torch.flatten(x, start_dim=1)   # Flatten the 2D image to a 1D tensor (for the linear layer)
    z = encoder(encoder_input)  # Forward pass on the encoder (to get the latent space vector)
    x_hat = decoder(z)  # Forward pass on the decoder (to get the reconstructed input)
    x_hat = x_hat.reshape(in_shape)    # Restore the output to the original shape
    loss = F.binary_cross_entropy(x_hat, x) # Compute the reconstruction loss
    return loss, x_hat

## Training algorithm

Before we can train our model, we have to define our training algorithm.

In [None]:
import os
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

# Define some training hyperparameters
epochs = 25
batch_size = 256

def train(forward_pass_fn, encoder, decoder, optimizer, train_data, val_data, device="cuda"):
    # Create dataloaders from the data
    # Those are PyTorch's abstraction to help iterate over the data
    data_loader_kwargs = {"batch_size": batch_size, "num_workers": os.cpu_count() - 1, "pin_memory": True}
    train_dataloader = DataLoader(train_data, shuffle=True, **data_loader_kwargs)
    val_dataloader = DataLoader(val_data, **data_loader_kwargs)

    # Ensure that the networks are on the requested device (typically a GPU)
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    fit_pbar = tqdm(range(epochs), desc="Training", unit="epoch")
    pbar_metrics = {"train_loss": None, "val_loss": None}
    for epoch in fit_pbar:

        # Train once over all the training data
        for x, _ in train_dataloader:
            x = x.to(device)    # Move the data tensor to the device
            optimizer.zero_grad()   # Make sure gradients are reset
            train_loss, _ = forward_pass_fn(encoder, decoder, x)    # Forward pass
            train_loss.backward()   # Backward pass
            optimizer.step()    # Update parameters w.r.t. optimizer and gradients
            pbar_metrics["train_loss"] = train_loss.item()
            fit_pbar.set_postfix(pbar_metrics)

        # At the end of the epoch, check performance against the validation data
        for x, _ in val_dataloader:
            x = x.to(device)    # Move the data tensor to the device
            val_loss, _ = forward_pass_fn(encoder, decoder, x)
            pbar_metrics["val_loss"] = val_loss.item()
            fit_pbar.set_postfix(pbar_metrics)

## Questions:

The previous `train(...)` function contains a typical **PyTorch training loop**.  That training loop contains a **forward pass**, a **backward pass**, a **gradient step** (*optimizer.step()*) and a **validation check**.  Also, common to PyTorch are **data loaders**.  A data loader is an object that encapsulates a dataset and provides an iterable over its content.

* Do you see what the data loaders are used for?
* Do you see what the forward pass does?  What are the inputs and outputs of that function?


Finally, let's train our model!

In [None]:
optimizer = torch.optim.Adam([*encoder.parameters(), *decoder.parameters()])
train(autoencoder_forward_pass, encoder, decoder, optimizer, mnist_train, mnist_test)

Now, let's take a look at the results on the test set.

In [None]:
from src.visualization.utils import display_autoencoder_results

display_autoencoder_results(mnist_test, lambda x: autoencoder_forward_pass(encoder, decoder, x.cuda())[1])

### Latent space

Before we move on to the variational autoencoder, go back to the beginning of this notebook and replace the 32 latent space size by a size of 2 and retrain the autoencoder.

Once this is done, execute the following cell to visualize the latent space.


In [None]:
# Run this cell only if the autoencoder has a latent space size of 2.

from src.visualization.latent_space import explore_latent_space

latent_space_size = 2

explore_latent_space(
    mnist_test,
    lambda x: encoder(torch.flatten(x, start_dim=1)),
    lambda z: decoder(z).reshape(data_shape),
    encodings_label="target",
)

## Question :

Why do you think that with a 2D latent space we end up reconstructing less accurate (more blurry) images?

### Let's turn our autoencoder variational

Variational autoencoders (VAE) are very similar to autoencoders.  The differences are threefold:

* The VAE's encoder ouputs mean and variance vectors
* The input of the decoder is a vector, randomly sampled, from a Normal distribution determined by the predicted mean and variance vectors 
* The loss has 2 terms: the reconstruction loss (like for the normal AE) + the KL divergence (for the encoder's output)

Since gradient cannot back-propagate into a random sampling method, VAE always come with a **reparametrization trick**.

In [None]:
# This time, we start right away with a 2D latent space to visualize it easily afterwards
latent_space_size = 2

# In practice, a small trick to easily implement the two heads of the encoder is to simply
# double the size of its output. Then, we can slice the output in half during the forward pass!
vae_encoder = make_encoder(data_size, latent_space_size * 2) 
vae_decoder = make_decoder(data_size, latent_space_size)

## Questions:

* In the previous cell, we used the same function to build our VAE's encoder and decoder networks than the AE.  The only difference is the output size of the encoder is multiplied by 2.  Why do you think that is?
* In the next cell, we include the **reparametrization trick** to the **forward pass**.  Remember why this has to be done?
* What is the latent space size of the VAE?

We also need to change the training algorithm, since we have to implement the **reparametrization trick**.

In [None]:
def kl_div(mu, logvar):
    kl_div_by_samples = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    return torch.mean(kl_div_by_samples)

def vae_forward_pass(encoder, decoder, x):
    """VAE forward pass.

    Args:
        encoder: neural net that predicts a mean and a logvar vector
        decoder: neural net that projects a point in the latent space back into the image space
        x: batch of N MNIST images

    Returns:
        loss: crossentropy + kl_divergence loss
        x_hat: batch of N reconstructed images
    """
    in_shape = x.shape  # Save the input shape
    encoder_input = torch.flatten(x, start_dim=1)  # Flatten the 2D image to a 1D tensor (for the linear layer)
    encoding_distr = encoder(encoder_input)  # Forward pass on the encoder (to get the latent space posterior)
    # Nothing changed so far!

    # Second part of our trick!
    # We separate the (unique) latent space posterior into its two halves: mean and logvar
    mu, logvar = encoding_distr[:, :latent_space_size], encoding_distr[:, latent_space_size:]

    # Reparametrization trick
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std

    # Decoding mostly stays the same. The only difference is the added 4th line below
    x_hat = decoder(z)  # Forward pass on the decoder (to get the reconstructed input)
    x_hat = x_hat.reshape(in_shape)    # Restore the output to the original shape
    loss = F.binary_cross_entropy(x_hat, x) # Compute the reconstruction loss
    loss += 5e-3 * kl_div(mu, logvar)  # Loss now also includes the KL divergence term
    return loss, x_hat

Now it's time to train our variational autoencoder!

In [None]:
optimizer = torch.optim.Adam([*vae_encoder.parameters(), *vae_decoder.parameters()])
train(vae_forward_pass, vae_encoder, vae_decoder, optimizer, mnist_train, mnist_test)

Now, let's take a look at the results on the test set.

In [None]:
display_autoencoder_results(mnist_test, lambda x: vae_forward_pass(vae_encoder, vae_decoder, x.cuda())[1])

## More visualization 

Now that we have a latent space in two dimensions, we can easily visualize it and look at how the data is
distributed.

### See the difference between this latent space and that of the previous autoencoder?

In [None]:
from src.visualization.latent_space import explore_latent_space

explore_latent_space(
    mnist_test,
    lambda x: vae_encoder(torch.flatten(x, start_dim=1))[:, :latent_space_size],
    lambda z: vae_decoder(z).reshape(data_shape),
    encodings_label="target",
)

### In the next cell, we shall decode one selected vector `z` in the latent space.  Change the content of that vector and you will see what happens!

In [None]:
import matplotlib.pyplot as plt

z = [-1, -1]  # 2D latent vector

z_torch = torch.tensor(z, dtype=torch.float).cuda()  # convert Z into a PyTorch tensor

sample = vae_decoder(z_torch).reshape(data_shape)  # decode the latent vector with the VAE decoder

plt.imshow(sample.detach().cpu().numpy()) # plot the resulting image