# Stochastic Variational Inference in Pyro

Pyro has been designed with particular attention paid to supporting stochastic variational inference as a general purpose inference algorithm.  Let's see how we go about doing variational inference in Pyro.

## Setup

We're going to assume we've already defined our model in Pyro. For examples how to do this [**SEE LINK**].  The model has observations ${\bf x}$ and latent random variables ${\bf z}$ as well as parameters $\theta$. It has a joint probability density of the form 

$$p_{\theta}({\bf x}, {\bf z}) = p_{\theta}({\bf x}|{\bf z}) p_{\theta}({\bf z})$$

The model is given as a stochastic function `model(*args, **kwargs)`, which, in the general case takes arguments. The different pieces of `model()` are encoded via the mapping:

1. observations $\Longleftrightarrow$ `pyro.observe`
2. latent random variables $\Longleftrightarrow$ `pyro.sample`
3. parameters $\Longleftrightarrow$ `pyro.param`

## Learning

In this context our criterion for learning a good model will be maximizing the log evidence, i.e. we want to find the value of $\theta$ given by

$$\theta_{\rm{max}} = \underset{\theta}{\operatorname{argmax}} \log p_{\theta}({\bf x})$$

where the log evidence $\log p_{\theta}({\bf x})$ is given by

$$\log p_{\theta}(x) = \log \int\! d{\bf z}\; p_{\theta}({\bf x}, {\bf z})$$

In the general case this is a doubly difficult problem. This is because (even for a fixed $\theta$) the integral over the latent random variables $\bf z$ is often intractable. Similarly, even if we know how to calculate the log evidence for all values of $\theta$, maximizing the log evidence as a function of $\theta$ will in general be a difficult non-linear optimization problem. 

In addition to finding $\theta_{\rm{max}}$, we would like to calculate the posterior over the latent variables $\bf z$:

$$ p_{\theta_{\rm{max}}}({\bf z} | {\bf x}) = \frac{p_{\theta_{\rm{max}}}({\bf x} , {\bf z})}{
\int \! d{\bf z}\; p_{\theta_{\rm{max}}}({\bf x} , {\bf z}) } $$

Note that the denominator of this expression is the (usually intractable) evidence.

## The ELBO

Variational inference offers a scheme for finding $\theta_{\rm{max}}$ and computing an approximation to the posterior $p_{\theta_{\rm{max}}}({\bf z} | {\bf x})$. The basic idea is that we introduce a parameterized distribution $q_{\phi}({\bf z})$, where $\phi$ are known as the variational parameters. This distribution is called the variational distribution in most of the literature, and in the context of Pyro it's called the **guide**. The guide will serve as an approximation to the posterior.

Just like the model, the guide is encoded as a stochastic function `guide()` that contains `pyro.sample` and `pyro.param` statements. It does _not_ contain `pyro.observe` statements, since the guide needs to be a properly normalized distribution. Note that Pyro enforces that `model()` and `guide()` have the same call signature, i.e. both callables should take the same arguments. 

Learning will be setup as an optimization problem where each iteration of training takes a step in $\theta-\phi$ space that moves the guide closer to the exact posterior.
To do this we need to define an appropriate objective function. A simple derivation (for example see reference [1]) yields what we're after: the evidence lower bound (ELBO). The ELBO, which is a function of both $\theta$ and $\phi$, is defined as an expectation w.r.t. to samples from the guide:

$${\rm ELBO} \equiv \mathbb{E}_{q_{\phi}({\bf z})} \left [ 
\log p_{\theta}({\bf x}, {\bf z}) - \log q_{\phi}({\bf z})
\right]$$

By assumption we can compute the log probabilities inside the expectation. And since the guide is by assumption a parametric distribution we can sample from, we can compute Monte Carlo estimates of this quantity. Crucially, the ELBO is a lower bound to the log evidence. For all choices of $\theta$ and $\phi$ we have that 

$$\log p_{\theta}({\bf x}) \ge {\rm ELBO} $$

So if we take gradient steps to maximize the ELBO, we will also be pushing the log evidence higher. Furthermore, it can be shown that the gap between the ELBO and the log evidence is given by the KL divergence between the guide and the posterior:

$$ \log p_{\theta}({\bf x}) - {\rm ELBO} = 
\rm{KL}\!\left( q_{\phi}({\bf z}) \lVert p_{\theta}({\bf z} | {\bf x}) \right) $$

This KL divergence is a particular (non-negative) measure of 'closeness' between two distributions. So, for a fixed $\theta$, as we take steps in $\phi$ space that increase the ELBO, we decrease the KL divergence between the guide and the posterior, i.e. we move the guide towards the posterior. In the general case we take gradient steps in both $\theta$ and $\phi$ space so that the guide and model play chase, with the guide tracking a moving posterior $\log p_{\theta}({\bf x} | {\bf z})$. Perhaps somewhat surprisingly, this often works quite well in practice.

So at high level variational inference is easy: all we need to do is define a guide and compute gradients of the ELBO. Actually, computing gradients for general model and guide pairs leads to some complications, see tutorial [**INSERT LINK**]. For the purposes of this tutorial, let's consider that a solved problem and look at the support that Pyro provides for doing variational inference. 

## The `SVI` Class

In Pyro the machinery for doing variational inference is encapsulated in the `SVI` class.

## References

[1] `Automated Variational Inference in Probabilistic Programming`,
<br/>&nbsp;&nbsp;&nbsp;&nbsp;
David Wingate, Theo Weber

[2] `Black Box Variational Inference`,<br/>&nbsp;&nbsp;&nbsp;&nbsp;
Rajesh Ranganath, Sean Gerrish, David M. Blei

[3] `Auto-Encoding Variational Bayes`,<br/>&nbsp;&nbsp;&nbsp;&nbsp;
Diederik P Kingma, Max Welling