## Variational Inference

at high level variational inference is easy: all we need to do is define a guide and compute gradients of the ELBO. Actually, computing gradients for general model and guide pairs leads to some complications (see the tutorial SVI Part III for a discussion). For the purposes of this tutorial, let’s consider that a solved problem and look at the support that Pyro provides for doing variational inference.

in Pyro the user needs to provide three things: the model, the guide, and an optimizer. We’ve discussed the model and guide above and we’ll discuss the optimizer in some detail below, so let’s assume we have all three ingredients at hand.

The SVI object provides two methods, step() and evaluate_loss(), that encapsulate the logic for variational learning and evaluation:

1. The method step() takes a single gradient step and returns an estimate of the loss (i.e. minus the ELBO). If provided, the arguments to step() are piped to model() and guide().
2. The method evaluate_loss() returns an estimate of the loss without taking a gradient step. Just like for step(), if provided, arguments to evaluate_loss() are piped to model() and guide().

For the case where the loss is the ELBO, both methods also accept an optional argument num_particles, which denotes the number of samples used to compute the loss (in the case of evaluate_loss) and the loss and gradient (in the case of step).


## A simple example

We finish with a simple example. You’ve been given a two-sided coin. You want to determine whether the coin is fair or not, i.e. whether it falls heads or tails with the same frequency. You have a prior belief about the likely fairness of the coin based on two observations:

* it’s a standard quarter issued by the US Mint
* it’s a bit banged up from years of use

So while you expect the coin to have been quite fair when it was first produced, you allow for its fairness to have since deviated from a perfect 1:1 ratio. So you wouldn’t be surprised if it turned out that the coin preferred heads over tails at a ratio of 11:10. By contrast you would be very surprised if it turned out that the coin preferred heads over tails at a ratio of 5:1—it’s not that banged up.

To learn something about the fairness of the coin that is more precise than our somewhat vague prior, we need to do an experiment and collect some data. Let’s say we flip the coin 10 times and record the result of each flip. In practice we’d probably want to do more than 10 trials.

Here we have a single latent random variable ('latent_fairness'), which is distributed according to $Beta(10,10)$. Conditioned on that random variable, we observe each of the datapoints using a Bernoulli likelihood. Note that each observation is assigned a unique name in Pyro.

Our next task is to define a corresponding guide, i.e. an appropriate variational distribution for the latent random variable $f$. The only real requirement here is that $q(f)$ should be a probability distribution over the range [0.0,1.0], since $f$ doesn’t make sense outside of that range. A simple choice is to use another Beta distribution parameterized by two trainable parameters $aq$ and $Bq$ and. Actually, in this particular case this is the ‘right’ choice, since conjugacy of the Bernoulli and Beta distributions means that the exact posterior is a Beta distribution.

There are a few things to note here:

* We’ve taken care that the names of the random variables line up exactly between the model and guide.
* model(data) and guide(data) take the same arguments. The variational parameters are torch.tensors. The requires_grad flag is automatically set to True by pyro.param.
* We use constraint=constraints.positive to ensure that alpha_q and beta_q remain non-negative during optimization. Under the hood an exponential transform ensures positivity.

In [5]:
import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 2000

assert pyro.__version__.startswith('1.8.4')

# clear the param store in case we're in a REPL
pyro.clear_param_store()

# create some data with 6 observed heads and 4 observed tails
data = []
for _ in range(6):
    data.append(torch.tensor(1.0))
for _ in range(4):
    data.append(torch.tensor(0.0))

def model(data):
    # define the hyperparameters that control the Beta prior
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    # sample f from the Beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the Bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

def guide(data):
    # register the two variational parameters with Pyro
    # - both parameters will have initial value 15.0.
    # - because we invoke constraints.positive, the optimizer
    # will take gradients on the unconstrained parameters
    # (which are related to the constrained parameters by a log)
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0),
                        constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# do gradient steps
for step in range(n_steps):
    svi.step(data)
    if step % 100 == 0:
        print('.', end='')

# grab the learned variational parameters
alpha_q = pyro.param("alpha_q").item()
beta_q = pyro.param("beta_q").item()

# here we use some facts about the Beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha_q / (alpha_q + beta_q)
# compute inferred standard deviation
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * math.sqrt(factor)

print("\nBased on the data and our prior belief, the fairness " +
      "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))

....................
Based on the data and our prior belief, the fairness of the coin is 0.535 +- 0.090


This estimate is to be compared to the exact posterior mean, which in this case is given by $16/30=0.53$.
Note that the final estimate of the fairness of the coin is in between the the fairness preferred by the prior (namely 0.50)
and the fairness suggested by the raw empirical frequencies ($6/10=0.60$).