# `pyro.contrib.funsor`: a new backend for Pyro (pt. 2)

In [1]:
from collections import OrderedDict
import functools

import torch
from torch.distributions import constraints

import funsor

from pyro import set_rng_seed as pyro_set_rng_seed
from pyro.ops.indexing import Vindex
from pyro.poutine.messenger import Messenger

funsor.set_backend("torch")
torch.set_default_dtype(torch.float32)
pyro_set_rng_seed(101)

## Introduction

In part 1 of this tutorial, we were introduced to the new `pyro.contrib.funsor` backend for Pyro.

Here we'll look at how to use the components in `pyro.contrib.funsor` to implement a variable elimination inference algorithm from scratch. As before, we'll use `pyroapi` so that we can write our model with standard Pyro syntax.

In [2]:
import pyro.contrib.funsor
import pyroapi
from pyroapi import infer, handlers, ops, optim, pyro
from pyroapi import distributions as dist

We will be working with the following model throughout:

In [3]:
def model(data):

    p = pyro.param("probs", lambda: torch.rand((3, 3)), constraint=constraints.simplex)
    locs_mean = pyro.param("locs_mean", lambda: torch.ones((3,)))
    locs = pyro.sample("locs", dist.Normal(locs_mean, 1.).to_event(1))
    print("locs", locs.shape)

    x = 0
    for i in pyro.markov(range(len(data))):
        x = pyro.sample("x{}".format(i), dist.Categorical(p[x]), infer={"enumerate": "parallel"})
        print("x{}".format(i), x.shape)
        pyro.sample("y{}".format(i), dist.Normal(Vindex(locs)[..., x], 1.), obs=data[i])

## Enumerating discrete variables

Our first step is to implement an effect handler that performs parallel enumeration of discrete latent variables. We'll do that by constructing a `funsor.Tensor` representing the support of each discrete latent variable and using the new `pyro.to_data` primitive from part 1 to convert it to a `torch.Tensor` with the appropriate shape.

In part 1 we also saw that it was necessary to provide the number of "visible" dimensions used in a block of code by calling `_DIM_STACK.set_first_available_dim`. To avoid this tedious bit of bookkeeping, we'll have our enumeration effect handler inherit from the `BaseEnumMessenger` class provided in `pyro.contrib.funsor`, which takes care of setting `first_available_dim` and resetting it after the handler has exited. Our enumeration handler's constructor will take a `first_available_dim` keyword argument because of this, just like Pyro's `poutine.enum`.

In [4]:
from pyro.contrib.funsor.handlers.named_messenger import BaseEnumMessenger

class EnumMessenger(BaseEnumMessenger):
    
    # although our __init__ does not do anything extra, we specify it explicitly here
    # to show the argument that needs to be passed to `BaseEnumMessenger.__init__`.
    def __init__(self, first_available_dim=-1):
        super().__init__(first_available_dim=first_available_dim)
    
    @pyroapi.pyro_backend("contrib.funsor")  # necessary since we invoke pyro.to_data
    def _pyro_sample(self, msg):
        if msg["done"] or msg["is_observed"] or msg["infer"].get("enumerate") != "parallel":
            return

        raw_value = msg["fn"].enumerate_support(expand=False)
        size = raw_value.numel()
        funsor_value = funsor.Tensor(
            raw_value.squeeze(), OrderedDict([(msg["name"], funsor.bint(size))]), size)
        
        msg["value"] = pyro.to_data(funsor_value)
        msg["done"] = True

## Vectorizing a model across multiple samples

Next, since our priors over global variables are continuous and cannot be enumerated exactly, we will implement an effect handler that uses a global dimension to draw multiple samples in parallel from the model.

Recall that in part 1 we saw that `DimType.GLOBAL` dimensions must be deallocated manually or they will persist until the final effect handler has exited. This low-level detail is taken care of automatically by the `GlobalNameMessenger` handler provided in `pyro.contrib.funsor` as a base class for any effect handlers that allocate global dimensions. Our vectorization effect handler will inherit from this class.

In [5]:
from pyro.poutine.broadcast_messenger import BroadcastMessenger
from pyro.poutine.indep_messenger import CondIndepStackFrame

from pyro.contrib.funsor.handlers.named_messenger import GlobalNamedMessenger
from pyro.contrib.funsor.handlers.runtime import DimType

class VectorizeMessenger(GlobalNamedMessenger):
    
    def __init__(self, size, name="_PARTICLES"):
        super().__init__()
        self.name = name
        self.size = size
        self._indices = funsor.Tensor(
            torch.arange(0, self.size),
            OrderedDict([(self.name, funsor.bint(self.size))])
        )

    @pyroapi.pyro_backend("contrib.funsor")
    def __enter__(self):
        super().__enter__()  # do this first to take care of global dim recycling
        # Here we indicate that the vectorization dimension is a DimType.GLOBAL dimension,
        # as opposed to a DimType.VISIBLE dimension that we would use in pyro.plate
        indices = pyro.to_data(self._indices, dim_type=DimType.GLOBAL)
        # extract the dimension allocated by to_data to match plate's current behavior
        self.dim, self.indices = -indices.dim(), indices.squeeze()
        return self

    def _pyro_sample(self, msg):
        frame = CondIndepStackFrame(self.name, self.dim, self.size, 0)
        msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
        BroadcastMessenger._pyro_sample(msg)

## Computing an ELBO with variable elimination

Our final effect handler will build up a lazy Funsor expression for the marginal likelihood, the loss function we'll be using to train our approximate posterior over the global variables.

In [6]:
class LogJointMessenger(Messenger):

    @pyroapi.pyro_backend("contrib.funsor")
    def __enter__(self):
        self.log_joint = pyro.to_funsor(0., funsor.reals())
        return super().__enter__()

    @pyroapi.pyro_backend("contrib.funsor")
    def _pyro_post_sample(self, msg):
        with funsor.interpreter.interpretation(funsor.terms.lazy):
            funsor_dist = pyro.to_funsor(msg["fn"], output=funsor.reals())
            self.log_joint += funsor_dist(value=pyro.to_funsor(msg["value"], funsor_dist.inputs["value"]))

And the actual loss function:

In [7]:
@pyroapi.pyro_backend("contrib.funsor")
def log_z(model, *args):
    with LogJointMessenger() as tr, EnumMessenger(), VectorizeMessenger(20):
        model(*args)

    with funsor.interpreter.interpretation(funsor.terms.lazy):
        prod_vars = frozenset({"_PARTICLES"})
        sum_vars = frozenset(tr.log_joint.inputs) - prod_vars
        expr = tr.log_joint.reduce(funsor.ops.logaddexp, sum_vars).reduce(funsor.ops.add, prod_vars)

    return pyro.to_data(funsor.optimizer.apply_optimizer(expr))

## Putting it all together

Finally, with all this machinery implemented, we can perform inference in our model.

In [9]:
data = [torch.tensor(1.)] * 10
log_marginal = log_z(model, data)
params = [pyro.param("probs").unconstrained(), pyro.param("locs_mean").unconstrained()]
print(log_marginal, torch.autograd.grad(log_marginal, params))

x0 torch.Size([3, 1])
x1 torch.Size([3, 1, 1])
x2 torch.Size([3, 1])
x3 torch.Size([3, 1, 1])
x4 torch.Size([3, 1])
x5 torch.Size([3, 1, 1])
x6 torch.Size([3, 1])
x7 torch.Size([3, 1, 1])
x8 torch.Size([3, 1])
x9 torch.Size([3, 1, 1])
tensor(-333.6899, grad_fn=<SumBackward0>) (tensor([[ 3.1844,  1.0427, -4.2271],
        [ 0.4608,  0.1888, -0.6496],
        [-1.3316, -0.3965,  1.7282]]), tensor([ 7.4881,  5.9110, 17.3450]))
