In [1]:
%config InlineBackend.figure_format = 'svg'

In [2]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import dataclasses
import genjax
from genjax import Diff
from typing import List, Tuple, Any

sns.set_theme(style="white")

# Pretty printing.
console = genjax.pretty(width=80)

# Reproducibility.
key = jax.random.PRNGKey(314159)

The `update` interface method for generative functions defines an update operation on traces produced by generative functions. 

`update` allows the user to provide new constraints, as well as new arguments, and returns an updated trace which is consistent with the new constraints, as well as an incremental importance weight which measures the difference between the new and old constraints under the model. `update` is used to implement many types of iterative MCMC inference families.

The specification of `update` only requires that a modeling language support the above behavior - nonetheless, modeling languages can implement `update` with custom optimizations to improve the cost of repeatedly calling `update` (e.g. an iterative MCMC inference procedure).

In this notebook, we'll be focused on these optimization opportunities within the implementation of `update` for the `BuiltinGenerativeFunction` language. We'll describe a system which supports incremental computing capabilities using change information (called `Diff` in the codebase) propagation.^[Think of a value of `Diff` type as representing a new value $v^\prime$ using a decomposition $v^\prime = v \oplus dv$ where $dv$ is the change to the value and $v$ is the original value.] 

While we'll be focused on the distribution and builtin languages, this system is also applicable to the combinator implementations of `update`. In another notebook, we'll see how the incremental computing system can be used to efficiently compute `update` for `UnfoldCombinator`.

## What is `update` used for?

Before we discuss how `update` can be optimized by a generative function implementor, it's worth constructing a simple example which shows how `update` is used, and to show why optimizing `update` is worthwhile.

One common usage of `update` is in MCMC algorithm kernels. MCMC is often repeatedly applied to generate a chain of samples: any optimization opportunities that we identify and take advantage of will provide runtime gains which are multiplied over the length of the chain.

Let's example this scenario using a pedagogical example - remember that the potential optimization pattern (based upon random variable dependency information) we'll describe extends to all generative functions.

### Pedagogical example

Consider the following generative function:

In [3]:
@genjax.gen
def model(x):
    a = genjax.trace("a", genjax.Normal)(x, 1.0)
    b = genjax.trace("b", genjax.Normal)(x, 1.0)
    c = genjax.trace("c", genjax.Normal)(a + b, 1.0)
    return c

The variable dependency graph is shown below.

```{mermaid}
flowchart LR
  x[Argument] --> a[a]
  x --> b[b]
  a --> c[c]
  b --> c
  c --> r[Return]
```

Now, when we simulate a trace from this model - we get choices for `"a"`, `"b"`, and `"c"`.

In [4]:
key, tr = model.simulate(key, (2.0,))
tr

Iterative inference techniques like Metropolis-Hastings (and other MCMC methods) start with an initial trace, propose an update to the trace using a proposal, and then compute a criterion for accepting or rejecting the update.

In Metropolis-Hastings, the criterion involves an _accept-reject ratio_ computation - which requires computing the probability of transitioning from the current trace to the new trace, as well as the probability of transitioning from the new trace back to the current trace, under a kernel defined by the algorithm.

The library implementation of Metropolis-Hastings is shown below - `MetropolisHastings.apply` shows the main content of the algorithm (it's safe to ignore other methods for now).

In [5]:
@dataclasses.dataclass
class MetropolisHastings(genjax.MCMCKernel):
    selection: genjax.Selection
    proposal: genjax.GenerativeFunction

    def flatten(self):
        return (), (self.selection, self.proposal)

    def apply(self, key, trace: genjax.Trace, proposal_args: Tuple):
        model = trace.get_gen_fn()
        model_args = trace.get_args()
        proposal_args_fwd = (trace.get_choices(), *proposal_args)
        key, proposal_tr = self.proposal.simulate(key, proposal_args_fwd)
        fwd_weight = proposal_tr.get_score()
        diffs = jtu.tree_map(Diff.no_change, model_args)
        key, (_, weight, new, discard) = model.update(
            key, trace, proposal_tr.get_choices(), diffs
        )
        proposal_args_bwd = (new, *proposal_args)
        key, (bwd_weight, _) = self.proposal.importance(key, discard, proposal_args_bwd)
        alpha = weight - fwd_weight + bwd_weight
        key, sub_key = jax.random.split(key)
        check = jnp.log(random.uniform(sub_key)) < alpha
        return (
            key,
            jax.lax.cond(
                check,
                lambda *args: (new, True),
                lambda *args: (trace, False),
            ),
        )

    def reversal(self):
        return self

This computation involves `update` - which _incrementally_ updates a trace to be consistent with new arguments and constraints, and computes an importance weight (the difference between the trace's new score and the old score).

::: {.callout-important}

In the invocation of `update`, there's an interesting not-yet-explained argument: `diffs` - a tuple of `Diff` values, which represent _changes_ to the original arguments of the call which produced the trace which we are attempting to update. We'll come back to these values in a moment.

:::

If we naively evaluate the required log probability by re-evaluating the entire model - we're performing extra computation. We can see this by considering a specific target address - let's consider `"a"`. If the update changes `"a"`, what other generative function calls do we need to visit to compute the correct update - both to the trace, and the importance weight? 

The graph below shows the answer.

```{mermaid}
flowchart LR
  x[Argument] --> a[a]
  x --> b[b]
  a --> c[c]
  b --> c
  c --> r[Return]
  style a fill:#f9f,stroke:#333,stroke-width:4px
  style c fill:#f9f,stroke:#333,stroke-width:4px
```

An update to `"a"` requires that we re-evaluate the log probability at `"c"` because the return value of the generative function call at `"a"` flows into the generative function call at `"c"` - but we do not need to re-visit `"b"` because none of the values which flow into `"b"` have changed. 

When computing the weight difference, unchanged sites thus contribute nothing.^[The important idea is that tracking what values have changed allows us to identify what parts of the computation graph are required - and what parts do not need to be re-visited or re-computed.]

## Change information

The specification of `update` doesn't require that an implementation track or use the change information - but generative function implementations can choose to optimize their `update` implementation. 

With that in mind, several of the languages which GenJAX exposes can be instructed to perform optimized `update` computations using `Diff` values.

A `Diff` value consists of a base value `v` and a value of `Change` type, which represents the change to the base value. The new argument value for `update` is given by $\text{v} \oplus dv$ where `dv :: Change`. 

The $\oplus$ operation must be appropriately defined for the change type lattice - we implement this operation for common change types in GenJAX, but users can define their own change types for `Pytree` data classes.

In [6]:
genjax.Diff.new(5.0, genjax.NoChange)

### Diffs for distributions

Let's explore the basics with distributions.

In [7]:
key, dist_tr = genjax.Normal.simulate(key, (0.0, 1.0))
dist_tr

In [8]:
# dist_tr.update is equivalent to model.update(key, tr, ...)
key, (ret_diff, w, tr, d) = dist_tr.update(
    key,
    genjax.EmptyChoiceMap(),
    (
        genjax.Diff.new(1.0, genjax.UnknownChange),
        genjax.Diff.new(1.0, genjax.NoChange),
    ),
)

The return values do not change.

In [9]:
(dist_tr.get_retval(), ret_diff.val)

The weight is non-zero because the arguments have changed, implying that we must re-evaluate the log probability.

In [10]:
w

What does the code look like when there is no new constraint and both the arguments do not change?

In [11]:
# dist_tr.update is equivalent to model.update(key, tr, ...)
jaxpr = jax.make_jaxpr(dist_tr.update)(
    key,
    genjax.EmptyChoiceMap(),
    (
        genjax.Diff.new(0.0, genjax.NoChange),
        genjax.Diff.new(1.0, genjax.NoChange),
    ),
)
jaxpr

As expected, no computation is required - so the flattened arguments are just forwarded to the return.

## `cache`: change aware memoization

The `BuiltinGenerativeFunction` language exposes a primitive called `cache` that interacts with the change tracking system to support memoization of deterministic computations (even deterministic computations which depend on random choices).