In [None]:
import warnings

warnings.filterwarnings("ignore")
from pathlib import Path
import numpy as np
import arviz as az
from tbdynamics.calibration.az_aux import (
    tabulate_calib_results,
    plot_post_prior_comparison,
    plot_trace,
)
from tbdynamics.camau.calibration.utils import get_bcm
from tbdynamics.camau.constants import params_name
from tbdynamics.calibration.az_aux import (
    process_idata_for_derived_metrics,
    process_priors_for_derived_metrics,
    plot_derived_comparison,
)
from tbdynamics.tools.inputs import get_death_rate, process_universal_death_rate

In [None]:
RUN_PATH = Path.cwd().parent.parent / 'data/outputs/camau/'
idata = az.from_netcdf(RUN_PATH / 'r0204/calib_full_out.nc')
# idata = idata.sel(draw=np.s_[80000:])
burnt_idata = idata.sel(draw=np.s_[50000:])

In [None]:
params = {
    "start_population_size": 30000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
    # "contact_rate": 0.02,
    # "rr_infection_latent": 0.1890473700762809,
    # "rr_infection_recovered": 0.17781844797545143,
    "smear_positive_death_rate": 0.3655528915762244,
    "smear_negative_death_rate": 0.027358324164819155,
    "smear_positive_self_recovery": 0.18600338108638945,
    "smear_negative_self_recovery": 0.11333894801537307,
    "screening_scaleup_shape": 0.5,
    "screening_inflection_time": 1993,
    # "time_to_screening_end_asymp": 2.1163556520843936,
    "acf_sensitivity": 0.90,
    # "prop_mixing_same_stratum": 0.6920672992582717,
    # "early_prop_adjuster": -0.017924441638418186,
    # "late_reactivation_adjuster": 1.1083422207175728,
    "detection_reduction": 0.30,
    # "total_population_dispersion": 3644.236227852164,
    # "notif_dispersion": 88.37092488550051,
    # "latent_dispersion": 7.470896188551709,
}
covid_effects = {"detection_reduction": True, "contact_reduction": False}
bcm = get_bcm(params, covid_effects)

In [None]:
bcm.priors

In [None]:
fig = plot_post_prior_comparison(burnt_idata, bcm.priors, params_name, 2)

In [None]:
# universal_death = process_universal_death_rate(get_death_rate())

In [None]:
# priors = {
#     'smear_positive_death_rate': bcm.priors['smear_positive_death_rate'],
#     'smear_positive_self_recovery': bcm.priors['smear_positive_self_recovery'],
#     'smear_negative_death_rate': bcm.priors['smear_negative_death_rate'],
#     'smear_negative_self_recovery': bcm.priors['smear_negative_self_recovery'],
# }

In [None]:
tabulate_calib_results(burnt_idata, params_name)

In [None]:
tracing = plot_trace(idata, params_name)

In [None]:
tracing

In [None]:
# tracing.savefig('../docs/param_traces.png', dpi=300, bbox_inches='tight', format='png', pad_inches=0)

In [None]:
posterior_metrics = process_idata_for_derived_metrics(burnt_idata, universal_death[2023])

In [None]:
prior_metrics = process_priors_for_derived_metrics(priors, universal_death[2023])

In [None]:
plot_derived_comparison(prior_metrics,posterior_metrics);