In [1]:
import src.inference.long_inf_slicing as slicing
import src.models.builders as mb
import src.inference.helpers as ih
from plotly.subplots import make_subplots
import logging

logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)
from pgmpy.inference.ExactInference import VariableElimination
from pgmpy.factors.discrete import TabularCPD


import pandas as pd
import numpy as np

In [2]:
# Adapted code to see if the forward days inference alg with precomputed messages is valid

df_mock = pd.DataFrame(
    {
        "ID": ["1", "1", "1"],
        "Date Recorded": [1, 2, 3],
        "Height": 180,
        "Age": 35,
        "Sex": "Male",
        "ecFEV1": [1.8, 3.5, 0.1],
        "ecFEF2575%ecFEV1": [12, 120, 150],
        "idx ecFEV1 (L)": [1, 3, 0],
        "idx ecFEF2575%ecFEV1": [0, 6, 7],
    }
)
df_mock["Date Recorded"] = pd.to_datetime(
    df_mock["Date Recorded"], unit="D", origin="2020-01-01"
)
df_mock

Unnamed: 0,ID,Date Recorded,Height,Age,Sex,ecFEV1,ecFEF2575%ecFEV1,idx ecFEV1 (L),idx ecFEF2575%ecFEV1
0,1,2020-01-02,180,35,Male,1.8,12,1,0
1,1,2020-01-03,180,35,Male,3.5,120,3,6
2,1,2020-01-04,180,35,Male,0.1,150,0,7


## Validate query_forwardly_across_days() with precomp messages

In [3]:
# Adapted code to see if the forward days inference alg with precomputed messages is valid
# Uses model with FEV1 (with its noise model), FEF25-75

ecfev1_noise_model_cpt_suffix = "_std_0.7"
n_days = 3

# Set values
height = df_mock.Height.iloc[0]
age = df_mock.Age.iloc[0]
sex = df_mock.Sex.iloc[0]

(
    _,
    inf_alg_approx,
    HFEV1,
    uecFEV1,
    ecFEV1,
    AR,
    HO2Sat,
    O2SatFFA,
    IA,
    UO2Sat,
    O2Sat,
    ecFEF2575prctecFEV1,
) = mb.o2sat_fev1_fef2575_point_in_time_model_noise_shared_healthy_vars_light(
    height,
    age,
    sex,
    ecfev1_noise_model_cpt_suffix=ecfev1_noise_model_cpt_suffix,
)

# Set inputs for approximate inference
shared_vars_approx = [HFEV1, HO2Sat]
vars_approx = [AR]
# obs_var_names_approx = [ecFEV1.name, O2Sat.name]
obs_var_names_approx = [ecFEV1.name]

# APPROXIMATE INFERENCE
# Get precompupted messages to speedup inference
arr = np.ones(AR.card)
arr /= arr.sum()
uniform_from_o2_side = {
    "['O2 saturation if fully functional alveoli (%)', 'Healthy O2 saturation (%)', 'Airway resistance (%)'] -> Airway resistance (%)": arr
}
precomp_messages = uniform_from_o2_side
df_query_res, df_res_before_convergence, shared_vars_final = (
    slicing.query_forwardly_across_days(
        df_mock,
        inf_alg_approx,
        shared_vars_approx,
        vars_approx,
        obs_var_names_approx,
        1e-8,
        days_specific_evidence=[],
        precomp_messages=precomp_messages.copy(),
        debug=False,
        auto_reset_shared_vars=True,
    )
)
hfev1_approx = df_query_res.loc[0, HFEV1.name]
ho2sat_approx = df_query_res.loc[0, HO2Sat.name]

# EXACT INFERENCE
# (
#     model_exact,
#     inf_alg_exact,
#     HFEV1,
#     HO2Sat,
#     AR_vars,
#     ecFEV1_vars,
#     O2SatFFA_vars,
#     IA_vars,
#     UO2Sat_vars,
#     O2Sat_vars,
# ) = mb.o2_sat_fev1_n_days_model_light(n_days, height, age, sex)
(
    model_exact,
    # inf_alg_exact,
    HFEV1,
    HO2Sat,
    AR_vars,
    uecFEV1_vars,
    ecFEV1_vars,
    O2SatFFA_vars,
    IA_vars,
    UO2Sat_vars,
    O2Sat_vars,
    ecFEF2575prctecFEV1,
) = mb.o2sat_fev1_fef2575_n_days_model_noise_shared_healthy_vars_light(
    n_days,
    height,
    age,
    sex,
    ecfev1_noise_model_cpt_suffix=ecfev1_noise_model_cpt_suffix,
)
var_elim = VariableElimination(model_exact)

shared_vars_exact = [HFEV1.name, HO2Sat.name]
obs_vars_exact = {}
for j in range(n_days):
    ecfev1_obs = df_mock.loc[j, "idx ecFEV1 (L)"]
    # o2sat_obs = df_n_days.reset_index().loc[j, "O2 Saturation"]
    obs_vars_exact[ecFEV1_vars[j].name] = ecfev1_obs
    # [O2Sat_vars[j], o2sat_obs],

res_exact = var_elim.query(
    variables=[HFEV1.name, AR_vars[0].name, AR_vars[1].name, AR_vars[2].name],
    evidence=obs_vars_exact,
    joint=False,
)

In [4]:
fig = make_subplots(rows=4, cols=1, vertical_spacing=0.13)
# Add HFEV1
ih.plot_histogram(
    fig, HFEV1, res_exact[HFEV1.name].values, HFEV1.a, HFEV1.b, 1, 1, annot=False
)
ih.plot_histogram(
    fig, AR, res_exact[AR_vars[0].name].values, AR.a, AR.b, 2, 1, annot=False
)
ih.plot_histogram(
    fig, AR, res_exact[AR_vars[1].name].values, AR.a, AR.b, 3, 1, annot=False
)
ih.plot_histogram(
    fig, AR, res_exact[AR_vars[2].name].values, AR.a, AR.b, 4, 1, annot=False
)

# Add HFEV1
ih.plot_histogram(
    fig,
    HFEV1,
    df_query_res.loc[0, HFEV1.name],
    HFEV1.a,
    HFEV1.b,
    1,
    1,
    title=HFEV1.name,
)
ih.plot_histogram(
    fig,
    AR,
    df_query_res.loc[0, AR.name],
    AR.a,
    AR.b,
    2,
    1,
    title=f"{AR.name} day 1",
    annot=False,
)
ih.plot_histogram(
    fig,
    AR,
    df_query_res.loc[1, AR.name],
    AR.a,
    AR.b,
    3,
    1,
    title=f"{AR.name} day 2",
    annot=False,
)
ih.plot_histogram(
    fig,
    AR,
    df_query_res.loc[2, AR.name],
    AR.a,
    AR.b,
    4,
    1,
    title=f"{AR.name} day 3",
    annot=False,
)

for i in range(4):
    fig.data[i].marker.color = "#636EFA"
    fig.data[i + 4].marker.color = "#EF553B"
# Reduce x axis title font size
fig.update_xaxes(title_font=dict(size=12), title_standoff=7)

# Hide legend
title = "Cutset conditioning (red) vs variable elimination (blue)"
fig.update_layout(showlegend=False, height=550, width=800, title=title)
fig.show()

## Validate query_forwardly_across_days() with day specific evidence

In [5]:
# Adapted code to see if the forward days inference alg with precomputed messages is valid
# Uses model with FEV1 (with its noise model), FEF25-75

ecfev1_noise_model_cpt_suffix = "_std_0.7"
n_days = 3

# Set values
height = df_mock.Height.iloc[0]
age = df_mock.Age.iloc[0]
sex = df_mock.Sex.iloc[0]

(
    _,
    inf_alg_approx,
    HFEV1,
    uecFEV1,
    ecFEV1,
    AR,
    HO2Sat,
    O2SatFFA,
    IA,
    UO2Sat,
    O2Sat,
    ecFEF2575prctecFEV1,
) = mb.o2sat_fev1_fef2575_point_in_time_model_noise_shared_healthy_vars_light(
    height,
    age,
    sex,
    ecfev1_noise_model_cpt_suffix=ecfev1_noise_model_cpt_suffix,
)

# Set inputs for approximate inference
shared_vars_approx = [HFEV1, HO2Sat]
vars_approx = [AR]
# obs_var_names_approx = [ecFEV1.name, O2Sat.name]
obs_var_names_approx = [ecFEV1.name]

# APPROXIMATE INFERENCE
# Get precompupted messages to speedup inference
arr = np.ones(AR.card)
arr /= arr.sum()
uniform_from_o2_side = {
    "['O2 saturation if fully functional alveoli (%)', 'Healthy O2 saturation (%)', 'Airway resistance (%)'] -> Airway resistance (%)": arr
}
precomp_messages = uniform_from_o2_side

# Day specific evidence
# Create exp decreasing evidence
# ar_obs_idx = np.exp(-np.arange(AR.card))

df_mock[AR.name] = -1
df_mock[f"idx {AR.name}"] = 1000
df_mock.loc[1, AR.name] = 5
ar_obs_idx_day_1 = AR.get_bin_idx_for_value(df_mock.loc[1, AR.name])
df_mock.loc[1, f"idx {AR.name}"] = ar_obs_idx_day_1
df_mock.loc[2, AR.name] = 12
ar_obs_idx_day_2 = AR.get_bin_idx_for_value(df_mock.loc[2, AR.name])
df_mock.loc[2, f"idx {AR.name}"] = ar_obs_idx_day_2

days_specific_evidence = [
    (AR.name, [df_mock.loc[1, "Date Recorded"], df_mock.loc[2, "Date Recorded"]])
]
print(f"Day specific evidence: {days_specific_evidence}")

df_query_res, df_res_before_convergence, shared_vars_final = (
    slicing.query_forwardly_across_days(
        df_mock,
        inf_alg_approx,
        shared_vars_approx,
        vars_approx,
        obs_var_names_approx,
        1e-8,
        days_specific_evidence,
        # precomp_messages={},
        precomp_messages=precomp_messages.copy(),
        debug=False,
        auto_reset_shared_vars=True,
    )
)
hfev1_approx = df_query_res.loc[0, HFEV1.name]
ho2sat_approx = df_query_res.loc[0, HO2Sat.name]

# EXACT INFERENCE
# (
#     model_exact,
#     inf_alg_exact,
#     HFEV1,
#     HO2Sat,
#     AR_vars,
#     ecFEV1_vars,
#     O2SatFFA_vars,
#     IA_vars,
#     UO2Sat_vars,
#     O2Sat_vars,
# ) = mb.o2_sat_fev1_n_days_model_light(n_days, height, age, sex)
(
    model_exact,
    # inf_alg_exact,
    HFEV1,
    HO2Sat,
    AR_vars,
    uecFEV1_vars,
    ecFEV1_vars,
    O2SatFFA_vars,
    IA_vars,
    UO2Sat_vars,
    O2Sat_vars,
    ecFEF2575prctecFEV1,
) = mb.o2sat_fev1_fef2575_n_days_model_noise_shared_healthy_vars_light(
    n_days,
    height,
    age,
    sex,
    ecfev1_noise_model_cpt_suffix=ecfev1_noise_model_cpt_suffix,
)
var_elim = VariableElimination(model_exact)

shared_vars_exact = [HFEV1.name, HO2Sat.name]
obs_vars_exact = {}
for j in range(n_days):
    ecfev1_obs = df_mock.loc[j, "idx ecFEV1 (L)"]
    # o2sat_obs = df_n_days.reset_index().loc[j, "O2 Saturation"]
    obs_vars_exact[ecFEV1_vars[j].name] = ecfev1_obs
    # [O2Sat_vars[j], o2sat_obs],
obs_vars_exact[AR_vars[1].name] = ar_obs_idx_day_1
obs_vars_exact[AR_vars[2].name] = ar_obs_idx_day_2

res_exact = var_elim.query(
    variables=[HFEV1.name, AR_vars[0].name], evidence=obs_vars_exact, joint=False
)

Day specific evidence: [('Airway resistance (%)', [Timestamp('2020-01-03 00:00:00'), Timestamp('2020-01-04 00:00:00')])]


In [7]:
fig = make_subplots(rows=4, cols=1, vertical_spacing=0.1)
# Add HFEV1
ih.plot_histogram(
    fig, HFEV1, res_exact[HFEV1.name].values, HFEV1.a, HFEV1.b, 1, 1, annot=False
)
ih.plot_histogram(
    fig, AR, res_exact[AR_vars[0].name].values, AR.a, AR.b, 2, 1, annot=False
)
ih.plot_histogram(
    fig,
    AR,
    AR.get_point_message(df_mock.loc[1, AR.name]),
    AR.a,
    AR.b,
    3,
    1,
    annot=False,
)
ih.plot_histogram(
    fig,
    AR,
    AR.get_point_message(df_mock.loc[2, AR.name]),
    AR.a,
    AR.b,
    4,
    1,
    annot=False,
)

# Add HFEV1
# ih.plot_histogram(fig, HFEV1, df_res_before_convergence.loc[5, HFEV1.name], HFEV1.a, HFEV1.b, 1, 1, title=HFEV1.name)
ih.plot_histogram(
    fig,
    HFEV1,
    df_query_res.loc[0, HFEV1.name],
    HFEV1.a,
    HFEV1.b,
    1,
    1,
    title=HFEV1.name,
)
ih.plot_histogram(
    fig,
    AR,
    df_query_res.loc[0, AR.name],
    AR.a,
    AR.b,
    2,
    1,
    title=f"{AR.name} day 1",
    annot=False,
)
ih.plot_histogram(
    fig,
    AR,
    AR.get_point_message(df_mock.loc[1, AR.name]),
    AR.a,
    AR.b,
    3,
    1,
    title=f"{AR.name} day 2",
    annot=False,
)
ih.plot_histogram(
    fig,
    AR,
    AR.get_point_message(df_mock.loc[2, AR.name]),
    # df_query_res.loc[2, AR.name],
    AR.a,
    AR.b,
    4,
    1,
    title=f"{AR.name} day 3",
    annot=False,
)

for i in range(4):
    fig.data[i].marker.color = "#636EFA"
    fig.data[i + 4].marker.color = "#EF553B"
# Reduce x axis title font size
fig.update_xaxes(title_font=dict(size=12), title_standoff=7)

# Hide legend
title = "Cutset conditioning (red) vs variable elimination (blue)"
fig.update_layout(showlegend=False, height=550, width=800, title=title)
fig.show()

## Validate sample_jointly_from_AR()