## Note
This notebook is used to genrate and save outputs as pickle files 

In [None]:
import warnings

warnings.filterwarnings("ignore")
from tbdynamics.vietnam.calibration.utils import (
    calculate_scenario_outputs,
    calculate_covid_diff_cum_quantiles,
    calculate_scenario_diff_cum_quantiles,
    calculate_diff_cum_detection_reduction
)
from tbdynamics.settings import DATA_PATH, BASE_PATH, VN_PATH
from tbdynamics.constants import QUANTILES
from tbdynamics.tools.inputs import load_params, load_targets, matrix
from tbdynamics.vietnam.model import build_model
from tbdynamics.vietnam.calibration.utils import get_targets, get_all_priors, get_bcm
from pathlib import Path
import arviz as az
import pickle
import numpy as np
from typing import Dict
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from estival.sampling import tools as esamp
import estival.priors as esp
import estival.targets as est
from estival.model import BayesianCompartmentalModel


In [None]:
# extract data - only run for first time
RUN_PATH = Path.cwd().parent.parent / 'runs/r0205'
OUT_PATH = Path.cwd().parent.parent / 'data/outputs/vietnam'

In [None]:
idata_raw = az.from_netcdf(RUN_PATH / 'calib_full_out.nc')
burnt_idata = idata_raw.sel(draw=np.s_[50000:])
idata = az.extract(burnt_idata, num_samples=1000)
# inference_data = az.convert_to_inference_data(idata_extract.reset_index('sample'))
# az.to_netcdf(inference_data, OUT_PATH /'extracted_data.nc')

In [None]:
#  Load saved idata
# idata = az.from_netcdf(BASE_PATH / 'idata/idata_detection.nc')
params = {
    "start_population_size": 2000000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}

In [None]:
summary = az.summary(idata_raw, var_names=["detection_reduction"], hdi_prob=0.95)

# Extract mean, lower, and upper bounds of 95% CrI
mean = summary.loc["detection_reduction", "mean"]
lower = summary.loc["detection_reduction", "hdi_2.5%"]
upper = summary.loc["detection_reduction", "hdi_97.5%"]

In [None]:
def get_priors():
    """Get all priors used in any of the analysis types.

    Returns:
        All the priors used under any analyses
    """
    priors = [
        esp.UniformPrior("contact_rate", (0.001, 0.05)),
        esp.BetaPrior("rr_infection_latent", 3.0, 8.0),
        esp.BetaPrior("rr_infection_recovered", 3.0, 8.0),
        esp.GammaPrior.from_mode("progression_multiplier", 1.0, 2.0),
        esp.TruncNormalPrior("smear_positive_death_rate", 0.389, 0.0276, (0.335, 0.449)),
        esp.TruncNormalPrior("smear_negative_death_rate", 0.025, 0.0041, (0.017, 0.035)),
        esp.TruncNormalPrior("smear_positive_self_recovery", 0.231, 0.0276, (0.177, 0.288)),
        esp.TruncNormalPrior("smear_negative_self_recovery", 0.130, 0.0291, (0.073, 0.209)),
        esp.UniformPrior("screening_scaleup_shape", (0.05, 0.5)),
        esp.TruncNormalPrior("screening_inflection_time", 2000, 3.5, (1986, 2010)),
        esp.GammaPrior.from_mode("time_to_screening_end_asymp", 2.0, 5.0),
        esp.UniformPrior("incidence_props_smear_positive_among_pulmonary", (0.1, 0.6) ),
        esp.UniformPrior("incidence_props_pulmonary", (0.5, 0.9)),
    ]

    return priors

In [None]:
def get_bcm_simple(
    params, covid_effects
):
    """
    Constructs and returns a Bayesian Compartmental Model.
    Parameters:
    - params (dict): A dictionary containing fixed parameters for the model.

    Returns:
    - BayesianCompartmentalModel: An instance of the BayesianCompartmentalModel class, ready for
      simulation and analysis. This model encapsulates the TB compartmental model, the dynamic
      and fixed parameters, prior distributions for Bayesian inference, and target data for model
      validation or calibration.
    """
    params = params or {}
    fixed_params = load_params(VN_PATH / "params.yml")
    tb_model = build_model(
        fixed_params, matrix, covid_effects, None
    )
    priors = get_priors()
    targets = get_targets()
    return BayesianCompartmentalModel(tb_model, params, priors, targets)

In [None]:
def calculate_diff_cum_detection_reduction(
    params,
    idata_extract: az.InferenceData,
    detection_reduction_values,
    cumulative_start_time: float = 2020.0,
    year: float = 2035.0,
) -> Dict[str, Dict[str, pd.DataFrame]]:
    """
    Calculate absolute and relative differences in cumulative TB incidence and mortality
    by a target year for various detection reduction values, compared to a baseline.

    Args:
        params: Dictionary of model parameters.
        idata_extract: InferenceData object.
        detection_reduction_values: List of detection reduction values to test.
        cumulative_start_time: Year to start cumulative calculations.
        year: Target year to extract cumulative outcomes.

    Returns:
        A dictionary with absolute and relative quantile differences for each scenario.
    """
    # Baseline: no detection or contact reduction
    covid_effects = {"detection_reduction": False, "contact_reduction": False}
    bcm_base = get_bcm(params, covid_effects)
    spaghetti_base = esamp.model_results_for_samples(idata_extract, bcm_base).results
    yearly_base = spaghetti_base.loc[
        (spaghetti_base.index >= cumulative_start_time) & (spaghetti_base.index % 1 == 0)
    ]
    base_cum_diseased = yearly_base["incidence_raw"].cumsum()
    base_cum_deaths = yearly_base["mortality_raw"].cumsum()

    output = {"abs": {}, "rel": {}}

    for val in detection_reduction_values:
        # Update a copy of params with the detection reduction value
        covid_effects = {"detection_reduction": True, "contact_reduction": False}
        scenario_params = params.copy()
        scenario_params["detection_reduction"] = val

        # Keep covid_effects as no reductions
        bcm = get_bcm_simple(scenario_params, covid_effects)
        spaghetti = esamp.model_results_for_samples(idata_extract, bcm).results
        yearly = spaghetti.loc[
            (spaghetti.index >= cumulative_start_time) & (spaghetti.index % 1 == 0)
        ]
        cum_diseased = yearly["incidence_raw"].cumsum()
        cum_deaths = yearly["mortality_raw"].cumsum()

        # Differences
        abs_diff_diseased = cum_diseased.loc[year] - base_cum_diseased.loc[year]
        abs_diff_deaths = cum_deaths.loc[year] - base_cum_deaths.loc[year]
        rel_diff_diseased = abs_diff_diseased / base_cum_diseased.loc[year]
        rel_diff_deaths = abs_diff_deaths / base_cum_deaths.loc[year]

        # Quantiles
        abs_quant_diseased = abs_diff_diseased.quantile(QUANTILES)
        abs_quant_deaths = abs_diff_deaths.quantile(QUANTILES)
        rel_quant_diseased = rel_diff_diseased.quantile(QUANTILES)
        rel_quant_deaths = rel_diff_deaths.quantile(QUANTILES)

        scenario_key = f"detection_reduction_{val}"
        output["abs"][scenario_key] = pd.DataFrame({
            "cumulative_diseased": abs_quant_diseased,
            "cumulative_deaths": abs_quant_deaths
        }).T

        output["rel"][scenario_key] = pd.DataFrame({
            "cumulative_diseased": rel_quant_diseased,
            "cumulative_deaths": rel_quant_deaths
        }).T

    return output

In [None]:
covid_uncertainties = calculate_diff_cum_detection_reduction(params, idata, [0.1, 0.2, 0.3, 0.4, 0.5, 0.6])

In [None]:
def plot_abs_diff_boxplot(diff_output):
    """
    Plot boxplots of absolute differences in cumulative diseased and deaths (2035)
    across detection reduction values.

    Args:
        diff_output: Output dictionary from calculate_diff_cum_detection_reduction
    """
    records = []
    for scenario_label, df in diff_output["abs"].items():
        reduction_value = float(scenario_label.split("_")[-1])
        for indicator in ["cumulative_diseased", "cumulative_deaths"]:
            for quantile, val in df.loc[indicator].items():
                records.append({
                    "reduction_value": reduction_value,
                    "indicator": "Cumulative number of new TB episodes" if indicator == "cumulative_diseased" else "Cumulative TB-related deaths",
                    "quantile": quantile,
                    "value": val
                })

    df_plot = pd.DataFrame(records)

    fig = px.box(
        df_plot,
        x="reduction_value",
        y="value",
        color="indicator",
        points=False,
        labels={
            "reduction_value": "Detection Reduction Value",
            "value": "Absolute Difference (2035)",
            "indicator": "Outcome"
        },
        title="Absolute Differences in Cumulative TB Outcomes by Detection Reduction (2035)"
    )

    fig.update_layout(boxmode="group", legend_title_text="Outcome")
    fig.show()

In [None]:
plot_abs_diff_boxplot(covid_uncertainties)