## Introduction to Stochastic Variational Inference in Pyro

In [1]:
import torch
import pyro

Pyro has been designed with particular attention paid to supporting stochastic variational inference as a general purpose inference algorithm

Now let’s establish some notation. The model has observations x and latent random variables z as well as parameters θ .It has a joint probability density of the form
$$
p_{\theta}({\bf x}, {\bf z}) = p_{\theta}({\bf x}|{\bf z}) p_{\theta}({\bf z})
$$

We assume that the various probability distributions $p_i$ that make up $p_{\theta}({\bf x}, {\bf z})$ have the following properties:

1. we can sample from each  $p_i$
2. we can compute the pointwise log pdf $p_i$
3. $p_i$ is differentiable w.r.t. the parameters $\theta$

###  Model Learning

In this context our criterion for learning a good model will be maximizing the log evidence, i.e. we want to find the value of $\theta$ given by

$$\theta_{\rm{max}} = \underset{\theta}{\operatorname{argmax}} \log p_{\theta}({\bf x})$$


where the log evidence $\log p_{\theta}({\bf x})$ is given by

$$\log p_{\theta}(x) = \log \int\! p_{\theta}({\bf x}, {\bf z})  d{\bf z}\;$$

In addition to finding $\theta_max$, we would like to calculate the posterior over the latent variables z:

$$p_{\theta_{\rm{max}}}({\bf z} | {\bf x}) = \frac{p_{\theta_{\rm{max}}}({\bf x} , {\bf z})}{
\int \! d{\bf z}\; p_{\theta_{\rm{max}}}({\bf x} , {\bf z}) }$$

Variational inference offers a scheme for finding $\theta_{max}$ and computing an approximation to the posterior $p_{\theta_{\rm{max}}}({\bf z} | {\bf x})$. It introduce a parameterized distribution $$q_{\phi}({\bf z})$$ where  $\phi$ are known as the variational parameters. This distribution is called the variational distribution however in the context of Pyro it’s called the **guide**. The **guide** in pyro serve as an approximation to the posterior


### Guide

The guide is encoded as a stochastic function **guide()** that contains **pyro.sample** and **pyro.param** statements. However, it does not contain observed data, since the guide needs to be a properly normalized distribution.It further required to provide a valid joint probability density over all the latent random variables in the model. In Pyro both the model() and guide() should have the same call signature, i.e. both callables should take the same arguments even if the distributions used in the two cases can be different. For example if the model contains a random variable *z_1*

```python
def model():
    pyro.sample("z_1", ...)
```
then the guide needs to have a matching sample statement
```python
def guide():
    pyro.sample("z_1", ...)
```

Once a guide has been specified, we can then perform learning and inference which is an optimization problem of maximizing the evidence lower bound (ELBO). The ELBO, is a function of both $\theta$ and $\phi$, defined as an expectation w.r.t. to samples from the guide:

$${\rm ELBO} \equiv \mathbb{E}_{q_{\phi}({\bf z})} \left [
\log p_{\theta}({\bf x}, {\bf z}) - \log q_{\phi}({\bf z})
\right]$$

### SVI Class

The **SVI** class is unified interface for stochastic variational inference in Pyro. To use this class you need to provide:
- the model, 
- the guide, and an 
- optimizer which is a wrapper a for a PyTorch optimizer as discusseced in below

```python
from pyro.infer import SVI, Trace_ELBO
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
```

The SVI object provides two methods, **step()** and **evaluate_loss()**, 
- The method step() takes a single gradient step and returns an estimate of the loss (i.e. minus the ELBO). 
- The method evaluate_loss() returns an estimate of the loss without taking a gradient step.

Both of these methods  accept an optional argument: **num_particles**, which denotes the number of samples used to compute the loss  and gradient.

### Optimizers

The module **pyro.optim** provides support for optimization in Pyro. In particular it provides **PyroOptim**, which is used to wrap PyTorch optimizers and manage optimizers for dynamically generated parameters. **PyroOptim** takes two arguments: 
- a constructor for PyTorch optimizers *optim_constructor* and 
- a specification of the optimizer *arguments optim_args*

```python
from pyro.optim import Adam
adam_params = {"lr": 0.005, "betas": (0.95, 0.999)}
optimizer = Adam(adam_params)
```

### Conditional Independence, Subsampling, and Amortization

For a model with N observations, running the model and guide and constructing the ELBO involves evaluating log pdf’s whose complexity scales badly with N. This is a problem if we want to scale to large datasets. Luckily, the ELBO objective naturally supports subsampling provided that our model/guide have some conditional independence structure that we can take advantage of. For example, in the case that the observations are conditionally independent given the latents, the log likelihood term in the ELBO can be approximated with. 

$$\sum_{i=1}^N \log p({\bf x}_i | {\bf z}) \approx  \frac{N}{M}
\sum_{i\in{\mathcal{I}_M}} \log p({\bf x}_i | {\bf z})$$

where $\mathcal{I}_M$ is a mini-batch of indices of size M with M<N. 

To implement this in pyro we first needs to make sure that the model and guide are written in such a way that Pyro can leverage the relevant conditional independencies. Pyro provides two language primitives for marking conditional independencies: **irange** and **iarange**. 

#### iarange
Context manager for conditionally independent ranges of variables. It is similar to torch.arange() in that it yields an array of indices by which other tensors can be indexed. However, iarange differs from torch.arange() in that:

- It informs inference algorithms that the variables being indexed are conditionally independent. 

```python
with iarange("name", size) as ind:
    # ...do conditionally independent stuff with ind...
```
- Additionally, iarange can take advantage of the conditional independence assumptions by subsampling the indices and informing inference algorithms to scale various computed values. This is typically used to subsample minibatches of data:

```python
with pyro.iarange("data", len(data), subsample_size=100) as ind:
    batch = data[ind]
    assert len(batch) == 100
```

#### irange
It is a Non-vectorized version of iarange:

```python
for i in pyro.irange("data", len(data), subsample_size=100):
    batch = data[i]
```

### Subsampling in Pyro
Subsampling allows us to SVI on large datasets. Depending on the structure of the model and guide, Pyro supports several ways of doing subsampling. 

#### 1. Automatic subsampling with irange and iarange
The simplest case in which we get subsampling for free with one or two additional arguments to irange and iarange:

- using irange
```python
for i in pyro.irange("data", len(data), subsample_size=50):
    pyro.sample("obs_{}".format(i), dist.Normal(loc, scale), obs=data[i])
```

- using iarange
```python
with pyro.iarange("data", len(data), subsample_size=100) as ind:
    batch = data[ind]
    assert len(batch) == 100
    pyro.sample('obs', dist.Normal(loc, scale), obs=batch)
```

**Limitation**: For a sufficiently large dataset even after a large number of iterations there’s a nonnegligible probability that some of the datapoints will have never been selected.


#### 2. Custom subsampling strategies with irange and iarange
We can take control of subsampling by making use of the subsample argument to irange and iarange

```python
batchsize = 20
data_size = 100
ind = torch.randint(0, data_size, (batchsize,)).long() 
with iarange('data', 100, subsample=ind):
    obs = sample('obs', dist.Normal(loc, scale), obs=data[ind])
```

### Amortization

Consider a model with global $\beta$ and local $z$ latent random variables and local variational parameters $\lambda$:
$$p({\bf x}, {\bf z}, \beta) = p(\beta)
\prod_{i=1}^N p({\bf x}_i | {\bf z}_i) p({\bf z}_i | \beta)  \qquad \qquad
$$
and 
$$q({\bf z}, \beta) = q(\beta) \prod_{i=1}^N q({\bf z}_i | \beta, \lambda_i)$$
For small to medium-sized N using local variational parameters like this can be a good approach. If N is large, however, the fact that the space we’re doing optimization over grows with N can be a real probelm. One way to avoid this nasty growth with the size of the dataset is amortization.

Amortization  works as follow, Instead of introducing local variational parameters, we’re going to learn a single parametric function $f(⋅)$ and work with a variational distribution that has the form

$$q(\beta) \prod_{n=1}^N q({\bf z}_i | f({\bf x}_i))$$

This approach has other benefits too: for example, during learning $f(⋅)$ effectively allows us to share statistical power among different datapoints. Note that this is precisely the approach used in the VAE