## Note
This notebook is used to genrate and save outputs as pickle files 

In [24]:
import warnings
warnings.filterwarnings("ignore")

from pathlib import Path
from typing import Dict, List, Optional

import pickle
import numpy as np
import pandas as pd
import arviz as az
import plotly.express as px
import plotly.graph_objects as go

import estival.priors as esp
import estival.targets as est
from estival.model import BayesianCompartmentalModel
from estival.sampling import tools as esamp

from tbdynamics.constants import QUANTILES
from tbdynamics.settings import DATA_PATH, BASE_PATH, VN_PATH
from tbdynamics.tools.inputs import load_params, load_targets, matrix
from tbdynamics.vietnam.model import build_model
from tbdynamics.vietnam.constants import params_name
from tbdynamics.vietnam.calibration.utils import (
    calculate_scenario_outputs,
    calculate_covid_diff_cum_quantiles,
    calculate_scenario_diff_cum_quantiles,
    calculate_diff_cum_detection_reduction,
    get_targets,
    get_all_priors,
    get_bcm
)
from tbdynamics.calibration.plotting import plot_abs_diff_scatter_multi
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from typing import List, Optional, Literal


In [2]:
# extract data - only run for first time
RUN_PATH = Path.cwd().parent.parent / 'runs/r0205'
OUT_PATH = Path.cwd().parent.parent / 'data/outputs/vietnam'

In [3]:
idata_raw = az.from_netcdf(RUN_PATH / 'calib_full_out.nc')
burnt_idata = idata_raw.sel(draw=np.s_[50000:])
idata = az.extract(burnt_idata, num_samples=300)
# inference_data = az.convert_to_inference_data(idata_extract.reset_index('sample'))
# az.to_netcdf(inference_data, OUT_PATH /'extracted_data.nc')

In [4]:
#  Load saved idata
# idata = az.from_netcdf(BASE_PATH / 'idata/idata_detection.nc')
params = {
    "start_population_size": 2000000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}
scenario_config = {"detection_reduction": True, "contact_reduction": False}

In [5]:
# bcm = get_bcm(params, scenario_config, None)
# base_results = esamp.model_results_for_samples(idata, bcm).results

In [6]:
# full_params_df = burnt_idata.posterior.to_dataframe().reset_index()

In [7]:
def calculate_covid_diff_cum_merge(
    params: Dict[str, float],
    idata_extract: az.InferenceData,
    cumulative_start_time: float = 2020.0,
) -> pd.DataFrame:
    """
    Run models for two COVID scenarios using extracted InferenceData,
    compute cumulative incidence and deaths for all years, and return absolute differences
    merged with posterior parameters.

    Args:
        params: Dictionary of model parameters.
        idata_extract: InferenceData object from az.extract().
        cumulative_start_time: Starting point for cumulative calculation.

    Returns:
        DataFrame with cumulative values and absolute differences for each year,
        merged with posterior parameters and sample IDs.
    """
    # Flatten posterior samples
    df_params = idata_extract.to_dataframe()
    df_params = df_params.drop(columns=[c for c in ['chain', 'draw'] if c in df_params.columns], errors='ignore')
    df_params = df_params.reset_index()

    covid_configs = [
        {"detection_reduction": False, "contact_reduction": False},
        {"detection_reduction": True, "contact_reduction": False},
    ]

    result_list = []

    for i, covid_effects in enumerate(covid_configs):
        bcm = get_bcm(params, covid_effects)
        model_res = esamp.model_results_for_samples(idata_extract, bcm).results

        subset = model_res[["incidence_raw", "mortality_raw"]]

        long_df = subset.stack(level=[0, 1, 2]).reset_index()
        long_df.columns = ["time", "variable", "chain", "draw", "value"]

        # Filter for yearly values ≥ start time
        yearly_df = long_df[
            (long_df["time"] >= cumulative_start_time) & (long_df["time"] % 1 == 0)
        ].copy()

        yearly_df = yearly_df.sort_values(["variable", "chain", "draw", "time"])
        yearly_df["cumulative"] = yearly_df.groupby(["variable", "chain", "draw"])["value"].cumsum()

        # Pivot to wide format for all years
        final_pivot = yearly_df.pivot_table(
            index=["chain", "draw", "time"],
            columns="variable",
            values="cumulative"
        ).reset_index()

        suffix = f"_scen{i}"
        final_pivot = final_pivot.rename(columns={
            "incidence_raw": f"cumulative_diseased{suffix}",
            "mortality_raw": f"cumulative_deaths{suffix}"
        })

        result_list.append(final_pivot)

    # Merge both scenarios on chain, draw, time
    merged = pd.merge(result_list[0], result_list[1], on=["chain", "draw", "time"])

    # Calculate absolute differences for all time points
    merged["abs_diff_cumulative_diseased"] = (
        merged["cumulative_diseased_scen1"] - merged["cumulative_diseased_scen0"]
    )
    merged["abs_diff_cumulative_deaths"] = (
        merged["cumulative_deaths_scen1"] - merged["cumulative_deaths_scen0"]
    )

    # Merge with posterior parameter values
    final_merged = pd.merge(merged, df_params, on=["chain", "draw"])

    return final_merged

In [8]:
df = calculate_covid_diff_cum_merge(params, idata)

In [None]:
def plot_abs_diff_scatter_multi(
    df: pd.DataFrame,
    outcome: Literal["cumulative_diseased", "cumulative_deaths"] = "cumulative_diseased",
    params: Optional[List[str]] = None,
    year: float = 2035.0,
    n_cols: int = 3
) -> go.Figure:
    """
    Plot absolute differences vs posterior parameters for a specific year using subplots.

    Args:
        df: DataFrame from `calculate_covid_diff_cum_merge`.
        outcome: Outcome to plot ('cumulative_diseased' or 'cumulative_deaths').
        params: List of posterior parameters to plot. If None, selects automatically.
        year: Single year to include in the plot.

    Returns:
        Plotly Figure with scatter plots.
    """
    df_filtered = df[df["time"].round(1) == round(year, 1)].copy()

    # Auto-select parameter names if not given
    if params is None:
        exclude = {
            "chain", "draw", "time",
            f"cumulative_diseased_scen0", f"cumulative_diseased_scen1",
            f"cumulative_deaths_scen0", f"cumulative_deaths_scen1",
            f"abs_diff_cumulative_diseased", f"abs_diff_cumulative_deaths"
        }
        params = [
            col for col in df.columns
            if col not in exclude and "_dispersion" not in col and df[col].dtype.kind in "fi"
        ]
    # subplot_titles = [params_name.get(p, p) for p in params]
    # Set up subplots
    n_rows = (len(params) + n_cols - 1) // n_cols
    fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=params, vertical_spacing=0.07)

    for i, param in enumerate(params):
        row = i // n_cols + 1
        col = i % n_cols + 1

        fig.add_trace(
            go.Scatter(
                x=df_filtered[param],
                y=df_filtered[f"abs_diff_{outcome}"],
                mode="markers",
                marker=dict(size=4, color="#636efa"),  # Default Plotly blue
                showlegend=False,
            ),
            row=row,
            col=col,
        )

        fig.update_xaxes(title_text="", row=row, col=col)
        fig.update_yaxes(title_text="", row=row, col=col, type = "log")

    fig.update_layout(
        height=150 * n_rows,
        title="",
        margin=dict(t=20, b=10),
    )

    return fig

In [42]:
plot_abs_diff_scatter_multi(df,  outcome="cumulative_diseased") #.write_image('uncer.png', scale=3)

In [44]:
params = ["contact_rate", "smear_positive_death_rate", "incidence_props_pulmonary", "incidence_props_smear_positive_among_pulmonary", "smear_negative_death_rate", "smear_positive_self_recovery", "smear_negative_self_recovery", "detection_reduction"]
plot_abs_diff_scatter_multi(df, outcome="cumulative_diseased", params=params, n_cols = 2).write_image('uncer.png', scale=3)

### Calculate the basecase outputs with scenarios of improving case detection

In [None]:
# basecase = calculate_scenario_outputs(params, idata)
# with open(OUT_PATH / 'quant_outputs3.pkl', 'wb') as f:
#      pickle.dump(basecase, f)

### Output for differences in cumulative diseased and cumulative deaths of COVID-19 vs no COVID-19

In [None]:
# covid_cum_outputs = calculate_covid_diff_cum_quantiles(params, idata)
# with open(OUT_PATH / 'covid_diff_quantiles.pkl', 'wb') as f:
#      pickle.dump(covid_cum_outputs, f)

### TB notifications in with different settings of COVID-19, with log likelihood

In [None]:
# notif_covid_outputs = calculate_notifications_for_covid(params, idata)
# with open(OUT_PATH /'notif_for_covid_with_ll.pkl', 'wb') as f:
#      pickle.dump(notif_covid_outputs, f)

### Cumulative diseased and death with different case detection scenarios

In [None]:
# scenarios_diff_quantiles = calculate_scenario_diff_cum_quantiles(params, idata, [2.0, 5.0, 12.0], extreme_transmission=True)
# with open(OUT_PATH/ 'scenarios_diff_outputs2.pkl', 'wb') as f:
#     pickle.dump(scenarios_diff_quantiles,f)