# Kiwi Final Product

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
from functools import cached_property
from types import SimpleNamespace
from typing import Any, Literal, Union

import pandas
import pandas as pd
from pandas import DataFrame, Interval, Series, Timedelta, Timestamp
from tqdm.auto import tqdm

from tsdm.datasets import KIWI_RUNS
from tsdm.tasks import KIWI_RUNS_TASK, BaseTask

dataset = KIWI_RUNS()
ts = dataset.timeseries.drop([355, 445, 482])
md = dataset.timeseries.drop([355, 445, 482])
task = KIWI_RUNS_TASK()

In [None]:
df = task.split_idx_sparse

## Get the induction time and time/value of final product

In [None]:
self = SimpleNamespace()

target: Literal["OD600", "Fluo_GFP"] = "Fluo_GFP"
t_min: Union[str, Timedelta] = "0.6h"
delta_t: Union[str, Timedelta] = "5m"
eval_batch_size: int = 128
train_batch_size: int = 32

self.target = target
self.delta_t = Timedelta(delta_t)
self.t_min = Timedelta(t_min)
self.eval_batch_size = eval_batch_size
self.train_batch_size = train_batch_size

# setup dataset
self.dataset = KIWI_RUNS()
self.dataset.timeseries = self.dataset.timeseries.drop([355, 445, 482])
self.dataset.metadata = self.dataset.metadata.drop([355, 445, 482])
self.units: DataFrame = self.dataset.units
self.metadata: DataFrame = self.dataset.metadata
self.timeseries: DataFrame = self.dataset.timeseries

#### Start with empty table

In [None]:
def get_induction_time(s: Series) -> Timestamp:
    # Compute the induction time
    # s = ts.loc[run_id, experiment_id]
    inducer = s["InducerConcentration"]
    total_induction = inducer[-1] - inducer[0]

    if pd.isna(total_induction) or total_induction == 0:
        return pd.NA

    inductions = inducer[inducer.diff() != 0.0]
    assert len(inductions) == 1, "Multiple Inductions occur!"
    return inductions.first_valid_index()

In [None]:
def get_final_product(s: Series) -> Timestamp:
    # Final and target times
    targets = s[target]
    mask = pd.notna(targets)
    targets = targets[mask]
    assert len(targets) >= 1, f"not enough target observations {targets}"
    return targets.index[-1]

In [None]:
def get_time_table(ts: DataFrame):
    columns = [
        "slice",
        "delta_t",
        "t_min",
        "t_induction",
        "t_max",
        "t_final",
        "y_final",
    ]
    index = ts.reset_index(level=[2]).index.unique()
    df = DataFrame(index=index, columns=columns)

    df["t_min"] = self.t_min
    df["delta_t"] = self.delta_t

    for idx, slc in tqdm(ts.groupby(level=[0, 1])):
        slc = slc.reset_index(level=[0, 1], drop=True)
        # display(slc)
        t_induction = get_induction_time(slc)
        t_final = get_final_product(slc)
        if pd.isna(t_induction):
            print(f"{idx}: no t_induction!")
            t_max = get_final_product(slc.loc[slc.index < t_final])
            assert t_max < t_final
        else:
            assert t_induction < t_final, f"{t_induction=} after {t_final}!"
            t_max = t_induction
        df.loc[idx, "t_max"] = t_max
        df.loc[idx, "t_min"] = t_min = slc.index[0] + self.t_min
        df.loc[idx, "t_induction"] = t_induction
        df.loc[idx, "t_final"] = t_final
        df.loc[idx, "y_final"] = slc.loc[t_final, target]
        df.loc[idx, "slice"] = slice(t_min, t_max)
        # = t_final
    return df

In [None]:
final_product_times = get_time_table(ts)

In [None]:
slc = final_product_times.slice.iloc[0]

In [None]:
(439, 15325, slc)

In [None]:
task = KIWI_RUNS_TASK()

## The sampling

for each time-series, we create a sampler that creates timeslices.



### IntervalSampler

Returns all intervals `[a, b]` such that

- `a = t₀ + i⋅sₖ`
- `b = t₀ + i⋅sₖ + Δtₖ`
- `i, k ∈ ℤ`
- `a ≥ t_min`
- `b ≤ t_max`
- `sₖ` is the stride corresponding to intervals of size `Δtₖ`
- interval sizes can be provided by one of:
   - single value -> `Δtₖ` will be integer multiples of it
   - `Sequence[type]`
   - `Mapping[int, type]`
   - `Callable[[int], type]`
- stride sizes can be provided via one of:
   - single value -> `sₖ` will be integer multiples of it
   - `Sequence[type]`
   - `Mapping[int, type]`
   - `Callable[[int], type]`

**Mandatory Inputs**

- `t_min: Timestamp`
- `t_max: Timestamp`


**Optional: Exactly one of the following**
- `num_slices: int`
- `delta_t: TimeDelta` 
- `grid: Sequence[Timestamp]`

**Optional Inputs**
- `t_offset: Timestamp = t_min` The basepoint for the grid. Can also be randomly generated, if required.
- `min_length: TimeDelta | int = 0`
  If int, the minimum multiple of `Δt` allowed.
  If TimeDelta, then the lower bound for multiples of `Δt`
- `max_length: TimeDelta | int = t_max-t_min`
  If int, the maximum multiple of `Δt` allowed.
  If TimeDelta, then the upper bound for multiples of `Δt`
- `shuffle: bool = True` Whether to randomly order the generated slices.

# Implementing the Samplers

## Helper function

Return all integers `k` for which

`t_min ≤ t_0 + k⋅Δt ≤ t_max`

In [None]:
from collections.abc import Callable, Iterator, Mapping, Sequence
from itertools import count
from typing import Optional, TypeVar

TimedeltaLike = TypeVar("TimedeltaLike", int, float, Timedelta)
TimestampLike = TypeVar("TimestampLike", int, float, Timestamp)


def grid(
    xmin: TimestampLike,
    xmax: TimestampLike,
    delta: TimedeltaLike,
    xoffset: Optional[TimestampLike] = None,
) -> list[int]:
    """Computes `\{k∈ℤ∣ xₘᵢₙ ≤ x₀+k⋅Δ ≤ xₘₐₓ\}`.

    Special case: if Δ=0, returns [0]
    """

    xo = xmin if xoffset is None else xoffset
    zero = type(delta)(0)

    if delta == zero:
        return [0]

    assert delta > zero, "Assumption delta>0 violated!"
    assert xmin <= xoffset <= xmax, "Assumption: xmin≤xoffset≤xmax violated!"

    a = xmin - xoffset
    b = xmax - xoffset
    kmax = b // delta
    kmin = a // delta

    assert xmin <= xo + kmin * delta
    assert xmin > xo + (kmin - 1) * delta
    assert xmax >= xo + kmax * delta
    assert xmax < xo + (kmax + 1) * delta

    return list(range(kmin, kmax + 1))

In [None]:
pd.notna(task.timeseries).mean().sort_values()

In [None]:
task.controls

In [None]:
ts.loc[439, 15325]

In [None]:
cats = pandas.CategoricalDtype(categories=list("abcd"))

In [None]:
s = pd.Series(list("abca")).astype(cats)
pd.get_dummies(s)

In [None]:
V = TypeVar("V")

Boxed = Union[
    Sequence[V],
    Mapping[int, V],
    Callable[[int], V],
]

dt_type = Union[
    TimedeltaLike,
    Sequence[TimedeltaLike],
    Mapping[int, TimedeltaLike],
    Callable[[int], TimedeltaLike],
]

import numpy as np
from torch.utils.data import Sampler


class IntervalSampler(
    Sampler,
):
    """Returns all intervals `[a, b]` such that:

    - `a = t₀ + i⋅sₖ`
    - `b = t₀ + i⋅sₖ + Δtₖ`
    - `i, k ∈ ℤ`
    - `a ≥ t_min`
    - `b ≤ t_max`
    - `sₖ` is the stride corresponding to intervals of size `Δtₖ`
    """

    def __init__(
        self,
        xmin,
        xmax,
        deltax: dt_type,
        stride: Optional[dt_type] = None,
        levels: Optional[Sequence[int]] = None,
        offset: Optional[dt_type] = None,
        multiples: bool = True,
        shuffle: bool = True,
    ) -> None:
        # set stride and offset
        zero = 0 * (xmax - xmin)
        stride = zero if stride is None else stride
        offset = xmin if offset is None else offset

        # validate bounds
        assert xmin <= offset <= xmax, "Assumption: xmin≤xoffset≤xmax violated!"

        # determine delta_max
        delta_max = max(offset - xmin, xmax - offset)

        # determine levels
        if levels is None:
            if isinstance(deltax, Mapping):
                levels = [k for k in deltax.keys() if deltax[l] <= delta_max]
            elif isinstance(deltax, Sequence):
                levels = [k for k in range(len(deltax)) if deltax[k] <= delta_max]
            elif isinstance(deltax, Callable):
                levels = []
                for k in count():
                    dt = self._get_value(deltax, k)
                    if dt == zero:
                        continue
                    if dt > delta_max:
                        break
                    levels.append(k)
            else:
                levels = [0]
        else:
            levels = [k for k in levels if self._get_value(deltax, k) <= delta_max]

        # validate levels
        assert all(self._get_value(deltax, k) <= delta_max for k in levels)

        # compute valid intervals
        intervals: list[Interval] = []

        # for each level, get all intervals
        for k in levels:
            dt = self._get_value(deltax, k)
            st = self._get_value(stride, k)
            x0 = self._get_value(offset, k)

            # get valid interval bounds, probably there is an easier way to do it...
            stridesa = grid(xmin, xmax, st, x0)
            stridesb = grid(xmin, xmax, st, x0 + dt)
            valid_strides = set.intersection(set(stridesa), set(stridesb))

            if not valid_strides:
                break

            intervals.extend([
                (x0 + i * st, x0 + i * st + dt, dt, st) for i in valid_strides
            ])

        # set variables
        self.offset = offset
        self.deltax = deltax
        self.stride = stride
        self.shuffle = shuffle
        self.intervals = DataFrame(
            intervals, columns=["left", "right", "delta", "stride"]
        )

    def __iter__(self) -> Iterator:
        if self.shuffle:
            perm = np.random.permutation(len(self))
        else:
            perm = np.arange(len(self))

        for k in perm:
            yield self.loc[k, "left"], self.loc[k, "right"]

    def __len__(self) -> int:
        return len(self.intervals)

    def __getattr__(self, key):
        return self.intervals.__getattr__(key)

    def __getitem__(self, key):
        return self.intervals[key]

    @staticmethod
    def _get_value(obj: Union[V, Boxed[V]], k: int) -> V:
        if isinstance(obj, Callable):
            return obj(k)
        if isinstance(obj, Sequence):
            return obj[k]
        # Fallback: multiple!
        return obj

In [None]:
idx = (439, 15325)
s = ts.loc[idx]
t_min = s.index[0]
t_max = s.index[-1]
t_0 = t_min
delta_t = Timedelta("5m")
stride = Timedelta("5m")
t_0, t_min, t_max, delta_t, stride

In [None]:
sampler = IntervalSampler(
    xmin=t_min,
    xmax=t_max,
    # offset=t_0,
    deltax=lambda k: k * delta_t,
    stride=None,
    shuffle=True,
)

sampler

In [None]:
final_product = get_time_table(ts)

In [None]:
final_product.loc[idx]

In [None]:
datasets = {}
subsamplers = {}

for idx, slc in tqdm(ts.groupby(["run_id", "experiment_id"])):
    # T, X = self.preprocessor.encode(slc.reset_index(level=[0, 1], drop=True))
    delta_t, t_min, t_induction, t_max, t_final, *_ = final_product.loc[idx]
    subsamplers[idx] = IntervalSampler(
        xmin=t_min,
        xmax=t_max,
        # offset=t_0,
        deltax=lambda k: k * delta_t,
        stride=None,
        shuffle=True,
    )
    datasets[idx] = slc.reset_index(level=[0, 1], drop=True)

In [None]:
?CollectionSampler

In [None]:
obj.__repr__()

In [None]:
obj.__next__ = obj.__repr__()

In [None]:
range()

In [None]:
list(iter(obj))

In [None]:
if hasattr(thing, "__next__"):
    return str(thing)
elif isinstance(thing, dict) and _nest_lvl < get_option("display.pprint_nest_depth"):
    result = _pprint_dict(
        thing, _nest_lvl, quote_strings=True, max_seq_items=max_seq_items
    )
elif is_sequence(thing) and _nest_lvl < get_option("display.pprint_nest_depth"):
    result = _pprint_seq(
        thing,
        _nest_lvl,
        escape_chars=escape_chars,
        quote_strings=quote_strings,
        max_seq_items=max_seq_items,
    )
elif isinstance(thing, str) and quote_strings:
    result = f"'{as_escaped_string(thing)}'"
else:
    result = as_escaped_string(thing)

In [None]:
type(df.dtypes[0])

In [None]:
ts

In [None]:
list(sampler)

## Implementing the SliceSampler

TODOS:

- Create a **`DataLoader`**

- Need a working **`TimeSeriesCollection`** object (tuple[TimeTensor] + tuple[MetaData])
    - Need a working **`TimeTensor`** object that indexes a `torch.tensor` with a `pandas.Index` or `pandas.MultiIndex`
- Need a working **`CollectionSampler`** object (⇝ currently does not return idx or metadata!)
    - Return `NamedTuple` object (timeseries: list[tensor], metadata: list[tensor], index: list[tensor])
- Implement custom **`collate_fn`** functions 
    - Just return `list[Tensor]`: Pro: simplest thing, Con: 
    - padded tensor: Pro: Simple, model must not specifically support it. Con: really bad when things have highly varying length
        - Batch by size! (Issue: could lead to varying `batchsize`, )
    - `PackedSequence`: Pro: Fastest code likely, Con: Model must explicitly support it!

In [None]:
class SliceSampler(Sampler):
    r"Sample slices from data"
    
    
    def __init__(
        self,
        dataset: Dataset,
        interval_sampler: Sampler,
    ):
    
    self.dataset = dataset
    self.interval_sampler = interval_sampler
    
    
    def __iter__(self) -> Iterator:
        for left, right in self.interval_sampler:
            yield self.dataset[left:right]

    def __len__(self) -> int:
        return len(self.sampler)

## Implementing the Task Object

In [None]:
class KIWI_FINAL_PRODUCT(BaseTask):
    """Predict the final Biomass.

    The goal ist to forecast the final product/biomass value only.
    This means the problem can both be viewed as a time-series forecasting,
    and as a time-series regression task if one ignores the final time stamp.

    The evluation protocol consists of considering initial segments of the time-series `TS[t≤k*Δt]`
    where `k` ranges over all integers satisfying `t_{min} ≤ k*Δt ≤ t_{max}`.

    Here, `t_{min}` is a global constant (0.6h by defaut), `t_{max}` is chosen on a per-time-series basis

    - If there was induction, `t_{max} = t_{induction}`.
    - Else, `t_{max} = \max\{ t < t_{final}\}`.

    Thus, for each time-series one obtains a set of admissible slices

    .. math::
        J_i = \{ k∈ℤ ∣ t_{min}(TS_i) ≤ k*Δt ≤ t_{max}(TS_i) \}
        S_i = \{ TS_i[t≤k*Δt] ∣ k∈J_i \}

    The target metric is averaged over these slices, and each time-series weight is normalized by the number of slices.

    .. math::
        ℒ(θ) = 𝔼_i 𝔼_{S∈S_i} ℓ( ̂y(S, θ), y(S) )
    """

    def __init__(
        self,
        target: Literal["OD600", "Fluo_GFP"] = "Fluo_GFP",
        t_min: Union[str, Timedelta] = "0.6h",
        delta_t: Union[str, Timedelta] = "5m",
        eval_batch_size: int = 128,
        train_batch_size: int = 32,
    ) -> None:
        self.target = target
        self.delta_t = Timedelta(delta_t)
        self.t_min = Timedelta(t_min)
        self.eval_batch_size = eval_batch_size
        self.train_batch_size = train_batch_size

        # setup dataset
        self.dataset: Dataset = KIWI_RUNS()
        self.units: DataFrame = self.dataset.units
        self.metadata: DataFrame = self.dataset.metadata.drop([355, 482])
        self.timeseries: DataFrame = self.dataset.timeseries.drop([355, 482])

        # compute t_max, t_induction and t_final for each time series

    @cached_property
    def index(self) -> None: ...

    @cached_property
    def split_idx(self) -> DataFrame:
        splitter = ShuffleSplit(n_splits=5, random_state=0, test_size=0.25)
        groups = self.metadata.groupby(["color", "run_id"])
        group_idx = groups.ngroup()

        splits = DataFrame(index=self.metadata.index)
        for i, (train, _) in enumerate(splitter.split(groups)):
            splits[i] = group_idx.isin(train).map({False: "test", True: "train"})

        splits.columns.name = "split"
        return splits.astype("string").astype("category")

    @cached_property
    def splits(self) -> dict[Any, tuple[DataFrame, DataFrame]]: ...

    def get_dataloader(): ...