In [32]:
from summer2 import CompartmentalModel
from summer2.parameters import Parameter, DerivedOutput, Function
from summer2.functions.time import get_sigmoidal_interpolation_function, get_linear_interpolation_function

import yaml
from typing import List, Dict
from pathlib import Path   

In [33]:
tv_data_path = Path.cwd() / 'data' / 'time_variant_params.yml'
with open(tv_data_path, 'r') as file:
    tv_data = yaml.safe_load(file)

In [None]:
def get_tb_model(config: dict):

    """
    Prepare time-variant parameters and other quantities requiring pre-processsing
    """
    crude_birth_rate_func = get_linear_interpolation_function(
        x_pts=tv_data['crude_birth_rate']['times'], y_pts=[cbr / 1000. for cbr in tv_data['crude_birth_rate']['values']]
    )

    life_expectancy_func = get_linear_interpolation_function(
        x_pts=tv_data['life_expectancy']['times'], y_pts=tv_data['life_expectancy']['values']
    )
    
    passive_detection_func = get_sigmoidal_interpolation_function(
        x_pts=[1950., 2024.], y_pts=[0., Parameter('current_passive_detection_rate')], curvature=16
    )

    # 'tx_duration': 0.5,
    # 'tx_prop_success': .9,
    # 'tx_prop_death_among_failure': .2



    """
    Build the model
    """
    compartments = (
        "susceptible", 
        "latent_early",
        "latent_late",
        "infectious",
        "treatment", 
        "recovered",
    )
    model = CompartmentalModel(
        times=(config["start_time"], config["end_time"]),
        compartments=compartments,
        infectious_compartments=("infectious",),
    )
    model.set_initial_population(
        distribution=
        {
            "susceptible": config["population"] - config["seed"], 
            "infectious": config["seed"],
        },
    )
    
    # add birth and all cause mortality
    model.add_crude_birth_flow(
        name="birth",
        birth_rate=crude_birth_rate_func,
        dest="susceptible"
    )

    model.add_universal_death_flows(
        name="all_cause_mortality",
        death_rate= 1. / life_expectancy_func
    )

    # infection and reinfection flows
    model.add_infection_frequency_flow(
        name="infection", 
        contact_rate=Parameter("transmission_rate"),
        source="susceptible", 
        dest="latent_early",
    )
    for reinfection_source in ["latent_late", "recovered"]:
        model.add_infection_frequency_flow(
            name=f"reinfection_{reinfection_source}", 
            contact_rate=Parameter("transmission_rate") * Parameter(f"rr_reinfection_{reinfection_source}"),
            source=reinfection_source, 
            dest="latent_early",
        )

    # latency progression
    model.add_transition_flow(
        name="stabilisation",
        fractional_rate=Parameter("stabilisation_rate"),
        source="latent_early",
        dest="latent_late",
    )
    for progression_type in ["early", "late"]:
        model.add_transition_flow(
            name=f"progression_{progression_type}",
            fractional_rate=Parameter(f"activation_rate_{progression_type}"),
            source=f"latent_{progression_type}",
            dest="infectious",
        )

    # natural recovery
    model.add_transition_flow(
        name="self_recovery",
        fractional_rate=Parameter("self_recovery_rate"),
        source="infectious",
        dest="recovered",
    )

    # TB-specific death
    model.add_death_flow(
        name="active_tb_death",
        death_rate=Parameter("tb_death_rate"),
        source="infectious",
    )

    # detection of active TB
    model.add_transition_flow(
        name="tb_detection",
        fractional_rate=passive_detection_func,
        source="infectious",
        dest="treatment",
    )

    # treatment exit flows
    # model.add_transition_flow(
    #     name="tx_recovery",
    #     fractional_rate=Parameter("tx_success_rate"),
    #     source="treatment",
    #     dest="recovered",
    # )
    # model.add_transition_flow(
    #     name="tx_relapse",
    #     fractional_rate=Parameter("tx_relapse_rate"),
    #     source="treatment",
    #     dest="infectious",
    # )
    # model.add_death_flow(
    #     name="tx_death",
    #     death_rate=Parameter("tx_death_rate"),
    #     source="treatment",
    # )

    """
       Request Derived Outputs
    """
    # Raw outputs
    model.request_output_for_compartments(name="raw_ltbi_prevalence", compartments=["latent_early", "latent_late"])
    model.request_output_for_compartments(name="raw_tb_prevalence", compartments=["infectious"])

    model.request_output_for_flow(name="progression_early", flow_name="progression_early")
    model.request_output_for_flow(name="progression_late", flow_name="progression_late")
    model.request_aggregate_output(name="raw_tb_incidence", sources=["progression_early", "progression_late"])

    model.request_output_for_flow(name="raw_notifications", flow_name="tb_detection")

    model.request_output_for_flow(name="active_tb_death", flow_name="active_tb_death")
    # model.request_output_for_flow(name="tx_death", flow_name="tx_death")
    # model.request_aggregate_output(name="all_tb_deaths", sources=["active_tb_death", "tx_death"])

    # Outputs relative to population size
    model.request_output_for_compartments(name="population", compartments=compartments)
    model.request_function_output(name="ltbi_prop", func=DerivedOutput("raw_ltbi_prevalence") / DerivedOutput("population"))
    model.request_function_output(name="tb_prevalence_per100k", func=1.e5 * DerivedOutput("raw_tb_prevalence") / DerivedOutput("population"))
    model.request_function_output(name="tb_incidence_per100k", func=1.e5 * DerivedOutput("raw_tb_incidence") / DerivedOutput("population"))

    return model


config = {
    "start_time": 1850,
    "end_time": 2050,
    "population": 1.e6,
    "seed": 100,   
}

model = get_tb_model(config)

In [None]:
test_params = {
    'activation_rate_early': 1.,
    'activation_rate_late': 1.,
    'current_passive_detection_rate': 1.,
    'rr_reinfection_latent_late': 1.,
    'rr_reinfection_recovered': 1.,
    'self_recovery_rate': 1.,
    'stabilisation_rate': 1.,
    'tb_death_rate': 1.,
    'transmission_rate': 1.,
    'tx_duration': 0.5,
    'tx_prop_death': .04  # WHO
}
model.run(test_params)

In [None]:
model.get_derived_outputs_df()['population'].plot()

In [None]:
import numpy as np
output = model.get_derived_outputs_df()['active_cases']
data = (output + pd.Series(np.random.uniform(-10000, 10000, len(output)))).iloc[50:150]
data.name = "active_cases"
data.plot()

# Calibration 

In [None]:
from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm
import estival.priors as esp
import estival.targets as est
from estival.sampling import tools as esamptools

import pymc as pm
import pandas as pd

import arviz as az #For convergence stats and diagnostics
import time
import matplotlib .pyplot as plt

In [None]:
all_priors = [
    esp.UniformPrior("contact_rate", [0.1, 0.3]),
    esp.UniformPrior("recovery_rate", [0.05, 0.2])
]

targets = [
    est.NormalTarget("active_cases", data, stdev=esp.UniformPrior("dispersion", [10., 10000.]))
]


def get_bcm_object(model, all_mle_params, fixed_param_indices=[]):
    
    priors = [p for i, p in enumerate(all_priors) if i not in fixed_param_indices]
    mle_params = {p.name: all_mle_params[p.name] for i, p in enumerate(all_priors) if i in fixed_param_indices}

    bcm = BayesianCompartmentalModel(model, mle_params, priors, targets)

    return bcm
    

### Optimisation

In [None]:
# Import nevergrad
import nevergrad as ng

# Import our convenience wrapper
from estival.wrappers.nevergrad import optimize_model
bcm = get_bcm_object(model, {}, [])

orunner = optimize_model(bcm)
rec = orunner.minimize(1000)
all_mle_params = rec.value[1]

In [None]:
model.run(all_mle_params)
model.get_derived_outputs_df().plot()
data.plot(style='.',color='black') 

### Sampling

In [None]:
def run_sampling(model, all_mle_params, fixed_param_indices=[]):

    bcm = get_bcm_object(model, all_mle_params, fixed_param_indices)
    with pm.Model() as model:
        variables = epm.use_model(bcm)
        idata = pm.sample(step=[pm.DEMetropolis(variables)], draws=10000, tune=1000, cores=4, chains=4) 
    
    return bcm, idata

def check_sampling(bcm, idata, n_samples=1000):
    az.plot_trace(idata)

    sample_idata = az.extract(idata, num_samples = n_samples)
    mres = esamptools.model_results_for_samples(sample_idata, bcm)
    esamptools.quantiles_for_results(mres.results,[0.025,0.25,0.5,0.75,0.975])["active_cases"].plot()
    data.plot(style='.',color='black') 

def get_output_values(bcm, idata, n_samples=1000):
    sample_idata = az.extract(idata, num_samples = n_samples)
    
    mres = esamptools.model_results_for_samples(sample_idata, bcm)
    output_values = mres.results['active_cases'].loc[100.]
    
    return output_values



In [None]:
bcm, idata = run_sampling(model, all_mle_params, fixed_param_indices=[])
check_sampling(bcm, idata)

In [None]:
bcm_2, idata_2 = run_sampling(model, all_mle_params, fixed_param_indices=[0])
check_sampling(bcm_2, idata_2)