In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
import torch
import pyro
display(pyro.__version__)
pyro.set_rng_seed(12345) # For reproducibility
x = torch.randn(5)
y = 2*x -1 + 0.5*torch.randn(len(x))

# Stochastic Variational Inference with  [Pyro](https://pyro.ai/)

Previously we learned 

- the fundamental ideas of variational inference
- how to write probabilistic probabilistic models in `pyro`

In this lesson we will continue our `pyro` tutorial focused on performing [Stochastic Variational inference](https://www.jmlr.org/papers/volume14/hoffman13a/hoffman13a.pdf)(SVI). 

SVI allows variational inference (VI) to scale to large databases by subsampling, i.e. estimating quantities of interest using minibatches of data. Typically we will be interested in first-order methods based on gradient descent to optimize our cost functions: Stochastic gradient descent (SGD). Basically

> VI + SGD = SVI

The unified interface for SVI in `pyro` is located in [`pyro.infer.SVI`](http://docs.pyro.ai/en/stable/inference_algos.html). The parameters of the SVI object are

```python
pyro.infer.SVI(model, # A function that defines our generative model 
               guide, # A function that defines our approximate posterior
               optim, # A "gradient-descent-based" optimizer
               loss, # The cost function: Some variant of the ELBO
               ...
)
```

We already saw how to write models. Now we will focus now on how to write guides using the Bayesian linear regression from the previous pyro lesson as an example

In [None]:
from pyro.distributions import Normal, Uniform

def model_obs(x, y=None):  
    w = pyro.sample("w", Normal(0.0, 10.0))
    b = pyro.sample("b", Normal(0.0, 10.0))
    s = pyro.sample("s", Uniform(0.01, 10.0))    
    with pyro.plate('dataset', size=len(x)):
        return pyro.sample("y", Normal(x*w + b, s), obs=y)

## Guide function

The guide represents our approximate posterior $q_\nu(\theta)$, i.e. it has to specify the distribution for the posterior of the parameters of the model. 

More technically

- Every `pyro.sample` in the model that represents a latent variable has to be in the guide (using the same name)
- The primitive [`pyro.param`]() is used to register the hyperparameters of these latent variables
- The arguments of the guide have to be the same as those in the model
- The function does not need to return anything

The registered parameters correspond to $\nu$. These are the values that we will optimize later

For the bayesian linear regression $\theta = (w, b)$. We will create a fully factored normal model

$$
q_\nu(w,b,s) = q_{\nu_w}(w)q_{\nu_b}(b) q_{\nu_s}(s) = \mathcal{N}(w|\mu_w, \sigma_w^2) \mathcal{N}(b|\mu_b, \sigma_b^2) \mathcal{N}(s|\mu_s, 0.05)
$$

The hyperparameters are $\nu = (\mu_w, \sigma_w, \mu_b, \sigma_b, \mu_s)$

In [None]:
from torch.distributions import constraints

def guide(x, y=None): 
    # slope
    w_loc = pyro.param("w_loc", torch.tensor(0.))
    w_scale = pyro.param("w_scale", torch.tensor(1.), constraint=constraints.positive)
    w = pyro.sample("w", Normal(w_loc, w_scale))
    # intercept
    b_loc = pyro.param("b_loc", torch.tensor(0.))
    b_scale = pyro.param("b_scale", torch.tensor(1.), constraint=constraints.positive)
    b = pyro.sample("b", Normal(b_loc, b_scale))
    # noise variance
    s_loc = pyro.param("s_loc", torch.tensor(1.), constraint=constraints.positive)
    s = pyro.sample("s", Normal(s_loc, torch.tensor(0.05)), constraint=constraints.positive)
    

## Cost function

In the previous lesson we studied the Evidence Lower Bound (ELBO)

$$
\begin{align}
\hat \nu &= \text{arg}\max_\nu \mathcal{L}(\nu) \nonumber \\
&= \text{arg}\max_\nu \int q_\nu(\theta) \log \frac{p(\mathcal{D}|\theta) p (\theta)}{q_\nu(\theta)} d\theta
\end{align}
$$

where

- The model function defines $p(\mathcal{D}|\theta) p (\theta)$ 
- The guide function defines $q_\nu(\theta)$ 

Pyro offers several versions of the [ELBO](https://docs.pyro.ai/en/stable/inference_algos.html#module-pyro.infer.elbo)

- `Trace_ELBO`: Default ELBO. Reduces variance of the gradients using "Rao-Blackwellization"
- `TraceEnum_ELBO`: Performs exhaustive enumeration for discrete variables
- `TraceMeanField_ELBO`: Assumes Mean-field structure. Reduce variance of gradients using analytical KL when possible

We will study the importance of gradient variance later

## Training

The main method of the SVI object is 

- `svi.step(*args)`: Performs a gradient step, similar to the `backward()` plus `step()` in pytorch

The `step()` method receives the inputs for guide and model as arguments

In this example we select the default ELBO and SGD with adaptive learning rate as optimizer

In [None]:
pyro.enable_validation(True) # Activate additional debug of model/guides

pyro.clear_param_store()

svi = pyro.infer.SVI(model=model_obs,  
                     guide=guide,                     
                     optim=pyro.optim.ClippedAdam({"lr": 0.01}),# Optimizer
                     loss=pyro.infer.Trace_ELBO(num_particles=10,
                                                vectorize_particles=True) # Loss function
                    ) 

fig, ax = plt.subplots(1, 6, figsize=(10, 2.5), dpi=80, tight_layout=True)
lines = [ax_.plot([], [])[0] for ax_ in ax]
param_names = ["ELBO", "w_loc", "b_loc", "s_loc", "w_scale", "b_scale"]
param_evolution = {}
for name in param_names:
    param_evolution[name] = []
    
for ax_, name in zip(ax, param_names):
    ax_.set_title(name)

In [None]:
for k in tqdm(range(3000)):
    param_evolution["ELBO"].append(svi.step(x, y))
    for name in param_names[1:]:
        param_evolution[name].append(pyro.param(name).item()) 
    
    if np.mod(k, 100) == 0:
        for i, name in enumerate(param_names):
            lines[i].set_ydata(param_evolution[name][:k])
        for line in lines:
            line.set_xdata(range(k))
        for ax_ in ax.ravel():
            ax_.relim()
            ax_.autoscale_view()
        fig.canvas.draw()

## Predictive posterior

In [None]:
predictive = pyro.infer.Predictive(model_obs, 
                                   return_sites=("w", "b", "s", "y"),
                                   guide=guide,
                                   num_samples=1000
                                   #posterior_samples=sampler.get_samples()
                                   )

x_test = np.linspace(-5, 5, num=100).astype('float32') 
predictive_samples = predictive(torch.from_numpy(x_test))
med = predictive_samples["y"].median(axis=0).values.numpy()
qua = predictive_samples["y"].quantile(torch.tensor([0.05, 0.95]), axis=0).numpy()

fig, ax = plt.subplots(figsize=(7, 3), tight_layout=True)
ax.plot(x_test, med)
ax.fill_between(x_test, qua[0], qua[1], alpha=0.5);

ax.errorbar(x.numpy(), y.numpy(), xerr=0, 
            yerr=predictive_samples["s"].median().numpy(),
            fmt='none', c='k', zorder=100);

## Posterior of the parameters

We can plot the predictive posterior of the parameters

In [None]:
import corner
figure = corner.corner(np.stack((predictive_samples['b'].detach().numpy() , 
                                 predictive_samples['w'].detach().numpy(),
                                 predictive_samples['s'].detach().numpy())).T, 
                       smooth=1., bins=20, quantiles=[0.16, 0.5, 0.84], 
                       labels=["b", "w", "s"], show_titles=True, title_kwargs={"fontsize": 12})

And we can get the name and value of the hyper-parameters using

In [None]:
for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))

## Autoguides

Given a model, `pyro` offers functions to generate guides automatically from it. These are found in the [`pyro.infer.autoguide`](https://docs.pyro.ai/en/stable/infer.autoguide.html) module

For example the guide we previously wrote is roughly equivalent to `AutoDiagonalNormal`, a guide where the latent variables are normal and independent

Other interesting auto guides are

- `AutoMultivariateNormal`: Adds a correlation matrix for the latent variables 
- `AutoLowRankMultivariateNormal`: Similar to the previous one, but with a low rank covariance
- `AutoDelta`: Returns the MAP
- `AutoLaplaceApproximation`: Multivariate normal guide centered on the MAP with variance equal to the neg hessian
- `AutoNormalizingFlow`: Uses a sequence of bijective transformations starting from a Gaussian to obtain a more flexible distribution (more in this in future lessons)

Let's test some of them

In [None]:
from pyro.infer.autoguide import AutoDelta, AutoLaplaceApproximation, AutoDiagonalNormal, AutoMultivariateNormal

pyro.enable_validation(True)
pyro.clear_param_store()

#guide = AutoDelta(model_obs)
#guide = AutoDiagonalNormal(model_obs, init_scale=1e-2)
guide = AutoMultivariateNormal(model_obs, init_scale=1e-2)

svi = pyro.infer.SVI(model=model_obs,  
                     guide=guide,                     
                     optim=pyro.optim.ClippedAdam({"lr": 0.01}),# Optimizer
                     loss=pyro.infer.Trace_ELBO(num_particles=10,
                                                vectorize_particles=True) # Loss function
                    ) 

fig, ax = plt.subplots(figsize=(6, 3), tight_layout=True)
ELBO = []

for k in tqdm(range(3000)):
    ELBO.append(svi.step(x, y))
    
    if np.mod(k, 100) == 0:
        ax.cla()
        ax.plot(ELBO)
        fig.canvas.draw()

In [None]:
for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))

In [None]:
predictive = pyro.infer.Predictive(model_obs, 
                                   return_sites=("w", "b", "s", "y"),
                                   guide=guide,
                                   num_samples=1000)
x_test = np.linspace(-5, 5, num=100).astype('float32') 
predictive_samples = predictive(torch.from_numpy(x_test))

import corner
figure = corner.corner(np.stack((predictive_samples['b'].detach().numpy() , 
                                 predictive_samples['w'].detach().numpy(),
                                 predictive_samples['s'].detach().numpy())).T, 
                       smooth=1., bins=20, quantiles=[0.16, 0.5, 0.84], 
                       labels=["b", "w", "s"], show_titles=True, title_kwargs={"fontsize": 12})

## Self-study

- Pyro's SVI tutorial part [I](https://pyro.ai/examples/svi_part_i.html) and [II](https://pyro.ai/examples/svi_part_ii.html)
- [Pyro tutorial tips and tricks](https://pyro.ai/examples/svi_part_iv.html)

TODO: Do example on [correlated gaussian](https://jmhldotorg.files.wordpress.com/2013/11/slidescharlesuniversitylaplacevi2013.pdf)