# SVI Part II: Subsampling,  Conditional Independence, and Amortization

## The Goal: Scaling SVI to Large Datasets

For a model with $N$ observations, running the `model` and `guide` and constructing the ELBO involves evaluating log pdf's whose complexity scales with $N$. This is a problem if we want to scale to large datasets. Luckily, the ELBO objective naturally supports subsampling provided that our model/guide have some conditional independence structure that we can take advantage of. For example, in the case that the observations are conditionally independent given the latents, the log likelihood term in the ELBO can be approximated with

$$ \sum_{i=1}^N \log p({\bf x}_i | {\bf z}) \approx  \frac{N}{M}
\sum_{i\in{\mathcal{I}_M}} \log p({\bf x}_i | {\bf z})  $$

where $\mathcal{I}_M$ is a mini-batch of indices of size $M$ with $M<N$ (for a discussion please see references [1,2]). Great, problem solved! But how do we do this in Pyro?

## Marking Conditional Independence in Pyro

If a user wants to do this sort of thing in Pyro, he or she first needs to make sure that the model and guide are written in such a way that Pyro can leverage the relevant conditional independencies. Let's see how this is done. Pyro provides two language primitives for marking conditional independencies: `irange` and `iarange`. Let's start with the simpler of the two.

### `irange`

Let's return to the example we used in the previous tutorial [**INSERT LINK**]. For convenience let's replicate the main logic of `model` here:

In [None]:
def model(data):
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.beta, alpha0, beta0)
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.observe("obs_{}".format(i), dist.bernoulli, data[i], f)

For this model the observations are conditionally independent given the latent random variable `latent_fairness`. To explicitly mark this in Pyro we basically just need to replace the Python builtin `range` with the Pyro construct `irange`:

In [None]:
def model(data):
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.beta, alpha0, beta0)
    # loop over the observed data
    for i in pyro.irange("data_loop", data):  # <== we only changed this line
        # observe datapoint i using the bernoulli likelihood
        pyro.observe("obs_{}".format(i), dist.bernoulli,
                     data[i], f)

We see that `pyro.irange` is very similar to `range` with two small differences. First, each invocation of `irange` requires the user to provide a unique name. Second, while `irange` is an iterator over integers just like `range`, instead of taking an explicit length argument like `range(3)`, we instead pass in the iterable we'd like to iterate over (in this case `data`). The only requirement is that the iterable provide a `len()` method. 

So far so good. Pyro can now leverage the conditional indendency of the observations given the latent random variable. But how this does actually work? Basically `pyro.irange` is implemented using a context manager. At every execution of the body of the `for` loop we enter a new (conditional) independence context which is then exited at the end of the `for` loop body. Let's be very explicit about this: 

- because each `pyro.observe` statement occurs within a different execution of the body of the `for` loop, Pyro marks each observation as independent
- that this independence is properly a _conditional_ independence _given_ `latent_fairness` follows because `latent_fairness` is sampled _outside_ of the context of `data_loop`.

Before moving on, let's mention some gotchas to be avoided when using `irange`. Consider the following variant of the above code snippet:

In [None]:
my_reified_list = list(pyro.irange("data_loop", data))
for i in my_reified_list:  
    pyro.observe("obs_{}".format(i), dist.bernoulli, data[i], f)

This will _not_ achieve the desired behavior, since `list()` will enter and exit the `data_loop` context completely before a single `pyro.observe` statement is called. Similarly, the user needs to take care not to leak mutable computations across the boundary of the context manager, as this may lead to subtle bugs.

## `iarange`

Conceptually `iarange` is the same as `irange` except that it is a vectorized operation. As such it potentially enables large speed-ups compared to the explicit `for` loop that appears with `irange`. Let's see how this would look for our running example. First we need `data` to be in the form of a tensor:

In [None]:
data = Variable(torch.zeros(10, 1))
data[0:6, 0] = torch.ones(6)  # 6 heads and 4 tails

Then we have:

In [None]:
with iarange('observe_data'):
    pyro.observe('obs', dist.bernoulli, data, f)

Let's compare this to the analogous `irange` construction point-by-point:
- just like `irange`, `iarange` requires the user to specify a unique name.
- note that this code snippet only introduces a single (observed) random variable (namely `obs`), since the entire tensor is considered at once. 
- since there is no need for an iterator in this case, there is no need to specify the length of the tensor(s) involved in the independent context

## Subsampling

We now know how to mark conditional independence in Pyro. This is useful in and of itself (see the dependency tracking section in [**INSERT LINK**]), but we'd like to do subsampling so that we can do SVI on large datasets. Depending on the structure of the model and guide, Pyro supports several ways of doing subsampling. Let's go through these one by one.

## Amortization [placeholder: needs to be totally rewritten]

The purpose of the guide (i.e. the variational distribution) is to provide a (parameterized) approximation to the exact posterior $p({\bf z}_{1:T}|{\bf x}_{1:T})$. Actually, there's an implicit assumption here which we should make explicit, so let's take a step back. 
Suppose our dataset $\mathcal{D}$ consists of $N$ sequences 
$\{ {\bf x}_{1:T_1}^1, {\bf x}_{1:T_2}^2, ..., {\bf x}_{1:T_N}^N \}$. Then the posterior we're actually interested in is given by 
$p({\bf z}_{1:T_1}^1, {\bf z}_{1:T_2}^2, ..., {\bf z}_{1:T_N}^N | \mathcal{D})$, i.e. we want to infer the latents for _all_ $N$ sequences. Even for small $N$ this is a very high-dimensional distribution that will require a very large number of parameters to specify. In particular if we were to directly parameterize the posterior in this form, the number of parameters required would grow (at least) linearly with $N$. One way to avoid this nasty growth with the size of the dataset is *amortization*.

This works as follows. Instead of introducing variational parameters for each sequence in our dataset, we're going to learn a single parametric function $f({\bf x}_{1:T})$ and work with a variational distribution that has the form $\prod_{n=1}^N q({\bf z}_{1:T_n}^n | f({\bf x}_{1:T_n}^n))$. The function $f(\cdot)$&mdash;which basically maps a given observed sequence to a set of variational parameters tailored to that sequence&mdash;will need to be sufficiently rich to capture the posterior accurately, but now we can handle large datasets without having to introduce an obscene number of variational parameters. This approach has other benefits too: for example, during learning $f(\cdot)$ effectively allows us to share statistical power among different sequences.

## References

[1] `Stochastic Variational Inference`,
<br/>&nbsp;&nbsp;&nbsp;&nbsp;
Matthew D. Hoffman, David M. Blei, Chong Wang, John Paisley

[2] `Auto-Encoding Variational Bayes`,<br/>&nbsp;&nbsp;&nbsp;&nbsp;
Diederik P Kingma, Max Welling