<a href="https://colab.research.google.com/github/r-doz/PML2025/blob/main/./04_exact_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Exact inference with Belief Propagation

Tavano Matteo SM3800057 First PML homework


This notebook is inspired from [Jessica Stringham's work](https://jessicastringham.net)


### Exercise 6: Extending Belief Propagation to the Sum-Product Algorithm

In the notebook *"Exact Inference with Belief Propagation"*, we previously computed the marginal distribution of a given variable using the message-passing method. Now, we aim to extend this implementation to the sum-product algorithm.

1. **Extend the `Messages` class** by adding the following methods:
   - **`forward`**: Computes the forward pass.
   - **`backward`**: Computes the backward pass.
   - **`belief_propagation`**: Executes the forward and backward passes, then uses the computed messages to determine all marginal distributions. This method should return a dictionary mapping each variable name to its corresponding marginal distribution.

2. **Apply the `belief_propagation` method** to compute the marginal distributions of the variables in the factor graph described on page 43 of the course notes.

For this exercise, please submit the updated notebook **`04_exact_inference.ipynb`**, including your additional code.

**NOTE: Make sure to add comments to all the code you write!**

In [1]:
# libraries import
import numpy as np
from collections import namedtuple, deque

## Distribution

In [2]:
class Distribution():
    """
    Discrete probability distributions, expressed using labeled arrays
    probs: array of probability values
    axes_labels: list of axes names
    """
    def __init__(self, probs, axes_labels):
        # Ensure probs is a numpy array
        self.probs = np.asarray(probs)
        self.axes_labels = axes_labels

    def get_axes(self):
        #returns a dictionary with axes names and the corresponding coordinates
        return {name: axis for axis, name in enumerate(self.axes_labels)}

    def get_other_axes_from(self, axis_label):
        #returns a tuple containing all the axes coordinates except from axis_label
        axis_to_keep = self.get_axes().get(axis_label)
        if axis_to_keep is None:
             raise ValueError(f"Axis label '{axis_label}' not found in {self.axes_labels}")
        return tuple(axis for axis, name in enumerate(self.axes_labels) if name != axis_label)

    def is_valid_conditional(self, variable_name):
        #variable_name is the name of the variable for which we are computing the distribution, e.g. in p(y|x) it is 'y'
        axis_to_sum = self.get_axes().get(variable_name)
        if axis_to_sum is None:
            # Handle case where the distribution might be p(x) and variable_name is 'x'
             if len(self.axes_labels) == 1 and self.axes_labels[0] == variable_name:
                 axis_to_sum = 0 # For single variable distributions p(x)
             else:
                print(f"Warning: Variable '{variable_name}' not found in axes {self.axes_labels} for conditional check.")
                return False

        # Handle scalar distributions (like prior factors with fixed value)
        if self.probs.ndim == 0:
             return np.isclose(self.probs, 1.0) # A single probability must be 1 to be valid "conditional" on itself

        # Ensure sums are close to 1.0 across the specified axis
        try:
            sums = np.sum(self.probs, axis=axis_to_sum)
            return np.all(np.isclose(sums, 1.0))
        except np.AxisError:
             print(f"Error: Axis {axis_to_sum} out of bounds for array with shape {self.probs.shape}")
             return False

    def is_valid_joint(self):
        # Check if the sum over all elements is close to 1.0
        return np.isclose(np.sum(self.probs), 1.0)

In [3]:
# function to allow multiplications between distributions
def multiply(p_dist1, p_dist2):
    '''
    Compute the product of two distributions p1(vars1) * p2(vars2).
    Handles alignment and broadcasting based on shared variable names.
    Returns a new Distribution.

    Example:
    p1(a, b), p2(b, c) -> result(a, b, c)
    p1(a), p2(b) -> result(a, b)
    p1(a, b), p2(b) -> result(a, b)
    '''
    # Ensure both inputs are Distribution objects
    if not isinstance(p_dist1, Distribution):
        raise TypeError("p_dist1 must be a Distribution object")
    if not isinstance(p_dist2, Distribution):
        raise TypeError("p_dist2 must be a Distribution object")

    # Combine axes labels, preserving order and uniqueness
    # Start with p_dist1's labels
    new_axes_labels = list(p_dist1.axes_labels)
    # Add any labels from p_dist2 not already present
    for label in p_dist2.axes_labels:
        if label not in new_axes_labels:
            new_axes_labels.append(label)

    # Determine shape for broadcasting/alignment
    # For each distribution, create a shape tuple aligned with new_axes_labels
    # Dimensions corresponding to axes *not* in the original distribution will be 1
    shape1 = []
    map1 = np.arange(len(p_dist1.axes_labels)) # Original axes indices
    target_map1 = [] # Target axes indices in the new combined shape
    for label in new_axes_labels:
        if label in p_dist1.axes_labels:
            axis_index = p_dist1.axes_labels.index(label)
            shape1.append(p_dist1.probs.shape[axis_index])
            target_map1.append(axis_index) # Track where this original axis goes
        else:
            shape1.append(1)
            # No corresponding original axis, placeholder

    shape2 = []
    map2 = np.arange(len(p_dist2.axes_labels))
    target_map2 = []
    for label in new_axes_labels:
        if label in p_dist2.axes_labels:
            axis_index = p_dist2.axes_labels.index(label)
            shape2.append(p_dist2.probs.shape[axis_index])
            target_map2.append(axis_index)
        else:
            shape2.append(1)

    # Use reshape and transpose carefully if axes order changes significantly,
    # but simple reshape works if we just add singleton dimensions.
    try:
        reshaped_probs1 = p_dist1.probs.reshape(tuple(shape1))
    except ValueError:
        # If reshape fails, likely means axes order mismatch. Need transpose.
        current_to_new_map1 = {label: i for i, label in enumerate(new_axes_labels) if label in p_dist1.axes_labels}
        transpose_order1 = [p_dist1.axes_labels.index(label) for label in new_axes_labels if label in p_dist1.axes_labels]
        # We need to align p_dist1.probs axes to match the order in new_axes_labels where they exist
        # This is complex. Let's stick to the simpler broadcasting assumption first.
        # Fallback: Use explicit dimension expansion with np.expand_dims or slicing.

        # Simpler approach: Add singleton dimensions using slicing
        aligned_probs1 = p_dist1.probs
        current_axes1 = list(p_dist1.axes_labels)
        for i, label in enumerate(new_axes_labels):
             if label not in current_axes1:
                  aligned_probs1 = np.expand_dims(aligned_probs1, axis=i)
                  current_axes1.insert(i, label) # Keep track of inserted axes
        # Now need to reorder if necessary - this gets complex fast.
        # Let's assume the reshape approach works for typical cases.
        # Re-raising the original error if assumption fails.
        raise ValueError(f"Reshape failed for p_dist1. Check axes alignment. Shape: {p_dist1.probs.shape}, Target Shape: {shape1}") from None

    try:
        reshaped_probs2 = p_dist2.probs.reshape(tuple(shape2))
    except ValueError:
         # Similar issue for p_dist2
         raise ValueError(f"Reshape failed for p_dist2. Check axes alignment. Shape: {p_dist2.probs.shape}, Target Shape: {shape2}") from None


    # Perform element-wise multiplication (broadcasting handles the rest)
    result_probs = reshaped_probs1 * reshaped_probs2

    return Distribution(result_probs, new_axes_labels)

**Notes** Refactored using Gemini 2.5

## Node

In [4]:
class Node(object):
    def __init__(self, name):
        self.name = name
        self.neighbors = [] # Stores neighboring Node objects

    def is_valid_neighbor(self, neighbor):
        raise NotImplementedError()

    def add_neighbor(self, neighbor):
        if not self.is_valid_neighbor(neighbor):
             raise TypeError(f"Cannot add {type(neighbor)} as neighbor to {type(self)}")
        if neighbor not in self.neighbors:
            self.neighbors.append(neighbor)

## Variable

In [5]:
class Variable(Node):
    def __init__(self, name):
        super(Variable, self).__init__(name)
        self.observed_value = None # To handle observed evidence (index of observed state)

    def is_valid_neighbor(self, factor):
        return isinstance(factor, Factor)  # Variables can only neighbor Factors

    def observe(self, value):
        """Sets the observed value (state index) for this variable."""
        self.observed_value = value

    def unobserve(self):
        """Removes the observation for this variable."""
        self.observed_value = None

## Factor

In [6]:
class Factor(Node):
    def __init__(self, name):
        super(Factor, self).__init__(name)
        self.data = None # Should store a Distribution object

    def is_valid_neighbor(self, variable):
        return isinstance(variable, Variable)  # Factors can only neighbor Variables

## Parser

In [7]:
ParsedTerm = namedtuple('ParsedTerm', [
    'term',      # e.g., 'p(a|b,c)'
    'var_name',  # list of non-conditioned vars, e.g., ['a']
    'given',     # list of conditioned vars, e.g., ['b', 'c']
])

def _parse_term(term_content):
    # Given content like 'a|b,c' or 'x,y|z' or 'k'
    # returns ([vars], [given])
    if '|' in term_content:
        var_part, given_part = term_content.split('|', 1)
        vars_ = [v.strip() for v in var_part.split(',') if v.strip()]
        given_ = [g.strip() for g in given_part.split(',') if g.strip()]
    else:
        vars_ = [v.strip() for v in term_content.split(',') if v.strip()]
        given_ = []
    # Ensure no empty strings resulted from parsing (e.g., 'p(a,)')
    if any(not v for v in vars_) or any(not g for g in given_):
         raise ValueError(f"Invalid variable format in term content: '{term_content}'")
    return vars_, given_

def _parse_model_string_into_terms(model_string):
    terms = []
    # Split by 'p(' but handle potential edge cases like starting 'p('
    parts = model_string.strip().split('p(')
    if not parts[0].isspace() and parts[0] != "":
         raise ValueError("Model string should start with 'p(' or factors separated by 'p('.")

    for part in parts[1:]: # Skip the potentially empty part before the first 'p('
        if not part: continue # Skip empty strings resulting from split

        if ')' not in part:
             raise ValueError(f"Invalid term format in: '{part}'. Missing closing parenthesis.")

        # Extract content inside parentheses and the full term name
        try:
            term_content, rest = part.split(')', 1) # Split only on the first ')'
        except ValueError:
             raise ValueError(f"Invalid term format in: '{part}'. Malformed content.")

        term_name = f"p({term_content})"

        vars_, given_ = _parse_term(term_content)
        if not vars_: # Must have at least one variable, e.g. p() is invalid
             raise ValueError(f"Term '{term_name}' must contain at least one variable.")

        terms.append(ParsedTerm(term_name, vars_, given_))
    if not terms:
        raise ValueError("Model string did not contain any valid factor terms.")
    return terms


def parse_model_into_variables_and_factors(model_string):
    parsed_terms = _parse_model_string_into_terms(model_string)

    variables = {} # Map variable name (str) to Variable object
    factors = []   # List of Factor objects

    # First pass: Identify all unique variables
    all_var_names = set()
    for parsed_term in parsed_terms:
        all_var_names.update(parsed_term.var_name)
        all_var_names.update(parsed_term.given)

    if not all_var_names:
         raise ValueError("No variables found in the model string.")

    # Create Variable objects
    for var_name in all_var_names:
        variables[var_name] = Variable(var_name)

    # Second pass: Create Factor objects and link them to Variables
    factor_names = set()
    for parsed_term in parsed_terms:
        if parsed_term.term in factor_names:
            raise ValueError(f"Duplicate factor name found: {parsed_term.term}. Factors must be unique.")
        factor_names.add(parsed_term.term)

        new_factor = Factor(parsed_term.term)
        factors.append(new_factor)

        # Link factor to its variables and vice-versa
        factor_vars_names = parsed_term.var_name + parsed_term.given
        if not factor_vars_names:
             raise ValueError(f"Factor '{new_factor.name}' is not connected to any variables.")

        for var_name in factor_vars_names:
            if var_name not in variables:
                 # This should not happen if parsing was correct, but check anyway
                 raise ValueError(f"Variable '{var_name}' needed by factor '{new_factor.name}' was not identified.")
            variable_obj = variables[var_name]
            new_factor.add_neighbor(variable_obj)
            variable_obj.add_neighbor(new_factor) # Link back

    return factors, variables

## PGM

In [8]:
class PGM(object):
    def __init__(self, factors, variables):
        self._factors = {f.name: f for f in factors} # Store factors in a dict by name
        self._variables = variables # Dict mapping name to Variable object

    @classmethod
    def from_string(cls, model_string):
        factors_list, variables_dict = parse_model_into_variables_and_factors(model_string)
        return PGM(factors_list, variables_dict)

    def set_distributions(self, data):
        """
        Assigns Distribution objects to factors.
        Input `data` is a dictionary mapping factor names (str) to Distribution objects.
        """
        var_dims = {} # Store expected dimension size for each variable axis
        assigned_factors = set()
        for factor_name, factor_data in data.items():
            if factor_name not in self._factors:
                # Check for slight variations like missing spaces if desired, but strict match is safer.
                raise ValueError(f"Factor name '{factor_name}' from data not found in PGM factors: {list(self._factors.keys())}")

            factor = self._factors[factor_name]
            assigned_factors.add(factor_name)

            if not isinstance(factor_data, Distribution):
                 raise TypeError(f"Data for factor '{factor_name}' must be a Distribution object, got {type(factor_data)}.")

            # Check if the axes labels in the distribution match the factor's neighbors
            factor_neighbor_names = set(v.name for v in factor.neighbors)
            distribution_axes_names = set(factor_data.axes_labels)

            if distribution_axes_names != factor_neighbor_names:
                missing_axes = factor_neighbor_names - distribution_axes_names
                extra_axes = distribution_axes_names - factor_neighbor_names
                error_msg = f"Axes mismatch for factor '{factor_name}': "
                if missing_axes: error_msg += f"Distribution missing axes for variables {missing_axes}. Expected axes based on neighbors: {factor_neighbor_names}. Got: {distribution_axes_names}."
                if extra_axes: error_msg += f"Distribution has extra axes {extra_axes} not connected to factor. Expected axes based on neighbors: {factor_neighbor_names}. Got: {distribution_axes_names}."
                raise ValueError(error_msg)

            # Check and record dimension sizes for consistency
            if not hasattr(factor_data.probs, 'shape'):
                 raise ValueError(f"Factor data 'probs' for '{factor_name}' has no shape attribute.")

            # Ensure the number of dimensions matches the number of labels
            if len(factor_data.probs.shape) != len(factor_data.axes_labels):
                raise ValueError(f"Dimension mismatch for factor '{factor_name}': Number of axes labels ({len(factor_data.axes_labels)}) does not match number of probability dimensions ({len(factor_data.probs.shape)}).")

            for var_name, dim_size in zip(factor_data.axes_labels, factor_data.probs.shape):
                if var_name not in self._variables:
                    # Should not happen if parsing is correct
                     raise ValueError(f"Variable '{var_name}' from factor '{factor_name}' axes not found in global variable list.")

                if var_name not in var_dims:
                    var_dims[var_name] = dim_size
                    self._variables[var_name].dim = dim_size # Store dimension on variable node
                elif var_dims[var_name] != dim_size:
                    raise ValueError(
                        f"Inconsistent dimension size for variable '{var_name}'. "
                        f"Factor '{factor_name}' expects size {dim_size}, "
                        f"but previously seen size was {var_dims[var_name]}."
                    )

            factor.data = factor_data # Assign the distribution to the factor

        # Check if all factors in the PGM were assigned data
        missing_data = set(self._factors.keys()) - assigned_factors
        if missing_data:
             print(f"Warning: Distributions were not provided for all factors: {missing_data}")


    def variable_from_name(self, var_name):
        if var_name not in self._variables:
            raise KeyError(f"Variable '{var_name}' not found in the PGM.")
        return self._variables[var_name]

    def factor_from_name(self, factor_name):
        if factor_name not in self._factors:
             raise KeyError(f"Factor '{factor_name}' not found in the PGM.")
        return self._factors[factor_name]

    def get_variables(self):
        """Returns the dictionary of variables {name: Variable}."""
        return self._variables

    def get_factors(self):
        """Returns the dictionary of factors {name: Factor}."""
        return self._factors

**Notes:** getter methods added

## Messages

In [9]:
class Messages(object):
    def __init__(self):
        self.messages = {} # Memoization cache: key=(sender_name, receiver_name), value=message_array

    def _get_message_key(self, sender, receiver):
        """Helper to create a consistent key for the messages dictionary."""
        return (sender.name, receiver.name)

    def _get_variable_dimension(self, variable, context_node):
        """ Tries to infer the dimension (number of states) of a variable. """
        if hasattr(variable, 'dim'):
            return variable.dim
        # Infer from context node (neighbor factor or variable)
        if isinstance(context_node, Factor) and context_node.data:
             try:
                 axis_index = context_node.data.axes_labels.index(variable.name)
                 return context_node.data.probs.shape[axis_index]
             except (ValueError, IndexError):
                 pass # Variable not in this factor's data or data malformed
        # Try inferring from other neighbors
        for neighbor in variable.neighbors:
            if neighbor != context_node and isinstance(neighbor, Factor) and neighbor.data:
                try:
                    axis_index = neighbor.data.axes_labels.index(variable.name)
                    dim = neighbor.data.probs.shape[axis_index]
                    variable.dim = dim # Cache dimension on variable node
                    return dim
                except (ValueError, IndexError):
                    continue
        raise RuntimeError(f"Could not determine dimension for variable '{variable.name}' from neighbors.")


    def _variable_to_factor_messages(self, variable, factor):
        """
        Computes the message from a variable node to a factor node.
        message(var -> fac) = product of messages from all other factors linked to var.
                              = product_{fac' in neighbors(var) \ {fac}} message(fac' -> var)
        Handles observed variables.
        """
        # Key is already checked by the public wrapper method

        # Check for observed variable
        if variable.observed_value is not None:
             var_dim = self._get_variable_dimension(variable, factor)
             message = np.zeros(var_dim)
             if 0 <= variable.observed_value < var_dim:
                  message[variable.observed_value] = 1.0
             else:
                 raise ValueError(f"Observed value {variable.observed_value} for variable '{variable.name}' is out of bounds for dimension {var_dim}.")
             # print(f"  Observed Msg Var({variable.name}) -> Fac({factor.name}): {message}") # Debug
        else:
            # Get messages from all other neighboring factors
            incoming_messages = []
            for neighbor_factor in variable.neighbors:
                if neighbor_factor.name == factor.name:
                    continue # Skip the factor we are sending the message to
                # Recursively compute/retrieve the message from that factor to this variable
                incoming_message = self.factor_to_variable_message(neighbor_factor, variable)
                incoming_messages.append(incoming_message)

            # Compute the product of incoming messages
            if not incoming_messages:
                # If the variable is a leaf node (only connected to 'factor'), message is uniform (array of ones)
                var_dim = self._get_variable_dimension(variable, factor)
                message = np.ones(var_dim)
            else:
                # Product of all incoming messages. Assumes messages are 1D arrays.
                try:
                    # Stack messages into a 2D array for prod along axis 0
                    message_array = np.array(incoming_messages)
                    # Check for empty arrays or inconsistent shapes if necessary
                    if message_array.size == 0: # Handle case where incoming messages resulted in empty arrays (e.g., contradictions)
                        var_dim = self._get_variable_dimension(variable, factor)
                        message = np.zeros(var_dim) # Or handle as error?
                    else:
                         message = np.prod(message_array, axis=0)
                except ValueError as e:
                    print(f"Error computing product for messages into Var({variable.name}) from factors other than Fac({factor.name}): {incoming_messages}")
                    raise e

        # Normalize message? No, standard BP doesn't normalize var->fac messages.

        # Memoization is handled by the public wrapper method
        return message

    def _factor_to_variable_messages(self, factor, variable):
        """
        Computes the message from a factor node to a variable node.
        message(fac -> var) = sum_{vars in fac \ {var}} [ factor_potential * product_{var' in neighbors(fac) \ {var}} message(var' -> fac) ]
        """
        # Key is already checked by the public wrapper method

        # Start with the factor's potential (distribution)
        if factor.data is None:
            raise ValueError(f"Factor '{factor.name}' does not have distribution data set.")

        # Make a copy to avoid modifying the original factor data
        current_potential = Distribution(np.copy(factor.data.probs), list(factor.data.axes_labels))

        # Multiply by incoming messages from all other neighboring variables
        other_neighbor_variables = [v for v in factor.neighbors if v.name != variable.name]

        for neighbor_variable in other_neighbor_variables:
            # Get the message from this neighbor variable to the factor
            incoming_message = self.variable_to_factor_messages(neighbor_variable, factor)

            # Multiply the current potential by this message.
            message_dist = Distribution(incoming_message, [neighbor_variable.name])
            try:
                 current_potential = multiply(current_potential, message_dist) # Use the generalized multiply
            except ValueError as e:
                 print(f"Error multiplying potential for Fac({factor.name}) with message from Var({neighbor_variable.name})")
                 print(f"  Potential: labels={current_potential.axes_labels}, shape={current_potential.probs.shape}")
                 print(f"  Message: labels={message_dist.axes_labels}, shape={message_dist.probs.shape}")
                 raise e


        # Sum over all variables connected to the factor EXCEPT the target variable
        vars_to_sum_out = [v.name for v in other_neighbor_variables]
        axes_to_sum_out = []
        potential_axes_map = current_potential.get_axes()
        for var_name in vars_to_sum_out:
             if var_name in potential_axes_map:
                  axes_to_sum_out.append(potential_axes_map[var_name])
             else:
                 # This variable might have been introduced by a message multiplication
                 # and might have dimension 1 if it wasn't part of original factor
                 # Check shape to be sure
                 if var_name in current_potential.axes_labels:
                      axis_idx = current_potential.axes_labels.index(var_name)
                      if current_potential.probs.shape[axis_idx] == 1:
                           axes_to_sum_out.append(axis_idx) # Sum out singleton dimensions too
                      else:
                           # Should not happen if graph/messages are correct
                           print(f"Warning: Variable '{var_name}' to sum out in Fac({factor.name}) -> Var({variable.name}) has dimension > 1 but was not in original factor neighbors.")
                 # else: Variable not present, cannot sum out.

        # Perform the summation
        if axes_to_sum_out:
            # Use tuple for summing over multiple axes, ensure axes are unique and valid
            unique_axes_to_sum = tuple(sorted(list(set(axes_to_sum_out)), reverse=True)) # Sum largest axes first potentially helps
            if not all(0 <= ax < current_potential.probs.ndim for ax in unique_axes_to_sum):
                 raise IndexError(f"Invalid axis found in {unique_axes_to_sum} for potential with ndim={current_potential.probs.ndim} in Fac({factor.name}) -> Var({variable.name})")
            summed_probs = np.sum(current_potential.probs, axis=unique_axes_to_sum)
        else:
             # If no other variables to sum out (factor involves only target variable, or others were observed and handled?)
             summed_probs = current_potential.probs

        # Result should be 1D array corresponding to target variable's states.
        # It might have leading/trailing dims of size 1 from the sum. Squeeze them.
        final_message_probs = np.squeeze(summed_probs)

        # Ensure the result is still an array, even if variable has only 1 state
        if not isinstance(final_message_probs, np.ndarray):
             final_message_probs = np.array([final_message_probs])
        elif final_message_probs.ndim == 0: # Handle case where squeeze results in scalar
             final_message_probs = final_message_probs.reshape(1,)

        # print(f"  Msg Fac({factor.name}) -> Var({variable.name}): {final_message_probs}") # Debug

        # Normalize message? No, standard BP doesn't normalize fac->var messages.

        # Memoization is handled by the public wrapper method
        return final_message_probs


    def marginal(self, variable):
        """
        Computes the marginal distribution for a given variable.
        p(var) proportional to product_{fac in neighbors(var)} message(fac -> var)
        Handles observed variables.
        """
        if variable.observed_value is not None:
             # If observed, the marginal is a delta function (one-hot)
             var_dim = self._get_variable_dimension(variable, variable) # Pass variable as context
             marginal_p = np.zeros(var_dim)
             if 0 <= variable.observed_value < var_dim:
                 marginal_p[variable.observed_value] = 1.0
             else:
                  # This case should be caught earlier by variable_to_factor message if observed value invalid
                  print(f"Warning: Observed value {variable.observed_value} for variable '{variable.name}' seems out of bounds ({var_dim}). Returning zeros.")
             return marginal_p

        # If not observed, compute product of incoming messages from all neighbors
        incoming_messages = []
        if not variable.neighbors:
            # Isolated variable - marginal is typically considered uniform or based on a prior if one exists
            # In factor graph context, often means model is underspecified for this var.
            print(f"Warning: Variable '{variable.name}' has no factors connected. Cannot compute marginal belief from factors. Returning uniform.")
            # Cannot determine dimension without neighbors or prior info. Assuming dim 1? Risky.
            # Let's try to get dim if it was stored during set_distributions
            try:
                var_dim = self._get_variable_dimension(variable, variable)
                return np.ones(var_dim) / var_dim
            except RuntimeError:
                 print(f"Error: Cannot determine dimension for isolated variable '{variable.name}'. Returning [1.0].")
                 return np.array([1.0])


        for neighbor_factor in variable.neighbors:
            incoming_message = self.factor_to_variable_message(neighbor_factor, variable)
            # Ensure message is numpy array for stacking
            if not isinstance(incoming_message, np.ndarray):
                 incoming_message = np.array(incoming_message)
            # Check if message is empty (e.g. due to contradiction)
            if incoming_message.size == 0:
                  print(f"Warning: Received empty message from factor '{neighbor_factor.name}' to variable '{variable.name}'. Indicates potential contradiction. Marginal will be zero.")
                  var_dim = self._get_variable_dimension(variable, variable)
                  return np.zeros(var_dim) # Belief is zero everywhere
            incoming_messages.append(incoming_message)

        # Product of messages
        try:
            message_array = np.array(incoming_messages)
            if message_array.size == 0 : # Should be caught above, but double check
                var_dim = self._get_variable_dimension(variable, variable)
                unnorm_p = np.zeros(var_dim)
            else:
                unnorm_p = np.prod(message_array, axis=0)
        except ValueError as e:
             print(f"Error computing product of incoming messages for variable '{variable.name}': {incoming_messages}")
             raise e


        # Normalize to get a valid probability distribution
        norm_const = np.sum(unnorm_p)
        if np.isclose(norm_const, 0):
            # Avoid division by zero. Indicates contradiction or zero probability event.
            print(f"Warning: Normalization constant is close to zero for variable '{variable.name}'. Belief is zero or ill-defined. Returning uniform distribution over inferred dimension.")
            var_dim = self._get_variable_dimension(variable, variable)
            return np.ones(var_dim) / var_dim
        else:
            marginal_p = unnorm_p / norm_const

        # Ensure result has the correct dimension, handle scalar case
        var_dim = self._get_variable_dimension(variable, variable)
        if marginal_p.shape != (var_dim,):
            if marginal_p.size == var_dim: # Can be reshaped if size matches
                try:
                     marginal_p = marginal_p.reshape((var_dim,))
                except ValueError:
                      print(f"Warning: Marginal shape mismatch for {variable.name}. Expected {(var_dim,)}, got {marginal_p.shape}. Check message calculations.")
            elif marginal_p.ndim == 0 and var_dim == 1: # Scalar result for dim 1 var
                 marginal_p = np.array([marginal_p])
            else:
                print(f"Warning: Marginal shape mismatch for {variable.name}. Expected {(var_dim,)}, got {marginal_p.shape}. Check message calculations.")


        return marginal_p

    # --- Public wrappers for message computation with memoization ---
    def variable_to_factor_messages(self, variable, factor):
         message_key = self._get_message_key(variable, factor)
         if message_key not in self.messages:
             # print(f"Computing msg: Var({variable.name}) -> Fac({factor.name})") # Debug
             self.messages[message_key] = self._variable_to_factor_messages(variable, factor)
         # else: print(f"Using cached msg: Var({variable.name}) -> Fac({factor.name})") # Debug
         # Return a copy to prevent accidental modification of cached value
         return np.copy(self.messages[message_key])

    def factor_to_variable_message(self, factor, variable):
         message_key = self._get_message_key(factor, variable)
         if message_key not in self.messages:
             # print(f"Computing msg: Fac({factor.name}) -> Var({variable.name})") # Debug
             self.messages[message_key] = self._factor_to_variable_messages(factor, variable)
         # else: print(f"Using cached msg: Fac({factor.name}) -> Var({variable.name})") # Debug
         # Return a copy
         return np.copy(self.messages[message_key])

    # --- New methods for structured passes and full BP ---

    def forward(self, pgm: PGM, root_variable_name: str):
        """
        Computes messages flowing from the leaves towards the specified root variable.
        This function primarily ensures that the necessary messages required for
        calculating the marginal of the root (and potentially nodes "downstream"
        from it in the backward pass) are computed, leveraging the lazy recursive approach.

        Args:
            pgm (PGM): The PGM object.
            root_variable_name (str): The name of the variable to act as the root for this pass.
        """
        print(f"--- Running Forward Pass (towards root '{root_variable_name}') ---")
        if root_variable_name not in pgm.get_variables():
             raise ValueError(f"Root variable '{root_variable_name}' not found in PGM.")

        root_variable = pgm.variable_from_name(root_variable_name)

        # Trigger the computation of messages incoming to the root variable.
        # The recursive calls within factor_to_variable_message will propagate
        # the computation requests towards the leaves as needed.
        print(f"Triggering message computations incoming to root '{root_variable_name}'...")
        for factor in root_variable.neighbors:
             # Requesting this message triggers recursive calls towards leaves
             _ = self.factor_to_variable_message(factor, root_variable)

        print(f"Forward pass towards '{root_variable_name}' complete (messages computed lazily).")


    def backward(self, pgm: PGM, root_variable_name: str):
        """
        Computes messages flowing from the specified root variable out towards the leaves.
        This relies on the forward pass having computed messages incoming to the root.
        It triggers computations flowing outwards from the root.

        Args:
            pgm (PGM): The PGM object.
            root_variable_name (str): The name of the variable acting as the root.
        """
        print(f"--- Running Backward Pass (from root '{root_variable_name}') ---")
        if root_variable_name not in pgm.get_variables():
             raise ValueError(f"Root variable '{root_variable_name}' not found in PGM.")

        root_variable = pgm.variable_from_name(root_variable_name)

        # We need a way to traverse outwards from the root. BFS or DFS can work.
        # Let's use a queue (BFS style) to manage nodes to process.
        # We need to send messages outwards from the root.

        queue = [(root_variable, None)] # Start with root, parent=None
        visited_edges = set() # To avoid cycles in message passing (though should be tree)

        print(f"Triggering message computations outgoing from root '{root_variable_name}'...")

        while queue:
            current_node, parent_node = queue.pop(0)

            if isinstance(current_node, Variable):
                # Send messages Variable -> Factor to all neighbors except the one towards parent_node
                for neighbor_factor in current_node.neighbors:
                    # The "parent" factor is the one connecting to parent_node (if current_node is not root)
                    parent_factor = parent_node if isinstance(parent_node, Factor) else None
                    if neighbor_factor == parent_factor:
                        continue

                    edge = tuple(sorted((current_node.name, neighbor_factor.name)))
                    if edge not in visited_edges:
                         # Compute Var -> Fac message
                         _ = self.variable_to_factor_messages(current_node, neighbor_factor)
                         visited_edges.add(edge)
                         queue.append((neighbor_factor, current_node)) # Add factor to queue

            elif isinstance(current_node, Factor):
                # Send messages Factor -> Variable to all neighbors except the one towards parent_node
                 for neighbor_variable in current_node.neighbors:
                    # The "parent" variable is the one connecting to parent_node (if current_node is not root)
                    parent_variable = parent_node if isinstance(parent_node, Variable) else None
                    if neighbor_variable == parent_variable:
                        continue

                    edge = tuple(sorted((current_node.name, neighbor_variable.name)))
                    if edge not in visited_edges:
                         # Compute Fac -> Var message
                         _ = self.factor_to_variable_message(current_node, neighbor_variable)
                         visited_edges.add(edge)
                         queue.append((neighbor_variable, current_node)) # Add variable to queue

        print(f"Backward pass from '{root_variable_name}' complete (messages computed lazily/triggered).")


    def belief_propagation(self, pgm: PGM):
        """
        Executes the full belief propagation algorithm (sum-product) to compute all marginal distributions.

        For a tree, this can be achieved by:
        1. (Optionally) Performing an explicit forward pass towards an arbitrary root.
        2. (Optionally) Performing an explicit backward pass away from that root.
        3. Calculating the marginal for each variable using the computed messages.

        Alternatively, due to the lazy/recursive nature of the message computation here,
        simply requesting the marginal for every variable will ensure all necessary messages
        are computed exactly once (due to memoization) across the entire graph. This function
        uses the simpler approach of iterating through variables and calling `marginal`.

        Args:
            pgm (PGM): The PGM object containing the graph structure and factor potentials.

        Returns:
            dict: A dictionary mapping each variable name (str) to its computed
                  marginal distribution (numpy array).
        """
        print("\n--- Running Belief Propagation (Sum-Product) ---")
        all_marginals = {}
        variables = pgm.get_variables() # Get the dictionary {name: Variable_object}

        # This implicitly triggers all necessary message computations due to recursion and memoization.
        print("Calculating marginals for all variables (triggers message computations)...")
        if not variables:
             print("Warning: No variables in PGM.")
             return {}

        for var_name, variable_obj in variables.items():
            # Compute the marginal for this variable.
            # This will trigger recursive computation of all necessary messages.
            # Previously computed messages stored in self.messages will be reused.
            # print(f"Calculating marginal for: {var_name}") # Debug
            try:
                marginal_dist = self.marginal(variable_obj)
                all_marginals[var_name] = marginal_dist
                # print(f"  Marginal p({var_name}) = {marginal_dist}") # Debug
            except Exception as e:
                 print(f"ERROR calculating marginal for variable '{var_name}': {e}")
                 # Optionally re-raise or store None/error indicator
                 all_marginals[var_name] = f"Error: {e}"


        print("Belief Propagation complete.")
        return all_marginals

**Notes**: Used Gemini 2.5 to fix several errors that appears during debug phase. After that, the entire block of code was refactored, adding some comments and checking for correctness.

## Example usage
**Apply the belief_propagation method to a given factor graph (page 43 PML notes)**

![factor_ex](imgs/factor_example.png)

In [10]:
# Define the graph structure string
model_string = "p(x1,x3)p(x2,x3)p(x3,x4,x5)"

# Create the PGM object
pgm = PGM.from_string(model_string)
print("PGM created with variables:", list(pgm.get_variables().keys()))
print("PGM created with factors:", list(pgm.get_factors().keys()))

# Define the factor potentials using the values from the example
f1_probs = np.array([[0.3, 0.2],
                     [0.1, 0.4]])
f1_dist = Distribution(f1_probs, ['x1', 'x3'])

f2_probs = np.array([[0.1, 0.5],
                     [0.2, 0.2]])
f2_dist = Distribution(f2_probs, ['x2', 'x3'])

f3_probs = np.array([
    [[0.1, 0. ],
     [0.1, 0.1]],
    [[0.1, 0. ],
     [0.2, 0.4]]
])
f3_dist = Distribution(f3_probs, ['x3', 'x4', 'x5'])

# Create the data dictionary mapping factor names to distributions
data = {
    "p(x1,x3)": f1_dist,
    "p(x2,x3)": f2_dist,
    "p(x3,x4,x5)": f3_dist,
}

# Set the distributions in the PGM
try:
    pgm.set_distributions(data)
    print("\nDistributions set successfully.")
    for name, var in pgm.get_variables().items():
        print(f"  Variable '{name}' has dimension {var.dim}")
except ValueError as e:
    print(f"Error setting distributions: {e}")
except Exception as e:
    print(f"An unexpected error occurred during distribution setting: {e}")

# Initialize the Messages class
m = Messages()

# Run Belief Propagation
all_marginals = m.belief_propagation(pgm)

# Print the computed marginals
print("\n--- Computed Marginal Distributions ---")
if all_marginals:
    for var_name in sorted(all_marginals.keys()):
        marginal = all_marginals[var_name]
        if isinstance(marginal, str) and marginal.startswith("Error"):
            print(f"p({var_name}) = {marginal}")
        else:
            print(f"p({var_name}) = {np.array2string(np.asarray(marginal), precision=4, suppress_small=True)}")
else:
    print("No marginals computed.")

PGM created with variables: ['x5', 'x2', 'x4', 'x1', 'x3']
PGM created with factors: ['p(x1,x3)', 'p(x2,x3)', 'p(x3,x4,x5)']

Distributions set successfully.
  Variable 'x5' has dimension 2
  Variable 'x2' has dimension 2
  Variable 'x4' has dimension 2
  Variable 'x1' has dimension 2
  Variable 'x3' has dimension 2

--- Running Belief Propagation (Sum-Product) ---
Calculating marginals for all variables (triggers message computations)...
Belief Propagation complete.

--- Computed Marginal Distributions ---
p(x1) = [0.3788 0.6212]
p(x2) = [0.6727 0.3273]
p(x3) = [0.1091 0.8909]
p(x4) = [0.1636 0.8364]
p(x5) = [0.4545 0.5455]
