I've used this notebook to derive the BP algorithm on top of the PGMPY library, whose BP isn't a real BP algorithm

In [1]:
import src.models.graph_builders as graph_builders
import src.models.var_builders as var_builders
import src.models.builders as mb
import src.inference.helpers as ih

from functools import reduce
import numpy as np

In [2]:
# Create factor graph
height = 170
age = 30
sex = "Male"
(
    HFEV1,
    ecFEV1,
    AR,
    HO2Sat,
    O2SatFFA,
    IA,
    UO2Sat,
    O2Sat,
) = var_builders.o2sat_fev1_point_in_time(height, age, sex)

model = graph_builders.fev1_o2sat_point_in_time_factor_graph(
    HFEV1, ecFEV1, AR, HO2Sat, O2SatFFA, IA, UO2Sat, O2Sat, check_model=False
)

# BP message computation

# Add this to the FactorGraph library
def point_mass_message(self, var, obs):
    """
    Returns the point mass message for the variable given the observed state.
    """
    card = self.get_cardinality(var)
    # Create an array with 1 at the index of the evidence and 0 elsewhere
    message = np.zeros(card)
    message[obs] = 1
    return message

In [19]:
ecFEV1_obs = 3.2
ecFEV1_idx = ecFEV1.get_bin_for_value(ecFEV1_obs)[1]
print(ecFEV1_idx)

var_to_infer = "Healthy FEV1 (L)"


def variable_node_message(incoming_messages):
    """
    The outgoing message is the element wise product of all incoming messages
    """
    if len(incoming_messages) == 1:
        return incoming_messages[0]
    outgoing_message = np.multiply(*incoming_messages)
    # Normalise
    return outgoing_message / np.sum(outgoing_message)


def factor_node_message(incoming_messages, factor, target_var):
    """
    Returns the outgoing message for a factor node, which is the multiplication of the incoming messages with the factor function (CPT)

    The variables' order in the incoming messages list must match the variable's order in the CPT's dimensions
    """
    cpt = factor.values
    vars = factor.variables

    assert (
        len(incoming_messages) == cpt.ndim - 1
    ), f"Error computing factor node message for {target_var}. The number of incoming messages must equal the card(CPT) - 1"

    # Ensure that the target var is on the CPT's 0th axis
    # Find idx of target var in the vars list
    target_var_idx = vars.index(target_var)
    if target_var_idx != 0:
        # Move target var to the 0th axis
        cpt = np.moveaxis(cpt, target_var_idx, 0)

    # Invert incoming_messages, so that the first message corresponds to the last dimension of the CPT
    incoming_messages = list(reversed(incoming_messages))

    # Reduce the CPT with the inverted list of incoming messages
    return reduce(lambda matrix, m: np.matmul(matrix, m), incoming_messages, cpt)

64


In [20]:
# ecFEV1 -> factor
m = model.point_mass_message(ecFEV1.name, ecFEV1_idx)
ecFEV1toF = variable_node_message([m])

# AR -> factor
artoF = variable_node_message([AR.cpt])

# factor -> var_to_infer
toHFEV1 = factor_node_message([ecFEV1toF, artoF], model.get_factors()[1], var_to_infer)

# hfev1 posterior
hfev1_posterior = variable_node_message([toHFEV1, HFEV1.cpt])
hfev1_posterior.shape

(100,)

In [21]:
# Compare results against the bayes net model (factor graph model takes ages to compute)
model1, inf_alg1, _, _, _ = mb.fev1_point_in_time_model(height, age, sex)
res = ih.infer(inf_alg1, [HFEV1], [[ecFEV1, ecFEV1_obs]])
res[HFEV1.name].values.shape

(100,)

In [28]:
((hfev1_posterior - res[HFEV1.name].values) < 1e-5).all()
# Plot the two posteriors with plotly go figure
import plotly.graph_objs as go

fig = go.Figure()
# Add bar plot for the factor graph model
fig.add_trace(go.Bar(x=HFEV1.bins, y=hfev1_posterior, name="Factor Graph"))
fig.add_trace(go.Bar(x=HFEV1.bins, y=res[HFEV1.name].values, name="Bayes Net"))
# Update layout
fig.update_layout(
    barmode="group", title="Comparison of HFEV1 posteriors", width=800, height=300
)
fig.show()