In [None]:
import warnings
warnings.filterwarnings("ignore")
import numpy as np
from pathlib import Path
import pymc as pm
import arviz as az
import multiprocessing as mp
from estival.wrappers import pymc as epm
from estival.sampling import tools as esamp
from estival.wrappers import nevergrad as eng
from estival.utils.parallel import map_parallel
import nevergrad as ng
# from autumn.infrastructure.remote import springboard
from tbdynamics.vietnam.calibration.utils import get_bcm
from estival.utils.sample import SampleTypes

In [None]:
covid_configs = {
        # 'no_covid': {
        #     "detection_reduction": False,
        #     "contact_reduction": False
        # },  # No reduction
        'detection': {
            "detection_reduction": True,
            "contact_reduction": False
        },  # No contact reduction
        # 'contact': {
        #     "detection_reduction": False,
        #     "contact_reduction": True
        # },  # Only contact reduction
        'detection_and_contact': {
            "detection_reduction": True,
            "contact_reduction": True
        },  # With detection + contact reduction
    }

covid_effects = {
    'detection_reduction':True,
    'contact_reduction':True
}
params = {
    "start_population_size": 2000000.0,
    "seed_time": 1805.0,
    "seed_num": 1.0,
    "seed_duration": 1.0,
}

In [None]:
def calibrate(out_path, params, covid_effects, draws, tune):
    bcm = get_bcm(params, covid_effects)
    def optimize_ng_with_idx(item):
        idx, sample = item
        opt = eng.optimize_model(bcm, budget=500, opt_class=ng.optimizers.TwoPointsDE, suggested = sample, num_workers=8)
        rec= opt.minimize(500)
        return idx, rec.value[1]

    lhs_samples = bcm.sample.lhs(16, ci=0.67)
    lhs_lle = esamp.likelihood_extras_for_samples(lhs_samples, bcm)
    lhs_sorted = lhs_lle.sort_values("loglikelihood", ascending=False)
    opt_samples_idx = map_parallel(optimize_ng_with_idx, lhs_sorted.iterrows())
    best_opt_samps = bcm.sample.convert(opt_samples_idx)
    init_samps = best_opt_samps.convert(SampleTypes.LIST_OF_DICTS)[0:8]
    n_chains = 8
    n_samples = 1000
    with pm.Model() as pm_model:
        variables = epm.use_model(bcm)
        idata_raw = pm.sample(
            step=[pm.DEMetropolisZ(variables, proposal_dist=pm.NormalProposal)],
            draws=draws,
            cores= 16,
            tune=tune,
            discard_tuned_samples=False,
            chains=n_chains,
            progressbar=True,
            initvals=init_samps,
        )
    idata_raw.to_netcdf(str(out_path / "calib_full_out.nc"))

    # burnt_idata = idata_raw.sel(draw=np.s_[50000:])
    # idata_extract = az.extract(burnt_idata, num_samples=n_samples)
    # bcm.sample.convert(idata_extract).to_hdf5(out_path / "calib_extract_out.h5")

    # spaghetti_res = esamp.model_results_for_samples(idata_extract, bcm)
    # spaghetti_res.results.to_hdf(str(out_path / "results.hdf"), "spaghetti")

    # like_df = esamp.likelihood_extras_for_idata(idata_raw, bcm)
    # like_df.to_hdf(str(out_path / "results.hdf"), "likelihood")


# def run_calibration(bridge: springboard.task.TaskBridge, bcm, draws, tune):
#     import multiprocessing as mp
#     mp.set_start_method("forkserver")
#     idata_raw = calibrate(bridge.out_path, bcm, draws, tune)
#     bridge.logger.info("Calibration complete")

In [None]:
def calibrate_with_configs(out_path, params, covid_configs, draws, tune):
    for config_name, covid_effects in covid_configs.items():
        # Call the original calibrate function with each covid_effects
        bcm = get_bcm(params, covid_effects)
        
        def optimize_ng_with_idx(item):
            idx, sample = item
            opt = eng.optimize_model(bcm, budget=1000, opt_class=ng.optimizers.TwoPointsDE, suggested=sample, num_workers=8)
            rec = opt.minimize(500)
            return idx, rec.value[1]

        lhs_samples = bcm.sample.lhs(16, ci=0.67)
        lhs_lle = esamp.likelihood_extras_for_samples(lhs_samples, bcm)
        lhs_sorted = lhs_lle.sort_values("loglikelihood", ascending=False)
        opt_samples_idx = map_parallel(optimize_ng_with_idx, lhs_sorted.iterrows())
        best_opt_samps = bcm.sample.convert(opt_samples_idx)
        init_samps = best_opt_samps.convert(SampleTypes.LIST_OF_DICTS)[0:8]
        n_chains = 8
        # n_samples = 1000
        
        with pm.Model() as pm_model:
            variables = epm.use_model(bcm)
            idata_raw = pm.sample(
                step=[pm.DEMetropolisZ(variables)],
                draws=draws,
                cores=16,
                tune=tune,
                discard_tuned_samples=False,
                chains=n_chains,
                progressbar=True,
                initvals=init_samps,
            )
        
        # Save results using the configuration key in the filenames
        idata_raw.to_netcdf(str(out_path / f"calib_full_out_{config_name}.nc"))
        # burnt_idata = idata_raw.sel(draw=np.s_[5000:])
        # idata_extract = az.extract(burnt_idata, num_samples=n_samples)
        # bcm.sample.convert(idata_extract).to_hdf5(out_path / f"calib_extract_out_{config_name}.h5")
        
        # spaghetti_res = esamp.model_results_for_samples(idata_extract, bcm)
        # spaghetti_res.results.to_hdf(str(out_path / f"results_{config_name}.hdf"), "spaghetti")
        
        # like_df = esamp.likelihood_extras_for_idata(idata_raw, bcm)
        # like_df.to_hdf(str(out_path / f"results_{config_name}.hdf"), "likelihood")


In [None]:
OUT_PATH = Path.cwd().parent.parent / 'runs/r1606'

In [None]:
OUT_PATH

In [None]:
draws= 10000
tune = 5000
calibrate(OUT_PATH,params, covid_effects, draws, tune)