# Problem 1: Gradient-based variational inference

Let us consider the exact same model and data as in the previous exercise set:
\begin{align*}
x_n &\sim \text{Poisson}(uv),\\
u &\sim \text{Gamma}(6, 1),\\
v &\sim \text{Gamma}(3, 3).
\end{align*}
with the same data $\mathbf{x}=[5, 3, 9, 13, 5, 3, 5, 1, 2, 7, 6, 5, 6, 7, 4]$.

Now we use gradient-based methods to learn variational approximation for the posterior. 

Write code that:
1. Specifies the approximation using suitable pytorch.distribution
2. Evaluates the ELBO by using $M$ samples drawn from the approximation with .rsample()
3. Optimizes ELBO wrt the parameters of the approximation
4. Plots the approximation on top of Gibbs samples
5. Plots convergence of the ELBO estimate (note that with small $M$ this may be somewhat noisy). Use this plot to check that your optimizer actually has converged to a good solution.

Use the code to try out alternative models and approximations. For each of the three alternatives below, always show both the convergence curve and the resulting posterior approximation and **explain what you see**.
1. Use the same approximation family as before, so that $q(u,v)$ is a product of two gamma distributions. Do you get the same result as with CAVI?
2. Use multivariate normal approximation, so that $q(u,v)$ is a bivariate normal distribution. Explain what changed.
3. Change the prior $p(u,v)$ from product of two gamma distributions to product of two half-normal distributions with scales of your own choosing. Explain what happens.

HINTS:
1. https://pytorch.org/docs/stable/distributions.html helps with the distribution syntax etc.
2. You can use any $M$, but probably it is best to avoid very small ones. It is a good idea to quickly explore how the estimate behaves as a function of $M$.
3. The easiest way to parameterize the multivariate normal is to use some 2x2 matrix $A$ as learnable parameters but then give "L = torch.tril(A)" as the Cholesky parameter for the multivariate normal distribution. We have one extra parameter in $A$ that is never optimized or used (the top right corner), but that does not matter.
4. Be careful with bounds: You have positivity requirement for two things here, for the **parameters of the approximation terms** (for gamma approximation) and additionally for **the samples u and v drawn from the approximation** (for the normal approximation that could result in negative samples). You can, for example, use "torch.nn.functional.softplus(alpha_unconstrained)" as the alpha-parameter in a gamma distribution, and you can truncate the samples from the normal distribution to a small positive constant or push those also through softplus. The former is always valid but the latter is strictly speaking wrong as we then use truncated normal as approximation, but we can ignore this problem here for simplicity.
5. It is possible to write general code that directly supports arbitrary distributions, but that is quite tedious. You can definitely have separate copies of your code for the different choices of $q(u,v)$ if that is easier -- that's what I will be doing in model solutions anyway.
6. You already know the optimal solution for the gamma approximation based on Exercise 3. The values are quite large, so you might want to also initialize your approximation with numbers of similar magnitude to make the optimization problem a bit easier.

The cell below again has the model definition and Gibbs sampler for ease of result presentation.

In [21]:
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import random
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
%matplotlib inline
plt.rcParams["figure.figsize"] = (12, 6)

u_alpha = 6.
u_beta = 1.
v_alpha = 3.
v_beta = 3.

x = np.array([5,3,9,13,5,3,5,1,2,7,6,5,6,7,4])
N = len(x)

# Priors
u_prior = stats.gamma(u_alpha, scale=1./u_beta)
v_prior = stats.gamma(v_alpha, scale=1./v_beta)

def log_density(data, u, v):
    likelihood = stats.poisson(u*v)
    return np.sum(likelihood.logpmf(data)) + u_prior.logpdf(u) + v_prior.logpdf(v)

def Gibbs(x, u, v, T):
    # Storage for samples
    samples = [u, v]
    sumx = sum(x)
    N = len(x)
    for t in range(T):
        # Sample u conditional on v and data
        log_lambda_term = sumx + u_alpha
        lambda_term = N*v + u_beta
        u = np.random.gamma(log_lambda_term, scale = 1.0 / lambda_term)
        
        # Sample v conditional on u and data
        log_lambda_term = sumx + v_alpha
        lambda_term = N*u + v_beta
        v = np.random.gamma(log_lambda_term, scale = 1.0 / lambda_term)

        samples = np.vstack([samples, [u, v]])
    return samples

In [None]:
import torch
from torch.nn.functional import softplus as sp

# Define data and the optimization problem
data = torch.tensor(x)
parameters = ... # Set also some initial values here
opt = optim.Adam(parameters, lr=...)
M = ...
nIter = ...

# Priors
u_prior = torch.distributions.Gamma(u_alpha, u_beta)
v_prior = torch.distributions.Gamma(v_alpha, v_beta)

    
ELBOS = list()
for i in range(nIter):
    opt.zero_grad()

    # Define approximation
    q_uv = ...

    # Obtain samples
    uv = ...
    
    # Compute ELBO
    ELBO = ...
    
    # We need a minimization problem but ELBO is to be maximized
    loss = -ELBO
    loss.backward()
    ELBOS.append(ELBO.item())
    opt.step()
    
# Plotting goes here
# 1. Convergence plot
# 2. Illustration of the posterior

# Problem 2: Variational autoencoder

The code below implements most parts of a **variational autoencoder for the MNIST digits data** and also downloads the data for you on the first run. We use two-dimensional representations (K=2) for ease of plotting, fully connected neural networks for all components of the model, and normal likelihood for simplicity. A better model using convolutional networks, higher-dimensional representations and more suitable likelihoods would follow the exact same general algorithm.

You are free to consult external sources (including ones that provide code) to understand the model better, but remember to mention what you looked at.

Complete the implementation by
1. Specifying the prior distribution of the latent variables as a normal distribution with zero mean and unit covariance
2. Compute the parameters of the approximation for the set of samples in the current mini-batch. Remember that the standard deviation has to be positive.
3. Form the actual approximation and obtain samples from it using .rsample()
4. Compute the ELBO, using normal likelihood with fixed standard deviation (obs_sigma). Since every data point is independent both in terms of the prior and the likelihood, it is probably easiest to write the expression for a single data point and then take the mean of those. By taking the mean instead of sum, you get numbers that are comparable over multiple batch sizes. Remember that our observations have D=784 dimensions and you need to sum over those in the log-likelihood part.

Then run the code and inspect how it works. You should be seeing ELBO improve and the mean reconstructions for the images should look somehow reasonable. If this is not the case, try to guess what is wrong and debug your code. You can also change the parameters (number of hidden layers etc) if you want to further improve the results.

In [None]:
import torch
import torch.nn as nn
import torchvision
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Set hyperparameters of the model and optimization
K = 2
obs_sigma = 0.1
batch_size = 50
numEpoch = 20 
lr = 0.001

# MNIST data 
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                             ])),
  batch_size=batch_size, shuffle=True)

# Prior distribution for latent variables
p_z = ...

# Encoder and decoder specifications. Both are fully connected networks, so no CNN magic here
D = 28*28
H = 40
encoder_mu = nn.Sequential(nn.Linear(D,H), nn.ReLU(),
                           nn.Linear(H,H), nn.ReLU(),
                           nn.Linear(H,K,bias=True))
encoder_sigma = nn.Sequential(nn.Linear(D,H), nn.ReLU(),
                              nn.Linear(H,H), nn.ReLU(),
                              nn.Linear(H,K,bias=True))
decoder = nn.Sequential(nn.Linear(K,H), nn.ReLU(),
                        nn.Linear(H,H), nn.ReLU(),
                        nn.Linear(H,D,bias=True))

# Optimize over parameters of all networks
params = list(encoder_mu.parameters()) + list(encoder_sigma.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=lr)

elbos = []
for i in tqdm(range(numEpoch)):
    batches = iter(train_loader)

    epochloss = 0.
    for j in range(len(batches)):
        optimizer.zero_grad()

        # Next batch of samples
        batch_data, batch_targets = next(batches)
        x = batch_data.reshape((batch_size,-1))
    
        # Form parameters of approximation
        mu_approx = ...
        sigma_approx = ...
        
        # Sample from approximation
        q_z_x = ...
        z = ...

        # Find mean parameters of observed data
        x_mean = ...
    
        # ELBO
        ELBO_for_one_point = ...
        loss = - torch.mean(ELBO_for_one_point)
        epochloss += loss
    
        loss.backward()
        optimizer.step()
    elbos.append(-epochloss/len(batches))

Below you find code for making the plots you need. Feel free to improve the plots if you want to show something else, for example illustrate the variance of the representations.

In [None]:
# Convergence plot
plt.rcParams["figure.figsize"] = (10, 5)
plt.plot(elbos)
plt.title("Convergence plot")
plt.xlabel('Epoch')
_ = plt.ylabel('ELBO')
plt.show()


# Illustration of some samples, showing the mean of the reconstruction
plt.rcParams["figure.figsize"] = (10, 10)
# Note: Uses the values from the last iteration of the algorithm, so this shows some
# examples in the last minibatch
for sam in range(8):
    plt.subplot(4,4,sam*2+1)
    plt.imshow(x[sam,:].reshape(28,28))

    plt.subplot(4,4,sam*2+2)
    plt.imshow(x_mean[sam,:].detach().reshape(28,28))
plt.show()
    
    
# 2D visualization of the data in the latent space
# Do this for 10,000 examples, forming a batch of those and computing q() for them
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                             ])),
  batch_size=10000, shuffle=True)

batches = iter(train_loader)
batch_data, batch_targets = next(batches)
x = batch_data.reshape((10000,-1))

# Find the parameters of the approximation in the same way as during optimization
mu_approx = ...
sigma_approx = ...

plt.rcParams["figure.figsize"] = (10, 10)
for c in range(10):
    plt.plot(mu.detach()[batch_targets==c,0], mu.detach()[batch_targets==c,1], '.', alpha=0.8)
    # Perhaps add here some way of illustrating the variance of the embedding
    #plt.plot(...)

plt.title("Latent representation")
plt.show()

# Problem 3: Normalizing flows as variational approximation

Read through the paper "Sylvester Normalizing Flows for Variational Inference" by van den Berg et al. (UAI, 2018) available at http://auai.org/uai2018/proceedings/papers/156.pdf and watch the 15-minute conference precentation explaining the paper available at https://www.youtube.com/watch?v=VeYyUcIDVHI&ab_channel=UAI2018. Note that this is not the first paper tha proposed using normalising flows for variational inference, but I chose this because it is easier to read and understand.

Answer to the following questions. If you use the notebook to write the answers, please use the 'Markdown' mode for the cell and write equations in latex notation inside dollar symbols. Hand-written answers are also fine, and illustrations that help understand the concepts are appreciated.
1. Explain briefly **how the Sylvester flow works** -- explain it also using mathematical notation and tell the main characteristics. No need to go through any proofs or the details for different special cases, but focus on explaining the main principle.
2. Tell how we can **use the flow as variational approximation**, explaning it in words while also providing the details. You can either describe the details mathematically or write python-like pseudocode where you explain how specific quantities are being transformed and what do we optimize for etc.
3. What do you think the result would be if you applied this for our gamma-Poisson problem studied throughout the exercises? What would the resulting posterior approximation most likely be? Do you see challenges in using this for that problem?