In [None]:
from __future__ import annotations

from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import optax
import pandas as pd
import seaborn as sns
from numpy.polynomial.polynomial import Polynomial

from example_models.linear_chain import get_linear_chain_2v
from mxlpy import Model, Simulator, fns, npe, plot, scan, surrogates
from mxlpy.distributions import LogNormal, Normal, sample
from mxlpy.types import AbstractSurrogate, unwrap

# Mechanistic Learning

Mechanistic learning is the intersection of mechanistic modelling and machine learning.  
*mxlpy* currently supports two such approaches: surrogates and neural posterior estimation.  

In the following we will mostly use the `mxlpy.surrogates` and `mxlpy.npe` modules to learn about both approaches.  

## Surrogate models


**Surrogate models** replace whole parts of a mechanistic model (or even the entire model) with machine learning models.  

<img src="assets/surrogate.png" style="max-height: 300px;">

This allows combining together multiple models of arbitrary size, without having to worry about the internal state of each model.  
They are especially useful for improving the description of *boundary effects*, e.g. a dynamic description of downstream consumption. 

## Manual construction

Surrogates can have return two kind of values in `mxply`: `derived quantities` and `reactions`.  

We will start by defining a polynomial surrogate that will get the value of a variable `x` and output the derived quantity `y`.  
Note that due to their nature surrogates can take multiple inputs and return multiple outputs, so we will always use iterables when defining them.  

We then also add a derived value `z` that uses the output of our surrogate to see that we are getting the correct output.   

In [None]:
m = Model()
m.add_variable("x", 1.0)
m.add_surrogate(
    "surrogate",
    surrogates.poly.Surrogate(
        model=Polynomial(coef=[2]),
        args=["x"],
        outputs=["y"],
    ),
)
m.add_derived("z", fns.add, args=["x", "y"])

# Check output
m.get_args()

Next we extend that idea to create a reaction.  
The only thing we need to change here is to also add the `stoichiometries` of the respective output variable.  

I've renamed the output to `v1` here to fit convention, but that is not technically necessary.  
`mxlpy` will always infer structurally into what kind of value your surrogate will be translated.  


In [None]:
m = Model()
m.add_variable("x", 1.0)
m.add_surrogate(
    "surrogate",
    surrogates.poly.Surrogate(
        model=Polynomial(coef=[2]),
        args=["x"],
        outputs=["v1"],
        stoichiometries={"v1": {"x": -1}},
    ),
)
m.add_derived("z", fns.add, args=["x", "v1"])

# Check output
m.get_right_hand_side()

Note that if you have **multiple outputs**, it is perfectly fine for them to mix between derived values and reactions.  

```python
Surrogate(
    model=...,
    args=["x", "y"],
    outputs=["d1", "v1"],               # outputs derived value d1 and rate v1
    stoichiometries={"v1": {"x": -1}},  # only rate v1 is given stoichiometries
)
```

## Training a surrogate from data and using it

We will start with a simple linear chain model

$$ \Large \varnothing \xrightarrow{v_1} x \xrightarrow{v_2} y \xrightarrow{v_3} \varnothing $$

where we want to read out the steady-state rate of $v_3$ dependent on the fixed concentration of $x$, while ignoring the inner state of the model.  


$$ \Large  x \xrightarrow{} ... \xrightarrow{v_3}$$

Since we need to fix a `variable` as an `parameter`, we can use the `make_variable_static` method to do that.

In [None]:
# Now "x" is a parameter
get_linear_chain_2v().make_variable_static("x").parameters

And we can already create a function to create a model, which will take our surrogate as an input.  

In [None]:
def get_model_with_surrogate(surrogate: AbstractSurrogate) -> Model:
    model = Model()
    model.add_variables({"x": 1.0, "z": 0.0})

    # Adding the surrogate
    model.add_surrogate(
        "surrogate",
        surrogate,
        args=["x"],
        outputs=["v2"],
        stoichiometries={
            "v2": {"x": -1, "z": 1},
        },
    )

    # Note that besides the surrogate we haven't defined any other reaction!
    # We could have though
    return model

### Create data

The surrogates used in the following will all use the **steady-state** fluxes depending on the inputs.  

We can thus create the necessary training data usign `scan.steady_state`.  
Since this is usually a large amount of data, we recommend caching the results using `Cache`. 

In [None]:
surrogate_features = pd.DataFrame({"x": np.geomspace(1e-12, 2.0, 21)})

surrogate_targets = scan.steady_state(
    get_linear_chain_2v().make_variable_static("x"),
    to_scan=surrogate_features,
).fluxes.loc[:, ["v3"]]

# It's always a good idea to check the inputs and outputs
fig, (ax1, ax2) = plot.two_axes(figsize=(6, 3), sharex=False)
_ = plot.violins(surrogate_features, ax=ax1)[1].set(
    title="Features", ylabel="Flux / a.u."
)
_ = plot.violins(surrogate_targets, ax=ax2)[1].set(
    title="Targets", ylabel="Flux / a.u."
)
plt.show()

### Polynomial surrogate

We can train our polynomial surrogate using `train_polynomial_surrogate`.  
By default this will train polynomials for the degrees `(1, 2, 3, 4, 5, 6, 7)`, but you can change that by using the `degrees` argument.  
The function returns the trained surrogate and the training information for the different polynomial degrees.  

> **Currently the polynomial surrogates are limited to a single feature and a single target**


In [None]:
surrogate, info = surrogates.poly.train(
    surrogate_features["x"],
    surrogate_targets["v3"],
)

print("Model", surrogate.model, end="\n\n")
print(info["score"])

You can then insert the surrogate into the model using the function we defined earlier


In [None]:
concs, fluxes = unwrap(
    Simulator(get_model_with_surrogate(surrogate)).simulate(10).get_result()
)

fig, (ax1, ax2) = plot.two_axes(figsize=(8, 3))
plot.lines(concs, ax=ax1)
plot.lines(fluxes, ax=ax2)
ax1.set(xlabel="time / a.u.", ylabel="concentration / a.u.")
ax2.set(xlabel="time / a.u.", ylabel="flux / a.u.")
plt.show()

While polynomial regression can model nonlinear relationships between variables, it often struggles when the underlying relationship is more complex than a polynomial function.  
You will learn about using neural networks in the next section.  

### Neural network surrogate using PyTorch

Neural networks are designed to capture highly complex and nonlinear relationships.  
Through layers of neurons and activation functions, neural networks can learn intricate patterns that are not easily represented by e.g. a polynomial.  
They have the flexibility to approximate any continuous function, given sufficient depth and appropriate training.  

You can train a neural network surrogate based on the popular [PyTorch](https://pytorch.org/) library using `train_torch_surrogate`.  
That function takes the `features`, `targets` and the number of `epochs` as inputs for it's training.  

`train_torch_surrogate` returns the trained surrogate, as well as the training `loss`.  
It is always a good idea to check whether that training loss approaches 0.   

In [None]:
surrogate, loss = surrogates.torch.train(
    features=surrogate_features,
    targets=surrogate_targets,
    batch_size=100,
    epochs=250,
)

ax = loss.plot(ax=plt.subplots(figsize=(4, 2.5))[1])
ax.set_ylim(0, None)
plt.show()

As before, you can then insert the surrogate into the model using the function we defined earlier


In [None]:
concs, fluxes = unwrap(
    Simulator(get_model_with_surrogate(surrogate)).simulate(10).get_result()
)

fig, (ax1, ax2) = plot.two_axes(figsize=(8, 3))
plot.lines(concs, ax=ax1)
plot.lines(fluxes, ax=ax2)
ax1.set(xlabel="time / a.u.", ylabel="concentration / a.u.")
ax2.set(xlabel="time / a.u.", ylabel="flux / a.u.")
plt.show()

### Re-entrant training

Quite often you don't know the amount of epochs you are going to need in order to reach the required loss.  
In this case, you can directly use the `TorchSurrogateTrainer` class to continue training.  

In [None]:
trainer = surrogates.torch.Trainer(
    features=surrogate_features,
    targets=surrogate_targets,
)

# First training epochs
trainer.train(epochs=100)
trainer.get_loss().plot(figsize=(4, 2.5)).set_ylim(0, None)
plt.show()

# Decide to continue training
trainer.train(epochs=150)
trainer.get_loss().plot(figsize=(4, 2.5)).set_ylim(0, None)
plt.show()

surrogate = trainer.get_surrogate(surrogate_outputs=["x"])

### Troubleshooting

It often can make sense to check specific predictions of the surrogate.  
For example, what does it predict when the inputs are all 0?  


In [None]:
print(surrogate.predict_raw(np.array([-0.1])))
print(surrogate.predict_raw(np.array([0.0])))
print(surrogate.predict_raw(np.array([0.1])))

## Using keras instead of torch

If you installed keras, you can use it with exactly the same interface torch


In [None]:
surrogate, loss = surrogates.keras.train(
    features=surrogate_features,
    targets=surrogate_targets,
    batch_size=100,
    epochs=250,
)

ax = loss.plot(ax=plt.subplots(figsize=(4, 2.5))[1])
ax.set_ylim(0, None)
plt.show()

## Using equinox instead of torch

In [None]:
surrogate, loss = surrogates.equinox.train(
    features=surrogate_features,
    targets=surrogate_targets,
    batch_size=100,
    epochs=250,
    optimizer=optax.adamw(learning_rate=0.001),
)

ax = loss.plot(ax=plt.subplots(figsize=(4, 2.5))[1])
ax.set_ylim(0, None)
plt.show()

## Neural posterior estimation


**Neural posterior estimation** answers the question: **what parameters could have generated the data I measured?**  
Here you use an ODE model and prior knowledge about the parameters of interest to create *synthetic data*.  
You then use the generated synthetic data as the *features* and the input parameters as the *targets* to train an *inverse problem*.  
Once that training is successful, the neural network can now predict the input parameters for real world data.  

<img src="assets/npe.png" style="max-height: 175px;">

You can use this technique for both steady-state as well as time course data.  
The only difference is in using `scan.time_course`.  

> Take care here to save the targets as well in case you use cached data :)

In [None]:
# Note that now the parameters are the targets
npe_targets = sample(
    {
        "k1": LogNormal(mean=1.0, sigma=0.3),
    },
    n=1_000,
)

# And the generated data are the features
npe_features = (
    scan.steady_state(
        get_linear_chain_2v(),
        to_scan=npe_targets,
    )
    .get_args()
    .loc[:, ["y", "v2", "v3"]]
)

# It's always a good idea to check the inputs and outputs
fig, (ax1, ax2) = plot.two_axes(figsize=(6, 3), sharex=False)
_ = plot.violins(npe_features, ax=ax1)[1].set(title="Features", ylabel="Flux / a.u.")
_ = plot.violins(npe_targets, ax=ax2)[1].set(title="Targets", ylabel="Flux / a.u.")
plt.show()

### Train NPE

You can then train your neural posterior estimator using `npe.train_torch_ss_estimator` (or `npe.train_torch_time_course_estimator` if you have time course data).  


In [None]:
estimator, losses = npe.torch.train_steady_state(
    features=npe_features,
    targets=npe_targets,
    epochs=100,
    batch_size=100,
)

ax = losses.plot(figsize=(4, 2.5))
ax.set(xlabel="epoch", ylabel="loss")
ax.set_ylim(0, None)
plt.show()

### Sanity check: do prior and posterior match?

In [None]:
fig, (ax1, ax2) = plot.two_axes(figsize=(6, 2))

ax = sns.kdeplot(npe_targets, fill=True, ax=ax1)
ax.set_title("Prior")

posterior = estimator.predict(npe_features)
ax = sns.kdeplot(posterior, fill=True, ax=ax2)
ax.set_title("Posterior")
plt.show()

### Re-entrant training

As with the surrogates you often you don't know the amount of epochs you are going to need in order to reach the required loss.  
For the neural posterior estimation you can use the `npe.TorchSteadyStateTrainer` and `npe.TorchTimeCourseTrainer` respectively to continue training.   

In [None]:
trainer = npe.torch.SteadyStateTrainer(
    features=npe_features,
    targets=npe_targets,
)

# Initial training
trainer.train(epochs=20, batch_size=100)
trainer.get_loss().plot(figsize=(4, 2.5)).set_ylim(0, None)
plt.show()

# Continue training
trainer.train(epochs=20, batch_size=100)
trainer.get_loss().plot(figsize=(4, 2.5)).set_ylim(0, None)
plt.show()

# Get trainer if loss is deemed suitable
estimator = trainer.get_estimator()

<div style="color: #ffffff; background-color: #04AA6D; padding: 3rem 1rem 3rem 1rem; box-sizing: border-box">
    <h2>First finish line</h2>
    With that you now know most of what you will need from a day-to-day basis about labelled models in mxlpy.
    <br />
    <br />
    Congratulations!
</div>

## Custom loss function

You can use a custom loss function by simply injecting a function that takes the predicted tensor `x` and the data `y` and produces another tensor.  

In [None]:
from typing import TYPE_CHECKING

import torch

from mxlpy import LinearLabelMapper, Simulator
from mxlpy.distributions import sample
from mxlpy.fns import michaelis_menten_1s
from mxlpy.parallel import parallelise

if TYPE_CHECKING:
    from mxlpy.types import AbstractEstimator

In [None]:
def mean_abs(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return torch.mean(torch.abs(x - y))


trainer = surrogates.torch.Trainer(
    features=surrogate_features,
    targets=surrogate_targets,
    loss_fn=mean_abs,
)

trainer = npe.torch.SteadyStateTrainer(
    features=npe_features,
    targets=npe_targets,
    loss_fn=mean_abs,
)

trainer = npe.torch.TimeCourseTrainer(
    features=npe_features,
    targets=npe_targets,
    loss_fn=mean_abs,
)

## Label NPE

In [None]:
# FIXME: todo
# Show how to change Adam settings or user other optimizers
# Show how to change the surrogate network

In [None]:
def get_closed_cycle() -> tuple[Model, dict[str, int], dict[str, list[int]]]:
    """

    | Reaction       | Labelmap |
    | -------------- | -------- |
    | x1 ->[v1] x2   | [0, 1]   |
    | x2 ->[v2a] x3  | [0, 1]   |
    | x2 ->[v2b] x3  | [1, 0]   |
    | x3 ->[v3] x1   | [0, 1]   |

    """
    model = (
        Model()
        .add_parameters(
            {
                "vmax_1": 1.0,
                "km_1": 0.5,
                "vmax_2a": 1.0,
                "vmax_2b": 1.0,
                "km_2": 0.5,
                "vmax_3": 1.0,
                "km_3": 0.5,
            }
        )
        .add_variables({"x1": 1.0, "x2": 0.0, "x3": 0.0})
        .add_reaction(
            "v1",
            michaelis_menten_1s,
            stoichiometry={"x1": -1, "x2": 1},
            args=["x1", "vmax_1", "km_1"],
        )
        .add_reaction(
            "v2a",
            michaelis_menten_1s,
            stoichiometry={"x2": -1, "x3": 1},
            args=["x2", "vmax_2a", "km_2"],
        )
        .add_reaction(
            "v2b",
            michaelis_menten_1s,
            stoichiometry={"x2": -1, "x3": 1},
            args=["x2", "vmax_2b", "km_2"],
        )
        .add_reaction(
            "v3",
            michaelis_menten_1s,
            stoichiometry={"x3": -1, "x1": 1},
            args=["x3", "vmax_3", "km_3"],
        )
    )
    label_variables: dict[str, int] = {"x1": 2, "x2": 2, "x3": 2}
    label_maps: dict[str, list[int]] = {
        "v1": [0, 1],
        "v2a": [0, 1],
        "v2b": [1, 0],
        "v3": [0, 1],
    }
    return model, label_variables, label_maps

In [None]:
def _worker(
    x: tuple[tuple[int, pd.Series], tuple[int, pd.Series]],
    mapper: LinearLabelMapper,
    time: float,
    initial_labels: dict[str, int | list[int]],
) -> pd.Series:
    (_, y_ss), (_, v_ss) = x
    return unwrap(
        Simulator(mapper.build_model(y_ss, v_ss, initial_labels=initial_labels))
        .simulate(time)
        .get_result()
    ).variables.iloc[-1]


def get_label_distribution_at_time(
    model: Model,
    label_variables: dict[str, int],
    label_maps: dict[str, list[int]],
    time: float,
    initial_labels: dict[str, int | list[int]],
    ss_concs: pd.DataFrame,
    ss_fluxes: pd.DataFrame,
) -> pd.DataFrame:
    mapper = LinearLabelMapper(
        model,
        label_variables=label_variables,
        label_maps=label_maps,
    )

    return pd.DataFrame(
        dict(
            parallelise(
                partial(
                    _worker, mapper=mapper, time=time, initial_labels=initial_labels
                ),
                inputs=list(
                    enumerate(
                        zip(
                            ss_concs.iterrows(),
                            ss_fluxes.iterrows(),
                            strict=True,
                        )
                    )
                ),  # type: ignore
                cache=None,
            )
        ),
        dtype=float,
    ).T


def inverse_parameter_elasticity(
    estimator: AbstractEstimator,
    datum: pd.Series,
    *,
    normalized: bool = True,
    displacement: float = 1e-4,
) -> pd.DataFrame:
    ref = estimator.predict(datum).iloc[0, :]

    coefs = {}
    for name, value in datum.items():
        up = coefs[name] = estimator.predict(
            pd.Series(datum.to_dict() | {name: value * 1 + displacement})
        ).iloc[0, :]
        down = coefs[name] = estimator.predict(
            pd.Series(datum.to_dict() | {name: value * 1 - displacement})
        ).iloc[0, :]
        coefs[name] = (up - down) / (2 * displacement * value)

    coefs = pd.DataFrame(coefs)
    if normalized:
        coefs *= datum / ref.to_numpy()

    return coefs

In [None]:
model, label_variables, label_maps = get_closed_cycle()

ss_concs, ss_fluxes = unwrap(
    Simulator(model)
    .update_parameters({"vmax_2a": 1.0, "vmax_2b": 0.5})
    .simulate_to_steady_state()
    .get_result()
)
mapper = LinearLabelMapper(
    model,
    label_variables=label_variables,
    label_maps=label_maps,
)

_, axs = plot.relative_label_distribution(
    mapper,
    unwrap(
        Simulator(
            mapper.build_model(
                ss_concs.iloc[-1], ss_fluxes.iloc[-1], initial_labels={"x1": 0}
            )
        )
        .simulate(5)
        .get_result()
    ).variables,
    sharey=True,
    n_cols=3,
)

axs[0, 0].set_ylabel("Relative label distribution")
axs[0, 1].set_xlabel("Time / s")
plt.show()

In [None]:
surrogate_targets = sample(
    {
        "vmax_2b": Normal(0.5, 0.1),
    },
    n=1000,
).clip(lower=0)

ax = sns.kdeplot(surrogate_targets, fill=True)
ax.set_title("Prior")

In [None]:
ss_concs, ss_fluxes = scan.steady_state(
    model,
    to_scan=surrogate_targets,
)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
_, ax = plot.violins(ss_concs, ax=ax1)
ax.set_ylabel("Concentration / a.u.")
_, ax = plot.violins(ss_fluxes, ax=ax2)
ax.set_ylabel("Flux / a.u.")

In [None]:
surrogate_features = get_label_distribution_at_time(
    model=model,
    label_variables=label_variables,
    label_maps=label_maps,
    time=5,
    ss_concs=ss_concs,
    ss_fluxes=ss_fluxes,
    initial_labels={"x1": 0},
)
_, ax = plot.violins(surrogate_features)
ax.set_ylabel("Relative label distribution")

In [None]:
estimator, losses = npe.torch.train_steady_state(
    features=surrogate_features,
    targets=surrogate_targets,
    batch_size=100,
    epochs=250,
)

ax = losses.plot()
ax.set_ylim(0, None)

In [None]:
fig, (ax1, ax2) = plt.subplots(
    1,
    2,
    figsize=(8, 3),
    layout="constrained",
    sharex=True,
    sharey=False,
)

ax = sns.kdeplot(surrogate_targets, fill=True, ax=ax1)
ax.set_title("Prior")

posterior = estimator.predict(surrogate_features)

ax = sns.kdeplot(posterior, fill=True, ax=ax2)
ax.set_title("Posterior")
ax2.set_ylim(*ax1.get_ylim())
plt.show()

### Inverse parameter sensitivity

In [None]:
_ = plot.heatmap(inverse_parameter_elasticity(estimator, surrogate_features.iloc[0]))

In [None]:
elasticities = pd.DataFrame(
    {
        k: inverse_parameter_elasticity(estimator, i).loc["vmax_2b"]
        for k, i in surrogate_features.iterrows()
    }
).T

_ = plot.violins(elasticities)