# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
import numpy as np
import pandas as pd
import pandas.api.types
from pandas import DataFrame, MultiIndex, Series

from tsdm.encoders import BaseEncoder

In [None]:
class TripletEncoder(BaseEncoder):
    r"""Encode the data into triplets."""

    categories: pd.CategoricalDtype
    r"""The stored categories."""
    dtypes: Series
    r"""The original dtypes."""

    def __init__(self, sparse: bool = True) -> None:
        r"""Initialize the encoder.

        Parameters
        ----------
        sparse: bool = True
        """
        super().__init__()
        self.sparse = sparse

    def fit(self, data: DataFrame) -> None:
        r"""Fit the encoder.

        Parameters
        ----------
        data
        """
        self.categories = pd.CategoricalDtype(data.columns)
        self.dtypes = data.dtypes
        # result = data.melt(ignore_index=False)
        # # observed = result["value"].notna()
        # # result = result[observed]
        # variable = result.columns[0]
        # result[variable] = result[variable].astype(pd.StringDtype())
        # self.categories = pd.CategoricalDtype(result[variable].unique())

    def encode(self, df: DataFrame) -> DataFrame:
        r"""Encode the data."""
        result = df.melt(ignore_index=False).dropna()
        # observed = result["value"].notna()
        # result = result[observed]
        variable = result.columns[0]
        result[variable] = result[variable].astype(pd.StringDtype())
        result[variable] = result[variable].astype(self.categories)
        result.rename(columns={variable: "variable"}, inplace=True)
        # result.index.rename("time", inplace=True)
        # result.sort_values(by=["time", "variable"], inplace=True)
        result = result.sort_index()
        if not self.sparse:
            return result
        return pd.get_dummies(
            result, columns=["variable"], sparse=True, prefix="", prefix_sep=""
        )

    def decode(self, data: DataFrame, /) -> DataFrame:
        r"""Decode the data."""
        if self.sparse:
            df = data.iloc[..., 1:].stack()
            # df = df[df == 1].stack().reset_index(level=-1)
            # df["value"] = data["value"]
            # df = df.rename(columns={"level_1": "variable"})

            df = df[df == 1]
            df.index = df.index.rename("variable", level=-1)
            df = df.reset_index(level=-1)
            df["value"] = data["value"]
        else:
            df = data
        df = df.pivot_table(
            # TODO: FIX with https://github.com/pandas-dev/pandas/pull/45994
            # simply use df.index.names instead then.
            index=df.index,
            columns="variable",
            values="value",
            dropna=False,
        )
        if isinstance(data.index, MultiIndex):
            df.index = MultiIndex.from_tuples(df.index, names=data.index.names)

        # re-add missing columns
        for cat in self.categories.categories:
            if cat not in df.columns:
                df[cat] = float("nan")  # TODO: replace with pd.NA when supported

        result = df[self.categories.categories]  # fix column order
        result = result.astype(self.dtypes)
        return result

In [None]:
from tsdm.tasks import KIWI_FINAL_PRODUCT

task = KIWI_FINAL_PRODUCT()
ts = task.timeseries.sort_index(axis="index").sort_index(axis="columns")
channel_freq = pd.notna(ts).mean().sort_values()

fast_channels = channel_freq[channel_freq >= 0.1].index
slow_channels = channel_freq[channel_freq < 0.1].index
FAST = ts[fast_channels].dropna(how="all")
SLOW = ts[slow_channels].dropna(how="all")
groups = {"fast": fast_channels, "slow": slow_channels}

In [None]:
ts = ts.loc[439, 15325]

In [None]:
enc = TripletEncoder(sparse=True)
enc.fit(ts)
encoded = enc.encode(ts)

In [None]:
encoded.index.is_monotonic_increasing

In [None]:
decoded = enc.decode(encoded)
pd.testing.assert_frame_equal(ts, decoded)

In [None]:
decoded.index.is_monotonic_increasing

In [None]:
ts