# Max-product algorithm on tree-structured graphs

Consider a probabilistic graphical model in seven variables that has the following conditional independence property, represented as a tree:

![image](../Figures/tree_graph.png)

We assume that the variables are discrete with three possible states $x_i \in \{1, 2, 3\}$, and that the joint distribution of $x_1, \ldots, x_7$ decompose into pairwise potentials:
$$
p(x_1, x_2, x_3, x_4, x_5, x_6, x_7) = \psi_{12}(x_1, x_2) \psi_{23}(x_2, x_3) \psi_{24}(x_2, x_4) \psi_{35}(x_3, x_5) \psi_{46}(x_4, x_6) \psi_{47}(x_4, x_7)
$$

Our goal in this notebook is to find the most likely state of this probability model using the max-product algorithm, when the pairwise potential functions are given by

$$
\psi_{ij}(x_i, x_j) = \Psi,
$$

for all $i \sim j$, where $\Psi$ is a randomly generated $3 \times 3$ matrix.

In [1]:
import numpy as np
import os
os.chdir('..')
from models import Potential, TreeGraph
from algorithms import MaxProduct

For the purpose of comparison, we compute the joint probability explicitly, although this is not necessary to run the max-product algorithm.

In [2]:
np.random.seed(2353)
Psi = np.random.uniform(size=(3,3))

Psi12 = Psi[:,:,None,None,None,None,None]
Psi23 = Psi[None,:,:,None,None,None,None]
Psi24 = Psi[None,:,None,:,None,None,None]
Psi35 = Psi[None,None,:,None,:,None,None]
Psi46 = Psi[None,None,None,:,None,:,None]
Psi47 = Psi[None,None,None,:,None,None,:]

joint_probability = Psi12 * Psi23 * Psi24 * Psi35 * Psi46 * Psi47
Z = joint_probability.sum()
joint_probability /= Z

Now, run the max-product algorithm to compute the most likely state of this model.

In [3]:
np.random.seed(2353)
Psi = np.random.uniform(size=(3,3))

edge_potentials = [Potential(Psi, variables=[1,2]),
                   Potential(Psi, variables=[2,3]),
                   Potential(Psi, variables=[2,4]),
                   Potential(Psi, variables=[3,5]),
                   Potential(Psi, variables=[4,6]),
                   Potential(Psi, variables=[4,7])]

f = TreeGraph(edge_potentials, state_type='mode') # Set up graphical model
MaxProduct(f, root_node=1) # Run max-product, taking node 1 as the root

# Get most likely states obtained by max-product
i1, i2, i3, i4, i5, i6, i7 = list(f.states.values())

# Relabel states 0 -> 1, 1 -> 2, 2 -> 3
state_dict = {0: 1, 1: 2, 2: 3}
eval = lambda i: state_dict[i]
p1, p2, p3, p4, p5, p6, p7 = list(map(eval, [i1, i2, i3, i4, i5, i6, i7]))

print(f"Most likely state: {p1, p2, p3, p4, p5, p6, p7}")

Most likely state: (2, 3, 2, 2, 3, 3, 3)


Below, we verify that indeed the state found by the max-product algorithm corresponds to the most likely state (i.e., giving rise to the maximal probability).

In [4]:
np.max(joint_probability) == joint_probability[i1, i2, i3, i4, i5, i6, i7]

True

## Computation by hand

Below, we run the max-product algorithm by hand to better understand its inner workings. This is essentially what is happening inside the `MaxProduct` function used above.

### Leaves to root
We first propagate messages from leaf nodes up to the root node, taking node 1 as the root, to compute the maximal probability.

In [5]:
m53 = m64 = m74 = np.max(Psi, axis=1) # Send messages from nodes 5 -> 3, 6 -> 4 and 7 -> 4

m32 = np.max(Psi * m53[None], axis=1) # Send message from nodes 3 to 2
m42 = np.max(Psi * (m64 * m74)[None], axis=1) # Send message from nodes 4 to 2

m21 = np.max(Psi * (m32 * m42)[None], axis=1) # Send message from nodes 2 to 1

p_ = np.max(m21) / Z  # Compute maximal probability

We check below that the obtained probability is indeed maximal under this model:

In [6]:
p_ == np.max(joint_probability)

True

### Root to leaves
Next, we back-track from the root node back to the leaf nodes to find the state which gave rise to this maximal probability.

In [7]:
i1 = np.argmax(m21) # Compute state at node 1

i2 = np.argmax(Psi[i1] * (m32 * m42)) # Compute state at node 2

i3 = np.argmax(Psi[i2] * m53) # Compute state at node 3
i4 = np.argmax(Psi[i2] * (m64 * m74)) # Compute state at node 4

i5 = np.argmax(Psi[i3]) # Compute state at node 5
i6 = i7 = np.argmax(Psi[i4]) # Compute state at nodes 6 and 7

# Relabel states 0 -> 1, 1 -> 2, 2 -> 3
state_dict = {0: 1, 1: 2, 2: 3}
eval = lambda i: state_dict[i]
p1, p2, p3, p4, p5, p6, p7 = list(map(eval, [i1, i2, i3, i4, i5, i6, i7]))

print(f"Most likely state: {p1, p2, p3, p4, p5, p6, p7}")

Most likely state: (2, 3, 2, 2, 3, 3, 3)
