# Intro to Stochastic Variational Inference (SVI) in Pyro

Pyro has been designed with particular attention paid to supporting stochastic variational inferences as a general purpose inference algorithm. Let's see how we go about doing variational inference in Pyro.

## Setup

We're going to assume we've already defined our model in Pyro. As a quick reminder, the model is given as a stochastic function `model(*args, **kwargs)`, which, in the general case takes arguments. The different pieces of `model()` are encoded via the mapping.:
1. observations <-> `pyro.sample` with the `obs` argument
2. latent random variables <-> `pyro.sample`
3. parameters <-> `pyro.param`

Now let's establish some notation. The model has observations __x__ and latent random variables __z__ as well as parameters θ. It has a joint probability density of the form

\begin{equation}
p_{\theta}(\mathbf{x}, \mathbf{z}) = p_{\theta}(\mathbf{x}|\mathbf{z})p_{\theta}(\mathbf{z})
\end{equation}

We assume that the various probability distributions $p_{i}$, that make up $p_{\theta}(\mathbf{x}, \mathbf{z})$ have the following properties.
1. We can sample from each $p_i$
2. We can compute the pointwize log pdf of $p_i$
3. $p_i$ is differentiable w.r.t. the paramters $\theta$

## Model Learning

In this context our criterion for learning a good model will be maximizing the log evidence, i.e. we want to find the value of $\theta$ given by

\begin{equation}
\theta_{max} = \underset{\theta}{\operatorname{arg max}}{\log p_{\theta}(\mathbf{x})}
\end{equation}

where the log evidence $\log p_{\theta}(\mathbf{x})$ is given by 

\begin{equation}
\log p_{\theta}(\mathbf{x}) = \log \int d\mathbf{z} p_{\theta}(\mathbf{z})
\end{equation}

In the general case this is a double difficult problem. This is because (even for a fixed $\theta$) the integral over the latent random variables $\mathbf{z}$ is often intractable. Furthermore, even if we know how to calculate the log evidence for all values of $\theta$, maximizing the log evidence as a fucntion of $\theta$ will in general be a difficult non-convex optimization problem.

In addition to finding $\theta_{max}$, we would like to calculate the posterior over the latent variables $\mathbf{z}$:

EQUATION

Note that the denominator of this expression is the (usually intractable) evidence. Variational inference offers a scheme for finding $\theta_{max}$ and computing an approximation to the posterior $p_{\theta}(\mathbf{z}|\mathbf{x})$. Lets see how that works.

## Guide
The basic idea is that we introduce a parameterized distribution $q_{\phi}(\mathbf{z})$ , where $\phi$ are known as the variational parameters. This distribution is called the variational distribution in much of the literature, and in the context of Pyro it's called the guide. The guide will serve as an approximation to the posterior.

Just liek the model, the guide is encoded as a stochastic function `guide()` that contains `pyro.sample` and `pyro.param` statements. It does not contain observed data, since the guide needs to be a properly normalized distribution. Note that Pyro enforces that the `model()` and `guide()` have the same call signature, both callables should take the same arguments.

Since the guide is an approximation to the posterior $p_{\theta_{max}}(z|x)$, the guide needs to provide a valid joint probability density over all the latent random variables in the model. Recall that when random variables are specified in Pyro with the primitive statement `pyro.sample`  the first argument denotes the name of th erandom variable. These names will be used to align the random variables in th emodel and guide. To be very explicity, if the model contains a random variable z_1, then the guide needs to have a matching `sample` statement.

The distributions used in the two cases can be different, but the names must line up 1-to-1.

Once we've specified a guide, we're ready to proceed to inference. Learning will be set up as an optimization problem where each iteration of training takes a step in $\theta - \phi$ space that move the guide closer to the exact posterior. To do this we need to define an appropriate objective function.

## ELBO

A simple derivation yields what we're after: the evidence lower bound. The ELBO which is a function of both $\theta$ and $\phi$ is defined as an expectation w.r.t. samples from the guide.

ELBO = expectation over guides (log p_theta(x,z) - log q_phi(z))

By assumption we can compute the log probabilities inside the expectation. And since the guide is assumed to be a parametric distribution we can sample from, we can compute Monte Carlo estimates of this quantity. Crucially, the ELBO is a lower bound to the log evidence, for all choices of $\theta$ and $\phi$ we have that log p_theta(x) >= ELBO

So if we tke gradient steps to maximize the ELBO, we will alsob e pushing the log evidence higher in expectation. Furthermore it can be shown that the gap between the ELBO and the log evidence is given by the KL divergence of between the guide and posterior.

KL divergence measures closeness between two distributions. For a fixed theta, we move phi to increase ELBO, we decrease the KL divergence, in other words moving the guide toward the posterior. In the general case we take gradient steps in \theta and \phi so that the guide and model both move. This optimization problem can be solved for many problems.

Highlevel. Define guide, compute gradients of ELBO. Some complications, but they are beyond scope of this notebook.

## `SVI` Class

In pyro the machinery for doing variational inference is captured in the SVI class. 

The user needs to provide 3 things. model, guide, and optimizer. We've discussed model and guide above. 

```python
import pyro
from pyro.infer import SVI, Trace_ELBO
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
```

## Optimizers

Model and guide can be arbitrary stochastic functions, if they satisfy the conditions.

some detail that i don't need to repeat

# Example (my data)

two sided coin, query is the coin fair? prior is that it's a US coin, with some wear. 

Encode our prior as a beta distribution, Beta(10,10), symmetric peaked at 0.5.

some observations are stored in `data`

In [13]:
import numpy as np

In [18]:
data = (np.random.random(10) > 0.5).astype(int)

In [19]:
data

array([0, 0, 0, 1, 0, 0, 0, 0, 0, 1])

In [27]:
import pyro.distributions as dist

def model(data):
    # define the hyperparameters that control the beta prior
    α_0 = torch.tensor(10.)
    β_0 = torch.tensor(10.)
    # sample f from the beta distribution prior
    f = pyro.sample("latent_fairness", dist.Beta(α_0, β_0))
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the bernoulli likelihood Bernoulli(f)
        # each data point is a realization of the bernoulli idstribution with the fairness specified by f
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

Next task is to define a guide. The only requirement is that q(f) should be a probability distribution. Simple choice is to use another Beta distribution. This is the 'right' choice in this case because it the the conjugate of the bernoulli. 

In [28]:
def guide(data):
    # register the two variational parameters with Pyro
    α_q = pyro.param("α_q", torch.tensor(15.), constraint=constraints.positive)
    β_q = pyro.param("β_q", torch.tensor(15.), constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(α_q, β_q)
    pyro.sample('latent_fairness', dist.Beta(α_q, β_q))

* names line up exactly
* same args
* variational params are torch tensors
* specify constraints

In [34]:
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
from pyro.distributions import constraints

In [35]:
import pyro

In [36]:
import torch

In [37]:
# set up the optimizer

adam_params = {'lr': 0.0005, 'betas': (0.9, 0.999)}
optimizer = Adam(adam_params)

# set up the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 5000
# do gradient steps
for step in range(n_steps):
    svi.step(data)

In [40]:
import math

In [41]:

# grab the learned variational parameters
α_q = pyro.param("α_q").item()
β_q = pyro.param("β_q").item()

# here we use some facts about the beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = α_q / (α_q + β_q)
# compute inferred standard deviation
factor = β_q / (α_q * (1.0 + α_q + β_q))
inferred_std = inferred_mean * math.sqrt(factor)

print("\nbased on the data and our prior belief, the fairness " +
      "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))


based on the data and our prior belief, the fairness of the coin is 0.397 +- 0.087


# Example (as written)

In [44]:
import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

n_steps = 2000

# enable validation (validate parameters of distributions)
pyro.enable_validation(True)

# clear the param store 
pyro.clear_param_store()

# create some data with 6 obs. heads and 4 obs. tails.
data = []
for _ in range(6):
    data.append(torch.tensor(1.0))
for _ in range(4):
    data.append(torch.tensor(0.))
    
def model(data):
    alpha0 = torch.tensor(10.)
    beta0 = torch.tensor(10.)
    f = pyro.sample('latent_fairness', dist.Beta(alpha0, beta0))
    for i in range(len(data)):
        pyro.sample('obs_{}'.format(i), dist.Bernoulli(f), obs=data[i])

def guide(data):
    alpha_q = pyro.param('alpha_q', torch.tensor(15.0), constraint=constraints.positive)
    beta_q = pyro.param('beta_q', torch.tensor(15.), constraint=constraints.positive)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))
    
optimizer=Adam({})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

for step in range(n_steps):
    svi.step(data)
    if step%100:
        print('.', end='')


........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................................

In [45]:
# grab the learned variational parameters
alpha_q = pyro.param("alpha_q").item()
beta_q = pyro.param("beta_q").item()

# here we use some facts about the beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha_q / (alpha_q + beta_q)
# compute inferred standard deviation
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * math.sqrt(factor)

print("\nbased on the data and our prior belief, the fairness " +
      "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))


based on the data and our prior belief, the fairness of the coin is 0.534 +- 0.090
