I've used this notebook to derive the BP algorithm on top of the PGMPY library, whose BP isn't a real BP algorithm

In [1]:
import src.models.builders as mb
import src.inference.helpers as ih
import plotly.graph_objs as go
import numpy as np

In [11]:
# Create factor graph
height = 170
age = 30
sex = "Male"
model, inf_alg, HFEV1, ecFEV1, AR, HO2Sat, O2SatFFA, IA, UO2Sat, O2Sat = (
    mb.o2sat_fev1_point_in_time_model_2(height, age, sex)
)

# BP message computation

Add this to the FactorGraph library

def point_mass_message(self, var, obs):
    """
    Returns the point mass message for the variable given the observed state.
    """
    card = self.get_cardinality(var)
    # Create an array with 1 at the index of the evidence and 0 elsewhere
    message = np.zeros(card)
    message[obs] = 1
    return message

In [3]:
from functools import reduce


# Functions moved to the custom BP algorithm
def variable_node_message(incoming_messages):
    """
    The outgoing message is the element wise product of all incoming messages
    """
    if len(incoming_messages) == 1:
        return incoming_messages[0]
    outgoing_message = np.multiply(*incoming_messages)
    # Normalise
    return outgoing_message / np.sum(outgoing_message)


def factor_node_message(incoming_messages, factor, target_var):
    """
    Returns the outgoing message for a factor node, which is the multiplication of the incoming messages with the factor function (CPT)

    The variables' order in the incoming messages list must match the variable's order in the CPT's dimensions
    """
    cpt = factor.values
    vars = factor.variables

    assert (
        len(incoming_messages) == cpt.ndim - 1
    ), f"Error computing factor node message for {target_var}. The number of incoming messages must equal the card(CPT) - 1"

    # Ensure that the target var is on the CPT's 0th axis
    # Find idx of target var in the vars list
    target_var_idx = vars.index(target_var)
    if target_var_idx != 0:
        # Move target var to the 0th axis
        cpt = np.moveaxis(cpt, target_var_idx, 0)

    # Invert incoming_messages, so that the first message corresponds to the last dimension of the CPT
    incoming_messages = list(reversed(incoming_messages))

    # Reduce the CPT with the inverted list of incoming messages
    outgoing_message = reduce(
        lambda cpt_reduced, m: np.matmul(cpt_reduced, m), incoming_messages, cpt
    )

    # Normalise
    return outgoing_message / sum(outgoing_message)

In [7]:
# Test implementation
var_to_infer = "Healthy FEV1 (L)"

ecFEV1_obs = 1
ecFEV1_idx = ecFEV1.get_bin_for_value(ecFEV1_obs)[1]

o2sat_obs = 100
o2sat_idx = O2Sat.get_bin_for_value(o2sat_obs)[1]

hfev1_factor = model.get_factors()[0]
ecfev1_factor = model.get_factors()[1]
ar_factor = model.get_factors()[2]
ho2sat_factor = model.get_factors()[3]
o2satffa_factor = model.get_factors()[4]
ia_factor = model.get_factors()[5]
uo2sat_factor = model.get_factors()[6]
o2sat_factor = model.get_factors()[7]

# Normalise all factors
# hfev1_factor.normalize()
# ecfev1_factor.normalize()
# ar_factor.normalize()
# ho2sat_factor.normalize()
# o2satffa_factor.normalize()
# ia_factor.normalize()
# uo2sat_factor.normalize()
# o2sat_factor.normalize()

# O2sat -> O2sat factor
o2sat_obs_m = model.get_point_mass_message(O2Sat.name, o2sat_idx)
# o2sat_obs_m = np.ones(51) / 51
o2sat_up = variable_node_message([o2sat_obs_m])

# Only one message comes to UO2Sat
to_uo2sat = factor_node_message([o2sat_up], o2sat_factor, UO2Sat.name)
# Only one message incoming UO2Sat, hence
uo2sat_up = variable_node_message([to_uo2sat])
assert (uo2sat_up == to_uo2sat).all()

# IA -> O2Sat factor
ia_right = variable_node_message([IA.cpt])

# uo2sat_factor -> o2satffa
to_o2satffa = factor_node_message([uo2sat_up, ia_right], uo2sat_factor, O2SatFFA.name)

# O2SatFFA to O2SatFFA factor
o2satffa_up = variable_node_message([to_o2satffa])
assert (o2satffa_up == to_o2satffa).all()

# HO2Sat to O2SatFFA factor
ho2sat_down = variable_node_message([HO2Sat.cpt])

# O2SatFFA factor to AR
to_ar = factor_node_message([o2satffa_up, ho2sat_down], o2satffa_factor, AR.name)

# AR -> ecfev1 factor
ar_left = variable_node_message([AR.cpt, to_ar])

# ecFEV1 -> ecfev1 factor
ecfev1_obs_m = model.get_point_mass_message(ecFEV1.name, ecFEV1_idx)
ecfev1_up = variable_node_message([ecfev1_obs_m])

# ecfev1 -> hfev1
to_hfev1 = factor_node_message([ecfev1_up, ar_left], ecfev1_factor, var_to_infer)

# hfev1 posterior
hfev1_posterior = variable_node_message([to_hfev1, HFEV1.cpt])
hfev1_posterior.shape

(100,)

In [8]:
print(o2satffa_factor.values.sum())
tmp = np.sum(o2satffa_factor.values, axis=0)
tmp = np.sum(tmp, axis=0)
tmp = np.sum(tmp, axis=0)
print(tmp)
o2satffa_factor.normalize()
print(o2satffa_factor.values.sum())
tmp = np.sum(o2satffa_factor.values, axis=0)
tmp = np.sum(tmp, axis=0)
tmp = np.sum(tmp, axis=0)
tmp

1.0
1.0000000000000002
1.0


1.0000000000000002

In [9]:
# Compare results against the bayes net model (factor graph model takes ages to compute)
model1, inf_alg1, _, _, _, _, _, _, _, _ = mb.o2sat_fev1_point_in_time_model(
    height, age, sex
)
res = ih.infer(
    inf_alg1, variables=[HFEV1], evidences=[[ecFEV1, ecFEV1_obs], [O2Sat, o2sat_obs]]
)
res[HFEV1.name].values.shape

((hfev1_posterior - res[HFEV1.name].values) < 1e-5).all()
# Plot the two posteriors with plotly go figure
import plotly.graph_objs as go

fig = go.Figure()
# Add bar plot for the factor graph model
fig.add_trace(
    go.Bar(
        x=HFEV1.bins,
        y=hfev1_posterior,
        name="Factor Graph with custom<br>BP computation",
    )
)
fig.add_trace(go.Bar(x=HFEV1.bins, y=res[HFEV1.name].values, name="Bayes Net"))
# Update layout
fig.update_layout(
    barmode="group", title="Comparison of HFEV1 posteriors", width=800, height=300
)
fig.show()


invalid value encountered in divide



In [6]:
((hfev1_posterior - res[HFEV1.name].values) < 1e-5).all()
# Plot the two posteriors with plotly go figure
import plotly.graph_objs as go

fig = go.Figure()
# Add bar plot for the factor graph model
fig.add_trace(go.Bar(x=HFEV1.bins, y=hfev1_posterior, name="Factor Graph"))
fig.add_trace(go.Bar(x=HFEV1.bins, y=res[HFEV1.name].values, name="Bayes Net"))
# Update layout
fig.update_layout(
    barmode="group", title="Comparison of HFEV1 posteriors", width=800, height=300
)
fig.show()

# BP scheduling with recursion

In [6]:
# Returns the posterior distribution of the queried variable, recursively going through the graph until a node stopping criteria is met.


# def query(model, inf_alg, var, evidences):
#     """
#     Returns the posterior distribution of the queried variable, recursively going through the graph until reaching a root variable (with no parent), or an observed variable.
#     """
#     return process_var(model, inf_alg, evidences, var)


# def process_var(model, inf_alg, evidences, var, from_factor=None, debug=False):
#     """
#     Returns the message outgoing from the variable node, given the incoming messages from its neighbouring factors.

#     evidences: dict with the observed variables and their values
#     var: str, the variable from which we want to compute the outgoing message
#     from_factor: str, the factor asking to process that variable, as part of the recursion.
#     from_factor is None for the first call, i.e. for the queried variable from which we want to compute the posterior.
#     """
#     print(f"Processing variable {var}")
#     # Stopping criteria: if the variable is observed, return the point mass message of the observation
#     if var in evidences.keys():
#         print(f"Returning point mass message for {var}")
#         return model.point_mass_message(var, evidences[var])
#     # Else, get the incoming messages from all neighbouring factors
#     else:
#         print(f"Variable {var} has multiple connections, going deeper in the graph")
#         incoming_messages = []
#         for factor in model.neighbors(var):
#             if factor != from_factor:
#                 incoming_messages.append(
#                     process_factor(model, inf_alg, evidences, factor, from_var=var)
#                 )
#         print(f"Computing outgoing message for the {var}\n")
#         return inf_alg.variable_node_message(incoming_messages)


# def process_factor(model, inf_alg, evidences, factor, from_var: str):
#     """
#     Returns the message outgoing from the factor node, given the incoming messages from its neighbouring variables.

#     factor: str, the factor from which we want to compute the outgoing message
#     from_var: str, the variable asking to process that factor, as part of the recursion.
#     from_var is None for the first call, i.e. for the queried variable from which we want to compute the posterior.
#     """
#     print(f"Processing a factor of {from_var}")
#     # from_var can't be null
#     assert from_var is not None, "from_var must be specified"

#     vars = factor.variables
#     # Stopping criteria: if the factor is connected to only one variable, return the factor function which is the prior of from_var
#     if len(vars) == 1:
#         print(f"The factor is {from_var}'s prior, returning it as the message")
#         prior = factor.values
#         assert prior.ndim == 1, "The factor function must be a 1D array"
#         return prior
#     # Else, get the incoming messages from all neighbouring variables
#     else:
#         print(
#             f"This {from_var}'s factor has multiple connections, going deeper in the graph"
#         )
#         incoming_messages = []
#         for var in vars:
#             if var != from_var:
#                 incoming_messages.append(
#                     process_var(model, inf_alg, evidences, var, from_factor=factor)
#                 )
#         return inf_alg.factor_node_message(incoming_messages, factor, from_var)

In [3]:
# Compare results against the bayes net model (factor graph model takes ages to compute)
model1, inf_alg1, _, _, _, _, _, _, _, _ = mb.o2sat_fev1_point_in_time_model(
    height, age, sex
)


def compare_inf_results(pgmpy_alg, custom_alg, var, evidences):
    """
    Compares the results of the inference algorithms for the given variable and evidences
    """
    res_pgmpy = ih.infer(pgmpy_alg, [var], evidences)
    posterior_pgmpy = res_pgmpy[var.name].values

    res_custom = ih.infer(custom_alg, [var], evidences)
    posterior_custom = res_custom[var.name]

    assert (
        (posterior_custom - posterior_pgmpy) < 1e-5
    ).all(), "It looks like there's an error in the custom BP algorithm!"

    fig = go.Figure()
    # Add bar plot for the factor graph model
    fig.add_trace(
        go.Bar(
            x=HFEV1.bins,
            y=posterior_custom,
            name="Factor Graph with custom<br>BP computation and scheduling",
        )
    )
    fig.add_trace(go.Bar(x=HFEV1.bins, y=posterior_pgmpy, name="Bayes Net"))
    # Update layout
    fig.update_layout(
        barmode="group", title="Comparison of HFEV1 posteriors", width=800, height=300
    )
    fig.show()

    return -1

In [4]:
compare_inf_results(inf_alg1, inf_alg, HFEV1, [[ecFEV1, 1], [O2Sat, 100]])
compare_inf_results(inf_alg1, inf_alg, HFEV1, [[ecFEV1, 3.5]])
compare_inf_results(inf_alg1, inf_alg, HFEV1, [])

  phi.values = phi.values / phi1.values



invalid value encountered in divide



-1

In [14]:
query_vars = [AR, HO2Sat, O2SatFFA, UO2Sat]
res_custom = ih.infer(inf_alg, query_vars, [[ecFEV1, 5.9], [O2Sat, 85]])
res_pgmpy = ih.infer(inf_alg1, query_vars, [[ecFEV1, 5.9], [O2Sat, 85]])

for variable in query_vars:
    assert (
        (res_custom[variable.name] - res_pgmpy[variable.name].values) < 1e-5
    ).all(), f"It looks like there's an error in the custom BP algorithm for {variable.name}!"

# Comparing performance of custom vs existing BP algorithms

In [20]:
# Custom is instant!
res_custom = ih.infer(
    inf_alg, [HFEV1, AR, HO2Sat, O2SatFFA, UO2Sat], [[ecFEV1, 5.9], [O2Sat, 85]]
)

In [21]:
# PGMPY implementation takes 26 seconds
res_pgmpy = ih.infer(
    inf_alg1, [HFEV1, AR, HO2Sat, O2SatFFA, UO2Sat], [[ecFEV1, 5.9], [O2Sat, 85]]
)

# Getting pmgpy tests to work

In [2]:
from pgmpy.models import BayesianNetwork
from pgmpy.factors.discrete import TabularCPD
from src.inference.inf_algs import apply_pgmpy_bp

In [22]:
bayes_net = BayesianNetwork(
    [
        ("A", "B"),
        ("B", "C"),
        ("B", "D"),
    ]
)

bayes_net.add_cpds(
    TabularCPD("A", 2, [[0.4], [0.6]], [], []),
    TabularCPD("B", 3, [[0.2, 0.05], [0.3, 0.15], [0.5, 0.8]], ["A"], [2]),
    TabularCPD("C", 2, [[0.4, 0.5, 0.1], [0.6, 0.5, 0.9]], ["B"], [3]),
    TabularCPD("D", 3, [[0.1, 0.1, 0.2], [0.3, 0.2, 0.1], [0.6, 0.7, 0.7]], ["B"], [3]),
)

bayes_net.check_model()

inf_alg_pgmpy = apply_pgmpy_bp(bayes_net)

cpd = TabularCPD("A", 2, [[0.1 * 0.3], [0.9 * 0.7]])
cpd.normalize()

vars = ["B", "C"]
res = inf_alg_pgmpy.query(
    variables=vars,
    # evidence={"D": 0},
    virtual_evidence=[
        cpd
        # TabularCPD("A", 2, [[0.1], [0.9]]),
    ],
    show_progress=False,
    joint=False,
)
for variable in vars:
    print(variable, res[variable].values)

res

B [0.05461538 0.15461538 0.79076923]
C [0.17823077 0.82176923]


{'B': <DiscreteFactor representing phi(B:3) at 0x7fe9ab6b0e50>,
 'C': <DiscreteFactor representing phi(C:2) at 0x7fe9ab6b3670>}

In [21]:
np.array([0.1 * 0.3]) / (0.1 * 0.3 + 0.9 * 0.7)

array([0.04545455])

# Custom BP: add returning all messages

In [12]:
vars = [HFEV1.name]
res, messages = inf_alg.query(
    variables=vars,
    get_messages=True,
)
factors = model.get_factors()
print(factors)
messages

[<DiscreteFactor representing phi(Healthy FEV1 (L):100) at 0x7fe3c2b9e1d0>, <DiscreteFactor representing phi(ecFEV1 (L):120, Healthy FEV1 (L):100, Airway resistance (%):45) at 0x7fe390fb8340>, <DiscreteFactor representing phi(Airway resistance (%):45) at 0x7fe3c2aee980>, <DiscreteFactor representing phi(Healthy O2 saturation (%):20) at 0x7fe3c2aed750>, <DiscreteFactor representing phi(O2 saturation if fully functional alveoli (%):40, Healthy O2 saturation (%):20, Airway resistance (%):45) at 0x7fe3c2aed030>, <DiscreteFactor representing phi(Inactive alveoli (%):30) at 0x7fe3c2aec430>, <DiscreteFactor representing phi(Underlying O2 saturation (%):100, O2 saturation if fully functional alveoli (%):40, Inactive alveoli (%):30) at 0x7fe3c2f5cd90>, <DiscreteFactor representing phi(O2 saturation (%):51, Underlying O2 saturation (%):100) at 0x7fe3c2f5d000>]


{"From factor ['Healthy FEV1 (L)'] -> to var Healthy FEV1 (L)": array([2.27577091e-08, 3.29616508e-08, 4.77422539e-08, 6.91287304e-08,
        1.00030287e-07, 1.44605212e-07, 2.08777706e-07, 3.00957882e-07,
        4.33039370e-07, 6.21773572e-07, 8.90651194e-07, 1.27246160e-06,
        1.81275029e-06, 2.57445578e-06, 3.64408041e-06, 5.13983578e-06,
        7.22230205e-06, 1.01082494e-05, 1.40883852e-05, 1.95499066e-05,
        2.70048398e-05, 3.71252279e-05, 5.07862575e-05, 6.91183727e-05,
        9.35692772e-05, 1.25976433e-04, 1.68650192e-04, 2.24466989e-04,
        2.96971079e-04, 3.90482016e-04, 5.10203557e-04, 6.62327814e-04,
        8.54126424e-04, 1.09401835e-03, 1.39160174e-03, 1.75763549e-03,
        2.20395468e-03, 2.74330370e-03, 3.38907152e-03, 4.15491598e-03,
        5.05426795e-03, 6.09971266e-03, 7.30225371e-03, 8.67047555e-03,
        1.02096320e-02, 1.19207014e-02, 1.37994598e-02, 1.58356372e-02,
        1.80122269e-02, 2.03050231e-02, 2.26824592e-02, 2.51058119e-02,
 

# Add slicing

In [30]:
import src.data.breathe_data as breathe_data
from functools import reduce

In [31]:
df = breathe_data.load_o2_fev1_df_from_excel()
df.head()

Unnamed: 0,ID,Date Recorded,FEV1,O2 Saturation,ecFEV1,Age,Sex,Height,Predicted FEV1,Healthy O2 Saturation,ecFEV1 % Predicted,FEV1 % Predicted,O2 Saturation % Healthy
0,101,2019-02-20,1.31,97.0,1.32,53,Male,173.0,3.610061,97.22596,36.564477,36.287474,99.767593
1,101,2019-02-21,1.29,96.0,1.32,53,Male,173.0,3.610061,97.22596,36.564477,35.733466,98.739061
2,101,2019-02-22,1.32,96.0,1.32,53,Male,173.0,3.610061,97.22596,36.564477,36.564477,98.739061
3,101,2019-02-23,1.28,97.0,1.33,53,Male,173.0,3.610061,97.22596,36.841481,35.456463,99.767593
4,101,2019-02-24,1.33,98.0,1.36,53,Male,173.0,3.610061,97.22596,37.672492,36.841481,100.796125


In [32]:
class SharedNodeVariable:
    def __init__(self, name, card, graph_key):
        self.name = name
        self.card = card
        self.graph_key = graph_key
        self.messages = {}
        # self.posteriors = []

    def add_message(self, plate_key, message):
        assert message.shape == (
            self.card,
        ), "The message must have the same shape as the variable's cardinality"
        self.messages[plate_key] = message

    def get_virtual_message(self):
        if len(self.messages) == 0:
            return None
        elif len(self.messages) == 1:
            return list(self.messages.values())[0]
        else:
            return reduce(np.multiply, list(self.messages.values()))

In [33]:
var_to_colname = {
    ecFEV1.name: "ecFEV1",
    O2Sat.name: "O2 Saturation",
}

df[ecFEV1.name] = df.apply(
    lambda row: ecFEV1.get_bin_for_value(row[var_to_colname[ecFEV1.name]])[1], axis=1
)
df[O2Sat.name] = df.apply(
    lambda row: O2Sat.get_bin_for_value(row[var_to_colname[O2Sat.name]])[1], axis=1
)
df.head()

Unnamed: 0,ID,Date Recorded,FEV1,O2 Saturation,ecFEV1,Age,Sex,Height,Predicted FEV1,Healthy O2 Saturation,ecFEV1 % Predicted,FEV1 % Predicted,O2 Saturation % Healthy,ecFEV1 (L),O2 saturation (%)
0,101,2019-02-20,1.31,97.0,1.32,53,Male,173.0,3.610061,97.22596,36.564477,36.287474,99.767593,26,47
1,101,2019-02-21,1.29,96.0,1.32,53,Male,173.0,3.610061,97.22596,36.564477,35.733466,98.739061,26,46
2,101,2019-02-22,1.32,96.0,1.32,53,Male,173.0,3.610061,97.22596,36.564477,36.564477,98.739061,26,46
3,101,2019-02-23,1.28,97.0,1.33,53,Male,173.0,3.610061,97.22596,36.841481,35.456463,99.767593,26,47
4,101,2019-02-24,1.33,98.0,1.36,53,Male,173.0,3.610061,97.22596,37.672492,36.841481,100.796125,27,48


In [35]:
from typing import List


def query_across_days(
    df,
    belief_propagation,
    shared_variables: List[SharedNodeVariable],
    variables: List[str],
):
    for i in range(len(df)):
        day = df["Date Recorded"].iloc[i]

        def build_evidence(variables):
            evidence = {}
            for variable in variables:
                idx_obs = df[variable].iloc[i]
                evidence[variable] = idx_obs
            return evidence

        evidence = build_evidence(variables)

        def build_virtual_evidence(shared_vars):
            virtual_evidence = {}
            for shared_var in shared_variables:
                virtual_message = shared_var.get_virtual_message()
                if virtual_message is not None:
                    virtual_evidence[shared_var.name] = virtual_message
            return virtual_evidence

        virtual_evidence = build_virtual_evidence(shared_variables)

        var_to_infer = list(map(lambda v: v.name, shared_variables))

        # Query the graph
        res, messages = belief_propagation.query(
            var_to_infer, evidence, virtual_evidence, get_messages=True
        )

        # Save message for current day
        for shared_var in shared_variables:
            shared_var.add_message(day, messages[shared_var.graph_key])

    return res


# Meta
key_hfev1 = f"['{ecFEV1.name}', '{HFEV1.name}', '{AR.name}'] -> {HFEV1.name}"
key_ho2sat = f"['{O2SatFFA.name}', '{HO2Sat.name}', '{AR.name}'] -> {HO2Sat.name}"

# Specific to ID 101
HFEV1shared = SharedNodeVariable(HFEV1.name, len(HFEV1.bins), key_hfev1)
HO2Satshared = SharedNodeVariable(HO2Sat.name, len(HO2Sat.bins), key_ho2sat)

shared_vars = [HFEV1shared, HO2Satshared]
obs_vars = [ecFEV1.name, O2Sat.name]

df_for_ID = df[df["ID"] == "101"]

height = df_for_ID.Height.iloc[0]
age = df_for_ID.Age.iloc[0]
sex = df_for_ID.Sex.iloc[0]
model, inf_alg, HFEV1, ecFEV1, AR, HO2Sat, O2SatFFA, IA, UO2Sat, O2Sat = (
    mb.o2sat_fev1_point_in_time_model_2(height, age, sex)
)

res = query_across_days(df_for_ID, inf_alg, shared_vars, obs_vars)

In [36]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = go.Figure()
fig = make_subplots(rows=4, cols=1)
ih.plot_histogram(fig, HFEV1, res[HFEV1.name].values, HFEV1.a, HFEV1.b, 1, 1, "blue")
ih.plot_histogram(
    fig, HO2Sat, res[HO2Sat.name].values, HO2Sat.a, HO2Sat.b, 2, 1, "blue"
)


def plot_scatter(fig, x, y, row, col, colour, title=None):
    fig.add_trace(
        go.Scatter(
            x=x,
            y=y,
            mode="markers",
        ),
        row=row,
        col=col,
    )
    fig.update_traces(marker=dict(size=2), row=row, col=col)
    if colour:
        fig.update_traces(marker=dict(color=colour), row=row, col=col)
    # Add x axis title
    fig.update_yaxes(title_text=title, row=row, col=col)
    fig.update_xaxes(title_text="Days", row=row, col=col)


plot_scatter(
    fig, df_for_ID["Date Recorded"], df_for_ID[ecFEV1.name], 3, 1, "black", ecFEV1.name
)
plot_scatter(
    fig, df_for_ID["Date Recorded"], df_for_ID[O2Sat.name], 4, 1, "black", O2Sat.name
)
fig.update_layout(
    title=f"HFEV1 and HO2Sat posterior distributions over time for ID 101 ({len(df_for_ID)} data-points)",
    width=1000,
    height=800,
    font=dict(size=10),
    showlegend=False,
)
fig.show()