In [None]:
from model import get_tb_model

In [None]:
config = {
    "start_time": 1850,
    "end_time": 2050,
    "population": 1.e6,
    "seed": 100,   
}

model = get_tb_model(config)

In [None]:
test_params = {
    # Planning to vary these parameters
    'transmission_rate': 1.,

    'activation_rate_early': 1.,
    'activation_rate_late': 1.,
    'stabilisation_rate': 1.,

    'rr_reinfection_latent_late': 1.,
    'rr_reinfection_recovered': 1.,

    'self_recovery_rate': 1.,
    'tb_death_rate': 1.,

    'current_passive_detection_rate': 1.,

    # Planning to fix these ones
    'tx_duration': 0.5,
    'tx_prop_death': .04  # WHO, among new treatment
}
model.run(test_params)

# Calibration 

In [None]:
from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm
import estival.priors as esp
import estival.targets as est
from estival.sampling import tools as esamptools

import pymc as pm
import pandas as pd

import arviz as az #For convergence stats and diagnostics
import time
import matplotlib .pyplot as plt

In [None]:
all_priors = [
    esp.UniformPrior("transmission_rate", [0.1, 10.]),
    esp.UniformPrior("activation_rate_early", [0., 1.]),
    esp.UniformPrior("activation_rate_late", [0., 1.]),
    esp.UniformPrior("stabilisation_rate", [0., 1.]),
    esp.UniformPrior("rr_reinfection_latent_late", [0.2, 0.5]),
    esp.UniformPrior("rr_reinfection_recovered", [0.5, 1.]),
    esp.UniformPrior("self_recovery_rate", [0., 0.5]),
    esp.UniformPrior("tb_death_rate", [0., 0.5]),
    esp.UniformPrior("current_passive_detection_rate", [0.2, 2.]),
]

target_data = pd.Series({2024: 1000.})

targets = [
    est.NormalTarget("tb_prevalence_per100k", target_data , stdev=100.)
]


def get_bcm_object(model, all_mle_params, fixed_param_indices=[]):
    
    priors = [p for i, p in enumerate(all_priors) if i not in fixed_param_indices]
    mle_params = {p.name: all_mle_params[p.name] for i, p in enumerate(all_priors) if i in fixed_param_indices}

    bcm = BayesianCompartmentalModel(model, test_params | mle_params, priors, targets)

    return bcm
    

### Optimisation

In [None]:
# Import nevergrad
import nevergrad as ng

# Import our convenience wrapper
from estival.wrappers.nevergrad import optimize_model
bcm = get_bcm_object(model, {}, [])

orunner = optimize_model(bcm)
rec = orunner.minimize(1000)
all_mle_params = rec.value[1]

In [None]:
model.run(test_params | all_mle_params)
do = model.get_derived_outputs_df()

In [None]:
do['tb_prevalence_per100k'].plot()
target_data.plot(style='.',color='red')

### Sampling

In [None]:
def run_sampling(model, all_mle_params, fixed_param_indices=[]):

    bcm = get_bcm_object(model, all_mle_params, fixed_param_indices)
    with pm.Model() as model:
        variables = epm.use_model(bcm)
        idata = pm.sample(step=[pm.DEMetropolisZ(variables)], draws=10000, tune=1000, cores=4, chains=4) 
    
    return bcm, idata

def check_sampling(bcm, idata, n_samples=1000):
    az.plot_trace(idata)

    sample_idata = az.extract(idata, num_samples = n_samples)
    mres = esamptools.model_results_for_samples(sample_idata, bcm)
    esamptools.quantiles_for_results(mres.results,[0.025,0.25,0.5,0.75,0.975])["tb_prevalence_per100k"].loc[2010:].plot()
    target_data.loc[2010:].plot(style='.',color='red') 

def get_output_values(bcm, idata, n_samples=1000):
    sample_idata = az.extract(idata, num_samples = n_samples)
    
    mres = esamptools.model_results_for_samples(sample_idata, bcm)
    output_values = mres.results['active_cases'].loc[100.]
    
    return output_values



In [None]:
bcm, idata = run_sampling(model, all_mle_params, fixed_param_indices=[])
check_sampling(bcm, idata)

In [None]:
len(all_mle_params)

In [None]:
bcm_2, idata_2 = run_sampling(model, all_mle_params, fixed_param_indices=[0])
check_sampling(bcm_2, idata_2)