In [1]:
import warnings
warnings.filterwarnings("ignore")
import numpy as np
from pathlib import Path
import pymc as pm
import arviz as az
import multiprocessing as mp
from estival.wrappers import pymc as epm
from estival.sampling import tools as esamp
from estival.wrappers import nevergrad as eng
from estival.utils.parallel import map_parallel
import nevergrad as ng
# from autumn.infrastructure.remote import springboard
from tbdynamics.camau.calibration.utils import get_bcm
from estival.utils.sample import SampleTypes



In [2]:
covid_configs = {
        'no_covid': {
            "detection_reduction": False,
            "contact_reduction": False
        },  # No reduction
        'detection': {
            "detection_reduction": True,
            "contact_reduction": False
        },  # No contact reduction
        'contact': {
            "detection_reduction": False,
            "contact_reduction": True
        },  # Only contact reduction
        'detection_and_contact': {
            "detection_reduction": True,
            "contact_reduction": True
        },  # With detection + contact reduction
    }

covid_effects = {
    'detection_reduction':True,
    'contact_reduction':False
}
params = {
    "start_population_size": 30000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
    # "contact_rate": 0.02,
    "rr_infection_latent": 0.1890473700762809,
    "rr_infection_recovered": 0.17781844797545143,
    "smear_positive_death_rate": 0.3655528915762244,
    "smear_negative_death_rate": 0.027358324164819155,
    "smear_positive_self_recovery": 0.18600338108638945,
    "smear_negative_self_recovery": 0.11333894801537307,
    "screening_scaleup_shape": 0.5,
    "screening_inflection_time": 1993,
    # "time_to_screening_end_asymp": 2.1163556520843936,
    "acf_sensitivity": 0.90,
    # "prop_mixing_same_stratum": 0.6920672992582717,
    # "early_prop_adjuster": -0.017924441638418186,
    # "late_reactivation_adjuster": 1.1083422207175728,
    "detection_reduction": 0.30,
    # "total_population_dispersion": 3644.236227852164,
    # "notif_dispersion": 88.37092488550051,
    # "latent_dispersion": 7.470896188551709,
}
N_CHAINS = 4

In [None]:
def calibrate(out_path, params, covid_effects, draws, tune):
    bcm = get_bcm(params, covid_effects)
    def optimize_ng_with_idx(item):
        idx, sample = item
        opt = eng.optimize_model(bcm, budget=1000, opt_class=ng.optimizers.TwoPointsDE, suggested = sample, num_workers=8)
        rec= opt.minimize(1000)
        return idx, rec.value[1]

    lhs_samples = bcm.sample.lhs(16, ci=0.67)
    lhs_lle = esamp.likelihood_extras_for_samples(lhs_samples, bcm)
    lhs_sorted = lhs_lle.sort_values("loglikelihood", ascending=False)
    opt_samples_idx = map_parallel(optimize_ng_with_idx, lhs_sorted.iterrows())
    best_opt_samps = bcm.sample.convert(opt_samples_idx)
    init_samps = best_opt_samps.convert(SampleTypes.LIST_OF_DICTS)[0:N_CHAINS]
    n_chains = N_CHAINS
    with pm.Model() as pm_model:
        variables = epm.use_model(bcm)
        idata_raw = pm.sample(
            step=[pm.DEMetropolisZ(variables, scaling=0.1, proposal_dist=pm.NormalProposal)],
            draws=draws,
            cores= 16,
            tune=tune,
            discard_tuned_samples=False,
            chains=n_chains,
            progressbar=True,
            initvals=init_samps,
        )
    idata_raw.to_netcdf(str(out_path / "calib_full_out.nc"))

In [4]:
def calibrate_with_configs(out_path, params, covid_configs, draws, tune):
    for config_name, covid_effects in covid_configs.items():
        # Call the original calibrate function with each covid_effects
        bcm = get_bcm(params, covid_effects)
        
        def optimize_ng_with_idx(item):
            idx, sample = item
            opt = eng.optimize_model(bcm, budget=500, opt_class=ng.optimizers.TwoPointsDE, suggested=sample, num_workers=8)
            rec = opt.minimize(500)
            return idx, rec.value[1]

        lhs_samples = bcm.sample.lhs(16, ci=0.67)
        lhs_lle = esamp.likelihood_extras_for_samples(lhs_samples, bcm)
        lhs_sorted = lhs_lle.sort_values("loglikelihood", ascending=False)
        opt_samples_idx = map_parallel(optimize_ng_with_idx, lhs_sorted.iterrows())
        best_opt_samps = bcm.sample.convert(opt_samples_idx)
        init_samps = best_opt_samps.convert(SampleTypes.LIST_OF_DICTS)[0:4]
        n_chains = N_CHAINS
        
        with pm.Model() as pm_model:
            variables = epm.use_model(bcm)
            idata_raw = pm.sample(
                step=[pm.DEMetropolisZ(variables)],
                draws=draws,
                cores=16,
                tune=tune,
                discard_tuned_samples=False,
                chains=n_chains,
                progressbar=True,
                initvals=init_samps,
            )
        
        # Save results using the configuration key in the filenames
        idata_raw.to_netcdf(str(out_path / f"calib_full_out_{config_name}.nc"))


In [5]:
OUT_PATH = Path.cwd().parent.parent / 'data/outputs/camau/r1103'

In [6]:
OUT_PATH 

WindowsPath('c:/Users/vbui0010/tbdynamics/data/outputs/camau/r1103')

In [7]:
draws= 2000
tune = 1000
calibrate(OUT_PATH,params, covid_effects, draws, tune)

Multiprocess sampling (4 chains in 16 jobs)
DEMetropolisZ: [contact_rate, rr_infection_latent, rr_infection_recovered, time_to_screening_end_asymp, early_prop_adjuster, late_reactivation_adjuster, detection_reduction, total_population_dispersion, notif_dispersion, latent_dispersion, act3_trial_dispersion, act3_control_dispersion]


Sampling 4 chains for 1_000 tune and 2_000 draw iterations (4_000 + 8_000 draws total) took 786 seconds.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
