# Kiwi Runs - conceptual

Conceptually, this dataset is a `Mapping[TimeSeriesDataset]`, or even a `Mapping[Mapping[TimeSeriesDataset]]`, depending on how we treat the `run_id`, `experiment_id` Multi-Index.

Where a `TimeSeriesDataset` consists of 
  - time series data (typically `tuple[TimeTensor]`)
  - and static metadata (typically `tuple[Tensor]`)

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

## imports

In [None]:
import logging
from collections.abc import *
from itertools import chain, count
from typing import Any, NamedTuple, Optional, TypeVar, Union

import numpy as np
import pandas
import pandas as pd
import torch
from pandas import DataFrame, Index, Interval, MultiIndex, Series, Timedelta, Timestamp
from torch import Tensor
from torch.utils.data import *
from torch.utils.data import Dataset as Torch_Dataset
from tqdm.auto import tqdm

from tsdm.datasets import KIWI_RUNS
from tsdm.datasets.torch import TimeSeriesDataset, TimeTensor
from tsdm.encoders import *
from tsdm.utils.strings import repr_mapping, repr_sequence

IndexedArray = Union[Series, DataFrame, TimeTensor]
r"""Type Hint for IndexedArrays."""

_IndexedArray = (Series, DataFrame, TimeTensor)
r"""TODO: replace with python 3.10"""

# Task Relevant Data

In [None]:
ds = KIWI_RUNS()
ts = ds.timeseries.drop([355, 445, 482]).astype(float)
md = ds.metadata.drop([355, 445, 482])
target = "Fluo_GFP"

### Final Timestamp Table

In [None]:
def get_induction_time(s: Series) -> Timestamp:
    # Compute the induction time
    # s = ts.loc[run_id, experiment_id]
    inducer = s["InducerConcentration"]
    total_induction = inducer[-1] - inducer[0]

    if pd.isna(total_induction) or total_induction == 0:
        return pd.NA

    diff = inducer.diff()
    mask = pd.notna(diff) & (diff != 0.0)
    inductions = inducer[mask]
    assert len(inductions) == 1, "Multiple Inductions occur!"
    return inductions.first_valid_index()

In [None]:
def get_final_product(s: Series, target) -> Timestamp:
    # Final and target times
    targets = s[target]
    mask = pd.notna(targets)
    targets = targets[mask]
    assert len(targets) >= 1, f"not enough target observations {targets}"
    return targets.index[-1]

In [None]:
def get_time_table(ts: DataFrame, target="Fluo_GFP", t_min="0.6h", delta_t="5m"):
    columns = [
        "slice",
        "t_min",
        "t_induction",
        "t_max",
        "t_target",
    ]
    index = ts.reset_index(level=[2]).index.unique()
    df = DataFrame(index=index, columns=columns)

    min_wait = Timedelta(t_min)

    for idx, slc in tqdm(ts.groupby(level=[0, 1])):
        slc = slc.reset_index(level=[0, 1], drop=True)
        # display(slc)
        t_induction = get_induction_time(slc)
        t_target = get_final_product(slc, target=target)
        if pd.isna(t_induction):
            print(f"{idx}: no t_induction!")
            t_max = get_final_product(slc.loc[slc.index < t_target], target=target)
            assert t_max < t_target
        else:
            assert t_induction < t_target, f"{t_induction=} after {t_target}!"
            t_max = t_induction
        df.loc[idx, "t_max"] = t_max

        df.loc[idx, "t_min"] = t_min = slc.index[0] + min_wait
        df.loc[idx, "t_induction"] = t_induction
        df.loc[idx, "t_target"] = t_target
        df.loc[idx, "slice"] = slice(t_min, t_max)
        # = t_final
    return df

In [None]:
final_product_times = get_time_table(ts)

### Final Vector Table

In [None]:
final_vec = {}
for idx in md.index:
    t_target = final_product_times.loc[idx, "t_target"]
    final_vec[(*idx, t_target)] = ts.loc[idx].loc[t_target]

final_vec = DataFrame.from_dict(final_vec, orient="index")
final_vec.index = final_vec.index.set_names(ts.index.names)
final_vec = final_vec[target]

# Implementation of Dataset Objects

## DatasetMapping

In [None]:
class DatasetMapping(Mapping, Torch_Dataset):
    r"""Represents a ``mapping[index → torch.Datasets]``.

    All tensors must have a shared index,
    in the sense that index.unique() is identical for all inputs.
    """

    dataset: dict[Any, Torch_Dataset]
    """The dataset"""

    def __init__(self, indexed_datasets: Mapping[Any, Torch_Dataset]):
        super().__init__()
        self.dataset = dict(indexed_datasets)
        self.index = self.dataset.keys()
        self.keys = self.dataset.keys  # type: ignore[assignment]
        self.values = self.dataset.values  # type: ignore[assignment]
        self.items = self.dataset.items  # type: ignore[assignment]

    def __len__(self):
        r"""Length of the dataset."""
        return len(self.dataset)

    def __getitem__(self, item):
        r"""Hierarchical lookup."""
        # test for hierarchical indexing
        if isinstance(item, Sequence):
            first, rest = item[0], item[1:]
            if isinstance(first, (Iterable, slice)):
                # pass remaining indices to sub-object
                value = self.dataset[first]
                return value[rest]

        # no hierarchical indexing
        return self.dataset[item]

    def __iter__(self):
        r"""Iterate over the dataset."""
        for key in self.index:
            yield self.dataset[key]

    def __repr__(self):
        r"""Representation of the dataset."""
        return repr_mapping(self)

## TimeSeriesDataset

In [None]:
TimeLikeTensor = Any
StaticTensor = Any


def repr_array(obj, title: Optional[str] = None):
    if hasattr(obj, "shape"):
        return type(obj).__name__ + "[" + str(obj.shape) + "]"
    return "[" + ", ".join(repr_array(x) for x in obj) + "]"


def repr_singleton_or_tuple(obj, repr_fun: Optional[Callable] = None):
    to_string = repr if repr_fun is None else repr_fun

    if isinstance(obj, tuple):
        return repr_sequence(obj, repr_fun=to_string)
    return to_string(obj)


class TimeSeriesSlice(NamedTuple):
    timeseries: tuple[TimeLikeTensor, ...]
    metadata: tuple[StaticTensor, ...]

    def __repr__(self):
        pad = " " * 2
        string = type(self).__name__ + "("
        string += (
            "\n"
            + pad
            + "timeseries="
            + repr_singleton_or_tuple(self.timeseries, repr_array)
        )
        string += (
            "\n"
            + pad
            + "metadata="
            + repr_singleton_or_tuple(self.metadata, repr_array)
        )
        string += "\n)"
        return string


def tensor_info(x: Tensor) -> str:
    r"""Print useful information about Tensor."""
    return f"{x.__class__.__name__}[{tuple(x.shape)}]"

In [None]:
class TimeSeriesDataset(TensorDataset):
    """A general Time Series Dataset.

    Consists of 2 things
    - timeseries: TimeTensor / tuple[TimeTensor]
    - metadata: Tensor / tuple[Tensor]

    When retrieving items, we generally use slices:

    - ds[timestamp] = ds[timestamp:timestamp]
    - ds[t₀:t₁] = tuple[X[t₀:t₁] for X in self.timeseries], metadata
    """

    timeseries: Union[TimeLikeTensor, tuple[TimeLikeTensor, ...]]
    metadata: Optional[Union[StaticTensor, tuple[StaticTensor, ...]]] = None

    def __init__(
        self,
        timeseries: Union[IndexedArray, Collection[IndexedArray]],
        *timetensors: IndexedArray,
        metadata: Optional[Union[Tensor, Collection[Tensor]]] = None,
    ):
        super().__init__()

        ts_tensors = (
            [timeseries] if isinstance(timeseries, _IndexedArray) else list(timeseries)
        )
        ts_tensors.extend(timetensors)
        self.timeseries = ts_tensors[0] if len(ts_tensors) == 1 else tuple(ts_tensors)
        self.metadata = metadata

    def __repr__(self) -> str:
        r"""Pretty print."""
        pad = r"  "

        if isinstance(self.timeseries, tuple):
            ts_lines = [tensor_info(tensor) for tensor in self.timeseries]
        else:
            ts_lines = [tensor_info(self.timeseries)]

        if self.metadata is None:
            md_lines = [f"{None}"]
        elif isinstance(self.metadata, tuple):
            md_lines = [tensor_info(tensor) for tensor in self.metadata]
        else:
            md_lines = [tensor_info(self.metadata)]

        return (
            f"{self.__class__.__name__}("
            + "".join(["\n" + 2 * pad + line for line in ts_lines])
            + "\n"
            + pad
            + "metadata:"
            + "".join(["\n" + 2 * pad + line for line in md_lines])
            + "\n"
            + ")"
        )

    def __len__(self) -> int:
        r"""Return the length of the longest timeseries."""
        if isinstance(self.timeseries, tuple):
            minval = min(min(ts.index) for ts in self.timeseries)
            maxval = max(max(ts.index) for ts in self.timeseries)
            return maxval - minval
        minval = min(self.timeseries.index)
        maxval = max(self.timeseries.index)
        return maxval - minval

    def __getitem__(self, item):
        r"""Return corresponding slice from each tensor."""
        if isinstance(self.timeseries, tuple):
            ts = tuple(tensor.loc[item] for tensor in self.timeseries)
        else:
            ts = self.timeseries.loc[item]
        md = self.metadata
        return TimeSeriesSlice(ts, md)

## MappingDataset

In [None]:
class MappingDataset(Torch_Dataset, Mapping):
    r"""Represents a Mapping[Key, Dataset]."""

    def __init__(self, data: Mapping, prepend_key: bool = False):
        super().__init__()
        assert isinstance(data, Mapping)
        if isinstance(data, Mapping):
            self.index = data.keys()
            self.data = data
        self.prepend_key = prepend_key

    def __iter__(self) -> Iterator:
        return iter(self.index)

    def __len__(self):
        return len(self.index)

    def _lookup(self, key):
        if not isinstance(key, tuple):
            return self.data[key]
        try:
            outer = self.data[key[0]]
            return outer[key[1:]]
        except KeyError:
            return self.data[key]

    def __getitem__(self, key):
        if self.prepend_key:
            return key, self._lookup(key)
        return self._lookup(key)

    @staticmethod
    def from_dataframe(df: DataFrame, levels: Optional[list[str]] = None):
        if levels is not None:
            mindex = df.index.to_frame()
            subidx = MultiIndex.from_frame(mindex[levels])
            index = subidx.unique()
        else:
            index = df.index

        return MappingDataset({idx: df.loc[idx] for idx in index})

    def __repr__(self):
        r"""Representation of the dataset."""
        return repr_mapping(self)  # , repr_fun=repr_array)

# Implementation of Sampler Objects

## MappingSampler

In [None]:
class MappingSampler(Sampler):
    def __init__(self, data_source: Mapping, shuffle: bool = True):
        # super().__init__(data_source)
        self.data = data_source

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self):
        perm = np.random.permutation(list(self.data.keys()))
        for k in perm:
            yield self.data[k]

## IntervalSampler

In [None]:
TimedeltaLike = TypeVar("TimedeltaLike", int, float, Timedelta)
TimestampLike = TypeVar("TimestampLike", int, float, Timestamp)


def grid(
    xmin: TimestampLike,
    xmax: TimestampLike,
    delta: TimedeltaLike,
    xoffset: Optional[TimestampLike] = None,
) -> list[int]:
    """Computes `\{k∈ℤ∣ xₘᵢₙ ≤ x₀+k⋅Δ ≤ xₘₐₓ\}`.

    Special case: if Δ=0, returns [0]
    """

    xo = xmin if xoffset is None else xoffset
    zero = type(delta)(0)

    if delta == zero:
        return [0]

    assert delta > zero, "Assumption delta>0 violated!"
    assert xmin <= xoffset <= xmax, "Assumption: xmin≤xoffset≤xmax violated!"

    a = xmin - xoffset
    b = xmax - xoffset
    kmax = b // delta
    kmin = a // delta

    assert xmin <= xo + kmin * delta
    assert xmin > xo + (kmin - 1) * delta
    assert xmax >= xo + kmax * delta
    assert xmax < xo + (kmax + 1) * delta

    return list(range(kmin, kmax + 1))

In [None]:
V = TypeVar("V")

Boxed = Union[
    Sequence[V],
    Mapping[int, V],
    Callable[[int], V],
]

dt_type = Union[
    TimedeltaLike,
    Sequence[TimedeltaLike],
    Mapping[int, TimedeltaLike],
    Callable[[int], TimedeltaLike],
]

import numpy as np
from torch.utils.data import Sampler


class IntervalSampler(
    Sampler,
):
    """Returns all intervals `[a, b]` such that:

    - `a = t₀ + i⋅sₖ`
    - `b = t₀ + i⋅sₖ + Δtₖ`
    - `i, k ∈ ℤ`
    - `a ≥ t_min`
    - `b ≤ t_max`
    - `sₖ` is the stride corresponding to intervals of size `Δtₖ`
    """

    def __init__(
        self,
        xmin,
        xmax,
        deltax: dt_type,
        stride: Optional[dt_type] = None,
        levels: Optional[Sequence[int]] = None,
        offset: Optional[dt_type] = None,
        multiples: bool = True,
        shuffle: bool = True,
    ) -> None:
        # set stride and offset
        zero = 0 * (xmax - xmin)
        stride = zero if stride is None else stride
        offset = xmin if offset is None else offset

        # validate bounds
        assert xmin <= offset <= xmax, "Assumption: xmin≤xoffset≤xmax violated!"

        # determine delta_max
        delta_max = max(offset - xmin, xmax - offset)

        # determine levels
        if levels is None:
            if isinstance(deltax, Mapping):
                levels = [k for k in deltax.keys() if deltax[k] <= delta_max]
            elif isinstance(deltax, Sequence):
                levels = [k for k in range(len(deltax)) if deltax[k] <= delta_max]
            elif isinstance(deltax, Callable):
                levels = []
                for k in count():
                    dt = self._get_value(deltax, k)
                    if dt == zero:
                        continue
                    if dt > delta_max:
                        break
                    levels.append(k)
            else:
                levels = [0]
        else:
            levels = [k for k in levels if self._get_value(deltax, k) <= delta_max]

        # validate levels
        assert all(self._get_value(deltax, k) <= delta_max for k in levels)

        # compute valid intervals
        intervals: list[Interval] = []

        # for each level, get all intervals
        for k in levels:
            dt = self._get_value(deltax, k)
            st = self._get_value(stride, k)
            x0 = self._get_value(offset, k)

            # get valid interval bounds, probably there is an easier way to do it...
            stridesa = grid(xmin, xmax, st, x0)
            stridesb = grid(xmin, xmax, st, x0 + dt)
            valid_strides = set.intersection(set(stridesa), set(stridesb))

            if not valid_strides:
                break

            intervals.extend([
                (x0 + i * st, x0 + i * st + dt, dt, st) for i in valid_strides
            ])

        # set variables
        self.offset = offset
        self.deltax = deltax
        self.stride = stride
        self.shuffle = shuffle
        self.intervals = DataFrame(
            intervals, columns=["left", "right", "delta", "stride"]
        )

    def __iter__(self) -> Iterator:
        if self.shuffle:
            perm = np.random.permutation(len(self))
        else:
            perm = np.arange(len(self))

        for k in perm:
            yield slice(self.loc[k, "left"], self.loc[k, "right"])

    def __len__(self) -> int:
        return len(self.intervals)

    def __getattr__(self, key):
        return self.intervals.__getattr__(key)

    def __getitem__(self, key):
        return self.intervals[key]

    @staticmethod
    def _get_value(obj: Union[V, Boxed[V]], k: int) -> V:
        if isinstance(obj, Callable):
            return obj(k)
        if isinstance(obj, Sequence):
            return obj[k]
        # Fallback: multiple!
        return obj

## HierarchicalSampler

In [None]:
class HierarchicalSampler(Sampler):
    r"""Samples a single random dataset from a collection of dataset.

    Optionally, we can delegate a subsampler to then sample from the randomly drawn dataset.
    """

    idx: Index
    r"""The shared index."""
    subsamplers: Mapping[Any, Sampler]
    r"""The subsamplers to sample from the collection."""
    early_stop: bool = False
    r"""Whether to stop sampling when the index is exhausted."""
    shuffle: bool = True
    r"""Whether to sample in random order."""
    sizes: Series
    r"""The sizes of the subsamplers."""
    partition: Series
    r"""Contains each key a number of times equal to the size of the subsampler."""

    def __init__(
        self,
        data_source: Dataset,
        subsamplers: Mapping[Any, Sampler],
        shuffle: bool = True,
        early_stop: bool = False,
    ):
        super().__init__(data_source)
        self.data = data_source
        self.idx = data_source.keys()
        self.subsamplers = dict(subsamplers)
        self.sizes = Series({key: len(self.subsamplers[key]) for key in self.idx})
        self.shuffle = shuffle
        self.early_stop = early_stop

        if early_stop:
            partition = list(chain(*([key] * min(self.sizes) for key in self.idx)))
        else:
            partition = list(chain(*([key] * self.sizes[key] for key in self.idx)))
        self.partition = Series(partition)

    def __len__(self):
        r"""Return the maximum allowed index."""
        if self.early_stop:
            return min(self.sizes) * len(self.subsamplers)
        return sum(self.sizes)

    def __iter__(self):
        r"""Return indices of the samples.

        When ``early_stop=True``, it will sample precisely min() * len(subsamplers) samples.
        When ``early_stop=False``, it will sample all samples.
        """
        activate_iterators = {
            key: iter(sampler) for key, sampler in self.subsamplers.items()
        }

        if self.shuffle:
            perm = np.random.permutation(self.partition)
        else:
            perm = self.partition

        for key in perm:
            yield key, next(activate_iterators[key])

    def __getitem__(self, key: Any) -> Sampler:
        r"""Return the subsampler for the given key."""
        return self.subsamplers[key]

    def __repr__(self):
        return repr_mapping(self.subsamplers)

# Plugging everything together

## Preprocessing

In [None]:
# d = {}
# for idx, slc in ts.groupby(["run_id", "experiment_id"]):
#     slc = slc.reset_index([0, 1], drop=True)
#     lower = slc.index[0]
#     slc.index = slc.index - slc.index[0]
#     d[idx] = slc
# ts = pandas.concat(d, names=["run_id", "experiment_id"])

## Dataset Object

In [None]:
outer_index = md.index
TSDs = {}
for idx in md.index:
    TSDs[idx] = TimeSeriesDataset(
        ts.loc[idx],
        metadata=(md.loc[idx], final_vec.loc[idx]),
    )

DS = MappingDataset(TSDs, prepend_key=True)

## Samplers 

### the subsampler dictionary

In [None]:
delta_t = Timedelta("5m")

subsamplers = {}

for key in TSDs:
    subsampler = IntervalSampler(
        xmin=final_product_times.loc[key, "t_min"],
        xmax=final_product_times.loc[key, "t_max"],
        # offset=t_0,
        deltax=lambda k: k * delta_t,
        stride=None,
        shuffle=True,
    )
    subsamplers[key] = subsampler
sampler = HierarchicalSampler(TSDs, subsamplers, shuffle=True)

In [None]:
ts.loc[439].loc[
    15325
]  # .loc[Timedelta('0 days 00:32:18'): Timedelta('0 days 01:28:47')]

In [None]:
idx = next(iter(sampler))
ts.loc[(439, 15325)]

### store indices to .csv

In [None]:
# df = DataFrame(list(iter(sampler)), columns=["idx", "slice"])
# df["start"] = df["slice"].apply(lambda x: x.start) / Timedelta("1h")
# df["stop"] = df["slice"].apply(lambda x: x.stop) / Timedelta("1h")
# df["run_id"], df["experiment_id"] = zip(*df.idx)
# df = df[["run_id", "experiment_id", "start", "stop"]].round(2)
# df = df.sort_values(["run_id", "experiment_id", "start", "stop"])

# DataLoader

## Simple (no collate_fn)

In [None]:
def collate(x):
    return x


dloader = DataLoader(DS, sampler=sampler, collate_fn=collate, batch_size=8)

In [None]:
batch = next(iter(dloader))

## With PreProcessing

## With Post-Processing

Batch ⟶ Post-Processed Batch

- convert things to `tensor`s if not already done
- combine `list[Tensor]` to `Tensor` / `PaddedTensor` / `PackedTensor`


Specific Example: Final Product Value

1. Data Lookup:
    - main data encoding: 
        - main data: Standardization 
        - index: MinMax
        - triplet encoding
    
    - target data lookup:
        - Seperate Dataset?
        - TupleDataset? <= why not? i.e. lookup both inputs and target at once!
            - This one only needs the first level of index, so just drop the other?
            - Treat it as different kind of metadata? ✓ Sounds good!

2. Encoding
    - main data: just use the encoder as advertised.
    - target data: need a separate "target_encoder"
        - could be a slice of the regular encoder
        

In [None]:
preprocessor = ChainedEncoder(
    # ConcatEncoder(axis=-1),
    TensorEncoder(names=("time", "value", "index")),
    DataFrameEncoder(
        column_encoders={
            "value": IdentityEncoder(),
            tuple(ts.columns): FloatEncoder("float32"),
        },
        index_encoders=MinMaxScaler() @ DateTimeEncoder(unit="h"),
    ),
    TripletEncoder(sparse=True),
    # DataFrameEncoder(
    Standardizer(),
    # index_encoders = MinMaxScaler() @ DateTimeEncoder(unit="h"),
    # ),
)

encoder = preprocessor

In [None]:
original = ts.loc[439, 15325].round(2)

In [None]:
encoder.fit(ts.reset_index([0, 1], drop=True))

In [None]:
encoded = encoder.encode(ts.reset_index([0, 1], drop=True))

In [None]:
encoder.encode(original.iloc[[-10]])

In [None]:
final_vec.loc[439, 15325]

In [None]:
class Batch(NamedTuple):
    index: Tensor
    timeseries: Tensor
    metadata: Tensor

    def __repr__(self):
        return repr_mapping(
            self._asdict(), title=self.__class__.__name__, repr_fun=repr_array
        )


def mycollate(batch: list):
    index: list[Tensor] = []
    timeseries: list[Tensor] = []
    metadata: list[Tensor] = []

    for idx, (ts, md) in batch:
        index.append(torch.tensor(idx[0], dtype=int))
        timeseries.append(ts)
        metadata.append(md)
    return Batch(torch.stack(index), timeseries, metadata)

In [None]:
dloader = DataLoader(DS, sampler=sampler, collate_fn=mycollate, batch_size=8)
batch = next(iter(dloader))

## Encode Target Only

In [None]:
target_encoder = TensorEncoder() @ FloatEncoder() @ Standardizer(axis=())

In [None]:
target_col = ts[target]
target_idx = ts.columns.get_loc(target)
target_encoder.fit(target_col)
result = target_encoder.encode(target_col)
mask = torch.isnan(result)
result[~mask]
# mask = pd.notna(result)
# result[mask]

In [None]:
# reuse the other encoder via slicing
target_encoder = TensorEncoder() @ FloatEncoder() @ encoder[-1][target_idx]

In [None]:
result = target_encoder.encode(target_col)
mask = torch.isnan(result)
result[~mask]

# The **Final** result

In [None]:
next(iter(sampler))

In [None]:
DS[next(iter(sampler))]

In [None]:
timeseries = []
metadata = []
targets = []
encoded_targets = []

for ts_data, (md_data, target) in zip(batch.timeseries, batch.metadata):
    timeseries.append(preprocessor.encode(ts_data))
    metadata.append(md_data)
    targets.append(target)
    encoded_targets.append(target_encoder.encode(target))

targets = pandas.concat(targets)
encoded_targets = torch.concat(encoded_targets)

In [None]:
def get_value_from_triplet(triplet, idx: int):
    time, value, index = triplet
    indices = torch.argmax(index, dim=1) == target_index
    idx = torch.argmax(indices.to(int))
    return value[idx]

In [None]:
class Batch(NamedTuple):
    index: Tensor
    timeseries: Tensor
    metadata: Tensor
    targets: Tensor
    encoded_targets: Tensor

    def __repr__(self):
        return repr_mapping(
            self._asdict(), title=self.__class__.__name__, repr_fun=repr_array
        )


def mycollate(batch: list):
    index: list[Tensor] = []
    timeseries = []
    metadata = []
    targets = []
    encoded_targets = []

    for idx, (ts_data, (md_data, target)) in batch:
        index.append(torch.tensor(idx[0]))
        timeseries.append(preprocessor.encode(ts_data))
        metadata.append(md_data)
        targets.append(target)
        encoded_targets.append(target_encoder.encode(target))

    index = torch.stack(index)
    targets = pandas.concat(targets)
    encoded_targets = torch.concat(encoded_targets)

    return Batch(index, timeseries, metadata, targets, encoded_targets)

In [None]:
dloader = DataLoader(DS, sampler=sampler, collate_fn=mycollate, batch_size=8)
batch = next(iter(dloader))

In [None]:
key = next(iter(dloader.sampler))
sample = dloader.dataset[key]

(key, slc), (ts, (md, target)) = sample

In [None]:
for _ in tqdm(dloader): ...

In [None]:
from tsdm.models import SetFuncTS

model = SetFuncTS(17, 1)

In [None]:
t, v, m = batch.timeseries[0]
model(t, v, m)

In [None]:
model.batch_forward(batch.timeseries)

# Modified DataFrameEncoder

In [None]:
class DataFrameEncoder(BaseEncoder):
    r"""Combine multiple encoders into a single one.

    It is assumed that the DataFrame Modality doesn't change.
    """

    column_encoders: Union[BaseEncoder, Mapping[Any, BaseEncoder]]
    r"""Encoders for the columns."""
    index_encoders: Optional[Union[BaseEncoder, Mapping[Any, BaseEncoder]]] = None
    r"""Optional Encoder for the index."""
    colspec: Series = None
    r"""The columns-specification of the DataFrame."""
    encode_index: bool
    r"""Whether to encode the index."""
    column_wise: bool
    r"""Whether to encode column-wise."""
    partitions: Optional[dict] = None
    r"""Contains partitions if used column wise."""

    def __init__(
        self,
        column_encoders: Union[BaseEncoder, Mapping[Any, BaseEncoder]],
        *,
        index_encoders: Optional[Union[BaseEncoder, Mapping[Any, BaseEncoder]]] = None,
    ):
        r"""Set up the individual encoders.

        Note: the same encoder instance can be used for multiple columns.

        Parameters
        ----------
        column_encoders
        index_encoders
        """
        super().__init__()
        self.column_encoders = column_encoders

        if isinstance(index_encoders, Mapping):
            raise NotImplementedError("Multi-Index encoders not yet supported")

        self.index_encoders = index_encoders
        self.column_wise: bool = isinstance(self.column_encoders, Mapping)
        self.encode_index: bool = index_encoders is not None

        index_spec = DataFrame(
            columns=["col", "encoder"],
            index=Index([], name="partition"),
        )

        if self.encode_index:
            if not isinstance(self.index_encoders, Mapping):
                _idxenc_spec = Series(
                    {
                        "col": pd.NA,
                        "encoder": self.index_encoders,
                    },
                    name=0,
                )
                # index_spec = index_spec.append(_idxenc_spec)
                index_spec.loc[0] = _idxenc_spec
            else:
                raise NotImplementedError(
                    "Multiple Index encoders are not supported yet."
                )

        if not isinstance(self.column_encoders, Mapping):
            colenc_spec = DataFrame(
                columns=["col", "encoder"],
                index=Index([], name="partition"),
            )

            _colenc_spec = Series(
                {
                    "col": pd.NA,
                    "encoder": self.column_encoders,
                },
                name=0,
            )
            # colenc_spec = colenc_spec.append(_colenc_spec)
            # colenc_spec = pandas.concat([colenc_spec, _colenc_spec])
            colenc_spec.loc[0] = _colenc_spec
        else:
            keys = self.column_encoders.keys()
            assert len(set(keys)) == len(keys), "Some index are duplicates!"

            _encoders = tuple(set(self.column_encoders.values()))
            encoders = Series(_encoders, name="encoder")
            partitions = Series(range(len(_encoders)), name="partition")

            _columns = defaultdict(list)
            for key, encoder in self.column_encoders.items():
                _columns[encoder].append(key)

            columns = Series(_columns, name="col")

            colenc_spec = DataFrame(encoders, index=partitions)
            colenc_spec = colenc_spec.join(columns, on="encoder")

        self.spec = pandas.concat(
            [index_spec, colenc_spec],
            keys=["index", "columns"],
            names=["section", "partition"],
        ).astype({"col": object})

        self.spec.name = self.__class__.__name__

        # add extra repr options by cloning from spec.
        # for x in [
        #     "_repr_data_resource_",
        #     "_repr_fits_horizontal_",
        #     "_repr_fits_vertical_",
        #     "_repr_html_",
        #     "_repr_latex_",
        # ]:
        #     setattr(self, x, getattr(self.spec, x))

    def fit(self, df: DataFrame, /) -> None:
        r"""Fit to the data."""
        self.colspec = df.dtypes

        if self.index_encoders is not None:
            if isinstance(self.index_encoders, Mapping):
                raise NotImplementedError("Multiple index encoders not yet supported")
            self.index_encoders.fit(df.index)

        if isinstance(self.column_encoders, Mapping):
            # check if cols are a proper partition.
            keys = set(df.columns)
            _keys = set(self.column_encoders.keys())
            assert keys <= _keys, f"Missing encoders for columns {keys - _keys}!"
            assert (
                keys >= _keys
            ), f"Encoder given for non-existent columns {_keys - keys}!"

            for _, series in self.spec.loc["columns"].iterrows():
                encoder = series["encoder"]
                cols = series["col"]
                encoder.fit(df[cols])
        else:
            cols = list(df.columns)
            self.spec.loc["columns"].iloc[0]["col"] = cols
            encoder = self.spec.loc["columns", "encoder"].item()
            encoder.fit(df)

    def encode(self, df: DataFrame, /) -> DataFrame:
        r"""Encode the input."""
        encoded_frames: dict[Any, DataFrame] = {}
        for partition, (col_names, encoder) in self.spec.loc["columns"].iterrows():
            encoded_frame = encoder.encode(df[col_names])
            encoded_frames[partition] = encoded_frame

        if self.index_encoders is not None:
            if isinstance(self.index_encoders, Mapping):
                raise NotImplementedError("Multiple index encoders not yet supported")
            encoded_index = self.index_encoders.encode(df.index)
        else:
            encoded_index = df.index

        encoded = pandas.concat(
            encoded_frames, axis="columns", names=["partition", df.columns.name]
        )
        encoded = encoded.droplevel(
            "partition", axis="columns"
        )  # remove partition index
        encoded = encoded.set_index(encoded_index)
        encoded = encoded[df.columns]  # fix column order
        return encoded

    def decode(self, data: DataFrame, /) -> DataFrame:
        r"""Decode the input."""
        if self.encode_index:
            if isinstance(self.index_encoders, Mapping):
                raise NotImplementedError("Multiple index encoders not yet supported")
            encoder = self.spec.loc["index", "encoder"].item()
            decoded_index = encoder.decode(data.index)
        else:
            decoded_index = None

        decoded_frames: dict[Any, DataFrame] = {}
        for partition, (col_names, encoder) in self.spec.loc["columns"].iterrows():
            # col_names += col_names
            decoded_frame = encoder.decode(data[col_names])
            decoded_frames[partition] = decoded_frame

        # return decoded_frames
        decoded = pandas.concat(
            decoded_frames, axis="columns", names=["partition", data.columns.name]
        )
        decoded = decoded.droplevel(
            "partition", axis="columns"
        )  # remove partition index
        decoded = decoded.set_index(decoded_index)
        decoded = decoded[self.colspec.index]  # fix column order
        decoded = decoded.astype(self.colspec)  # fix data types
        return decoded

    def __repr__(self) -> str:
        """Pretty print."""
        return f"{self.__class__.__name__}(" + self.spec.__repr__() + "\n)"

    def _repr_html_(self) -> str:
        """HTML representation."""
        html_repr = self.spec._repr_html_()  # pylint: disable=protected-access
        return f"<h3>{self.__class__.__name__}</h3> {html_repr}"

In [None]:
next(iter(dloader.sampler))

In [None]:
df = dloader.sampler[439, 15325].intervals

In [None]:
type(df)