In [1]:
import src.models.builders as mb
import src.inference.helpers as ih
import plotly.graph_objs as go
import numpy as np
import pandas as pd
import src.models.var_builders as var_builders

import src.data.breathe_data as breathe_data
from typing import List

from src.inference.long_inf_slicing import SharedNodeVariable

## Breathe data processing

In [2]:
df = breathe_data.load_o2_fev1_df_from_excel()
df.head()

Unnamed: 0,ID,Date Recorded,FEV1,O2 Saturation,ecFEV1,Age,Sex,Height,Predicted FEV1,Healthy O2 Saturation,ecFEV1 % Predicted,FEV1 % Predicted,O2 Saturation % Healthy
0,101,2019-02-20,1.31,97.0,1.32,53,Male,173.0,3.610061,97.22596,36.564477,36.287474,99.767593
1,101,2019-02-21,1.29,96.0,1.32,53,Male,173.0,3.610061,97.22596,36.564477,35.733466,98.739061
2,101,2019-02-22,1.32,96.0,1.32,53,Male,173.0,3.610061,97.22596,36.564477,36.564477,98.739061
3,101,2019-02-23,1.28,97.0,1.33,53,Male,173.0,3.610061,97.22596,36.841481,35.456463,99.767593
4,101,2019-02-24,1.33,98.0,1.36,53,Male,173.0,3.610061,97.22596,37.672492,36.841481,100.796125


In [3]:
# Create factor graph
height = 170
age = 30
sex = "Male"
(
    HFEV1,
    ecFEV1,
    AR,
    HO2Sat,
    O2SatFFA,
    IA,
    UO2Sat,
    O2Sat,
) = var_builders.o2sat_fev1_point_in_time(height, age, sex)

var_to_colname = {
    ecFEV1.name: "ecFEV1",
    O2Sat.name: "O2 Saturation",
}

df[ecFEV1.name] = df.apply(
    lambda row: ecFEV1.get_bin_for_value(row[var_to_colname[ecFEV1.name]])[1], axis=1
)
df[O2Sat.name] = df.apply(
    lambda row: O2Sat.get_bin_for_value(row[var_to_colname[O2Sat.name]])[1], axis=1
)
df.head()

Unnamed: 0,ID,Date Recorded,FEV1,O2 Saturation,ecFEV1,Age,Sex,Height,Predicted FEV1,Healthy O2 Saturation,ecFEV1 % Predicted,FEV1 % Predicted,O2 Saturation % Healthy,ecFEV1 (L),O2 saturation (%)
0,101,2019-02-20,1.31,97.0,1.32,53,Male,173.0,3.610061,97.22596,36.564477,36.287474,99.767593,26,47
1,101,2019-02-21,1.29,96.0,1.32,53,Male,173.0,3.610061,97.22596,36.564477,35.733466,98.739061,26,46
2,101,2019-02-22,1.32,96.0,1.32,53,Male,173.0,3.610061,97.22596,36.564477,36.564477,98.739061,26,46
3,101,2019-02-23,1.28,97.0,1.33,53,Male,173.0,3.610061,97.22596,36.841481,35.456463,99.767593,26,47
4,101,2019-02-24,1.33,98.0,1.36,53,Male,173.0,3.610061,97.22596,37.672492,36.841481,100.796125,27,48


## Slicing algorithm

In [23]:
def query_across_days(
    df,
    belief_propagation,
    shared_variables: List[SharedNodeVariable],
    variables: List[str],
    n_epochs,
):
    epoch = 0
    df_res_hfev1 = pd.DataFrame(index=HFEV1.bins_str)

    df_res = pd.DataFrame(
        columns=["Epoch"] + list(map(lambda v: v.name, shared_variables))
    )

    post_hfev1_old_epoch = HFEV1._uniform_prior()
    post_ho2sat_old_epoch = HO2Sat._uniform_prior()
    while True:
        print(f"epoch {epoch}")

        post_hfev1_old_day = HFEV1._uniform_prior()
        post_ho2sat_old_day = HO2Sat._uniform_prior()
        diffs_hfev1_day = np.array([])
        diffs_ho2sat_day = np.array([])
        for i in range(len(df)):
            day = df["Date Recorded"].iloc[i].strftime("%Y-%m-%d")

            def build_evidence(variables):
                evidence = {}
                for variable in variables:
                    idx_obs = df[variable].iloc[i]
                    evidence[variable] = idx_obs
                return evidence

            evidence = build_evidence(variables)

            def build_virtual_evidence(shared_variables):
                virtual_evidence = {}
                for shared_var in shared_variables:
                    virtual_message = shared_var.get_virtual_message(day)
                    if virtual_message is not None:
                        virtual_evidence[shared_var.name] = virtual_message
                return virtual_evidence

            virtual_evidence = build_virtual_evidence(shared_variables)

            var_to_infer = list(map(lambda v: v.name, shared_variables))

            # Query the graph
            res, messages = belief_propagation.query(
                var_to_infer, evidence, virtual_evidence, get_messages=True
            )

            # Save message for current day
            for shared_var in shared_variables:
                shared_var.add_message(day, messages[shared_var.graph_key])

            post_hfev1_old_day, diff_hfev1_day = get_diff(
                res, post_hfev1_old_day, HFEV1
            )
            post_ho2sat_old_day, diff_ho2sat_day = get_diff(
                res, post_ho2sat_old_day, HO2Sat
            )
            # print(
            #     f"Epoch {epoch}, day {i} - Diff hfev1 {diff_hfev1_day}, diff ho2sat {diff_ho2sat_day}"
            # )
            diffs_hfev1_day = np.append(diffs_hfev1_day, post_ho2sat_old_day)
            diffs_ho2sat_day = np.append(diffs_ho2sat_day, diff_ho2sat_day)

        # print(
        #     f"Epoch {epoch} - Sum daily diffs for HFEV1: {diffs_hfev1_day.sum()}, and HO2Sat: {diffs_ho2sat_day.sum()}"
        # )
        post_hfev1_old_epoch, diff_hfev1_epoch = get_diff(
            res, post_hfev1_old_epoch, HFEV1
        )
        post_ho2sat_old_epoch, diff_ho2sat_epoch = get_diff(
            res, post_ho2sat_old_epoch, HO2Sat
        )
        print(
            f"Epoch {epoch} - Posteriors' diff for HFEV1: {diff_hfev1_epoch}, and HO2Sat: {diff_ho2sat_epoch}"
        )

        # Create new row df with epoch, and on shared variables array per row cel
        new_row = [epoch] + list(map(lambda v: res[v.name].values, shared_variables))
        # Same but as df
        new_row = pd.DataFrame(
            [new_row], columns=["Epoch"] + list(map(lambda v: v.name, shared_variables))
        )

        df_res = pd.concat([df_res, new_row], ignore_index=True)
        df_res_hfev1[f"{epoch}"] = res[HFEV1.name].values

        if epoch >= n_epochs:
            return df_res, df_res_hfev1
        epoch += 1


def get_diff(res, old, var):
    new = res[var.name].values
    diff = np.abs(new - old).sum()
    return new, diff

In [24]:
# Meta
key_hfev1 = f"['{ecFEV1.name}', '{HFEV1.name}', '{AR.name}'] -> {HFEV1.name}"
key_ho2sat = f"['{O2SatFFA.name}', '{HO2Sat.name}', '{AR.name}'] -> {HO2Sat.name}"

# Specific to ID 101
HFEV1shared = SharedNodeVariable(HFEV1.name, len(HFEV1.bins), key_hfev1)
HO2Satshared = SharedNodeVariable(HO2Sat.name, len(HO2Sat.bins), key_ho2sat)

shared_vars = [HFEV1shared, HO2Satshared]
obs_vars = [ecFEV1.name, O2Sat.name]

df_for_ID = df[df["ID"] == "101"]

height = df_for_ID.Height.iloc[0]
age = df_for_ID.Age.iloc[0]
sex = df_for_ID.Sex.iloc[0]
model, inf_alg, HFEV1, ecFEV1, AR, HO2Sat, O2SatFFA, IA, UO2Sat, O2Sat = (
    mb.o2sat_fev1_point_in_time_model_2(height, age, sex)
)

n_epochs = 5
df_res, df_res_hfev1 = query_across_days(
    df_for_ID, inf_alg, shared_vars, obs_vars, n_epochs
)

epoch 0
Epoch 0 - Posteriors' diff for HFEV1: 1.113643328496882, and HO2Sat: 1.5164763809838036
epoch 1
Epoch 1 - Posteriors' diff for HFEV1: 0.0, and HO2Sat: 0.0
epoch 2
Epoch 2 - Posteriors' diff for HFEV1: 0.0, and HO2Sat: 0.0
epoch 3
Epoch 3 - Posteriors' diff for HFEV1: 0.0, and HO2Sat: 0.0
epoch 4
Epoch 4 - Posteriors' diff for HFEV1: 0.0, and HO2Sat: 0.0
epoch 5
Epoch 5 - Posteriors' diff for HFEV1: 0.0, and HO2Sat: 0.0


## When to stop the loopy BF?

In [113]:
df_res

Unnamed: 0,Epoch,Healthy FEV1 (L),Healthy O2 saturation (%)
0,0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
1,1,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."


In [121]:
# Plot final resutls for an ID
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = go.Figure()
fig = make_subplots(rows=n_epochs, cols=2)
for epoch in range(n_epochs):
    ih.plot_histogram(
        fig,
        HFEV1,
        df_res[HFEV1.name].iloc[epoch],
        HFEV1.a,
        HFEV1.b,
        epoch + 1,
        1,
        "blue",
        False,
    )
    ih.plot_histogram(
        fig,
        HO2Sat,
        df_res[HO2Sat.name].iloc[epoch],
        HO2Sat.a,
        HO2Sat.b,
        epoch + 1,
        2,
        "blue",
        False,
    )

fig.update_layout(
    title=f"HFEV1 and HO2Sat posterior distributions over time for ID 101 ({len(df_for_ID)} data-points)",
    width=1000,
    height=400,
    font=dict(size=10),
    showlegend=False,
)
fig.show()

In [91]:
# Create one heatmap for each day, with the HFEV1shared.name as the y axis, the day as the x axis and the probability as the value
# def create_heatmap(df, shared_var):
x = df_res_hfev1.columns
y = df_res_hfev1.index
z = df_res_hfev1

fig = go.Figure()
colorscale = [[0, "white"], [1, "blue"]]
fig.add_trace(go.Heatmap(z=z, x=x, y=y, colorscale=colorscale))
fig.update_layout(
    # title=f"Probability of {shared_var.name} given {', '.join(shared_var.parents)}",
    yaxis_title=HFEV1.name,
    xaxis_title="Epoch",
    width=300,
    height=800,
    font=dict(size=5),
)
fig.show()

In [55]:
df.corr(numeric_only=True)

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species_id
sepal_length,1.0,-0.109369,0.871754,0.817954,0.782561
sepal_width,-0.109369,1.0,-0.420516,-0.356544,-0.419446
petal_length,0.871754,-0.420516,1.0,0.962757,0.949043
petal_width,0.817954,-0.356544,0.962757,1.0,0.956464
species_id,0.782561,-0.419446,0.949043,0.956464,1.0


## Plot final results for an ID

In [None]:
# Plot final resutls for an ID
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = go.Figure()
fig = make_subplots(rows=4, cols=1)
ih.plot_histogram(fig, HFEV1, res[HFEV1.name].values, HFEV1.a, HFEV1.b, 1, 1, "blue")
ih.plot_histogram(
    fig, HO2Sat, res[HO2Sat.name].values, HO2Sat.a, HO2Sat.b, 2, 1, "blue"
)


def plot_scatter(fig, x, y, row, col, colour, title=None):
    fig.add_trace(
        go.Scatter(
            x=x,
            y=y,
            mode="markers",
        ),
        row=row,
        col=col,
    )
    fig.update_traces(marker=dict(size=2), row=row, col=col)
    if colour:
        fig.update_traces(marker=dict(color=colour), row=row, col=col)
    # Add x axis title
    fig.update_yaxes(title_text=title, row=row, col=col)
    fig.update_xaxes(title_text="Days", row=row, col=col)


plot_scatter(
    fig, df_for_ID["Date Recorded"], df_for_ID[ecFEV1.name], 3, 1, "black", ecFEV1.name
)
plot_scatter(
    fig, df_for_ID["Date Recorded"], df_for_ID[O2Sat.name], 4, 1, "black", O2Sat.name
)
fig.update_layout(
    title=f"HFEV1 and HO2Sat posterior distributions over time for ID 101 ({len(df_for_ID)} data-points)",
    width=1000,
    height=800,
    font=dict(size=10),
    showlegend=False,
)
fig.show()