In [None]:
'''
 * Copyright (c) 2008 Radhamadhab Dalai
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
'''

## Sum-Product Algorithm in Factor Graphs

Consider a simple factor graph illustrated in **Fig.51**, where the unnormalized joint distribution is given by:

![image.png](attachment:image.png)
Fig.51 A simple factor graph used to illustrate the sum-product algorithm.


$$
p(x) = f_a(x_1, x_2) f_b(x_2, x_3) f_c(x_2, x_4). \tag{8.73}
$$

## Message Passing

### From Leaves to Root

The following sequence of messages is propagated from the leaf nodes $x_1$ and $x_4$ towards the root $x_3$:

1. **Message from $x_1$ to $f_a$:**

   $$
   \mu_{x_1 \to f_a}(x_1) = 1. \tag{8.74}
   $$

2. **Message from $f_a$ to $x_2$:**

   $$
   \mu_{f_a \to x_2}(x_2) = \sum_{x_1} f_a(x_1, x_2). \tag{8.75}
   $$

3. **Message from $x_4$ to $f_c$:**

   $$
   \mu_{x_4 \to f_c}(x_4) = 1. \tag{8.76}
   $$

4. **Message from $f_c$ to $x_2$:**

   $$
   \mu_{f_c \to x_2}(x_2) = \sum_{x_4} f_c(x_2, x_4). \tag{8.77}
   $$

5. **Message from $x_2$ to $f_b$:**

   $$
   \mu_{x_2 \to f_b}(x_2) = \mu_{f_a \to x_2}(x_2) \mu_{f_c \to x_2}(x_2). \tag{8.78}
   $$

6. **Message from $f_b$ to $x_3$:**

   $$
   \mu_{f_b \to x_3}(x_3) = \sum_{x_2} f_b(x_2, x_3) \mu_{x_2 \to f_b}(x_2). \tag{8.79}
   $$

### From Root to Leaves

After propagating messages to the root $x_3$, messages are propagated back to the leaves:

1. **Message from $x_3$ to $f_b$:**

   $$
   \mu_{x_3 \to f_b}(x_3) = 1. \tag{8.80}
   $$

2. **Message from $f_b$ to $x_2$:**

   $$
   \mu_{f_b \to x_2}(x_2) = \sum_{x_3} f_b(x_2, x_3). \tag{8.81}
   $$

3. **Message from \(x_2\) to \(f_a\):**

   $$
   \mu_{x_2 \to f_a}(x_2) = \mu_{f_b \to x_2}(x_2) \mu_{f_c \to x_2}(x_2). \tag{8.82}
   $$

4. **Message from $f_a$ to $x_1$:**

   $$
   \mu_{f_a \to x_1}(x_1) = \sum_{x_2} f_a(x_1, x_2) \mu_{x_2 \to f_a}(x_2). \tag{8.83}
   $$

5. **Message from \(x_2\) to \(f_c\):**

   $$
   \mu_{x_2 \to f_c}(x_2) = \mu_{f_a \to x_2}(x_2) \mu_{f_b \to x_2}(x_2). \tag{8.84}
   $$

6. **Message from $f_c$ to $x_4$:**

   $$
   \mu_{f_c \to x_4}(x_4) = \sum_{x_2} f_c(x_2, x_4) \mu_{x_2 \to f_c}(x_2). \tag{8.85}
   $$

## Marginal Computation

After completing message passing in both directions, the marginal distributions can be computed. For example, the marginal distribution $p(x_2)$ is given by:

$$
p(x_2) = \mu_{f_a \to x_2}(x_2) \mu_{f_b \to x_2}(x_2) \mu_{f_c \to x_2}(x_2). \tag{8.86}
$$

Expanding the messages:

$$
p(x_2) = \sum_{x_1} \sum_{x_3} \sum_{x_4} f_a(x_1, x_2) f_b(x_2, x_3) f_c(x_2, x_4),
$$

which corresponds to the marginalization of the joint distribution $p(x)$:

$$
p(x_2) = \sum_{x_1} \sum_{x_3} \sum_{x_4} p(x). \tag{8.87}
$$

## Observed Variables

If a subset of the variables is observed, say $v$, the joint distribution $p(x)$ is modified as:

$$
p(h, v = v^\ast) = p(x) \prod_{i} I(v_i, v_i^\ast),
$$

where $I(v, v^\ast)$ is an indicator function:

$$
I(v, v^\ast) = 
\begin{cases} 
1 & \text{if } v = v^\ast, \\
0 & \text{otherwise}.
\end{cases}
$$

The sum-product algorithm can then compute the posterior marginals $p(h_i | v = v^\ast)$ efficiently.


![image-2.png](attachment:image-2.png)

Fig.52 Flow of messages for the sum-product algorithm applied to the example graph in Fig.51. (a) From the leaf nodes x1 and x4 towards the root node x3 . (b) From the root node towards the leaf nodes.


In [1]:
import numpy as np

# Define factors as lambda functions
factors = {
    'fa': lambda x1, x2: np.exp(-((x1 - 1)**2 + (x2 - 2)**2)),
    'fb': lambda x2, x3: np.exp(-((x2 - 2)**2 + (x3 - 3)**2)),
    'fc': lambda x2, x4: np.exp(-((x2 - 2)**2 + (x4 - 1)**2)),
}

# Define the range of variables
x_range = np.arange(0, 5, 1)  # Variable range [0, 4]

# Message placeholders
messages = {
    'x1->fa': np.ones(len(x_range)),
    'fa->x2': np.zeros(len(x_range)),
    'x4->fc': np.ones(len(x_range)),
    'fc->x2': np.zeros(len(x_range)),
    'x2->fb': np.zeros(len(x_range)),
    'fb->x3': np.zeros(len(x_range)),
    'x3->fb': np.ones(len(x_range)),
    'fb->x2': np.zeros(len(x_range)),
    'x2->fa': np.zeros(len(x_range)),
    'fa->x1': np.zeros(len(x_range)),
    'x2->fc': np.zeros(len(x_range)),
    'fc->x4': np.zeros(len(x_range)),
}

# Forward pass: From leaves to root
for x2 in x_range:
    messages['fa->x2'][x2] = np.sum([
        factors['fa'](x1, x2) * messages['x1->fa'][x1] for x1 in x_range
    ])
    messages['fc->x2'][x2] = np.sum([
        factors['fc'](x2, x4) * messages['x4->fc'][x4] for x4 in x_range
    ])
    messages['x2->fb'][x2] = messages['fa->x2'][x2] * messages['fc->x2'][x2]

for x3 in x_range:
    messages['fb->x3'][x3] = np.sum([
        factors['fb'](x2, x3) * messages['x2->fb'][x2] for x2 in x_range
    ])

# Backward pass: From root to leaves
for x2 in x_range:
    messages['fb->x2'][x2] = np.sum([
        factors['fb'](x2, x3) * messages['x3->fb'][x3] for x3 in x_range
    ])
    messages['x2->fa'][x2] = messages['fb->x2'][x2] * messages['fc->x2'][x2]
    messages['x2->fc'][x2] = messages['fa->x2'][x2] * messages['fb->x2'][x2]

for x1 in x_range:
    messages['fa->x1'][x1] = np.sum([
        factors['fa'](x1, x2) * messages['x2->fa'][x2] for x2 in x_range
    ])

for x4 in x_range:
    messages['fc->x4'][x4] = np.sum([
        factors['fc'](x2, x4) * messages['x2->fc'][x2] for x2 in x_range
    ])

# Compute marginal probabilities
marginals = {'x2': np.zeros(len(x_range))}

for x2 in x_range:
    marginals['x2'][x2] = (
        messages['fa->x2'][x2] *
        messages['fb->x2'][x2] *
        messages['fc->x2'][x2]
    )

# Normalize the marginal probabilities
marginals['x2'] /= np.sum(marginals['x2'])

print("Marginal distribution for x2:", marginals['x2'])


Marginal distribution for x2: [5.58774846e-06 4.52779947e-02 9.09432835e-01 4.52779947e-02
 5.58774846e-06]


## Max-Sum Algorithm for Factor Graphs

The **max-sum algorithm** finds the configuration $ \mathbf{x} $ that maximizes the joint probability $ p(\mathbf{x}) $ in a factor graph. This is achieved through message passing, where summations in the sum-product algorithm are replaced by maximizations, and products are replaced by sums of logarithms.

---

## Joint Distribution

The joint probability of a factor graph can be written as:
$$
p(\mathbf{x}) = \prod_{f \in F} f(\mathbf{x}_f),
$$
where $ f(\mathbf{x}_f) $ is a factor over a subset of variables $ \mathbf{x}_f $.

We aim to find:
$$
\mathbf{x}_{\text{max}} = \arg \max_{\mathbf{x}} p(\mathbf{x}),
$$
and
$$
p_{\text{max}} = \max_{\mathbf{x}} p(\mathbf{x}).
$$

---

## Message Passing Rules

### From Factor $ f $ to Variable $ x $:
$$
\mu_{f \to x}(x) = \max_{\mathbf{x}_f \setminus x} \left[ \ln f(\mathbf{x}_f) + \sum_{x' \in \text{ne}(f) \setminus x} \mu_{x' \to f}(x') \right].
$$

### From Variable $ x $ to Factor $ f $:
$$
\mu_{x \to f}(x) = \sum_{f' \in \text{ne}(x) \setminus f} \mu_{f' \to x}(x).
$$

---

## Initialization

For leaf nodes:
- Variable-to-factor messages:
$$
\mu_{x \to f}(x) = 0.
$$
- Factor-to-variable messages:
$$
\mu_{f \to x}(x) = \ln f(x).
$$

---

## Root Node

At the root node $ x_r $, the maximum joint probability is computed as:
$$
p_{\text{max}} = \max_x \sum_{f \in \text{ne}(x_r)} \mu_{f \to x}(x),
$$
and the corresponding most probable configuration at the root node is:
$$
x_{\text{max}} = \arg \max_x \sum_{f \in \text{ne}(x_r)} \mu_{f \to x}(x).
$$

---

## Backward Pass

To compute the most probable configuration $ \mathbf{x}_{\text{max}} $, backtrack from the root node to the leaves using:
$$x_{\text{max}} = \arg \max_x \left[ \ln f(\mathbf{x}_f) + \sum_{x' \in \text{ne}(f) \setminus x} \mu_{x' \to f}(x') \right].
$$

---

## Algorithm Steps

1. **Forward Pass**:
   - Propagate messages from the leaves to the root.
   - Compute $ \mu_{f \to x}(x) $ and $ \mu_{x \to f}(x) $ for all variables and factors.

2. **Compute $ p_{\text{max}} $ and $ x_{\text{max}} $ at the root**:
   - Find the maximum joint probability and the configuration at the root.

3. **Backward Pass**:
   - Propagate messages from the root back to the leaves.
   - Use the messages to trace the most probable configuration $ \mathbf{x}_{\text{max}} $.

---

## Example

Consider a simple factor graph with factors $ f_a(x_1, x_2) $, $ f_b(x_2, x_3) $, and $ f_c(x_2, x_4) $. The message updates are:

1. **Factor-to-variable messages**:
$$
\mu_{f_a \to x_2}(x_2) = \max_{x_1} \left[ \ln f_a(x_1, x_2) + \mu_{x_1 \to f_a}(x_1) \right],
$$

$$
\mu_{f_c \to x_2}(x_2) = \max_{x_4} \left[ \ln f_c(x_2, x_4) + \mu_{x_4 \to f_c}(x_4) \right].
$$

2. **Variable-to-factor messages**:
$$
\mu_{x_2 \to f_b}(x_2) = \mu_{f_a \to x_2}(x_2) + \mu_{f_c \to x_2}(x_2).
$$

3. **Root computation**:
$$
\mu_{f_b \to x_3}(x_3) = \max_{x_2} \left[ \ln f_b(x_2, x_3) + \mu_{x_2 \to f_b}(x_2) \right].
$$

4. **Backward pass**:
Use the messages to determine:
$$
x_{\text{max}} = \arg \max_x \mu_{f \to x}(x).
$$


In [2]:
import numpy as np

# Define the factors as lambda functions
factors = {
    'fa': lambda x1, x2: np.exp(-((x1 - 1)**2 + (x2 - 2)**2)),
    'fb': lambda x2, x3: np.exp(-((x2 - 2)**2 + (x3 - 3)**2)),
    'fc': lambda x2, x4: np.exp(-((x2 - 2)**2 + (x4 - 1)**2)),
}

# Define variable range
x_range = np.arange(0, 5, 1)  # Variable range [0, 4]

# Message placeholders (log domain)
messages = {
    'x1->fa': np.zeros(len(x_range)),
    'fa->x2': np.zeros(len(x_range)),
    'x4->fc': np.zeros(len(x_range)),
    'fc->x2': np.zeros(len(x_range)),
    'x2->fb': np.zeros(len(x_range)),
    'fb->x3': np.zeros(len(x_range)),
    'x3->fb': np.zeros(len(x_range)),
    'fb->x2': np.zeros(len(x_range)),
    'x2->fa': np.zeros(len(x_range)),
    'fa->x1': np.zeros(len(x_range)),
    'x2->fc': np.zeros(len(x_range)),
    'fc->x4': np.zeros(len(x_range)),
}

# Forward pass: From leaves to root
for x2 in x_range:
    messages['fa->x2'][x2] = np.max([
        np.log(factors['fa'](x1, x2)) + messages['x1->fa'][x1] for x1 in x_range
    ])
    messages['fc->x2'][x2] = np.max([
        np.log(factors['fc'](x2, x4)) + messages['x4->fc'][x4] for x4 in x_range
    ])
    messages['x2->fb'][x2] = messages['fa->x2'][x2] + messages['fc->x2'][x2]

for x3 in x_range:
    messages['fb->x3'][x3] = np.max([
        np.log(factors['fb'](x2, x3)) + messages['x2->fb'][x2] for x2 in x_range
    ])

# Compute maximum joint probability at root
p_max = np.max(messages['fb->x3'])
x3_max = np.argmax(messages['fb->x3'])

# Backward pass: Trace back configuration
x2_max = np.argmax([
    messages['fa->x2'][x2] + messages['fc->x2'][x2] for x2 in x_range
])

x1_max = np.argmax([
    np.log(factors['fa'](x1, x2_max)) + messages['x2->fa'][x2_max] for x1 in x_range
])

x4_max = np.argmax([
    np.log(factors['fc'](x2_max, x4)) + messages['x2->fc'][x2_max] for x4 in x_range
])

print("Maximum joint probability:", p_max)
print("Most probable configuration: x1 =", x1_max, "x2 =", x2_max, "x3 =", x3_max, "x4 =", x4_max)


Maximum joint probability: 0.0
Most probable configuration: x1 = 1 x2 = 2 x3 = 3 x4 = 1


![image.png](attachment:image.png)

Fig.53 A lattice, or trellis, diagram show- ing explicitly the K possible states (one per row of the diagram) for each of the variables xn in the chain model. In this illustration K = 3. The ar- row shows the direction of message passing in the max-product algorithm. For every state k of each variable xn (corresponding to column n of the diagram) the function φ(xn ) deﬁnes a unique state at the previous variable, indicated by the black lines. The two paths through the lattice correspond to conﬁgurations that give the global maximum of the joint probability distribution, and either of these can be found by tracing back along the black lines in the opposite direction to the arrow.

## Max-Sum Algorithm with Backtracking

The **max-sum algorithm** finds the most probable configuration $ \mathbf{x}_{\text{max}} $ of variables in a tree-structured factor graph. This description includes the forward message-passing phase and the backtracking step for consistent solutions.

---

## Forward Message Passing

For a chain of variables $ \{x_1, x_2, \ldots, x_N\} $:

### Messages from Variable to Factor:
$$
\mu_{x_n \to f_{n,n+1}}(x_n) = \mu_{f_{n-1,n} \to x_n}(x_n).
$$

### Messages from Factor to Variable:
$$
\mu_{f_{n-1,n} \to x_n}(x_n) = \max_{x_{n-1}} \left[ \ln f_{n-1,n}(x_{n-1}, x_n) + \mu_{x_{n-1} \to f_{n-1,n}}(x_{n-1}) \right].
$$

#### Initial Message:
At the leaf node \( x_1 \):

$$
\mu_{x_1 \to f_{1,2}}(x_1) = 0.
$$

---

## Finding the Most Probable Value at the Root

At the root node $ x_N $, the most probable value is given by:
$$
x_N^{\text{max}} = \arg \max_{x_N} \mu_{f_{N-1,N} \to x_N}(x_N).
$$

---

## Backtracking for Consistent States

During the forward pass, store the values of $ x_{n-1} $ that maximize the probability at each step:
$$
\phi(x_n) = \arg \max_{x_{n-1}} \left[ \ln f_{n-1,n}(x_{n-1}, x_n) + \mu_{x_{n-1} \to f_{n-1,n}}(x_{n-1}) \right].
$$

To determine the most probable configuration $ \mathbf{x}_{\text{max}} = \{x_1^{\text{max}}, x_2^{\text{max}}, \ldots, x_N^{\text{max}}\} $, propagate backward using:
$$
x_{n-1}^{\text{max}} = \phi(x_n^{\text{max}}).
$$

---

## Trellis Diagram Representation

The computation can be visualized in a **trellis diagram** where:
- Each column represents states of a variable $ x_n $.
- Each edge between states represents the maximizing configuration based on $ \phi(x_n) $.

---

## Generalization to Tree-Structured Factor Graphs

For a factor node $ f $ connected to variable $ x $:
$$
\mu_{f \to x}(x) = \max_{\{x_1, \ldots, x_M\} \setminus x} \left[ \ln f(x, x_1, \ldots, x_M) + \sum_{x_i \in \text{ne}(f) \setminus x} \mu_{x_i \to f}(x_i) \right].
$$

During the forward pass, store the maximizing values for backtracking. After determining $ x_{\text{max}} $ at the root, use backtracking to compute the consistent maximizing configuration.

---

## Special Case: Hidden Markov Models

In the context of Hidden Markov Models (HMMs), the max-sum algorithm is called the **Viterbi algorithm**. It is used to find the most probable sequence of hidden states given observed evidence.

---

## Observed Variables

When evidence is included, observed variables are clamped to their observed values. Maximization is performed over the hidden variables, which can be formalized by including identity functions into the factor functions.

---

## Comparison to Iterated Conditional Modes (ICM)

- **ICM Algorithm**:
  - Simpler: Each step maximizes the conditional distribution at a single node.
  - Not guaranteed to find a global maximum, even for tree-structured graphs.

- **Max-Sum Algorithm**:
  - Messages are functions of variables, involving $ K $ values for each state.
  - Guaranteed to find a global maximum for tree-structured graphs.


In [4]:
import numpy as np

def max_sum_algorithm(factors, tree_structure, root):
    """
    Implements the max-sum algorithm on a tree-structured factor graph.

    Parameters:
    - factors: dict of factor functions. Each factor is keyed by a tuple of its connected variables
      and returns log probabilities. For example:
        {('x1', 'x2'): lambda x1, x2: np.log(psi_x1x2(x1, x2))}
    - tree_structure: dict representing the adjacency list of the factor graph.
      Example: {'x1': ['x2'], 'x2': ['x1', 'x3'], 'x3': ['x2']}
    - root: the variable at which to start backtracking after the forward pass.

    Returns:
    - max_prob: the maximum log-probability of the joint configuration.
    - max_config: a dict mapping each variable to its value in the most probable configuration.
    """
    # Messages from variables to factors and factors to variables
    messages_var_to_factor = {}
    messages_factor_to_var = {}

    # Store backtracking information
    backtrack = {}

    # Helper function to find neighbors excluding a specific node
    def neighbors(node, exclude=None):
        return [n for n in tree_structure[node] if n != exclude]

    # Forward pass: propagate messages from leaves to the root
    def forward_pass(current, parent=None):
        # Process all children first (post-order traversal)
        for neighbor in neighbors(current, exclude=parent):
            forward_pass(neighbor, current)
        
        # Compute messages from factors to variables
        for neighbor in neighbors(current, exclude=parent):
            factor = tuple(sorted((current, neighbor)))
            max_values = []
            argmax_values = []

            # Maximize over the other variable
            for x_current in range(len(factors[factor](np.arange(2), np.arange(2))[0])):
                max_prob = -np.inf
                best_arg = None
                for x_neighbor in range(len(factors[factor](np.arange(2), np.arange(2))[1])):
                    score = factors[factor](x_current, x_neighbor)
                    score += messages_var_to_factor.get((neighbor, factor), {}).get(x_neighbor, 0)
                    if score > max_prob:
                        max_prob = score
                        best_arg = x_neighbor

                max_values.append(max_prob)
                argmax_values.append(best_arg)

            # Store the message and backtracking info
            messages_factor_to_var[(factor, current)] = max_values
            backtrack[(current, neighbor)] = argmax_values

        # Compute messages from variables to factors
        for neighbor in neighbors(current, exclude=parent):
            factor = tuple(sorted((current, neighbor)))
            message = []
            for x_current in range(len(factors[factor](np.arange(2), np.arange(2))[0])):
                total_score = 0
                for other_neighbor in neighbors(current, exclude=neighbor):
                    other_factor = tuple(sorted((current, other_neighbor)))
                    total_score += messages_factor_to_var.get((other_factor, current), {}).get(x_current, 0)
                message.append(total_score)
            messages_var_to_factor[(current, factor)] = message

    # Perform forward pass starting from any leaf
    forward_pass(root)

    # Backtracking: find the most probable configuration
    max_config = {}
    max_prob = -np.inf
    for x_root in range(len(messages_factor_to_var[(tuple(sorted((root, neighbors(root)[0]))), root)])):
        prob = messages_factor_to_var[(tuple(sorted((root, neighbors(root)[0]))), root)][x_root]
        if prob > max_prob:
            max_prob = prob
            max_config[root] = x_root

    # Backtracking through the graph
    def backtrack_pass(current, parent=None):
        for neighbor in neighbors(current, exclude=parent):
            max_config[neighbor] = backtrack[(current, neighbor)][max_config[current]]
            backtrack_pass(neighbor, current)

    backtrack_pass(root)

    return max_prob, max_config


## Exact Inference in General Graphs

The **sum-product** and **max-sum** algorithms provide efficient and exact solutions to inference problems in tree-structured graphs. However, many practical applications involve graphs with loops. For such cases, the **junction tree algorithm** offers an exact inference procedure. Below is a summary of the key steps involved.

---

### Key Steps in the Junction Tree Algorithm

1. **Moralization** (if starting from a directed graph):
   - Convert the directed graph into an undirected graph by *moralizing* it.
   - In this step, parents of each node are connected by edges, and all directed edges are replaced with undirected ones.

2. **Triangulation**:
   - Eliminate all chord-less cycles containing four or more nodes by adding edges (chords).
   - For example, in the graph below, the cycle $ A \to C \to B \to D \to A $ is chord-less. Adding an edge between $ A $ and $ B $ or between $ C $ and $ D $ removes the chord-less cycle.

3. **Construct the Join Tree**:
   - From the triangulated graph, create a new **tree-structured undirected graph** called a join tree:
     - Nodes correspond to the **maximal cliques** of the triangulated graph.
     - Links connect pairs of cliques that share variables.
   - **Maximal Spanning Tree**:
     - Select the spanning tree such that the weight of the tree is maximized.
     - Weight of a link = number of variables shared between the two cliques it connects.
     - Weight of the tree = sum of the weights of its links.
   - Condense the tree:
     - Absorb any clique that is a subset of another into the larger clique to form the **junction tree**.

4. **Running Intersection Property**:
   - The triangulation ensures that the junction tree satisfies the *running intersection property*:
     - If a variable appears in two cliques, it also appears in all cliques on the path connecting those two cliques.
   - This ensures consistent inference about variables across the graph.

5. **Message Passing**:
   - Apply a two-stage message passing algorithm (similar to the sum-product algorithm) on the junction tree:
     - Compute marginals and conditionals efficiently.

---

### Computational Complexity and Treewidth

- The computational cost of the junction tree algorithm depends on the **size of the largest clique**:
  - For discrete variables, cost grows exponentially with the size of the largest clique.
- **Treewidth**:
  - Defined as $ \text{treewidth} = \text{size of the largest clique} - 1 $.
  - Ensures that a tree has a treewidth of 1.
  - A graph with high treewidth makes the junction tree algorithm impractical.

---

### Summary

The junction tree algorithm uses graphical operations to organize computations, exploiting factorization properties of the distribution. While it is exact and efficient for arbitrary graphs, its computational cost depends on the largest clique size. For high treewidth graphs, the algorithm can become infeasible.


In [6]:
import networkx as nx
from itertools import combinations

def moralize_graph(directed_graph):
    """
    Convert a directed graph to an undirected moralized graph.
    """
    moral_graph = directed_graph.to_undirected()
    for node in directed_graph.nodes:
        parents = list(directed_graph.predecessors(node))
        for parent1, parent2 in combinations(parents, 2):
            moral_graph.add_edge(parent1, parent2)
    return moral_graph

def triangulate_graph(graph):
    """
    Triangulate the graph by adding edges to eliminate chord-less cycles.
    """
    triangulated_graph = graph.copy()
    while True:
        chordless_cycle = find_chordless_cycle(triangulated_graph)
        if not chordless_cycle:
            break
        add_chord(triangulated_graph, chordless_cycle)
    return triangulated_graph

def find_chordless_cycle(graph):
    """
    Find a chord-less cycle of length >= 4 in the graph.
    """
    for cycle in nx.cycle_basis(graph):
        if len(cycle) >= 4:
            if not has_chord(graph, cycle):
                return cycle
    return None

def has_chord(graph, cycle):
    """
    Check if a cycle has a chord (an edge not part of the cycle connecting two nodes in the cycle).
    """
    cycle_edges = {(cycle[i], cycle[(i + 1) % len(cycle)]) for i in range(len(cycle))}
    cycle_edges |= {(b, a) for a, b in cycle_edges}  # Add reverse direction for undirected comparison

    for u, v in combinations(cycle, 2):
        if graph.has_edge(u, v) and (u, v) not in cycle_edges:
            return True
    return False


def add_chord(graph, cycle):
    """
    Add a chord to a chord-less cycle to eliminate it.
    """
    for u, v in combinations(cycle, 2):
        if not graph.has_edge(u, v):
            graph.add_edge(u, v)
            return

def construct_join_tree(triangulated_graph):
    """
    Construct a join tree from the triangulated graph.
    """
    cliques = list(nx.find_cliques(triangulated_graph))
    clique_graph = nx.Graph()
    for i, clique1 in enumerate(cliques):
        for j, clique2 in enumerate(cliques):
            if i < j:
                weight = len(set(clique1) & set(clique2))
                if weight > 0:
                    clique_graph.add_edge(tuple(clique1), tuple(clique2), weight=weight)
    max_spanning_tree = nx.maximum_spanning_tree(clique_graph)
    return condense_tree(max_spanning_tree)

def condense_tree(tree):
    """
    Condense the join tree by absorbing smaller cliques into larger ones.
    """
    nodes_to_remove = []
    for node1 in tree.nodes:
        for node2 in tree.nodes:
            if set(node1).issubset(node2) and node1 != node2:
                nodes_to_remove.append(node1)
                break
    tree.remove_nodes_from(nodes_to_remove)
    return tree

def message_passing(join_tree):
    """
    Perform message passing on the join tree to compute marginals.
    """
    messages = {}
    for edge in nx.dfs_edges(join_tree):
        messages[edge] = compute_message(join_tree, edge)
    return messages

def compute_message(tree, edge):
    """
    Compute the message for a given edge in the tree.
    """
    source, target = edge
    shared_vars = set(source) & set(target)
    # Simplified message computation placeholder
    message = f"Message from {source} to {target} over {shared_vars}"
    return message

# Example Usage
if __name__ == "__main__":
    # Define a directed graph
    dg = nx.DiGraph()
    dg.add_edges_from([
        ("A", "C"), ("B", "C"), ("C", "D"), ("D", "E"), ("E", "F"), ("F", "G")
    ])
    
    # Step 1: Moralize the graph
    moral_graph = moralize_graph(dg)
    
    # Step 2: Triangulate the graph
    triangulated_graph = triangulate_graph(moral_graph)
    
    # Step 3: Construct the join tree
    join_tree = construct_join_tree(triangulated_graph)
    
    # Step 4: Perform message passing
    messages = message_passing(join_tree)
    
    # Outputs
    print("Messages:")
    for edge, msg in messages.items():
        print(msg)


Messages:
Message from ('F', 'G') to ('F', 'E') over {'F'}
Message from ('F', 'E') to ('E', 'D') over {'E'}
Message from ('E', 'D') to ('C', 'D') over {'D'}
Message from ('C', 'D') to ('C', 'B', 'A') over {'C'}
