# Inference with discrete latent variables

This tutorial describes Pyro's enumeration strategy for discrete latent variable models.
First read the [Tensor Shapes Tutorial](http://pyro.ai/examples/tensor_shapes.html) and the [Gaussian Mixture Model Tutorial](http://pyro.ai/examples/gmm.html).

#### Summary 

- Pyro implements automatic enumeration over discrete latent variables.
- This strategy can be used alone or inside SVI (via [TraceEnum_ELBO](http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.traceenum_elbo.TraceEnum_ELBO)), HMC, or NUTS.
- Annotate a sample site `infer={"enumerate": "parallel"}` to trigger enumeration.
- If a sample site determines downstream structure, instead use `{"enumerate": "sequential"}`.
- Write your models to allow arbitrarily deep batching on the left, e.g. use broadcasting.
- Inference cost is exponential in treewidth, so try to write models with narrow treewidth.
- If you have trouble, let us know on [forum.pyro.ai](https://forum.pyro.ai)!

#### Table of contents

- [Overview](#Overview)
- [Mechanics of enumeration](#Mechanics-of-enumeration)

In [1]:
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate

pyro.enable_validation()
pyro.set_rng_seed(0)

## Overview <a class="anchor" id="Overview"></a>

Pyro's enumeration strategy encompasses popular algorithms including variable elimination, exact message passing, forward-filter-backward-sample, inside-out, Baum-Welch, and many other special-case algorithms. Aside from enumeration, Pyro implements a number of inference strategies including variational inference ([SVI](http://docs.pyro.ai/en/dev/inference_algos.html)) and monte carlo ([HMC](http://docs.pyro.ai/en/dev/mcmc.html#pyro.infer.mcmc.HMC) and [NUTS](http://docs.pyro.ai/en/dev/mcmc.html#pyro.infer.mcmc.NUTS)). Enumeration can be used either as a stand-alone strategy, or as a component of other strategies. Thus enumeration allows Pyro do marginalize out discrete latent variables in HMC and SVI models, and to use variational enumeration of discrete variables in SVI guides.

## Mechanics of enumeration  <a class="anchor" id="Mechanics-of-enumeration"></a>

The core idea of enumeration is to interpret discrete [pyro.sample](http://docs.pyro.ai/en/dev/primitives.html#pyro.sample) statements as full enumeration rather than random sampling. Other inference algorithms can then sum out the enumerated values. For example a sample statement might return scalar shape under the standard "sample" interpretation (we'll illustrate with trivial model and guide):

In [2]:
def model():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print('model z = {}'.format(z))

def guide():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print('guide z = {}'.format(z))

elbo = Trace_ELBO()
elbo.loss(model, guide);

guide z = 4
model z = 4


However under the enumeration interpretation, the same sample site will return a fully enumerated set of values, based on its distribution's [.enumerate_support()](https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution.enumerate_support) method.

In [3]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "parallel"));

guide z = tensor([ 0,  1,  2,  3,  4])
model z = tensor([ 0,  1,  2,  3,  4])


Note that we've used "parallel" enumeration to enumerate along a new tensor dimension. This is cheap and allows Pyro to parallelize computation, but requires downstream program structure to avoid branching on the value of `z`. To support dynamic program structure, you can instead use "sequential" enumeration, which runs the entire model,guide pair once per sample value.

In [4]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "sequential"));

guide z = 4
model z = 4
guide z = 3
model z = 3
guide z = 2
model z = 2
guide z = 1
model z = 1
guide z = 0
model z = 0


Parallel enumeration is cheaper but more complex than sequential enumeration, so we'll focus the rest of this tutorial on the parallel variant. Note that both forms can be interleaved.

### Multiple latent variables

We just saw that a single discrete sample site can be enumerated via nonstandard interpretation. A model with a single discrete latent variable is a mixture model. Models with multiple discrete latent variables can be more complex, including HMMs, CRFs, DBNs, and other structured models. In models with multiple discrete latent variables, Pyro enumerates each in a different tensor dimension (counting from the right; see [Tensor Shapes Tutorial](http://pyro.ai/examples/tensor_shapes.html)). This allows Pyro to determine the dependency graph among varaibles and then perform cheap exact inference using variable elimination algorithms.

To understand enumeration dimension allocation, consider the following model, where here we collapse variables out of the model, rather than enumerate them in the guide.

In [5]:
@config_enumerate(default="parallel")
def model():
    p = pyro.param("p", torch.randn(3, 3).exp())
    x = pyro.sample("x", dist.Categorical(p[0]))
    y = pyro.sample("y", dist.Categorical(p[x]))
    z = pyro.sample("z", dist.Categorical(p[y]))
    print('model x.shape = {}'.format(x.shape))
    print('model y.shape = {}'.format(y.shape))
    print('model z.shape = {}'.format(z.shape))
    
def guide():
    pass

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, guide);

model x.shape = torch.Size([3])
model y.shape = torch.Size([3, 1])
model z.shape = torch.Size([3, 1, 1])


## Plates and enumeration

## Time series example

## How to write a tractable model

## Whether to enumerate in the guide or model