### Stochastic Variational Inference (SVI)
In the last notebook, we learn how to build a model using Pyro primities and by composing model using other models. We also learn how to perform variational inference on a toy weight model.

We will discuss in details stochastic variational inference. We focus on learning Pyro to perform variational inference. The theory is kept to the minimum.

As we discussed, a model is a function that make take arguments in general case. The following mapping is important:

1. observations (x) is specificed using ``pyro.sample`` with ``obs`` argument.
2. latent random variables is specified using ``pyro.sample``
3. parameters is specified using ``pyro.param``

Given a full model $p_{\theta}(x,z)$, VI is used to find an optimal $q_{\lambda}(z)$ that approximates $p_{\theta}(z|x)$. Think of $\theta$ as fixed parameters and drop them from the notation for convenience. SVI does this by optimize an objective function called ELBO

$$\mathcal{L}(\lambda) = E_{q_{\lambda}{(z)}}\big[ log(p(x,z) - log(q(z))\big]$$

Concretely, SVI optimize ELBO by taking a gradient steps in the parameter space of $\lambda.$

ELBO stands for evidence lower bound because for every choice of $\theta$ and $\lambda$, we have $\log p(x) \geq ELBO(q)$. 

In general case, we can optimize ELBO wrt both $\lambda$, variational parameters, and $\theta$, model parameters. Intuitively, we have the variational distribution chasing e a moving posterior $\log p_{\theta}(z|x)$.

### SVI using Pyro

Suppose we records a the result of throwing a coin 10 times, where 1 denotes heads and 0 denotes tails.

In [15]:
data = [1., 0., 1.,0.,1., 1., 1., 0., 1., 1.]

We want to reason about the fairness of this coin. Since we don't get to observe this coin ``fairness`` directly, this is a perfect example for latent variable models.

In [17]:
import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 2000

# enable validation (e.g. validate parameters of distributions)
assert pyro.__version__.startswith('0.3.0')
pyro.enable_validation(True)

# clear the param store in case we're in a REPL
pyro.clear_param_store()

# create some data with 6 observed heads and 4 observed tails
data = []
for _ in range(6):
    data.append(torch.tensor(1.0))
for _ in range(4):
    data.append(torch.tensor(0.0))

def model(data):
    # define the hyperparameters that control the beta prior
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

def guide(data):
    # register the two variational parameters with Pyro
    # - both parameters will have initial value 15.0.
    # - because we invoke constraints.positive, the optimizer
    # will take gradients on the unconstrained parameters
    # (which are related to the constrained parameters by a log)
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0),
                        constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# do gradient steps
for step in range(n_steps):
    svi.step(data)
    if step % 100 == 0:
        print('.', end='')

# grab the learned variational parameters
alpha_q = pyro.param("alpha_q").item()
beta_q = pyro.param("beta_q").item()

# here we use some facts about the beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha_q / (alpha_q + beta_q)
# compute inferred standard deviation
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * math.sqrt(factor)

print("\nbased on the data and our prior belief, the fairness " +
      "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))

....................
based on the data and our prior belief, the fairness of the coin is 0.529 +- 0.090


In [18]:
alpha_q

15.876687049865723

In [19]:
beta_q

14.123374938964844