In [None]:
import arviz as az

import utils as ut
import model as md

from estival.sampling import tools as esamptools


In [None]:
config = {
    "start_time": 1850,
    "end_time": 2050,
    "population": 1.e6,
    "seed": 100,   
    "intervention_time": 2025,
}

intervention_params = {
    "transmission_reduction": {
        "rel_reduction": .20
    },
    "preventive_treatment": {
        "rate": .10,
        "efficacy": .8
    },
    "faster_detection": {
        "detection_rate_mutliplier": 2.
    },
    "improved_treatment": {
        "negative_outcomes_rel_reduction": .50
    }
}

# Optimisation

### Find optimal parameter set, varying all parameters, using model with no intervention

In [None]:
model = md.get_tb_model(config, intervention_params, active_interventions=[])
mle_params = ut.find_mle(model, opti_budget=300)

### Check optimal model fit

In [None]:
model.run(ut.default_params | mle_params)
do = model.get_derived_outputs_df()
do['tb_prevalence_per100k'].loc[2010:].plot()
ut.target_data.plot(style='.',color='red')

# Main Analysis 

### Run Metropolis sampling 

In [None]:
idatas = {}
for fixed_param in [None] + [p.name for p in ut.all_priors]:
    print(f"Running Metropolis sampling fixing {fixed_param}")
    idatas[fixed_param] = ut.run_sampling(model, mle_params, fixed_param, draws=1000, tune=500, cores=4, chains=4)

    break


### Run full scenario runs for sampled parameters

In [None]:
full_runs = {}
for fixed_param, idata in idatas.items():
    burn_in, n_samples = 100, 1000
    chain_length = idata.sample_stats.sizes['draw']
    burnt_idata = idata.sel(draw=range(burn_in, chain_length))  # Discard burn-in
    full_run_param_samples =  az.extract(burnt_idata, num_samples=n_samples)

    full_runs[fixed_param] = {}
    for intervention in [None] + list(intervention_params.keys()):
        active_interventions = [intervention] if intervention else []
        model = md.get_tb_model(config, intervention_params, active_interventions)
        bcm = ut.get_bcm_object(model, ut.default_params | mle_params, fixed_param)
        full_runs[fixed_param][intervention] = esamptools.model_results_for_samples(full_run_param_samples, bcm)


In [None]:
# quants = calculate_diff_output_quantiles(ref_full_runs, res)