Idea of cutset conditioning: it's a way to run exact inference on a model with loops. You cut the loop by observing one of the variables in the loop to all the possible states, then fuse the results in a smart way.

 Cutset Conditioning is a technique for solving nearly-tree-structured CSPs in which some variables are assigned to separately from the rest, removed from the constraint graph, and leaving a tree-structured CSP for those remaining.

 Cutsets are some set of variables that are cut (severing edges) from the original constraint graph and solved separately.

 Conditioning is the process of assigning a value to some variable in a cutset, performing forward checking on its neighbor domains before cutting, and finally, severing it from the original graph.

https://forns.lmu.build/classes/spring-2019/cmsi-282/lecture-13M.html#backtracking++

In [1]:
import src.data.breathe_data as bd

# import src.inference.long_inf_slicing as slicing
import src.models.builders as mb
import src.data.helpers as dh

# import src.models.var_builders as var_builders
import src.inference.helpers as ih
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go

# import src.models.helpers as mh

import pandas as pd
import numpy as np

Figure per entry that has the AR from obs FEF2575 on top and on the bottom the point mass AR obtained by repeating model runs with several point mass HFEV1 (3, 3.5, 4, 4.5, 5, etc)

In [2]:
df = bd.load_meas_from_excel("BR_O2_FEV1_FEF2575_conservative_smoothing_with_idx")
# df = bd.load_meas_from_excel("BR_O2_FEV1_FEF2575_with_idx")

INFO:root:* Checking for same day measurements *


# Visualisations of the alignment between the message from FEF25-75 and from FEV1/HFEV1 factors to AR

### Two plots

In [10]:
# With each run I should retrieve
# 1/ the message from FEF25-75%FEFV1 to AR
# 2/ the point mass message from the factor ecFEV1, HFEV1 to AR
# Use the point in time model, there is no shared variables.


def can_messages_align_for_ID(df_for_ID):
    df_for_ID.reset_index(inplace=True, drop=True)
    height = df_for_ID.loc[0, "Height"]
    age = df_for_ID.loc[0, "Age"]
    sex = df_for_ID.loc[0, "Sex"]
    id = df_for_ID.loc[0, "ID"]
    (
        model,
        inf_alg,
        HFEV1,
        ecFEV1,
        AR,
        HO2Sat,
        O2SatFFA,
        IA,
        UO2Sat,
        O2Sat,
        ecFEF2575prctecFEV1,
    ) = mb.o2sat_fev1_fef2575_point_in_time_model_shared_healthy_vars(height, age, sex)

    FEV_to_AR_key = "['ecFEV1 (L)', 'Healthy FEV1 (L)', 'Airway resistance (%)'] -> Airway resistance (%)"
    FEF2575_to_AR_key = (
        "['ecFEF25-75 % ecFEV1 (%)', 'Airway resistance (%)'] -> Airway resistance (%)"
    )

    HFEV1_obs_list = [2, 2.5, 3, 3.5, 4, 4.5, 5, 5.5]
    colour_list = px.colors.sample_colorscale(
        "YlGnBu", [i / (len(HFEV1_obs_list) - 1) for i in range(len(HFEV1_obs_list))]
    )

    df_for_ID = df_for_ID.sort_values(by="ecFEF2575%ecFEV1", ascending=True)
    # Take 4 idx in 5, 30, 60, 95 percentiles of the data
    idx_list = list((len(df_for_ID) * np.array([0.05, 0.5, 0.95])).astype(int))
    df_for_ID_sub = df_for_ID.iloc[idx_list, :]

    res_per_idx = []

    for idx in df_for_ID_sub.index:
        FEV1_obs = df_for_ID.loc[idx, "ecFEV1"]
        FEF2575prctFEV1_obs = df_for_ID.loc[idx, "ecFEF2575%ecFEV1"]
        FEV_m_list = []

        # Query AR
        for HFEV1_obs in HFEV1_obs_list:
            # HFEV1_obs must be > ecFEV1_obs
            evidence = [
                [ecFEV1, FEV1_obs],
                [ecFEF2575prctecFEV1, FEF2575prctFEV1_obs],
                [HFEV1, HFEV1_obs],
            ]
            _, messages = ih.infer_on_factor_graph(
                inf_alg, [AR], evidence, get_messages=True
            )

            FEV_m_list.append(messages[FEV_to_AR_key])
            FEF2575_m = messages[FEF2575_to_AR_key]

        res_per_idx.append([FEV1_obs, FEF2575prctFEV1_obs, FEV_m_list, FEF2575_m])

    fig = make_subplots(rows=6, cols=1, vertical_spacing=0.05)
    plot_row = 1
    for FEV1_obs, FEF2575prctFEV1_obs, FEV_m_list, FEF2575_m in res_per_idx:

        for HFEV1_obs, FEV_m, colour in zip(HFEV1_obs_list, FEV_m_list, colour_list):
            ih.plot_histogram(
                fig,
                AR,
                FEV_m,
                AR.a,
                AR.b,
                plot_row,
                1,
                name=f"HFEV1 = {HFEV1_obs}",
                annot=False,
            )
            # Change the last trace's colour
            fig.data[-1].marker.color = colour
            # Hide legend if plot_row > 1
            if plot_row > 1:
                fig.data[-1].showlegend = False

        ih.plot_histogram(
            fig,
            AR,
            FEF2575_m,
            AR.a,
            AR.b,
            plot_row + 1,
            1,
            annot=False,
            title=AR.name,
            colour="grey",
        )
        # hide this last trace's legend
        fig.data[-1].showlegend = False
        # Add message from ecFEV1/HFEV1 factor on y axis row 1 title
        fig.update_yaxes(title_text=f"ecFEV1<br>{FEV1_obs:.2f}L", row=plot_row, col=1)
        fig.update_yaxes(
            title_text=f"ecFEF25-75%ecFEV1<br>{FEF2575prctFEV1_obs:.2f}%",
            row=plot_row + 1,
            col=1,
        )
        plot_row += 2

    # Reduce font size and margins
    title = f"ID {id} - Can points mass messages from HFEV1, ecFEV1 align with messages from FEF25-75"
    # Reduce margins between plots
    fig.update_layout(
        font=dict(size=8),
        margin=dict(l=10, r=10, t=30, b=10),
        height=750,
        width=600,
        barmode="overlay",
        bargap=0.1,
        title=title,
    )
    fig.update_xaxes(title_standoff=6)

    fig.write_image(
        dh.get_path_to_main() + f"/PlotsBreathe/Cutset_conditioning/{title}.pdf"
    )
    # fig.show()


interesting_ids = [
    "132",
    "146",
    "177",
    "180",
    "202",
    "527",
    "117",
    "131",
    "134",
    "191",
    "139",
    "253",
    "101",
    # Also from consec values
    "405",
    "272",
    "201",
    "203",
]

# df[df.ID.isin(interesting_ids)].groupby("ID").apply(can_messages_align_for_ID)

df_for_ID = df[df.ID == "101"]
can_messages_align_for_ID(df_for_ID)

### Heatmaps of FEF2575 messages vs FEV1 messages for different HFEV1

In [6]:
# With each run I should retrieve
# 1/ the message from FEF25-75%FEFV1 to AR
# 2/ the point mass message from the factor ecFEV1, HFEV1 to AR
# Use the point in time model, there is no shared variables.


def can_messages_align_for_ID_heatmap(df_for_ID, save=True):
    df_for_ID.reset_index(inplace=True, drop=True)
    height = df_for_ID.loc[0, "Height"]
    age = df_for_ID.loc[0, "Age"]
    sex = df_for_ID.loc[0, "Sex"]
    (
        model,
        inf_alg,
        HFEV1,
        ecFEV1,
        AR,
        HO2Sat,
        O2SatFFA,
        IA,
        UO2Sat,
        O2Sat,
        ecFEF2575prctecFEV1,
    ) = mb.o2sat_fev1_fef2575_point_in_time_model_shared_healthy_vars(height, age, sex)

    FEV_to_AR_key = "['ecFEV1 (L)', 'Healthy FEV1 (L)', 'Airway resistance (%)'] -> Airway resistance (%)"
    FEF2575_to_AR_key = (
        "['ecFEF25-75 % ecFEV1 (%)', 'Airway resistance (%)'] -> Airway resistance (%)"
    )

    HFEV1_obs_list = [2, 3, 4, 5]
    # Compare obs list to min obs fev1
    min_obs_fev1 = df_for_ID.ecFEV1.min()
    HFEV1_obs_list = [
        HFEV1_obs for HFEV1_obs in HFEV1_obs_list if HFEV1_obs > min_obs_fev1
    ]

    # Dates on the xaxis, AR on the y axis
    FEV_m_arr = np.zeros((AR.card, len(df_for_ID)))
    FEF2575_m_arr = np.zeros((AR.card, len(df_for_ID)))

    for i, row in df_for_ID.iterrows():
        FEV1_obs = row.ecFEV1
        FEF2575prctFEV1_obs = row["ecFEF2575%ecFEV1"]

        # Query AR
        FEV_m_one_day = np.zeros(AR.card)
        for HFEV1_obs in HFEV1_obs_list:
            # HFEV1_obs must be > ecFEV1_obs
            evidence = [
                [ecFEV1, FEV1_obs],
                [ecFEF2575prctecFEV1, FEF2575prctFEV1_obs],
                [HFEV1, HFEV1_obs],
            ]
            _, messages = ih.infer_on_factor_graph(
                inf_alg, [AR], evidence, get_messages=True
            )

            # Since the messages are "almost" point mass (max over 2 bins)
            # we'll just put the value for the heatmap at the location of the mean
            AR_mean_val = AR.get_mean(messages[FEV_to_AR_key])
            AR_mean_idx = AR.get_bin_for_value(AR_mean_val)[1]
            # Add intensity value at the location of the AR mean
            FEV_m_one_day[AR_mean_idx] = HFEV1_obs

        FEV_m_arr[:, i] = FEV_m_one_day
        fef2575_m = messages[FEF2575_to_AR_key]
        # Make sure the messages are normalised - yes it is the case indeed
        fef2575_m = fef2575_m / fef2575_m.sum()
        FEF2575_m_arr[:, i] = fef2575_m

    df_for_ID["Date"] = pd.to_datetime(df_for_ID["Date Recorded"]).copy()
    df_for_ID["Date"] = df_for_ID["Date"].dt.strftime("%d-%m-%Y")

    fig = go.Figure(
        data=go.Heatmap(
            z=FEF2575_m_arr,
            x=df_for_ID["Date"],
            y=AR.get_bins_str(),
            opacity=0.8,
            colorscale="Blues",
            # Exclude from colour bar
            showscale=False,
        )
    )

    colorscale = [
        [0, "rgba(0, 0, 0, 0)"],  # Transparent for value 0
        [1 / 5, "rgba(0, 0, 0, 0)"],  # Transparent for value 0
        [1 / 5, "rgb(255, 245, 235)"],  # Light orange for value 2
        [2 / 5, "rgb(255, 245, 235)"],  # Light orange for value 2
        [2 / 5, "rgb(254, 230, 206)"],  # Medium-light orange for value 3
        [3 / 5, "rgb(254, 230, 206)"],  # Medium-light orange for value 3
        [3 / 5, "rgb(253, 174, 107)"],  # Medium orange for value 4
        [4 / 5, "rgb(253, 174, 107)"],  # Medium orange for value 4
        [4 / 5, "rgb(241, 105, 19)"],  # Dark orange for value 5
        # [5/5, 'rgb(241, 105, 19)'],  # Dark orange for value 5
        # [5/5, 'rgb(217, 72, 1)'],  # Darker orange for value 6
        [1, "rgb(217, 72, 1)"],  # Darker orange for value 6
    ]

    fig.add_traces(
        go.Heatmap(
            z=FEV_m_arr,
            x=df_for_ID["Date"],
            y=AR.get_bins_str(),
            # Change colour
            colorscale=colorscale,
        )
    )

    title = f"{df_for_ID.loc[0, 'ID']} - Heatmaps messages alignment from HFEV1, ecFEV1 to AR and FEF25-75 to AR"
    fig.update_layout(
        font=dict(size=6), height=600, width=len(df_for_ID) + 400, title=title
    )
    # Add Date on x axis
    fig.update_xaxes(title_text="Date", tickangle=45)
    fig.update_yaxes(title_text="Airway resistance (%)")

    if save:
        fig.write_image(
            dh.get_path_to_main() + f"/PlotsBreathe/Cutset_conditioning/{title}.png",
            scale=3,
        )
    else:
        fig.show()

    return fig, FEV_m_arr, FEF2575_m_arr


interesting_ids = [
    "132",
    "146",
    "177",
    "180",
    "202",
    "527",
    "117",
    "131",
    "134",
    "191",
    "139",
    "253",
    "101",
    # Also from consec values
    "405",
    "272",
    "201",
    "203",
]

df[df.ID.isin(interesting_ids)].groupby("ID").apply(can_messages_align_for_ID_heatmap)

# df_for_ID = df[df.ID == "191"]
# fig, FEV_m_arr, FEF2575_m_arr = can_messages_align_for_ID_heatmap(df_for_ID, save=False)


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide



ID
101    (Figure({\n    'data': [{'colorscale': [[0.0, ...
117    (Figure({\n    'data': [{'colorscale': [[0.0, ...
131    (Figure({\n    'data': [{'colorscale': [[0.0, ...
132    (Figure({\n    'data': [{'colorscale': [[0.0, ...
134    (Figure({\n    'data': [{'colorscale': [[0.0, ...
139    (Figure({\n    'data': [{'colorscale': [[0.0, ...
146    (Figure({\n    'data': [{'colorscale': [[0.0, ...
177    (Figure({\n    'data': [{'colorscale': [[0.0, ...
180    (Figure({\n    'data': [{'colorscale': [[0.0, ...
191    (Figure({\n    'data': [{'colorscale': [[0.0, ...
201    (Figure({\n    'data': [{'colorscale': [[0.0, ...
202    (Figure({\n    'data': [{'colorscale': [[0.0, ...
203    (Figure({\n    'data': [{'colorscale': [[0.0, ...
253    (Figure({\n    'data': [{'colorscale': [[0.0, ...
272    (Figure({\n    'data': [{'colorscale': [[0.0, ...
405    (Figure({\n    'data': [{'colorscale': [[0.0, ...
527    (Figure({\n    'data': [{'colorscale': [[0.0, ...
dtype: object

# Fusing the weights

### Evaluate computational speedup by avoiding to calculate doublons

In [7]:
(
    model,
    inf_alg,
    HFEV1,
    ecFEV1,
    AR,
    HO2Sat,
    O2SatFFA,
    IA,
    UO2Sat,
    O2Sat,
    ecFEF2575prctecFEV1,
) = mb.o2sat_fev1_fef2575_point_in_time_model_shared_healthy_vars(120, 12, "Male")

In [65]:
def get_speedup_prct_for_id(df_for_ID):
    # How many entries have the same bin in ecFEV1 and ecFEF2575%ecFEV1
    # This trick wouldn't improve the computation time much
    n_data_no_duplicates = len(
        df_for_ID.groupby(
            ["idx ecFEV1 (L)", "idx ecFEF2575%ecFEV1", "idx O2 saturation (%)"]
        )
        .size()
        .sort_values(ascending=False)
    )
    # 10 s for 10 entries
    time_per_entry = 12.5 / 10
    return (
        len(df_for_ID),
        len(df_for_ID) * time_per_entry,
        n_data_no_duplicates,
        n_data_no_duplicates * time_per_entry,
    )


times = df.groupby("ID").apply(get_speedup_prct_for_id).sort_values(ascending=False)

t_before = 0
t_after = 0
n_before = 0
n_after = 0
for i in range(len(times)):
    n_ops_b, t_ops_b, n_ops_a, t_ops_a = times.values[i]
    t_before += t_ops_b
    t_after += t_ops_a
    n_before += n_ops_b
    n_after += n_ops_a
print(
    f"Before: {t_before/3600:.0f} h, After: {t_after/3600:.0f} h. Speedup: {t_before/t_after:.2f}"
)
print(f"Before: {n_before}, After: {n_after:}. Speedup: {n_before/n_after:.2f}")

times

Before: 14 h, After: 9 h. Speedup: 1.56
Before: 41260, After: 26456. Speedup: 1.56


ID
101      (1680, 2100.0, 206, 257.5)
123     (1128, 1410.0, 601, 751.25)
240    (1101, 1376.25, 509, 636.25)
133      (1066, 1332.5, 502, 627.5)
405     (1035, 1293.75, 234, 292.5)
                   ...             
225              (1, 1.25, 1, 1.25)
213              (1, 1.25, 1, 1.25)
516              (1, 1.25, 1, 1.25)
160              (1, 1.25, 1, 1.25)
428              (1, 1.25, 1, 1.25)
Length: 352, dtype: object

### Actually fusing weights

In [30]:
def compute_log_p_D_given_M_per_entry_per_HFEV1_obs(
    df_for_ID_in, debug=False, save=False, speedup=True, ar_prior="uniform"
):
    df_for_ID_in = df_for_ID_in.copy().reset_index(drop=True)
    id = df_for_ID_in.loc[0, "ID"]
    height = df_for_ID_in.loc[0, "Height"]
    age = df_for_ID_in.loc[0, "Age"]
    sex = df_for_ID_in.loc[0, "Sex"]

    (
        _,
        inf_alg,
        HFEV1,
        ecFEV1,
        _,
        _,
        _,
        _,
        _,
        _,
        ecFEF2575prctecFEV1,
    ) = mb.o2sat_fev1_fef2575_point_in_time_model_shared_healthy_vars(
        height, age, sex, ar_prior=ar_prior
    )

    # HFEV1 can't be above max observed ecFEV1
    HFEV1_obs_list = HFEV1.midbins[
        HFEV1.midbins - HFEV1.bin_width / 2 >= df_for_ID_in.ecFEV1.max()
    ]
    print(
        f"ID {id} - Number of HFEV1 specific models: {len(HFEV1_obs_list)}, max ecFEV1: {df_for_ID_in.ecFEV1.max()}, first possible bin for HFEV1: {HFEV1.get_bin_for_value(HFEV1_obs_list[0])[0]}"
    )

    N = len(df_for_ID_in)
    df_for_ID = df_for_ID_in.copy()

    # Speed up code by removing duplicates and adding them later on
    if speedup:
        print(f"{N} entries before speedup")
        df_for_ID = df_for_ID.sort_values(
            by=["idx ecFEV1 (L)", "idx ecFEF2575%ecFEV1"], ascending=False
        )
        df_duplicates = (
            df_for_ID.groupby(["idx ecFEV1 (L)", "idx ecFEF2575%ecFEV1"])
            .size()
            .reset_index()
        )
        df_duplicates.columns = [
            "idx ecFEV1 (L)",
            "idx ecFEF2575%ecFEV1",
            "n duplicates",
        ]
        df_duplicates = df_duplicates.sort_values(
            by=["idx ecFEV1 (L)", "idx ecFEF2575%ecFEV1"], ascending=False
        ).reset_index(drop=True)
        n_dups = df_duplicates["n duplicates"].values
        # Keep only the first entry for each pair of ecFEV1 and ecFEF2575%ecFEV1]
        # Create df_for_ID without duplicates
        df_for_ID = df_for_ID.drop_duplicates(
            subset=["idx ecFEV1 (L)", "idx ecFEF2575%ecFEV1"], keep="first"
        ).reset_index(drop=True)
        print(f"{len(df_for_ID)} entries after speedup")
        print(
            f"Number of duplicates {N - len(df_for_ID)}, speedup removes {(N-len(df_for_ID))/N*100:.2f}% of entries"
        )

    H = len(HFEV1_obs_list)
    N_maybe_no_dups = len(df_for_ID) if speedup else N
    log_p_D_given_M = np.zeros((N_maybe_no_dups, H))
    AR_dist_given_M_matrix = np.zeros((N_maybe_no_dups, AR.card, H))

    # Get the joint probability of ecFEV1 and ecFEF2575 given the model for this individual
    # For each entry
    for n, row in df_for_ID.iterrows():
        if debug:
            print(f"Processing row {n+1}/{N_maybe_no_dups}")

        # For each model given an HFEV1 observation
        for h, HFEV1_obs in enumerate(HFEV1_obs_list):

            # Getting the joint probabilities of ecFEF2575 and ecFEV1 under the model
            res1, _ = ih.infer_on_factor_graph(
                inf_alg,
                [ecFEV1, ecFEF2575prctecFEV1],
                [[HFEV1, HFEV1_obs]],
                get_messages=True,
            )
            dist_ecFEV1 = res1[ecFEV1.name].values

            # Observe both HFEV1 and ecFEV1 to compute the joint probability
            # P(ecFEV1, ecFEF2575 | HFEV1) = P(ecFEV1 | HFEV1) * P( ecFEF2575 | HFEV1, ecFEV1)
            res2, _ = ih.infer_on_factor_graph(
                inf_alg,
                [ecFEF2575prctecFEV1],
                [[HFEV1, HFEV1_obs], [ecFEV1, row.ecFEV1]],
                get_messages=True,
            )
            dist_ecFEF2575prctecFEV1 = res2[ecFEF2575prctecFEV1.name].values

            res3, _ = ih.infer_on_factor_graph(
                inf_alg,
                [AR],
                [
                    [HFEV1, HFEV1_obs],
                    [ecFEV1, row.ecFEV1],
                    [ecFEF2575prctecFEV1, row["ecFEF2575%ecFEV1"]],
                ],
                get_messages=True,
            )
            dist_AR = res3[AR.name].values

            # The probability of the data given the model is the expectation of the data given the model
            idx_obs_ecFEV1 = ecFEV1.get_bin_for_value(row.ecFEV1)[1]
            idx_obs_ecFEF2575 = ecFEF2575prctecFEV1.get_bin_for_value(
                row["ecFEF2575%ecFEV1"]
            )[1]

            # Get the probability of the data given the model
            p_ecFEV1 = dist_ecFEV1[idx_obs_ecFEV1]
            p_ecFEF2575 = dist_ecFEF2575prctecFEV1[idx_obs_ecFEF2575]

            # Save information for this round
            AR_dist_given_M_matrix[n, :, h] = dist_AR
            log_p_D_given_M[n, h] = np.log(p_ecFEV1) + np.log(p_ecFEF2575)

    if debug:
        print("log(P(D|M)), first row", log_p_D_given_M[0, :])

    if speedup:
        # Put back the duplicates
        # Repeat each element in the array by the number in the array dups
        log_p_D_given_M = np.repeat(log_p_D_given_M, n_dups, axis=0)
        AR_dist_given_M_matrix = np.repeat(AR_dist_given_M_matrix, n_dups, axis=0)
        if debug:
            print("P(D|M), first row, after applying duplicates", log_p_D_given_M[:, 0])

    # For each HFEV1 model, given HFEV1_obs_list, we compute the log probability of the model given the data
    # log(P(M|D)) = 1/N * sum_n log(P(D|M)) + Cn_avg + log(P(M))
    log_p_M_given_D = np.zeros(H)
    for h, HFEV1_obs in enumerate(HFEV1_obs_list):
        log_p_M = np.log(HFEV1.cpt[HFEV1.get_bin_for_value(HFEV1_obs)[1]])
        log_p_M_given_D[h] = np.sum(log_p_D_given_M[:, h]) + log_p_M

    # Exponentiating very negative numbers gives too small numbers
    # Setting the highest number to 1
    shift = 1 - log_p_M_given_D.max()
    log_p_M_given_D_shifted = log_p_M_given_D + shift

    # Exponentiate and normalise
    p_M_given_D = np.exp(log_p_M_given_D_shifted)
    p_M_given_D = p_M_given_D / p_M_given_D.sum()

    # Fill the p(M|D) array with zeros on the left, where the HFEV1_obs < max ecFEV1
    p_M_given_D_full = np.zeros(HFEV1.card)
    HFEV1_obs_idx = [
        HFEV1.get_bin_for_value(HFEV1_obs)[1] for HFEV1_obs in HFEV1_obs_list
    ]
    p_M_given_D_full[HFEV1_obs_idx] = p_M_given_D

    # Add plot
    layout = [
        [{"type": "scatter", "rowspan": 1, "colspan": 1}, None, None],
        [{"type": "heatmap", "rowspan": 3, "colspan": 3}, None, None],
        [None, None, None],
        [None, None, None],
    ]
    fig = make_subplots(
        rows=np.shape(layout)[0],
        cols=np.shape(layout)[1],
        specs=layout,
        vertical_spacing=0.1,
    )

    # Add HFEV1 posterior
    ih.plot_histogram(fig, HFEV1, p_M_given_D_full, 0, 6, 1, 1, annot=True)

    # Add heatmap with AR posteriors
    AR_dist_matrix = np.matmul(AR_dist_given_M_matrix, p_M_given_D)
    df1 = pd.DataFrame(
        data=AR_dist_matrix,
        columns=AR.get_bins_str(),
        index=df_for_ID_in["Date Recorded"].apply(
            lambda date: date.strftime("%Y-%m-%d")
        ),
    )
    colorscale = [
        [0, "white"],
        [0.01, "red"],
        [0.05, "yellow"],
        [0.1, "cyan"],
        [0.6, "blue"],
        [1, "black"],
    ]

    fig.add_trace(
        go.Heatmap(z=df1.T, x=df1.index, y=df1.columns, coloraxis="coloraxis1"),
        row=2,
        col=1,
    )

    speedup = " (with speedup)" if speedup else ""

    title = f"{id} - Posterior HFEV1 after fusing all P(M_h|D)<br>AR prior: {ar_prior}{speedup}"
    fig.update_layout(
        font=dict(size=12),
        height=700,
        width=1200,
        title=title,
        coloraxis1=dict(
            colorscale=colorscale,
            colorbar_x=1,
            colorbar_y=0.36,
            # colorbar_thickness=23,
            colorbar_len=0.77,
        ),
    )
    # Add Date on x axis
    fig.update_xaxes(title_text=HFEV1.name, row=1, col=1)
    fig.update_yaxes(title_text="p", row=1, col=1)
    fig.update_yaxes(title_text=AR.name, row=2, col=1)
    fig.update_xaxes(
        title_text="Date",
        row=2,
        col=1,
        nticks=50,
        type="category",
    )

    if save:
        fig.write_image(
            dh.get_path_to_main() + f"/PlotsBreathe/Cutset_conditioning/{title}.png",
            scale=3,
        )
    else:
        fig.show()

    return fig, p_M_given_D_full, p_M_given_D, AR_dist_given_M_matrix

In [29]:
# ar_prior = "breathe (2 days model, ecFEV1, ecFEF25-75)"
ar_prior = "uniform"
# p_M_given_D_plot, fig = compute_log_p_D_given_M_per_entry_per_HFEV1_obs(dftmp, debug=False, save=False, speedup=True, ar_prior=ar_prior)

dftmp = df[df.ID == "527"]
fig, p_M_given_D_full, p_M_given_D, AR_dist_given_M_matrix = (
    compute_log_p_D_given_M_per_entry_per_HFEV1_obs(
        dftmp, debug=False, save=False, speedup=True, ar_prior=ar_prior
    )
)

ID 527 - Number of HFEV1 specific models: 100, max ecFEV1: 0.92, first possible bin for HFEV1: [1.00; 1.05)
5 entries before speedup
4 entries after speedup
Number of duplicates 1, speedup removes 20.00% of entries


# Infering HFEV1, AR, IA through time while observing ecFEV1, ecFEF25-75, SpO2

In [5]:
import os
import multiprocessing

# Number of CPU cores
num_cores = os.cpu_count()  # or multiprocessing.cpu_count()
print(f"Number of CPU cores: {num_cores}")

Number of CPU cores: 10


In [6]:
import time
import itertools

In [16]:
def compute_log_p_D_given_M_per_entry_per_HFEV1_obs(
    df_for_ID_in, debug=False, save=False, speedup=True, ar_prior="uniform"
):
    df_for_ID_in = df_for_ID_in.copy().reset_index(drop=True)
    id = df_for_ID_in.loc[0, "ID"]
    height = df_for_ID_in.loc[0, "Height"]
    age = df_for_ID_in.loc[0, "Age"]
    sex = df_for_ID_in.loc[0, "Sex"]

    (
        _,
        inf_alg,
        HFEV1,
        ecFEV1,
        AR,
        HO2Sat,
        O2SatFFA,
        IA,
        UO2Sat,
        O2Sat,
        ecFEF2575prctecFEV1,
    ) = mb.o2sat_fev1_fef2575_point_in_time_model_shared_healthy_vars(
        height, age, sex, ar_prior=ar_prior
    )

    # HFEV1 can't be above max observed ecFEV1
    HFEV1_obs_list = HFEV1.midbins[
        HFEV1.midbins - HFEV1.bin_width / 2 >= df_for_ID_in.ecFEV1.max()
    ]
    # Create tuples of obs (HFEV1, HO2Sat) to observe
    H_obs_list = [
        list(zip([HFEV1_obs] * HO2Sat.card, HO2Sat.midbins))
        for HFEV1_obs in HFEV1_obs_list
    ]
    # Flatten the list
    H_obs_list = list(itertools.chain(*H_obs_list))

    print(
        f"ID {id} - Number of HFEV1, HO2Sat specific models: {len(H_obs_list)}, max ecFEV1: {df_for_ID_in.ecFEV1.max()}, first possible bin for HFEV1: {HFEV1.get_bin_for_value(HFEV1_obs_list[0])[0]}"
    )

    N = len(df_for_ID_in)
    df_for_ID = df_for_ID_in.copy()

    # Speed up code by removing duplicates and adding them later on
    if speedup:
        print(f"{N} entries before speedup")
        df_for_ID = df_for_ID.sort_values(
            by=["idx ecFEV1 (L)", "idx ecFEF2575%ecFEV1", "idx O2 saturation (%)"],
            ascending=False,
        )
        df_duplicates = (
            df_for_ID.groupby(
                ["idx ecFEV1 (L)", "idx ecFEF2575%ecFEV1", "idx O2 saturation (%)"]
            )
            .size()
            .reset_index()
        )
        df_duplicates.columns = [
            "idx ecFEV1 (L)",
            "idx ecFEF2575%ecFEV1",
            "idx O2 saturation (%)",
            "n duplicates",
        ]
        df_duplicates = df_duplicates.sort_values(
            by=["idx ecFEV1 (L)", "idx ecFEF2575%ecFEV1", "idx O2 saturation (%)"],
            ascending=False,
        ).reset_index(drop=True)
        n_dups = df_duplicates["n duplicates"].values
        # Keep only the first entry for each pair of ecFEV1 and ecFEF2575%ecFEV1]
        # Create df_for_ID without duplicates
        df_for_ID = df_for_ID.drop_duplicates(
            subset=["idx ecFEV1 (L)", "idx ecFEF2575%ecFEV1", "idx O2 saturation (%)"],
            keep="first",
        ).reset_index(drop=True)
        print(f"{len(df_for_ID)} entries after speedup")
        print(
            f"Number of duplicates {N - len(df_for_ID)}, speedup removes {(N-len(df_for_ID))/N*100:.2f}% of entries"
        )

    H = len(H_obs_list)
    N_maybe_no_dups = len(df_for_ID) if speedup else N
    log_p_D_given_M = np.zeros((N_maybe_no_dups, H))
    AR_dist_given_M_matrix = np.zeros((N_maybe_no_dups, AR.card, H))

    # Get the joint probability of ecFEV1 and ecFEF2575 given the model for this individual
    # For each entry
    tic = time.time()
    for n, row in df_for_ID.iterrows():
        if debug:
            print(f"Processing row {n+1}/{N_maybe_no_dups}")

        # For each model given an HFEV1 observation
        for h, (HFEV1_obs, HO2Sat_obs) in enumerate(H_obs_list):

            # Getting the joint probabilities of ecFEF2575 and ecFEV1 under the model
            res1 = ih.infer_on_factor_graph(
                inf_alg,
                [ecFEV1],
                [[HFEV1, HFEV1_obs], [HO2Sat, HO2Sat_obs]],
            )
            dist_ecFEV1 = res1[ecFEV1.name].values

            # Observe both HFEV1 and ecFEV1 to compute the joint probability
            # P(ecFEV1, ecFEF2575, O2sat | HFEV1) = P(ecFEV1 | HFEV1) * P( ecFEF2575 | HFEV1, ecFEV1) * P( O2Sat | HFEV1, ecFEV1, ecFEF2575)
            res2 = ih.infer_on_factor_graph(
                inf_alg,
                [ecFEF2575prctecFEV1, AR],
                [[HFEV1, HFEV1_obs], [HO2Sat, HO2Sat_obs], [ecFEV1, row.ecFEV1]],
            )
            dist_ecFEF2575prctecFEV1 = res2[ecFEF2575prctecFEV1.name].values

            # res3, _ = ih.infer_on_factor_graph(
            #     inf_alg,
            #     [O2Sat],
            #     [
            #         [HFEV1, HFEV1_obs],
            #        [HO2Sat, HO2Sat_obs],
            #         [ecFEV1, row.ecFEV1],
            #         [ecFEF2575prctecFEV1, row["ecFEF2575%ecFEV1"]],
            #     ],
            #     get_messages=True,
            # )
            # dist_O2Sat = res3[O2Sat.name].values

            # res4, _ = ih.infer_on_factor_graph(
            #     inf_alg,
            #     [AR],
            #     [
            #         [HFEV1, HFEV1_obs],
            #         [HO2Sat, HO2Sat_obs],
            #         [ecFEV1, row.ecFEV1],
            #         [ecFEF2575prctecFEV1, row["ecFEF2575%ecFEV1"]],
            #         # [O2Sat, row["O2 Saturation"]],
            #     ],
            #     get_messages=True,
            # )

            # Use previously inferred AR, and add message from FEF25-75
            m_to_factor = ecFEF2575prctecFEV1.get_point_message(row["ecFEF2575%ecFEV1"])
            factor_to_AR = np.matmul(m_to_factor, ecFEF2575prctecFEV1.cpt)
            factor_to_AR = factor_to_AR / factor_to_AR.sum()

            dist_AR = res2[AR.name].values * factor_to_AR
            dist_AR = dist_AR / dist_AR.sum()

            # The probability of the data given the model is the expectation of the data given the model
            idx_obs_ecFEV1 = ecFEV1.get_bin_for_value(row.ecFEV1)[1]
            idx_obs_ecFEF2575 = ecFEF2575prctecFEV1.get_bin_for_value(
                row["ecFEF2575%ecFEV1"]
            )[1]
            # idx_obs_O2Sat = O2Sat.get_bin_for_value(row["O2 Saturation"])[1]

            # Get the probability of the data given the model
            p_ecFEV1 = dist_ecFEV1[idx_obs_ecFEV1]
            p_ecFEF2575 = dist_ecFEF2575prctecFEV1[idx_obs_ecFEF2575]
            # p_O2Sat = dist_O2Sat[idx_obs_O2Sat]

            # Save information for this round
            AR_dist_given_M_matrix[n, :, h] = dist_AR
            log_p_D_given_M[n, h] = np.log(p_ecFEV1) + np.log(
                p_ecFEF2575
            )  # + np.log(p_O2Sat)
    toc = time.time()
    print(f"Time for {N_maybe_no_dups} entries: {toc-tic:.2f} s")

    if debug:
        print("log(P(D|M)), first row", log_p_D_given_M[0, :])

    if speedup:
        # Put back the duplicates
        # Repeat each element in the array by the number in the array dups
        log_p_D_given_M = np.repeat(log_p_D_given_M, n_dups, axis=0)
        AR_dist_given_M_matrix = np.repeat(AR_dist_given_M_matrix, n_dups, axis=0)
        if debug:
            print("P(D|M), first row, after applying duplicates", log_p_D_given_M[:, 0])

    # For each HFEV1 model, given HFEV1_obs_list, we compute the log probability of the model given the data
    # log(P(M|D)) = 1/N * sum_n log(P(D|M)) + Cn_avg + log(P(M))
    log_p_M_given_D = np.zeros(H)
    for h, (HFEV1_obs, HO2Sat_obs) in enumerate(H_obs_list):
        log_p_M_hfev1 = np.log(HFEV1.cpt[HFEV1.get_bin_for_value(HFEV1_obs)[1]])
        log_p_M_ho2sat = np.log(HO2Sat.cpt[HO2Sat.get_bin_for_value(HO2Sat_obs)[1]])
        log_p_M_given_D[h] = (
            np.sum(log_p_D_given_M[:, h]) + log_p_M_hfev1 + log_p_M_ho2sat
        )

    # Exponentiating very negative numbers gives too small numbers
    # Setting the highest number to 1
    shift = 1 - log_p_M_given_D.max()
    log_p_M_given_D_shifted = log_p_M_given_D + shift

    # Exponentiate and normalise
    p_M_given_D = np.exp(log_p_M_given_D_shifted)
    p_M_given_D = p_M_given_D / p_M_given_D.sum()
    AR_dist_matrix = np.matmul(AR_dist_given_M_matrix, p_M_given_D)

    # Reshape P(M|D) into a 2D array for each HFEV1_obs, HO2Sat_obs
    p_M_given_D = p_M_given_D.reshape((len(HFEV1_obs_list), HO2Sat.card))

    # Fill the p(M|D) array with zeros on the left, where the HFEV1_obs < max ecFEV1
    n_impossible_hfev1_values = HFEV1.card - len(HFEV1_obs_list)
    p_M_given_D_full = np.vstack(
        [np.zeros((n_impossible_hfev1_values, HO2Sat.card)), p_M_given_D]
    )

    # Get the probability of HFEV1
    p_HFEV1_given_D = p_M_given_D_full.sum(axis=1)

    # Add plot
    layout = [
        [{"type": "scatter", "rowspan": 1, "colspan": 1}, None, None],
        [{"type": "heatmap", "rowspan": 3, "colspan": 3}, None, None],
        [None, None, None],
        [None, None, None],
    ]
    fig = make_subplots(
        rows=np.shape(layout)[0],
        cols=np.shape(layout)[1],
        specs=layout,
        vertical_spacing=0.1,
    )

    # Add HFEV1 posterior
    ih.plot_histogram(fig, HFEV1, p_HFEV1_given_D, 0, 6, 1, 1, annot=True)

    # Add heatmap with AR posteriors
    df1 = pd.DataFrame(
        data=AR_dist_matrix,
        columns=AR.get_bins_str(),
        index=df_for_ID_in["Date Recorded"].apply(
            lambda date: date.strftime("%Y-%m-%d")
        ),
    )
    colorscale = [
        [0, "white"],
        [0.01, "red"],
        [0.05, "yellow"],
        [0.1, "cyan"],
        [0.6, "blue"],
        [1, "black"],
    ]

    fig.add_trace(
        go.Heatmap(z=df1.T, x=df1.index, y=df1.columns, coloraxis="coloraxis1"),
        row=2,
        col=1,
    )

    speedup = " (with speedup)" if speedup else ""

    title = f"{id} - Posterior HFEV1 after fusing all P(M_h|D)<br>AR prior: {ar_prior}{speedup}"
    fig.update_layout(
        font=dict(size=12),
        height=700,
        width=1200,
        title=title,
        coloraxis1=dict(
            colorscale=colorscale,
            colorbar_x=1,
            colorbar_y=0.36,
            # colorbar_thickness=23,
            colorbar_len=0.77,
        ),
    )
    # Add Date on x axis
    fig.update_xaxes(title_text=HFEV1.name, row=1, col=1)
    fig.update_yaxes(title_text="p", row=1, col=1)
    fig.update_yaxes(title_text=AR.name, row=2, col=1)
    fig.update_xaxes(
        title_text="Date",
        row=2,
        col=1,
        nticks=50,
        type="category",
    )

    if save:
        fig.write_image(
            dh.get_path_to_main() + f"/PlotsBreathe/Cutset_conditioning/{title}.png",
            scale=3,
        )
    else:
        fig.show()

    return fig, p_M_given_D_full, p_M_given_D, AR_dist_given_M_matrix
    # return

In [None]:
ar_prior = "breathe (2 days model, ecFEV1, ecFEF25-75)"
ar_prior = "uniform"
# p_M_given_D_plot, fig = compute_log_p_D_given_M_per_entry_per_HFEV1_obs(dftmp, debug=False, save=False, speedup=True, ar_prior=ar_prior)

dftmp = df[df.ID == "527"]
fig, p_M_given_D_full, p_M_given_D, AR_dist_given_M_matrix = (
    compute_log_p_D_given_M_per_entry_per_HFEV1_obs(
        dftmp, debug=False, save=True, speedup=True, ar_prior=ar_prior
    )
)

ID 527 - Number of HFEV1, HO2Sat specific models: 2000, max ecFEV1: 0.92, first possible bin for HFEV1: [1.00; 1.05)
5 entries before speedup
5 entries after speedup
Number of duplicates 0, speedup removes 0.00% of entries


In [15]:
UO2Sat.cpt

NameError: name 'UO2Sat' is not defined