In [1]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import genjax

sns.set_theme(style="white")

# Pretty printing.
console = genjax.pretty(width=80)

# Reproducibility.
key = jax.random.PRNGKey(314159)

Gen helps probabilistic programmers design and implement models and inference algorithms by automating the (often) complicated inference math. The generative function interface is the key abstraction layer which provides this automation. Generative function language designers can extend the interface to new generative function objects - providing domain-specific patterns and optimizations which users can automatically take advantage of.

One key class of generative function languages are _combinators_ - higher-order functions which accept generative functions as input, and produce a new generative function type as an output.

Combinators functionally transform the generative structure that we pass into them, expressing useful patterns - including chain-like computations, IID sampling patterns, or generative computations which form grammar-like structures. 

Combinators also expose optimization opportunities - by registering the patterns as generative functions, implementors (e.g. the library authors) can specialize the implementation of the generative function interface methods. Users of combinators can then take advantage of this interface specialization to express asymptotically optimal updates (useful in e.g. MCMC kernels), or optimized importance weight calculations.

In this notebook, we'll be discussing `Unfold` - a combinator for expressing generative computations which are reminiscent of state-space (or Markov) models. To keep things simple, we'll explore a hidden Markov model example - but combinator usage generalizes to models with much richer structure.

## Introducing `Unfold`

Let's discuss `Unfold`.^[A quick reminder: when in doubt, you can use the console from `console = genjax.pretty()` to inspect the classes which we discuss in the notebooks.]

How do we make an instance of `Unfold`? Given an existing generative function which is a _kernel_ - a kernel accepts and returns the same type signature - we can create a valid `Unfold` instance.^[This is not strictly true. `Unfold` also allows you to pass in a set of _static arguments_ which are provided to the kernel _after the state argument_, unchanged, at each time step. We show this at the bottom of the notebook.]

Here's an example kernel:

In [5]:
@genjax.gen
def kernel(prev_latent):
    new_latent = genjax.Normal(prev_latent, 1.0) @ "z"
    new_obs = genjax.Normal(new_latent, 1.0) @ "x"
    return new_latent

In [6]:
key, tr = jax.jit(kernel.simulate)(key, (0.3,))
tr

To create an `Unfold` instance, we provide two things:

* The kernel generative function.
* A static maximum unroll chain argument. Dynamically, `Unfold` may not unroll all the way up to this maximum - but for JAX/XLA compilation, we need to provide this maximum value as an invariant upper bound for any invocation of `Unfold`.

In [7]:
chain = genjax.Unfold(kernel, max_length=10)
chain

To invoke an interface method, the arguments which `Unfold` expects is a `Tuple`, where the first element is the maximum **index** in the resulting chain, and the second element is the initial state.

::: {.callout-important}

## Usage of index argument vs. a length argument

Note how we've bolded **index** above - think of the index value as denoting an upper bound on active indices for the resulting chain. An _active index_ is one in which the value was evolved using the `kernel` from the previous value. Passing in `index = 5` means: all values after `return[5]` are not evolved, they're just filled with the `return[5]` value.

Indexing follows Python convention - so e.g. passing in `0` as the index means that a **single application** of the kernel was applied to the state, before evolution was halted and evolved statically.

:::

In [8]:
key, tr = jax.jit(chain.simulate)(key, (5, 0.3))
tr

In [9]:
tr.indices

In [10]:
tr.get_retval()

Note how `tr.indices` keep track of where the chain stopped evolving, according to the index argument to `Unfold`. In `tr.get_retval()`, we see that the final dynamic value (afterwards, evolution stops) occurs at `index = 5`.

## Combinator choice maps

Typically, each combinator has a unique choice map. The choice map simultaneously represents the structure of the generative choices which the transformed combinator generative function makes, as well as optimization opportunities which a user can take advantage of.

Let's study the choice map for `UnfoldTrace`.

In [11]:
chm = tr.get_choices()
chm

Again, let's look at the indices.

In [None]:
chm.indices

No surprises - the choice map also keeps track of which indices are active, and which indices are inactive. 

Inactive indices **do not** participate in inference metadata computations - so e.g. if we ask for the score of the trace:

In [None]:
tr.get_score()

The score is the same as the sum of sub-trace scores from `[0:5]`.

In [None]:
np.sum(tr.get_subtree("z").get_score()[0:6] + tr.get_subtree("x").get_score()[0:6])

The reason why we have an `index` argument is that we can dynamically choose how much of the chain contributes to the generative computation. This `index` argument can come from other generative function - it need not be a JAX trace-time static value.

With this in mind, it's best to think of `Unfold` as representing a space of processes which unroll up to some maximum static length - but the active generative process can halt before that maximum length.