In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from __future__ import annotations

In [None]:
from collections.abc import Collection, Iterator
from dataclasses import KW_ONLY, dataclass
from typing import Any, Generic, Literal, Optional, TypeAlias

from pandas import NA, DataFrame, Index
from torch.utils.data.dataset import Dataset as TorchDataset

from tsdm.datasets import TimeSeriesCollection, TimeSeriesDataset
from tsdm.utils.strings import repr_dataclass
from tsdm.utils.types import KeyVar

In [None]:
TimeSlice: TypeAlias = Index | slice
TSC_KeyV: TypeAlias = tuple[KeyVar, tuple[TimeSlice, TimeSlice]]
FMT_OptionsV: TypeAlias = Literal["masked", "sparse"]
# TODO: Add type hint for ("masked", "masked") once mypy 983 is released
# FIXME: https://github.com/python/mypy/issues/13871
from typing import TypeVar, overload

TimeSeriesClass = TypeVar("TimeSeriesClass", TimeSeriesDataset, TimeSeriesCollection)


@dataclass
class TimeSeriesTask(TorchDataset, Generic[KeyVar, TimeSeriesClass]):
    r"""Creates sample from a TimeSeriesCollection.

    There are different modus operandi for creating samples from a TimeSeriesCollection.

    Format Specification
    ~~~~~~~~~~~~~~~~~~~~

    - column-sparse
    - separate x and u and y

    - masked: In this format, two equimodal copies of the data are stored with appropriate masking.
        - inputs = (t, s, m)
        - targets = (s', m')

    - dense: Here, the data is split into groups of equal length. (x, u, y) share the same time index.
        - inputs = (t, x, u, m_x)
        - targets = (y, m_y)

    - sparse: Here, the data is split into groups sparse tensors. ALl NAN-only rows are dropped.
        - inputs = (t_y, (t_x, x), (t_u, u), m_x)
        - targets = (y, m_y)

    This class is used inside DataLoader.

    +---------------+------------------+------------------+
    | variable      | observation-mask | forecasting-mask |
    +===============+==================+==================+
    | observables X | ✔                | ✘                |
    +---------------+------------------+------------------+
    | controls U    | ✔                | ✔                |
    +---------------+------------------+------------------+
    | targets Y     | ✘                | ✔                |
    +---------------+------------------+------------------+

    Examples
    --------
    - time series classification task: empty forecasting horizon, only metadata_targets set.
    - time series imputation task: observation horizon and forecasting horizon overlap
    - time series forecasting task: observation horizon and forecasting horizon
    - time series forecasting task (autoregressive): observables = targets
    - time series event forecasting: predict both event and time of event (tᵢ, yᵢ)_{i=1:n} given n
    - time series event forecasting++: predict both event and time of event (tᵢ, yᵢ)_{i=1:n} and n

    Technical Remark
    ----------------
    The option `pin_memory` of `torch.utils.data.DataLoader` recurses through
    Mappings and Sequence. However, it will cast the types. The only preserved Types are

    - dicts
    - tuples
    - namedtuples

    Dataclasses are currently not supported. Therefore, we preferably use namedtuples
    or dicts as containers.
    """

    dataset: Any
    _: KW_ONLY = NotImplemented
    targets: Index
    r"""Columns of the data that are used as targets."""
    observables: Index = NotImplemented
    r"""Columns of the data that are used as inputs."""
    covariates: Optional[Index] = None
    r"""Columns of the data that are used as controls."""
    metadata_targets: Optional[Index] = None
    r"""Columns of the metadata that are targets."""
    metadata_observables: Optional[Index] = NotImplemented
    r"""Columns of the metadata that are targets."""
    sample_format: str = "masked"

    def __post_init__(self):
        r"""Post init."""
        if self.observables is NotImplemented:
            self.observables = self.dataset.timeseries.columns
        if self.metadata_observables is NotImplemented:
            if self.dataset.metadata is None:
                self.metadata_observables = None
            else:
                self.metadata_observables = self.dataset.metadata.columns

    def __iter__(self) -> Iterator[TimeSeriesCollection]:
        return iter(self.dataset)

    def __len__(self) -> int:
        return len(self.dataset)

    def __repr__(self):
        return repr_dataclass(self, recursive=1)

    def __getitem__(self, key) -> Sample:
        match self.sample_format:
            case "masked" | ("masked", "masked"):
                return self._make_masked_sample(key)
            case ("sparse", "masked"):
                return self._make_sparse_index_sample(key)  # type: ignore[unreachable]
            case ("masked", "sparse"):
                return self._make_sparse_column_sample(key)  # type: ignore[unreachable]
            case "sparse" | ("sparse", "sparse"):
                return self._make_sparse_sample(key)
            case _:
                raise ValueError(f"Unknown sample format {self.sample_format=}")

    @overload
    def _make_masked_sample(self, key: tuple[TimeSlice, TimeSlice]) -> Sample: ...

    @overload
    def _make_masked_sample(
        self, key: tuple[KeyVar, tuple[TimeSlice, TimeSlice]]
    ) -> Sample: ...

    def _make_masked_sample(self, key):
        assert isinstance(key, tuple) and len(key) == 2
        if isinstance(self.dataset, TimeSeriesDataset):
            observation_horizon, forecasting_horizon = key
            tsd = self.dataset
        else:
            assert isinstance(key[1], Collection) and len(key[1]) == 2
            outer_key, (observation_horizon, forecasting_horizon) = key
            tsd = self.dataset[outer_key]

        # timeseries
        ts_observed = tsd[observation_horizon]
        ts_forecast = tsd[forecasting_horizon]
        horizon_index = ts_observed.index.union(ts_forecast.index)
        ts = tsd[horizon_index]

        x = ts.copy()
        # mask everything except covariates and observables
        columns = ts.columns.difference(self.covariates or [])
        x.loc[ts_observed.index, columns.difference(self.observables)] = NA
        x.loc[ts_forecast.index, columns] = NA

        y = ts.copy()
        # mask everything except targets in forecasting horizon
        y.loc[ts_observed.index] = NA
        y.loc[ts_forecast.index, ts.columns.difference(self.targets)] = NA
        t_target = y.index.to_series()

        # metadata
        md = tsd.metadata
        md_targets: Optional[DataFrame] = None
        if self.metadata_targets is not None:
            assert md is not None
            md_targets = md[self.metadata_targets]
            md = md.drop(columns=self.metadata_targets)

        inputs = Inputs(t_target=t_target, x=x, u=None, metadata=md)
        targets = Targets(y=y, metadata=md_targets)
        return Sample(key=key, inputs=inputs, targets=targets)

    def _make_sparse_column_sample(self, key) -> Sample:
        assert isinstance(key, tuple) and len(key) == 2
        assert isinstance(key[1], Collection) and len(key[1]) == 2
        outer_key, (observation_horizon, forecasting_horizon) = key

        tsd = self.dataset[outer_key]

        # timeseries
        ts_observed = tsd[observation_horizon]
        ts_forecast = tsd[forecasting_horizon]
        horizon_index = ts_observed.index.union(ts_forecast.index)
        ts = tsd[horizon_index]

        x = ts[self.observables].copy()
        x.loc[ts_forecast.index] = NA

        y = ts[self.targets].copy()
        y.loc[ts_observed.index] = NA
        t_target = y.index.to_series()

        u: Optional[DataFrame] = None
        if self.covariates is not None:
            u = ts[self.covariates]

        # metadata
        md = tsd.metadata
        md_targets: Optional[DataFrame] = None
        if self.metadata_targets is not None:
            assert md is not None
            md_targets = md[self.metadata_targets]
            md[
                (
                    md.columns.difference(self.metadata_targets)
                    if self.metadata_observables is None
                    else self.metadata_observables
                )
            ] = NA

            # md = md.drop(columns=self.metadata_targets)

        inputs = Inputs(t_target=t_target, x=x, u=u, metadata=md)
        targets = Targets(y=y, metadata=md_targets)
        return Sample(key=key, inputs=inputs, targets=targets)

    def _make_sparse_index_sample(self, key) -> Sample:
        sample = self._make_masked_sample(key)
        return sample.sparsify_index()

    def _make_sparse_sample(self, key) -> Sample:
        sample = self._make_sparse_column_sample(key)
        return sample.sparsify_index()