EE4685
===
Machine Learning: A Bayesian Perspective
===

# Computer Exercise 2: Variational Autoencoders (VAEs)

In this tutorial, we will cover:

1. Reparameterization trick
2. Variational Autoencoders (VAEs)
3. Latent Variable Visualization
4. New Data Generation


Our info:

- Justin Dauwels (j.h.g.dauwels@tudelft.nl)

Adapted from work of:
- Luca Moschella (moschella@di.uniroma1.it)
- Antonio Norelli (norelli@di.uniroma1.it)

Course:

- *Brightspace page here*

**Import dependencies (run the following cells)**

In [None]:
# @title import dependencies

from typing import Mapping, Union, Optional
from pathlib import Path

import numpy as np
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import plotly.graph_objects as go
import plotly.express as px
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, models, transforms

import os
import pickle
from tqdm import tqdm

from __future__ import print_function, division


In [None]:
# @title reproducibility stuff

import random
torch.manual_seed(42)
np.random.seed(42)
random.seed(0)

torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True  # Note that this Deterministic mode can have a performance impact
torch.backends.cudnn.benchmark = False

# Quick Recap

Many machine learning and deep learning methods are based on a variety of concepts from probability theory and information theory. 

Since Variational AutoEncoders are very intensive on this aspect, it can be useful to brush up some fundamentals.

If you are not already confident with the concepts of marginal probability, Shannon entropy, cross-entropy, Kullback–Leibler divergence and mutual information, we suggest you a fast reading, this [blog post](http://colah.github.io/posts/2015-09-Visual-Information/) with awesome visualizations by Chris Olah; and a book reference, [*Information Theory, Inference, and Learning Algorithms*](https://www.inference.org.uk/itprnn/book.pdf) by David J.C. MacKay.


## Variational Autoencoders (VAEs)

You may have already encountered autoencoders in the context of signal processing or information theory as compression systems. 
The task of data compression is a good starting point to understand also neural autoencoders.

In the context of deep learning we refer to autoencoders (AEs) as neural networks where the expected output coincides with the input and the architecture contains a bottleneck (see figure below). The bottleneck favour a precise behaviour; the first part of the AE, or **Encoder**, learns to distill the information useful to *distinguish* between input samples, a **latent code** sized as the limited bottleneck, that is decoded to reconstruct the original input by the second part of the network, the **Decoder**. To accomplish this result, the Decoder should learn all the information about the domain of the data not useful to distinguish between samples, but needed to reconstruct them. The loss of a standard AE is simply a reconstruction loss, e.g. in the case of an image a comparison of each pixel, tipically with a cross-entropy.

From a very pragmatic point of view, VAEs are just AEs trained with a latent code disturbed by noise. The amount of noise is itself a trainable parameter, and a new term in the loss promotes higher noise. This new term balances the reconstruction term of the loss, that alone would push the amount of noise to zero.

Strengthened by this interpretation of VAEs as a way of regulirizing AEs, we can now venturing out in their rich theoretical justification based on probability theory. 



### [Auto-encoding Variational Bayes](https://arxiv.org/abs/1312.6114)
We start considering the inputs, hidden representations, and reconstructed outputs of a VAE as probabilistic random variables within a directed graphical model. 

With a Bayesian perspective:
- The encoder becomes a variational inference network, mapping observed inputs to (approximate) posterior distributions for each latent attribute. That is, our encoder describes an entire probability distribution for each latent attribute and the latent code is *sampled* from this distributions.
- The decoder becomes a generative network, capable of mapping latent codes to new samples similar to the input ones.

So the training dataset becomes the tip of an iceberg; the observed result of an **underlying causative probabilistic process**.

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/vae1.png)


The power of the resulting model is captured by Feynman’s famous [chalkboard quote](http://archives-dc.library.caltech.edu/islandora/object/ct1:483): “What I cannot create, I do not understand.”

The following gif shows the decoded samples of a continuous path over the latent space of a VAE trained on MNIST.

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/interp.gif)

Rather than directly outputting values for the latent state as we would in a standard AE, the encoder model of a VAE outputs entire probability distributions, during training we sample from these distributions to obtain a latent code for the decoder.

Traditional **VAEs make the strong assumption on the distribution of latent codes as a multivariate Gaussian with a diagonal covariance structure**, so each component is normally distributed and indipendent from the others. 

A $n$-dimensional gaussian is defined by a vector $(\mu_1, ... , \mu_n)$ and a covariance matrix $\Sigma \in \mathbb{R}^{n \times n}$. Since we are assuming independent components, $\Sigma$ is a diagonal matrix and we can describe it with a single vector of size $n$ containing the diagonal. 

So the Encoder of a VAE outputs two vectors collecting the means and the variances (instead of encoding the variance $\sigma^2$ we will work for convenience with the logarithm of the variance $log(\sigma^2)$). 

Ponder a bit about the strength of such assumption. For instance a VAE trained on MNIST will build an unimodal distribution on the latent space containing the whole dataset. But is sampling from this distribution the true underlying causative probabilistic process? Would it not be more reasonable to have a ten-modal distribution or maybe a distinct probability distribution for each digit? Certainly it depends on what you want to do with these representations. If you want to use VAEs as generative models, a unimodal gaussian is convenient, for instance you cannot obtain these nice interpolations between different digits with a multimodal distribution. But if your goal is to obtain a useful representation of the world, for instance in reinforcement learning, is this a satisfactory assumption? What does it mean to interpolate between the two very distinct concepts of a 3 and a 5?


![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/implementation.png)



### Intuition: encoding probability distributions (from Jeremy Jordan's [blog post](https://www.jeremyjordan.me/variational-autoencoders/))



Let's suppose we've trained an autoencoder model on a large dataset of faces with a encoding dimension of 6. An ideal autoencoder will learn descriptive attributes of faces such as skin color, whether or not the person is wearing glasses, etc. in an attempt to describe an observation in some compressed representation.

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/int1.png)

In the example above, we've described the input image in terms of its latent attributes using a **single value to describe each attribute**. 

However, we may prefer to represent each latent attribute as a range of possible values. For instance, what *single value* would you assign for the smile attribute if you feed in a photo of the Mona Lisa? Using a variational autoencoder, we can describe latent attributes in probabilistic terms.

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/int2.png)


With this approach, we'll now represent each latent attribute for a given input as a probability distribution. When decoding from the latent state, we'll randomly sample from each latent state distribution to generate a vector as input for our decoder model.

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/int3.png)

By constructing our encoder model to output a range of possible values (a statistical distribution) from which we'll randomly sample to feed into our decoder model, we're essentially enforcing a continuous, smooth latent space representation. For any sampling of the latent distributions, we're expecting our decoder model to be able to accurately reconstruct the input. Thus, values which are nearby to one another in latent space should correspond with very similar reconstructions.

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/int4.png)

# 1. Reparameterization Trick

### Implementation overview

Now that we are close to code, have you noticed the elephant in the room? VAEs require to sample from the distribution predicted by the encoder to obtain a latent code for the decoder, but how can we backpropagate gradients through a fundamentally random operation such as sampling?



#### Reparameterization trick

We need the reparameterization trick in order to backpropagate through a random node. Assume we sample a variable $z$ from a distribution $q_\phi$ and pass it to a funciton $f(z)$. For backpropagation, we would want **the gradient of expectation** of $f(z)$, which is

![Equation To Differentiate](https://raw.githubusercontent.com/sentient-codebot/cmp_exr_vae/main/images/reparam_eq.png)

Normally, such the expectation operation is approximated by **Monte Carlo**. However, when the distribution is **parameterized** by $\phi$ and we want gradient w.r.t. $\phi$, Monte Carlo won't work. We need find a way to **remove $\phi$ out of the distribution!** 

## Question 1
Given:

In [None]:
# Some data
data = torch.rand(1,10)

# Our network
mid_layer = torch.nn.Linear(in_features=10,  out_features=10)
out_layer = torch.nn.Linear(in_features=10,  out_features=10)

# Target output
target = torch.ones(1,10)

We want optimize the parameters in order to obtain a prediction of `target` from the following computation:

In [None]:
# data -> mid_layer -> pred_mean
#                       pred_mean, pred_std -> pred_distribution -> sample -> prediction

# Generate the mean and variance of a normal distribution
pred_mean = mid_layer(data) # get mean through our data
pred_std = torch.ones(size=(1,10)) # assume the normal distribution has a variance of 1. 

pred_dist = torch.distributions.normal.Normal(pred_mean, pred_std) # create our prediction distribution
sampling = torch.normal(mean=pred_mean, std=pred_std)     # random numbers drawn from **separate** normal distributions
                                                          # whose mean and standard deviation are given.
sampling = pred_dist.sample()

pred_output = out_layer(sampling)

# Loss function
loss = F.mse_loss(pred_output, target)
mid_layer.zero_grad()
out_layer.zero_grad()
loss.backward(retain_graph=True)


But...

In [None]:
print(mid_layer.weight.grad)


### Question 1.1
Why do we not have gradients of the `mid_layer`? Fill in the blanks in the following code.
 
Hint: we can check the `grad_fn` and `requires_grad` of each variable in our operations. 


In [None]:
print('grad_fn      ', loss.grad_fn)
print('requires_grad', loss.requires_grad)
print('-------------')

print('grad_fn      ', ...) # check the `grad_fn` of `pred_output`
print('requires_grad', ...) # check whether `pred_output` requires gradient
print('-------------')

print('grad_fn      ', ...) # check the `grad_fn` of ...
print('requires_grad', ...) # check whether ... requires `gradient`


### Question 1.2
Let us now reparametrize the operations! Please fill in the blanks in the code below. 

Recap:
 
![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/trick.png)


In [None]:
# Posterior distribution prediction
pred_mean = mid_layer(data)
pred_std = torch.ones_like(data)

# Replace the `pred_dist.sample()` above with following code
epsilon = ...                   # random numbers drawn from **separate** canonical normal distributions
sampling = ...                  # Rescale each normal distribution to match our predictions

pred_output = out_layer(sampling)

# Loss function
loss = F.mse_loss(pred_output, target)
mid_layer.zero_grad()
out_layer.zero_grad()
loss.backward()


Now print out the gradient of `mid_layer.weight`

In [None]:
print(...)

### Extended Reading
PyTorch has already implemented some ready-to-use differentiable sampling functions. 

In [None]:
# Forward
data = torch.rand(1,10)
pred_mean = mid_layer(data)
pred_std = torch.ones_like(data)

normal_distribution = torch.distributions.Normal(loc=pred_mean, scale=pred_std,)
sampling = normal_distribution.rsample()  # rsample allows pathwise derivatives, i.e. implements the reparametrization trick

pred_output = out_layer(sampling)

# Loss
loss = F.mse_loss(pred_output, target)
mid_layer.zero_grad()
out_layer.zero_grad()
loss.backward()

mid_layer.weight.grad

#### Note
- Check out the [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) package to learn more.
- Check out [this](https://bochang.me/blog/posts/pytorch-distributions/) and [this](https://ericmjl.github.io/blog/2019/5/29/reasoning-about-shapes-and-probability-distributions/) to understand abount `batch_shape`, `sample_shape` and `event_shape` in the distributions package.

**Warning** It is not always easy to do reparameterization on any distribution. Some `rsample` functions in the torch.distributions class are not implemented. 

# 2. Implementing a VAE



#### **Encoder and Decoder**
We'll use a convolutional encoder and decoder, which generally gives better performance than fully connected versions that have the same number of parameters.

**Bottleneck**

In the convolution layers, we increase the channels as we approach the bottleneck, but note that the total number of features still decreases, since the channels increase by a factor of 2 in each convolution, but the spatial size decreases by a factor of 4.


**Deconvolution (Transposed Convolution)**

In the decoder we are using a transposed convolution, also known as deconvolution (differing from the "deconvolution" in signal processing). You can think of it as the convolution in the opposite direction.  ([here](https://github.com/vdumoulin/conv_arithmetic) more visualizations)

*Blue maps are inputs, and cyan maps are outputs*

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/deconv.gif)



**Kernel size**

We are using `kernel_size=4`. The motivation behind this choice is to lessen the checkerboard artifacts described [here](https://distill.pub/2016/deconv-checkerboard)

## Question 2
Please answer the following questions and complete the code according to instructions. 

In [None]:
# Define hyperparameters

# 2-d latent space, parameter count in same order of magnitude
# as in the original VAE paper (VAE paper has about 3x as many)
latent_dims = 2
num_epochs = 40
batch_size = 128
capacity = 64
learning_rate = 1e-3
variational_beta = 1
use_gpu = True

device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")

### Question 2.1 
MNIST images show digits from `0-9` in `28x28` grayscale images. We do not center them at 0, because we will be using a binary cross-entropy loss that treats pixel values as probabilities $p \in [0,1]$. 
Please complete the following code for data loading. 


In [None]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

img_transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform)
train_dataloader = DataLoader(...) # setup a dataloader from `train_dataset`, use predefined batch size, and set `shuffle` to True

test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform)
test_dataloader = DataLoader(test_dataset, batch_size=max(10000, batch_size), shuffle=True)

### Question 2.2
Please complete the following code for defining the network. Why do we calculate the log variance instead of variance? 

In [None]:
class Encoder(nn.Module):
    def __init__(self, hidden_channels: int, latent_dim: int) -> None:
        """
        Simple encoder module

        It predicts the `mean` and `log(variance)` parameters.

        The choice to use the `log(variance)` is for stability reasons:
        https://stats.stackexchange.com/a/353222/284141
        """
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, 
                               out_channels=hidden_channels, 
                               kernel_size=4, 
                               stride=2, 
                               padding=1) # out: ...

        self.conv2 = nn.Conv2d(in_channels=hidden_channels, 
                               out_channels=hidden_channels*2, 
                               kernel_size=4, 
                               stride=2, 
                               padding=1) # out: ...

        self.fc_mu = nn.Linear(in_features=..., # please calculate the correct input dimension
                               out_features=latent_dim)
        self.fc_logvar = nn.Linear(in_features=..., # please calculate the correct input dimension
                                   out_features=latent_dim)
            
        self.activation = nn.ReLU()

    def forward(self, x: torch.Tensor):
        """
        :param x: batch of images with shape [batch, channels, w, h]
        :returns: the predicted mean and log(variance)
        """
        x = self.activation(self.conv1(x))
        x = self.activation(self.conv2(x))

        x = x.view(x.shape[0], -1)

        x_mu = self.fc_mu(x)
        x_logvar = self.fc_logvar(x)

        return x_mu, x_logvar

class Decoder(nn.Module):
    def __init__(self, hidden_channels: int, latent_dim: int) -> None:
        """
        Simple decoder module
        """
        super().__init__()
        self.hidden_channels = hidden_channels

        self.fc = nn.Linear(in_features=latent_dim, 
                            out_features=...) # please calculate the correct output dimension

        self.conv2 = nn.ConvTranspose2d(in_channels=hidden_channels*2, 
                                        out_channels=hidden_channels, 
                                        kernel_size=4, 
                                        stride=2, 
                                        padding=1)
        self.conv1 = nn.ConvTranspose2d(in_channels=hidden_channels, 
                                        out_channels=1, 
                                        kernel_size=4, 
                                        stride=2, 
                                        padding=1)
        
        self.activation = nn.ReLU()
        
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        :param x: a sample from the distribution governed by the mean and log(var)
        :returns: a reconstructed image with size [batch, 1, w, h]
        """
        x = self.fc(x)
        x = x.view(x.size(0), self.hidden_channels*2, ..., ...) # please fill in the last two dimensions
        x = self.activation(self.conv2(x))
        x = torch.sigmoid(self.conv1(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
        return x
        

#### **VAE**

The VAE definition is straghforward since we have all the pieces we need:
- The Encoder
- The reparametrization trick we saw in the last section
- The Decoder




In [None]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, hidden_channels: int, latent_dim: int):
        super().__init__()
        self.encoder = Encoder(hidden_channels=hidden_channels, 
                               latent_dim=latent_dim)
        self.decoder = Decoder(hidden_channels=hidden_channels, 
                               latent_dim=latent_dim)
    
    def forward(self, x):
        latent_mu, latent_logvar = self.encoder(x)
        latent = self.latent_sample(latent_mu, latent_logvar)
        x_recon = self.decoder(latent)
        return x_recon, latent_mu, latent_logvar
    
    def latent_sample(self, mu, logvar):

        if self.training:
            # the reparameterization trick
            std = (logvar * 0.5).exp()
            return ? # return a sample drawing from the normal distribution parameterized by `mu` and `std`. hint: the reparameterization trick
            # std = logvar.mul(0.5).exp_()
            # eps = torch.empty_like(std).normal_()
            # return eps.mul(std).add_(mu)
        else:
            return mu

vae = VariationalAutoencoder(hidden_channels=capacity, latent_dim=latent_dims)
vae = vae.to(device)

num_params = sum(p.numel() for p in vae.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)

#### Loss function



In [None]:
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    # You can look at the derivation of the KL term here https://arxiv.org/pdf/1907.08956.pdf
    kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + variational_beta * kldivergence

What role does `variational_beta` play here? 

### Training VAE


You don't need to change the code in the block below. 

In [None]:
# @title training utilities

import pandas as pd
import numpy as np

from typing import Callable, Optional
def make_averager() -> Callable[[Optional[float]], float]:
    """ Returns a function that maintains a running average

    :returns: running average function
    """
    count = 0
    total = 0

    def averager(new_value: Optional[float]) -> float:
        """ Running averager

        :param new_value: number to add to the running average,
                          if None returns the current average
        :returns: the current average
        """
        nonlocal count, total
        if new_value is None:
            return total / count if count else float("nan")
        count += 1
        total += new_value
        return total / count

    return averager

def save_in_dataframe(df_log, labels, mus, stddevs, epoch):
    df = pd.DataFrame()

    df['index'] = np.arange(len(mus[:,0])) * epoch 
    df['image_ind'] = np.arange(len(mus[:,0]))
    df['class'] = labels.data.numpy().astype(str)
    df['mu_x'] = mus[:,0]
    df['mu_y'] = mus[:,1]
    df['std_x'] = stddevs[:,0]
    df['std_y'] = stddevs[:,1]
    df['epoch'] = np.ones(len(mus[:,0])) * epoch
    
    df_log = pd.concat([df_log, df])

    return df_log

def run_on_testbatch(df_log, vae, epoch, x, y):
    with torch.no_grad():
        x = x.to(device)
        x, mus, stddevs = vae(x)
        x = x.to('cpu')
        mus = mus.to('cpu').data.numpy()
        stddevs = stddevs.to('cpu').mul(0.5).exp_().data.numpy()

    return save_in_dataframe(df_log, y, mus, stddevs, epoch)

def plot_loss(losses):
    fig = go.Figure()

    fig.add_trace(go.Scatter(
        x=list(range(len(losses))),
        y=losses,
        # name="Name of Trace 1"       # this sets its legend entry
    ))

    fig.update_layout(
        title="Train loss",
        xaxis_title="Epoch",
        yaxis_title="Loss",
        font=dict(
            family="Courier New, monospace",
            size=18,
            color="#7f7f7f"
        )
    )
    return fig

def refresh_bar(bar, desc):
    bar.set_description(desc)
    bar.refresh()


### Question 2.3 
Please fill in the missing code below. 

In [None]:
# 🚥 Training 🚥

vae = VariationalAutoencoder(hidden_channels=capacity, latent_dim=latent_dims)
vae = vae.to(device)
optimizer = torch.optim.Adam(...) # setup an Adam optimizer using predefined lr, and a weight decay of 1e-5. 

# set to training mode
vae.train()

df_log = pd.DataFrame()
test_batch_x, test_batch_y = iter(test_dataloader).next() # return one batch of test data
df_log = run_on_testbatch(df_log, vae, 0, test_batch_x, test_batch_y )

train_loss_avg = []

print('Training ...')

tqdm_bar = tqdm(range(1, num_epochs+1), desc="epoch [loss: ...]")
# tqdm_iter = trange(1, num_epochs+1, desc="epoch [loss: ...]")
for epoch in tqdm_bar:
    train_loss_averager = make_averager()
        
    batch_bar =  tqdm(train_dataloader, leave=False, desc='batch', total=len(train_dataloader))
    for image_batch, _ in batch_bar:
        
        image_batch = image_batch.to(device)

        # vae reconstruction
        image_batch_recon, latent_mu, latent_logvar = vae(image_batch)
        
        # reconstruction error
        loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)
        
        # backpropagation, please fill in code here
        ...
        
        # one step of the optmizer 
        optimizer.step()
        
        refresh_bar(batch_bar, f"train batch [loss: {train_loss_averager(loss.item()):.3f}]")
    
    refresh_bar(tqdm_bar, f"epoch [loss: {train_loss_averager(None):.3f}]")

    train_loss_avg.append(train_loss_averager(None))
    df_log = run_on_testbatch(df_log, vae, epoch, test_batch_x, test_batch_y )

df_log = df_log.set_index(['index'])
plot_loss(train_loss_avg)

# 3. Latent Variable Visualization

### Extended Reading: Latent space (from Jeremy Jordan's [blog post](https://www.jeremyjordan.me/variational-autoencoders/))

Using VAEs we are able to learn **smooth latent state representations** of the input data. For standard autoencoders, we simply need to learn an encoding which allows us to reproduce the input. 

As you can see in the left-most figure, focusing only on reconstruction loss does allow us to separate out the classes (in this case, MNIST digits) which should allow our decoder model the ability to reproduce the original handwritten digit, but there's an uneven distribution of data within the latent space. In other words, there are areas in latent space which don't represent any of our observed data.

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/latent.png)



On the flip side, if we only focus only on ensuring that the latent distribution is similar to the prior distribution (through our KL divergence loss term), we end up describing every observation using the same unit Gaussian, which we subsequently sample from to describe the latent dimensions visualized. This effectively treats every observation as having the same characteristics; in other words, we've failed to describe the original data.

However, when the two terms are optimized simultaneously, we're encouraged to describe the latent state for an observation with distributions close to the prior but deviating when necessary to describe salient features of the input.

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/distr.png)


In the following cells you can visualize how the training modified the latent space of our network.




### Load Pretrained Model

If you could not run the full training above, you can download the pretrained models and avoid to run the training for enough epochs to get good reconstructions (~100 epochs).

In [None]:
if False:
    filename = 'vae_2d.pth'
    import urllib
    if not os.path.isdir('./pretrained'):
        os.makedirs('./pretrained')
    print('downloading ...')
    filepath = Path("./pretrained") / filename
    urllib.request.urlretrieve(f"http://geometry.cs.ucl.ac.uk/creativeai/pretrained/{filename}", 
                               filepath)
    vae.load_state_dict(torch.load(filepath))
    vae = vae.to(device)
    print(f'done, loaded in {device}')

In [None]:
# @title Trained VAE

import plotly.express as px

n_samples = 1470 #@param {type:"slider", min:10, max:5000, step:10}
size_exactly_as_std = False #@param {type:"boolean"}

def plot_latent_params(df, max_size: bool):
    if size_exactly_as_std:
        size_max = 200
    else:
        size_max=None

    return px.scatter(
        df.loc[df['image_ind'] < n_samples], 
        x="mu_x", y="mu_y", 
        animation_frame="epoch", animation_group="image_ind",
        size="std_x", 
        color="class", 
        hover_name="image_ind", #facet_col="class", 
        color_discrete_sequence=px.colors.qualitative.Plotly, 
        width=800, 
        height=800, 
        size_max=size_max,
        range_x=[-5, 5], 
        range_y=[-5, 5])
    
plot_latent_params(df_log, max_size = size_exactly_as_std)

In [None]:
# @title Untrained VAE

import plotly.express as px

n_samples = 190 #@param {type:"slider", min:10, max:5000, step:10}
size_exactly_as_std = False #@param {type:"boolean"}

vae_zero = VariationalAutoencoder(hidden_channels=capacity, latent_dim=latent_dims)
vae_zero = vae_zero.to(device)

df_log3 = pd.DataFrame()
df_log3 = run_on_testbatch(df_log3, vae_zero, 0, test_batch_x, test_batch_y )

    
plot_latent_params(df_log3, max_size = size_exactly_as_std)

### Evaluate on the Test Set


In [None]:
# set to evaluation mode
vae.eval()

test_loss_averager = make_averager()

with torch.no_grad():

    test_bar = tqdm(test_dataloader, total=len(test_dataloader), desc = 'batch [loss: ...]')
    for image_batch, _ in test_bar:    
        image_batch = image_batch.to(device)

        # vae reconstruction
        image_batch_recon, latent_mu, latent_logvar = vae(image_batch)

        # reconstruction error
        loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)

        refresh_bar(test_bar, f"test batch [loss: {test_loss_averager(loss.item()):.3f}]")

    
print(f'Average test loss: {test_loss_averager(None)})')


# Build a dictionary label2images for future use
from collections import defaultdict
label2img = defaultdict(list)
for img_batch, label_batch in test_dataloader:
    img_batch = img_batch.to(device)
    for i in range(img_batch.shape[0]):
        # mantain the singleton batch dimension with [i]
        label2img[label_batch[i].item()].append(img_batch[[i], ...])


### Visualize latent distributions

Inspect the latent dimensions for a few samples from the data to see the characteristics of the distribution.

![](https://raw.githubusercontent.com/lucmos/DLAI-s2-2020-tutorials/master/08/6-vae.png)


If we observe that the latent distributions appear to be very tight, we may decide to give higher weight to the KL divergence term with a parameter $β>1$, encouraging the network to learn broader distributions. 

This simple insight has led to the growth of a new class of models - disentangled variational autoencoders. As it turns out, by placing a larger emphasis on the KL divergence term we're also implicitly enforcing that the learned latent dimensions are uncorrelated (through our simplifying assumption of a diagonal covariance matrix).

$$
{\cal L}\left( {x,\hat x} \right) + \beta \sum\limits_j {KL\left( {{q_j}\left( {z|x} \right)||N\left( {0,1} \right)} \right)}
$$

### Question 3.1 
Visualize the distribution of $z$. Please complete the code below. 

Hint: [Multivariate Normal in PyTorch](https://pytorch.org/docs/stable/distributions.html#multivariatenormal) and [Normal Distribution in PyTorch](https://pytorch.org/docs/stable/distributions.html#normal)

In [None]:
#@title Latent distributions { run: "auto", output-height: 5000 }

def reconstruct_images(images, model):
    model.eval()
    with torch.no_grad():
        images, _, _ = model(images.to(device))
        images = images.clamp(0, 1)
        return images


# visualize a reconstructed image
digit = 5 #@param {type:"slider", min:0, max:9, step:1}
digits_style = 20 #@param {type:"slider", min:0, max:42, step:1}

image = label2img[digit][digits_style]
bigimage = torch.cat((image,reconstruct_images(image, vae))).cpu()

mus, logvars = ...          # get the mean and log variance of z
std = ...                   # get the standar deviation
mus = ...                   # reshape mean to omit the batch dimension

plt.figure(figsize = (2, 17))
plt.imshow(torchvision.utils.make_grid(bigimage, 10, 5).permute(1, 2, 0))
plt.title("Source - Reconstruced")
plt.axis('off')
plt.show()

# visualize the joint distribution of `z_0` and `z_1`
samples = 5000
bivariate_nd = ...          # instantiate a bivariate distribution of z according to the mean and variance, see Hint. 
bidist = bivariate_nd.sample((samples,))

fig = go.Figure()
fig.add_trace(go.Histogram2d(x=bidist[:, 0], y=bidist[:, 1], 
                             xbins=dict(start=-2, end=2, size=0.1),
                             ybins=dict(start=-2, end=2, size=0.1),
                             histnorm='probability'))
fig.update_traces(opacity=0.6)
fig.update_layout(
    title="Distribution from which we sample for each image",
    font=dict(
        family="Courier New, monospace",
        size=18,
        color="#7f7f7f"
    ),
    height=800,
    width=800,
    autosize=False,
)
fig.show()

# visualize the marginal distribution of `z_0` and `z_1`
samples = 5000
nd1 = ...                       # the marginal distribution of `z_0`
nd2 = ...                       # the marginal distribution of `z_1`

dist1 = ...                     # sample
dist2 = ...                     # sample

fig = go.Figure()
fig.add_trace(go.Histogram(x=dist1, 
                           xbins=dict(start=-2, end=2, size=0.005),
                           histnorm='probability', name="0-th latent dim"))
fig.add_trace(go.Histogram(x=dist2, 
                           xbins=dict(start=-2, end=2, size=0.005),
                           histnorm='probability', 
                           name="1-th latent dim"))
fig.update_layout(barmode='overlay')
fig.update_traces(opacity=0.4)
fig.update_layout(
    title="Probability distributions of each latent dimension",
    font=dict(
        family="Courier New, monospace",
        size=18,
        color="#7f7f7f"
    ),
    height=300
)
fig.show()


### Question 3.2 
Visualize reconstructions in the test set. Please complete the code below. 

Hint:
- Use of `Dataloader` class: [Datset and Dataloader](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)

In [None]:
import torchvision.utils

images, labels = ...            # sample a batch of test samples
reconstruced_images = ...       # get reconstructed images

# Matplolib plot, much faster for static images
# First visualise the original images
plt.figure(figsize = (17, 17))
plt.imshow(torchvision.utils.make_grid(images[1:50],10,5).permute(1, 2, 0))
plt.title("Some original images")
plt.axis('off')
plt.show()

# Reconstruct and visualise the images using the vae
plt.figure(figsize = (17, 17))
plt.imshow(torchvision.utils.make_grid(reconstruced_images[1:50], 10, 5).permute(1, 2, 0))
plt.title("Some VAE reconstruction")
plt.axis('off')
plt.show()

# # To use plotly:
# # First visualise the original images
# px.imshow(torchvision.utils.make_grid(images[1:50],10,5).permute(1, 2, 0),
#           title="Some original images",
#           color_continuous_scale='grayscale',
#           color_continuous_midpoint=0.5).show()

# # Reconstruct and visualise the images using the vae
# px.imshow(torchvision.utils.make_grid(reconstruced_images[1:50], 10, 5).permute(1, 2, 0),
#           title="Some VAE reconstruction", 
#           color_continuous_scale='grayscale',
#           color_continuous_midpoint=0.5).show()

### Question 3.3 Interpolate in Latent Space
Now we would like to interpolate in the latent space and visualize the decoded image. Giving two images to the encoder, we can get two means, corresponding to two points in the latent space. If we interpolate between these two points, what will the corresponding images look like? Please complete the code below and describe your findings. 

In [None]:
def interpolation(lambda1, model, img1, img2):
    # lambda1: a number in (0,1)
    
    with torch.no_grad():
    
        # latent vector of first image
        img1 = img1.to(device)
        latent_1, _ = ...                   # get the mean of `z` conditioned on the first image

        # latent vector of second image
        img2 = img2.to(device)
        latent_2, _ = ...                   # get the mean of `z` conditioned on the second image

        # interpolation of the two latent vectors
        inter_latent = ...                  # get the interpolated points according to lambda

        # reconstruct interpolated image
        inter_image = model.decoder(inter_latent)
        inter_image = inter_image.clamp(0, 1).cpu()

        return inter_image

In [None]:
#@title Playground interpolation { run: "auto", output-height: 5000 }

vae.eval()


num_interpolations = 40 #@param {type:"slider", min:10, max:100, step:10}

start_digit = 1 #@param {type:"slider", min:0, max:9, step:1}
end_digit = 7 #@param {type:"slider", min:0, max:9, step:1}
digits_style = 9 #@param {type:"slider", min:0, max:42, step:1}


start_image = label2img[start_digit][digits_style] # get one sample of `start_digit`
end_image = label2img[end_digit][digits_style]

# interpolation lambdas
lambda_range=np.linspace(0,1,num_interpolations)
bigimage = torch.cat([interpolation(x, vae, start_image, end_image) for x in lambda_range])


# Matplolib plot, much faster for static images
plt.figure(figsize = (17, 17))
plt.imshow(torchvision.utils.make_grid(bigimage, 10).permute(1, 2, 0))
plt.title("VAE interpolation: from top left to lower right")
plt.axis('off')
plt.show()

# # To use plotly:
# px.imshow(torchvision.utils.make_grid(bigimage, 10).permute(1, 2, 0),
#           title="VAE interpolation: from top left to lower right", 
#           color_continuous_scale='grayscale',
#           color_continuous_midpoint=0.5).show()

# 4. Generate New Data (VAE as Generator)

By sampling from the latent space, we can use the decoder network to form a generative model capable of creating new data similar to what was observed during training. Specifically, we'll sample from the prior distribution $p(z)$ which we assumed follows a unit Gaussian distribution.

Although the generated digits are not perfect, they are usually better than for a non-variational Autoencoder (compare results for the 10d VAE to the results for the autoencoder).

Similar to autoencoders, the manifold of latent vectors that decode to valid digits is sparser in higher-dimensional latent spaces. Increasing the weight of the KL-divergence term in the loss (increasing `variational_beta`) makes the manifold less sparse at the cost of a lower-quality reconstruction. A pre-trained model with `variational_beta = 10` is available at `./pretrained/vae_10d_beta10.pth`.

### Question 4.1
Sample from our prior distribution of $z$ (recall what distribution is that), and visualize the decoded images. Complete the code below. 

In [None]:
#@title Playground random generations { run: "auto", output-height: 5000 }

vae.eval()

with torch.no_grad():

    # sample latent vectors from the normal distribution, our prior distribution
    latent = ...                # sample from prior distribution. hint: what dimension should it be? 

    # reconstruct images from the latent vectors
    img_recon = vae.decoder(latent) # decode images
    img_recon = img_recon.cpu().clamp(0, 1)

    # Matplolib plot, much faster for static images
    plt.figure(figsize = (17, 17))
    plt.imshow(torchvision.utils.make_grid(img_recon.data[:100],10,5).permute(1, 2, 0))
    plt.title("VAE generation from prior distribution")
    plt.axis('off')
    plt.show()
    
    # # To use plotly:
    # px.imshow(torchvision.utils.make_grid(img_recon.data[:100],10,5).permute(1, 2, 0),
    #       title="VAE generation from prior distribution", 
    #       color_continuous_scale='grayscale',
    #       color_continuous_midpoint=0.5,
    #       height=1000).show()

### Quesiton 4.2 Traverse in the 2D Latent Space

We plot a figure below to visualize the data generated by the decoder network of a variational autoencoder trained on the MNIST handwritten digits dataset. 

Here, we linearly **interpolate a grid of values** and display the output of our decoder network. Please fill in the missing code below. 

In [None]:
#@title Playground 2D linear interp { run: "auto", output-height: 5000 }

import matplotlib.pyplot as plt

latents_lims = ...      # absolute bound of our traverse. what value would be appropriate? hint: think of our assumed prior distribution
num_interpolations = 30 #@param {type:"slider", min:10, max:50, step:2}

# load a network that was trained with a 2d latent space
if latent_dims != 2:
    print('Please change the parameters to two latent dimensions.')
    
with torch.no_grad():
    
    # create a sample grid in 2d latent space
    latent_interpolation = torch.linspace(-latents_lims, latents_lims, num_interpolations)
    latent_grid = torch.stack(
        (
            latent_interpolation.repeat(num_interpolations, 1),
            latent_interpolation[:, None].repeat(1, num_interpolations)
        ), dim=-1).view(-1, 2)

    # reconstruct images from the latent vectors
    latent_grid = latent_grid.to(device)
    image_recon = vae.decoder(latent_grid)
    image_recon = image_recon.cpu()

    # Matplolib plot, much faster for static images
    plt.figure(figsize = (17, 17))
    plt.imshow(torchvision.utils.make_grid(image_recon.data[:num_interpolations ** 2], 
                                          num_interpolations).permute(1, 2, 0))
    plt.title("2D latent space")
    plt.axis('off')
    plt.show()
    
    # To use plotly:
    # px.imshow(torchvision.utils.make_grid(image_recon.data[:num_interpolations ** 2], 
    #                                       num_interpolations),
    #       title="2D latent space", 
    #       color_continuous_scale='grayscale',
    #       color_continuous_midpoint=0.5,
    #       height=1000).show()

...why is it distorted?

We are doing a linear interpolation over a (multivariate) gaussian distribution!

### Extened Reading: Sample proportionally to the distribution
To get an undistorted sense of the full latent manifold, we can sample and decode latent space coordinates proportionally to the model’s distribution over latent space. In other words, we simply sample relative to our chosen prior distribution over $z$. In our case, this means sampling linearly spaced percentiles from the [inverse CDF](http://work.thaslwanter.at/Stats/html/statsDistributions.html#other-important-presentations-of-probability-densities) of a spherical Gaussian. For the icdf implementation of torch.distributions, see [Normal Distribution](https://pytorch.org/docs/stable/distributions.html#normal). 

In [None]:
#@title Playground 2D icdf interp { run: "auto", output-height: 5000 }

import matplotlib.pyplot as plt

num_interpolations = 50 #@param {type:"slider", min:10, max:50, step:2}

from scipy import stats

nd = torch.distributions.Normal(loc=torch.as_tensor([0.]),
                                scale=torch.as_tensor([1.]))

with torch.no_grad():
    
    # create a sample grid in 2d latent space
    latent_interpolation = torch.linspace(0.001, 0.999, num_interpolations)
    latent_grid = torch.stack(
        (
            latent_interpolation.repeat(num_interpolations, 1),
            latent_interpolation[:, None].repeat(1, num_interpolations)
        ), dim=-1).view(-1, 2)

    latent_grid = nd.icdf(latent_grid) 
    # reconstruct images from the latent vectors
    latent_grid = latent_grid.to(device) # move `latent_grid` to `device` (optional)
    image_recon = vae.decoder(latent_grid) # decode latent grid
    image_recon = image_recon.cpu()

    # Matplolib plot, much faster for static images
    plt.figure(figsize = (17, 17))
    plt.imshow(torchvision.utils.make_grid(image_recon.data[:num_interpolations ** 2], 
                                          num_interpolations).permute(1, 2, 0))
    plt.title("2D latent space")
    plt.axis('off')
    plt.show()
    

# Further Discussion
1. Do $z_0$ and $z_1$ have correlation according to their prior and posterior distribution? If not, why and how could we incorporate correlation into our model? 
2. According to the latent space visualization, there is some correlation between the latent distribution and digit label. Try to show this correlation. With your (pre)trained VAE try to build a classifier by changing the decoder with some fully connected layers. 