In [None]:
from typing import List, Dict
import arviz as az
import estival.priors as esp
from estival.sampling import tools as esamp
import plotly.graph_objects as go
import plotly.express as px
from tbdynamics.tools.utils import round_sigfig
import pandas as pd
from tbdynamics.settings import OUT_PATH
from tbdynamics.vietnam.calibration.utils import get_bcm
from plotly.subplots import make_subplots

pd.options.plotting.backend = "plotly"

from tbdynamics.settings import OUT_PATH

In [None]:
def calculate_covid_cum_diff(
    params: Dict[str, float],
    idata_extract: az.InferenceData,
    cumulative_start_time: float = 2020.0,
    years: List[float] = [2021.0, 2022.0, 2025.0, 2030.0, 2035.0],
) -> Dict[str, Dict[str, pd.DataFrame]]:
    """
    Run the models for the specified scenarios, calculate cumulative diseased and death values,
    and return both the raw values and differences for detection reduction only - no covid.

    Args:
        params: Dictionary containing model parameters.
        idata_extract: InferenceData object containing the model data.
        cumulative_start_time: Year to start calculating the cumulative values.
        years: List of years for which to calculate the differences.

    Returns:
        A dictionary containing cumulative diseased and deaths, incidence, and mortality raw data for each scenario,
        as well as the differences between "detection" and "no_covid".
    """

    # Define the scenarios with scenario names as keys
    covid_configs = {
        "no_covid": {"detection_reduction": False, "contact_reduction": False},  # No reduction
        "detection": {"detection_reduction": True, "contact_reduction": False},  # Detection reduction only
    }

    scenario_results = {}

    for scenario_name, covid_effects in covid_configs.items():
        # Get the model results
        bcm = get_bcm(params, covid_effects)
        spaghetti_res = esamp.model_results_for_samples(idata_extract, bcm).results

        # Filter the results to include only the rows where the index (year) is an integer
        yearly_data = spaghetti_res.loc[
            (spaghetti_res.index >= cumulative_start_time) & (spaghetti_res.index % 1 == 0)
        ]

        # Calculate cumulative sums for each sample
        cumulative_diseased_yearly = yearly_data["incidence_raw"].cumsum()
        cumulative_deaths_yearly = yearly_data["mortality_raw"].cumsum()

        # Extract results for specified years
        cumulative_diseased_results = cumulative_diseased_yearly.loc[years]
        cumulative_deaths_results = cumulative_deaths_yearly.loc[years]

        # Store all relevant results in the dictionary under the scenario name
        scenario_results[scenario_name] = {
            "cumulative_diseased_yearly": cumulative_diseased_yearly,
            "cumulative_deaths_yearly": cumulative_deaths_yearly,
            "incidence_raw": yearly_data["incidence_raw"],
            "mortality_raw": yearly_data["mortality_raw"],
        }

    # Calculate the differences between "detection" and "no_covid" for cumulative values only
    diff_results = {
        "cumulative_diseased_yearly_diff": (
            scenario_results["detection"]["cumulative_diseased_yearly"]
            - scenario_results["no_covid"]["cumulative_diseased_yearly"]
        ),
        "cumulative_deaths_yearly_diff": (
            scenario_results["detection"]["cumulative_deaths_yearly"]
            - scenario_results["no_covid"]["cumulative_deaths_yearly"]
        ),
    }

    # Combine both raw scenario results and the differences into the final output
    return {"scenario_results": scenario_results, "diff_results": diff_results}


In [None]:
idata = az.from_netcdf(OUT_PATH / 'vietnam/idata/idata_detection.nc')
params = {
    "start_population_size": 2000000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}

In [None]:
covid_diff = calculate_covid_cum_diff(params, idata)

In [None]:
def plot_diff(results: Dict[str, Dict[str, pd.DataFrame]], idata: az.InferenceData):
    """
    Plot the differences for cumulative diseased and cumulative deaths between
    'detection only' and 'no covid' scenarios, using only the columns with the 
    lowest and highest values in the last row.

    Args:
        results: Dictionary containing the difference DataFrames for cumulative diseased and deaths.
        idata: InferenceData object containing the posterior parameter values.
    """
    diff_results = results["diff_results"]
    indicators = ["cumulative_diseased_yearly_diff", "cumulative_deaths_yearly_diff"]
    posterior_samples = idata.posterior

    # Create subplots for differences
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=["Cumulative Diseased Yearly Difference", "Cumulative Deaths Yearly Difference"]
    )

    for idx, indicator in enumerate(indicators, start=1):
        diff_df = diff_results[indicator]
        last_row = diff_df.iloc[-1]
        min_col = last_row.idxmin()
        max_col = last_row.idxmax()

        # Plot only the min and max columns for differences
        for col in [min_col, max_col]:
            sample_idx = int(col)
            param_info = posterior_samples.isel(sample=sample_idx)
            legend_label = "<br>".join(
                f"{param}: {round(float(value.values), 3)}"
                for param, value in param_info.items()
            )
            fig.add_trace(
                go.Scatter(
                    x=diff_df.index,
                    y=diff_df[col],
                    mode='lines',
                    name=legend_label,
                    line=dict(width=2)
                ),
                row=1,
                col=idx
            )

    # Customize axes
    for i in range(1, 3):
        fig.update_xaxes(title_text="Year", row=1, col=i)
        # fig.update_yaxes(title_text=f"{indicators[i-1].replace('_', ' ').capitalize()} (Difference)", row=1, col=i)

    # Update layout with a central legend at the bottom
    fig.update_layout(
        title="COVID Scenario Differences: Cumulative Diseased and Deaths",
        showlegend=True,
        legend=dict(
            title="Parameter Sets",
            x=0.5,
            y=-0.5,
            xanchor="center",
            orientation="h"
        ),
        height=680,
        template="plotly_white"
    )

    fig.show()



In [None]:
plot_diff(covid_diff, idata)

In [None]:
def plot_raw(results: Dict[str, Dict[str, pd.DataFrame]], idata: az.InferenceData):
    """
    Plot the raw results for cumulative diseased and cumulative deaths for each scenario,
    with parameter set details from `idata`. Only the first unique minimum and maximum 
    values are used if there are duplicates. The legend is grouped by scenario name, 
    with each parameter set in columns under each scenario.

    Args:
        results: Dictionary containing raw data for cumulative diseased and deaths for each scenario.
        idata: InferenceData object containing the posterior parameter values.
    """
    scenario_results = results["scenario_results"]
    indicators = ["incidence_raw", "mortality_raw"]
    posterior_samples = idata.posterior

    # Create subplots for raw results
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=["Incidence (Raw)", "Mortality (Raw)"]
    )

    for idx, indicator in enumerate(indicators, start=1):
        raw_df_no_covid = scenario_results["no_covid"][indicator]
        raw_df_detection = scenario_results["detection"][indicator]

        # Get only the first occurrence of the min and max columns for each scenario
        min_col_no_covid = raw_df_no_covid.iloc[-1].idxmin()
        max_col_no_covid = raw_df_no_covid.iloc[-1].idxmax()
        
        min_col_detection = raw_df_detection.iloc[-1].idxmin()
        max_col_detection = raw_df_detection.iloc[-1].idxmax()

        # Plot the first unique min and max for No COVID scenario
        for i, col in enumerate([min_col_no_covid, max_col_no_covid]):
            sample_idx = int(col)
            param_info = posterior_samples.isel(sample=sample_idx)
            legend_label = "<br>".join(
                f"{param}: {round(float(value.values), 3)}"
                for param, value in param_info.items()
            )
            fig.add_trace(
                go.Scatter(
                    x=raw_df_no_covid.index,
                    y=raw_df_no_covid[col],
                    mode='lines',
                    name=f"<b>No COVID</b><br>{legend_label}",
                    legendgroup="No COVID",
                    showlegend=(i == 0),
                    line=dict(color='blue', width=2)
                ),
                row=1,
                col=idx
            )

        # Plot the first unique min and max for Detection Only scenario
        for i, col in enumerate([min_col_detection, max_col_detection]):
            sample_idx = int(col)
            param_info = posterior_samples.isel(sample=sample_idx)
            legend_label = "<br>".join(
                f"{param}: {round(float(value.values), 3)}"
                for param, value in param_info.items()
            )
            fig.add_trace(
                go.Scatter(
                    x=raw_df_detection.index,
                    y=raw_df_detection[col],
                    mode='lines',
                    name=f"<b>Detection Only</b><br>{legend_label}",
                    legendgroup="Detection Only",
                    showlegend=(i == 0),
                    line=dict(color='green', width=2)
                ),
                row=1,
                col=idx
            )

    # Customize axes
    for i in range(1, 3):
        fig.update_xaxes(title_text="Year", row=1, col=i)
        fig.update_yaxes(title_text=f"{indicators[i-1].replace('_', ' ').capitalize()} (Raw)", row=1, col=i)

    # Update layout with a structured legend, grouping by scenario name
    fig.update_layout(
        title="COVID Scenario Raw Results: Diseased and Deaths",
        showlegend=True,
        legend=dict(
            title_text="Scenario<br><span style='font-size:12px'>Grouped by scenario with parameter sets in columns</span>",
            orientation="h",
            x=0.5,
            y=-0.5,
            xanchor="center",
            font=dict(size=10),
            traceorder="normal",
            itemwidth=200
        ),
        height=800,
        template="plotly_white"
    )

    fig.show()


In [None]:
plot_raw(covid_diff, idata)

In [None]:
def plot_cumulative(results: Dict[str, Dict[str, pd.DataFrame]], idata: az.InferenceData):
    """
    Plot the raw results for cumulative diseased and cumulative deaths for each scenario,
    with parameter set details from `idata`. Only the first unique minimum and maximum 
    values are used if there are duplicates. The legend is grouped by scenario name, 
    with each parameter set in columns under each scenario.

    Args:
        results: Dictionary containing raw data for cumulative diseased and deaths for each scenario.
        idata: InferenceData object containing the posterior parameter values.
    """
    scenario_results = results["scenario_results"]
    indicators = ["cumulative_diseased_yearly", "cumulative_deaths_yearly"]
    posterior_samples = idata.posterior

    # Create subplots for raw results
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=["Incidence (Raw)", "Mortality (Raw)"]
    )

    for idx, indicator in enumerate(indicators, start=1):
        raw_df_no_covid = scenario_results["no_covid"][indicator]
        raw_df_detection = scenario_results["detection"][indicator]

        # Get only the first occurrence of the min and max columns for each scenario
        min_col_no_covid = raw_df_no_covid.iloc[-1].idxmin()
        max_col_no_covid = raw_df_no_covid.iloc[-1].idxmax()
        
        min_col_detection = raw_df_detection.iloc[-1].idxmin()
        max_col_detection = raw_df_detection.iloc[-1].idxmax()

        # Plot the first unique min and max for No COVID scenario
        for i, col in enumerate([min_col_no_covid, max_col_no_covid]):
            sample_idx = int(col)
            param_info = posterior_samples.isel(sample=sample_idx)
            legend_label = "<br>".join(
                f"{param}: {round(float(value.values), 3)}"
                for param, value in param_info.items()
            )
            fig.add_trace(
                go.Scatter(
                    x=raw_df_no_covid.index,
                    y=raw_df_no_covid[col],
                    mode='lines',
                    name=f"<b>No COVID</b><br>{legend_label}",
                    legendgroup="No COVID",
                    showlegend=(i == 0),
                    line=dict(color='blue', width=2)
                ),
                row=1,
                col=idx
            )

        # Plot the first unique min and max for Detection Only scenario
        for i, col in enumerate([min_col_detection, max_col_detection]):
            sample_idx = int(col)
            param_info = posterior_samples.isel(sample=sample_idx)
            legend_label = "<br>".join(
                f"{param}: {round(float(value.values), 3)}"
                for param, value in param_info.items()
            )
            fig.add_trace(
                go.Scatter(
                    x=raw_df_detection.index,
                    y=raw_df_detection[col],
                    mode='lines',
                    name=f"<b>Detection Only</b><br>{legend_label}",
                    legendgroup="Detection Only",
                    showlegend=(i == 0),
                    line=dict(color='green', width=2)
                ),
                row=1,
                col=idx
            )

    # Customize axes
    for i in range(1, 3):
        fig.update_xaxes(title_text="Year", row=1, col=i)
        fig.update_yaxes(title_text=f"{indicators[i-1].replace('_', ' ').capitalize()} (Cumulative)", row=1, col=i)

    # Update layout with a structured legend, grouping by scenario name
    fig.update_layout(
        title="COVID Scenario Raw Results: Cumulative Diseased and Deaths",
        showlegend=True,
        legend=dict(
            title_text="Scenario<br><span style='font-size:12px'>Grouped by scenario with parameter sets in columns</span>",
            orientation="h",
            x=0.5,
            y=-0.5,
            xanchor="center",
            font=dict(size=10),
            traceorder="normal",
            itemwidth=200
        ),
        height=800,
        template="plotly_white"
    )

    fig.show()


In [None]:
plot_cumulative(covid_diff, idata)