# Hands-on session 2.1: Variational Autoencoder
## Building Autoencoders in PyTorch
Inspired by a [similar tutorial](https://blog.keras.io/building-autoencoders-in-keras.html) for Keras by François Chollet.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%%capture packages_install

# Make sure the notebook's dependencies are installed
import sys
!{sys.executable} -m pip install -r ../requirements.txt

### Let's start by looking at our data

In [None]:
import numpy as np
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from src.visualization.utils import display_data_samples

# MNIST consists of 28x28 images, so the size of the data is
data_shape = 28, 28
data_size = data_shape[0] * data_shape[1]

# Download and prepare data
transform = transforms.Compose([
    transforms.ToTensor(),
])
mnist_train = MNIST("../data", train=True, download=True, transform=transform)
mnist_test = MNIST("../data", train=False, download=True, transform=transform)

# Check data by displaying random images
samples_indices = np.random.randint(len(mnist_train), size=10)
display_data_samples(data=[mnist_train[sample_idx][0] for sample_idx in samples_indices])

### Let's build a deep deterministic autoencoder

In [None]:
from torch import nn

# Let's define the encoder architecture we want,
# with some options to configure the input and output size
def make_encoder(data_size, encoding_size):
    return nn.Sequential(
        nn.Linear(data_size, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, encoding_size),
    )

# Same thing for the decoder
def make_decoder(data_size, encoding_size):
    return nn.Sequential(
        nn.Linear(encoding_size, 64),
        nn.ReLU(),
        nn.Linear(64, 128),
        nn.ReLU(),
        nn.Linear(128, data_size),
        nn.Sigmoid(),
    )

# Now let's build our networks, with an arbitrary dimensionality of the latent space
# and an input and output size depending on the data.
encoder = make_encoder(data_size, 32)
decoder = make_decoder(data_size, 32)

Before we can train our model, we have to define our training algorithm.

In [None]:
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

# Define some training hyperparameters
epochs = 25
batch_size = 256

def fit(step_fn, encoder, decoder, optimizer, train_data, val_data, device="cuda"):
    # Create dataloaders from the data
    # Those are PyTorch's abstraction to help iterate over the data
    train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_data, batch_size=batch_size)

    # Ensure that the networks are on the requested device
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    fit_pbar = tqdm(range(epochs), desc="Training", unit="epoch")
    pbar_metrics = {"train_loss": None, "val_loss": None}
    for epoch in fit_pbar:
        # Set model in training mode before training
        encoder.train()
        decoder.train()

        # Train once over all the training data
        for x, _ in train_dataloader:
            x = x.to(device)    # Move the data tensor to the device
            optimizer.zero_grad()   # Make sure gradients are reset
            train_loss, _ = step_fn(encoder, decoder, x)    # Forward pass
            train_loss.backward()   # Backward pass
            optimizer.step()    # Update parameters w.r.t. optimizer and gradients
            pbar_metrics["train_loss"] = train_loss.item()
            fit_pbar.set_postfix(pbar_metrics)

        # Set model in eval mode before validation
        encoder.eval()
        decoder.eval()

        # At the end of the epoch, check performance against the validation data
        for x, _ in val_dataloader:
            x = x.to(device)    # Move the data tensor to the device
            val_loss, _ = step_fn(encoder, decoder, x)
            pbar_metrics["val_loss"] = val_loss.item()
            fit_pbar.set_postfix(pbar_metrics)

Now, we have to define a training step that's specific to our training algorithm

In [None]:
import torch
import torch.nn.functional as F

def autoencoder_step(encoder, decoder, x):
    in_shape = x.shape  # Save the input shape
    encoder_input = torch.flatten(x, start_dim=1)   # Flatten the 2D image to a 1D tensor (for the linear layer)
    z = encoder(encoder_input)  # Forward pass on the encoder (to get the latent space vector)
    x_hat = decoder(z)  # Forward pass on the decoder (to get the reconstructed input)
    x_hat = x_hat.reshape(in_shape)    # Restore the output to the original shape
    loss = F.binary_cross_entropy(x_hat, x) # Compute the reconstruction loss
    return loss, x_hat

Finally, let's train our model!

In [None]:
optimizer = torch.optim.Adam([*encoder.parameters(), *decoder.parameters()])
fit(autoencoder_step, encoder, decoder, optimizer, mnist_train, mnist_test)

Now, let's take a look at the results on the test set.

In [None]:
from src.visualization.utils import display_autoencoder_results

display_autoencoder_results(mnist_test, lambda x: autoencoder_step(encoder, decoder, x.cuda())[1])

### Let's make our autoencoder variational!

At a high-level, we have to change the encoder so that it has two outputs (for the mean and variance), instead of one.

Do you think we should change the decoder's architecture?

In [None]:
# This time, we should make the latent space 2-dimensional to visualize it easily afterwards
encoding_size = 2

# In practice, a small trick to easily implement the two heads of the encoder is to simply
# double the size of its output. Then, we can slice the output in half during the forward pass!
encoder = make_encoder(data_size, encoding_size * 2)
decoder = make_decoder(data_size, encoding_size)

We also need to change the training algorithm, since we have to implement the reparametrization trick.

In [None]:
def kl_div(mu, logvar):
    kl_div_by_samples = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    return torch.mean(kl_div_by_samples)

def vae_step(encoder, decoder, x):
    in_shape = x.shape  # Save the input shape
    encoder_input = torch.flatten(x, start_dim=1)   # Flatten the 2D image to a 1D tensor (for the linear layer)
    encoding_distr = encoder(encoder_input)  # Forward pass on the encoder (to get the latent space posterior)
    # Nothing changed so far!

    # Second part of our trick!
    # We separate the (unique) latent space posterior into its two halves: mean and logvar
    mu, logvar = encoding_distr[:, :encoding_size], encoding_distr[:, encoding_size:]

    # Reparametrization trick
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std

    # Decoding mostly stays the same. The only difference is the added 4th line below
    x_hat = decoder(z)  # Forward pass on the decoder (to get the reconstructed input)
    x_hat = x_hat.reshape(in_shape)    # Restore the output to the original shape
    loss = F.binary_cross_entropy(x_hat, x) # Compute the reconstruction loss
    loss += 1e-5 * kl_div(mu, logvar)  # Loss now also includes the KL divergence term
    return loss, x_hat

Now it's time to train our variational autoencoder!

In [None]:
optimizer = torch.optim.Adam([*encoder.parameters(), *decoder.parameters()])
fit(vae_step, encoder, decoder, optimizer, mnist_train, mnist_test)

Now, let's take a look at the results on the test set.

In [None]:
display_autoencoder_results(mnist_test, lambda x: vae_step(encoder, decoder, x.cuda())[1])

Now that we've got a latent space in two dimensions, we can easily visualize it and look at how the data was
distributed.

In [None]:
from src.visualization.latent_space import explore_latent_space

explore_latent_space(mnist_test, encoder, decoder)
