In [None]:
from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import mxlpy as mb2
from example_models import (
    get_linear_chain_2v,
)
from mxlpy import make_protocol, plot, scan

# Parameter scans


Parameter scans allow you to systematically assess the behaviour of your model dependent on the value of one or more parameters.  
*mxlpy* has routines to scan over, and easily visualise **time courses**, **protocol time courses**, and **steady states** for one or more parameters.  

<div>
    <img src="assets/time-course-by-parameter.png" 
         style="vertical-align:middle; max-height: 175px;" />
    <img src="assets/parameter-scan-2d.png" 
         style="vertical-align:middle; max-height: 175px;" />
</div>

For this, we import the `scan` and `plot` modules from which contain the respective routines.  


## Steady-state

The steady-state scan takes a `pandas.DataFrame` of parameters to be scanned as an input and returns the steady-states at the respective parameter values.  

The DataFrame can take an arbitrary number of parameters and should be in the following format 

|  n |   k1 |
|---:|-----:|
|  0 |  1   |
|  1 |  1.2 |
|  2 |  1.4 |

As an example we will use a linear chain of two reactions like this

$$ \varnothing \xrightarrow{v_1} S \xrightarrow{v_2} P \xrightarrow{v_3} \varnothing$$


In [None]:
res = scan.steady_state(
    get_linear_chain_2v(),
    to_scan=pd.DataFrame({"k1": np.linspace(1, 3, 11)}),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 3))
plot.lines(res.variables, ax=ax1)  # access concentrations by name
plot.lines(res.fluxes, ax=ax2)  # access fluxes by name

ax1.set(ylabel="Concentration / a.u.")
ax2.set(ylabel="Flux / a.u.")
plt.show()

All scans return a result object, which allow multiple access patterns for convenience. 

Namely, the concentrations and fluxes can be accessed by name, unpacked or combined into a single dataframe.

In [None]:
# Access by name
_ = res.variables
_ = res.fluxes

# scan can be unpacked
concs, fluxes = res

# combine concs and fluxes as single dataframe
_ = res.combined

#### Combinations

Often you want to scan over multiple parameters at the same time.  
The recommended way to do this is to use the `cartesian_product` function, which takes a `parameter_name: values` mapping and creates a `pandas.DataFrame` of their combinations from it (think nested for loop).  

In the case the parameters `DataFrame` contains more than one column, the returned `pandas.DataFrame` will contain a `pandas.MultiIndex`.  

In [None]:
mb2.cartesian_product(
    {
        "k1": [1, 2],
        "k2": [3, 4],
    }
)

In [None]:
res = scan.steady_state(
    get_linear_chain_2v(),
    to_scan=mb2.cartesian_product(
        {
            "k1": np.linspace(1, 2, 3),
            "k2": np.linspace(1, 2, 4),
        }
    ),
)

res.variables.head()

You can plot the results of a **single variable** of this scan using a heatmap

In [None]:
plot.heatmap_from_2d_idx(res.variables, variable="x")
plt.show()

Or create heatmaps of all passed variables at once.  

In [None]:
plot.heatmaps_from_2d_idx(res.variables)
plt.show()

You can also combine more than two parameters, however, visualisation then becomes challenging.  

In [None]:
res = scan.steady_state(
    get_linear_chain_2v(),
    to_scan=mb2.cartesian_product(
        {
            "k1": np.linspace(1, 2, 3),
            "k2": np.linspace(1, 2, 4),
            "k3": np.linspace(1, 2, 4),
        }
    ),
)
res.variables.head()

### Time course

You can perform a time course for each of the parameter values, resulting in a **distribution of time courses**.    
The index now also contains the time, so even for one parameter a `pandas.MultiIndex` is used.

In [None]:
tss = scan.time_course(
    get_linear_chain_2v(),
    to_scan=pd.DataFrame({"k1": np.linspace(1, 2, 11)}),
    time_points=np.linspace(0, 1, 11),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(tss.variables, ax=ax1)
plot.lines_mean_std_from_2d_idx(tss.fluxes, ax=ax2)

ax1.set(xlabel="time / a.u.", ylabel="Concentration / a.u.")
ax2.set(xlabel="time / a.u.", ylabel="Flux / a.u.")
plt.show()

Again, this works for an arbitray number of parameters.

In [None]:
tss = scan.time_course(
    get_linear_chain_2v(),
    to_scan=mb2.cartesian_product(
        {
            "k1": np.linspace(1, 2, 11),
            "k2": np.linspace(1, 2, 4),
        }
    ),
    time_points=np.linspace(0, 1, 11),
)


fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(tss.variables, ax=ax1)
plot.lines_mean_std_from_2d_idx(tss.fluxes, ax=ax2)
ax1.set(xlabel="time / a.u.", ylabel="Concentration / a.u.")
ax2.set(xlabel="time / a.u.", ylabel="Flux / a.u.")
plt.show()

The scan object returned has a `pandas.MultiIndex` of `n x time`, where `n` is an index that references parameter combinations.  
You can access the referenced parameters using `.to_scan`

In [None]:
tss.to_scan.head()

You can also easily access common aggregates like `mean` and `standard deviation (std)` using `get_agg_per_time`.  

In [None]:
tss.get_agg_per_time("std").head()

### Protocol

The same can be done for protocols.  

In [None]:
res = scan.protocol_time_course(
    get_linear_chain_2v(),
    to_scan=pd.DataFrame({"k2": np.linspace(1, 2, 11)}),
    time_points=np.linspace(0, 6, 21),
    protocol=make_protocol(
        [
            (1, {"k1": 1}),
            (2, {"k1": 2}),
            (3, {"k1": 1}),
        ]
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(res.variables, ax=ax1)
plot.lines_mean_std_from_2d_idx(res.fluxes, ax=ax2)
ax1.set(xlabel="time / a.u.", ylabel="Concentration / a.u.")
ax2.set(xlabel="time / a.u.", ylabel="Flux / a.u.")

for ax in (ax1, ax2):
    plot.shade_protocol(res.protocol["k1"], ax=ax, alpha=0.2)
plt.show()

<div style="color: #ffffff; background-color: #04AA6D; padding: 3rem 1rem 3rem 1rem; box-sizing: border-box">
    <h2>First finish line</h2>
    With that you now know most of what you will need from a day-to-day basis about parameter scans in mxlpy.
    <br />
    <br />
    Congratulations!
</div>

In [None]:
import pickle
from pathlib import Path
from typing import TYPE_CHECKING, Any

from mxlpy.parallel import Cache, parallelise

if TYPE_CHECKING:
    from collections.abc import Hashable

## Parallel execution

By default, all scans are executed in parallel.  
To do this, they internally use the `parallelise` function defined by `mxlpy`.  

> Tip: You can also use this function for other analyses as it is not specific to any `mxlpy` constructs.  

The `parallelise` takes a function of type `T` and an iterable of a `key: T` pair.  
The key is used to map the results to a given input and for caching (see below).  


In [None]:
def square(x: float) -> float:
    return x**2


output = parallelise(square, [("a", 2), ("b", 3), ("c", 4)])
output

## Caching

In case the simulations take a significant amount of time to run, it makes sense to cache the results on disk.  
You can do that by adding a `cache` to the `parallelise` function (and thus to all `scan` functions as well).  

```python
parallelise(...,  cache=Cache())
```

The first time the scan is run, the calculations are done, every subsequent time the results are loaded.  

In [None]:
output = parallelise(
    square,
    [("a", 2), ("b", 3), ("c", 4)],
    cache=Cache(),
)

To avoid overwriting cache results by different analyses we recommend saving each of them in a respective folder.  

In [None]:
_ = Cache(tmp_dir=Path(".cache") / "analysis-name")

By default the `key` of `parallelise` is used to create a pickle file called `{k}.p`.  
You can customise this behaviour by changing the `name_fn`, `load_fn` and `save_fn` arguments respectively.

In [None]:
def _pickle_name(k: Hashable) -> str:
    return f"{k}.p"


def _pickle_load(file: Path) -> Any:
    with file.open("rb") as fp:
        return pickle.load(fp)


def _pickle_save(file: Path, data: Any) -> None:
    with file.open("wb") as fp:
        pickle.dump(data, fp)


_ = Cache(
    name_fn=_pickle_name,
    load_fn=_pickle_load,
    save_fn=_pickle_save,
)