# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np
from pandas import DataFrame, DatetimeIndex, Series

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
from tsdm.datasets import Electricity

ds = Electricity()

In [None]:
dt = ds.index
time_resolution = [
    dt.year,
    dt.month,
    dt.day,
    dt.hour,
    dt.minute,
    dt.second,
    dt.microsecond,
    dt.nanosecond,
]

In [None]:
import pandas as pd

pd.Timedelta(ds.index.inferred_freq)

In [None]:
class SocialTime:
    level_codes = {
        "Y": "year",
        "M": "month",
        "W": "weekday",
        "D": "day",
        "h": "hour",
        "m": "minute",
        "s": "second",
        "µ": "microsecond",
        "n": "nanosecond",
    }

    def __init__(self, levels: str = "YMWDhms") -> None:
        self.levels = [self.level_codes[k] for k in levels]

    def fit(self, x: Series, /) -> None:
        self.original_type = type(x)
        self.original_name = x.name
        self.original_dtype = x.dtype
        self.rev_cols = [l for l in self.levels if l != "weekday"]
        # self.new_names = {level:f"{x.name}_{level}" for level in self.levels}
        # self.rev_names = {f"{x.name}_{level}":level for level in self.levels if level != "weekday"}

    def encode(self, x, /):
        if isinstance(x, DatetimeIndex):
            res = {level: getattr(x, level) for level in self.levels}
        else:
            res = {level: getattr(x, level) for level in self.levels}
        return DataFrame.from_dict(res)

    def decode(self, x, /):
        x = x[self.rev_cols]
        s = pd.to_datetime(x)
        return self.original_type(s, name=self.original_name, dtype=self.original_dtype)

In [None]:
enc = SocialTime()
enc.fit(ds.index)
encoded = enc.encode(ds.index)

In [None]:
enc.decode(encoded)

In [None]:
ds.index

In [None]:
from tsdm.encoders import *

In [None]:
enc = FrameEncoder(PeriodicEncoder(), duplicate=True) @ SocialTimeEncoder()
enc.fit(ds.index)
enc.encode(ds.index)

In [None]:
# from collections.abc import Mapping

# class PeriodicSocialTimeEncoder(SocialTimeEncoder):
#     r"""Combines SocialTimeEncoder with PeriodicEncoder using the right frequencies."""

#     frequencies = {
#         "Y": 1,
#         "M": 12,
#         "W": 7,
#         "D": 365,
#         "h": 24,
#         "m": 60,
#         "s": 60,
#         "µ": 1000,
#         "n": 1000,
#     }
#     column_encoders: Mapping[str, PeriodicEncoder]
#     encoder: BaseEncoder

#     def __init__(self, *, levels: str = "YMWDhms") -> None:
#         super().__init__(levels=levels)
#         self.column_encoders = {
#             level: PeriodicEncoder(period=self.frequencies[level])
#             for level in self.level_code
#         }
#         self.encoder = FrameEncoder(self.column_encoders) @ SocialTimeEncoder()

#     def fit(self, x: Series) -> None:
#         self.encoder.fit(x)

#     def encode(self, data: Series, /) -> DataFrame:
#         return self.encoder.encode(data)

#     def decode(self, data: DataFrame, /) -> Series:
#         return self.encoder.decode(data)

In [None]:
# class PeriodicSocialTimeEncoder(SocialTimeEncoder):
#     r"""Combines SocialTimeEncoder with PeriodicEncoder using the right frequencies."""

#     frequencies = {
#         "year": 1,
#         "month": 12,
#         "weekday": 7,
#         "day": 365,
#         "hour": 24,
#         "minute": 60,
#         "second": 60,
#         "microsecond": 1000,
#         "nanosecond": 1000,
#     }
#     column_encoders: Mapping[str, PeriodicEncoder]
#     encoder: BaseEncoder

#     def __new__(cls, levels: str = "YMWDhms") -> BaseEncoder:
#         self = super().__new__(cls)
#         self.__init__(levels)
#         column_encoders = {
#             level: PeriodicEncoder(period=self.frequencies[level])
#             for level in self.levels
#         }
#         return FrameEncoder(column_encoders) @ self

In [None]:
enc = PeriodicEncoder(5)
weekday = ds.index.weekday
enc.fit(weekday)
encoded = enc.encode(weekday)
decoded = enc.decode(encoded)

In [None]:
enc = PeriodicSocialTimeEncoder()
enc.fit(ds.index)
encoded = enc.encode(ds.index)

In [None]:
something = enc[0].decode(encoded)

In [None]:
pd.to_datetime(something[set])

In [None]:
something.max()

In [None]:
enc[1].decode(something)

In [None]:
decoded = enc.decode(encoded)

In [None]:
[encoded[col].unique().size for col in encoded]

In [None]:
print()

In [None]:
from pandas.core.indexes.frozen import FrozenList

encoded[FrozenList(["cos_year", "sin_year"])]