# MCMC using backward filtering, forward guiding

Parameter inference for trees with Gaussian transitions along edges and observations at the leaf nodes following the finite state space example in https://arxiv.org/abs/2203.04155. The conditioning and upwards/downwards message passing and fusing operations is in accordance with the backward filtering, forward guiding approach of Frank van der Meulen, Moritz Schauer et al., see https://arxiv.org/abs/2010.03509 and https://arxiv.org/abs/2203.04155 . The latter reference provides an accesible introduction to the scheme and the notation used in this example.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from jax.random import PRNGKey, split
import hyperiax
import jax
from jax import numpy as jnp
from hyperiax.execution import LevelwiseTreeExecutor
from hyperiax.models import UpLambda
from hyperiax.models.functional import product_fuse_children
from hyperiax.mcmc import ParameterStore, UniformParameter
from hyperiax.mcmc.metropolis_hastings import metropolis_hastings
from hyperiax.mcmc.plotting import trace_plots

import matplotlib.pyplot as plt
from tqdm import tqdm


In [3]:
key = PRNGKey(42)

# Gaussian tree, constant node covariance

First, we initialize the tree. We let the data be vectors in $\mathbb{R}^2$, and we set the root to be $(0,0)$.

In [4]:
# create tree and initialize with noise

from hyperiax.tree import HypTree, TreeNode
from hyperiax.tree.childrenlist import ChildList

# example tree, see Figure 1 / Figure 4 tree in https://arxiv.org/abs/2203.04155
root = TreeNode(); # x_{-1}
x0 = TreeNode(); x0.parent = root; root.children = ChildList([x0])
x1 = TreeNode(); x1.parent = x0;
x3 = TreeNode(); x3.parent = x0;
x0.children = ChildList([x1,x3])
x2 = TreeNode(); x2.parent = x1; x1.children = ChildList([x2])
v3 = TreeNode(); v3.parent = x2; x2.children = ChildList([v3])
x4 = TreeNode(); x4.parent = x3;
v2 = TreeNode(); v2.parent = x3; x3.children = ChildList([x4,v2])
v1 = TreeNode(); v1.parent = x4; x4.children = ChildList([v1])
v1.children = v2.children = v3.children = ChildList()

tree = HypTree(root)
print('Tree:',tree)
#tree.plot_tree()

# set types to select the right transitions
# types
troot = 0; tinner_node = 1; tleaf_node = 2;
tree.root['type'] = troot
x1['type'] = x2['type'] = x3['type'] = x4['type'] = tinner_node
v1['type'] = v2['type'] = v3['type'] = tleaf_node

# number of states
R = 3

# root value
tree.root['value'] = jnp.zeros(R)

# observations. The extra dimensions compared to https://arxiv.org/abs/2203.04155 is because Hyperiax requires same shapes for all nodes.
v1['value'] = jnp.eye(R)[0]
v2['value'] = jnp.eye(R)[1]
v3['value'] = jnp.eye(R)[2]


# transition matrices
x0['type'] = troot

# root, initial state prior
pi1 = pi2 = pi3 = 1/3; 
km10 = lambda params: jnp.diag([pi1,pi2,pi3])
# inner node
kst = lambda params: jnp.array([[1.-params['theta'],params['theta'],0.],
                                [.25,.5,.25],
                                [.4,.3,.3]])
# leaves. The extra dimension compared to https://arxiv.org/abs/2203.04155 is because Hyperiax requires same shape for all nodes.
lambdi = lambda params: jnp.array([[1.,1.,0.],
                                    [1.,1.,0.],
                                    [0.,0.,1.]])

# using jax.lax.cond instead of python ifs
def transition(value,type,params): 
    return jax.lax.cond(type == tinner_node,
        lambda: jnp.dot(kst(params),value),
        lambda: jax.lax.cond(type == tleaf_node,
            lambda: jnp.dot(lambdi(params),value),
            lambda: jnp.array([pi1,pi2,pi3])
        )
    )


Tree: HypTree with 5 levels and 9 nodes


Parameter for the transition kernel.

In [5]:
# parameters, theta with uniform prior
params = ParameterStore({
    'theta': UniformParameter(.5), # theta parameter for kst
    })

We now define the backwards filter through the up function. In this case, the up operation is matrix multiplication.

In [6]:
# backwards filter. The operation is vmap'ed over the batch dimension (leading dimension)
#@jax.jit
def up(value,type,params,**args):
    return jax.vmap(lambda value,type: {'value': transition(value,type,params)})(value,type)


We create the model and executor for the backwards filter (up).

In [7]:
# create model and executor
upmodel = UpLambda(up_fn=up,fuse_fn=product_fuse_children(axis=0))
upexec = LevelwiseTreeExecutor(upmodel)

We are now ready to execute the upwards pass.

In [8]:
# execture backwards filter
utree = upexec.up(tree,params.values())

# print results
for node in utree.iter_bfs():
    print(node.data)


{'type': 0, 'value': Array([0.33333334, 0.33333334, 0.33333334], dtype=float32)}
{'type': 0, 'value': Array([0.109375  , 0.125     , 0.10312501], dtype=float32)}
{'type': 1, 'value': Array([0.  , 0.25, 0.3 ], dtype=float32)}
{'type': 1, 'value': Array([1.  , 0.75, 0.  ], dtype=float32)}
{'type': 1, 'value': Array([0., 0., 1.], dtype=float32)}
{'type': 1, 'value': Array([1., 1., 0.], dtype=float32)}
{'type': 2, 'value': Array([0., 1., 0.], dtype=float32)}
{'type': 2, 'value': Array([0., 0., 1.], dtype=float32)}
{'type': 2, 'value': Array([1., 0., 0.], dtype=float32)}


# MCMC

To be implemented.