In [None]:
import contextlib
from collections.abc import Generator

import matplotlib.patches as mpatches
import matplotlib.path as mpath
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from cycler import cycler
from matplotlib import patches, patheffects
from matplotlib.axes import Axes

from mxlpy import distributions, plot
from mxlpy.distributions import sample
from mxlpy.fns import michaelis_menten_1s

C_ODE = "#1d5f66"
C_ML = "#6b134b"
C_GRAY = "#656565"


def scale_hex_lightness(name: str, scale: float) -> tuple[float, float, float]:
    import colorsys

    from matplotlib.colors import ColorConverter

    # convert rgb to hls
    hue, lum, sat = colorsys.rgb_to_hls(*ColorConverter.to_rgb(name))
    return colorsys.hls_to_rgb(hue, min(1, lum * scale), s=sat)


@contextlib.contextmanager
def myxkcd(
    scale: float = 1,
    length: float = 100,
    randomness: float = 2,
) -> Generator[None, None, None]:
    with plt.rc_context(
        {
            # "font.family": [
            # "xkcd",
            # "xkcd Script",
            # "Comic Neue",
            # "Comic Sans MS",
            # ],
            # "font.size": 14.0,
            "path.sketch": (scale, length, randomness),
            "path.effects": [patheffects.withStroke(linewidth=4, foreground="w")],
            "lines.linewidth": 2.0,
            # "figure.facecolor": "white",
            # "grid.linewidth": 0.0,
            # "axes.grid": False,
            "axes.unicode_minus": False,
            "axes.edgecolor": C_GRAY,
            "xtick.major.size": 8,
            "xtick.major.width": 2,
            "ytick.major.size": 8,
            "ytick.major.width": 2,
            "axes.prop_cycle": cycler(
                color=[
                    C_ODE,
                    scale_hex_lightness(C_ODE, 0.6),
                    scale_hex_lightness(C_ODE, 1.4),
                ]
            ),
            "axes.linewidth": 1.5,
        }
    ):
        yield


def savefig(name: str) -> None:
    plt.savefig(f"{name}.png", dpi=100)

## Time course

In [None]:
with myxkcd():
    fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")
    ax.plot(x := np.linspace(0, 1, 101), michaelis_menten_1s(x, 1, 0.1))
    ax.set(
        title="Time course",
        xlabel="Time",
        ylabel="Concentration",
    )
    ax.spines[["right", "top"]].set_visible(False)
    savefig("time-course")
    plt.show()

## Protocol time course

In [None]:
protocol = pd.DataFrame(
    {
        pd.Timedelta(seconds=0.33): {"k": 1},
        pd.Timedelta(seconds=0.66): {"k": 2},
        pd.Timedelta(seconds=1): {"k": 1},
    }
).T


with myxkcd():
    fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")
    ax.plot(x := np.linspace(0, 1, 101), michaelis_menten_1s(x, 1, 0.1))
    ax.annotate("Light off", (0.075, 0.2), fontsize=8)
    ax.annotate("Light on", (0.4, 0.3), fontsize=8)
    ax.annotate("Light off", (0.725, 0.2), fontsize=8)

    plot.shade_protocol(protocol["k"], ax=ax, alpha=0.1, add_legend=False)
    ax.set(
        title="Protocol time course",
        xlabel="Time",
        ylabel="Concentration",
    )
    ax.spines[["right", "top"]].set_visible(False)
    savefig("protocol-time-course")
    plt.show()

## Steady-state

In [None]:
with myxkcd():
    fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")
    ax.plot(
        x := np.linspace(0, 1, 101),
        michaelis_menten_1s(x, 1.0, 0.01),
    )
    ax.set_ylim(0, 1)
    ax.annotate(
        "",
        (0.8, 0.9),
        xytext=(0.8, 0.8),
        arrowprops={
            "arrowstyle": "-[, widthB=3.0, lengthB=0.5",
            # "lw": 2.0,
            # "color": "k",
        },
    )
    ax.annotate(
        "Concentration \ndoesn'tchange",
        (0.8, 0.8),
        ha="center",
        va="top",
    )

    ax.set(
        title="Steady state",
        xlabel="Time",
        ylabel="Concentration",
    )
    ax.spines[["right", "top"]].set_visible(False)
    savefig("steady-state")
    plt.show()

## Parameter scan

In [None]:
with myxkcd():
    fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")
    ax.plot(x := np.linspace(0, 1, 101), michaelis_menten_1s(x, 1, 0.1))
    ax.set(
        title="Parameter scan",
        xlabel="Parameter value",
        ylabel="Steady-state \nConcentration",
    )
    ax.spines[["right", "top"]].set_visible(False)
    savefig("parameter-scan")
    plt.show()

## Time course by parameters

In [None]:
with myxkcd():
    fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")
    ax.plot(x := np.linspace(0, 1, 101), michaelis_menten_1s(x, 1, 0.1))
    ax.plot(x := np.linspace(0, 1, 101), michaelis_menten_1s(x, 0.9, 0.1))
    ax.plot(x := np.linspace(0, 1, 101), michaelis_menten_1s(x, 0.8, 0.1))

    ax.annotate(
        "p = 1",
        (1.0, michaelis_menten_1s(1.0, 1.0, 0.1)),
        va="center",
        color="C0",
    )
    ax.annotate(
        "p = 0.9",
        (1.0, michaelis_menten_1s(1.0, 0.9, 0.1)),
        va="center",
        color="C1",
    )
    ax.annotate(
        "p = 0.8",
        (1.0, michaelis_menten_1s(1.0, 0.8, 0.1)),
        va="center",
        color="C2",
    )

    ax.set(
        title="Time course by parameter",
        xlabel="Time",
        ylabel="Concentration",
    )
    ax.spines[["right", "top"]].set_visible(False)
    savefig("time-course-by-parameter")
    plt.show()

## Parameter scan 2D

In [None]:
data = pd.DataFrame(
    {
        "x": {
            (1.0, 1.0): 1.0,
            (1.0, 1.3333333333333333): 0.75,
            (1.0, 1.6666666666666665): 0.6000000000000001,
            (1.0, 2.0): 0.5,
            (1.5, 1.0): 1.4999999999999811,
            (1.5, 1.3333333333333333): 1.1249999999999927,
            (1.5, 1.6666666666666665): 0.8999999999999069,
            (1.5, 2.0): 0.75,
            (2.0, 1.0): 2.000000000000164,
            (2.0, 1.3333333333333333): 1.5,
            (2.0, 1.6666666666666665): 1.200000000000353,
            (2.0, 2.0): 1.0,
        }
    }
)

data.index.names = ["p1", "p2"]

# with myxkcd():
_ = plot.heatmaps_from_2d_idx(data)
savefig("parameter-scan-2d")
plt.show()

## Elasticities

In [None]:
with myxkcd():
    km = 1.0
    fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")
    ax.plot(x := np.linspace(0, 1, 101), michaelis_menten_1s(x, 1, km))

    x = 0.5
    y = michaelis_menten_1s(0.5, 1, km)
    dx = 0.1
    dy = michaelis_menten_1s(x + dx, 1, km) - y
    ax.arrow(x, y, dx, 0)
    ax.arrow(x + dx, y, 0, dy)

    ax.annotate(
        r"$\Delta$ C",
        (x, y - 0.05),
        fontsize=8,
    )
    ax.annotate(
        r"$\Delta$ Flux",
        (x + dx * 1.2, y + dy / 4),
        fontsize=8,
    )

    ax.set(
        title="Variable elasticity",
        xlabel="Concentration",
        ylabel="Flux",
    )
    ax.spines[["right", "top"]].set_visible(False)
    savefig("variable-elasticity")
    plt.show()

In [None]:
with myxkcd():
    km = 1.0
    fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")
    ax.plot(x := np.linspace(0, 1, 101), michaelis_menten_1s(x, 1, km))

    x = 0.5
    y = michaelis_menten_1s(0.5, 1, km)
    dx = 0.1
    dy = michaelis_menten_1s(x + dx, 1, km) - y
    ax.arrow(x, y, dx, 0)
    ax.arrow(x + dx, y, 0, dy)

    ax.annotate(
        r"$\Delta$ p",
        (x, y - 0.05),
        fontsize=8,
    )
    ax.annotate(
        r"$\Delta$ Flux",
        (x + dx * 1.2, y + dy / 4),
        fontsize=8,
    )

    ax.set(
        title="Parameter elasticity",
        xlabel="Parameter",
        ylabel="Flux",
    )
    ax.spines[["right", "top"]].set_visible(False)
    savefig("parameter-elasticity")
    plt.show()

## Response coefficients

In [None]:
with myxkcd():
    fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")
    ax.plot(
        x := np.linspace(0, 1, 101),
        michaelis_menten_1s(x, 0.5, 0.05),
    )
    ax.set_ylim(0, 1)

    x = 1.0
    y = michaelis_menten_1s(x, 0.5, 0.05)
    dy = michaelis_menten_1s(x, 0.55, 0.05) - y
    ax.arrow(x, y, 0, dy, head_width=0.04, head_length=0.01)
    ax.arrow(x, y, 0, -dy, head_width=0.04, head_length=0.01)
    ax.annotate(
        r"$\Delta$ C caused by $\Delta$ p",
        (x * 1.05, y),
        fontsize=8,
    )
    ax.set(
        title="Response coefficient",
        xlabel="Time",
        ylabel="Concentration",
    )
    ax.spines[["right", "top"]].set_visible(False)
    savefig("response-coefficient")
    plt.show()

## Fitting

In [None]:
with myxkcd():
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6, 2.5), layout="constrained")

    x = np.linspace(0, 1, 101)
    ax = ax1
    ax.plot(x, michaelis_menten_1s(x, 1, 0.1), linestyle="dashed", label="Data")
    ax.plot(x, michaelis_menten_1s(x, 0.8, 0.1), label="Prediction")

    for xs in [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
        ax.annotate(
            "",
            (xs, michaelis_menten_1s(xs, 1, 0.1)),
            (xs, michaelis_menten_1s(xs, 0.8, 0.1)),
            arrowprops={"arrowstyle": "->"},
        )
    ax.spines[["right", "top"]].set_visible(False)
    ax.set(
        title="Fitting",
        xlabel="Time",
        ylabel="Concentration",
    )
    ax.legend()

    ax = ax2
    data = pd.DataFrame(
        {
            "S1": {"model": 1.0, "data": 0.8},
            "S2": {"model": 1.2, "data": 0.7},
            "S3": {"model": 1.3, "data": 1.1},
        }
    ).T
    sns.barplot(data["model"], ax=ax)
    sns.barplot(data["data"], ax=ax)
    ax.set(ylabel="Concentration")

    ax.annotate(
        "",
        (0, 1.0),
        (0, 0.8),
        arrowprops={"arrowstyle": "->"},
    )
    ax.annotate(
        "",
        (1, 1.2),
        (1, 0.7),
        arrowprops={"arrowstyle": "->"},
    )
    ax.annotate(
        "",
        (2, 1.3),
        (2, 1.1),
        arrowprops={"arrowstyle": "->"},
    )
    ax.spines[["right", "top"]].set_visible(False)

    savefig("fitting")
    plt.show()

## Parameter distributions

In [None]:
with myxkcd():
    fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")
    sns.kdeplot(distributions.Normal(5, 1).sample(1000), fill=True, ax=ax)
    ax.set(title="Parameter distribution", xlabel="Parameter value")
    ax.spines[["right", "top"]].set_visible(False)
    savefig("parameter-distribution")
    plt.show()

In [None]:
with myxkcd():
    fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")
    sns.violinplot(
        sample(
            {"p1": distributions.Normal(3, 1), "p2": distributions.Normal(6, 1)}, 1000
        ),
        fill=True,
        alpha=0.25,
        ax=ax,
        # hue_order=[True, False],
        split=True,
    )
    ax.set_xticks([])
    ax.set(xlabel="Parameters", ylabel="Value")
    ax.spines[["right", "top"]].set_visible(False)
    savefig("violins")
    plt.show()

## MC time course

In [None]:
x = np.linspace(0, 1, 101)
pars = sample(
    {
        "vmax": distributions.Normal(1, 0.12),
        "km": distributions.Uniform(0.1, 0.2),
    },
    10,
)

mm_rate = pd.DataFrame(
    {k: michaelis_menten_1s(x, v["vmax"], v["km"]) for k, v in pars.iterrows()}
)

with myxkcd():
    fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")

    plot.line_mean_std(mm_rate, ax=ax, grid=False)
    ax.set(title="MC time course", xlabel="Time", ylabel="Concentration")
    ax.spines[["right", "top"]].set_visible(False)
    savefig("mc-time-course")
    plt.show()

In [None]:
x = np.linspace(0, 1, 101)
pars = sample(
    {
        "vmax": distributions.Normal(1, 0.12),
        "km": distributions.Uniform(0.1, 0.2),
    },
    10,
)

mm_rate = pd.DataFrame(
    {k: michaelis_menten_1s(x, v["vmax"], v["km"]) for k, v in pars.iterrows()}
)
mm_rate.index = x

with myxkcd():
    fig, ax = plt.subplots(figsize=(3, 2.5), layout="constrained")

    plot.line_mean_std(mm_rate, ax=ax, grid=False)
    ax.set(title="MC protocol time course", xlabel="Time", ylabel="Concentration")
    ax.spines[["right", "top"]].set_visible(False)

    protocol = pd.DataFrame(
        {
            pd.Timedelta(seconds=0.33): {"k": 1},
            pd.Timedelta(seconds=0.66): {"k": 2},
            pd.Timedelta(seconds=1): {"k": 1},
        }
    ).T
    ax.annotate("Light off", (0.075, 0.2), fontsize=8)
    ax.annotate("Light on", (0.4, 0.3), fontsize=8)
    ax.annotate("Light off", (0.725, 0.2), fontsize=8)
    plot.shade_protocol(protocol["k"], ax=ax, alpha=0.1, add_legend=False)

    savefig("mc-protocol-time-course")
    plt.show()

## Surrogate

In [None]:
def circle_with_text(
    ax: Axes,
    x: float,
    y: float,
    text: str,
    color: str,
    radius: float = 0.1,
    fontsize: int = 10,
) -> None:
    ax.add_patch(patches.Circle((x, y), radius=radius, color=color, alpha=0.2))
    ax.annotate(text, xy=(x, y), fontsize=fontsize, ha="center", va="center")


with myxkcd():
    fig, ax = plt.subplots(figsize=(4, 4), layout="constrained")
    ax.set_axis_off()
    ax.set_aspect("equal", adjustable="box")

    circle_with_text(ax, 0.1, 0.5, "S1", "C0")
    circle_with_text(ax, 0.5, 0.5, "Neural\nNetwork", "#6b134b", radius=0.17)
    circle_with_text(ax, 0.9, 0.8, "S2", "C2")
    circle_with_text(ax, 0.9, 0.2, "S3", "C1")
    ax.annotate(
        "",
        (0.33, 0.5),
        (0.2, 0.5),
        arrowprops={"arrowstyle": "->"},
    )
    ax.annotate(
        "",
        (0.75, 0.75),
        (0.6, 0.66),
        arrowprops={"arrowstyle": "->"},
    )
    ax.annotate(
        "",
        (0.75, 0.25),
        (0.6, 0.33),
        arrowprops={"arrowstyle": "->"},
    )
    savefig("surrogate")
    plt.show()

In [None]:
with myxkcd():
    fig, ax = plt.subplots(figsize=(2, 2), layout="constrained")
    ax.set_axis_off()
    ax.set_aspect("equal", adjustable="box")
    circle_with_text(
        ax, 0.5, 0.5, "Neural\nNetwork", "#6b134b", radius=0.5, fontsize=20
    )
    savefig("neural-network")

plt.show()

## Neural posterior estimation

In [None]:
with myxkcd():
    fig, (ax1, ax2, ax3) = plt.subplots(
        1,
        3,
        figsize=(7, 2.5),
        gridspec_kw={"width_ratios": [1, 0.5, 1]},
        layout="constrained",
    )
    ax = ax1
    ax.plot(x := np.linspace(0, 1, 101), michaelis_menten_1s(x, 1, 0.1))
    ax.plot(x := np.linspace(0, 1, 101), michaelis_menten_1s(x, 0.9, 0.1))
    ax.plot(x := np.linspace(0, 1, 101), michaelis_menten_1s(x, 0.8, 0.1))
    ax.set(
        title="Data",
        xlabel="Time",
        ylabel="Concentration",
    )
    ax.spines[["right", "top"]].set_visible(False)

    ax = ax2
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    # ax.annotate(
    #     "Neural\nnetwork",
    #     (0.5, 0.5),
    #     ha="center",
    #     va="center",
    #     fontsize=20,
    # )
    circle_with_text(ax, 0.5, 0.5, "Neural\nNetwork", "#6b134b", radius=0.45)
    ax.set_axis_off()
    ax.set_aspect("equal", adjustable="box")

    ax = ax3
    sns.kdeplot(distributions.Normal(5, 1).sample(1000), fill=True, ax=ax)
    ax.yaxis.tick_right()
    ax.yaxis.set_label_position("right")
    ax.set(title="Parameter distribution", xlabel="Parameter value")
    ax.spines[["left", "top"]].set_visible(False)

    fig.patches.append(
        patches.FancyArrowPatch(
            posA=(0.8, 0.5),  # Start point (in axes coordinates)
            posB=(1.0, 0.5),  # End point (in axes coordinates)
            color="black",
            arrowstyle="->",
            linewidth=2,
            mutation_scale=15,
            transform=ax1.transAxes,
        )
    )
    fig.patches.append(
        patches.FancyArrowPatch(
            posA=(0.0, 0.5),  # Start point (in axes coordinates)
            posB=(0.2, 0.5),  # End point (in axes coordinates)
            color="black",
            arrowstyle="->",
            linewidth=2,
            mutation_scale=15,
            transform=ax3.transAxes,
        )
    )

    savefig("npe")
    plt.show()

In [None]:
plt.plot(x, np.exp(x))
plt.show()

In [None]:
x = np.linspace(0.1, 1, 101)

with myxkcd():
    fig, (ax1, ax2, ax3) = plt.subplots(
        1,
        3,
        figsize=(7, 2.5),
        sharey=True,
        layout="constrained",
    )

    ax1.set_ylim(0, 1)
    ax1.set_ylabel("Error")

    ax1.set(title="Identifiable")
    ax1.plot(x, (x - 0.6) ** 2 + 0.3)

    ax2.set(title="Structurally \nunidentifiable")
    ax2.plot(x, np.full_like(x, fill_value=0.3))

    ax3.set(title="Practically \nunidentifiable")

    ax3.add_patch(
        mpatches.PathPatch(
            mpath.Path(
                [
                    (0, 0.8),
                    (0.4, 0.05),
                    (0.85, 0.36),
                    (1.0, 0.35),
                ],
                [
                    mpath.Path.MOVETO,
                    mpath.Path.CURVE4,
                    mpath.Path.CURVE3,
                    mpath.Path.LINETO,
                ],
            ),
            fc="none",
            color=C_ODE,
            linewidth=2,
            transform=ax.transData,
        )
    )

    for ax in (ax1, ax2, ax3):
        ax.spines[["right", "top"]].set_visible(False)
    savefig("identifiability")
    plt.show()

# Stability

In [None]:
from example_models import get_phase_plane
from mxlpy import Simulator, unwrap

with myxkcd():
    fig, ax = plot.trajectories_2d(
        get_phase_plane(),
        x1=("s1", np.linspace(0, 2, 10)),
        x2=("s2", np.linspace(0, 2, 10)),
    )

    for s1 in np.linspace(0, 1, 4):
        for s2 in np.linspace(0, 2, 4):
            c = unwrap(
                Simulator(get_phase_plane(), y0={"s1": s1, "s2": s2})
                .simulate(1.5)
                .get_result()
            ).variables
            ax.plot(c["s1"], c["s2"], linewidth=3)

    # ax.set_xlabel("S")
    # ax.set_ylabel("P")
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.spines[["right", "top"]].set_visible(False)
    savefig("phase-plane")
    plt.show()