# `pyro.contrib.funsor`: a new backend for Pyro (pt. 3)

## Introduction

In Part 1 of this tutorial, we looked at new low-level functionality, especially the new primitives `pyro.to_funsor` and `pyro.to_data`, and in Part 2 we saw a simple example of how those low-level components could be used to implement a powerful inference algorithm.

In this, the final part of the `pyro.contrib.funsor` tutorial, we'll look at how the new implementations of higher-level Pyro machinery can be used together with Funsor to drastically simplify the implementation of Pyro's most powerful general-purpose inference engine, `pyro.infer.TraceEnum_ELBO`. What we'll end up with is not a toy: it's a fully functional version of `TraceEnum_ELBO`.

In [None]:
import torch
import funsor

from pyro import set_rng_seed as pyro_set_rng_seed
from pyro.infer import ELBO
from pyro.infer import TraceEnum_ELBO as OrigTraceEnum_ELBO

funsor.set_backend("torch")
torch.set_default_tensor_type(torch.float32)
pyro_set_rng_seed(101)

As before we'll make use of `pyroapi` to use `pyro.contrib.funsor` with existing models.

In [None]:
import pyro.contrib.funsor
import pyroapi
from pyroapi import handlers, infer, ops, optim, pyro
from pyroapi import distributions as dist

We'll start by reviewing what's actually being computed. Readers who need a refresher on the basics of variational inference in Pyro should check out the variational inference tutorials.

(TODO)

Now we're ready to dive into the actual computation. We'll start from the topmost level (defining the `TraceEnum_ELBO` class) and work our way through each subroutine of the algorithm.

In [None]:
class TraceEnum_ELBO(ELBO):

    def _get_trace(self, *args, **kwargs):  # must be defined to avoid NotImplementedError
        raise ValueError("shouldn't be here")

    @pyroapi.pyro_backend("contrib.funsor")
    def differentiable_loss(self, model, guide, *args, **kwargs):
        
        # get traces: this part is exactly the same in Pyro and the new backend
        model_tr, guide_tr = get_traces(model, guide, -self.max_plate_nesting-1, *args, **kwargs)

        # extract terms from the model and guide traces
        model_terms, guide_terms = accumulate_terms(model_tr, guide_tr)

        # contract out auxiliary variables in the model
        model_costs = integrate_model_vars(model_terms["log_measures"], model_terms["log_factors"],
                                           model_terms["measure_vars"] - guide_terms["measure_vars"],
                                           model_terms["plate_vars"] | guide_terms["plate_vars"])
        
        # compute guide costs (-log(q) terms)
        guide_costs = [-log_q for log_q in guide_terms["log_factors"]]

        # integrate out guide variables
        elbo = integrate_guide_vars(guide_terms["log_measures"], model_costs + guide_costs,
                                    guide_terms["measure_vars"] - model_terms["measure_vars"],
                                    model_terms["plate_vars"] | guide_terms["plate_vars"])

        assert not elbo.inputs
        with funsor.memoize.memoize():
            # elbo is a lazy expression that we rewrite to an optimized form
            return -pyro.to_data(funsor.optimizer.apply_optimizer(elbo))

We'll go through the four helper functions one by one, starting with `get_traces`. This is essentially identical to the code in the original `TraceEnum_ELBO`.

In [None]:
def get_traces(model, guide, first_available_dim, *args, **kwargs):
    with handlers.enum(first_available_dim=first_available_dim):
        guide_tr = handlers.trace(guide).get_trace(*args, **kwargs)
        model_tr = handlers.trace(handlers.replay(model, trace=guide_tr)).get_trace(*args, **kwargs)
    return model_tr, guide_tr

Let's look at the `accumulate_terms` helper next. This function extracts all of the necessary tensors to compute the ELBO from the model and guide traces.

In [None]:
def accumulate_terms(model_tr, guide_tr):
    model_terms = {"log_factors": [], "log_measures": [], "plate_vars": frozenset(), "measure_vars": frozenset()}
    guide_terms = {"log_factors": [], "log_measures": [], "plate_vars": frozenset(), "measure_vars": frozenset()}
    for terms, tr in zip((model_terms, guide_terms), (model_tr, guide_tr)):
        for name, node in tr.nodes.items():
            if node["type"] != "sample" or type(node["fn"]).__name__ == "_Subsample":
                continue
            # if a site is enumerated in the model, measure but no log_prob
            if name in guide_tr.nodes or node['is_observed']:
                terms["log_factors"].append(node["funsor"]["log_prob"])
            if node["funsor"].get("log_measure", None) is not None:
                terms["log_measures"].append(node["funsor"]["log_measure"])
                terms["measure_vars"] |= frozenset(node["funsor"]["log_measure"].inputs)
            terms["plate_vars"] |= frozenset(f.name for f in node["cond_indep_stack"] if f.vectorized)
            terms["measure_vars"] |= frozenset(node["funsor"]["log_prob"].inputs)
    return model_terms, guide_terms

With the `model_terms` and `guide_terms` data structures in hand, we're ready to integrate out the variables enumerated within the model. The function we'll use, `integrate_model_vars`, works entirely on `funsor.Funsor`s: it takes in lists of `funsor.Funsor`s and `frozenset`s of input names, and returns another list of modified `funsor.Funsor`s.

In [None]:
@funsor.interpreteter.interpretation(funsor.terms.lazy)
def integrate_model_vars(log_measures, log_factors, measure_vars, plate_vars):
    contracted_terms = [t for t in log_factors if measure_vars & frozenset(t.inputs)]
    uncontracted_terms = [t for t in log_factors if not measure_vars & frozenset(t.inputs)]
    return uncontracted_terms + funsor.sum_product.partial_sum_product(
        funsor.ops.logaddexp, funsor.ops.add,
        log_measures + contracted_terms,
        plates=plate_vars, eliminate=measure_vars - plate_vars
    )

Note the use of `funsor.sum_product.partial_sum_product` in `integrate_model_vars`. This is an implementation of the tensor variable elimination algorithm that matches the description in the original paper nearly line-by-line.

After integrating out variables that were enumerated in the model, all that's left are variables that were enumerated in the guide. The `integrate_guide_vars` function that we'll use to compute the final ELBO also operates entirely on `funsor.Funsor`s.

In [None]:
@funsor.interpreter.interpretation(funsor.terms.lazy)
def integrate_guide_vars(log_measures, costs, measure_vars, plate_vars):
    result = funsor.to_funsor(0, output=funsor.reals())
    with funsor.memoize.memoize():
        for cost in costs:
            log_measure = funsor.sum_product.sum_product(
                funsor.ops.logaddexp, funsor.ops.add, log_measures,
                plates=plate_vars, eliminate=(plate_vars | measure_vars) - frozenset(cost.inputs)
            )
            term = funsor.Integrate(log_measure, cost, measure_vars & frozenset(cost.inputs))
            term = term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs))
            result = result + term
    return result

Like `integrate_model_vars`, `integrate_guide_vars` relies heavily on `funsor.sum_product`. However, unlike `integrate_model_vars`, the integral over the guide variables has moved outside the logarithm. This means that the ELBO decomposes into a sum of expectations of individual ELBO terms wrt the full guide distribution.