In [1]:
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.ops.indexing import Vindex

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.0.0')
pyro.enable_validation()
pyro.set_rng_seed(0)

In [2]:
"""
 Enumeration can be used either as a stand-alone strategy
 via infer_discrete, or as a component of other strategies.
 Thus enumeration allows Pyro to marginalize out discrete latent variables
 in HMC and SVI models, and to use variational enumeration of discrete variables in SVI guides.
"""
def model():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print('model z = {}'.format(z))

def guide():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print('guide z = {}'.format(z))

elbo = Trace_ELBO()
elbo.loss(model, guide);

guide z = 4
model z = 4


In [3]:
# However under the enumeration interpretation,
# the same sample site will return a fully enumerated set of values,
# based on its distribution’s .enumerate_support() method.

elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "parallel"));

guide z = tensor([0, 1, 2, 3, 4])
model z = tensor([0, 1, 2, 3, 4])


In [4]:
# To support dynamic program structure, you can instead use “sequential” enumeration,
# which runs the entire model,guide pair once per sample value,
# but requires running the model multiple times.

elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "sequential"));

guide z = 4
model z = 4
guide z = 3
model z = 3
guide z = 2
model z = 2
guide z = 1
model z = 1
guide z = 0
model z = 0


In [5]:
@config_enumerate
def model():
    p = pyro.param("p", torch.randn(3, 3).exp(), constraint=constraints.simplex)
    x = pyro.sample("x", dist.Categorical(p[0]))
    y = pyro.sample("y", dist.Categorical(p[x]))
    z = pyro.sample("z", dist.Categorical(p[y]))
    print('model x.shape = {}'.format(x.shape))
    print('model y.shape = {}'.format(y.shape))
    print('model z.shape = {}'.format(z.shape))
    return x, y, z

def guide():
    pass

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, guide);

model x.shape = torch.Size([3])
model y.shape = torch.Size([3, 1])
model z.shape = torch.Size([3, 1, 1])
