In [None]:
import pandas as pd
from tbdynamics.constants import (
    age_strata,
    organ_strata,
    compartments,

)
from tbdynamics.plotting import plot_model_vs_actual
import nevergrad as ng
from estival.wrappers import nevergrad as eng


# Import our convenience wrapper
from estival.wrappers.nevergrad import optimize_model
from tbdynamics.calib_utils import get_bcm, load_targets
from multiprocessing import cpu_count
from estival.sampling import tools as esamp
from estival.sampling.tools import SampleIterator, SampleTypes
from estival.utils.parallel import map_parallel
import pymc as pm
from estival.wrappers import pymc as epm
import arviz as az
import numpy as np

In [None]:
pd.options.plotting.backend = "plotly"

## Define Model

### Params and calibration targets

In [None]:
# params = {
#     'treatment_duration': 0.5, # 6 months
#     'screening_start_asymp': 0.,
# }

bcm = get_bcm()

### Running Optimization

In [None]:
def optimize_ng_with_idx(item):
    idx, sample = item
    opt = eng.optimize_model(bcm, budget=1000, opt_class=ng.optimizers.TwoPointsDE, suggested = sample, num_workers=4)
    rec= opt.minimize(1000)
    return idx, rec.value[1]

In [None]:
lhs_samples = bcm.sample.lhs(16)

In [None]:
lhs_lle = esamp.likelihood_extras_for_samples(lhs_samples, bcm)

In [None]:
lhs_sorted = lhs_lle.sort_values("loglikelihood", ascending=False)
lhs_sorted.index

In [None]:
best8 = lhs_samples[lhs_sorted.index].iloc[0:8]

In [None]:
opt_samples_idx = map_parallel(optimize_ng_with_idx, best8.iterrows())

In [None]:
lle_samps = esamp.likelihood_extras_for_samples(opt_samples_idx, bcm)
lle_samps

In [None]:
best_opt_samps = bcm.sample.convert(opt_samples_idx)
best_opt_samps

In [None]:
init_samps = best_opt_samps.iloc[0:4].convert("list_of_dicts")


In [None]:
mle_params = init_samps[0]

In [None]:
res = bcm.run(mle_params)
derived_df_0 = res.derived_outputs
targets = load_targets()

In [None]:
with pm.Model() as model:
    variables = epm.use_model(bcm)
    idata = pm.sample(step=[pm.DEMetropolisZ(variables)],draws=1000, chains=4, initvals=init_samps)

In [None]:
az.summary(idata)

In [None]:
lle = esamp.likelihood_extras_for_idata(idata, bcm)

In [None]:
lle["logposterior"].unstack(["chain"]).rolling(250).mean().plot()

In [None]:

burnt_idata = idata.sel(draw=np.s_[200:])

In [None]:
sds = az.extract(burnt_idata, num_samples=100)

In [None]:
spaghetti_res = esamp.model_results_for_samples(sds,bcm)

In [None]:
pd.options.plotting.backend = "matplotlib"

In [None]:
spaghetti_res.results["notification"].plot(legend=False)
bcm.targets["notification"].data.plot(style='.',color="black")

In [None]:
spaghetti_res.results["total_population"].plot(legend=False)
bcm.targets["total_population"].data.plot(style='.',color="black")

In [None]:
spaghetti_res.results["prevalence_pulmonary"].plot(legend=False)
bcm.targets["prevalence_pulmonary"].data.plot(style='.',color="black")

In [None]:
sds = az.extract(burnt_idata, num_samples=500)

In [None]:
samp_res = esamp.model_results_for_samples(sds,bcm)

In [None]:
quantiles = esamp.quantiles_for_results(samp_res.results, (0.05,0.25,0.5,0.75,0.95))

In [None]:
quantiles["notification"].plot()
bcm.targets["notification"].data.plot(style='.',color="black")

### Outputs

In [None]:
plot_model_vs_actual(
    derived_df_0, targets['pop'], "total_population", "Population", "Modelled vs Data"
)

In [None]:
derived_df_0[[f"total_populationXage_{i}" for i in age_strata]].plot(
    title="Modelled populatation by age group", kind="area"
)

In [None]:

plot_model_vs_actual(
    derived_df_0, targets['incidence'], "incidence", "Incidence", "Modelled vs Data"
)

In [None]:
derived_df_0[[f"prop_{compartment}" for compartment in compartments]].plot(kind="area")

In [None]:
derived_df_0[[f"total_populationXorgan_{i}" for i in organ_strata]].plot(
    title="Modelled populatation by organ status", kind="area"
)

In [None]:
plot_model_vs_actual(
    derived_df_0, targets['notifs'], "notification", "Notification", "Modelled vs Data"
)

In [None]:
plot_model_vs_actual(
    derived_df_0, targets['percentage_latent'], "percentage_latent", "Percentage latent", "Modelled vs Data"
)

In [None]:
derived_df_0['cdr'].plot()

In [None]:
plot_model_vs_actual(derived_df_0, targets['prevalence_pulmonary'], 'prevalence_pulmonary', 'Infectious prevalence', 'Modelled vs Estimation from 2017 prevalence survey')


In [None]:
OUT_PATH = Path.cwd() /'runs'

In [None]:
n_chains = 8
n_samples = 100
with pm.Model() as pm_model:
    variables = epm.use_model(bcm)
    idata_raw = pm.sample(step=[pm.DEMetropolisZ(variables)], draws=1000, tune=200, cores=8, discard_tuned_samples=False, chains=n_chains, progressbar=True, initvals=mle_params)
   

In [None]:
idata_raw.to_netcdf(str(OUT_PATH /'calib_full_out.nc'))

In [None]:
burnt_idata = idata_raw.sel(draw=np.s_[200:])

In [None]:
idata_extract = az.extract(burnt_idata, num_samples=n_samples)

In [None]:
bcm.sample.convert(idata_extract, 'list_of_dicts').to_hdf5(OUT_PATH / 'calib_extract_out.h5')

In [None]:
az.plot_posterior(idata_raw)

In [None]:
spaghetti_res = esamp.model_results_for_samples(idata_extract, bcm)

In [None]:
spaghetti_res

In [None]:
spaghetti_res.results.to_hdf(str(OUT_PATH / 'results.hdf'), 'spaghetti')

In [None]:
like_df = esamp.likelihood_extras_for_idata(idata_raw, bcm)
like_df.to_hdf(str(OUT_PATH / 'results.hdf'), 'likelihood')

In [None]:
quantiles = [0.025, 0.25, 0.5, 0.75, 0.975]

In [None]:
spaghettis  = pd.read_hdf(OUT_PATH / 'results.hdf', 'spaghetti')

In [None]:
quantile_outputs = esamp.quantiles_for_results(spaghettis, quantiles)

In [None]:
quantile_outputs['incidence']

In [None]:
quantile_outputs['percentage_latent']

In [None]:
quantile_outputs['notification'].columns