In [None]:
from summer2 import CompartmentalModel
from summer2 import AgeStratification
from summer2 import Overwrite
from summer2.parameters import Parameter
from tbdynamics.tools.inputs import load_params
from tbdynamics.settings import CM_PATH
import pandas as pd
import plotly.graph_objects as go

In [None]:
fixed_params = load_params(CM_PATH / "params.yml")
adult_latency = {k: v[15] for k, v in fixed_params["age_latency"].items()}
compartments = ["early_latent", "late_latent", "infectious"]
model = CompartmentalModel(
    times=(0.0, 20.0),
    compartments=compartments,
    infectious_compartments="infectious",
    timestep=0.01,
)
model.set_initial_population({"early_latent": 1.0})
latency_flows = [
    ("stabilisation", Parameter("stabilisation"), "early_latent", "late_latent"),
    ("early_activation", Parameter("early_activation"), "early_latent", "infectious"),
    ("late_activation", Parameter("late_activation"), "late_latent", "infectious"),
]
for latency_flow in latency_flows:
    model.add_transition_flow(*latency_flow)
model.run(adult_latency)
model.get_outputs_df()["infectious"].plot()

In [None]:
age_strata = [0,5,15]
def get_age_strat(fixed_params, start_group):
    strat = AgeStratification("age", age_strata, compartments)

    for age in age_strata:
        for flow_name, latency_params in fixed_params["age_latency"].items():
            adjs = {str(k): Overwrite(v) for k, v in latency_params.items()}
            strat.set_flow_adjustments(flow_name, adjs)

    # Correct dictionary comprehension
    pop_splits = {str(age): 0 for age in age_strata}
    pop_splits[str(start_group)] = 1
    strat.set_population_split(pop_splits)

    return strat

In [None]:
def run_model_for_start_group(start_group):
    model = CompartmentalModel(
        times=(0.0, 20.0),
        compartments=compartments,
        infectious_compartments=["infectious"],
        timestep=0.1,
    )
    # Set a default initial population in early_latent
    model.set_initial_population({"early_latent": 1.0})

    # Add transition flows
    latency_flows = [
        ("stabilisation", Parameter("stabilisation"), "early_latent", "late_latent"),
        ("early_activation", Parameter("early_activation"), "early_latent", "infectious"),
        ("late_activation", Parameter("late_activation"), "late_latent", "infectious"),
    ]
    for latency_flow in latency_flows:
        model.add_transition_flow(*latency_flow)

    # Apply the age stratification with the given start_group
    model.stratify_with(get_age_strat(fixed_params, start_group))
    model.run()
    return model.get_outputs_df()

# --- Loop over each start_group, run the model, build a combined results DataFrame ---
all_results = []
for start_group in age_strata:
    # Run model
    res = run_model_for_start_group(start_group)

    # Extract the column matching "infectiousXage_{start_group}" plus index``
    col_name = f"infectiousXage_{start_group}"

    # Make a DataFrame with [time, infectious] and add metadata about the start age
    subset = pd.DataFrame({
        "time": res.index,
        "infectious": res[col_name],
        "start_age": start_group
    })
    all_results.append(subset)

# Concatenate results from all runs
final_df = pd.concat(all_results, ignore_index=True)

In [None]:
fig = go.Figure()
for start_group in age_strata:
    subset = final_df[final_df["start_age"] == start_group]
    fig.add_trace(
        go.Scatter(
            x=subset["time"],
            y=subset["infectious"],
            mode="lines",
            name=f"Start age: {start_group}",
        )
    )

fig.update_layout(
    title="",  # Add or remove a title as desired
    xaxis_title="Time from infectious",
    yaxis_title="Number infectious"
)
fig.show()

In [None]:
final_df