In [9]:
from typing import List, Dict
import arviz as az
import estival.priors as esp
from estival.sampling import tools as esamp
import plotly.graph_objects as go
import plotly.express as px
from tbdynamics.tools.utils import round_sigfig
import pandas as pd
from tbdynamics.settings import OUT_PATH
from tbdynamics.vietnam.calibration.utils import get_bcm
from plotly.subplots import make_subplots

pd.options.plotting.backend = "plotly"

from tbdynamics.settings import OUT_PATH

In [2]:
def calculate_covid_cum_diff(
    params: Dict[str, float],
    idata_extract: az.InferenceData,
    cumulative_start_time: float = 2020.0,
    years: List[float] = [2021.0, 2022.0, 2025.0, 2030.0, 2035.0],
) -> Dict[str, Dict[str, pd.DataFrame]]:
    """
    Run the models for the specified scenarios, calculate cumulative diseased and death values,
    and return the difference of detection reduction only - no covid.

    Args:
        params: Dictionary containing model parameters.
        idata_extract: InferenceData object containing the model data.
        cumulative_start_time: Year to start calculating the cumulative values.
        years: List of years for which to calculate the differences.

    Returns:
        A dictionary containing the difference in cumulative diseased and deaths (detection only - no covid).
    """

    # Define the scenarios with scenario names as keys
    covid_configs = {
        "no_covid": {"detection_reduction": False, "contact_reduction": False},  # No reduction
        "detection_reduction_only": {"detection_reduction": True, "contact_reduction": False},  # Detection reduction only
    }

    scenario_results = {}

    for scenario_name, covid_effects in covid_configs.items():
        # Get the model results
        bcm = get_bcm(params, covid_effects)
        spaghetti_res = esamp.model_results_for_samples(idata_extract, bcm).results

        # Filter the results to include only the rows where the index (year) is an integer
        yearly_data = spaghetti_res.loc[
            (spaghetti_res.index >= cumulative_start_time) & (spaghetti_res.index % 1 == 0)
        ]

        # Calculate cumulative sums for each sample
        cumulative_diseased_yearly = yearly_data["incidence_raw"].cumsum()
        cumulative_deaths_yearly = yearly_data["mortality_raw"].cumsum()

        # Extract results for specified years
        cumulative_diseased_results = cumulative_diseased_yearly.loc[years]
        cumulative_deaths_results = cumulative_deaths_yearly.loc[years]

        # Store the results in the dictionary
        scenario_results[scenario_name] = {
            "cumulative_diseased": cumulative_diseased_results,
            "cumulative_deaths": cumulative_deaths_results,
        }

    # Calculate the difference between detection only and no covid scenarios
    diff_results = {
        "cumulative_diseased": scenario_results["detection_reduction_only"]["cumulative_diseased"]
        - scenario_results["no_covid"]["cumulative_diseased"],
        "cumulative_deaths": scenario_results["detection_reduction_only"]["cumulative_deaths"]
        - scenario_results["no_covid"]["cumulative_deaths"],
    }

    return diff_results


In [3]:
def plot_covid_diff(scenario_results: Dict[str, Dict[str, pd.DataFrame]]):
    """
    Plot the difference in cumulative diseased and deaths between the 'detection only' and 'no covid' scenarios.

    Args:
        scenario_results: Dictionary containing cumulative results and the difference between scenarios.
    """
    # Extract the difference data for each indicator
    diff_data = scenario_results["diff_detection_minus_no_covid"]
    
    indicators = ["cumulative_diseased", "cumulative_deaths"]

    for indicator in indicators:
        fig = go.Figure()
        
        # Extract the DataFrame for the current indicator
        diff_df = diff_data[indicator]
        
        # Add each column as a separate line in the plot
        for col in diff_df.columns:
            fig.add_trace(
                go.Scatter(
                    x=diff_df.index,
                    y=diff_df[col],
                    mode='lines',
                    name=f'{indicator} - {col}'
                )
            )
        
        # Customize layout
        fig.update_layout(
            title=f"Difference for {indicator.replace('_', ' ').capitalize()} (Detection Only - No Covid)",
            xaxis_title="Year",
            yaxis_title=indicator.replace('_', ' ').capitalize(),
            legend=dict(title="Quantiles", x=0.01, y=0.99),
            template="plotly_white"
        )
        
        # Display the figure
        fig.show()

In [4]:
idata = az.from_netcdf(OUT_PATH / 'vietnam/idata/idata_detection.nc')
params = {
    "start_population_size": 2000000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}

In [20]:
def plot_covid_cum_diff_extremes(diff_results: Dict[str, pd.DataFrame], idata: az.InferenceData):
    """
    Plot spaghetti lines for the cumulative diseased and cumulative deaths differences between
    'detection only' and 'no covid' scenarios, using only the columns with the lowest and highest values
    in the last row. Each line is labeled with its corresponding parameter set in the legend, 
    with each parameter in a new line, and the legend positioned at the bottom.

    Args:
        diff_results: Dictionary containing the difference DataFrames for cumulative diseased and deaths.
        idata: InferenceData object containing the posterior parameter values.
    """
    indicators = ["cumulative_diseased", "cumulative_deaths"]
    posterior_samples = idata.posterior

    # Create a subplot figure with 1 row and 2 columns for the two indicators
    fig = make_subplots(rows=1, cols=2, subplot_titles=[indicator.replace('_', ' ').capitalize() for indicator in indicators])

    for idx, indicator in enumerate(indicators, start=1):
        # Get the DataFrame for the current indicator
        diff_df = diff_results[indicator]

        # Identify the columns (samples) with the lowest and highest values in the last row
        last_row = diff_df.iloc[-1]
        min_col = last_row.idxmin()
        max_col = last_row.idxmax()

        # Plot only the min and max columns
        for sample_idx, col in enumerate(diff_df.columns):
            if col not in [min_col, max_col]:
                continue
            
            # Extract parameter values for the current sample
            param_info = posterior_samples.isel(sample=sample_idx)

            # Create a legend label with each parameter on a new line
            legend_label = "<br>".join(
                f"{param}: {round(float(value.values), 3)}"
                for param, value in param_info.items()
            )

            fig.add_trace(
                go.Scatter(
                    x=diff_df.index,
                    y=diff_df[col],
                    mode='lines',
                    name=legend_label,
                    line=dict(width=2)
                ),
                row=1,
                col=idx
            )

        # Customize each subplot
        fig.update_xaxes(title_text="Year", row=1, col=idx)
        fig.update_yaxes(title_text=f"{indicator.replace('_', ' ').capitalize()} Difference", row=1, col=idx)

    # Update the overall layout with the legend at the bottom
    fig.update_layout(
        title="Extreme Cumulative Diseased and Deaths Differences (Detection Only - No Covid)",
        showlegend=True,
        legend=dict(
            title="Parameter Sets",
            x=0.5,
            y=-0.3,
            xanchor="center",
            orientation="h",  # Horizontal orientation
            traceorder="normal"
        ),
        height=600,
        template="plotly_white"
    )

    # Display the combined figure
    fig.show()



In [6]:
covid_diff = calculate_covid_cum_diff(params, idata)

In [None]:
plot_covid_cum_diff_extremes(covid_diff, idata)