In [None]:
from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np

from example_models import get_linear_chain_2v
from mxlpy import Simulator, fit, make_protocol, plot

## Fitting

Almost every model at some point needs to be fitted to experimental data to be **validated**.  

<img src="assets/fitting.png" style="max-height: 175px;" />

*mxlpy* offers highly customisable local and global routines for fitting either **time series** or **steady-states**.  

The entire set of currently supported routines is

Single model, single data routines

- `steady_state`
- `time_course`
- `protocol_time_course`

Multiple model, single data routines

- `ensemble_steady_state`
- `ensemble_time_course`
- `ensemble_protocol_time_course`

A carousel is a special case of an ensemble, where the general
structure (e.g. stoichiometries) is the same, while the reactions kinetics
can vary
- `carousel_steady_state`
- `carousel_time_course`
- `carousel_protocol_time_course`

Multiple model, multiple data

- `joint_steady_state`
- `joint_time_course`
- `joint_protocol_time_course`

Multiple model, multiple data, multiple methods

Here we also allow to run different methods (e.g. steady-state vs time courses)
for each combination of model:data.

- `joint_mixed`

Minimizers
----------
- LocalScipyMinimizer, including common methods such as Nelder-Mead or L-BFGS-B
- GlobalScipyMinimizer, including common methods such as basin hopping or dual annealing

For this tutorial we are going to use the `fit` module to optimise our parameter values and the `plot` module to plot some results.  

Let's get started!

## Creating synthetic data

Normally, you would fit your model to experimental data.  
Here, for the sake of simplicity, we will generate some synthetic data.  

Checkout the [basics tutorial](basics.ipynb) if you need a refresher on building and simulating models.  

In [None]:
# As a small trick, let's define a variable for the model function
# That way, we can re-use it all over the file and easily replace
# it with another model
model_fn = get_linear_chain_2v

res = (
    Simulator(model_fn())
    .update_parameters({"k1": 1.0, "k2": 2.0, "k3": 1.0})
    .simulate_time_course(np.linspace(0, 10, 101))
    .get_result()
    .unwrap_or_err()
).get_combined()

fig, ax = plot.lines(res)
ax.set(xlabel="time / a.u.", ylabel="Conc. & Flux / a.u.")
plt.show()

### Steady-states

For the steady-state fit we need two inputs:

1. the steady state data, which we supply as a `pandas.Series`
2. an initial parameter guess

The fitting routine will compare all data contained in that series to the model output.  

> Note that the data both contains concentrations and fluxes!  

In [None]:
data = res.iloc[-1]
data.head()

In [None]:
fit_result = fit.steady_state(
    model_fn(),
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    data=res.iloc[-1],
    minimizer=fit.LocalScipyMinimizer(),
).unwrap_or_err()

fit_result.best_pars

If only some of the data is required, you can use a subset of it.  
The fitting routine will only try to fit concentrations and fluxes contained in that series.

In [None]:
fit_result = fit.steady_state(
    model_fn(),
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    data=data.loc[["x", "y"]],
    minimizer=fit.LocalScipyMinimizer(),
).unwrap_or_err()
fit_result.best_pars

> By default, mxlpy will apply standard scaling to all fitting functions.  

Specifically, it will calculate `loss_fn(data - data.mean()) / data.std()`, `(pred - data.mean()) / data.std())`.

To turn off this behaviour, set `standard_scale=False` in the fit functions

In [None]:
fit_result = fit.steady_state(
    model_fn(),
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    data=data.loc[["x", "y"]],
    minimizer=fit.LocalScipyMinimizer(),
    standard_scale=False,  # opt-out of standard scaling
).unwrap_or_err()
fit_result.best_pars

### Time course

For the time course fit we need again need two inputs

1. the time course data, which we supply as a `pandas.DataFrame`
2. an initial parameter guess

The fitting routine will create data at every time points specified in the `DataFrame` and compare all of them.  

Other than that, the same rules of the steady-state fitting apply.  

In [None]:
fit_result = fit.time_course(
    model_fn(),
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    data=res,
    minimizer=fit.LocalScipyMinimizer(),
).unwrap_or_err()

fit_result.best_pars

## Protocol time courses


Normally, you would fit your model to experimental data.  
Here, again, for the sake of simplicity, we will generate some synthetic data.  

In [None]:
protocol = make_protocol(
    [
        (1, {"k1": 1.0}),
        (1, {"k1": 2.0}),
        (1, {"k1": 1.0}),
    ]
)

res_protocol = (
    Simulator(model_fn())
    .update_parameters({"k1": 1.0, "k2": 2.0, "k3": 1.0})
    .simulate_protocol(
        protocol,
        time_points_per_step=10,
    )
    .get_result()
    .unwrap_or_err()
).get_combined()

fig, ax = plot.lines(res_protocol)
ax.set(xlabel="time / a.u.", ylabel="Conc. & Flux / a.u.")
plt.show()

For the protocol time course fit we need three inputs

1. an initial parameter guess
2. the time course data, which we supply as a `pandas.DataFrame`
3. the protocol, which we supply as a `pandas.DataFrame`

> Note that the parameter given by the protocol cannot be fitted anymore  


In [None]:
fit_result = fit.protocol_time_course(
    model_fn(),
    p0={"k2": 1.87, "k3": 1.093},  # note that k1 is given by the protocol
    data=res_protocol,
    protocol=protocol,
    minimizer=fit.LocalScipyMinimizer(),
).unwrap_or_err()

fit_result.best_pars

## Ensemble fitting

`mxlpy` supports ensebmle fitting, which is a **multi-model single data** approach, where shared parameters will be applied to all models at the same time.

Here you supply an iterable of models instead of just one, otherwise the API stays the same.  

In [None]:
ensemble_fit = fit.ensemble_steady_state(
    [
        model_fn(),
        model_fn(),
    ],
    data=res.iloc[-1],
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    minimizer=fit.LocalScipyMinimizer(tol=1e-6),
)

To get the best fitting model, you can use `get_best_fit` on the ensemble fit

In [None]:
fit_result = ensemble_fit.get_best_fit()

And you can of course also access all other fits

In [None]:
[i.loss for i in ensemble_fit.fits]

### Time course

Time course fits are adjusted just the same

In [None]:
ensemble_fit = fit.ensemble_time_course(
    [
        model_fn(),
        model_fn(),
    ],
    data=res,
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    minimizer=fit.LocalScipyMinimizer(tol=1e-6),
)

### Protocol time course

As are protocol time courses

In [None]:
ensemble_fit = fit.ensemble_protocol_time_course(
    [
        model_fn(),
        model_fn(),
    ],
    data=res_protocol,
    protocol=protocol,
    p0={"k2": 1.87, "k3": 1.093},  # note that k1 is given by the protocol
    minimizer=fit.LocalScipyMinimizer(tol=1e-6),
)

## Joint fitting

Next, we support joint fitting, which is a combined **multi-model multi-data** approach, where shared parameters will be applied to all models at the same time

In [None]:
fit.joint_steady_state(
    [
        fit.FitSettings(model=model_fn(), data=res.iloc[-1]),
        fit.FitSettings(model=model_fn(), data=res.iloc[-1]),
    ],
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    minimizer=fit.LocalScipyMinimizer(tol=1e-6),
)

In [None]:
fit.joint_time_course(
    [
        fit.FitSettings(model=model_fn(), data=res),
        fit.FitSettings(model=model_fn(), data=res),
    ],
    p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
    minimizer=fit.LocalScipyMinimizer(tol=1e-6),
)

In [None]:
fit.joint_protocol_time_course(
    [
        fit.FitSettings(model=model_fn(), data=res_protocol, protocol=protocol),
        fit.FitSettings(model=model_fn(), data=res_protocol, protocol=protocol),
    ],
    p0={"k2": 1.87, "k3": 1.093},
    minimizer=fit.LocalScipyMinimizer(tol=1e-6),
)

## Mixed joint fitting

Lastly, we support mixed-joint fitting, where each analysis takes it's own residual function to allow fitting both time series and steady-state data for multiple models at the same time.

In [None]:
fit.joint_mixed(
    [
        fit.MixedSettings(
            model=model_fn(),
            data=res.iloc[-1],
            residual_fn=fit.steady_state_residual,
        ),
        fit.MixedSettings(
            model=model_fn(),
            data=res,
            residual_fn=fit.time_course_residual,
        ),
        fit.MixedSettings(
            model=model_fn(),
            data=res_protocol,
            protocol=protocol,
            residual_fn=fit.protocol_time_course_residual,
        ),
    ],
    p0={"k2": 1.87, "k3": 1.093},
    minimizer=fit.LocalScipyMinimizer(tol=1e-6),
)

<div style="color: #ffffff; background-color: #04AA6D; padding: 3rem 1rem 3rem 1rem; box-sizing: border-box">
    <h2>First finish line</h2>
    With that you now know most of what you will need from a day-to-day basis about fitting in mxlpy.
    <br />
    <br />
    Congratulations!
</div>

## Advanced topics / customisation


All fitting routines internally are build in a way that they will call a tree of functions. 

- `minimizer`
  - `residual_fn`
    - `integrator`
    - `loss_fn`
  

You can therefore use dependency injection to overwrite the minimisation function, the loss function, the residual function and the integrator if need be.  

In [None]:
from functools import partial
from typing import TYPE_CHECKING, cast

from mxlpy.integrators import Scipy

if TYPE_CHECKING:
    import pandas as pd

## Parameterising scipy optimise

In [None]:
optimizer = fit.LocalScipyMinimizer(tol=1e-6, method="Nelder-Mead")

### Custom loss function

You can change the loss function that is being passed to the minimsation function using the `loss_fn` keyword.  
Depending on the use case (time course vs steady state) this function will be passed two pandas `DataFrame`s or `Series`.

In [None]:
def mean_absolute_error(
    x: pd.DataFrame | pd.Series,
    y: pd.DataFrame | pd.Series,
) -> float:
    """Mean absolute error between two dataframes."""
    return cast(float, np.mean(np.abs(x - y)))


(
    fit.time_course(
        model_fn(),
        p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
        data=res,
        loss_fn=mean_absolute_error,
        minimizer=fit.LocalScipyMinimizer(),
    )
    .unwrap_or_err()
    .best_pars
)

### Custom integrator

You can change the default integrator to an integrator of your choice by partially application of the class of any of the existing ones.  

Here, for example, we choose the `Scipy` solver suite and set the default relative and absolute tolerances to `1e-6` respectively.

In [None]:
(
    fit.time_course(
        model_fn(),
        p0={"k1": 1.038, "k2": 1.87, "k3": 1.093},
        data=res,
        integrator=partial(Scipy, rtol=1e-6, atol=1e-6),
        minimizer=fit.LocalScipyMinimizer(),
    )
    .unwrap_or_err()
    .best_pars
)