# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
import pandas as pd

df = pd.read_csv(
    "/home/rscholz/Projects/KIWI/Baselines/TemporalFusionTransformer/electricity/data/electricity/hourly_electricity.csv"
)

In [None]:
valid_boundary = 1315
test_boundary = 1339
index = df["days_from_start"]
train = df.loc[index < valid_boundary]
valid = df.loc[(index >= valid_boundary - 7) & (index < test_boundary)]
test = df.loc[index >= test_boundary - 7]

In [None]:
import tsdm

DS = tsdm.datasets.Electricity()
ds = DS.dataset.resample("1h").mean()

In [None]:
ds = ds[("2014-01-01" <= ds.index) & (ds.index < "2014-09-08")]

In [None]:
t_start = pd.Timestamp("2014-01-01")
t_train = pd.Timestamp("2014-08-08")
t_valid = pd.Timestamp("2014-09-01")
t_score = pd.Timestamp("2014-09-08")

train_index = ("2014-01-01" <= ds.index) & (ds.index < "2014-08-08")
valid_index = ("2014-08-08" <= ds.index) & (ds.index < "2014-09-01")
score_index = ("2014-09-01" <= ds.index) & (ds.index < "2014-09-08")
total_index = ("2014-01-01" <= ds.index) & (ds.index < "2014-09-08")
joint_index = train_index | valid_index
total_index = train_index | valid_index | score_index

assert all(
    (train_index.astype(int) + valid_index.astype(int) + score_index.astype(int))
    == (train_index | valid_index | score_index)
)

## Standardization

We need Sampler object, that samples indices

t_start, t_mid, t_stop

Signifying observation and forecasting horizon.
Furthermore, we need to know the stride, i.e. how much to advance this window in time.

Finally, we need to know what to do with the final slice of the data which may not accomodate a full window.

- starting index
- stopping index
   - alternatively a dataset!

- strides

In [None]:
from collections.abc import Iterator, Sequence
from typing import Any, Optional, overload

from tsdm.random.samplers import BaseSampler

dt = np.datetime64
td = np.timedelta64


class SequenceSampler(BaseSampler):
    r"""Samples sequences of length seq_len."""

    @overload
    def __init__(self, xmin: dt, xmax: dt, stride: td, seq_len: td) -> None: ...

    @overload
    def __init__(self, xmin: int, xmax: int, stride: int, seq_len: int) -> None: ...

    @overload
    def __init__(
        self, xmin: float, xmax: float, stride: float, seq_len: float
    ) -> None: ...

    def __init__(
        self,
        data_source: Optional[Sequence] = None,
        *,
        xmin: Optional = None,
        xmax: Optional = None,
        stride,
        seq_len,
        return_mask: bool = False,
        shuffle: bool = False,
    ) -> None:
        super().__init__(data_source)

        xmin = xmin if xmin is not None else data_source[0]
        xmax = xmax if xmax is not None else data_source[-1]

        self.data_source = data_source

        self.xmin = xmin if not isinstance(xmin, str) else pd.Timestamp(xmin)
        self.xmax = xmax if not isinstance(xmax, str) else pd.Timestamp(xmax)

        self.stride = stride if not isinstance(stride, str) else pd.Timedelta(stride)
        self.seq_len = (
            seq_len if not isinstance(seq_len, str) else pd.Timedelta(seq_len)
        )
        # k_max = max {k∈ℕ ∣ x_min + seq_len + k⋅stride ≤ x_max}
        self.k_max = int((xmax - xmin - seq_len) // stride)
        self.return_mask = return_mask
        self.shuffle = shuffle

        self.samples = np.array([
            (
                (x <= self.data_source) & (self.data_source < y)
                if self.return_mask
                else [x, y]
            )
            for x, y in self._iter_tuples()
        ])

    def _iter_tuples(self) -> Iterator[tuple[Any, Any]]:
        x = self.xmin
        y = x + self.seq_len
        x, y = min(x, y), max(x, y)  # allows nice handling of negative seq_len
        yield x, y

        for k in range(len(self)):
            x += self.stride
            y += self.stride
            yield x, y

    def __len__(self) -> int:
        return int((self.xmax - self.xmin - self.seq_len) // self.stride)

    def __iter__(self) -> Iterator:
        if self.shuffle:
            perm = np.random.permutation(len(self))
        else:
            perm = np.arange(len(self))

        return iter(self.samples[perm])

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}[{self.stride}, {self.seq_len}]"

## Apply Encoder

In [None]:
from tsdm.encoders import Standardizer

In [None]:
encoder = Standardizer()
encoder.fit(ds[train_index])
encoded = encoder.encode(ds)

In [None]:
sampler = SequenceSampler(
    ds.index,
    xmin=t_start,
    xmax=t_valid,
    seq_len="8d",
    stride="1d",
    return_mask=True,
    shuffle=True,
)

In [None]:
%%timeit
list(sampler);

## Task object

In [None]:
import torch

e_torch = torch.tensor(encoded.values, dtype=torch.float32)

dloader = torch.utils.data.DataLoader(e_torch, sampler=sampler, batch_size=32)

In [None]:
next(iter(dloader)).shape

In [None]:
from torch.utils.data import DataLoader

from tsdm.tasks import BaseTask

In [None]:
class ElectricityTFT(BaseTask):
    ...

    def get_dataloader(
        self,
        key,
        /,
        shuffle: bool = False,
        **dataloader_kwargs: Any,
    ) -> DataLoader: ...


ElectricityTFT()

In [None]:
def get_dataloader(
    self,
    key: KeyType,
    /,
    shuffle: bool = False,
    **dataloader_kwargs: Any,
) -> DataLoader:
    r"""Return a dataloader for the given split.

    Parameters
    ----------
    key: KeyType,
    shuffle: bool, default False
    dataloader_kwargs: Any,

    Returns
    -------
    DataLoader
    """
    # Construct the dataset object
    dataset = self.encoded_dataset

    sampler = SequenceSampler(
        dataset.index,
        xmin=t_start,
        xmax=t_valid,
        seq_len="8d",
        stride="1d",
        return_mask=True,
        shuffle=True,
    )
    return DataLoader(dataset, sampler=sampler, **kwargs)

In [None]:
## TFT preproc: social-time + time since start + ...
# Crucial here: weekly & daily frequency.
# Can't we just use time2vec with 24h / 7d freq?
# Probably.

# Need many2many FrameEncoder?
#

# o-time
# - social time features (append)
# o-time replace with time since start

In [None]:
ds