In [None]:
import arviz as az
from pathlib import Path
from datetime import datetime
import yaml

from estival.sampling import tools as esamptools

import utils as ut
import model as md

output_folder = Path.cwd() / "outputs"

In [None]:
TEST_ANALYSIS_CONFIG = {
    'opti_budget': 100,

    'mcmc_chains': 4,
    'mcmc_cores': 4,
    'mcmc_tune': 100,
    'mcmc_samples': 100,

    'full_runs_burnin': 50,
    'full_runs_samples': 100,
}

FULL_ANALYSIS_CONFIG = {
    'opti_budget': 5000,

    'mcmc_chains': 4,
    'mcmc_cores': 4,
    'mcmc_tune': 1000,
    'mcmc_samples': 5000,

    'full_runs_burnin': 2000,
    'full_runs_samples': 1000,
}

analysis_config = TEST_ANALYSIS_CONFIG
analysis_name = 'test' 
folder_path = output_folder / f"{datetime.now().strftime('%Y_%m_%d@%H_%M_%S')}_{analysis_name}"
folder_path.mkdir(parents=True, exist_ok=True)

In [None]:
config = {
    "start_time": 1850,
    "end_time": 2050,
    "population": 1.e6,
    "seed": 100,   
    "intervention_time": 2025,
}

intervention_params = {
    "transmission_reduction": {
        "rel_reduction": .20
    },
    "preventive_treatment": {
        "rate": .10,
        "efficacy": .8
    },
    "faster_detection": {
        "detection_rate_mutliplier": 2.
    },
    "improved_treatment": {
        "negative_outcomes_rel_reduction": .50
    }
}

# Optimisation

### Find optimal parameter set, varying all parameters, using model with no intervention

In [None]:
model = md.get_tb_model(config, intervention_params, active_interventions=[])
mle_params = ut.find_mle(model, opti_budget=analysis_config['opti_budget'])

with open(folder_path / "mle_params.yml", 'w') as file:
    yaml.dump(mle_params, file, default_flow_style=False)

### Check optimal model fit

In [None]:
model.run(ut.default_params | mle_params)
do = model.get_derived_outputs_df()
do['tb_prevalence_per100k'].loc[2010:].plot()
ut.target_data.plot(style='.',color='red')

# Main Analysis 

### Run Metropolis sampling 

In [None]:
(folder_path / "idatas").mkdir(parents=True, exist_ok=True)

idatas = {}
for fixed_param in [None] + [p.name for p in ut.all_priors]:
    print(f"Running Metropolis sampling fixing {fixed_param}")
    idatas[fixed_param] = ut.run_sampling(model, mle_params, fixed_param, draws=analysis_config['mcmc_samples'], tune=analysis_config['mcmc_tune'], cores=analysis_config['mcmc_cores'], chains=analysis_config['mcmc_chains'])
    idatas[fixed_param].to_netcdf(folder_path / "idatas" / f"idata_{fixed_param}.nc")

### Run full scenario runs for sampled parameters

In [None]:
full_runs, diff_output_dfs = {}, {}
(folder_path / "full_runs").mkdir(parents=True, exist_ok=True)
(folder_path / "diff_output_dfs").mkdir(parents=True, exist_ok=True)

for fixed_param, idata in idatas.items():
    print(f"Running full runs fixing {fixed_param}")
    chain_length = idata.sample_stats.sizes['draw']
    burnt_idata = idata.sel(draw=range(analysis_config['full_runs_burnin'], chain_length))  # Discard burn-in
    full_run_param_samples =  az.extract(burnt_idata, num_samples=analysis_config['full_runs_samples'])

    full_runs[fixed_param] = {}
    diff_output_dfs[fixed_param] = {}
    for intervention in [None] + list(intervention_params.keys()):
        active_interventions = [intervention] if intervention else []
        model = md.get_tb_model(config, intervention_params, active_interventions)
        bcm = ut.get_bcm_object(model, ut.default_params | mle_params, fixed_param)
        full_runs[fixed_param][intervention] = esamptools.model_results_for_samples(full_run_param_samples, bcm)
        full_runs[fixed_param][intervention].results.to_parquet(folder_path / "full_runs" / f"fullruns_{fixed_param}_{intervention}.parquet")

        if intervention:
            diff_output_dfs[fixed_param][intervention] = ut.calculate_diff_output_quantiles(full_runs[fixed_param][None], full_runs[fixed_param][intervention])
            diff_output_dfs[fixed_param][intervention].to_csv(folder_path / "diff_output_dfs" / f"diff_outputs_{fixed_param}_{intervention}.csv")
