In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from pandas import DataFrame, Index, Series
from torch.utils.data import Dataset as TorchDataset

from tsdm.datasets import TimeSeriesCollection
from tsdm.random.samplers import HierarchicalSampler, SlidingWindowSampler

In [None]:
import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import tsdm

In [None]:
ds = tsdm.datasets.KIWI_RUNS()

In [None]:
ds.time_features.dtypes

In [None]:
TSC = TimeSeriesCollection(
    index=ds.index,
    timeseries=ds.timeseries,
    metadata=ds.metadata,
    time_features=ds.time_features,
    value_features=ds.value_features,
    metadata_features=ds.metadata_features,
)

## Sampler

In [None]:
ts = ds.timeseries.loc[(439, 15325)]

In [None]:
sampler = SlidingWindowSampler(ts.index, horizons=["2h", "1h"], stride="1h")

In [None]:
next(iter(sampler))

## Construct the Sampler

In [None]:
TSC

In [None]:
from torch.utils.data import RandomSampler, SubsetRandomSampler

In [None]:
next(iter(SubsetRandomSampler(TSC.index)))

In [None]:
TSC[(439, 15325)]

In [None]:
subsamplers = {
    key: SlidingWindowSampler(ds.timeseries.index, horizons=["2h", "1h"], stride="1h")
    for key, ds in TSC.items()
}
sampler = HierarchicalSampler(TSC, subsamplers, shuffle=False)

In [None]:
outer_key, (forecasting_horizon, prediction_horizon) = next(iter(sampler))

In [None]:
TSC[outer_key][forecasting_horizon]

# TimeSeriesCollectionForecastingTask

## TimeSeriesCollection Sample

In [None]:
from typing import NamedTuple

from tsdm.utils.strings import *


class Inputs(NamedTuple):
    """tuple of inputs"""

    t: int
    x: int
    t_target: int
    metadata: int

    def __repr__(self):
        return repr_namedtuple(self, recursive=False)


class Sample(NamedTuple):
    """A sample for forecasting task"""

    key: int
    inputs: Inputs
    targets: int
    # originals: int

    def __repr__(self):
        return repr_namedtuple(self, recursive=False)

In [None]:
from dataclasses import dataclass


@dataclass
class TimeSeriesCollectionForecastingTask(TorchDataset):
    r"""Create Sample from TSC."""

    dataset: TimeSeriesCollection
    targets: Index
    controls: Index
    observables: Index

    def __getitem__(self, key) -> Sample:
        assert isinstance(key, tuple) and len(key) == 2
        outer_key, inner_key = key
        assert isinstance(inner_key, list) and len(inner_key) == 2
        observation_horizon, forecasting_horizon = inner_key

        tsd = self.dataset[outer_key]
        md = tsd.metadata

        obs = tsd[observation_horizon]
        pre = tsd[forecasting_horizon]
        horizon = observation_horizon | forecasting_horizon
        ts = tsd[horizon]

        return Sample(
            key=outer_key,
            inputs=Inputs(ts, obs, ts, pre),
            targets=pre,
        )

In [None]:
key = next(iter(sampler))

In [None]:
task = TimeSeriesCollectionForecastingTask(
    TSC, targets=None, controls=None, observables=None
)
task[key]

# Mapping Dataset

In [None]:
d = dict(enumerate("asdfghjkl"))

In [None]:
from collections.abc import Mapping
from dataclasses import dataclass

from torch.utils.data import Dataset

In [None]:
@dataclass
class MyMapping(Dataset, Mapping):
    internal_dict: dict

    def __iter__(self):
        return iter(self.internal_dict)

    def __getitem__(self, key):
        return self.internal_dict[key]

    def __len__(self):
        return len(self.internal_dict)

In [None]:
from torch.utils.data import DataLoader

dataloader = DataLoader(MyMapping(d))