<a href="https://colab.research.google.com/github/sg5g10/VI_masterclass/blob/main/ADVI_masterclass_Linear_regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch.distributions import Normal, Gamma, Bernoulli

import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange

# Set default type to float64 (instead of float32)
torch.set_default_dtype(torch.float64)

# ADVI
The algorithm's main steps are as follows:
1. Sample $\epsilon \sim \mathcal{N}(0,1)$
2. Generate $\xi=\mu + \log(\sigma) \epsilon$
3. Support transform, if needed, using $\theta=\mathcal{T}(\xi)$
4. Evaluate $\mathcal{L}_{MC}(\mu,\log(\sigma))$
5. Apply SGD with automatic differentiation to evaluate $\nabla_{\mu,\log(\sigma)}\mathcal{L}_{MC}(\mu,\log(\sigma))$

## A re-usable class for the variational approximation. 

This handles the parameterisation, sampling (with reparameterisation trick) and log density evaluations. Note Pytorch' distributions are coded in a way such that reparameterisation is supported whenever sampling is done.



In [2]:
from posixpath import sameopenfile
# Variational Approximation
class VarApprox():
    def __init__(self, size, m=None, log_s=None, suppTrans=True):
        if m is None:
            # Set the mean for the unconstrained variational distribution.
            m = torch.randn(size)

        if log_s is None:
            # Set the log standard deviation  for the unconstrained variational
            # distribution.
            log_s = torch.randn(size)

        # Variational parameters
        self.lam = torch.stack([m, log_s])
        self.lam.requires_grad = True

        # Dimension of the variational parameters
        self.size = size

        # Define support constraint
        self.suppTrans = suppTrans

    def dist(self):
        # Unconstrained variational distribution -- a Gaussian always.
        return torch.distributions.Normal(self.lam[0], self.lam[1].exp())

    def rsample(self, n=torch.Size([20])):
        # Same as:
        # self.lam[0] + torch.randn(n) * self.lam[1].exp()

        sample = self.dist().rsample(n)
        if self.suppTrans:
          return torch.exp(sample) # Transform to positive real
        else:
          return sample

    def log_q(self, real):
        # Log density of q evaluated on unonstrained
        # space.
        if self.suppTrans:
          real = torch.log(real)
        return self.dist().log_prob(real).sum(dim=1)


## First example: Linear regression

$$ \beta_k\sim \mathcal{N}(0,1)\\
\sigma \sim \mathcal{\gamma}(1,1)\\
y_n \sim \mathcal{N}(\boldsymbol{X}_n\boldsymbol{\beta}, \sigma^2)$$

Define the necessary functions for computing the following quantities: 
1. ``log_prior`` 
2. ``loglike``
3. ``log_q``
4. ``ELBO``

In [3]:
def log_prior(params):
  # log prior for beta, evaluated at sampled values for beta
  lp_b = Normal(0, 1).log_prob(params['beta']).sum(dim=1)

  # log prior sig
  lp_log_sig = Gamma(1, 1).log_prob(params['sig']).sum(dim=1)
  return lp_b + lp_log_sig

def loglike(y, x, params):
  beta = params['beta']
  sig = params['sig']
  return Normal(beta.matmul(x.T), sig).log_prob(y).sum(dim=-1) 

def log_q(var_approx, params):
  out = 0.0
  for site in var_approx:
    out += var_approx[site].log_q(params[site])
  return out

def elbo(y, x, var_approx):
  params = {}
  for site in var_approx:
      params[site] = var_approx[site].rsample()

  out = loglike(y, x, params)
  out += log_prior(params) 
  out -= log_q(var_approx, params)
  
  return out.mean()

### Generate some simulate data

In [None]:
# Generate data
N = 1000
x = torch.stack([torch.ones(N), torch.randn(N)], -1)
k = x.shape[1]
beta = torch.tensor([2., -3.])
sig = 0.5
y = Normal(x.matmul(beta), sig).rsample()

# Plot data
plt.scatter(x[:, 1].numpy(), y.numpy())
plt.xlabel("x")
plt.ylabel("y")
plt.show()

### SGD to maximise the MC ELBO Gradient

The function below runs the stochastic gradient optimisation. Obviously the vanilla SGD is rarely used. So, we are going to use the ``ADAM`` variant.

In [5]:
def run_advi(_y, _x, approx, max_iter=5000, lr=0.1):
  optimizer = torch.optim.Adam([approx[site].lam for site in var_approx], lr=lr)
  elbo_hist = []

  torch.manual_seed(1)

  # Progress bar for SGD. max_iter determines the final iteration.
  # mininterval determines refresh rate. Eevery 1 second used.
  iters = trange(max_iter, mininterval=1)
  N = x.shape[0]
  # Stochastic gradient descent
  for t in iters:
    loss = -elbo(_y, _x, approx) / N
    elbo_hist.append(-loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print progress bar
    iters.set_description('ELBO: {}'.format(elbo_hist[-1]), refresh=False) 
  return elbo_hist

### Press the inference button!
1. Instantiate the ``VarApprox`` class. 

2. Run VI

In [6]:
var_approx = {'beta': VarApprox(size=k, suppTrans=False), 'sig': VarApprox(size=1)}
loss = run_advi(y, x, var_approx)

ELBO: -0.7744201054638852: 100%|██████████| 5000/5000 [00:22<00:00, 223.16it/s]


In [None]:
# Plot ELBO history
plt.plot(loss)
plt.title('complete elbo history')
plt.show()

# Plot ELBO history (after 400-th iteration)
plt.plot(loss[500:])
plt.title('tail of elbo history')

In [None]:
# Inspect posterior
nsamps = 1000
sig_post = var_approx['sig'].rsample([nsamps]).detach().numpy()
print('True beta: {}'.format(beta.detach().numpy()))
print('beta mean: {}'.format(var_approx['beta'].lam[0].detach().numpy()))
print('beta sd: {}'.format(var_approx['beta'].lam[1].exp().detach().numpy()))
print()

print('True sigma: {}'.format(sig))
print('sig mean: {} | sig sd: {}'.format(sig_post.mean(), sig_post.std()))

In [None]:
#@title Excercise: Implement mini-batch ADVI, change the ``run_advi``, ``log_like`` and ``elbo`` functiona appropriately.
def run_advi(_y, _x, approx, minibatch_size=100, max_iter=5000, lr=0.1):
  optimizer = torch.optim.Adam([approx[site].lam for site in var_approx], lr=lr)
  elbo_hist = []

  torch.manual_seed(1)

  # Progress bar for SGD. max_iter determines the final iteration.
  # mininterval determines refresh rate. Eevery 1 second used.
  iters = trange(max_iter, mininterval=1)
  N = x.shape[0]
  # Stochastic gradient descent
  for t in iters:
    sample_with_replacement = minibatch_size > N
    batch_idx = np.random.choice(N, minibatch_size, replace=sample_with_replacement)
    loss = -elbo(y[batch_idx], x[batch_idx, :], approx, full_data_size=N) / N
    
    elbo_hist.append(-loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print progress bar
    iters.set_description('ELBO: {}'.format(elbo_hist[-1]), refresh=False) 
  return elbo_hist

def loglike(y, x, params, full_data_size):
  beta = params['beta']
  sig = params['sig']
  return Normal(beta.matmul(x.T), sig).log_prob(y).mean(dim=-1) * full_data_size

def elbo(y, x, var_approx, full_data_size):
  params = {}
  for site in var_approx:
      params[site] = var_approx[site].rsample()

  out = loglike(y, x, params, full_data_size)
  out += log_prior(params) 
  out -= log_q(var_approx, params)
  
  return out.mean()

var_approx = {'beta': VarApprox(size=k, suppTrans=False), 'sig': VarApprox(size=1)}
loss = run_advi(y, x, var_approx)

# Inspect posterior
nsamps = 1000
sig_post = var_approx['sig'].rsample([nsamps]).detach().numpy()
print('\n True beta: {}'.format(beta.detach().numpy()))
print('beta mean: {}'.format(var_approx['beta'].lam[0].detach().numpy()))
print('beta sd: {}'.format(var_approx['beta'].lam[1].exp().detach().numpy()))
print()

print('True sigma: {}'.format(sig))
print('sig mean: {} | sig sd: {}'.format(sig_post.mean(), sig_post.std()))