In [None]:
import matplotlib.pyplot as plt
from ipywidgets import FloatSlider, interact

from mxlpy import Model, Simulator, fns, plot
from mxlpy.types import unwrap


def simulate_and_plot(model: Model, beta: float, gamma: float) -> None:
    res = unwrap(
        Simulator(model)
        .update_parameters({"beta": beta, "gamma": gamma})
        .simulate(100)
        .get_result()
    )

    _, (ax1, ax2) = plot.two_axes(figsize=(7.5, 3.5))
    _ = plot.lines(res.variables, ax=ax1)
    _ = plot.lines(res.fluxes, ax=ax2)
    ax1.set(xlabel="Time / a.u.", ylabel="Relative Population")
    ax2.set(xlabel="Time / a.u.", ylabel="Rate of change")
    plt.show()


def sir() -> Model:
    return (
        Model()
        .add_variables({"s": 0.9, "i": 0.1, "r": 0.0})
        .add_parameters({"beta": 0.2, "gamma": 0.1})
        .add_reaction(
            "infection",
            fns.mass_action_2s,
            args=["s", "i", "beta"],
            stoichiometry={"s": -1, "i": 1},
        )
        .add_reaction(
            "recovery",
            fns.mass_action_1s,
            args=["i", "gamma"],
            stoichiometry={"i": -1, "r": 1},
        )
    )


def wrapper(beta: float, gamma: float) -> None:
    return simulate_and_plot(sir(), beta, gamma)


_ = interact(
    wrapper,
    beta=FloatSlider(0.2, min=0, max=1, continuous_update=False),
    gamma=FloatSlider(0.1, min=0.0, max=1.0, continuous_update=False),
)