# BioProcess Task

We update the BioProcess task:

- 5 fold cross-validation into train/valid/test
- On a single Time Series, the task is to forecast over a sliding window. The number of datapoints in each slice can vary due to irregular sampling


A timeseries dataset consists of 4 components:

1. Observational Data $(𝐭^\text{obs},𝐗)$ which is available in the observation horizon
2. Target Data $(𝐭^\text{pred},𝐘)$, which is values we should predict in the forecasting horizon. <br> We distinguish 2 subcases:
    1. Autoregressive case: If $𝐘$ are also observational variables
    2. Non-Autoregressive case: If $𝐘$ is not observed / provided over the observation horizon
    More specifically, we can split up $𝐘$ into $(𝐭^\text{pred},𝐘^\text{A})$ and $(𝐭^\text{pred},𝐘^\text{B})$ for the autoregressive and non-autoregressive part
3. Covariate Data $(𝐭^\text{cov},𝐔)$
4. Time-Independent Metadata $𝓜$

Summa summarum a timeseries is described by a 4/5-tuple 

$$ 𝓣𝓢 = \Big( (𝐭^\text{obs},𝐗)，(𝐭^\text{cov},𝐔)，(𝐭^\text{pred},𝐘)，𝓜\Big)
\qquad 𝓣𝓢 =  \Big( (𝐭^\text{obs},𝐗)，(𝐭^\text{cov},𝐔)，(𝐭^\text{auto},𝐘^\text{auto})，(𝐭^\text{pred}, 𝐘^\text{pred})，𝓜\Big)
$$

To shorten the writing, by some abuse of notation we write $𝐗_𝐭≔(𝐭^\text{obs},𝐗)$ and similarly for the other variables

$$ 𝓣𝓢 = \Big(𝐗_𝐭，𝐔_𝐭，𝐘_𝐭，𝓜\Big)
\qquad 𝓣𝓢 =  \Big(𝐗_𝐭，𝐔_𝐭，𝐘_𝐭=(𝐘_𝐭^\text{auto}，𝐘_𝐭^\text{pred})，𝓜\Big)
$$

Further, we introduce the notations:

- $𝐗_{S}≔\{ (tᵢ, X_{tᵢ}) ∣ tᵢ∈S\}$ for any set $S⊂𝓣$, in particular intervals:
    - $𝐗_{[t,t']}≔\{ (tᵢ, X_{tᵢ}) ∣ t≤tᵢ≤t'\}$ (closed interval)
    - $𝐗_{[t,t')}≔\{ (tᵢ, X_{tᵢ}) ∣ t≤tᵢ<'\}$ (half-open interval)
- We allow set addition $t+S ≔ \{t+s∣s∈S\}$, hence
    - $𝐗_{[t, t'] + ∆t}≔\{ (tᵢ, X_{tᵢ}) ∣ (t+∆t) ≤ tᵢ≤ (t'+∆t)\}$
- $𝐗_{≤t}≔ \{ (tᵢ, X_{tᵢ}) ∣ tᵢ≤t\}$, and analogously $𝐗_{<t}$, $𝐗_{>t}$, $𝐗_{≥t}$

The task is, given an observation horizon $I^\text{obs}$ and a forecasting horizon $I^\text{pred}$, <br>
typically $I^\text{obs} = [t, t+∆t^\text{obs}) = [t, s)$, $I^\text{pred}= [t+∆t^\text{obs}, t+∆t^\text{obs}+∆t^\text{pred}) = [s,r)$


$$ℓ(𝐘_{J}, \hat{𝐘}(𝐗_{I}，𝐔_{I+J}，𝐘_{I}^\text{auto}，𝓜))$$

We define initial intervals: 
- Initial Observation interval $I=[t_S, t_O)=[t_\text{start}, t_\text{start}+∆t^\text{obs})$
- Initial Forecasting interval $J=[t_O, t_P)=[t_\text{start}+∆t^\text{obs}, t_\text{start}+∆t^\text{obs}+∆t^\text{pred})$
- Initial Complete Interval $K=I∪J=[t_S, t_p) = [t_\text{start}, t_\text{start}+∆t^\text{obs}+∆t^\text{pred})$

Given a stride, i.e. a time delta $∆t$, we define the set of all integer multiples of $∆t$ that can be added to $t₀$ such that it stays inside the bounds of $[t_\min, t_\max]$:

$$W(t_\min, t_\max, ∆t, t₀) ≔ \{k⋅∆t ∣ k∈ℤ ∧ (t₀+k⋅Δt) ∈ [tₘᵢₙ，tₘₐₓ]\}$$

Moreover, let $W^+$ and $W^-$ be the sets that only take non-negative / non-positive $k$.

Then $G_𝓣≔W^+(T_\min, T_\max, ∆s, t_P)$ is the set of all increments we can apply to $I$ and $J$ such that they stay within the bounds

\begin{equation}
𝓛(θ) = 
\frac{1}{|G_𝓣|}\sum_{∆t∈G_𝓣} 
\frac{1}{|𝐘_{J+∆t}|}\sum_{t,y ∈ 𝐘_{J+∆t}} 
ℓ\big( y, \hat{y}_θ(t ∣ 𝐗_{I+∆t}，𝐔_{K+∆t}，𝐘_{I+∆t}^\text{auto}，𝓜)\big)
\end{equation}

Note that in order to achieve good training across all channels, even if they are sampled at wastely different rates, we normalize with respect to the observation frequency.

$$
\widetilde{\sum} =  \frac{\sum_i m_t ℓ(y_t, \hat{y}_y)}{\sum m_t}
$$

Note that, in order to avoid division by zero errors, we formally set the sum to zero if $∑_t m_t = 0$. In practice this is achieved by simply applying a `nan-mean` function.


- Error Metric: Weighted L2 / L1 / ND score.

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'svg'

%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
from pandas import DataFrame, Series

from tsdm.utils.data import *

In [None]:
from tsdm.datasets import KIWI_RUNS

In [None]:
ds = KIWI_RUNS()

## Generate the Folds

In [None]:
md = ds.metadata.drop([355, 482])
groups = md.groupby(["run_id", "color"], sort=False).ngroup()
folds = folds_from_groups(groups, train=7, valid=1, test=2)
assert all(sum([fold["test"] for fold in folds]) == 1)
assert all(sum([fold["valid"] for fold in folds]) <= 1)
splits = folds_as_frame(folds)

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%load_ext autoreload
%autoreload 2

In [None]:
from collections.abc import Callable
from functools import cached_property
from itertools import product
from typing import Any, Literal, NamedTuple, Optional

import torch
from pandas import DataFrame, Series
from torch import Tensor, jit
from torch.utils.data import DataLoader

from tsdm.datasets import KIWI_RUNS
from tsdm.encoders import BaseEncoder
from tsdm.metrics import WRMSE
from tsdm.random.samplers import HierarchicalSampler, SequenceSampler
from tsdm.tasks.base import BaseTask
from tsdm.utils.data import (
    MappingDataset,
    TimeSeriesDataset,
    folds_as_frame,
    folds_from_groups,
)
from tsdm.utils.strings import repr_namedtuple


class Sample(NamedTuple):
    r"""A sample of the data."""

    key: tuple[tuple[int, int], slice]
    inputs: tuple[DataFrame, DataFrame]
    targets: float
    originals: Optional[tuple[DataFrame, DataFrame]] = None

    def __repr__(self) -> str:
        return repr_namedtuple(self, recursive=1)


class Kiwi_BioProcessTask(BaseTask):
    r"""A collection of bioreactor runs.

    For this task we do several simplifications

    - drop run_id 355
    - drop almost all metadata
    - restrict timepoints to start_time & end_time given in metadata.

    - timeseries for each run_id and experiment_id
    - metadata for each run_id and experiment_id

    When first do a train/test split.
    Then the goal is to learn a model in a multi-task fashion on all the ts.

    To train, we sample
    1. random TS from the dataset
    2. random snippets from the sampled TS

    Questions:
    - Should each batch contain only snippets form a single TS, or is there merit to sampling
    snippets from multiple TS in each batch?

    Divide 'Glucose' by 10, 'OD600' by 20, 'DOT' by 100, 'Base' by 200, then use RMSE.
    """

    index: list[tuple[int, str]] = list(product(range(5), ("train", "valid", "test")))
    r"""Available index."""
    KeyType = tuple[Literal[0, 1, 2, 3, 4], Literal["train", "test"]]
    r"""Type Hint for Keys."""
    timeseries: DataFrame
    r"""The whole timeseries data."""
    metadata: DataFrame
    r"""The metadata."""
    observation_horizon: int = 96
    r"""The number of datapoints observed during prediction."""
    forecasting_horizon: int = 24
    r"""The number of datapoints the model should forecast."""
    preprocessor: BaseEncoder
    r"""Encoder for the observations."""
    controls: Series
    r"""The control variables."""
    targets: Series
    r"""The target variables."""
    observables: Series
    r"""The observables variables."""

    def __init__(
        self,
        *,
        forecasting_horizon: int = 24,
        observation_horizon: int = 96,
    ):
        self.forecasting_horizon = forecasting_horizon
        self.observation_horizon = observation_horizon
        self.horizon = self.observation_horizon + self.forecasting_horizon

        self.timeseries = ts = self.dataset.timeseries
        self.metadata = self.dataset.metadata
        self.units = self.dataset.units

        self.targets = targets = Series(["Base", "DOT", "Glucose", "OD600"])
        self.targets.index = self.targets.apply(ts.columns.get_loc)

        self.controls = controls = Series([
            "Cumulated_feed_volume_glucose",
            "Cumulated_feed_volume_medium",
            "InducerConcentration",
            "StirringSpeed",
            "Flow_Air",
            "Temperature",
            "Probe_Volume",
        ])
        controls.index = controls.apply(ts.columns.get_loc)

        self.observables = observables = Series([
            "Base",
            "DOT",
            "Glucose",
            "OD600",
            "Acetate",
            "Fluo_GFP",
            "Volume",
            "pH",
        ])
        observables.index = observables.apply(ts.columns.get_loc)

        assert (
            set(controls.values) | set(targets.values) | set(observables.values)
        ) == set(ts.columns)

    @cached_property
    def test_metric(self) -> Callable[..., Tensor]:
        r"""The metric to be used for evaluation."""
        ts = self.timeseries
        weights = DataFrame.from_dict(
            {
                "Base": 200,
                "DOT": 100,
                "Glucose": 10,
                "OD600": 20,
            },
            orient="index",
            columns=["inverse_weight"],
        )
        weights["col_index"] = weights.index.map(lambda x: (ts.columns == x).argmax())
        weights["weight"] = 1 / weights["inverse_weight"]
        weights["normalized"] = weights["weight"] / weights["weight"].sum()
        weights.index.name = "col"
        w = torch.tensor(weights["weight"])
        return jit.script(WRMSE(w))

    @cached_property
    def dataset(self) -> KIWI_RUNS:
        r"""Return the cached dataset."""
        dataset = KIWI_RUNS()
        dataset.metadata.drop([355, 482], inplace=True)
        dataset.timeseries.drop([355, 482], inplace=True)
        return dataset

    @cached_property
    def folds(self) -> DataFrame:
        r"""Return the folds."""
        md = self.dataset.metadata
        groups = md.groupby(["run_id", "color"], sort=False).ngroup()
        folds = folds_from_groups(
            groups, seed=2022, num_folds=5, train=7, valid=1, test=2
        )
        return folds_as_frame(folds)

    @cached_property
    def splits(self) -> dict[Any, tuple[DataFrame, DataFrame]]:
        r"""Return a subset of the data corresponding to the split.

        Returns
        -------
        tuple[DataFrame, DataFrame]
        """
        splits = {}
        for key in self.index:
            assert key in self.index, f"Wrong {key=}. Only {self.index} work."
            split, data_part = key

            mask = self.folds[split] == data_part
            idx = self.folds[split][mask].index
            timeseries = self.timeseries.reset_index(level=2).loc[idx]
            timeseries = timeseries.set_index("measurement_time", append=True)
            metadata = self.metadata.loc[idx]
            splits[key] = (timeseries, metadata)
        return splits

    @cached_property
    def dataloader_kwargs(self) -> dict:
        r"""Return the kwargs for the dataloader."""
        return {
            "batch_size": 1,
            "shuffle": False,
            "sampler": None,
            "batch_sampler": None,
            "num_workers": 0,
            "collate_fn": lambda *x: x,
            "pin_memory": False,
            "drop_last": False,
            "timeout": 0,
            "worker_init_fn": None,
            "prefetch_factor": 2,
            "persistent_workers": False,
        }

    def get_dataloader(
        self, key: KeyType, /, shuffle: bool = False, **dataloader_kwargs: Any
    ) -> DataLoader:
        r"""Return a dataloader for the given split.

        Parameters
        ----------
        key: KeyType,
        shuffle: bool, default False
        dataloader_kwargs: Any,

        Returns
        -------
        DataLoader
        """
        # Construct the dataset object
        ts, md = self.splits[key]
        dataset = _Dataset(
            ts,
            md,
            observables=self.observables.index,
            observation_horizon=self.observation_horizon,
            targets=self.targets.index,
        )

        TSDs = {}
        for idx in md.index:
            TSDs[idx] = TimeSeriesDataset(
                ts.loc[idx],
                metadata=md.loc[idx],
            )
        DS = MappingDataset(TSDs)

        # construct the sampler
        subsamplers = {
            key: SequenceSampler(ds, seq_len=self.horizon, shuffle=shuffle)
            for key, ds in DS.items()
        }
        sampler = HierarchicalSampler(DS, subsamplers, shuffle=shuffle)

        # construct the dataloader
        kwargs: dict[str, Any] = {"collate_fn": lambda *x: x} | dataloader_kwargs
        return DataLoader(dataset, sampler=sampler, **kwargs)


class _Dataset(torch.utils.data.Dataset):
    def __init__(self, ts, md, *, observables, targets, observation_horizon):
        super().__init__()
        self.timeseries = ts
        self.metadata = md
        self.observables = observables
        self.targets = targets
        self.observation_horizon = observation_horizon

    def __getitem__(self, item: tuple[tuple[int, int], slice]) -> Sample:
        key, slc = item
        ts = self.timeseries.loc[key].iloc[slc].copy(deep=True)
        md = self.metadata.loc[key].copy(deep=True)
        originals = (ts.copy(deep=True), md.copy(deep=True))
        targets = ts.iloc[self.observation_horizon :, self.targets].copy(deep=True)
        ts.iloc[self.observation_horizon :, self.targets] = float("nan")
        ts.iloc[self.observation_horizon :, self.observables] = float("nan")
        return Sample(key=item, inputs=(ts, md), targets=targets, originals=originals)

In [None]:
self = Kiwi_BioProcessTask()

In [None]:
self.splits.keys()

In [None]:
key = (0, "train")
shuffle = False

####################


ts, md = self.splits[key]
dataset = _Dataset(
    ts,
    md,
    observables=self.observables.index,
    observation_horizon=self.observation_horizon,
    targets=self.targets.index,
)

TSDs = {}
for idx in md.index:
    TSDs[idx] = TimeSeriesDataset(
        ts.loc[idx],
        metadata=md.loc[idx],
    )
DS = MappingDataset(TSDs)

In [None]:
key_, ds = next(iter(DS.items()))

In [None]:
SequenceSampler(ds, seq_len=120)

In [None]:
# construct the sampler
subsamplers = {
    key_: SequenceSampler(ds, seq_len=self.horizon, shuffle=shuffle)
    for key_, ds in DS.items()
}
sampler = HierarchicalSampler(DS, subsamplers, shuffle=shuffle)

# construct the dataloader
kwargs: dict[str, Any] = {"collate_fn": lambda *x: x} | dataloader_kwargs
DataLoader(dataset, sampler=sampler, **kwargs)