# BayesTree

A `BayesTree` is a graphical model that represents the result of multifrontal variable elimination on a `FactorGraph`. It is a tree structure where each node is a 'clique' containing a set of conditional distributions $P(	ext{Frontals} | 	ext{Separator})$.

Key properties:
*   **Cliques:** Each node (clique) groups variables that are eliminated together.
*   **Frontal Variables:** Variables eliminated within a specific clique.
*   **Separator Variables:** Variables shared between a clique and its parent in the tree. These variables were eliminated higher up in the tree.
*   **Tree Structure:** Represents the dependencies introduced during elimination more compactly than a Bayes net, especially for sparse problems.

Like `FactorGraph` and `BayesNet`, `BayesTree` is templated on the type of conditional/clique (e.g., `GaussianBayesTree`).

<a href="https://colab.research.google.com/github/borglab/gtsam/blob/develop/gtsam/inference/doc/BayesTree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install gtsam

In [None]:
import gtsam
import numpy as np

# We need concrete graph types and elimination to get a BayesTree
from gtsam import GaussianFactorGraph, Ordering, GaussianBayesTree
from gtsam import symbol_shorthand

X = symbol_shorthand.X
L = symbol_shorthand.L

## Creating a BayesTree (via Elimination)

BayesTrees are typically obtained by performing multifrontal elimination on a `FactorGraph`.

In [None]:
# Create a simple Gaussian Factor Graph (more complex this time)
graph = GaussianFactorGraph()
model = gtsam.noiseModel.Isotropic.Sigma(1, 1.0)
graph.add(X(0), -np.eye(1), np.zeros(1), model)           # Prior on x0
graph.add(X(0), -np.eye(1), X(1), np.eye(1), np.zeros(1), model) # x0 -> x1
graph.add(X(1), -np.eye(1), X(2), np.eye(1), np.zeros(1), model) # x1 -> x2
graph.add(L(1), -np.eye(1), X(0), np.eye(1), np.zeros(1), model) # l1 -> x0 (measurement)
graph.add(L(1), -np.eye(1), X(1), np.eye(1), np.zeros(1), model) # l1 -> x1 (measurement)
graph.add(L(2), -np.eye(1), X(1), np.eye(1), np.zeros(1), model) # l2 -> x1 (measurement)
graph.add(L(2), -np.eye(1), X(2), np.eye(1), np.zeros(1), model) # l2 -> x2 (measurement)

print("Original Factor Graph:")
graph.print()

# Eliminate multifrontally using COLAMD ordering
ordering = Ordering.Colamd(graph)
# Note: Multifrontal typically yields multiple roots if graph is disconnected
bayes_tree = graph.eliminateMultifrontal(ordering)

print("\nResulting BayesTree:")
bayes_tree.print()

Original Factor Graph: size 7
Factor 0: JacobianFactor(keys = [8070450532247928832], Z = [ -1 ], b = [ 0 ], model = diagonal sigmas [1])
Factor 1: JacobianFactor(keys = [8070450532247928832; 8070450532247928833], A[0] = [ -1  1 ], b = [ 0 ], model = diagonal sigmas [1])
Factor 2: JacobianFactor(keys = [8070450532247928833; 8070450532247928834], A[0] = [ -1  1 ], b = [ 0 ], model = diagonal sigmas [1])
Factor 3: JacobianFactor(keys = [7783684379976990720; 8070450532247928832], A[0] = [ -1  1 ], b = [ 0 ], model = diagonal sigmas [1])
Factor 4: JacobianFactor(keys = [7783684379976990720; 8070450532247928833], A[0] = [ -1  1 ], b = [ 0 ], model = diagonal sigmas [1])
Factor 5: JacobianFactor(keys = [7783684379976990721; 8070450532247928833], A[0] = [ -1  1 ], b = [ 0 ], model = diagonal sigmas [1])
Factor 6: JacobianFactor(keys = [7783684379976990721; 8070450532247928834], A[0] = [ -1  1 ], b = [ 0 ], model = diagonal sigmas [1])


Resulting BayesTree: cliques: 3, variables: 5
Root(s):
Co

## Properties and Access

A `BayesTree` allows access to its root cliques and provides a way to look up the clique containing a specific variable.

In [None]:
print(f"BayesTree number of cliques: {bayes_tree.size()}")

# Access roots
roots = bayes_tree.roots()
print(f"Number of roots: {len(roots)}")
if roots:
    # Access the conditional associated with the first root clique
    root_conditional = roots[0].conditional()
    print(f"Root clique 0 conditional frontals: {root_conditional.frontals()}")

# Find the clique containing a specific variable
clique_x1 = bayes_tree.clique(X(1))
# clique_x1 is a shared pointer to the clique (e.g., GaussianBayesTreeClique)
print(f"\nClique containing x1 ({X(1)}):")
clique_x1.print() # Print the clique itself

BayesTree number of cliques: 3
Number of roots: 1
Root clique 0 conditional frontals: [7783684379976990720, 8070450532247928833]

Clique containing x1 (8070450532247928833):
Conditional density P(l1, x1 | l2, x2) = P(l1 | x1) P(x1 | l2, x2) 
  size: 2
  Conditional P(l1 | x1): GaussianConditional( P(l1 | x1) = dl1 - R*dx1 - d), R = [ 0.5 ], d = [ 0 ], sigmas = [ 0.866025 ])

  Conditional P(x1 | l2, x2): GaussianConditional( P(x1 | l2, x2) = dx1 - R1*dl2 - R2*dx2 - d), R1 = [ 0.333333 ], R2 = [ 0.333333 ], d = [ 0 ], sigmas = [ 0.745356 ])




## Solution and Marginals

Similar to `BayesNet`, `BayesTree` (specifically derived types like `GaussianBayesTree`) provides an `optimize()` method for finding the MLE solution. It also allows for efficient computation of marginals on individual variables or joint marginals on pairs of variables using belief propagation or shortcut evaluation on the tree.

In [None]:
# Optimize to find the MLE solution (for GaussianBayesTree)
mle_solution = bayes_tree.optimize()
print("Optimized Solution (MLE):")
mle_solution.print()

# Compute marginal factor on a single variable (returns a Conditional)
marginal_x1 = bayes_tree.marginalFactor(X(1))
print("\nMarginal Factor on x1:")
marginal_x1.print()

# Compute joint marginal factor graph on two variables
joint_x0_x2 = bayes_tree.joint(X(0), X(2))
print("\nJoint Marginal Factor Graph on (x0, x2):")
joint_x0_x2.print()

Optimized Solution (MLE):
Values with 5 values:
Value l1: [0.]
Value l2: [0.]
Value x0: [0.]
Value x1: [0.]
Value x2: [0.]


Marginal Factor on x1:
Conditional density P(x1 | l2, x2) = P(x1 | l2, x2) 
  size: 1
  Conditional P(x1 | l2, x2): GaussianConditional( P(x1 | l2, x2) = dx1 - R1*dl2 - R2*dx2 - d), R1 = [ 0.333333 ], R2 = [ 0.333333 ], d = [ 0 ], sigmas = [ 0.745356 ])


Joint Marginal Factor Graph on (x0, x2):
Factor Graph: size 2
Factor 0: GaussianConditional( P(x0 | x2) = dx0 - R*dx2 - d), R = [ 0.25 ], d = [ 0 ], sigmas = [ 0.866025 ])
Factor 1: GaussianConditional( P(x2) = dx2 - d), d = [ 0 ], sigmas = [ 0.774597 ])



## Visualization

Bayes trees can be visualized using Graphviz.

In [None]:
dot_string = bayes_tree.dot()
print(dot_string)

# To render:
# dot -Tpng bayestree.dot -o bayestree.png
# import graphviz
# graphviz.Source(dot_string)

digraph G{
0[label="l1, x1 : l2, x2"];
0->1
1[label="l2, x2 : x0"];
1->2
2[label="x0 : "];
}