In [None]:
import arviz as az
from summer2 import CompartmentalModel
from summer2 import AgeStratification
from summer2 import Overwrite
from summer2.parameters import Parameter, Function
from tbdynamics.tools.utils import (
    interpolate_age_strata_values,
    adjust_latency_rates,
)
from estival.model import BayesianCompartmentalModel
from tbdynamics.constants import AGE_STRATA, QUANTILES, PLACEHOLDER_PARAM
from tbdynamics.tools.inputs import load_params

from tbdynamics.settings import CM_PATH
from tbdynamics.camau.calibration.utils import get_all_priors
import estival.sampling.tools as esamp
import estival.targets as est
import pandas as pd

import plotly.graph_objects as go

In [None]:
compartments = ["early_latent", "late_latent", "infectious"]

def build_model(fixed_params) -> CompartmentalModel:
    model = CompartmentalModel(
        times=(2000.0, 2005.0),
        compartments=compartments,
        infectious_compartments="infectious",
        timestep=1.0,
    )
    model.set_initial_population({"early_latent": 6.0})

    latency_flows = [
        ("stabilisation", PLACEHOLDER_PARAM, "early_latent", "late_latent"),
        ("early_activation", PLACEHOLDER_PARAM, "early_latent", "infectious"),
        ("late_activation", PLACEHOLDER_PARAM, "late_latent", "infectious"),
    ]
    for latency_flow in latency_flows:
        model.add_transition_flow(*latency_flow)
    age_strat = get_age_strat(fixed_params)
    model.stratify_with(age_strat)
    for age_stratum in AGE_STRATA:
        model.request_output_for_compartments(
            f"total_populationXage_{age_stratum}",
            compartments,
            strata={"age": str(age_stratum)},
        )
        model.request_output_for_compartments(
            f"infectious_sizeXage_{age_stratum}",
            "infectious",
            strata={"age": str(age_stratum)},
        )
        model.request_cumulative_output(f"cummulative_infectiousXage_{age_stratum}", f"infectious_sizeXage_{age_stratum}", 2000.0)
    adults_pop = [
        f"infectious_sizeXage_{adults_stratum}" for adults_stratum in AGE_STRATA[2:]
    ]
    children_pop = [
        f"infectious_sizeXage_{adults_stratum}" for adults_stratum in AGE_STRATA[:2]
    ]
    model.request_aggregate_output("infectious_adults", adults_pop)
    model.request_aggregate_output("infectious_children", children_pop)
    return model


def get_age_strat(fixed_params):
    strat = AgeStratification("age", AGE_STRATA, compartments)

    early_activation_rates = interpolate_age_strata_values(
        fixed_params["age_latency"]["early_activation"]
    )
    stabilisation_rates = interpolate_age_strata_values(
        fixed_params["age_latency"]["stabilisation"]
    )
    late_activation_rates = interpolate_age_strata_values(
        fixed_params["age_latency"]["late_activation"]
    )

    early_activation_func, stabilisation_func, late_activation_func = ({}, {}, {})
    for age in AGE_STRATA:
        age_latency = Function(
            adjust_latency_rates,
            [
                early_activation_rates[age],
                stabilisation_rates[age],
                late_activation_rates[age],
                0.0,
                Parameter("early_prop_adjuster"),
                Parameter("late_reactivation_adjuster"),
            ],
        )
        early_activation_func[str(age)] = Overwrite(age_latency[0])
        stabilisation_func[str(age)] = Overwrite(age_latency[1])
        late_activation_func[str(age)] = Overwrite(age_latency[2])

    # Set flow adjustments clearly separated by flow name
    strat.set_flow_adjustments("early_activation", early_activation_func)
    strat.set_flow_adjustments("stabilisation", stabilisation_func)
    strat.set_flow_adjustments("late_activation", late_activation_func)
    return strat

In [None]:
fixed_params = load_params(CM_PATH / "params.yml")
covid_effects = {"detection_reduction": True, "contact_reduction": False}

In [None]:
def get_targets():
    return [
        est.NormalTarget("infectious_sizeXage_0", pd.Series([0], index=[2000]), 0.0),
    ]


def get_bcm(
    params,
    covid_effects,
):
    params = params or {}
    fixed_params = load_params(CM_PATH / "params.yml")
    priors = get_all_priors(covid_effects)
    targets = get_targets()
    tb_model = build_model(fixed_params)
    return BayesianCompartmentalModel(tb_model, params, priors, targets)

In [None]:
idata_extract = az.from_netcdf('idata_extracted.nc')
params = {}

In [None]:
bcm = get_bcm(params, covid_effects)
base_results = esamp.model_results_for_samples(idata_extract, bcm).results
base_quantiles = esamp.quantiles_for_results(base_results, QUANTILES)

In [None]:
fig = go.Figure()
time_index = base_quantiles[f"cummulative_infectiousXage_{AGE_STRATA[0]}"].loc[:, 0.500].index
t_min = time_index.min()
t_max = time_index.max()

# Loop through all ages and add them to the same figure
[
    fig.add_trace(
        go.Scatter(
            x=time_index - t_min,
            y=base_quantiles[f"cummulative_infectiousXage_{age}"].loc[:, 0.500].values,
            mode="lines",
            name=f"Age {age}",
        )
    )
    for age in AGE_STRATA
]

fig.update_layout(
    title="Infectious by Age Group",
    xaxis_title="Time from Infection",
    yaxis_title="Cummulative infectious",
    legend_title="Age group",
    template="simple_white",
    width=800,
    height=600,
)

fig.show()