In [None]:
import warnings

warnings.filterwarnings("ignore")
from pathlib import Path
from tbdynamics.calibration.plotting import (
    plot_covid_configs_comparison_box,
    plot_outputs_for_covid,
    plot_scenario_output_ranges_by_col,
    plot_detection_scenarios_comparison_box
)
from tbdynamics.inputs import DATA_PATH, DOCS_PATH
from tbdynamics.calibration.utils import get_bcm
from tbdynamics.inputs import load_targets
from estival.sampling import tools as esamp
import arviz as az
import pickle
from typing import Dict, List
import pandas as pd
from tbdynamics.constants import quantiles


In [3]:
OUT_PATH = Path(DATA_PATH / 'outputs')
params = {
    "start_population_size": 2000000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}
idata = az.from_netcdf(OUT_PATH / 'inference_data1.nc')
targets = load_targets()

### Impacts of COVID-19

In [4]:
def calculate_covid_diff_cum_quantiles(
    params: Dict[str, float],
    idata_extract: az.InferenceData,
    cumulative_start_time: float = 2020.0,
    covid_analysis: int = 2,
    years: List[float] = [2021.0, 2022.0, 2025.0, 2030.0, 2035.0],
) -> Dict[str, Dict[str, pd.DataFrame]]:
    """
    Run the models for the specified scenarios, calculate cumulative diseased and death values,
    and return quantiles for absolute and relative differences between scenarios.

    Args:
        params: Dictionary containing model parameters.
        idata_extract: InferenceData object containing the model data.
        cumulative_start_time: Year to start calculating the cumulative values.
        covid_analysis: Integer specifying which analysis to run (default is 2).
        years: List of years for which to calculate the differences.

    Returns:
        A dictionary containing quantiles for absolute and relative differences between scenarios.
    """

    # Validate that covid_analysis is either 1 or 2
    if covid_analysis not in [1, 2]:
        raise ValueError("Invalid value for covid_analysis. Must be 1 or 2.")

    # Define the scenarios
    covid_configs = [
        {"detection_reduction": False, "contact_reduction": False},  # No reduction
        {
            "detection_reduction": True,
            "contact_reduction": True,
        },  # With detection + contact reduction
        {
            "detection_reduction": True,
            "contact_reduction": False,
        },  # No contact reduction
    ]

    covid_results = []
    for covid_effects in covid_configs:
        # Get the model results
        bcm = get_bcm(params, covid_effects)
        spaghetti_res = esamp.model_results_for_samples(idata_extract, bcm).results

        # Filter the results to include only the rows where the index (year) is an integer
        yearly_data = spaghetti_res.loc[
            (spaghetti_res.index >= cumulative_start_time)
            & (spaghetti_res.index % 1 == 0)
        ]

        # Calculate cumulative sums for each sample
        cumulative_diseased_yearly = yearly_data["incidence_raw"].cumsum()
        cumulative_deaths_yearly = yearly_data["mortality_raw"].cumsum()

        # Store the cumulative results in the list
        covid_results.append(
            {
                "cumulative_diseased": cumulative_diseased_yearly,
                "cumulative_deaths": cumulative_deaths_yearly,
            }
        )

    # Calculate the differences based on the covid_analysis value
    abs_diff = {
        "cumulative_diseased": covid_results[covid_analysis]["cumulative_diseased"]
        - covid_results[0]["cumulative_diseased"],
        "cumulative_deaths": covid_results[covid_analysis]["cumulative_deaths"]
        - covid_results[0]["cumulative_deaths"],
    }
    rel_diff = {
        "cumulative_diseased": abs_diff["cumulative_diseased"]
        / covid_results[0]["cumulative_diseased"],
        "cumulative_deaths": abs_diff["cumulative_deaths"]
        / covid_results[0]["cumulative_deaths"],
    }

    # Calculate quantiles for absolute and relative differences
    diff_quantiles_abs, diff_quantiles_rel  = {},{}

    for ind in ["cumulative_diseased", "cumulative_deaths"]:
        # Calculate absolute difference quantiles
        diff_quantiles_df_abs = pd.DataFrame(
            {
                quantile: [abs_diff[ind].loc[year].quantile(quantile) for year in years]
                for quantile in quantiles
            },
            index=years,
        )

        # Calculate relative difference quantiles
        diff_quantiles_df_rel = pd.DataFrame(
            {
                quantile: [rel_diff[ind].loc[year].quantile(quantile) for year in years]
                for quantile in quantiles
            },
            index=years,
        )

        diff_quantiles_abs[ind] = diff_quantiles_df_abs
        diff_quantiles_rel[ind] = diff_quantiles_df_rel

    return {"abs": diff_quantiles_abs, "rel": diff_quantiles_rel}

In [6]:
def calculate_notifications_for_covid(
    params: Dict[str, float],
    idata_extract: az.InferenceData,
) -> Dict[str, Dict[str, pd.DataFrame]]:
    """
    Calculate model outputs for each scenario defined in covid_configs and store the results
    in a dictionary where the keys correspond to the keys in covid_configs.

    Args:
        params: Dictionary containing model parameters.
        idata_extract: InferenceData object containing the model data.
        indicators: List of indicators to calculate outputs for.

    Returns:
        A dictionary where each key corresponds to a scenario in covid_configs and the value is
        another dictionary containing DataFrames with outputs for the given indicators.
    """

    # Define the covid_configs inside the function
    covid_configs = {
        "no_covid": {
            "detection_reduction": False,
            "contact_reduction": False,
        },  # No reduction
        "case_detection_reduction_only": {
            "detection_reduction": True,
            "contact_reduction": False,
        },  # No contact reduction
        "contact_reduction_only": {
            "detection_reduction": False,
            "contact_reduction": True,
        },  # Only contact reduction
        "detection_and_contact_reduction": {
            "detection_reduction": True,
            "contact_reduction": True,
        },  # With detection + contact reduction
    }

    covid_outputs = {}

    # Loop through each scenario in covid_configs
    for covid_name, covid_effects in covid_configs.items():
        # Run the model for the current scenario
        bcm = get_bcm(params, covid_effects)
        model_results = esamp.model_results_for_samples(idata_extract, bcm)
        spaghetti_res = model_results.results
        ll_res = (
            model_results.extras
        )  # Extract additional results (e.g., log-likelihoods)
        scenario_quantiles = esamp.quantiles_for_results(spaghetti_res, quantiles)

        # Initialize a dictionary to store indicator-specific outputs
        indicator_outputs = {}

        # Extract the results only for the "notification" indicator
        notification_indicator = (
            "notification"  # Replace with the exact name of the notification indicator
        )
        if notification_indicator in scenario_quantiles:
            indicator_outputs[notification_indicator] = scenario_quantiles[
                notification_indicator
            ]

        # Store the outputs and ll_res in the dictionary with the scenario name as the key
        covid_outputs[covid_name] = {
            "indicator_outputs": indicator_outputs,
            "ll_res": ll_res,
        }

    return covid_outputs

In [None]:
covid_diff = calculate_covid_diff_cum_quantiles(params, idata)
notif_covid_outputs = calculate_notifications_for_covid(params, idata)

### Notifications under different assumption of COVID-19 assumptions

In [None]:
plot_outputs_for_covid(notif_covid_outputs,targets)

### Long-term impacts of COVID-19

In [None]:
plot_covid_configs_comparison_box(covid_diff)

### Future projection

In [14]:
with open(OUT_PATH /'quant_outputs.pkl', 'rb') as f:
    outputs = pickle.load(f)

In [None]:
plot_scenario_output_ranges_by_col(outputs)

In [11]:
def calculate_scenario_diff_cum_quantiles(
    params: Dict[str, float],
    idata_extract: az.InferenceData,
    detection_multipliers: List[float],
    cumulative_start_time: int = 2020,
    covid_choice: int = 2,
    years: List[int] = [2021, 2022, 2025, 2030, 2035],
) -> Dict[str, Dict[str, Dict[str, pd.DataFrame]]]:
    """
    Calculate the cumulative incidence and deaths for each scenario with different detection multipliers,
    compute the differences compared to a base scenario, and return quantiles for absolute and relative differences.

    Args:
        params: Dictionary containing model parameters.
        idata_extract: InferenceData object containing the model data.
        detection_multipliers: List of multipliers for improved detection to loop through.
        cumulative_start_time: Year to start calculating the cumulative values.
        scenario_choice: Integer specifying which scenario to use (1 or 2).
        years: List of years for which to calculate the quantiles.

    Returns:
        A dictionary containing the quantiles for absolute and relative differences between scenarios.
    """

    # Set scenario configuration based on scenario_choice
    if covid_choice == 1:
        covid_config = {"detection_reduction": True, "contact_reduction": True}
    elif covid_choice == 2:
        covid_config = {"detection_reduction": True, "contact_reduction": False}
    else:
        raise ValueError("Invalid scenario_choice. Choose 1 or 2.")

    # Base scenario (without improved detection)
    bcm = get_bcm(params, covid_config)
    base_results = esamp.model_results_for_samples(idata_extract, bcm).results

    # Calculate cumulative sums for the base scenario
    yearly_data_base = base_results.loc[
        (base_results.index >= cumulative_start_time) & (base_results.index % 1 == 0)
    ]
    cumulative_diseased_base = yearly_data_base["incidence_raw"].cumsum()
    cumulative_deaths_base = yearly_data_base["mortality_raw"].cumsum()

    # Store results for each detection multiplier
    detection_diff_results = {}

    for multiplier in detection_multipliers:
        # Improved detection scenario
        bcm = get_bcm(params, covid_config, multiplier)
        scenario_result = esamp.model_results_for_samples(idata_extract, bcm).results

        # Calculate cumulative sums for each scenario
        yearly_data = scenario_result.loc[
            (scenario_result.index >= cumulative_start_time)
            & (scenario_result.index % 1 == 0)
        ]
        cumulative_diseased = yearly_data["incidence_raw"].cumsum()
        cumulative_deaths = yearly_data["mortality_raw"].cumsum()

        # Calculate differences compared to the base scenario
        abs_diff = {
            "cumulative_diseased": cumulative_diseased - cumulative_diseased_base,
            "cumulative_deaths": cumulative_deaths - cumulative_deaths_base,
        }
        rel_diff = {
            "cumulative_diseased": abs_diff["cumulative_diseased"]
            / cumulative_diseased_base,
            "cumulative_deaths": abs_diff["cumulative_deaths"] / cumulative_deaths_base,
        }

        # Calculate quantiles for absolute and relative differences
        diff_quantiles_abs = {}
        diff_quantiles_rel = {}

        for ind in ["cumulative_diseased", "cumulative_deaths"]:
            diff_quantiles_df_abs = pd.DataFrame(
                {
                    quantile: [
                        abs_diff[ind].loc[year].quantile(quantile) for year in years
                    ]
                    for quantile in quantiles
                },
                index=years,
            )

            diff_quantiles_df_rel = pd.DataFrame(
                {
                    quantile: [
                        rel_diff[ind].loc[year].quantile(quantile) for year in years
                    ]
                    for quantile in quantiles
                },
                index=years,
            )

            diff_quantiles_abs[ind] = diff_quantiles_df_abs
            diff_quantiles_rel[ind] = diff_quantiles_df_rel

        # Store the quantile results
        scenario_key = f"increase_case_detection_by_{multiplier}".replace(".", "_")
        detection_diff_results[scenario_key] = {
            "abs": diff_quantiles_abs,
            "rel": diff_quantiles_rel,
        }

    # Return the quantiles for absolute and relative differences
    return detection_diff_results

In [12]:
scenarios_diff = calculate_scenario_diff_cum_quantiles(params, idata, [2.0, 5.0, 12.0])

In [None]:
plot_detection_scenarios_comparison_box(scenarios_diff)