In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from pandas import DataFrame, Index, Series
from torch.utils.data import Dataset as TorchDataset

import tsdm
from tsdm.datasets import TimeSeriesCollection
from tsdm.random.samplers import HierarchicalSampler, SlidingWindowSampler

In [None]:
import logging

logging.basicConfig(level=logging.INFO)

In [None]:
ds = tsdm.datasets.KIWI_RUNS()

In [None]:
ds.time_features.dtypes

In [None]:
TSC = TimeSeriesCollection(
    index=ds.index,
    timeseries=ds.timeseries,
    metadata=ds.metadata,
    time_features=ds.time_features,
    value_features=ds.value_features,
    metadata_features=ds.metadata_features,
)

## Sampler

In [None]:
ts = ds.timeseries.loc[(439, 15325)]

In [None]:
sampler = SlidingWindowSampler(ts.index, horizons=["2h", "1h"], stride="1h")

In [None]:
next(iter(sampler))

## Construct the Sampler

In [None]:
TSC

In [None]:
from torch.utils.data import RandomSampler, SubsetRandomSampler

In [None]:
next(iter(SubsetRandomSampler(TSC.index)))

In [None]:
TSC[(439, 15325)]

In [None]:
subsamplers = {
    key: SlidingWindowSampler(ds.timeseries.index, horizons=["2h", "1h"], stride="1h")
    for key, ds in TSC.items()
}
sampler = HierarchicalSampler(TSC, subsamplers, shuffle=False)

In [None]:
outer_key, (forecasting_horizon, prediction_horizon) = next(iter(sampler))

In [None]:
TSC[outer_key][forecasting_horizon]

# TimeSeriesCollectionForecastingTask

In [None]:
TSC

In [None]:
from tsdm.tasks import TimeSeriesCollectionTask

targets = ["Base", "DOT", "Glucose", "OD600"]
observables = [
    "Base",
    "DOT",
    "Glucose",
    "OD600",
    "Acetate",
    "Fluo_GFP",
    "pH",
]
covariates = [
    "Cumulated_feed_volume_glucose",
    "Cumulated_feed_volume_medium",
    "InducerConcentration",
    "StirringSpeed",
    "Flow_Air",
    "Temperature",
    "Probe_Volume",
]

In [None]:
key = next(iter(sampler))
outer_key, (observation_horizon, forecasting_horizon) = key
TSC[outer_key][observation_horizon]

In [None]:
task = TimeSeriesCollectionTask(
    TSC,
    targets=targets,
    observables=observables,
    covariates=covariates,
    sample_format=("sparse", "sparse"),
)
sample = task[key]

In [None]:
if sample.inputs.t_target is not None:
    diff = sample.inputs.t_target.index.difference(sample.targets.y.index)
    sample.inputs.t_target.drop(diff, inplace=True)
sample

In [None]:
raise

In [None]:
diff = sample.inputs.t_target.index.difference(sample.targets.y.index)
sample.inputs.t_target.drop(diff, inplace=True)
sample

In [None]:
raise

# Mapping Dataset

In [None]:
d = dict(enumerate("asdfghjkl"))

In [None]:
from collections.abc import Mapping
from dataclasses import dataclass

from torch.utils.data import Dataset

In [None]:
@dataclass
class MyMapping(Dataset, Mapping):
    internal_dict: dict

    def __iter__(self):
        return iter(self.internal_dict)

    def __getitem__(self, key):
        return self.internal_dict[key]

    def __len__(self):
        return len(self.internal_dict)

In [None]:
from torch.utils.data import DataLoader

dataloader = DataLoader(MyMapping(d))