In [None]:
import warnings
warnings.filterwarnings("ignore")
import numpy as np
from pathlib import Path
import pymc as pm
import arviz as az
import multiprocessing as mp
from estival.wrappers import pymc as epm
from estival.sampling import tools as esamp
from estival.wrappers import nevergrad as eng
from estival.utils.parallel import map_parallel
import nevergrad as ng

from tbdynamics.constants import BURN_IN, OPTI_DRAWS
from autumn.infrastructure.remote import springboard
from tbdynamics.calib_utils import get_bcm
import pandas as pd
from estival.utils.sample import SampleTypes


In [None]:
params = {
    "start_population_size": 2300000.0,
    "seed_time": 1830.0,
    "seed_num": 100.0,
    "seed_duration": 20.0,
    # "contact_rate": 0.02977583831288669,
    # "rr_infection_latent": 0.20344010763518713,
    # "rr_infection_recovered": 0.40580870889350107,
    # "progression_multiplier": 0.8810860029360905,
    # "smear_positive_death_rate": 0.4313851033562638,
    # "smear_negative_death_rate": 0.03350161278620193,
    # "smear_positive_self_recovery": 0.28604824197753914,
    # "smear_negative_self_recovery": 0.15805647865361552,
    "screening_scaleup_shape": 0.25,
    # "screening_inflection_time": 1995.1487440977369,
    # "time_to_screening_end_asymp": 1.111127871433536,
    # "detection_reduction": 0.24351558358481182,
    # "contact_reduction": 0.3813306382411676,
}

In [None]:
def calibrate(out_path, params, draws, tune):
    bcm = get_bcm(params)
    def optimize_ng_with_idx(item):
        idx, sample = item
        opt = eng.optimize_model(bcm, budget=1000, opt_class=ng.optimizers.TwoPointsDE, suggested = sample, num_workers=8)
        rec= opt.minimize(1000)
        return idx, rec.value[1]

    lhs_samples = bcm.sample.lhs(16, ci=0.67)
    lhs_lle = esamp.likelihood_extras_for_samples(lhs_samples, bcm)
    lhs_sorted = lhs_lle.sort_values("loglikelihood", ascending=False)
    opt_samples_idx = map_parallel(optimize_ng_with_idx, lhs_sorted.iterrows())
    best_opt_samps = bcm.sample.convert(opt_samples_idx)
    init_samps = best_opt_samps.convert(SampleTypes.LIST_OF_DICTS)[0:8]
    n_chains = 8
    n_samples = 500
    with pm.Model() as pm_model:
        variables = epm.use_model(bcm)
        idata_raw = pm.sample(
            step=[pm.DEMetropolisZ(variables, proposal_dist=pm.NormalProposal)],
            draws=draws,
            cores= 8,
            tune=tune,
            discard_tuned_samples=False,
            chains=n_chains,
            progressbar=True,
            initvals=init_samps,
        )
    idata_raw.to_netcdf(str(out_path / "calib_full_out.nc"))

    burnt_idata = idata_raw.sel(draw=np.s_[50000:])
    idata_extract = az.extract(burnt_idata, num_samples=n_samples)
    bcm.sample.convert(idata_extract).to_hdf5(out_path / "calib_extract_out.h5")

    spaghetti_res = esamp.model_results_for_samples(idata_extract, bcm)
    spaghetti_res.results.to_hdf(str(out_path / "results.hdf"), "spaghetti")

    like_df = esamp.likelihood_extras_for_idata(idata_raw, bcm)
    like_df.to_hdf(str(out_path / "results.hdf"), "likelihood")


def run_calibration(bridge: springboard.task.TaskBridge, bcm, draws, tune):
    import multiprocessing as mp
    mp.set_start_method("forkserver")
    idata_raw = calibrate(bridge.out_path, bcm, draws, tune)
    bridge.logger.info("Calibration complete")

In [None]:
# def smc_calib():
#     with pm.Model() as model:
#         variables = epm.use_model(bcm)
#         print(variables)
#         idata_smc = pm.sample_smc(
#             kernel=pm.smc.IMH,
#             compute_convergence_checks=False,
#             start=None,
#             draws=15000,
#             chains=4,
#             threshold=0.1,
#             correlation_threshold=0.6,
#         )
#     idata_smc.to_netcdf(str("runs/r1207/calib_full_out.nc"))

In [None]:
# smc_calib()

In [None]:
OUT_PATH = Path.cwd() / 'runs/r1207'
draws= 100000
tune = 50000
calibrate(OUT_PATH,params, draws, tune)

In [None]:
# draws = 200000
# tune = 100000

# commands = [
#     'git clone --branch tb-covid https://github.com/longbui/tbdynamics.git',
#     'pip install -e ./tbdynamics',
# ]

# mspec = springboard.EC2MachineSpec(8, 2, 'compute')
# run_str = f'd{int(draws / 1000)}k-t{int(tune / 1000)}k-b{int(BURN_IN / 1000)}k'
# tspec_args = {'bcm': bcm,'draws': draws, 'tune': tune}
# tspec = springboard.TaskSpec(run_calibration, tspec_args)
# run_path =  springboard.launch.get_autumn_project_run_path('tbdynamics', 's3107_calibration', run_str)
# runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path, branch=None, extra_commands=commands)


In [None]:
# runner.wait()

In [None]:
# runner.get_log("crash")

In [None]:
# from autumn.infrastructure.remote import springboard
# rts = springboard.task.RemoteTaskStore()
# rts.cd('projects/tbdynamics/s3107_calibration')
# rts.ls()



In [None]:
# mt = rts.get_managed_task('2024-08-01T1645-d100k-t50k-b50k')

In [None]:
# mt.download_all()

In [None]:
# print(runner.get_log("crash"))