In [None]:
from summer2 import CompartmentalModel
from summer2.parameters import Parameter
from typing import List, Dict

In [None]:
def get_sir_model(config: dict):

    model = CompartmentalModel(
        times=(0.0, config["end_time"]),
        compartments=(
            "susceptible", 
            "infectious", 
            "recovered",
        ),
        infectious_compartments=("infectious",),
    )
    model.set_initial_population(
        distribution=
        {
            "susceptible": config["population"] - config["seed"], 
            "infectious": config["seed"],
        },
    )
    
    model.add_infection_frequency_flow(
        name="infection", 
        contact_rate=Parameter("contact_rate"), 
        source="susceptible", 
        dest="infectious",
    )
    model.add_transition_flow(
        name="recovery",
        fractional_rate=Parameter("recovery_rate"),
        source="infectious",
        dest="recovered",
    )
    model.request_output_for_compartments(name="active_cases", compartments=["infectious"])

    return model

config = {
    "end_time": 200.,
    "population": 1.e6,
    "seed": 100,   
}

model = get_sir_model(config)

In [None]:
test_params = {
   "contact_rate": 0.2,
   "recovery_rate": 0.1, 
}
model.run(test_params)

In [None]:
import numpy as np
output = model.get_derived_outputs_df()['active_cases']
data = (output + pd.Series(np.random.uniform(-10000, 10000, len(output)))).iloc[50:150]
data.name = "active_cases"
data.plot()

# Calibration 

In [None]:
from estival.model import BayesianCompartmentalModel
from estival.wrappers import pymc as epm
import estival.priors as esp
import estival.targets as est
from estival.sampling import tools as esamptools

import pymc as pm
import pandas as pd

import arviz as az #For convergence stats and diagnostics
import time
import matplotlib .pyplot as plt

In [None]:
all_priors = [
    esp.UniformPrior("contact_rate", [0.1, 0.3]),
    esp.UniformPrior("recovery_rate", [0.05, 0.2])
]

targets = [
    est.NormalTarget("active_cases", data, stdev=esp.UniformPrior("dispersion", [10., 10000.]))
]


def get_bcm_object(model, all_mle_params, fixed_param_indices=[]):
    
    priors = [p for i, p in enumerate(all_priors) if i not in fixed_param_indices]
    mle_params = {p.name: all_mle_params[p.name] for i, p in enumerate(all_priors) if i in fixed_param_indices}

    bcm = BayesianCompartmentalModel(model, mle_params, priors, targets)

    return bcm
    

### Optimisation

In [None]:
# Import nevergrad
import nevergrad as ng

# Import our convenience wrapper
from estival.wrappers.nevergrad import optimize_model
bcm = get_bcm_object(model, {}, [])

orunner = optimize_model(bcm)
rec = orunner.minimize(1000)
all_mle_params = rec.value[1]

In [None]:
model.run(all_mle_params)
model.get_derived_outputs_df().plot()
data.plot(style='.',color='black') 

### Sampling

In [None]:
def run_sampling(model, all_mle_params, fixed_param_indices=[]):

    bcm = get_bcm_object(model, all_mle_params, fixed_param_indices)
    with pm.Model() as model:
        variables = epm.use_model(bcm)
        idata = pm.sample(step=[pm.DEMetropolis(variables)], draws=10000, tune=1000, cores=4, chains=4) 
    
    return bcm, idata

def check_sampling(bcm, idata, n_samples=1000):
    az.plot_trace(idata)

    sample_idata = az.extract(idata, num_samples = n_samples)
    mres = esamptools.model_results_for_samples(sample_idata, bcm)
    esamptools.quantiles_for_results(mres.results,[0.025,0.25,0.5,0.75,0.975])["active_cases"].plot()
    data.plot(style='.',color='black') 

def get_output_values(bcm, idata, n_samples=1000):
    sample_idata = az.extract(idata, num_samples = n_samples)
    
    mres = esamptools.model_results_for_samples(sample_idata, bcm)
    output_values = mres.results['active_cases'].loc[100.]
    
    return output_values



In [None]:
bcm, idata = run_sampling(model, all_mle_params, fixed_param_indices=[])
check_sampling(bcm, idata)

In [None]:
bcm_2, idata_2 = run_sampling(model, all_mle_params, fixed_param_indices=[0])
check_sampling(bcm_2, idata_2)