# Hands-on session 2.1: Variational Autoencoder
## Variational Autoencoders Applied to Cardiac MRI

Made by **Nathan Painchaud** and **Pierre-Marc Jodoin** from the Universit√© de Sherbrooke, Canada.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%%capture packages_install

# Make sure the repo's package and its dependencies are installed
!pip install -e ../.

In [None]:
%%capture project_path_setup

import sys

if '../' in sys.path:
    print(sys.path)
else:
    sys.path.append('../')
    print(sys.path)

### Once again, let's start by looking at our data
Here we load MRI cardiac images and their groundtruth segmentation maps from the **ACDC dataset**.

In [None]:
import numpy as np
from src.data.acdc.dataset import Acdc
from src.visualization.utils import display_data_samples

# ACDC consists of 256x256 images with segmentation maps for 3 classes + background, so the size of the data is
data_shape = (4, 256, 256)

# Download and prepare data
acdc_train = Acdc("../data/acdc.h5", image_set="train")
acdc_val = Acdc("../data/acdc.h5", image_set="val")

# Check data by displaying random images
samples_indices = np.random.randint(len(acdc_train), size=10)
imgs, gts = zip(*[acdc_train[sample_idx] for sample_idx in samples_indices])
display_data_samples(mri=imgs, segmentation=gts)

### Let's build a deep autoencoder specialized for image processing: a convolutional autoencoder

Since convolutional networks are more complex than fully-connected networks and require a bit more code, let's tackle
one half of the autoencoder at a time. Let's start with the **encoder**.

In [None]:
from torch import nn

# Let's define the encoder architecture we want,
# with some options to configure the input and output size
def downsampling_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=2),
        nn.ReLU(),
        nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
    )

def make_encoder(data_shape, latent_space_size):
    in_channels = data_shape[0]
    shape_at_bottleneck = data_shape[1] // 16, data_shape[2] // 16
    size_at_bottleneck = shape_at_bottleneck[0] * shape_at_bottleneck[1] * 48
    return nn.Sequential(
        downsampling_block(in_channels, 48),    # Block 1 (input)
        downsampling_block(48, 96),             # Block 2
        downsampling_block(96, 192),            # Block 3
        downsampling_block(192, 48),            # Block 4 (limits number of channels to reduce total number of parameters)
        nn.Flatten(),                           # Flatten before FC-layer at the bottleneck
        nn.Linear(size_at_bottleneck, latent_space_size),   # Bottleneck
    )

# Now let's build our encoder, with an arbitrary dimensionality of the latent space
# and an input size depending on the data.
latent_space_size = 32
encoder = make_encoder(data_shape, latent_space_size*2) # here the latent space size is *2 because the encoder predicts a *mean* and *variance* vector

### Now let's look at the structure of the encoder that we have just created

In [None]:
from torchinfo import summary

summary_kwargs = dict(col_names=["input_size", "output_size", "kernel_size", "num_params"], depth=3, verbose=0)

summary(encoder, input_size=data_shape, batch_dim=0,  **summary_kwargs)

## Questions:

* How many **neurons** does this encoder network have?
* How many **parameters** does this encoder network have?
* How come some elements of the encoder network have **no kernel shape**?
* What is the size of the latent space of that encoder?

### Now that the encoder is good, let's make a decoder that mirrors the encoder

In [None]:
from src.modules import layers

# Same building blocks for the decoder as for the encoder
def upsampling_block(in_channels, out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=2, stride=2),
        nn.ReLU(),
        nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1),
        nn.ReLU(),
    )

def make_decoder(data_shape, latent_space_size):
    out_channels = data_shape[0]
    shape_at_bottleneck = data_shape[1] // 16, data_shape[2] // 16
    size_at_bottleneck = shape_at_bottleneck[0] * shape_at_bottleneck[1] * 48
    return nn.Sequential(
        # Bottleneck
        nn.Linear(latent_space_size, size_at_bottleneck),
        nn.ReLU(),
        layers.Reshape((48, *shape_at_bottleneck)),    # Restore shape before convolutional layers

        upsampling_block(48, 192),     # Block 1
        upsampling_block(192, 96),     # Block 2
        upsampling_block(96, 48),      # Block 3
        nn.ConvTranspose2d(in_channels=48, out_channels=48, kernel_size=2, stride=2), # Block 4 (output)
        nn.ReLU(),
        nn.Conv2d(in_channels=48, out_channels=out_channels, kernel_size=3, padding=1),
    )

# Now let's build our decoder, with the dimensionality of the latent space matching that of the encoder
# and an output size depending on the data.
decoder = make_decoder(data_shape, latent_space_size)

### Just like for the encoder, let's display the structure of the decoder network

In [None]:
summary(decoder, input_size=(latent_space_size,), batch_dim=0, **summary_kwargs)

## Question:

* Remember what a *ConvTranspose2d* is?

# Forward pass and loss function

Unfortunately, we cannot copy-paste the forward pass function from the MNIST notebook.  This is because of the following
reasons:

1. Here we are predicting segmentation maps with **4 values instead of the 2 black-and-white values** of the MNIST
images.  Thus, instead of minimizing the binary cross entropy, we need a **4-class cross-entropy** as the reconstruction
term in our **VAE** loss:  
$$ CrossEntropy + \lambda KL_{divergence} $$  
as shown in the hands-on document.

2. With the MNIST dataset, we were using a fully-connected neural network fed with a **vector of pixels** instead of a 2D image.
Here, we are using a _convolutional_ autoencoder, so we **need to preserve the 2D structure of our input data**. Since
we already receive the images as 2D tensors of pixels, this means that we can simply use our inputs as they are, and we don't
need to bother with flattening them to vectors anymore.

In [None]:
import torch
import torchmetrics
import torch.nn.functional as F

def kl_div(mu, logvar):
    kl_div_by_samples = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    return torch.mean(kl_div_by_samples)

def vae_forward_pass(encoder, decoder, x):
    """VAE forward pass.

    Args:
        encoder: neural net that predicts a mean and a logvar vector
        decoder: neural net that projects a point in the latent space back into the image space
        x: batch of N ACDC segmentation maps

    Returns:
        loss: crossentropy + kl_divergence loss
        x_hat: batch of N reconstructed segmentation maps
    """
    # We don't need to flatten the input images to (N, num_pixels) anymore,
    # but we need to convert them from one-channel categorical data to multi-channel one-hot format
    encoder_input = torchmetrics.utilities.data.to_onehot(x, num_classes=4).float()

    encoding_distr = encoder(encoder_input)  # Forward pass on the encoder (to get the latent space posterior)

    # We use the same trick as before to easily extract the components of the posterior distribution (mean and logvar latent vectors)
    mu, logvar = encoding_distr[:, :latent_space_size], encoding_distr[:, latent_space_size:]

    # Reparametrization trick
    # (same as before, since the latent codes are vectors regardless of input data's structure)
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std

    # Similar to the input that we didn't need to vectorize, we don't need to reshape the output to a 2D shape anymore,
    # since the convolutional network already produces a structured output
    x_hat = decoder(z)  # Forward pass on the decoder (to get the reconstructed input)
    loss = F.cross_entropy(x_hat, x) # Compute the reconstruction loss
    loss += 1e-4 * kl_div(mu, logvar)  # Loss now also includes the KL divergence term
    return loss, x_hat.argmax(dim=1) # Transform segmentation back to categorical so that it can be displayed easily

## Question:

* See the difference between this `vae_forward_pass` and the `vae_forward_pass` of the MNIST autoencoder notebook?

# Training
For the training algorithm, we can copy-paste the generic training code from our fully-connected
autoencoders on the MNIST dataset. The only difference is that we want to reconstruct **segmentation maps**
(the targets) instead of **grayscale images**.

In [None]:
import os
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

# Define some training hyperparameters
epochs = 30
batch_size = 64

def train(forward_pass_fn, encoder, decoder, optimizer, train_data, val_data, device="cuda"):
    # Create dataloaders from the data
    # Those are PyTorch's abstraction to help iterate over the data
    data_loader_kwargs = {"batch_size": batch_size, "num_workers": os.cpu_count() - 1, "pin_memory": True}
    train_dataloader = DataLoader(train_data, shuffle=True, **data_loader_kwargs)
    val_dataloader = DataLoader(val_data, **data_loader_kwargs)

    fit_pbar = tqdm(range(epochs), desc="Training", unit="epoch")
    pbar_metrics = {"train_loss": None, "val_loss": None}
    for epoch in fit_pbar:
        # Set model in training mode before training

        # Train once over all the training data
        for _, y in train_dataloader:
            y = y.to(device)    # Move the data tensor to the device
            optimizer.zero_grad()   # Make sure gradients are reset
            train_loss, _ = forward_pass_fn(encoder, decoder, y)    # Forward pass+loss
            train_loss.backward()   # Backward pass
            optimizer.step()    # Update parameters w.r.t. optimizer and gradients
            pbar_metrics["train_loss"] = train_loss.item()
            fit_pbar.set_postfix(pbar_metrics)

        # At the end of the epoch, check performance against the validation data
        for _, y in val_dataloader:
            y = y.to(device)    # Move the data tensor to the device
            val_loss, _ = forward_pass_fn(encoder, decoder, y)
            pbar_metrics["val_loss"] = val_loss.item()
            fit_pbar.set_postfix(pbar_metrics)

### Ready to train the variational autoencoder!
Note: this operation may take 5 minutes.

In [None]:
optimizer = torch.optim.Adam([*encoder.parameters(), *decoder.parameters()])
train(vae_forward_pass, encoder, decoder, optimizer, acdc_train, acdc_val)

### Now, let's take a look at the results on the validation set
Note: each time you execute the following cell, you will get different results.

If you want better looking reconstructed cardiac shapes, you may retrain your model with more epochs.

In [None]:
from src.visualization.utils import display_autoencoder_results

display_autoencoder_results(
    acdc_val, lambda x: vae_forward_pass(encoder, decoder, x.cuda())[1], reconstruct_target=True
)

We can use a dimensionality reduction algorithm, in our case [UMAP](https://towardsdatascience.com/how-exactly-umap-works-13e3040e1668),
to project the latent space to/from a 2D space we can visualize. Let's try to see an estimation of how the data is
distributed in the latent space.

In [None]:
from src.visualization.latent_space import explore_latent_space

explore_latent_space(
    acdc_val,
    lambda x: encoder(torchmetrics.utilities.data.to_onehot(x, num_classes=4).float())[:, :latent_space_size],
    lambda z: decoder(z).argmax(dim=1),
    data_to_encode="target",
    batch_size=64,
)

## Question:

* You may want to retrain the VAE with a smaller latent space size.  Why do you think that when we use a latent space size of 2 as for MNIST, the reconstructed cardiac shapes are of very poor quality?