In [None]:
import warnings

warnings.filterwarnings("ignore")
from pathlib import Path
from tbdynamics.calibration.plotting import (
    plot_covid_configs_comparison_box,
    plot_outputs_for_covid,
    plot_scenario_output_ranges_by_col,
    plot_detection_scenarios_comparison_box
)
from tbdynamics.inputs import DATA_PATH, DOCS_PATH
from tbdynamics.calibration.utils import get_bcm
from tbdynamics.inputs import load_targets
from estival.sampling import tools as esamp
import arviz as az
import pickle
from typing import Dict, List
import pandas as pd
from tbdynamics.constants import quantiles


In [2]:
OUT_PATH = Path(DATA_PATH / 'outputs')
params = {
    "start_population_size": 2000000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}
idata = az.from_netcdf(OUT_PATH / 'inference_data1.nc')
targets = load_targets()

### Impacts of COVID-19

In [None]:
# covid_diff = calculate_covid_diff_cum_quantiles(params, idata)
# notif_covid_outputs = calculate_notifications_for_covid(params, idata)

In [3]:
with open(OUT_PATH /'notif_for_covid_with_ll.pkl', 'rb') as f:
    notif_outputs = pickle.load(f)

In [None]:
notif_outputs['no_covid']['ll_res']

In [6]:
import xarray as xr
import arviz as az

In [7]:
ds = xr.Dataset.from_dataframe(notif_outputs['no_covid']['ll_res'])

In [9]:
idata = az.from_dict(
    posterior={"logposterior": ds["logposterior"]},
    prior={"logprior": ds["logprior"]},
    log_likelihood={"total_loglikelihood": ds["loglikelihood"]}
)

In [None]:
idata

In [None]:
az.loo(idata)

### Notifications under different assumption of COVID-19 assumptions

In [None]:
plot_outputs_for_covid(notif_covid_outputs,targets)

### Long-term impacts of COVID-19

In [None]:
plot_covid_configs_comparison_box(covid_diff)

### Future projection

In [14]:
with open(OUT_PATH /'quant_outputs.pkl', 'rb') as f:
    outputs = pickle.load(f)

In [None]:
plot_scenario_output_ranges_by_col(outputs)

In [11]:
def calculate_scenario_diff_cum_quantiles(
    params: Dict[str, float],
    idata_extract: az.InferenceData,
    detection_multipliers: List[float],
    cumulative_start_time: int = 2020,
    covid_choice: int = 2,
    years: List[int] = [2021, 2022, 2025, 2030, 2035],
) -> Dict[str, Dict[str, Dict[str, pd.DataFrame]]]:
    """
    Calculate the cumulative incidence and deaths for each scenario with different detection multipliers,
    compute the differences compared to a base scenario, and return quantiles for absolute and relative differences.

    Args:
        params: Dictionary containing model parameters.
        idata_extract: InferenceData object containing the model data.
        detection_multipliers: List of multipliers for improved detection to loop through.
        cumulative_start_time: Year to start calculating the cumulative values.
        scenario_choice: Integer specifying which scenario to use (1 or 2).
        years: List of years for which to calculate the quantiles.

    Returns:
        A dictionary containing the quantiles for absolute and relative differences between scenarios.
    """

    # Set scenario configuration based on scenario_choice
    if covid_choice == 1:
        covid_config = {"detection_reduction": True, "contact_reduction": True}
    elif covid_choice == 2:
        covid_config = {"detection_reduction": True, "contact_reduction": False}
    else:
        raise ValueError("Invalid scenario_choice. Choose 1 or 2.")

    # Base scenario (without improved detection)
    bcm = get_bcm(params, covid_config)
    base_results = esamp.model_results_for_samples(idata_extract, bcm).results

    # Calculate cumulative sums for the base scenario
    yearly_data_base = base_results.loc[
        (base_results.index >= cumulative_start_time) & (base_results.index % 1 == 0)
    ]
    cumulative_diseased_base = yearly_data_base["incidence_raw"].cumsum()
    cumulative_deaths_base = yearly_data_base["mortality_raw"].cumsum()

    # Store results for each detection multiplier
    detection_diff_results = {}

    for multiplier in detection_multipliers:
        # Improved detection scenario
        bcm = get_bcm(params, covid_config, multiplier)
        scenario_result = esamp.model_results_for_samples(idata_extract, bcm).results

        # Calculate cumulative sums for each scenario
        yearly_data = scenario_result.loc[
            (scenario_result.index >= cumulative_start_time)
            & (scenario_result.index % 1 == 0)
        ]
        cumulative_diseased = yearly_data["incidence_raw"].cumsum()
        cumulative_deaths = yearly_data["mortality_raw"].cumsum()

        # Calculate differences compared to the base scenario
        abs_diff = {
            "cumulative_diseased": cumulative_diseased - cumulative_diseased_base,
            "cumulative_deaths": cumulative_deaths - cumulative_deaths_base,
        }
        rel_diff = {
            "cumulative_diseased": abs_diff["cumulative_diseased"]
            / cumulative_diseased_base,
            "cumulative_deaths": abs_diff["cumulative_deaths"] / cumulative_deaths_base,
        }

        # Calculate quantiles for absolute and relative differences
        diff_quantiles_abs = {}
        diff_quantiles_rel = {}

        for ind in ["cumulative_diseased", "cumulative_deaths"]:
            diff_quantiles_df_abs = pd.DataFrame(
                {
                    quantile: [
                        abs_diff[ind].loc[year].quantile(quantile) for year in years
                    ]
                    for quantile in quantiles
                },
                index=years,
            )

            diff_quantiles_df_rel = pd.DataFrame(
                {
                    quantile: [
                        rel_diff[ind].loc[year].quantile(quantile) for year in years
                    ]
                    for quantile in quantiles
                },
                index=years,
            )

            diff_quantiles_abs[ind] = diff_quantiles_df_abs
            diff_quantiles_rel[ind] = diff_quantiles_df_rel

        # Store the quantile results
        scenario_key = f"increase_case_detection_by_{multiplier}".replace(".", "_")
        detection_diff_results[scenario_key] = {
            "abs": diff_quantiles_abs,
            "rel": diff_quantiles_rel,
        }

    # Return the quantiles for absolute and relative differences
    return detection_diff_results

In [12]:
scenarios_diff = calculate_scenario_diff_cum_quantiles(params, idata, [2.0, 5.0, 12.0])

In [None]:
plot_detection_scenarios_comparison_box(scenarios_diff)