# Kiwi Runs - conceptual

Conceptually, this dataset is a `Mapping[TimeSeriesDataset]`, or even a `Mapping[Mapping[TimeSeriesDataset]]`, depending on how we treat the `run_id`, `exp_id` Multi-Index.

Where a `TimeSeriesDataset` consists of 
  - time series data (typically `tuple[TimeTensor]`)
  - and static metadata (typically `tuple[Tensor]`)

In [1]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging
logging.basicConfig(level=logging.INFO)

In [2]:
import numpy as np
import matplotlib.pyplot as plt

np.set_printoptions(precision=4, floatmode='fixed', suppress=True)
rng = np.random.default_rng()

In [10]:
from tsdm.datasets import KIWI_RUNS
from tsdm.datasets.base import TimeTensor, TimeSeriesDataset, DatasetCollection

In [25]:
for idx, md_slc in md.iterrows():
    ts_slc = ts.loc[idx]    

In [30]:
tsd = TimeSeriesDataset(timeseries = ts_slc)

In [40]:
import logging
from typing import Any, Iterable, Mapping, Sequence

from torch import Tensor
from torch.utils.data import Dataset as Torch_Dataset

from tsdm.util.strings import repr_mapping
from collections.abc import Mapping
from torch.utils.data import Dataset as Torch_Dataset

In [38]:

class DatasetCollection:
    ...

class DatasetSequence:
    ...

class DatasetMapping(Mapping, Torch_Dataset):
    r"""Represents a ``mapping[index → torch.Datasets]``.

    All tensors must have a shared index,
    in the sense that index.unique() is identical for all inputs.
    """

    dataset: dict[Any, Torch_Dataset]
    """The dataset"""

    def __init__(self, indexed_datasets: Mapping[Any, Torch_Dataset]):
        super().__init__()
        self.dataset = dict(indexed_datasets)
        self.index = self.dataset.keys()
        self.keys = self.dataset.keys  # type: ignore[assignment]
        self.values = self.dataset.values  # type: ignore[assignment]
        self.items = self.dataset.items  # type: ignore[assignment]

    def __len__(self):
        r"""Length of the dataset."""
        return len(self.dataset)

    def __getitem__(self, item):
        r"""Hierarchical lookup."""
        # test for hierarchical indexing
        if isinstance(item, Sequence):
            first, rest = item[0], item[1:]
            if isinstance(first, (Iterable, slice)):
                # pass remaining indices to sub-object
                value = self.dataset[first]
                return value[rest]

        # no hierarchical indexing
        return self.dataset[item]

    def __iter__(self):
        r"""Iterate over the dataset."""
        for key in self.index:
            yield self.dataset[key]

    def __repr__(self):
        r"""Representation of the dataset."""
        return repr_mapping(self)


In [32]:
class TimeSeriesDataset(TensorDataset):
    """A general Time Series Dataset.

    Consists of 2 things
    - timeseries: TimeTensor / tuple[TimeTensor]
    - metadata: Tensor / tuple[Tensor]

    When retrieving items, we generally use slices:

    - ds[timestamp] = ds[timestamp:timestamp]
    - ds[t₀:t₁] = tuple[X[t₀:t₁] for X in self.timeseries], metadata
    """

    timeseries: Union[IndexedArray, tuple[IndexedArray, ...]]
    metadata: Optional[Union[IndexedArray, tuple[IndexedArray, ...]]] = None

    def __init__(
        self,
        *tensors: IndexedArray,
        timeseries: Optional[
            Union[
                IndexedArray,
                tuple[Index, Sized],
                dict[Index, Sized],
                Collection[IndexedArray],
                Collection[tuple[Index, Sized]],
            ]
        ] = None,
        metadata: Optional[Union[Tensor, tuple[Tensor]]] = None,
    ):
        super().__init__()

        ts_tensors = []
        for tensor in tensors:
            ts_tensors.append(TimeTensor(tensor))

        if timeseries is not None:
            # Case: IndexedArray
            if is_indexed_array(timeseries):
                ts_tensors.append(TimeTensor(timeseries))
            # Case: tuple[Index, ArrayLike],
            elif (
                isinstance(timeseries, tuple)
                and len(timeseries) == 2
                and isinstance(timeseries[0], Index)
            ):
                index, tensor = timeseries
                ts_tensors.append(TimeTensor(tensor, index=index))
            # Case: dict[Index, ArrayLike],
            elif isinstance(timeseries, Mapping):
                for index, tensor in timeseries.items():
                    ts_tensors.append(TimeTensor(tensor, index=index))
            # Case: Iterable
            elif isinstance(timeseries, Iterable):
                timeseries = list(timeseries)
                firstobs = timeseries[0]
                # Case: Iterable[IndexedArray]
                if is_indexed_array(firstobs):
                    for tensor in timeseries:
                        ts_tensors.append(TimeTensor(tensor))
                # Case: Iterable[tuple(Index, ArrayLike)]
                elif (
                    isinstance(firstobs, tuple)
                    and len(firstobs) == 2
                    and isinstance(firstobs[0], Index)
                ):
                    for index, tensor in timeseries:
                        ts_tensors.append(TimeTensor(tensor, index=index))
                else:
                    raise ValueError(f"{timeseries=} not undertstood")
            else:
                raise ValueError(f"{timeseries=} not undertstood")

        self.timeseries = tuple(ts_tensors)

        if metadata is not None:
            if isinstance(metadata, tuple):
                self.metadata = tuple(Tensor(tensor) for tensor in metadata)
            else:
                self.metadata = Tensor(metadata)

    def __repr__(self) -> str:
        r"""Pretty print."""
        pad = r"  "

        if isinstance(self.timeseries, tuple):
            ts_lines = [tensor_info(tensor) for tensor in self.timeseries]
        else:
            ts_lines = [tensor_info(self.timeseries)]

        if self.metadata is None:
            md_lines = [f"{None}"]
        elif isinstance(self.metadata, tuple):
            md_lines = [tensor_info(tensor) for tensor in self.metadata]
        else:
            md_lines = [tensor_info(self.metadata)]

        return (
            f"{self.__class__.__name__}("
            + "".join(["\n" + 2 * pad + line for line in ts_lines])
            + "\n"
            + pad
            + "metadata:"
            + "".join(["\n" + 2 * pad + line for line in md_lines])
            + "\n"
            + ")"
        )

    def __len__(self) -> int:
        r"""Return the length of the longest timeseries."""
        if isinstance(self.timeseries, tuple):
            return max(len(tensor) for tensor in self.timeseries)
        return len(self.timeseries)

    def __getitem__(self, item):
        r"""Return corresponding slice from each tensor."""
        return tuple(tensor.loc[item] for tensor in self.timeseries)

In [9]:
ds = KIWI_RUNS()
ts = ds.timeseries
md = ds.metadata