In [None]:
'''
 * Copyright (c) 2008 Radhamadhab Dalai
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
'''

## Loopy Belief Propagation and Learning the Graph Structure

## Loopy Belief Propagation
For many practical problems, exact inference is infeasible, and approximation methods become necessary. Two major categories of approximation methods are:

1. **Variational methods**: Deterministic approaches discussed in detail in Chapter 10.
2. **Sampling methods**: Stochastic Monte Carlo methods, explored in Chapter 11.

Here, we explore *loopy belief propagation*, a simple approach to approximate inference in graphs with loops. This builds on exact inference in tree-structured graphs using the **sum-product algorithm**.

### Key Idea
The **loopy belief propagation algorithm** applies the sum-product algorithm to graphs with loops, despite no guarantee of correctness. This is possible because the message-passing rules are purely local. However, due to loops, information can circulate multiple times around the graph. 

- For some models, the algorithm converges.
- For others, it does not converge.

### Message Passing Schedule
A message passing schedule determines how and when messages are passed in the algorithm. 

- **Initialization**:  
  Begin with an initial message, defined as the unit function, across every link in both directions.
  
- **Schedules**:  
  - **Flooding schedule**: Messages are passed simultaneously across every link in both directions at each time step.
  - **Serial schedules**: Messages are passed one at a time.

### Pending Messages
Following Kschischnang et al. (2001), a message is considered **pending** if:
- Node $a$ has received a new message on any of its links since the last time it sent a message to $b$.

When a node receives a new message, it creates pending messages on all other links.

### Convergence and Marginals
- For **tree-structured graphs**, the algorithm terminates once every message has been passed across each link in both directions, and the marginal probabilities are exact.
- For **graphs with loops**, termination is not guaranteed, but the algorithm often converges in practice within a reasonable time. 

Once the algorithm has converged or stopped, approximate local marginals are computed using the product of the most recently received incoming messages at each node.

### Effectiveness
The algorithm's performance varies:
- In some cases, it provides poor results.
- In other cases, it is highly effective, e.g., for decoding error-correcting codes.

## Learning the Graph Structure
In addition to inference, there is interest in learning the structure of the graph from data (Friedman and Koller, 2003).

### Bayesian Approach
From a Bayesian perspective, the posterior distribution over graph structures is:
$$
p(m|D) \propto p(m)p(D|m)
$$
where:
- $m$: Graph structure.
- $D$: Observed data set.
- $p(m)$: Prior over graphs.
- $p(D|m)$: Model evidence, which serves as the score for each graph structure.

### Challenges
1. **Marginalization**: Computing $p(D|m)$ involves marginalizing over latent variables, which is computationally challenging.
2. **Structure Space**: The number of possible graph structures grows exponentially with the number of nodes, necessitating heuristic approaches for practical exploration.


In [1]:
import numpy as np

class LoopyBeliefPropagation:
    def __init__(self, factor_graph, max_iter=100, tol=1e-6):
        """
        Initialize the loopy belief propagation algorithm.
        
        Parameters:
        - factor_graph: Dictionary containing 'factors' and 'edges'. 
                        Each factor is represented as a dictionary with 
                        'variables' (list of connected variables) and 'table' (probabilities).
        - max_iter: Maximum number of iterations to run.
        - tol: Tolerance for convergence.
        """
        self.factors = factor_graph['factors']
        self.edges = factor_graph['edges']
        self.max_iter = max_iter
        self.tol = tol
        self.messages = self._initialize_messages()
    
    def _initialize_messages(self):
        """Initialize all messages to uniform distributions."""
        messages = {}
        for edge in self.edges:
            var, factor = edge
            messages[(var, factor)] = np.ones_like(self.factors[factor]['table'].shape)
            messages[(factor, var)] = np.ones_like(self.factors[factor]['table'].shape)
        return messages
    
    def _normalize(self, message):
        """Normalize a message to sum to 1."""
        return message / np.sum(message)
    
    def _compute_message_factor_to_variable(self, factor, variable):
        """Compute the message from a factor node to a variable node."""
        factor_info = self.factors[factor]
        other_vars = [v for v in factor_info['variables'] if v != variable]
        factor_table = factor_info['table']
        
        # Aggregate messages from other variables
        incoming_messages = []
        for other_var in other_vars:
            incoming_messages.append(self.messages[(other_var, factor)])
        
        # Multiply the factor table by incoming messages along relevant dimensions
        message = factor_table.copy()
        for idx, msg in enumerate(incoming_messages):
            axes_to_sum = tuple(range(message.ndim))  # Axes to sum over
            message = np.sum(message * msg, axis=axes_to_sum)
        
        return self._normalize(message)
    
    def _compute_message_variable_to_factor(self, variable, factor):
        """Compute the message from a variable node to a factor node."""
        incoming_messages = []
        for other_factor in [f for f, v in self.edges if v == variable and f != factor]:
            incoming_messages.append(self.messages[(other_factor, variable)])
        
        # Combine incoming messages
        combined_message = np.prod(incoming_messages, axis=0)
        return self._normalize(combined_message)
    
    def run(self):
        """Run the loopy belief propagation algorithm."""
        for iteration in range(self.max_iter):
            max_change = 0
            
            # Update all messages
            for edge in self.edges:
                var, factor = edge
                
                # Message from variable to factor
                new_message_vf = self._compute_message_variable_to_factor(var, factor)
                change_vf = np.linalg.norm(new_message_vf - self.messages[(var, factor)])
                self.messages[(var, factor)] = new_message_vf
                max_change = max(max_change, change_vf)
                
                # Message from factor to variable
                new_message_fv = self._compute_message_factor_to_variable(factor, var)
                change_fv = np.linalg.norm(new_message_fv - self.messages[(factor, var)])
                self.messages[(factor, var)] = new_message_fv
                max_change = max(max_change, change_fv)
            
            # Check for convergence
            if max_change < self.tol:
                print(f"Converged in {iteration + 1} iterations.")
                break
        else:
            print("Reached maximum iterations without convergence.")
    
    def compute_marginals(self):
        """Compute the approximate marginal probabilities for all variables."""
        marginals = {}
        for variable in set(v for f, v in self.edges):
            incoming_messages = [
                self.messages[(factor, variable)]
                for factor, var in self.edges if var == variable
            ]
            marginal = np.prod(incoming_messages, axis=0)
            marginals[variable] = self._normalize(marginal)
        return marginals
    
    
    


In [2]:
# Define a simple factor graph
factor_graph = {
    "factors": {
        "f1": {"variables": ["X1", "X2"], "table": np.array([[0.8, 0.2], [0.1, 0.9]])},
        "f2": {"variables": ["X2", "X3"], "table": np.array([[0.7, 0.3], [0.4, 0.6]])}
    },
    "edges": [("X1", "f1"), ("X2", "f1"), ("X2", "f2"), ("X3", "f2")]
}

# Initialize the loopy belief propagation algorithm
lbp = LoopyBeliefPropagation(factor_graph)

# Run the algorithm
lbp.run()

# Compute marginals
marginals = lbp.compute_marginals()

# Print results
for var, marginal in marginals.items():
    print(f"Marginal for {var}: {marginal}")


Converged in 1 iterations.
Marginal for f2: 1.0
Marginal for f1: 1.0


In [6]:
class LoopyBeliefPropagation:
    def __init__(self, factor_graph, max_iter=100, tol=1e-6):
        """
        Initialize the loopy belief propagation algorithm.
        
        Parameters:
        - factor_graph: Dictionary containing 'factors' and 'edges'. 
                        Each factor is represented as a dictionary with 
                        'variables' (list of connected variables) and 'table' (list of lists for probabilities).
        - max_iter: Maximum number of iterations to run.
        - tol: Tolerance for convergence.
        """
        self.factors = factor_graph['factors']
        self.edges = factor_graph['edges']
        self.max_iter = max_iter
        self.tol = tol
        self.messages = self._initialize_messages()
    
    def _initialize_messages(self):
        """Initialize all messages to uniform distributions."""
        messages = {}
        for edge in self.edges:
            var, factor = edge
            messages[(var, factor)] = [1] * len(self.factors[factor]['table'][0])
            messages[(factor, var)] = [1] * len(self.factors[factor]['table'])
        return messages

    def _normalize(self, message):
        """Normalize a message to sum to 1."""
        total = sum(message)
        if total == 0:
            return [0] * len(message)
        return [m / total for m in message]
    
    def _compute_message_factor_to_variable(self, factor, variable):
        """Compute the message from a factor node to a variable node."""
        factor_info = self.factors[factor]
        other_vars = [v for v in factor_info['variables'] if v != variable]
        factor_table = factor_info['table']
        
        # Aggregate messages from other variables
        incoming_messages = []
        for other_var in other_vars:
            incoming_messages.append(self.messages[(other_var, factor)])
        
        # Combine factor table with incoming messages
        message = [0] * len(factor_table[0])
        for i, row in enumerate(factor_table):
            product = row[:]
            for idx, msg in enumerate(incoming_messages):
                product = [p * msg[j] for j, p in enumerate(product)]
            message = [m + p for m, p in zip(message, product)]
        
        return self._normalize(message)
    
    def _compute_message_variable_to_factor(self, variable, factor):
        """Compute the message from a variable node to a factor node."""
        incoming_messages = []
        for other_factor, var in self.edges:
            if var == variable and other_factor != factor:
                incoming_messages.append(self.messages[(other_factor, variable)])
        
        # Combine incoming messages
        combined_message = [1] * len(incoming_messages[0])
        for msg in incoming_messages:
            combined_message = [cm * m for cm, m in zip(combined_message, msg)]
        return self._normalize(combined_message)
    
    def run(self):
        """Run the loopy belief propagation algorithm."""
        for iteration in range(self.max_iter):
            max_change = 0
            
            # Update all messages
            for edge in self.edges:
                var, factor = edge
                
                # Message from variable to factor
                new_message_vf = self._compute_message_variable_to_factor(var, factor)
                change_vf = max(abs(new_message_vf[i] - self.messages[(var, factor)][i])
                                for i in range(len(new_message_vf)))
                self.messages[(var, factor)] = new_message_vf
                max_change = max(max_change, change_vf)
                
                # Message from factor to variable
                new_message_fv = self._compute_message_factor_to_variable(factor, var)
                change_fv = max(abs(new_message_fv[i] - self.messages[(factor, var)][i])
                                for i in range(len(new_message_fv)))
                self.messages[(factor, var)] = new_message_fv
                max_change = max(max_change, change_fv)
            
            # Check for convergence
            if max_change < self.tol:
                print(f"Converged in {iteration + 1} iterations.")
                break
        else:
            print("Reached maximum iterations without convergence.")
    
    def compute_marginals(self):
        """Compute the approximate marginal probabilities for all variables."""
        marginals = {}
        for variable in set(v for f, v in self.edges):
            incoming_messages = [
                self.messages[(factor, variable)]
                for factor, var in self.edges if var == variable
            ]
            marginal = [1] * len(incoming_messages[0])
            for msg in incoming_messages:
                marginal = [m * im for m, im in zip(marginal, msg)]
            marginals[variable] = self._normalize(marginal)
        return marginals


In [5]:
class LoopyBeliefPropagation:
    def __init__(self, factor_graph, max_iter=100, tol=1e-6):
        """
        Initialize the loopy belief propagation algorithm.
        
        Parameters:
        - factor_graph: Dictionary containing 'factors' and 'edges'. 
                        Each factor is represented as a dictionary with 
                        'variables' (list of connected variables) and 'table' (list of lists for probabilities).
        - max_iter: Maximum number of iterations to run.
        - tol: Tolerance for convergence.
        """
        self.factors = factor_graph['factors']
        self.edges = factor_graph['edges']
        self.max_iter = max_iter
        self.tol = tol
        self.messages = self._initialize_messages()
    
    def _initialize_messages(self):
        """Initialize all messages to uniform distributions."""
        messages = {}
        for edge in self.edges:
            var, factor = edge
            messages[(var, factor)] = [1] * len(self.factors[factor]['table'][0])
            messages[(factor, var)] = [1] * len(self.factors[factor]['table'])
        return messages

    def _normalize(self, message):
        """Normalize a message to sum to 1."""
        total = sum(message)
        if total == 0:
            return [0] * len(message)
        return [m / total for m in message]
    
    def _compute_message_factor_to_variable(self, factor, variable):
        """Compute the message from a factor node to a variable node."""
        factor_info = self.factors[factor]
        other_vars = [v for v in factor_info['variables'] if v != variable]
        factor_table = factor_info['table']
        
        # Aggregate messages from other variables
        incoming_messages = []
        for other_var in other_vars:
            incoming_messages.append(self.messages[(other_var, factor)])
        
        # Combine factor table with incoming messages
        message = [0] * len(factor_table[0])
        for i, row in enumerate(factor_table):
            product = row[:]
            for idx, msg in enumerate(incoming_messages):
                product = [p * msg[j] for j, p in enumerate(product)]
            message = [m + p for m, p in zip(message, product)]
        
        return self._normalize(message)
    
    def _compute_message_variable_to_factor(self, variable, factor):
        """Compute the message from a variable node to a factor node."""
        # Collect incoming messages from all factors connected to the variable, excluding the current factor
        incoming_messages = [
            self.messages[(other_factor, variable)]
            for other_factor, var in self.edges if var == variable and other_factor != factor
        ]
        
        # If there are no incoming messages (edge case), return a uniform message
        if not incoming_messages:
            return [1] * len(self.messages[(factor, variable)])
        
        # Combine incoming messages
        combined_message = [1] * len(incoming_messages[0])
        for msg in incoming_messages:
            combined_message = [cm * m for cm, m in zip(combined_message, msg)]
        
        return self._normalize(combined_message)
    
    def run(self):
        """Run the loopy belief propagation algorithm."""
        for iteration in range(self.max_iter):
            max_change = 0
            
            # Update all messages
            for edge in self.edges:
                var, factor = edge
                
                # Message from variable to factor
                new_message_vf = self._compute_message_variable_to_factor(var, factor)
                change_vf = max(abs(new_message_vf[i] - self.messages[(var, factor)][i])
                                for i in range(len(new_message_vf)))
                self.messages[(var, factor)] = new_message_vf
                max_change = max(max_change, change_vf)
                
                # Message from factor to variable
                new_message_fv = self._compute_message_factor_to_variable(factor, var)
                change_fv = max(abs(new_message_fv[i] - self.messages[(factor, var)][i])
                                for i in range(len(new_message_fv)))
                self.messages[(factor, var)] = new_message_fv
                max_change = max(max_change, change_fv)
            
            # Check for convergence
            if max_change < self.tol:
                print(f"Converged in {iteration + 1} iterations.")
                break
        else:
            print("Reached maximum iterations without convergence.")
    
    def compute_marginals(self):
        """Compute the approximate marginal probabilities for all variables."""
        marginals = {}
        for variable in set(v for f, v in self.edges):
            incoming_messages = [
                self.messages[(factor, variable)]
                for factor, var in self.edges if var == variable
            ]
            marginal = [1] * len(incoming_messages[0])
            for msg in incoming_messages:
                marginal = [m * im for m, im in zip(marginal, msg)]
            marginals[variable] = self._normalize(marginal)
        return marginals


# Example usage
if __name__ == "__main__":
    # Define a simple factor graph
    factor_graph = {
        "factors": {
            "f1": {"variables": ["X1", "X2"], "table": [[0.8, 0.2], [0.1, 0.9]]},
            "f2": {"variables": ["X2", "X3"], "table": [[0.7, 0.3], [0.4, 0.6]]}
        },
        "edges": [("X1", "f1"), ("X2", "f1"), ("X2", "f2"), ("X3", "f2")]
    }

    # Initialize the loopy belief propagation algorithm
    lbp = LoopyBeliefPropagation(factor_graph)

    # Run the algorithm
    lbp.run()

    # Compute marginals
    marginals = lbp.compute_marginals()

    # Print results
    for var, marginal in marginals.items():
        print(f"Marginal for {var}: {marginal}")


Converged in 2 iterations.
Marginal for f2: [0.5, 0.5]
Marginal for f1: [0.5, 0.5]
