In [None]:
from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import mxlpy as mb2
from example_models import (
    get_lin_chain_two_circles,
    get_linear_chain_2v,
    get_upper_glycolysis,
)
from mxlpy import make_protocol, mc, mca, plot

# Monte Carlo methods

Almost every parameter in biology is better described with a distribution than a single value.  
Monte-carlo methods allow you to capture the **range of possible behaviour** your model can exhibit.  
This is especially useful when you want to understand the **uncertainty** in your model's predictions.  

*mxlpy* offers these Monte Carlo methods for all *scans*  ...

<div>
    <img src="assets/time-course.png" 
         style='vertical-align:middle; max-height: 175px; max-width: 20%'/>
    <span style='padding: 0 1rem; font-size: 2rem'>+</span>
    <img src="assets/parameter-distribution.png" 
         style='vertical-align:middle; max-height: 175px; max-width: 20%'/>
    <span style='padding: 0 1rem; font-size: 2rem'>=</span>
    <img src="assets/mc-time-course.png" 
         style='vertical-align:middle; max-height: 175px; max-width: 20%'/>
</div>

and even for *metabolic control analysis*

<div>
    <img src="assets/parameter-elasticity.png" 
         style='vertical-align:middle; max-height: 175px; max-width: 20%'/>
    <span style='padding: 0 1rem; font-size: 2rem'>+</span>
    <img src="assets/parameter-distribution.png" 
         style='vertical-align:middle; max-height: 175px; max-width: 20%'/>
    <span style='padding: 0 1rem; font-size: 2rem'>=</span>
    <img src="assets/violins.png" 
         style='vertical-align:middle; max-height: 175px; max-width: 20%'/>
</div>

In this tutorial we will mostly use the `mxlpy.distributions` and `mxlpy.mca` modules, which contain the functionality to sample from distributions and run distributed analyses.  

## Sample values

To do any Monte-Carlo analysis, we first need to be able to sample values.  

For that, you can use the `sample` function and distributions supplied by mxlpy.  
These are mostly thin wrappers around the `numpy` and `scipy` sampling methods.  

In [None]:
from mxlpy.distributions import LogNormal, Uniform, sample

sample(
    {
        "k2": Uniform(1.0, 2.0),
        "k3": LogNormal(mean=1.0, sigma=1.0),
    },
    n=5,
)

## Steady-state

Using `mc.steady_state` you can calculate the steady-state distribution given the monte-carlo parameters.  

This works analogously to the `scan.steady_state` function, except the index of the dataframes is always just an integer.  

The parameters used can be obtained by `result.parameters`.


We will use a linear chain of reactions with two circles as an example model for this notebook. 

$$
\begin{array}{c|c}
    \mathrm{Reaction} & \mathrm{Stoichiometry} \\
    \hline
    v_0 & \varnothing \rightarrow{} \mathrm{x_1} \\
    v_1 & -\mathrm{x_1} \rightarrow{} \mathrm{x_2} \\
    v_2 & -\mathrm{x_1} \rightarrow{} \mathrm{x_3} \\
    v_3 & -\mathrm{x_1} \rightarrow{} \mathrm{x_4} \\
    v_4 & -\mathrm{x_4} \rightarrow{} \varnothing\\
    v_5 & -\mathrm{x_2} \rightarrow{} \mathrm{x_1} \\
    v_6 & -\mathrm{x_3} \rightarrow{} \mathrm{x_1} \\
\end{array}
$$

In [None]:
ss = mc.steady_state(
    get_linear_chain_2v(),
    mc_to_scan=sample(
        {
            "k1": Uniform(0.9, 1.1),
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(6, 2.5), sharex=False)
plot.violins(ss.variables, ax=ax1)
plot.violins(ss.fluxes, ax=ax2)
ax1.set(xlabel="Variables", ylabel="Concentration / a.u.")
ax2.set(xlabel="Reactions", ylabel="Flux / a.u.")
plt.show()

## Time course

Using `mc.time_course` you can calculate time courses for sampled parameters.  

<div>
    <img src="assets/time-course.png" 
         style='vertical-align:middle; max-height: 175px; max-width: 20%'/>
    <span style='padding: 0 1rem; font-size: 2rem'>+</span>
    <img src="assets/parameter-distribution.png" 
         style='vertical-align:middle; max-height: 175px; max-width: 20%'/>
    <span style='padding: 0 1rem; font-size: 2rem'>=</span>
    <img src="assets/mc-time-course.png" 
          style='vertical-align:middle; max-height: 175px; max-width: 20%'/>
</div>

This function works analogously to `scan.time_course`. 

The `pandas.DataFrame`s for concentrations and fluxes have a `n x time` `pandas.MultiIndex`.  

The corresponding parameters can be found in `result.parameters`

In [None]:
tc = mc.time_course(
    get_linear_chain_2v(),
    time_points=np.linspace(0, 1, 11),
    mc_to_scan=sample(
        {
            "k1": Uniform(0.9, 1.1),
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(tc.variables, ax=ax1)
plot.lines_mean_std_from_2d_idx(tc.fluxes, ax=ax2)
ax1.set(xlabel="Time / a.u", ylabel="Concentration / a.u.")
ax2.set(xlabel="Time / a.u", ylabel="Flux / a.u.")
plt.show()

## Protocol time course

Using `mc.time_course_over_protocol` you can calculate time courses for sampled parameters given a discrete protocol.  

<div>
    <img src="assets/protocol-time-course.png" 
         style='vertical-align:middle; max-height: 175px; max-width: 20%'/>
    <span style='padding: 0 1rem; font-size: 2rem'>+</span>
    <img src="assets/parameter-distribution.png" 
         style='vertical-align:middle; max-height: 175px; max-width: 20%'/>
    <span style='padding: 0 1rem; font-size: 2rem'>=</span>
    <img src="assets/mc-protocol-time-course.png" 
         style='vertical-align:middle; max-height: 175px; max-width: 20%'/>
</div>



The `pandas.DataFrame`s for concentrations and fluxes have a `n x time` `pandas.MultiIndex`.  
The corresponding parameters can be found in `scan.parameters`

In [None]:
tc = mc.protocol_time_course(
    get_linear_chain_2v(),
    time_points=np.linspace(0, 6, 21),
    protocol=make_protocol(
        [
            (1, {"k1": 1}),
            (2, {"k1": 2}),
            (3, {"k1": 1}),
        ]
    ),
    mc_to_scan=sample(
        {
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

fig, (ax1, ax2) = plot.two_axes(figsize=(7, 4))
plot.lines_mean_std_from_2d_idx(tc.variables, ax=ax1)
plot.lines_mean_std_from_2d_idx(tc.fluxes, ax=ax2)
for ax in (ax1, ax2):
    plot.shade_protocol(tc.protocol["k1"], ax=ax, alpha=0.1)

ax1.set(xlabel="Time / a.u", ylabel="Concentration / a.u.")
ax2.set(xlabel="Time / a.u", ylabel="Flux / a.u.")
plt.show()

## Metabolic control analysis

*mxlpy* further has routines for monte-carlo distributed metabolic control analysis.  
This allows quantifying, whether the coefficients obtained from the MCA analysis are robust against parameter changes or whether they are just an artifact of a particular choice of parameters.  

### Variable elasticities

<div>
    <img src="assets/variable-elasticity.png" 
         style='vertical-align:middle; max-height: 150px; max-width: 20%;'/>
    <span style='padding: 0 1rem; font-size: 2rem'>+</span>
    <img src="assets/parameter-distribution.png" 
         style='vertical-align:middle; max-height: 150px; max-width: 20%;'/>
    <span style='padding: 0 1rem; font-size: 2rem'>=</span>
    <img src="assets/violins.png" 
         style='vertical-align:middle; max-height: 150px; max-width: 20%;'/>
</div>

The returned `pandas.DataFrame` has a `pd.MultiIndex` of shape `n x reaction`.  

In [None]:
mc_elas = mc.variable_elasticities(
    get_upper_glycolysis(),
    variables={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    to_scan=["GLC", "F6P"],
    mc_to_scan=sample(
        {
            # "k1": LogNormal(mean=np.log(0.25), sigma=1.0),
            # "k2": LogNormal(mean=np.log(1.0), sigma=1.0),
            "k3": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k4": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k5": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k6": LogNormal(mean=np.log(1.0), sigma=1.0),
            # "k7": LogNormal(mean=np.log(2.5), sigma=1.0),
        },
        n=5,
    ),
)

_ = plot.violins_from_2d_idx(mc_elas)
plt.show()

#### Parameter elasticities

<div>
    <img src="assets/parameter-elasticity.png" 
         style='vertical-align:middle; max-height: 150px; max-width: 20%'/>
    <span style='padding: 0 1rem; font-size: 2rem'>+</span>
    <img src="assets/parameter-distribution.png" 
         style='vertical-align:middle; max-height: 150px; max-width: 20%'/>
    <span style='padding: 0 1rem; font-size: 2rem'>=</span>
    <img src="assets/violins.png" 
         style='vertical-align:middle; max-height: 150px; max-width: 20%'/>
</div>

In [None]:
elas = mc.parameter_elasticities(
    get_upper_glycolysis(),
    variables={
        "GLC": 0.3,
        "G6P": 0.4,
        "F6P": 0.5,
        "FBP": 0.6,
        "ATP": 0.4,
        "ADP": 0.6,
    },
    to_scan=["k1", "k2", "k3"],
    mc_to_scan=sample(
        {
            "k3": LogNormal(mean=np.log(0.25), sigma=1.0),
        },
        n=5,
    ),
)

_ = plot.violins_from_2d_idx(elas)
plt.show()

### Response coefficients

<div>
    <img src="assets/response-coefficient.png" 
         style='vertical-align:middle; max-height: 150px; max-width: 20%;'/>
    <span style='padding: 0 1rem; font-size: 2rem'>+</span>
    <img src="assets/parameter-distribution.png" 
         style='vertical-align:middle; max-height: 150px; max-width: 20%;'/>
    <span style='padding: 0 1rem; font-size: 2rem'>=</span>
    <img src="assets/violins.png" 
         style='vertical-align:middle; max-height: 150px; max-width: 20%;'/>
</div>

In [None]:
# Compare with "normal" control coefficients
rc = mca.response_coefficients(
    get_lin_chain_two_circles(),
    to_scan=["vmax_1", "vmax_2", "vmax_3", "vmax_5", "vmax_6"],
)
_ = plot.heatmap(rc.variables)

mrc = mc.response_coefficients(
    get_lin_chain_two_circles(),
    to_scan=["vmax_1", "vmax_2", "vmax_3", "vmax_5", "vmax_6"],
    mc_to_scan=sample(
        {
            "k0": LogNormal(np.log(1.0), 1.0),
            "k4": LogNormal(np.log(0.5), 1.0),
        },
        n=10,
    ),
)

_ = plot.violins_from_2d_idx(mrc.variables, n_cols=len(mrc.variables.columns))

<div style="color: #ffffff; background-color: #04AA6D; padding: 3rem 1rem 3rem 1rem; box-sizing: border-box">
    <h2>First finish line</h2>
    With that you now know most of what you will need from a day-to-day basis about monte carlo methods in mxlpy.
    <br />
    <br />
    Congratulations!
</div>

## Advanced topics

## Parameter scans

Vary **both** monte carlo parameters as well as systematically scan for other parameters

In [None]:
mcss = mc.scan_steady_state(
    get_linear_chain_2v(),
    to_scan=pd.DataFrame({"k1": np.linspace(0, 1, 3)}),
    mc_to_scan=sample(
        {
            "k2": Uniform(1.0, 1.3),
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

plot.violins_from_2d_idx(mcss.variables)
plt.show()

In [None]:
# FIXME: no idea how to plot this yet. Ridge plots?
# Maybe it's just a bit much :D

mcss = mc.scan_steady_state(
    get_linear_chain_2v(),
    to_scan=mb2.cartesian_product(
        {
            "k1": np.linspace(0, 1, 3),
            "k2": np.linspace(0, 1, 3),
        }
    ),
    mc_to_scan=sample(
        {
            "k3": LogNormal(mean=1.0, sigma=0.2),
        },
        n=10,
    ),
)

mcss.variables.head()

### Custom distributions

If you want to create custom distributions, all you need to do is to create a class that follows the `Distribution` protocol, e.g. implements a sample function.  

For API consistency, the `sample` method has to take `rng` argument, which can be ignored if not applicable.  

In [None]:
from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from mxlpy.types import Array


@dataclass
class MyOwnDistribution:
    loc: float = 0.0
    scale: float = 1.0

    def sample(
        self,
        num: int,
        rng: np.random.Generator | None = None,
    ) -> Array:
        if rng is None:
            rng = np.random.default_rng()
        return rng.normal(loc=self.loc, scale=self.scale, size=num)


sample({"p1": MyOwnDistribution()}, n=5)