In [None]:
import warnings

warnings.filterwarnings("ignore")
from pathlib import Path
import arviz as az

from tbdynamics.calibration.plotting import plot_output_ranges
from tbdynamics.calibration.utils import get_bcm
from tbdynamics.inputs import load_targets, DATA_PATH, DOCS_PATH
import pickle
from tbdynamics.constants import quantiles
from typing import List, Dict
from estival.sampling import tools as esamp
import pandas as pd


In [2]:
OUT_PATH = DATA_PATH / 'outputs'
# loaded_inference_data = az.from_netcdf(OUT_PATH / 'inference_data1.nc')
idata = az.from_netcdf(OUT_PATH / 'inference_data1.nc')
params = {
    "start_population_size": 2000000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}
targets = load_targets()

In [3]:
def calculate_scenario_outputs(
    params: Dict[str, float],
    idata_extract: az.InferenceData,
    indicators: List[str] = ["incidence", "mortality_raw"],
    detection_multipliers: List[float] = [2.0, 5.0, 12.0],
) -> Dict[str, Dict[str, pd.DataFrame]]:
    """
    Calculate the model results for each scenario with different detection multipliers
    and return the baseline and scenario outputs.

    Args:
        params: Dictionary containing model parameters.
        idata_extract: InferenceData object containing the model data.
        indicators: List of indicators to return for the other scenarios (default: ['incidence', 'mortality_raw']).
        detection_multipliers: List of multipliers for improved detection to loop through (default: [2.0, 5.0, 12.0]).

    Returns:
        A dictionary containing results for the baseline and each scenario.
    """

    # Fixed scenario configuration
    scenario_config = {"detection_reduction": True, "contact_reduction": False}

    # Base scenario (calculate outputs for all indicators)
    bcm = get_bcm(params, scenario_config)
    base_results = esamp.model_results_for_samples(idata_extract, bcm).results
    base_quantiles = esamp.quantiles_for_results(base_results, quantiles)

    # Store results for the baseline scenario
    scenario_outputs = {"base_scenario": base_quantiles}

    # Calculate quantiles for each detection multiplier scenario
    for multiplier in detection_multipliers:
        bcm = get_bcm(params, scenario_config, multiplier)
        scenario_result = esamp.model_results_for_samples(idata_extract, bcm).results
        scenario_quantiles = esamp.quantiles_for_results(scenario_result, quantiles)

        # Store the results for this scenario
        scenario_key = f"increase_case_detection_by_{multiplier}".replace(".", "_")
        scenario_outputs[scenario_key] = scenario_quantiles

    # Extract only the relevant indicators for each scenario
    for scenario_key in scenario_outputs:
        if scenario_key != "base_scenario":
            scenario_outputs[scenario_key] = scenario_outputs[scenario_key][indicators]

    return scenario_outputs

In [4]:
outputs = calculate_scenario_outputs(params, idata)
# with open(OUT_PATH / 'quant_outputs.pkl', 'wb') as f:
#      pickle.dump(outputs, f)

In [5]:
# with open(OUT_PATH /'quant_outputs.pkl', 'rb') as f:
#     outputs = pickle.load(f)

In [6]:
target_plot = plot_output_ranges(outputs['base_scenario'],targets,['total_population','notification','adults_prevalence_pulmonary'],1,2010,2025)

In [None]:
target_plot

In [8]:
# target_plot.write_image(DOCS_PATH / "targets1.png", scale=3)

In [9]:
target_plot_history = plot_output_ranges(outputs['base_scenario'],targets,['total_population','notification','adults_prevalence_pulmonary'],1,1800,2010, history =True)

In [None]:
target_plot_history

In [11]:
# target_plot_history.write_image(DOCS_PATH / 'targets_history.png', scale=3)

In [12]:
compare_target_plot = plot_output_ranges(outputs['base_scenario'],targets,['incidence','mortality_raw','prevalence_smear_positive', 'percentage_latent'],2,2010,2025)

In [13]:
# compare_target_plot.write_image(DOCS_PATH / "non_targets.png", scale='3')

In [None]:
compare_target_plot

In [15]:
screening_plot=plot_output_ranges(outputs['base_scenario'],targets,['detection_rate'],1,1981,2025, show_title=False)

In [16]:
# screening_plot.write_image(DOCS_PATH / 'screening_plot.png', scale =3)

In [None]:
screening_plot