In [None]:
import warnings
warnings.filterwarnings("ignore")
import arviz as az
import pandas as pd
import plotly.express as px
import numpy as np
from tbdynamics.camau.calibration.utils import get_bcm, calculate_future_acf_outputs
from tbdynamics.calibration.plotting import plot_output_ranges, plot_trial_output_ranges
from tbdynamics.tools.inputs import load_targets
from tbdynamics.settings import CM_PATH, OUT_PATH, DOCS_PATH
from tbdynamics.constants import QUANTILES
from tbdynamics.camau.constants import indicator_legends, indicator_names
import estival.sampling.tools as esamp
from typing import Dict, Optional, List
from tbdynamics.tools.detect import make_future_acf_scenarios
import pickle

In [None]:
pd.options.plotting.backend = "plotly"

In [None]:
# loaded_inference_data = az.from_netcdf(OUT_PATH / 'inference_data1.nc')
# idata = az.from_netcdf(OUT_PATH / 'extracted_idata.nc')
params = {
    "start_population_size": 30000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
    "rr_infection_latent": 0.1890473700762809,
    "rr_infection_recovered": 0.17781844797545143,
    "smear_positive_death_rate": 0.3655528915762244,
    "smear_negative_death_rate": 0.027358324164819155,
    "smear_positive_self_recovery": 0.18600338108638945,
    "smear_negative_self_recovery": 0.11333894801537307,
    "screening_scaleup_shape": 0.3,
    "screening_inflection_time": 1993,
    "acf_sensitivity": 0.90,
}
targets = load_targets(CM_PATH / "targets.yml")

In [None]:
idata_raw = az.from_netcdf(OUT_PATH / 'camau/best/calib_full_out.nc')

In [None]:
burnt_idata = idata_raw.sel(draw=np.s_[50000:])
idata_extract = az.extract(burnt_idata, num_samples=300)

In [None]:
# outputs = calculate_scenario_outputs(params, idata)
# with open(OUT_PATH / 'quant_outputs.pkl', 'wb') as f:
#      pickle.dump(outputs, f)

In [None]:
# with open(OUT_PATH /'camau/quant_outputs.pkl', 'rb') as f:
#     outputs = pickle.load(f)
covid_effects = {"detection_reduction": True, "contact_reduction": False}

#     # Base scenario (calculate outputs for all indicators)
bcm = get_bcm(params, covid_effects)
base_results = esamp.model_results_for_samples(idata_extract, bcm).results
base_quantiles = esamp.quantiles_for_results(base_results, QUANTILES)

In [None]:
# base_quantiles.to_pickle(OUT_PATH / 'camau/output0304.pkl')
# base_quantiles = pd.read_pickle(OUT_PATH / 'camau/output0304.pkl')

In [None]:
# target_plot.write_image(DOCS_PATH / "targets1.png", scale=3)
plot_output_ranges(base_quantiles,targets,["total_population","act3_trial_adults_pop", "act3_control_adults_pop"],indicator_names,indicator_legends,1,2010,2025, option = 'camau') #.write_image(DOCS_PATH /'camau/pops.png', scale=3)

In [None]:
plot_output_ranges(base_quantiles,targets,['notification','percentage_latent_adults'],indicator_names,indicator_legends,2,2010,2025, option='camau') #.write_image(DOCS_PATH /'camau/targets.png', scale=3)


In [None]:
plot_output_ranges(base_quantiles,targets,['incidence', 'prevalence_pulmonary', 'adults_prevalence_pulmonary','mortality'],indicator_names,indicator_legends,2,2010,2025, option='camau') #.write_image(DOCS_PATH /'camau/compare.png', scale=3)

In [None]:
# plot_output_ranges(base_quantiles,targets,['detection_rate'],indicator_names,indicator_legends,1,1980,2025)

In [None]:
# plot_output_ranges(base_quantiles,targets,['incidence', 'prevalence_pulmonary', 'adults_prevalence_pulmonary','mortality'],indicator_names,indicator_legends,2,2010,2025, option='camau')

In [None]:
# plot_trial_output_ranges(base_quantiles,targets,['acf_detectionXact3_trial','acf_detectionXact3_control'],indicator_names,2) #.write_image(DOCS_PATH /'camau/trial.png', scale=3)

In [None]:
arms = ['act3_trial', 'act3_control', 'act3_other']
metrics = ['incidenceX', 'prevalence_infectiousX']
indicators = [f"{metric}{arm}" for arm in arms for metric in metrics]
plot_output_ranges(
    base_quantiles,
    targets,
    indicators,
    indicator_names,
    indicator_legends,
    2,
    2010,
    2025,
    option='camau',
) #.write_image(DOCS_PATH / 'camau/output_arms.png', scale=3)

In [None]:
# target_plot.write_image(DOCS_PATH / "targets2.png", scale=3)

In [None]:
# spah.write_image(DOCS_PATH / 'spah.png', scale = 3)

In [None]:
# target_plot_history = plot_output_ranges(outputs['base_scenario'],targets,['total_population','notification','adults_prevalence_pulmonary'],1,1800,2010, history =True)

In [None]:
# target_plot_history

In [None]:
# target_plot_history.write_image(DOCS_PATH / 'targets_history.png', scale=3)

In [None]:
# compare_target_plot = plot_output_ranges(outputs['base_scenario'],targets,['incidence','mortality_raw','prevalence_smear_positive', 'percentage_latent'],2,2010,2025)

In [None]:
# compare_target_plot.write_image(DOCS_PATH / "non_targets.png", scale='3')

In [None]:
# compare_target_plot

In [None]:
# screening_plot.write_image(DOCS_PATH / 'screening_plot.png', scale =3)

In [None]:
# cdr_plot = plot_output_ranges(outputs['base_scenario']['quantiles'],targets,['case_notification_rate'],1,2010,2025)

In [None]:
# cdr_plot.write_image(DOCS_PATH / 'cdr_plot.png', scale =3)

In [None]:
# early_plot = plot_output_ranges(base_quantiles,targets,['incidence_early_prop'], indicator_names, indicator_legends,1,2000,2025) #.write_image(DOCS_PATH /'camau/early.png', scale=3)

In [None]:
config = {
    "arm": ["trial", "control", "other"],
    "every": [2,4],
    "coverage": [0.8],
}
future_acf_scenarios = make_future_acf_scenarios(config)

In [None]:
future_acf_scenarios

In [None]:
request_outputs = [
    "notification",
    "acf_notification",
    "incidence_raw",
    "mortality_raw",
    "prevalence_infectious",
    "prevalence_pulmonary",
    "incidence",
    "notificationXact3_trial",
    "acf_detectionXact3_trial",
    "mortality_infectious_rawXact3_trial",
    "mortality_rateXact3_trial",
    "cumulative_deathsXact3_trial",
    "cumulative_diseasedXact3_trial",
    "prevalence_infectiousXact3_trial",
    "incidenceXact3_trial",
    "act3_trial_adults_prevalence",
    "incidence_adults"
]

In [None]:
prov_outputs = calculate_future_acf_outputs(params=params, idata_extract=idata_extract,covid_effects=covid_effects, future_acf_scenarios=future_acf_scenarios, request_outputs=request_outputs)

In [None]:
prov_outputs['status-quo'] = base_quantiles[request_outputs]

In [None]:
# with open(OUT_PATH / "camau/prov_scenario.pkl", "wb") as f:
#     pickle.dump(prov_outputs, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
def calculate_diff_cumulative_from_output_dict(
    output_dict: Dict[str, pd.DataFrame],
    cumulative_start_time: float = 2020.0,
    years: List[int] = [2030, 2035],
    quantiles: List[float] = QUANTILES,
    indicators: List[str] = ["incidence_raw", "mortality_raw"],
    base_scenario_key: str = "status-quo",
) -> Dict[str, Dict[str, Dict[str, pd.DataFrame]]]:
    """
    Loop through scenario results and compute quantiles of cumulative diff vs base.
    Automatically renames indicators for downstream plotting compatibility.
    
    Returns:
        Dict[scenario][abs/rel][indicator] = quantile DataFrames
    """
    # Map raw indicators to standardized names
    rename_map = {
        "incidence_raw": "cumulative_diseased",
        "mortality_raw": "cumulative_deaths"
    }

    base_df = output_dict[base_scenario_key]

    # Calculate base cumulative sums
    base_cum = {}
    yearly_base = base_df.loc[
        (base_df.index >= cumulative_start_time) & (base_df.index % 1 == 0)
    ]
    for ind in indicators:
        base_cum[ind] = yearly_base[ind].cumsum()

    diff_results = {}

    for scenario, df in output_dict.items():
        if scenario == base_scenario_key:
            continue

        yearly_data = df.loc[
            (df.index >= cumulative_start_time) & (df.index % 1 == 0)
        ]
        abs_diff = {}
        rel_diff = {}

        for ind in indicators:
            cum = yearly_data[ind].cumsum()
            abs_ = cum - base_cum[ind]
            rel_ = abs_ / base_cum[ind] * 100

            abs_df = pd.DataFrame(
                {q: [abs_.loc[year].quantile(q) for year in years] for q in quantiles},
                index=years,
            )
            rel_df = pd.DataFrame(
                {q: [rel_.loc[year].quantile(q) for year in years] for q in quantiles},
                index=years,
            )

            renamed = rename_map.get(ind, ind)
            abs_diff[renamed] = abs_df
            rel_diff[renamed] = rel_df

        diff_results[scenario] = {"abs": abs_diff, "rel": rel_diff}

    return diff_results

In [None]:
temp = calculate_diff_cumulative_from_output_dict(prov_outputs, years=[2035])

In [None]:
temp

In [None]:
from tbdynamics.calibration.plotting import plot_detection_scenarios_comparison_box

In [None]:
plot_detection_scenarios_comparison_box(temp)

In [None]:
def calculate_act3_effect(
    output_dict: Dict[str, pd.DataFrame],
    indicators: List[str] = ["incidence_raw", "mortality_raw"],
    base_scenario_key: str = "status-quo",
) -> Dict[str, pd.DataFrame]:
    """
    Calculate the effect of ACT3 on specified indicators.
    
    Returns:
        DataFrame with quantiles of the difference between ACT3 and base scenario.
    """
    base_df = output_dict[base_scenario_key]
    act3_df = output_dict["act3_trial"]

    results = {}
    for ind in indicators:
        base_cum = base_df[ind].cumsum()
        act3_cum = act3_df[ind].cumsum()
        diff = act3_cum - base_cum

        results[ind] = pd.DataFrame(
            {q: [diff.quantile(q)] for q in QUANTILES},
            index=[f"{ind}_act3_effect"]
        )

    return results

In [None]:
def calculate_act3_effect(
    params: Dict[str, float],
    idata_extract: az.InferenceData,
    covid_effects: Dict[str, bool],
    cumulative_start_time: float = 2020.0,
    years: List[int] = [2030, 2035],
    quantiles: List[float] = QUANTILES,
    indicators: List[str] = ["incidence_raw", "mortality_raw"],
) -> Dict[str, Dict[str, Dict[str, pd.DataFrame]]]:
    """
    Compare ACT3-style ACF (acf=True) vs no ACF (acf=False) under same COVID effects.
    Returns quantiles of absolute and relative differences in cumulative indicators.

    Output indicators: cumulative_diseased and cumulative_deaths
    """

    # ACF ON (status-quo)
    bcm_acf = get_bcm(params, covid_effects)
    results_acf = esamp.model_results_for_samples(idata_extract, bcm_acf).results
    yearly_acf = results_acf.loc[
        (results_acf.index >= cumulative_start_time) & (results_acf.index % 1 == 0)
    ]
    base_cum = {ind: yearly_acf[ind].cumsum() for ind in indicators}

    # ACF OFF
    bcm_no_acf = get_bcm(params, covid_effects, implement_act3=False)
    results_no_acf = esamp.model_results_for_samples(idata_extract, bcm_no_acf).results
    yearly_no_acf = results_no_acf.loc[
        (results_no_acf.index >= cumulative_start_time) & (results_no_acf.index % 1 == 0)
    ]

    abs_diff = {}
    rel_diff = {}

    rename_map = {
        "incidence_raw": "cumulative_diseased",
        "mortality_raw": "cumulative_deaths"
    }

    for ind in indicators:
        cum = yearly_no_acf[ind].cumsum()
        abs_ = cum - base_cum[ind]
        rel_ = abs_ / base_cum[ind] * 100

        abs_df = pd.DataFrame(
            {q: [abs_.loc[year].quantile(q) for year in years] for q in quantiles},
            index=years,
        )
        rel_df = pd.DataFrame(
            {q: [rel_.loc[year].quantile(q) for year in years] for q in quantiles},
            index=years,
        )

        renamed = rename_map.get(ind, ind)
        abs_diff[renamed] = abs_df
        rel_diff[renamed] = rel_df

    return {"no_acf": {"abs": abs_diff, "rel": rel_diff}}

In [None]:
act3_effect = calculate_act3_effect(
    params=params,
    idata_extract=idata_extract,
    covid_effects=covid_effects,
    cumulative_start_time=2014,
    years=[2020],
    quantiles=QUANTILES,
    indicators=["incidence_raw", "mortality_raw"],
)