In [None]:
import arviz as az
from summer2 import CompartmentalModel
from summer2 import AgeStratification
from summer2 import Overwrite
from summer2.parameters import Parameter, Function
from tbdynamics.tools.utils import (
    interpolate_age_strata_values,
    adjust_latency_rates,
)
from estival.model import BayesianCompartmentalModel
from tbdynamics.constants import AGE_STRATA, QUANTILES, PLACEHOLDER_PARAM
from tbdynamics.tools.inputs import load_params

from tbdynamics.settings import CM_PATH
from tbdynamics.camau.calibration.utils import get_all_priors
import estival.sampling.tools as esamp
import estival.targets as est
import pandas as pd

import plotly.graph_objects as go

In [None]:
fixed_params = load_params(CM_PATH / "params.yml")
adult_latency = {k: v[15] for k, v in fixed_params["age_latency"].items()}
compartments = ["early_latent", "late_latent", "infectious"]
model = CompartmentalModel(
    times=(0.0, 20.0),
    compartments=compartments,
    infectious_compartments="infectious",
    timestep=0.01,
)
model.set_initial_population({"early_latent": 1.0})
latency_flows = [
    ("stabilisation", Parameter("stabilisation"), "early_latent", "late_latent"),
    ("early_activation", Parameter("early_activation"), "early_latent", "infectious"),
    ("late_activation", Parameter("late_activation"), "late_latent", "infectious"),
]
for latency_flow in latency_flows:
    model.add_transition_flow(*latency_flow)
model.run(adult_latency)
model.get_outputs_df()["infectious"].plot()

In [None]:
 age_strata = [0,5,15]
def get_age_strat(fixed_params):
    strat = AgeStratification("age", age_strata, compartments)
    for age in age_strata:
        for flow_name, latency_params in fixed_params["age_latency"].items():
            adjs = {str(k): Overwrite(v) for k, v in latency_params.items()}
            strat.set_flow_adjustments(flow_name, adjs)
    return strat

In [None]:
model = CompartmentalModel(
    times=(0.0, 20.0),
    compartments=compartments,
    infectious_compartments="infectious",
    timestep=0.01,
)
model.set_initial_population({"early_latent": 3.0})
# model.add_importation_flow("seeding", 1.0, "early_latent", split_imports=False, dest_strata={'age':'0'})
latency_flows = [
    ("stabilisation", Parameter("stabilisation"), "early_latent", "late_latent"),
    ("early_activation", Parameter("early_activation"), "early_latent", "infectious"),
    ("late_activation", Parameter("late_activation"), "late_latent", "infectious"),
]
for latency_flow in latency_flows:
    model.add_transition_flow(*latency_flow)
model.stratify_with(get_age_strat(fixed_params))

In [None]:
model.run()

In [None]:
res = model.get_outputs_df()

In [None]:
res[[f"infectiousXage_{age}" for age in age_strata]].plot()