In [None]:
%matplotlib inline
# import some dependencies
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable

import pyro
import pyro.distributions as dist

# Inference in Pyro: from stochastic functions to marginal distributions

Stochastic functions induce a joint probability distribution over their latent variables and return values. For non-primitive stochastic functions, we can no longer explicitly compute the probability of an output `p(y | x)` or draw samples from the marginal distribution over return values `y ~ p( . | x)`.

*Inference* in a universal PPL like Pyro is the problem of constructing the marginal distribution over return values of a stochastic function given inputs so that we can perform these computations.  Pyro accomplishes this by collecting a number of weighted execution traces of the function, then collapsing those traces into a histogram over possible return values given a particular set of arguments.

*Bayesian inference* is the 

## A simple example

Consider the following model:

In [None]:
def normal_normal_model(x):
    z = pyro.sample("z", dist.diagnormal, x, Variable(torch.ones(1)))
    y = pyro.sample("y", dist.diagnormal, z, Variable(torch.ones(1)))
    return y

Collecting execution histories can be done either through sampling or, for models with only discrete latent variables, exact enumeration.  To create a basic importance sampler over execution traces (using the prior as the proposal distribution), we can write:

In [None]:
posterior = pyro.infer.Importance(normal_normal_model)

`posterior` is not a particularly useful object on its own, though advanced users can call it with arguments for `normal_normal_model` to sample a raw execution trace.  Instead, `posterior` must be consumed by `pyro.infer.Marginal`, which creates a primitive stochastic function with the same input and output types as `normal_normal_model`.  

In [None]:
marginal = pyro.infer.Marginal(posterior)

When called with an input `x`, `marginal` first uses `posterior` to generate a sequence of weighted execution traces given `x`, then builds a histogram over return values from the traces, and finally returns a sample drawn from the histogram.  

`pyro.infer.Marginal` also accepts the optional keyword argument `sites=[name1, name2, ...]` that provides a list of names of latent variables.  When `sites` is specified, `marginal` will return a dictionary where the keys are the names in `sites` and the values are values at those sites from a single execution trace.  This is useful because we may wish to compute many different marginals from the same posterior object.

## Conditioning models on data

The real utility of probabilistic programming is in the ability to condition generative models on observed data. In Pyro, we separate the expression of conditioning from its evaluation via inference, making it possible to write a model once and condition it on many different observations.

Consider `normal_normal_model` once again.  Suppose we want to sample from the marginal distribution of `z` given input `x = 0`, but now we have observed that `y == 1`.  Pyro provides the function `pyro.condition` to allow us to constrain the values of sample statements.  `pyro.condition` is a higher-order function that takes a model and a dictionary of data and returns a new model that has the same input and output signatures but always uses the same value at observed `sample` statements:

In [None]:
conditioned_normal_normal_model = pyro.condition(normal_normal_model, 
                                                 data={"y": Variable(torch.ones(1))})

Because it behaves just like an ordinary Python function, conditioning can be deferred or parametrized with Python's `lambda`:

In [None]:
def deferred_conditioned_normal_normal_model(data, *args, **kwargs):
    return pyro.condition(model, data=data)(*args, **kwargs)

In some cases it might be more convenient to pass observations directly to individual `pyro.sample` statements instead of using `pyro.condition`.  The optional `obs` keyword argument is reserved by `pyro.sample` for that purpose; in addition, `pyro.observe` is an alias for `pyro.sample` with `obs`:

In [None]:
# equivalent to pyro.condition(model, data={"y": torch.ones(1)})
def conditioned_normal_normal_model_2(x):
    z = pyro.sample("z", diagnormal, x, Variable(torch.ones(1)))
    y = pyro.sample("y", diagnormal, z, Variable(torch.ones(1)), 
                    obs=Variable(torch.ones(1))) # here we attach an observation y == 1
    return y

# equivalent to conditioned_normal_normal_model_2:
def conditioned_normal_normal_model_3(x):
    z = pyro.sample("z", diagnormal, x, Variable(torch.ones(1)))
    y = pyro.observe("y", diagnormal, Variable(torch.ones(1)), # here we attach an observation y == 1
                     z, Variable(torch.ones(1)))
    return y

However, hardcoding is not usually recommended due to its invasive non-compositional nature.  By contrast, using `pyro.condition`, conditioning may be composed freely to form multiple complex queries on probabilistic models without modifying the underlying model.  The only restriction is that a single site may only be constrained once.

In [None]:
def model2():
    mu = pyro.sample("mu", diagnormal, Variable(torch.zeros(1)), Variable(torch.ones(1)))
    sigma = torch.exp(pyro.sample("log_sigma", diagnormal, 
                                  Variable(torch.zeros(1)), Variable(torch.ones(1))))
    x = pyro.sample("x", diagnormal, mu, sigma)
    return sigma

# conditioning composes: 
# the following are all equivalent and do not interfere with each other
conditioned_model2_1 = pyro.condition(
    pyro.condition(model2, data={"mu": Variable(torch.ones(1))}), 
    data={"x": Variable(torch.ones(1))})

conditioned_model2_2 = pyro.condition(
    pyro.condition(model2, data={"x": Variable(torch.ones(1))}), 
    data={"mu": Variable(torch.ones(1))})

conditioned_model2_3 = pyro.condition(model2, data={"x": Variable(torch.ones(1)), "mu": Variable(torch.ones(1))})

In addition to conditioning for incorporating observations, Pyro also contains `pyro.do`, an implementation of Pearl's `do`-operator used for causal inference with an identical interface to `pyro.condition`.  `condition` and `do` can be mixed and composed freely, making Pyro a powerful tool for model-based causal inference.  See the [causal inference tutorial](https://pyro.ai/examples/causal_inference.html) for more details about `pyro.do` and a simple example of causal inference in a model of disease diagnosis.

## Flexible approximate inference with guide functions

Consider `deferred_conditioned_normal_normal_model`.  Now that we have constrained `y` against some data, we can approximate the marginal distribution over `z` given `x` and `y == data` with importance sampling.  

`pyro.infer.Importance` allows us to use arbitrary stochastic functions, which we will call *guide functions*, as proposal distributions, provided that:
1. all unobserved sample statements that appear in the model appear in the guide.
2. the guide has the same input signature as the model

The guide function should be chosen so that it closely approximates the distribution over all unobserved `sample` statements in the model.  The simplest guide for `deferred_conditioned_normal_normal_model` matches the prior distribution over `z`:

In [None]:
def normal_normal_prior_guide(y, x):
    return pyro.sample("z", diagnormal, x, Variable(torch.ones(1)))

posterior = pyro.infer.Importance(deferred_conditioned_normal_normal_model, 
                                  guide=normal_normal_prior_guide,
                                  num_samples=20)

marginal = pyro.infer.Marginal(posterior, sites=["z"])

Of course, the prior distribution is generally not a very good model of the posterior distribution.  For `normal_normal_model`, it can be shown that the posterior distribution over `z` given `x` can be written as:

In [None]:
def normal_normal_posterior_guide(y, x):
    a = (x + torch.sum(y)) / (torch.size(y, 0) + 1.0)
    b = Variable(torch.ones(1)) / (torch.size(y, 0) + 1.0)
    return pyro.sample("z", diagnormal, a, b)

Guide functions can serve as programmable, data-dependent proposal distributions for importance sampling, rejection sampling, sequential Monte Carlo, MCMC, and independent Metropolis-Hastings, and as variational distributions or inference networks for stochastic variational inference.  Currently, only importance sampling and stochastic variational inference are implemented in Pyro, but the other algorithms will be added in the future.

## Parametrized stochastic functions and variational inference

Although we could write out the exact posterior distribution for `normal_normal_model`, in general it is intractable to specify a guide that is a good approximation to the posterior distribution of an arbitrary conditioned stochastic function.  What we can do instead is use the top-level function `pyro.param` to specify a *family* of guides indexed by named parameters, and search for the member of that family that is the best approximation.  This approach to approximate posterior inference is called *variational inference*.

`pyro.param` is a frontend for Pyro's key-value *parameter store*, described in more detail in the [SVI tutorial](https://pyro.ai/examples/svi_part_i.html). Like `pyro.sample`, `pyro.param` is always called with a name as its first argument.  The first time `pyro.param` is called with a particular name, it stores its argument in the parameter store and then returns that value.  After that, when it is called with that name, it returns the value from the parameter store regardless of any other arguments.  It is similar to `simple_param_store.setdefault` here, but with some additional tracking and management functionality.

```python
simple_param_store = {}
a = simple_param_store.setdefault("a", torch.randn(1))
```

For example, we can parametrize `a` and `b` in `normal_normal_posterior_guide` instead of specifying them by hand:

In [None]:
def normal_normal_parametrized_guide(y, x):
    a = pyro.param("a", Variable(torch.randn(1), requires_grad=True))
    b = pyro.param("b", Variable(torch.exp(torch.randn(1)), requires_grad=True))
    return pyro.sample("z", diagnormal, a, b)

Pyro is built to enable *stochastic variational inference*, where parameters are always real-valued tensors, we compute Monte Carlo estimates of a loss function from samples of execution histories of the model and guide, and use stochastic gradient descent to search for the optimal parameters.  Combining stochastic gradient descent with PyTorch's GPU-accelerated tensor math and automatic differentiation allows us to scale variational inference to very high-dimensional parameter spaces and massive datasets.  

Pyro's SVI functionality is described in detail in the [SVI tutorial](https://pyro.ai/examples/svi_part_i.html). Here is a very simple example applying it to `normal_normal_model`:

In [None]:
svi = pyro.infer.SVI(model=deferred_conditioned_normal_normal_model, 
                     guide=normal_normal_parametrized_guide,
                     optimizer=pyro.optim.SGD({"lr": 0.001}),
                     loss="ELBO")

for _ in range(10):
    svi.step(Variable(torch.ones(1)), Variable(torch.ones(1)))

Note that optimization will update the guide parameters, but does not produce a posterior distribution object itself. Once we find good parameter values, we can use the guide as a representation of the model's approximate posterior for downstream tasks.  

For example, we can use the optimized guide as an importance distribution for estimating the marginal distribution over `z` with many fewer samples than the prior:

In [None]:
posterior = pyro.infer.Importance(deferred_conditioned_normal_normal_model, normal_normal_parametrized_guide)
marginal_z = pyro.infer.Marginal(posterior, sites=["z"])

# Next steps

In the [Variational Autoencoder tutorial](https://pyro.ai/examples/vae.html), we'll augment `normal_normal_model` and `normal_normal_parametrized_guide` with deep neural networks and use stochastic variational inference to build a generative model of images.