# `pyro.contrib.funsor`: a new backend for Pyro

## Summary


## Introduction

In this tutorial we'll cover the basics of `pyro.contrib.funsor`, a new backend for the Pyro probabilistic programming system that is intended to replace the current internals of Pyro and significantly expand its capabilities as both a modelling tool and an inference research platform.

The material here builds on the introduction to Funsor given at (intro tutorial) and on the generic Pyro API `pyroapi`. Additional documentation for Funsor itself can be found on the Pyro website, on GitHub, and in the research paper "Functional Tensors for Probabilistic Programming."

This tutorial is aimed at readers interested in developing custom inference algorithms and understanding Pyro's current and future internals. Those who are less interested in such details should find that they can already use the general-purpose algorithms in `contrib.funsor` with their existing Pyro models via `pyroapi`, as illustrated in (HMM examples).

## Reinterpreting existing Pyro models with `pyroapi`

The new backend uses the [`pyroapi` package](https://pyro.ai/api/) to integrate with existing Pyro code.

First, we import some dependencies:

In [None]:
from collections import OrderedDict

import torch
import funsor

from pyro import set_rng_seed as pyro_set_rng_seed
from pyro.ops.indexing import Vindex
from pyro.poutine.messenger import Messenger

funsor.set_backend("torch")
torch.set_default_dtype(torch.float32)
pyro_set_rng_seed(101)

Importing `pyro.contrib.funsor` registers the `"contrib.funsor"` backend with `pyroapi`, which can now be passed as an argument to the `pyroapi.pyro_backend` context manager.

In [None]:
import pyro.contrib.funsor
import pyroapi
from pyroapi import handlers, infer, ops, optim, pyro
from pyroapi import distributions as dist

# this is already done in pyro.contrib.funsor, but we repeat it here
pyroapi.register_backend("contrib.funsor", ...)  # TODO

And we're off! From here on, any `pyro.(...)` statement should be understood as dispatching to the new backend.

## Two new primitives: `to_funsor` and `to_data`

The first and most important new concept in `pyro.contrib.funsor` is the new pair of primitives `pyro.to_funsor` and `pyro.to_data`.

These are *effectful* versions of `funsor.to_funsor` and `funsor.to_data`, i.e. versions whose behavior can be intercepted, controlled, or used to trigger side effects by Pyro's library of algebraic effect handlers. Let's briefly review these two underlying functions before diving into the effectful versions in `pyro.contrib.funsor`.

As one might expect from the name, `to_funsor` takes as inputs objects that are not `funsor.Funsor`s and attempts to convert them into Funsor terms. For example, calling `funsor.to_funsor` on a Python number converts it to a `funsor.terms.Number` object:

In [None]:
funsor_one = funsor.to_funsor(float(1))
print(funsor_one)

funsor_two = funsor.to_funsor(torch.tensor(2.))
print(funsor_two_)

Similarly ,calling `funsor.to_data` on an atomic `funsor.Funsor` converts it to a regular Python object like a `float` or a `torch.Tensor`:

In [None]:
data_one = funsor.to_data(funsor.terms.Number(float(1), 'real'))
print(data_one)

data_two = funsor.to_data(funsor.Tensor(torch.tensor(2.), OrderedDict(), 'real'))
print(data_two)

In many cases it is necessary to provide an output type to uniquely convert a piece of data to a `funsor.Funsor`. This also means that, strictly speaking, `funsor.to_funsor` and `funsor.to_data` are not inverses. For example, `funsor.to_funsor` will automatically convert Python strings to `funsor.Variable`s, but only when given an output `funsor.domains.Domain`:

In [None]:
var_x = funsor.to_funsor("x", output=funsor.reals(2))
print(var_x)

However, it is often impossible to convert objects to and from Funsor expressions uniquely without additional type information about inputs. Consider the following `torch.Tensor`, which could be converted to a `funsor.Tensor` in several ways.

To resolve this ambiguity, we need to provide `to_funsor` and `to_data` with type information that describes how to convert positional dimensions to and from unordered named Funsor dimensions. This information comes in the form of dictionaries mapping batch dimensions to dimension names or vice versa.

In [None]:
ambiguous_tensor = torch.zeros((3, 1, 2))

# case 1: treat all dimensions as output/event dimensions
funsor1 = funsor.to_funsor(ambiguous_tensor, output=funsor.reals(3, 1, 2))
print(funsor1.inputs, funsor1.output)

# case 2: treat the leftmost dimension as a batch dimension
funsor2 = funsor.to_funsor(ambiguous_tensor, output=funsor.reals(1, 2), dim_to_name={-1: "a"})
print(funsor2.inputs, funsor2.output)

# case 3: treat the leftmost 2 dimensions as batch dimensions; empty batch dimensions are ignored
funsor3 = funsor.to_funsor(ambiguous_tensor, output=funsor.reals(2), dim_to_name={-1: "b", -2: "a"})
print(funsor3.inputs, funsor3.output)

# case 4: treat all dimensions as batch dimensions; empty batch dimensions are ignored
funsor4 = funsor.to_funsor(ambiguous_tensor, output=funsor.reals(), dim_to_name={-1: "c", -2: "b", -3: "a"})
print(funsor4.inputs, funsor4.output)

Similar ambiguity exists for `to_data`: the `inputs` of a `funsor.Funsor` are ordered arbitrarily, and empty dimensions in the data are squeezed away, so a mapping from names to batch dimensions must be provided to ensure unique conversion:

In [None]:
ambiguous_funsor = funsor.Tensor(torch.zeros((3, 2)), OrderedDict(a=funsor.bint(3), b=funsor.bint(2)), 'real')
print(ambiguous_funsor.inputs, ambiguous_funsor.output)

# case 1: the simplest version
tensor1 = funsor.to_data(ambiguous_funsor, name_to_dim={"a": -2, "b": -1})
print(tensor1.shape)

# case 2: an empty dimension between a and b
tensor2 = funsor.to_data(ambiguous_funsor, name_to_dim={"a": -3, "b": -1})
print(tensor2.shape)

# case 3: permuting the input dimensions
tensor3 = funsor.to_data(ambiguous_funsor, name_to_dim={"a": -1, "b": -2})
print(tensor3.shape)

Maintaining and updating this information efficiently becomes tedious and error-prone as the number of conversions increases. Fortunately, it can be automated away completely. Consider the following example:

In [None]:
name_to_dim = OrderedDict()

funsor_x = funsor.Tensor(torch.ones((2,)), OrderedDict(x=funsor.bint(2)), 'real')
name_to_dim.update({"x": -1})
tensor_x = funsor.to_data(funsor_x, name_to_dim=name_to_dim)
print(name_to_dim, funsor_x.inputs, tensor_x.shape)

funsor_y = funsor.Tensor(torch.ones((3, 2)), OrderedDict(y=funsor.bint(3), x=funsor.bint(2)), 'real')
name_to_dim.update({"y": -2})
tensor_y = funsor.to_data(funsor_y, name_to_dim=name_to_dim)
print(name_to_dim, funsor_y.inputs, funsor_y.shape)

funsor_z = funsor.Tensor(torch.ones((2, 3)), OrderedDict(z=funsor.bint(2), y=funsor.bint(3)), 'real')
name_to_dim.update({"z": -3})
tensor_z = funsor.to_data(funsor_z, name_to_dim=name_to_dim)
print(name_to_dim, funsor_z.inputs, funsor_z.shape)

This is exactly the functionality provided by `pyro.to_funsor` and `pyro.to_data`, as we can see by using them in the previous example and removing the manual updates. We must also wrap the function in a `handlers.named` effect handler to ensure that the dimension dictionaries do not persist beyond the function body.

In [None]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    px = ...
    funsor_px = pyro.to_funsor(px, output=funsor.reals())
    print(funsor_px.inputs)
    py = ...
    funsor_py = pyro.to_funsor(py, output=funsor.reals())
    print(funsor_py.inputs)
    pz = ...
    funsor_pz = pyro.to_funsor(py, output=funsor.reals())
    print(funsor_pz.inputs)

Critically, `pyro.to_funsor` and `pyro.to_data` use and update the same bidirectional mapping between names and dimensions, allowing them to be combined intuitively. A typical usage pattern, and one that `pyro.contrib.funsor` uses heavily in its inference algorithm implementations, is to create a `funsor.Funsor` term directly with a new named dimension and call `pyro.to_data` on it, perform some PyTorch computations, and call `pyro.to_funsor` on the result:

In [None]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    probs = funsor.Tensor(torch.tensor([0.5, 0.4, 0.7]), OrderedDict(batch=funsor.bint(3)))
    x = funsor.Tensor(torch.tensor([0., 1., 0., 1.]), OrderedDict(x=funsor.bint(4)))
    dx = dist.Bernoulli(probs=pyro.to_data(probs))
    print(type(dx), dx.shape)
    px = pyro.to_funsor(dx.log_prob(pyro.to_data(x)), output=funsor.reals())
    print(type(px), px.inputs, px.output)

## Dealing with large numbers of variables: (re-)introducing `pyro.markov`

So far, so good. However, what if the number of different named dimensions continues to increase? We face two problems: first, reusing the fixed number of available positional dimensions (25 in PyTorch), and second, computing shape information with time complexity that is independent of the number of variables.

A fully general automated solution to this problem would require deeper integration with Python or PyTorch. Instead, as an intermediate solution, we introduce the second key concept in `pyro.contrib.funsor`: the `pyro.markov` annotation, a way to indicate the shelf life of certain variables. `pyro.markov` is already part of Pyro (see enumeration tutorial) but the implementation in `pyro.contrib.funsor` is fresh.

The primary constraint on the design of `pyro.markov` is backwards compatibility: in order for `pyro.contrib.funsor` to be compatible with the large range of existing Pyro models, the new implementation had to match the shape semantics of Pyro's existing enumeration machinery as closely as possible. See the final section of this tutorial for more on this design decision.

In [None]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    for i in pyro.markov(range(10)):
        x = pyro.to_data(funsor.Tensor(torch.tensor([0., 1.]), OrderedDict({"x{}".format(i): funsor.bint(2)})))
        print(i, x.shape)

`pyro.markov` is a versatile piece of syntax that can be used as a context manager, a decorator, or an iterator. It is important to understand that `pyro.markov`'s only functionality at present is tracking variable usage, not directly indicating conditional independence properties to inference algorithms, and as such it is only necessary to add enough annotations to ensure that tensors have correct shapes, rather than attempting to manually encode as much dependency information as possible.

`pyro.markov` takes an additional argument `history` that determines the number of previous `pyro.markov` contexts to take into account when building the mapping between names and dimensions at a given `pyro.to_funsor`/`pyro.to_data` call.

In [None]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    for i in pyro.markov(range(10), history=2):
        x = pyro.to_data(funsor.Tensor(torch.tensor([0., 1.]), OrderedDict({"x{}".format(i)})))
        print(i, x.shape)

## Use cases beyond enumeration: global and visible dimensions

### Global dimensions

It is sometimes useful to have dimensions and variables ignore the `pyro.markov` structure of a program. For example, suppose we wanted to draw a batch of samples from a Pyro model's joint distribution. To accomplish this we indicate to `pyro.to_data` that a dimension should be treated as "global" (`DimTypes.GLOBAL`) via the `dim_type` keyword argument.

In [None]:
from pyro.contrib.funsor.runtime import _DIM_STACK, DimTypes

with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    funsor_particle_ids = funsor.Tensor(torch.arange(10), OrderedDict(n=funsor.bint(10)))
    tensor_particle_ids = pyro.to_data(funsor_particle_ids, dim_type=DimTypes.GLOBAL)
    print(funsor_particle_ids.inputs, tensor_particle_ids.shape)

`pyro.markov` does the hard work of automatically managing local dimensions, but because global dimensions ignore this structure, they must be deallocated manually or they will persist until the last active effect handler exits, just as global variables in Python persist until a program execution finishes.

In [None]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    
    funsor_plate1_ids = funsor.Tensor(torch.arange(10), OrderedDict(plate1=funsor.bint(10)))
    tensor_plate1_ids = pyro.to_data(funsor_plate1_ids, dim_type=DimTypes.GLOBAL)
    print(funsor_plate1_ids.inputs, tensor_plate1_ids.shape)
    
    funsor_plate2_ids = funsor.Tensor(torch.arange(9), OrderedDict(plate2=funsor.bint(9)))
    tensor_plate2_ids = pyro.to_data(funsor_plate2_ids, dim_type=DimTypes.GLOBAL)
    print(funsor_plate2_ids.inputs, tensor_plate2_ids.shape)
    
    _DIM_STACK.global_frame.free("plate1", -1)
    
    funsor_plate3_ids = funsor.Tensor(torch.arange(10), OrderedDict(plate3=funsor.bint(10)))
    tensor_plate3_ids = pyro.to_data(funsor_plate1_ids, dim_type=DimTypes.GLOBAL)
    print(funsor_plate3_ids.inputs, tensor_plate3_ids.shape)

Performing this deallocation directly is often unnecessary, and we include this interaction primarily to illuminate the internals of `pyro.contrib.funsor`. Instead, effect handlers that introduce global dimensions, like `pyro.plate`, may inherit from the `pyro.contrib.handlers.named_messenger.GlobalNamedMessenger` effect handler which deallocates global dimensions generically upon entry and exit. We will see an example of this in the next tutorial.

### Visible dimensions

We might also wish to preserve the meaning of the shape of a tensor of data. For this we indicate to `pyro.to_data` that a dimension should be treated as not merely global but "visible" (`DimTypes.VISIBLE`). We first need to indicate the maximum number of "visible" dimensions (an artifact of Pyro's design). Users who have come across `pyro.infer.TraceEnum_ELBO`'s `max_plate_nesting` argument are already familiar with this.

In [None]:
from pyro.contrib.funsor.runtime import _DIM_STACK, DimTypes

prev_first_available_dim = _DIM_STACK.set_first_available_dim(-2)

with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    funsor_local_ids = funsor.Tensor(torch.arange(10), OrderedDict(k=funsor.bint(9)))
    tensor_local_ids = pyro.to_data(funsor_local_ids, dim_type=DimTypes.LOCAL)
    print(funsor_local_ids.inputs, tensor_local_ids.shape)
    
    funsor_global_ids = funsor.Tensor(torch.arange(10), OrderedDict(n=funsor.bint(10)))
    tensor_global_ids = pyro.to_data(funsor_global_ids, dim_type=DimTypes.GLOBAL)
    print(funsor_global_ids.inputs, tensor_local_ids.shape)
    
    funsor_data_ids = funsor.Tensor(torch.arange(10), OrderedDict(m=funsor.bint(11)))
    tensor_data_ids = pyro.to_data(funsor_data_ids, dim_type=DimTypes.VISIBLE)
    print(funsor_data_ids.inputs, tensor_data_ids.shape)
    
# we also need to reset the first_available_dim after we're done
_DIM_STACK.set_first_available_dim(prev_first_available_dim)

Visible dimensions are also global and must therefore be deallocated manually or they will persist until the last effect handler exits, as in the previous example.

Again, interacting directly with the dimension allocator is almost always unnecessary, and we include it here only to illuminate the inner workings of `pyro.contrib.funsor`; rather, effect handlers like `pyro.handlers.enum` that may introduce non-visible dimensions that could conflict with visible dimensions should inherit from the base `pyro.contrib.funsor.handlers.named_messenger.BaseEnumMessenger` effect handler. We will see an example of this in the next tutorial.

## Appendix: design rationale

A final note: to some extent, `pyro.contrib.funsor` is a workaround for lack of easy program analysis and manipulation in Python and PyTorch. However, from PyTorch 1.5 and beyond, we expect this underlying issue to largely disappear, and hope to support increasing (but selective) use of Funsor in model code, not just inference code, and more systematically support propagating real-valued `funsor.Variable`s in model code for fully lazy inference evaluation. See `funsor.minipyro` for more on this.