In [None]:
import warnings

warnings.filterwarnings("ignore")
from tbdynamics.vietnam.calibration.utils import (
    calculate_scenario_outputs,
    calculate_covid_diff_cum_quantiles,
    calculate_scenario_diff_cum_quantiles,
    calculate_diff_cum_detection_reduction
)
from tbdynamics.calibration.plotting import plot_sensitivity_subplots
from tbdynamics.settings import DATA_PATH, BASE_PATH, VN_PATH
from tbdynamics.constants import QUANTILES
from tbdynamics.tools.inputs import load_params, load_targets, matrix
from tbdynamics.vietnam.model import build_model
from pathlib import Path
import arviz as az
import numpy as np
from typing import Union, List
from scipy.stats import qmc
import pandas as pd
import plotly.express as px
from typing import Union

In [None]:
pd.options.plotting.backend = "plotly"

In [None]:
RUN_PATH = Path.cwd().parent.parent / 'runs/r0205'
OUT_PATH = Path.cwd().parent.parent / 'data/outputs/vietnam'

In [None]:
idata_raw = az.from_netcdf(RUN_PATH / 'calib_full_out.nc')
burnt_idata = idata_raw.sel(draw=np.s_[50000:])
idata = az.extract(burnt_idata, num_samples=1000)

In [None]:
init_params = {
    "start_population_size": 2000000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}
fixed_params = load_params(VN_PATH / "params.yml")
covid_effects = {
    'detection_reduction':True,
    'contact_reduction':False
}

In [None]:
# summary = az.summary(idata_raw, hdi_prob=0.95)
# Extract mean, lower, and upper bounds of 95% CrI
# mean = summary.loc["detection_reduction", "mean"]
# lower = summary.loc["detection_reduction", "hdi_2.5%"]
# upper = summary.loc["detection_reduction", "hdi_97.5%"]

In [None]:
# mle_params = summary["mean"].to_dict()
# params = init_params | mle_params
# model=build_model(fixed_params, matrix,covid_effects)
# model.run(params)
# base_results = model.get_derived_outputs_df()
# yearly_base = base_results.loc[
#         (base_results.index >= 2020) & (base_results.index % 1 == 0)
#     ]
# base_cum_diseased = yearly_base["incidence_raw"].cumsum()
# base_cum_deaths = yearly_base["mortality_raw"].cumsum()

In [None]:
def extract_param_ranges(idata, hdi_prob=0.95):
    summary = az.summary(idata, hdi_prob=hdi_prob)
    mle_params = summary["mean"].to_dict()
    ranges = {
        param: (
            summary.loc[param, f"hdi_{(1 - hdi_prob) / 2:.1%}"],
            summary.loc[param, f"hdi_{1 - (1 - hdi_prob) / 2:.1%}"]
        )
        for param in summary.index
    }
    return mle_params, ranges

In [None]:
def run_sensitivity_analysis(
    params_to_vary: Union[str, List[str]],
    init_params: dict,
    fixed_params: dict,
    mle_params: dict,
    param_ranges: dict,
    improved_detection_multiplier=None,
    n_samples: int = 100,
    target_year: int = 2035,
):
    if isinstance(params_to_vary, str):
        params_to_vary = [params_to_vary]

    sub_ranges = {param: param_ranges[param] for param in params_to_vary}
    results_dict = {}

    for param in params_to_vary:
        low, high = sub_ranges[param]
        sampler = qmc.LatinHypercube(d=1)
        samples = qmc.scale(sampler.random(n=n_samples), [low], [high]).flatten()

        param_results = []

        for val in samples:
            sample_params = init_params | mle_params | {param: val}

            # Run base model
            model_base = build_model(fixed_params, matrix, covid_effects, improved_detection_multiplier)
            model_base.run(sample_params)
            yearly_base = model_base.get_derived_outputs_df().loc[
                (model_base.get_derived_outputs_df().index >= 2020) & 
                (model_base.get_derived_outputs_df().index % 1 == 0)
            ]
            cum_diseased_base = yearly_base["incidence_raw"].cumsum().loc[target_year]
            cum_deaths_base = yearly_base["mortality_raw"].cumsum().loc[target_year]

            # Run improved detection model
            model_improved = build_model(fixed_params, matrix, covid_effects, improved_detection_multiplier)
            model_improved.run(sample_params)
            yearly_improved = model_improved.get_derived_outputs_df().loc[
                (model_improved.get_derived_outputs_df().index >= 2020) & 
                (model_improved.get_derived_outputs_df().index % 1 == 0)
            ]
            cum_diseased_improved = yearly_improved["incidence_raw"].cumsum().loc[target_year]
            cum_deaths_improved = yearly_improved["mortality_raw"].cumsum().loc[target_year]

            # Calculate differences (improved - base)
            param_results.append({
                "value": val,
                "diff_cum_diseased": cum_diseased_improved - cum_diseased_base,
                "diff_cum_deaths": cum_deaths_improved - cum_deaths_base,
            })

        results_dict[param] = pd.DataFrame(param_results)

    return results_dict

In [None]:
# mle_params, full_param_ranges = extract_param_ranges(burnt_idata)

In [None]:
# mle_params

In [None]:
# df = run_sensitivity_analysis(
#     params_to_vary=["contact_rate", "smear_positive_death_rate"],
#     init_params=init_params,
#     fixed_params=fixed_params,
#     mle_params=mle_params,
#     param_ranges=full_param_ranges,
#     improved_detection_multiplier = 5.0,
#     n_samples=20, 
# )

In [None]:
# pd.to_pickle(df, DATA_PATH / "outputs/vietnam/sensitivity_results.csv")
df_dict = pd.read_pickle(DATA_PATH / "outputs/vietnam/sensitivity_results.csv")

In [None]:
keys_to_remove = ["smear_positive_death_rate"]
df = {k: v for k, v in df_dict.items() if k not in keys_to_remove}

In [None]:
plot_sensitivity_subplots(df)

In [None]:
# df_smear = run_sensitivity_analysis(
#     params_to_vary=["contact_rate"],
#     init_params=init_params,
#     fixed_params=fixed_params,
#     mle_params=mle_params,
#     param_ranges=full_param_ranges,
#     improved_detection_multiplier = 5.0,
#     n_samples=20, 
# )