# Time Tensor

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from typing import Optional, Union

import numpy as np
import torch
from numpy.typing import ArrayLike
from pandas import DataFrame, Index, Series
from torch import Tensor

In [None]:
from tsdm.datasets import ETTh1

In [None]:
ds = ETTh1.dataset

In [None]:
Tensor(ds.values)

In [None]:
from torch.utils.data import Dataset, TensorDataset

In [None]:
class IndexMethodClone:
    def __init__(self, data, index_method):
        self.data = data
        self.index_method = index_method

    def __getitem__(self, item):
        print(item)
        idx = self.index_method[item]
        return self.data[idx]

In [None]:
class TimeTensor(Tensor):
    @staticmethod
    def __new__(
        cls,
        x: Union[Tensor, DataFrame, Series, ArrayLike],
        *args,
        index: Optional[Index] = None,
        **kwargs,
    ):
        print(args, kwargs)
        if isinstance(x, DataFrame) or isinstance(x, Series):
            assert index is None, "Index given, but x is DataFrame/Series"
            x = x.values
        return super().__new__(cls, x, *args, **kwargs)

    def __init__(
        self,
        x: Union[Tensor, DataFrame, Series, ArrayLike],
        index: Optional[Index] = None,
    ):
        super().__init__()  # optional
        if isinstance(x, DataFrame) or isinstance(x, Series):
            index = x.index
        else:
            index = Index(np.arange(len(x))) if index is None else index

        self.index = Series(np.arange(len(x)), index=index)
        self.loc = IndexMethodClone(self, self.index.loc)
        self.iloc = IndexMethodClone(self, self.index.iloc)
        self.at = IndexMethodClone(self, self.index.at)
        self.iat = IndexMethodClone(self, self.index.iat)

In [None]:
ts = TimeTensor(ds)

In [None]:
ts.loc["2016":"2017"]

### TimeTensor Type hint & Type Check

In [None]:
IndexedArray = Union[Series, DataFrame, TimeTensor]


def is_indexed_array(x) -> bool:
    return (
        isinstance(x, Series) or isinstance(x, DataFrame) or isinstance(x, TimeTensor)
    )

## TimeSeriesDataSet

In [None]:
def tensor_info(x: Tensor) -> str:
    return f"{x.__class__.__name__}[{tuple(x.shape)}, {x.dtype}, {x.device.type}]"

In [None]:
tensor_info(torch.randn(1, 2, 3))

In [None]:
from collections.abc import Collection, Iterable, Mapping

In [None]:
class _TupleIndexMethodClone:
    r"""Clone .loc and similar methods to tensor-like object."""

    def __init__(
        self, data: tuple[ArrayLike, ...], index: tuple[Index, ...], method: str = "loc"
    ):
        self.data = data
        self.index = index
        self.method = tuple(getattr(idx, method) for idx in self.index)

    def __getitem__(self, item):
        indices = tuple(method[item] for method in self.method)
        return tuple(data[indices] for data in self.data)

In [None]:
class TimeSeriesDataset(TensorDataset):
    """A general Time Series Dataset.

    Consists of 2 things
    - timeseries: TimeTensor / tuple[TimeTensor]
    - metadata: Tensor / tuple[Tensor]

    When retrieving items, we generally use slices:

    - ds[timestamp] = ds[timestamp:timestamp]
    - ds[t₀:t₁] = tuple[X[t₀:t₁] for X in self.timeseries], metadata
    """

    def __init__(
        self,
        *tensors: IndexedArray,
        observations: Optional[
            Union[
                IndexedArray,
                tuple[Index, ArrayLike],
                dict[Index, ArrayLike],
                Collection[IndexedArray],
                Collection[tuple[Index, ArrayLike]],
            ]
        ] = None,
        metadata: Optional[Union[Tensor, tuple[Tensor]]] = None,
    ):
        ts_tensors = []
        for tensor in tensors:
            ts_tensors.append(TimeTensor(tensor))

        if observations is not None:
            # Case: IndexedArray
            if is_indexed_array(observations):
                ts_tensors.append(TimeTensor(observations))
            # Case: tuple[Index, ArrayLike],
            elif (
                isinstance(observations, tuple)
                and len(observations) == 2
                and isinstance(obseravations[0], Index)
            ):
                index, tensor = obseravations
                ts_tensors.append(TimeTensor(tensor, index=index))
            # Case: dict[Index, ArrayLike],
            elif isinstance(observations, Mapping):
                for index, tensor in observations.items():
                    ts_tensors.append(TimeTensor(tensor, index=index))
            # Case: Iterable
            elif isinstance(observations, Iterable):
                observations = list(observations)
                firstobs = observations[0]
                # Case: Iterable[IndexedArray]
                if is_indexed_array(firstobs):
                    for tensor in observations:
                        ts_tensors.append(TimeTensor(tensor))
                # Case: Iterable[tuple(Index, ArrayLike)]
                elif (
                    isinstance(firstobs, tuple)
                    and len(firstobs) == 2
                    and isinstance(firstobs[0], Index)
                ):
                    for index, tensor in observations:
                        ts_tensors.append(TimeTensor(tensor, index=index))
                else:
                    raise ValueError(f"{observations=} not undertstood")
            else:
                raise ValueError(f"{observations=} not undertstood")

        self.timeseries = tuple(ts_tensors)

        if metadata is not None:
            if isinstance(metadata, tuple):
                self.metadata = tuple(torch.Tensor(tensor) for tensor in metadata)
            else:
                self.metadata = torch.Tensor(metadata)
        else:
            self.metadata = None

    def __repr__(self) -> str:
        pad = r"  "

        if isinstance(self.timeseries, tuple):
            ts_lines = [tensor_info(tensor) for tensor in self.timeseries]
        else:
            ts_lines = [tensor_info(self.timeseries)]

        if self.metadata is None:
            md_lines = [f"{None}"]
        elif isinstance(self.metadata, tuple):
            md_lines = [tensor_info(tensor) for tensor in self.metadata]
        else:
            md_lines = [tensor_info(self.metadata)]

        return (
            f"{self.__class__.__name__}("
            + "".join(["\n" + 2 * pad + line for line in ts_lines])
            + "\n"
            + pad
            + "metadata:"
            + "".join(["\n" + 2 * pad + line for line in md_lines])
            + "\n"
            + ")"
        )

    def __len__(self) -> int:
        r"""The length of the longest timeseries."""
        if isinstance(self.timeseries, tuple):
            return max(len(tensor) for tensor in self.timeseries)
        return len(self.timeseries)

    def __getitem__(self, item):
        r"""Return corresponding slice from each tensor."""
        return tuple(tensor.loc[item] for tensor in self.timeseries)

In [None]:
a = TimeSeriesDataset(ds, ds, ds)

In [None]:
a = TimeSeriesDataset(ds, ds, ds)
b = 2

In [None]:
ds

In [None]:
torch.randn(5)

In [None]:
isinstance(ds, ArrayLike)

In [None]:
ds["OT"].shape

In [None]:
class TensorDataset(Dataset[tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.
    Each sample will be retrieved by indexing tensors along the first dimension.
    Args:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    tensors: tuple[Tensor, ...]

    def __init__(self, *tensors: Tensor) -> None:
        assert all(
            tensors[0].size(0) == tensor.size(0) for tensor in tensors
        ), "Size mismatch between tensors"
        self.tensors = tensors

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def __len__(self):
        return self.tensors[0].size(0)

In [None]:
ts.loc["2016-07-01 02:00:00":"2016-07-01 02:00:00"]  # "2016-07-01 03:00:00"]

In [None]:
ts.index["2016-07-01 02:00:00":"2016-07-01 02:00:00"]

In [None]:
from pandas import DataFrame

df = DataFrame(np.random.randn(7, 3), index=np.arange(7), columns=["A", "B", "C"])
df.loc[2]

In [None]:
tuple(range(3))

In [None]:
torch.Tensor((ds.values, ds.values))