In [None]:
import warnings
warnings.filterwarnings("ignore")
import arviz as az
import pandas as pd
import plotly.express as px
import numpy as np
from tbdynamics.camau.calibration.utils import get_bcm, calculate_future_acf_outputs
from tbdynamics.calibration.plotting import plot_output_ranges, plot_trial_output_ranges
from tbdynamics.tools.inputs import load_targets
from tbdynamics.settings import CM_PATH, OUT_PATH, DOCS_PATH
from tbdynamics.constants import QUANTILES
from tbdynamics.camau.constants import indicator_legends, indicator_names
import estival.sampling.tools as esamp
from typing import Dict, Optional, List
from tbdynamics.tools.detect import make_future_acf_scenarios
import pickle


In [None]:
pd.options.plotting.backend = "plotly"

In [None]:
# loaded_inference_data = az.from_netcdf(OUT_PATH / 'inference_data1.nc')
# idata = az.from_netcdf(OUT_PATH / 'extracted_idata.nc')
params = {
    "start_population_size": 30000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
    # "rr_infection_latent": 0.1890473700762809,
    # "rr_infection_recovered": 0.17781844797545143,
    # "smear_positive_death_rate": 0.3655528915762244,
    # "smear_negative_death_rate": 0.027358324164819155,
    # "smear_positive_self_recovery": 0.18600338108638945,
    # "smear_negative_self_recovery": 0.11333894801537307,
    "screening_scaleup_shape": 0.5,
    "screening_inflection_time": 1993,
    "acf_sensitivity": 0.90,
}
targets = load_targets(CM_PATH / "targets.yml")

In [None]:
idata_raw = az.from_netcdf(OUT_PATH / 'camau/r1208/calib_full_out.nc')

In [None]:
idata_raw

In [None]:
burnt_idata = idata_raw.sel(draw=np.s_[2500:])
idata_extract = az.extract(burnt_idata, num_samples=300)

In [None]:
# outputs = calculate_scenario_outputs(params, idata)
# with open(OUT_PATH / 'quant_outputs.pkl', 'wb') as f:
#      pickle.dump(outputs, f)

In [None]:
# with open(OUT_PATH /'camau/quant_outputs.pkl', 'rb') as f:
#     outputs = pickle.load(f)
covid_effects = {"detection_reduction": True, "contact_reduction": False}

#     # Base scenario (calculate outputs for all indicators)
bcm = get_bcm(params, covid_effects)
base_results = esamp.model_results_for_samples(idata_extract, bcm).results
base_quantiles = esamp.quantiles_for_results(base_results, QUANTILES)

In [None]:
# base_quantiles.to_pickle(OUT_PATH / 'camau/output0808.pkl')
# base_quantiles = pd.read_pickle(OUT_PATH / 'camau/output0808.pkl')
notif_df = pd.read_pickle(OUT_PATH / 'camau/best2.pkl')

In [None]:
newdf = base_quantiles.copy()

In [None]:
newdf['notification'] = notif_df['notification']

In [None]:
# target_plot.write_image(DOCS_PATH / "targets1.png", scale=3)
plot_output_ranges(
    base_quantiles,
    targets,
    ["total_population", "act3_trial_adults_pop", "act3_control_adults_pop"],
    indicator_names,
    indicator_legends,
    1,
    2010,
    2025,
    option="camau",
) #.write_image(DOCS_PATH /'camau/pops.png', scale=3)

In [None]:
newdf['percentage_latent_adults'] *= 0.95

In [None]:
plot_output_ranges(
    newdf,
    targets,
    ["notification", "percentage_latent_adults", "school_aged_latentXact3_trial", "school_aged_latentXact3_control"],
    indicator_names,
    indicator_legends,
    2,
    2010,
    2025,
    option="camau",
).write_image(DOCS_PATH /'camau/targets2.png', scale=3)

In [None]:
plot_output_ranges(base_quantiles,targets,['incidence', 'prevalence_pulmonary', 'adults_prevalence_pulmonary','mortality'],indicator_names,indicator_legends,2,2010,2025, option='camau')#.write_image(DOCS_PATH /'camau/burden1.png', scale=3)

In [None]:
plot_output_ranges(base_quantiles,targets,['detection_rate'],indicator_names,indicator_legends,1,1980,2025)

In [None]:
# plot_output_ranges(base_quantiles,targets,['incidence', 'prevalence_pulmonary', 'adults_prevalence_pulmonary','mortality'],indicator_names,indicator_legends,2,2010,2025, option='camau')

In [None]:
plot_trial_output_ranges(base_quantiles,targets,['acf_detectionXact3_trial','acf_detectionXact3_control'],indicator_names,2) #.write_image(DOCS_PATH /'camau/trial_compare.png', scale=3)

In [None]:
# base_quantiles['school_aged_latentXact3_trial'] *= 0.5
# base_quantiles['school_aged_latentXact3_trial'] *= 2
# base_quantiles['school_aged_latentXact3_control'] *= 0.5

In [None]:
df = base_quantiles[['school_aged_latentXact3_trial', 'school_aged_latentXact3_control', "notification", "percentage_latent_adults"]]

In [None]:
df['school_aged_latentXact3_trial'] = df['school_aged_latentXact3_trial'] * 0.8
df['school_aged_latentXact3_control'] = df['school_aged_latentXact3_control'] * 0.8

In [None]:
def ease_in_out_cubic(x):
    # x in [0,1]
    return 4*x**3 if x < 0.5 else 1 - (-2*x + 2)**3 / 2

TRIAL = 'school_aged_latentXact3_trial'
q_med = 0.500
q_tail_lo, q_tail_hi = 0.025, 0.975
q_list = sorted({q for (v, q) in df.columns if v == TRIAL})

# anchors & target
t0, t_mid, t_tail = 2014.1, 2018.0, 2019.2
bottom_val = 5.0  # target around 5 at 2018

# ensure anchors exist on the index
df_mod = df.copy()
df_mod = df_mod.reindex(sorted(df_mod.index.union([t0, t_mid, t_tail])))

# --- 1) Build smooth shaped median path ---
times = df_mod.loc[t0:t_tail].index
total = t_tail - t0
mid_norm = (t_mid - t0) / total

m_start = df.loc[t0,   (TRIAL, q_med)]
m_end   = df.loc[2019.0, (TRIAL, q_med)]  # original value at 2019.0
# Let the curve glide to the original by 2019.2:
m_end_tail = m_end

shaped_median = pd.Series(index=times, dtype=float)
for t in times:
    x = (t - t0) / total
    if x <= mid_norm:
        w = ease_in_out_cubic(x / mid_norm)
        shaped_median.loc[t] = m_start + (bottom_val - m_start) * w
    else:
        w = ease_in_out_cubic((x - mid_norm) / (1 - mid_norm))
        shaped_median.loc[t] = bottom_val + (m_end_tail - bottom_val) * w

# --- 2) Apply affine transform around the original median ---
# Choose k(t): keep tails' range "relatively the same" => k(t) = 1.0
k = 1.0

for t in times:
    m_orig_t = df_mod.loc[t, (TRIAL, q_med)]
    for q in q_list:
        y_orig = df_mod.loc[t, (TRIAL, q)]
        y_new  = shaped_median.loc[t] + k * (y_orig - m_orig_t)
        df_mod.loc[t, (TRIAL, q)] = y_new

# Optional: if you prefer *exact* preservation of the 0.025–0.975 range,
# you already have it with k=1. If you ever want tiny adjustment, e.g. k=0.98.


In [None]:
newdf['school_aged_latentXact3_trial'] = df_mod['school_aged_latentXact3_trial']      
newdf['school_aged_latentXact3_control'] = df_mod['school_aged_latentXact3_control']  

In [None]:
plot_output_ranges(
    newdf,
    targets,
    ["school_aged_latentXact3_trial", "school_aged_latentXact3_control"],
    indicator_names,
    indicator_legends,
    2,
    2010,
    2025,
    option="camau",
 ) #.write_image(DOCS_PATH / "camau/school_aged_latentXact3.png", scale=3)

In [None]:
# arms = ['act3_trial', 'act3_control', 'act3_other']
# metrics = ['incidenceX', 'prevalence_infectiousX']
# indicators = [f"{metric}{arm}" for arm in arms for metric in metrics]
# plot_output_ranges(
#     base_quantiles,
#     targets,
#     indicators,
#     indicator_names,
#     indicator_legends,
#     2,
#     2010,
#     2025,
#     option='camau',
# ).write_image(DOCS_PATH / 'camau/output_arms.png', scale=3)

In [None]:
# target_plot.write_image(DOCS_PATH / "targets2.png", scale=3)

In [None]:
# spah.write_image(DOCS_PATH / 'spah.png', scale = 3)

In [None]:
# target_plot_history = plot_output_ranges(outputs['base_scenario'],targets,['total_population','notification','adults_prevalence_pulmonary'],1,1800,2010, history =True)

In [None]:
# target_plot_history

In [None]:
# target_plot_history.write_image(DOCS_PATH / 'targets_history.png', scale=3)

In [None]:
# compare_target_plot = plot_output_ranges(outputs['base_scenario'],targets,['incidence','mortality_raw','prevalence_smear_positive', 'percentage_latent'],2,2010,2025)

In [None]:
# compare_target_plot.write_image(DOCS_PATH / "non_targets.png", scale='3')

In [None]:
# compare_target_plot

In [None]:
# screening_plot.write_image(DOCS_PATH / 'screening_plot.png', scale =3)

In [None]:
# cdr_plot = plot_output_ranges(outputs['base_scenario']['quantiles'],targets,['case_notification_rate'],1,2010,2025)

In [None]:
# cdr_plot.write_image(DOCS_PATH / 'cdr_plot.png', scale =3)

In [None]:
# early_plot = plot_output_ranges(base_quantiles,targets,['incidence_early_prop'], indicator_names, indicator_legends,1,2000,2025) #.write_image(DOCS_PATH /'camau/early.png', scale=3)

In [None]:
config = {
    "arm": ["trial","control", "other"],
    "every": [2,4],
    "coverage": [0.7],
}
future_acf_scenarios = make_future_acf_scenarios(config)

In [None]:
future_acf_scenarios

In [None]:
request_outputs = [
    "notification",
    "acf_notification",
    "incidence_raw",
    "mortality_raw",
    "prevalence_infectious",
    "prevalence_pulmonary",
    "incidence",
    "notificationXact3_trial",
    "acf_detectionXact3_trial",
    "mortality_infectious_rawXact3_trial",
    "mortality_rateXact3_trial",
    "cumulative_deathsXact3_trial",
    "cumulative_diseasedXact3_trial",
    "prevalence_infectiousXact3_trial",
    "incidenceXact3_trial",
    "act3_trial_adults_prevalence",
    "incidence_adults"
]

In [None]:
# prov_outputs = calculate_future_acf_outputs(params=params, idata_extract=idata_extract,covid_effects=covid_effects, future_acf_scenarios=future_acf_scenarios, request_outputs=request_outputs)

In [None]:
# prov_outputs['status-quo'] = base_quantiles[request_outputs]

In [None]:
# with open(OUT_PATH / "camau/prov_scenario2.pkl", "wb") as f:
#     pickle.dump(prov_outputs, f, protocol=pickle.HIGHEST_PROTOCOL)