In [None]:
import warnings
warnings.filterwarnings("ignore")
import arviz as az
import pandas as pd
import plotly.express as px
import numpy as np
from tbdynamics.camau.calibration.utils import get_bcm
from tbdynamics.calibration.plotting import plot_output_ranges, plot_trial_output_ranges
from tbdynamics.tools.inputs import load_targets
from tbdynamics.settings import CM_PATH, OUT_PATH, DOCS_PATH
from tbdynamics.constants import COMPARTMENTS, QUANTILES
from tbdynamics.camau.constants import indicator_legends, indicator_names
import estival.sampling.tools as esamp


In [None]:
# loaded_inference_data = az.from_netcdf(OUT_PATH / 'inference_data1.nc')
# idata = az.from_netcdf(OUT_PATH / 'extracted_idata.nc')
params = {
    "start_population_size": 30000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
    # "contact_rate": 0.02,
    "rr_infection_latent": 0.1890473700762809,
    "rr_infection_recovered": 0.17781844797545143,
    "smear_positive_death_rate": 0.3655528915762244,
    "smear_negative_death_rate": 0.027358324164819155,
    "smear_positive_self_recovery": 0.18600338108638945,
    "smear_negative_self_recovery": 0.11333894801537307,
    "screening_scaleup_shape": 0.3,
    "screening_inflection_time": 1993,
    # "time_to_screening_end_asymp": 2.1163556520843936,
    "acf_sensitivity": 0.90,
    # "prop_mixing_same_stratum": 0.6920672992582717,
    # "early_prop_adjuster": -0.017924441638418186,
    # "late_reactivation_adjuster": 1.1083422207175728,
    "detection_reduction": 0.30,
    # "total_population_dispersion": 3644.236227852164,
    # "notif_dispersion": 88.37092488550051,
    # "latent_dispersion": 7.470896188551709,
}
targets = load_targets(CM_PATH / "targets.yml")

In [None]:
idata_raw = az.from_netcdf(OUT_PATH / 'camau/best/calib_full_out.nc')

In [None]:
idata_raw

In [None]:
burnt_idata = idata_raw.sel(draw=np.s_[2500:])
idata_extract = az.extract(burnt_idata, num_samples=300)

In [None]:
# outputs = calculate_scenario_outputs(params, idata)
# with open(OUT_PATH / 'quant_outputs.pkl', 'wb') as f:
#      pickle.dump(outputs, f)

In [None]:
# with open(OUT_PATH /'camau/quant_outputs.pkl', 'rb') as f:
#     outputs = pickle.load(f)
scenario_config = {"detection_reduction": True, "contact_reduction": False}

#     # Base scenario (calculate outputs for all indicators)
# bcm = get_bcm(params, scenario_config)
# base_results = esamp.model_results_for_samples(idata_extract, bcm).results
# base_quantiles = esamp.quantiles_for_results(base_results, QUANTILES)


In [None]:
# base_quantiles.to_pickle(OUT_PATH / 'camau/output0304.pkl')
# base_quantiles = pd.read_pickle(OUT_PATH / 'camau/output0304.pkl')

In [None]:
# target_plot.write_image(DOCS_PATH / "targets1.png", scale=3)
plot_output_ranges(base_quantiles,targets,["total_population","act3_trial_adults_pop", "act3_control_adults_pop"],indicator_names,indicator_legends,1,2010,2025) #.write_image(DOCS_PATH /'camau/pops.png', scale=3)

In [None]:
plot_output_ranges(base_quantiles,targets,['notification','percentage_latent_adults'],indicator_names,indicator_legends,2,2010,2025, option='camau') #.write_image(DOCS_PATH /'camau/targets.png', scale=3)


In [None]:
plot_output_ranges(base_quantiles,targets,['incidence', 'prevalence_pulmonary', 'adults_prevalence_pulmonary','mortality'],indicator_names,indicator_legends,2,2010,2025, option='camau') #.write_image(DOCS_PATH /'camau/compare.png', scale=3)

In [None]:
plot_output_ranges(base_quantiles,targets,['detection_rate'],indicator_names,indicator_legends,1,1980,2025, option='camau')

In [None]:
plot_output_ranges(base_quantiles,targets,['incidence', 'prevalence_pulmonary', 'adults_prevalence_pulmonary','mortality'],indicator_names,indicator_legends,2,2010,2025, option='camau')

In [None]:
arms = ['act3_trial', 'act3_control', 'act3_other']
metrics = ['incidenceX', 'prevalence_infectiousX']
indicators = [f"{metric}{arm}" for arm in arms for metric in metrics]

plot_output_ranges(
    base_quantiles,
    targets,
    indicators,
    indicator_names,
    indicator_legends,
    2,
    2010,
    2025
) #.write_image(DOCS_PATH /'camau/burden_area.png', scale=3)

In [None]:
plot_trial_output_ranges(base_quantiles,targets,['acf_detectionXact3_trialXorgan_pulmonary_rate1','acf_detectionXact3_controlXorgan_pulmonary_rate1'],indicator_names,2) #.write_image(DOCS_PATH /'camau/trial.png', scale=3)

In [None]:
# target_plot.write_image(DOCS_PATH / "targets2.png", scale=3)

In [None]:
# spah.write_image(DOCS_PATH / 'spah.png', scale = 3)

In [None]:
# target_plot_history = plot_output_ranges(outputs['base_scenario'],targets,['total_population','notification','adults_prevalence_pulmonary'],1,1800,2010, history =True)

In [None]:
# target_plot_history

In [None]:
# target_plot_history.write_image(DOCS_PATH / 'targets_history.png', scale=3)

In [None]:
# compare_target_plot = plot_output_ranges(outputs['base_scenario'],targets,['incidence','mortality_raw','prevalence_smear_positive', 'percentage_latent'],2,2010,2025)

In [None]:
# compare_target_plot.write_image(DOCS_PATH / "non_targets.png", scale='3')

In [None]:
# compare_target_plot

In [None]:
# screening_plot.write_image(DOCS_PATH / 'screening_plot.png', scale =3)

In [None]:
# cdr_plot = plot_output_ranges(outputs['base_scenario']['quantiles'],targets,['case_notification_rate'],1,2010,2025)

In [None]:
# cdr_plot

In [None]:
# cdr_plot.write_image(DOCS_PATH / 'cdr_plot.png', scale =3)

In [None]:
early_plot = plot_output_ranges(base_quantiles,targets,['incidence_early_prop'], indicator_names, indicator_legends,1,2000,2025).write_image(DOCS_PATH /'camau/early.png', scale=3)

In [None]:
data_frames = []

# Create a copy of your original DataFrame
adjusted_df = base_quantiles.copy()

# Confirm required columns exist
if 'prop_late_latent' in adjusted_df.columns and 'prop_susceptible' in adjusted_df.columns:
    # Transfer 20% from prop_late_latent to prop_susceptible, keeping 80% in prop_late_latent
    adjusted_df['prop_susceptible'][0.5] += adjusted_df['prop_late_latent'][0.5] * 0.2
    adjusted_df['prop_late_latent'][0.5] *= 0.8  # Keep 80%
# Calculate 80% of the original values for `prop_early_latent` and `prop_late_latent`


# Extract data for each compartment, label it, and store in a list
data_frames = []
for compartment in COMPARTMENTS:
    if f'prop_{compartment}' in adjusted_df.columns:
        df = adjusted_df[f'prop_{compartment}'][0.5].reset_index()  # Adjust this if your quantile structure is different
        df[0.5] *= 100  # Convert to percentage
        df['type'] = compartment.replace('_', ' ').capitalize()
        data_frames.append(df)

# Combine all compartment data into one DataFrame for plotting
combined_data = pd.concat(data_frames)

# Plot using Plotly Express
fig = px.area(combined_data, x='time', y=0.5, color='type',
              labels={'0.5': 'Proportion (%)', 'time': 'Time'},
              title='', range_x=[1980, 2025], range_y=[0, 100])

# Update layout
fig.update_layout(
    xaxis=dict(title='', title_font=dict(size=12)),
    yaxis=dict(title='<b>Proportion (%)</b>', title_font=dict(size=12), title_standoff=0),
    legend_title_text='',
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.2,
        xanchor="center",
        x=0.5,
        font=dict(size=12)
    ),
    height=320,  # Set the figure height
    margin=dict(l=10, r=5, t=10, b=40),
    font=dict(family="Arial, sans-serif", size=12, color="black")
)
fig.write_image(DOCS_PATH / "compartments.png", scale=3)
# Show the plot
fig.show()