## Note
This notebook is used to genrate and save outputs as pickle files 

In [None]:
import warnings

warnings.filterwarnings("ignore")
from tbdynamics.vietnam.calibration.utils import (
    calculate_scenario_outputs,
    calculate_covid_diff_cum_quantiles,
    calculate_scenario_diff_cum_quantiles,
    calculate_diff_cum_detection_reduction
)
from tbdynamics.settings import DATA_PATH, BASE_PATH
from pathlib import Path
import arviz as az
import pickle
import numpy as np
from typing import Dict

In [None]:
# extract data - only run for first time
RUN_PATH = Path.cwd().parent.parent / 'runs/r0205'
OUT_PATH = Path.cwd().parent.parent / 'data/outputs/vietnam'

In [None]:
idata_raw = az.from_netcdf(RUN_PATH / 'calib_full_out.nc')
burnt_idata = idata_raw.sel(draw=np.s_[50000:])
idata = az.extract(burnt_idata, num_samples=1000)
# inference_data = az.convert_to_inference_data(idata_extract.reset_index('sample'))
# az.to_netcdf(inference_data, OUT_PATH /'extracted_data.nc')

In [None]:
#  Load saved idata
# idata = az.from_netcdf(BASE_PATH / 'idata/idata_detection.nc')
params = {
    "start_population_size": 2000000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}

### Calculate the basecase outputs with scenarios of improving case detection

In [None]:
# basecase = calculate_scenario_outputs(params, idata)
# with open(OUT_PATH / 'quant_outputs3.pkl', 'wb') as f:
#      pickle.dump(basecase, f)

In [None]:
summary = az.summary(idata_raw, var_names=["detection_reduction"], hdi_prob=0.95)

# Extract mean, lower, and upper bounds of 95% CrI
mean = summary.loc["detection_reduction", "mean"]
lower = summary.loc["detection_reduction", "hdi_2.5%"]
upper = summary.loc["detection_reduction", "hdi_97.5%"]

In [None]:
mean, lower, upper

In [None]:
covid_uncertainties = calculate_diff_cum_detection_reduction(params, idata, [lower, mean, upper])

In [None]:
covid_uncertainties

### Output for differences in cumulative diseased and cumulative deaths of COVID-19 vs no COVID-19

In [None]:
covid_cum_outputs = calculate_covid_diff_cum_quantiles(params, idata)
# with open(OUT_PATH / 'covid_diff_quantiles.pkl', 'wb') as f:
#      pickle.dump(covid_cum_outputs, f)

In [None]:
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
def plot_abs_diff_boxplot(diff_output):
    """
    Plot boxplots of absolute differences in cumulative diseased and deaths (2035)
    across detection reduction values.

    Args:
        diff_output: Output dictionary from calculate_diff_cum_detection_reduction
    """
    records = []
    for scenario_label, df in diff_output["abs"].items():
        reduction_value = float(scenario_label.split("_")[-1])
        for indicator in ["cumulative_diseased", "cumulative_deaths"]:
            for quantile, val in df.loc[indicator].items():
                records.append({
                    "reduction_value": reduction_value,
                    "indicator": "Cumulative Diseased" if indicator == "cumulative_diseased" else "Cumulative Deaths",
                    "quantile": quantile,
                    "value": val
                })

    df_plot = pd.DataFrame(records)

    fig = px.box(
        df_plot,
        x="reduction_value",
        y="value",
        color="indicator",
        points=False,
        labels={
            "reduction_value": "Detection Reduction Value",
            "value": "Absolute Difference (2035)",
            "indicator": "Outcome"
        },
        title="Absolute Differences in Cumulative TB Outcomes by Detection Reduction (2035)"
    )

    fig.update_layout(boxmode="group", legend_title_text="Outcome")
    fig.show()

In [None]:
plot_abs_diff_boxplot(covid_uncertainties)

### TB notifications in with different settings of COVID-19, with log likelihood

In [None]:
# notif_covid_outputs = calculate_notifications_for_covid(params, idata)
# with open(OUT_PATH /'notif_for_covid_with_ll.pkl', 'wb') as f:
#      pickle.dump(notif_covid_outputs, f)

### Cumulative diseased and death with different case detection scenarios

In [None]:
# scenarios_diff_quantiles = calculate_scenario_diff_cum_quantiles(params, idata, [2.0, 5.0, 12.0], extreme_transmission=True)
# with open(OUT_PATH/ 'scenarios_diff_outputs2.pkl', 'wb') as f:
#     pickle.dump(scenarios_diff_quantiles,f)