<a href="https://colab.research.google.com/github/sg5g10/VI_masterclass/blob/main/ADVI_masterclass_logistic_regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch.distributions import Normal, Gamma, Bernoulli

import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange
import seaborn as sns
sns.set_context("paper", font_scale=1)
sns.set(rc={"figure.figsize":(9,9),"font.size":16,"axes.titlesize":16,"axes.labelsize":16,
           "xtick.labelsize":15, "ytick.labelsize":15},style="white")

# Set default type to float64 (instead of float32)
torch.set_default_dtype(torch.float64)

# ADVI
The algorithm's main steps are as follows:
1. Sample $\epsilon \sim \mathcal{N}(0,1)$
2. Affine transform to generate $\xi=\mu + \log(\sigma) \epsilon$
3. Support transform, if needed, using $\theta=\mathcal{T}(\xi)$
4. Evaluate $\mathcal{L}_{MC}(\mu,\log(\sigma)$
5. Apply SGD with automatic differentiation to evaluate $\nabla_{\mu,\log(\sigma)}\mathcal{L}_{MC}(\mu,\log(\sigma))$

## A re-usable class for the variational approximation. 

This handles the parameterisation, sampling (with reparameterisation trick) and log density evaluations. Note Pytorch' distributions are coded in a way such that reparameterisation is supported whenever sampling is done.



In [2]:
from posixpath import sameopenfile
# Variational Approximation
class VarApprox():
    def __init__(self, size, m=None, log_s=None, suppTrans=True):
        if m is None:
            # Set the mean for the unconstrained variational distribution.
            m = torch.randn(size)

        if log_s is None:
            # Set the log standard deviation  for the unconstrained variational
            # distribution.
            log_s = torch.randn(size)

        # Variational parameters
        self.lam = torch.stack([m, log_s])
        self.lam.requires_grad = True

        # Dimension of the variational parameters
        self.size = size

        # Define support constraint
        self.suppTrans = suppTrans

    def dist(self):
        # Unconstrained variational distribution -- a Gaussian always.
        return torch.distributions.Normal(self.lam[0], self.lam[1].exp())

    def rsample(self, n=torch.Size([20])):
        # Same as:
        # self.lam[0] + torch.randn(n) * self.lam[1].exp()

        sample = self.dist().rsample(n)
        if self.suppTrans:
          return torch.exp(sample) # Transform to positive real
        else:
          return sample

    def log_q(self, real):
        # Log density of q evaluated on unonstrained
        # space.
        if self.suppTrans:
          real = torch.log(real)
        return self.dist().log_prob(real).sum(dim=1)


In [3]:
def log_q(var_approx, params):
  out = 0.0
  for site in var_approx:
    out += var_approx[site].log_q(params[site])
  return out

In [4]:
def elbo(y, x, var_approx):
  params = {}
  for site in var_approx:
      params[site] = var_approx[site].rsample()

  out = loglike(y, x, params)
  out += log_prior(params) 
  out -= log_q(var_approx, params)
  
  return out.mean() 

To better visualise the ELBO, we can use a moving average. Specifically, we are going to use a small class to code-up an exponential moving average.

In [5]:
class EMAMetric(object):
  def __init__(self, gamma=.99):

    super(EMAMetric, self).__init__()
    self._val = 0.
    self._gamma = gamma

  def step(self, x):
    x = x.detach().cpu().numpy() if torch.is_tensor(x) else x
    self._val = self._gamma * self._val + (1 - self._gamma) * x
    return self._val

  @property
  def val(self):
    return self._val

## Example: Bayesian logistic regression

$$ \beta_k\sim \mathcal{N}(0,1)\\
\gamma \sim \mathcal{N}(0,1)\\
y_n \sim \operatorname{Bern}(\operatorname{Logit}^{-1}(\boldsymbol{X}_n\boldsymbol{\beta} + \gamma))$$

Define the necessary functions for this models: 
1. ``log_prior`` 
2. ``loglike``


In [6]:
#@title Excercise: Implement the ``log_prior`` and ``log_like`` functions.
def log_prior(params):
  # log prior for beta, evaluated at sampled values for beta
  lp_b = Normal(0, 1).log_prob(params['beta']).sum(dim=1)

  # log prior sig
  lp_log_sig = Normal(0, 1).log_prob(params['gamma']).sum(dim=1)
  return lp_b + lp_log_sig

def loglike(y, x, params):
  beta = params['beta']
  gamma = params['gamma']
  logit_prob = beta.matmul(x.T) + gamma
  prob_pres = torch.log(torch.sigmoid(logit_prob))
  prob_abs = torch.log(torch.sigmoid(-logit_prob))

  return Bernoulli(logits=logit_prob).log_prob(y).sum(dim=-1) 

In [None]:
#@title Excercise: Generate simulated data.
N = 1000
x = torch.stack([torch.randn(N), torch.randn(N)], -1)
k = x.shape[1]
beta = torch.randn(2)
gamma = torch.randn(1)
logit_true = beta @ x.T + gamma
y = Bernoulli(logits=logit_true).sample()

# Plot data
plt.scatter(x[:, 1].numpy(), y.numpy())
plt.xlabel("x")
plt.ylabel("y")
plt.show()

### Here is the run_advi routine from the last excercise

In [17]:
def run_advi(_y, _x, approx, max_iter=3000, lr=0.1):
  optimizer = torch.optim.Adam([approx[site].lam for site in var_approx], lr=lr)
  elbo_hist = []
  loss_metric = EMAMetric()
  torch.manual_seed(1)
  iters = trange(max_iter, mininterval=1)
  N = x.shape[0]
  for t in iters:
    loss = -elbo(_y, _x, approx) / N
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    loss_metric.step(loss)
    iters.set_description('ELBO: {}'.format(-loss_metric.val), refresh=False) 
    elbo_hist.append(loss_metric.val)
  return elbo_hist

### Press the inference button!
1. Instantiate the ``VarApprox`` class. gotcha: in this example we don't need support transformation.

2. Run VI

In [None]:
#@title Excercise: Instantiate the variational approx.
var_approx = {'beta': VarApprox(size=k, suppTrans=False), 'gamma': VarApprox(size=1, suppTrans=False)}
loss = run_advi(y, x, var_approx)

In [None]:
# Plot ELBO history
plt.plot(loss)
plt.title('complete elbo history')
plt.show()

# Plot ELBO history (after 400-th iteration)
plt.plot(loss[500:])
plt.title('tail of elbo history')

In [None]:
# Inspect posterior
nsamps = 1000

print('True beta: {}'.format(beta.detach().numpy()))
print('beta mean: {}'.format(var_approx['beta'].lam[0].detach().numpy()))
print('beta sd: {}'.format(var_approx['beta'].lam[1].exp().detach().numpy()))
print()

print('True sig: {}'.format(gamma.detach().numpy()))
print('sig mean: {}'.format(var_approx['gamma'].lam[0].detach().numpy()))
print('sig sd: {}'.format(var_approx['gamma'].lam[1].exp().detach().numpy()))
print()

### Comparison with MCMC.

This is a simple model (computationally). So, we can easily run MCMC to infer the posteriors. Lets do so.

For this purpose we will be using a probabilistic programming library) (PPL), like STAN. In this case we will be using a very powerfull library Pyro.



In [None]:
!pip3 install pyro-ppl

In [22]:
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS



Probabilistic models in Pyro are specified as Python functions ``model(*args, **kwargs)`` that generate observed data from latent variables using special primitive functions whose behavior can be changed by Pyro’s internals depending on the high-level computation being performed.

Specifically, the different mathematical pieces of ``model()`` are encoded via the mapping:

1.   latent random variables   ``pyro.sample``
2.   observed random variables   ``pyro.sample`` with the obs keyword argument

In the code below I have used a loop over the datapoint. But this can be avoided using ``pyro.plate`` notation. 

In [23]:
def model(y, x):
    beta = pyro.sample("beta", dist.Normal(torch.zeros(x.shape[1]), torch.ones(x.shape[1])))
    gamma = pyro.sample("gamma", dist.Normal(torch.zeros(1), torch.ones(1)))

    logit_prob = beta.matmul(x.T) + gamma 

    for n in range(len(y)):
        # observe datapoint i using the Bernoulli
        # likelihood Bernoulli(logit)
        pyro.sample("obs_{}".format(n), dist.Bernoulli(logits=logit_prob[n]), obs=y[n])

In [None]:
nuts_kernel = NUTS(model, jit_compile=True)
mcmc = MCMC(
        nuts_kernel,
        num_samples=1000,
        warmup_steps=500,
        num_chains=2,
    )
mcmc.run(y, x)
mcmc.summary(prob=0.5)

Collect the parameter samples from the MCMC run. Samples parameters values from the variational approx. we have learnt.

In [25]:
mc_params = np.concatenate((mcmc.get_samples()["beta"], mcmc.get_samples()["gamma"]),axis=1)

vb_params = np.concatenate((var_approx['beta'].rsample([100]).detach().numpy(),
                  var_approx['gamma'].rsample([100]).detach().numpy()), 
                 axis=1)

### Compare the two inference method

In [None]:
param_names = [r"$\beta_1$",r"$\beta_2$", r"$\beta_3$"]
real_params = np.array([*beta,gamma])
for i, p in enumerate(param_names):
        
        # Add histogram subplot
        plt.subplot(3, 1, i+1)
        plt.axvline(real_params[i], linewidth=2.5, color='black')
        if i==0:
            sns.kdeplot(vb_params[:, i], color='magenta', linewidth = 2.5, label='Variational')
            sns.kdeplot(mc_params[:, i], color='orange', linewidth = 2.5, label='MCMC')
        elif i==1:
          sns.kdeplot(vb_params[:, i], color='magenta', linewidth = 2.5, label='Variational')
          sns.kdeplot(mc_params[:, 1], color='orange', linewidth = 2.5,label='MCMC')
        else:
          sns.kdeplot(vb_params[:, i], linewidth = 2.5, color='magenta')
          sns.kdeplot(mc_params[:, i], linewidth = 2.5, color='orange')  

        if i%2==0:
            plt.ylabel('Frequency')
        plt.xlabel(param_names[i])        
        if i<1:
            plt.legend(bbox_to_anchor=(0., 1.02, 1., .102), loc='lower center', ncol=2,fontsize=18)
plt.subplots_adjust(hspace=0.7)
plt.tight_layout()