Idea of cutset conditioning: it's a way to run exact inference on a model with loops. You cut the loop by observing one of the variables in the loop to all the possible states, then fuse the results in a smart way.

 Cutset Conditioning is a technique for solving nearly-tree-structured CSPs in which some variables are assigned to separately from the rest, removed from the constraint graph, and leaving a tree-structured CSP for those remaining.

 Cutsets are some set of variables that are cut (severing edges) from the original constraint graph and solved separately.

 Conditioning is the process of assigning a value to some variable in a cutset, performing forward checking on its neighbor domains before cutting, and finally, severing it from the original graph.

https://forns.lmu.build/classes/spring-2019/cmsi-282/lecture-13M.html#backtracking++

In [1]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

import src.data.breathe_data as bd
import src.data.helpers as dh
import src.inference.helpers as ih
import src.models.var_builders as var_builders
import src.models.builders as mb
import src.inf_cutset_conditioning.cutset_cond_algs as algs

Figure per entry that has the AR from obs FEF2575 on top and on the bottom the point mass AR obtained by repeating model runs with several point mass HFEV1 (3, 3.5, 4, 4.5, 5, etc)

In [2]:
df = bd.load_meas_from_excel("BR_O2_FEV1_FEF2575_conservative_smoothing_with_idx")
# df = bd.load_meas_from_excel("BR_O2_FEV1_FEF2575_with_idx")

INFO:root:* Checking for same day measurements *


# Visualisations of the alignment between the message from FEF25-75 and from FEV1/HFEV1 factors to AR

### Two plots

In [10]:
# With each run I should retrieve
# 1/ the message from FEF25-75%FEFV1 to AR
# 2/ the point mass message from the factor ecFEV1, HFEV1 to AR
# Use the point in time model, there is no shared variables.


def can_messages_align_for_ID(df_for_ID):
    df_for_ID.reset_index(inplace=True, drop=True)
    height = df_for_ID.loc[0, "Height"]
    age = df_for_ID.loc[0, "Age"]
    sex = df_for_ID.loc[0, "Sex"]
    id = df_for_ID.loc[0, "ID"]
    (
        model,
        inf_alg,
        HFEV1,
        ecFEV1,
        AR,
        HO2Sat,
        O2SatFFA,
        IA,
        UO2Sat,
        O2Sat,
        ecFEF2575prctecFEV1,
    ) = mb.o2sat_fev1_fef2575_point_in_time_model_shared_healthy_vars(height, age, sex)

    FEV_to_AR_key = "['ecFEV1 (L)', 'Healthy FEV1 (L)', 'Airway resistance (%)'] -> Airway resistance (%)"
    FEF2575_to_AR_key = (
        "['ecFEF25-75 % ecFEV1 (%)', 'Airway resistance (%)'] -> Airway resistance (%)"
    )

    HFEV1_obs_list = [2, 2.5, 3, 3.5, 4, 4.5, 5, 5.5]
    colour_list = px.colors.sample_colorscale(
        "YlGnBu", [i / (len(HFEV1_obs_list) - 1) for i in range(len(HFEV1_obs_list))]
    )

    df_for_ID = df_for_ID.sort_values(by="ecFEF2575%ecFEV1", ascending=True)
    # Take 4 idx in 5, 30, 60, 95 percentiles of the data
    idx_list = list((len(df_for_ID) * np.array([0.05, 0.5, 0.95])).astype(int))
    df_for_ID_sub = df_for_ID.iloc[idx_list, :]

    res_per_idx = []

    for idx in df_for_ID_sub.index:
        FEV1_obs = df_for_ID.loc[idx, "ecFEV1"]
        FEF2575prctFEV1_obs = df_for_ID.loc[idx, "ecFEF2575%ecFEV1"]
        FEV_m_list = []

        # Query AR
        for HFEV1_obs in HFEV1_obs_list:
            # HFEV1_obs must be > ecFEV1_obs
            evidence = [
                [ecFEV1, FEV1_obs],
                [ecFEF2575prctecFEV1, FEF2575prctFEV1_obs],
                [HFEV1, HFEV1_obs],
            ]
            _, messages = ih.infer_on_factor_graph(
                inf_alg, [AR], evidence, get_messages=True
            )

            FEV_m_list.append(messages[FEV_to_AR_key])
            FEF2575_m = messages[FEF2575_to_AR_key]

        res_per_idx.append([FEV1_obs, FEF2575prctFEV1_obs, FEV_m_list, FEF2575_m])

    fig = make_subplots(rows=6, cols=1, vertical_spacing=0.05)
    plot_row = 1
    for FEV1_obs, FEF2575prctFEV1_obs, FEV_m_list, FEF2575_m in res_per_idx:

        for HFEV1_obs, FEV_m, colour in zip(HFEV1_obs_list, FEV_m_list, colour_list):
            ih.plot_histogram(
                fig,
                AR,
                FEV_m,
                AR.a,
                AR.b,
                plot_row,
                1,
                name=f"HFEV1 = {HFEV1_obs}",
                annot=False,
            )
            # Change the last trace's colour
            fig.data[-1].marker.color = colour
            # Hide legend if plot_row > 1
            if plot_row > 1:
                fig.data[-1].showlegend = False

        ih.plot_histogram(
            fig,
            AR,
            FEF2575_m,
            AR.a,
            AR.b,
            plot_row + 1,
            1,
            annot=False,
            title=AR.name,
            colour="grey",
        )
        # hide this last trace's legend
        fig.data[-1].showlegend = False
        # Add message from ecFEV1/HFEV1 factor on y axis row 1 title
        fig.update_yaxes(title_text=f"ecFEV1<br>{FEV1_obs:.2f}L", row=plot_row, col=1)
        fig.update_yaxes(
            title_text=f"ecFEF25-75%ecFEV1<br>{FEF2575prctFEV1_obs:.2f}%",
            row=plot_row + 1,
            col=1,
        )
        plot_row += 2

    # Reduce font size and margins
    title = f"ID {id} - Can points mass messages from HFEV1, ecFEV1 align with messages from FEF25-75"
    # Reduce margins between plots
    fig.update_layout(
        font=dict(size=8),
        margin=dict(l=10, r=10, t=30, b=10),
        height=750,
        width=600,
        barmode="overlay",
        bargap=0.1,
        title=title,
    )
    fig.update_xaxes(title_standoff=6)

    fig.write_image(
        dh.get_path_to_main() + f"/PlotsBreathe/Cutset_conditioning/{title}.pdf"
    )
    # fig.show()


interesting_ids = [
    "132",
    "146",
    "177",
    "180",
    "202",
    "527",
    "117",
    "131",
    "134",
    "191",
    "139",
    "253",
    "101",
    # Also from consec values
    "405",
    "272",
    "201",
    "203",
]

# df[df.ID.isin(interesting_ids)].groupby("ID").apply(can_messages_align_for_ID)

df_for_ID = df[df.ID == "101"]
can_messages_align_for_ID(df_for_ID)

### Heatmaps of FEF2575 messages vs FEV1 messages for different HFEV1

In [6]:
# With each run I should retrieve
# 1/ the message from FEF25-75%FEFV1 to AR
# 2/ the point mass message from the factor ecFEV1, HFEV1 to AR
# Use the point in time model, there is no shared variables.


def can_messages_align_for_ID_heatmap(df_for_ID, save=True):
    df_for_ID.reset_index(inplace=True, drop=True)
    height = df_for_ID.loc[0, "Height"]
    age = df_for_ID.loc[0, "Age"]
    sex = df_for_ID.loc[0, "Sex"]
    (
        model,
        inf_alg,
        HFEV1,
        ecFEV1,
        AR,
        HO2Sat,
        O2SatFFA,
        IA,
        UO2Sat,
        O2Sat,
        ecFEF2575prctecFEV1,
    ) = mb.o2sat_fev1_fef2575_point_in_time_model_shared_healthy_vars(height, age, sex)

    FEV_to_AR_key = "['ecFEV1 (L)', 'Healthy FEV1 (L)', 'Airway resistance (%)'] -> Airway resistance (%)"
    FEF2575_to_AR_key = (
        "['ecFEF25-75 % ecFEV1 (%)', 'Airway resistance (%)'] -> Airway resistance (%)"
    )

    HFEV1_obs_list = [2, 3, 4, 5]
    # Compare obs list to min obs fev1
    min_obs_fev1 = df_for_ID.ecFEV1.min()
    HFEV1_obs_list = [
        HFEV1_obs for HFEV1_obs in HFEV1_obs_list if HFEV1_obs > min_obs_fev1
    ]

    # Dates on the xaxis, AR on the y axis
    FEV_m_arr = np.zeros((AR.card, len(df_for_ID)))
    FEF2575_m_arr = np.zeros((AR.card, len(df_for_ID)))

    for i, row in df_for_ID.iterrows():
        FEV1_obs = row.ecFEV1
        FEF2575prctFEV1_obs = row["ecFEF2575%ecFEV1"]

        # Query AR
        FEV_m_one_day = np.zeros(AR.card)
        for HFEV1_obs in HFEV1_obs_list:
            # HFEV1_obs must be > ecFEV1_obs
            evidence = [
                [ecFEV1, FEV1_obs],
                [ecFEF2575prctecFEV1, FEF2575prctFEV1_obs],
                [HFEV1, HFEV1_obs],
            ]
            _, messages = ih.infer_on_factor_graph(
                inf_alg, [AR], evidence, get_messages=True
            )

            # Since the messages are "almost" point mass (max over 2 bins)
            # we'll just put the value for the heatmap at the location of the mean
            AR_mean_val = AR.get_mean(messages[FEV_to_AR_key])
            AR_mean_idx = AR.get_bin_for_value(AR_mean_val)[1]
            # Add intensity value at the location of the AR mean
            FEV_m_one_day[AR_mean_idx] = HFEV1_obs

        FEV_m_arr[:, i] = FEV_m_one_day
        fef2575_m = messages[FEF2575_to_AR_key]
        # Make sure the messages are normalised - yes it is the case indeed
        fef2575_m = fef2575_m / fef2575_m.sum()
        FEF2575_m_arr[:, i] = fef2575_m

    df_for_ID["Date"] = pd.to_datetime(df_for_ID["Date Recorded"]).copy()
    df_for_ID["Date"] = df_for_ID["Date"].dt.strftime("%d-%m-%Y")

    fig = go.Figure(
        data=go.Heatmap(
            z=FEF2575_m_arr,
            x=df_for_ID["Date"],
            y=AR.get_bins_str(),
            opacity=0.8,
            colorscale="Blues",
            # Exclude from colour bar
            showscale=False,
        )
    )

    colorscale = [
        [0, "rgba(0, 0, 0, 0)"],  # Transparent for value 0
        [1 / 5, "rgba(0, 0, 0, 0)"],  # Transparent for value 0
        [1 / 5, "rgb(255, 245, 235)"],  # Light orange for value 2
        [2 / 5, "rgb(255, 245, 235)"],  # Light orange for value 2
        [2 / 5, "rgb(254, 230, 206)"],  # Medium-light orange for value 3
        [3 / 5, "rgb(254, 230, 206)"],  # Medium-light orange for value 3
        [3 / 5, "rgb(253, 174, 107)"],  # Medium orange for value 4
        [4 / 5, "rgb(253, 174, 107)"],  # Medium orange for value 4
        [4 / 5, "rgb(241, 105, 19)"],  # Dark orange for value 5
        # [5/5, 'rgb(241, 105, 19)'],  # Dark orange for value 5
        # [5/5, 'rgb(217, 72, 1)'],  # Darker orange for value 6
        [1, "rgb(217, 72, 1)"],  # Darker orange for value 6
    ]

    fig.add_traces(
        go.Heatmap(
            z=FEV_m_arr,
            x=df_for_ID["Date"],
            y=AR.get_bins_str(),
            # Change colour
            colorscale=colorscale,
        )
    )

    title = f"{df_for_ID.loc[0, 'ID']} - Heatmaps messages alignment from HFEV1, ecFEV1 to AR and FEF25-75 to AR"
    fig.update_layout(
        font=dict(size=6), height=600, width=len(df_for_ID) + 400, title=title
    )
    # Add Date on x axis
    fig.update_xaxes(title_text="Date", tickangle=45)
    fig.update_yaxes(title_text="Airway resistance (%)")

    if save:
        fig.write_image(
            dh.get_path_to_main() + f"/PlotsBreathe/Cutset_conditioning/{title}.png",
            scale=3,
        )
    else:
        fig.show()

    return fig, FEV_m_arr, FEF2575_m_arr


interesting_ids = [
    "132",
    "146",
    "177",
    "180",
    "202",
    "527",
    "117",
    "131",
    "134",
    "191",
    "139",
    "253",
    "101",
    # Also from consec values
    "405",
    "272",
    "201",
    "203",
]

df[df.ID.isin(interesting_ids)].groupby("ID").apply(can_messages_align_for_ID_heatmap)

# df_for_ID = df[df.ID == "191"]
# fig, FEV_m_arr, FEF2575_m_arr = can_messages_align_for_ID_heatmap(df_for_ID, save=False)


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide


invalid value encountered in divide



ID
101    (Figure({\n    'data': [{'colorscale': [[0.0, ...
117    (Figure({\n    'data': [{'colorscale': [[0.0, ...
131    (Figure({\n    'data': [{'colorscale': [[0.0, ...
132    (Figure({\n    'data': [{'colorscale': [[0.0, ...
134    (Figure({\n    'data': [{'colorscale': [[0.0, ...
139    (Figure({\n    'data': [{'colorscale': [[0.0, ...
146    (Figure({\n    'data': [{'colorscale': [[0.0, ...
177    (Figure({\n    'data': [{'colorscale': [[0.0, ...
180    (Figure({\n    'data': [{'colorscale': [[0.0, ...
191    (Figure({\n    'data': [{'colorscale': [[0.0, ...
201    (Figure({\n    'data': [{'colorscale': [[0.0, ...
202    (Figure({\n    'data': [{'colorscale': [[0.0, ...
203    (Figure({\n    'data': [{'colorscale': [[0.0, ...
253    (Figure({\n    'data': [{'colorscale': [[0.0, ...
272    (Figure({\n    'data': [{'colorscale': [[0.0, ...
405    (Figure({\n    'data': [{'colorscale': [[0.0, ...
527    (Figure({\n    'data': [{'colorscale': [[0.0, ...
dtype: object

# Fusing the weights

### Evaluate computational speedup by avoiding to calculate doublons

In [7]:
(
    model,
    inf_alg,
    HFEV1,
    ecFEV1,
    AR,
    HO2Sat,
    O2SatFFA,
    IA,
    UO2Sat,
    O2Sat,
    ecFEF2575prctecFEV1,
) = mb.o2sat_fev1_fef2575_point_in_time_model_shared_healthy_vars(120, 12, "Male")

In [65]:
def get_speedup_prct_for_id(df_for_ID):
    # How many entries have the same bin in ecFEV1 and ecFEF2575%ecFEV1
    # This trick wouldn't improve the computation time much
    n_data_no_duplicates = len(
        df_for_ID.groupby(
            ["idx ecFEV1 (L)", "idx ecFEF2575%ecFEV1", "idx O2 saturation (%)"]
        )
        .size()
        .sort_values(ascending=False)
    )
    # 10 s for 10 entries
    time_per_entry = 12.5 / 10
    return (
        len(df_for_ID),
        len(df_for_ID) * time_per_entry,
        n_data_no_duplicates,
        n_data_no_duplicates * time_per_entry,
    )


times = df.groupby("ID").apply(get_speedup_prct_for_id).sort_values(ascending=False)

t_before = 0
t_after = 0
n_before = 0
n_after = 0
for i in range(len(times)):
    n_ops_b, t_ops_b, n_ops_a, t_ops_a = times.values[i]
    t_before += t_ops_b
    t_after += t_ops_a
    n_before += n_ops_b
    n_after += n_ops_a
print(
    f"Before: {t_before/3600:.0f} h, After: {t_after/3600:.0f} h. Speedup: {t_before/t_after:.2f}"
)
print(f"Before: {n_before}, After: {n_after:}. Speedup: {n_before/n_after:.2f}")

times

Before: 14 h, After: 9 h. Speedup: 1.56
Before: 41260, After: 26456. Speedup: 1.56


ID
101      (1680, 2100.0, 206, 257.5)
123     (1128, 1410.0, 601, 751.25)
240    (1101, 1376.25, 509, 636.25)
133      (1066, 1332.5, 502, 627.5)
405     (1035, 1293.75, 234, 292.5)
                   ...             
225              (1, 1.25, 1, 1.25)
213              (1, 1.25, 1, 1.25)
516              (1, 1.25, 1, 1.25)
160              (1, 1.25, 1, 1.25)
428              (1, 1.25, 1, 1.25)
Length: 352, dtype: object

### Actually fusing weights

In [30]:
# ar_prior = "breathe (2 days model, ecFEV1, ecFEF25-75)"
ar_prior = "uniform"
# p_M_given_D_plot, fig = compute_log_p_D_given_M_per_entry_per_HFEV1_obs(dftmp, debug=False, save=False, speedup=True, ar_prior=ar_prior)

dftmp = df[df.ID == "527"]
fig, p_M_given_D_full, p_M_given_D, AR_dist_given_M_matrix = (
    compute_log_p_D_given_M_per_entry_per_HFEV1_obs(
        dftmp, debug=False, save=False, speedup=True, ar_prior=ar_prior
    )
)

NameError: name 'compute_log_p_D_given_M_per_entry_per_HFEV1_obs' is not defined

# Infering HFEV1, AR, IA through time while observing ecFEV1, ecFEF25-75, SpO2

In [3]:
(
    HFEV1,
    uecFEV1,
    ecFEV1,
    AR,
    HO2Sat,
    O2SatFFA,
    IA,
    UO2Sat,
    O2Sat,
    ecFEF2575prctecFEV1,
) = var_builders.o2sat_fev1_fef2575_point_in_time_model_noise_shared_healthy_vars(
    180, 10, "Male", ecfev1_noise_model_cpt_suffix="_std_0.23"
)

In [4]:
# Create a mock df with 100 entries with id 1, the same columns as the original df but all nan
ndays = 1000
df_mock = pd.DataFrame(
    np.nan,
    index=np.arange(ndays),
    columns=df.columns,
)
df_mock["ID"] = "1"
df_mock["Height"] = 180
df_mock["Age"] = 30
df_mock["Sex"] = "Male"
# Date recorded is a datetime
df_mock["Date Recorded"] = pd.date_range("2020-01-01", periods=ndays, freq="D")

In [None]:
df_for_ID = df[df.ID == "117"]

dftmp, _, _ = dh.find_longest_consec_series(df_for_ID)

ecfef2575_cols = [
    "ecFEF2575%ecFEV1",
    "idx ecFEF2575%ecFEV1",
    "idx ecFEF25-75 % ecFEV1 (%)",
]
ecfev1_cols = [
    "ecFEV1",
    "idx ecFEV1 (L)",
]

# obs = "ecFEV1, ecFEF25-75"
obs = "ecFEV1"
dftmp[ecfef2575_cols] = np.nan
# obs = "no obs"
# dftmp[ecfev1_cols + ecfef2575_cols] = np.nan


(
    fig,
    p_M_given_D,
    log_p_D_given_M,
    AR_given_M_and_D,
    AR_given_M_and_all_D,
    res_dict,
) = algs.run_long_noise_model_through_time(
    dftmp,
    ar_prior="uniform",
    # ar_prior="uniform message to HFEV1",
    # ar_prior="breathe (2 days model, ecFEV1 addmultnoise, ecFEF25-75)",
    ar_change_cpt_suffix="_shift_span_[-20;20]_joint_sampling_3_days_model_ecfev1std0.068",
    # ar_change_cpt_suffix="_shift_span_[-20;20]_joint_sampling_3_days_model_ecfev1std0.23",
    ecfev1_noise_model_suffix="_std_0.068",
    fef2575_cpt_suffix="",
    # save=True
)
dftmp.head()

117 - Time for 69 entries: 18.66 s


Unnamed: 0,ID,Date Recorded,FEV1,O2 Saturation,FEF2575,ecFEV1,ecFEF2575,Sex,Height,Age,...,ecFEV1 % Predicted,FEV1 % Predicted,O2 Saturation % Healthy,ecFEF2575%ecFEV1,idx ecFEV1 (L),idx O2 saturation (%),idx ecFEF2575%ecFEV1,idx ecFEF25-75 % ecFEV1 (%),Prev day,Days elapsed
64,117,2020-03-15,1.6,98,1.22,1.6,1.22,Female,162.0,31,...,50.401515,50.401515,99.828496,,32,48,,,2019-07-22,"237 days, 0:00:00"
65,117,2020-03-16,1.7,98,0.97,1.7,1.22,Female,162.0,31,...,53.55161,53.55161,99.828496,,34,48,,,2020-03-15,"1 day, 0:00:00"
66,117,2020-03-17,1.84,98,0.61,1.84,1.27,Female,162.0,31,...,57.961743,57.961743,99.828496,,36,48,,,2020-03-16,"1 day, 0:00:00"
67,117,2020-03-18,1.79,99,1.27,1.79,1.27,Female,162.0,31,...,56.386695,56.386695,100.847154,,35,49,,,2020-03-17,"1 day, 0:00:00"
68,117,2020-03-19,1.87,99,0.76,1.87,1.27,Female,162.0,31,...,58.906771,58.906771,100.847154,,37,49,,,2020-03-18,"1 day, 0:00:00"


In [5]:
# Compare first, mid and last day ARs
fig = make_subplots(rows=4, cols=1, vertical_spacing=0.12)

ih.plot_histogram(
    fig, AR, AR_given_M_and_D[0], AR.a, AR.b, 1, 1, name="First day", colour="#636EFA"
)
fig.update_xaxes(title_text="Airway resistance on day 1 (%)", row=1, col=1)
day = 2
ih.plot_histogram(
    fig,
    AR,
    AR_given_M_and_D[day - 1],
    AR.a,
    AR.b,
    2,
    1,
    name="Mid day",
    colour="#636EFA",
)
fig.update_xaxes(title_text=f"Airway resistance on day {day} (%)", row=2, col=1)
day = 100
ih.plot_histogram(
    fig,
    AR,
    AR_given_M_and_D[day - 1],
    AR.a,
    AR.b,
    3,
    1,
    name="Last day",
    colour="#636EFA",
)
fig.update_xaxes(title_text=f"Airway resistance on day {day} (%)", row=3, col=1)
day = 1000
ih.plot_histogram(
    fig,
    AR,
    AR_given_M_and_D[day - 1],
    AR.a,
    AR.b,
    4,
    1,
    name="Last day",
    colour="#636EFA",
)
fig.update_xaxes(title_text=f"Airway resistance on day {day} (%)", row=4, col=1)

fig.update_layout(
    font=dict(size=10),
    height=550,
    width=600,
    bargap=0.01,
    title="Airway resistance through time, when no data is observed",
    showlegend=False,
)
fig.update_xaxes(title_standoff=6)
fig.show()

IndexError: index 99 is out of bounds for axis 0 with size 69

In [7]:
# plot a heatmap of log_p_D_given_M, using the same colour scale as above
df1 = pd.DataFrame(
    data=np.exp(log_p_D_given_M),
    columns=HFEV1.get_bins_str(),
    index=dftmp["Date Recorded"].apply(lambda date: date.strftime("%Y-%m-%d")),
)

idx = 30
# idx = 7
fig = make_subplots(
    rows=4,
    subplot_titles=[
        "P(data | model conditionned on HFEV1_bin)",
        f"P(data | model) for date {df1.index[idx]}",
        # 'Evolution of airway resistance with HFEV1 instantiation'
    ],
    vertical_spacing=0.05,
)
colorscale = [
    [0, "white"],
    [0.01, "red"],
    [0.05, "yellow"],
    [0.1, "cyan"],
    [0.6, "blue"],
    [1, "black"],
]

fig.add_trace(
    go.Heatmap(z=df1.T, x=df1.index, y=df1.columns, coloraxis="coloraxis1"),
    row=1,
    col=1,
)
fig.update_yaxes(title=f"HFEV1 state", row=1, col=1)

ih.plot_histogram(fig, HFEV1, np.exp(log_p_D_given_M[idx, :]), 1, 6, 2, 1, annot=True)
# x axis for row 2: HFEV1 posterior when HFEV1 instanciated to 10
fig.update_xaxes(title=f"", row=2, col=1)
fig.update_yaxes(title=f"P(data | model)", row=2, col=1)

# Add p(ecFEV1)
ecFEV1_dist_with_obs = res_dict["ecFEV1"]
for i, obs in enumerate(dftmp["idx ecFEV1 (L)"].values):
    ecFEV1_dist_with_obs[i, obs, :] = 0.4

df3 = pd.DataFrame(
    data=ecFEV1_dist_with_obs[idx, :, :].T,
    # data=AR_given_M_and_all_D[idx, :, :].T,
    columns=ecFEV1.midbins,
    index=HFEV1.midbins,
)

fig.add_trace(
    go.Heatmap(
        z=df3.T,
        x=df3.index,
        y=df3.columns,
        coloraxis="coloraxis2",
    ),
    row=3,
    col=1,
)
fig.update_xaxes(title=f"", nticks=20, row=3, col=1)
fig.update_yaxes(title=f"p(ecFEV1)", nticks=8, row=3, col=1)

# Add p(ecFEF2575%ecFEV1)
ecFEF2575_dist_with_obs = res_dict["ecFEF2575%ecFEV1"]
for i, obs in enumerate(dftmp["idx ecFEF2575%ecFEV1"].values):
    ecFEF2575_dist_with_obs[i, obs, :] = 0.4

df4 = pd.DataFrame(
    data=ecFEF2575_dist_with_obs[idx, :, :].T,
    columns=ecFEF2575prctecFEV1.midbins,
    index=HFEV1.midbins,
)

fig.add_trace(
    go.Heatmap(
        z=df4.T,
        x=df4.index,
        y=df4.columns,
        coloraxis="coloraxis3",
    ),
    row=4,
    col=1,
)
fig.update_xaxes(title=f"", nticks=20, row=4, col=1)
fig.update_yaxes(title=f"p(ecFEF2575%ecFEV1)", nticks=8, row=4, col=1)

# df2 = pd.DataFrame(
#     data=res_dict["vevidence_ar"][idx, :, :].T,
#     # data=AR_given_M_and_all_D[idx, :, :].T,
#     columns=AR.midbins,
#     index=HFEV1.midbins,
# )
# fig.add_trace(
#     go.Heatmap(z=df2.T, x=df2.index, y=df2.columns, coloraxis="coloraxis3"),
#     row=4,
#     col=1,
# )
# fig.update_xaxes(title=f"HFEV1 state", nticks=20, row=4, col=1)
# fig.update_yaxes(title=f"AR v evidence", nticks=10, row=4, col=1)

# Set colorscale for color axis
fig.update_layout(
    coloraxis1=dict(colorscale=colorscale),
    coloraxis2=dict(colorscale=colorscale),
    coloraxis3=dict(colorscale=colorscale),
)


title = f"ID {dftmp.ID.iloc[0]} - idx {idx}"
fig.update_layout(height=800, width=800, title=title, font=dict(size=10))

fig.show()

In [53]:
# plot a heatmap of log_p_D_given_M, using the same colour scale as above
df1 = pd.DataFrame(
    data=np.exp(log_p_D_given_M),
    columns=HFEV1.get_bins_str(),
    index=dftmp["Date Recorded"].apply(lambda date: date.strftime("%Y-%m-%d")),
)

idx = 30
# idx = 7
fig = make_subplots(
    rows=3,
    vertical_spacing=0.05,
)
colorscale = [
    [0, "white"],
    [0.01, "red"],
    [0.05, "yellow"],
    [0.1, "cyan"],
    [0.6, "blue"],
    [1, "black"],
]

ih.plot_histogram(fig, HFEV1, np.exp(log_p_D_given_M[idx, :]), 1, 6, 1, 1, annot=True)
fig.update_xaxes(title=f"", row=1, col=1)
fig.update_yaxes(title=f"P(data | model)", row=1, col=1)

# Add p(ecFEV1)
ecFEV1_dist_with_obs = res_dict["ecFEV1"]
for i, obs in enumerate(dftmp["idx ecFEV1 (L)"].values):
    ecFEV1_dist_with_obs[i, obs, :] = 0.4

df3 = pd.DataFrame(
    data=ecFEV1_dist_with_obs[idx, :, :].T,
    columns=ecFEV1.midbins,
    index=HFEV1.midbins,
)

fig.add_trace(
    go.Heatmap(
        z=df3.T,
        x=df3.index,
        y=df3.columns,
        coloraxis="coloraxis1",
    ),
    row=2,
    col=1,
)
fig.update_xaxes(title=f"", nticks=8, row=2, col=1)
fig.update_yaxes(title=f"p(ecFEV1)", nticks=8, row=2, col=1)

# Add p(ecFEF2575%ecFEV1)
ecFEF2575_dist_with_obs = res_dict["ecFEF2575%ecFEV1"]
for i, obs in enumerate(dftmp["idx ecFEF2575%ecFEV1"].values):
    ecFEF2575_dist_with_obs[i, obs, :] = 0.1

df4 = pd.DataFrame(
    data=ecFEF2575_dist_with_obs[idx, :, :].T,
    columns=ecFEF2575prctecFEV1.midbins,
    index=HFEV1.midbins,
)

fig.add_trace(
    go.Heatmap(
        z=df4.T,
        x=df4.index,
        y=df4.columns,
        coloraxis="coloraxis2",
    ),
    row=3,
    col=1,
)
fig.update_xaxes(title=f"", nticks=20, row=3, col=1)
fig.update_yaxes(title=f"p(ecFEF2575%ecFEV1)", nticks=4, row=3, col=1)

# Set colorscale for color axis
# fig.update_layout(
#     coloraxis1=dict(colorscale=colorscale),
#     coloraxis2=dict(colorscale=colorscale),
# )
# Set colorscale for coor axis and position colorscale on the same row
fig.update_layout(
    coloraxis2=dict(
        colorscale=colorscale,
        colorbar=dict(
            x=1.05,  # Position the colorbar to the right of the plot
            y=0.35,  # Position the colorbar in the middle of the plot vertically
            xanchor="left",
            yanchor="bottom",
            # color bar height
            len=0.3,
        ),
    ),
    coloraxis1=dict(
        colorscale=colorscale,
        colorbar=dict(
            x=1.05,  # Position the colorbar to the right of the plot
            y=0,  # Position the colorbar in the middle of the plot vertically
            xanchor="left",
            yanchor="bottom",
            # color bar height
            len=0.3,
        ),
    ),
)


title = f"ID {dftmp.ID.iloc[0]} - idx {idx}, date {df1.index[idx]}"
fig.update_layout(height=800, width=400, title=title, font=dict(size=10))

fig.show()

In [39]:
res_dict["ecFEV1"][:, 3, 3]

array([3.75601474e-02, 3.10259060e-26, 1.25722604e-27, 8.01532451e-28,
       9.49981792e-28, 7.18769612e-28, 6.72476324e-28, 6.72428323e-28,
       6.22114613e-28, 6.16962549e-28, 6.30880194e-28, 6.46070026e-28,
       6.30900341e-28, 6.22087312e-28, 6.30883988e-28, 6.46070236e-28,
       6.46082063e-28, 6.30900898e-28, 6.22087341e-28, 6.72379425e-28,
       6.72423650e-28, 7.18479149e-28, 6.72465409e-28, 6.46105062e-28,
       6.30901974e-28, 7.78703637e-28, 6.39688913e-28, 6.14459342e-28,
       6.46088066e-28, 6.30901322e-28, 6.30890597e-28, 6.72387350e-28,
       6.46101761e-28, 6.46083626e-28, 6.72401597e-28, 6.72424703e-28,
       6.30919535e-28, 6.46072197e-28, 6.22096573e-28, 6.46064409e-28,
       6.46081779e-28, 6.46082635e-28, 6.22097028e-28, 6.46064436e-28,
       6.30900089e-28, 6.46071120e-28, 6.22096520e-28, 6.72380028e-28,
       6.22112780e-28, 6.16962451e-28, 6.16957732e-28, 6.16957480e-28,
       6.13974707e-28, 6.16955608e-28, 6.12235413e-28, 6.12232674e-28,
      

In [35]:
ecFEV1.get_bin_for_value(2.35)

('[2.35; 2.40)', 47)