# TripletDecoder = Inverse of TripletEncoder

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
import logging
from typing import Literal, Optional

import numpy as np
import pandas as pd
import pandas.api.types
from pandas import DataFrame, Index, MultiIndex, Series

import tsdm
from tsdm.encoders import BaseEncoder

In [None]:
class TripletDecoder(BaseEncoder):
    r"""Encode the data into triplets."""

    categories: pd.CategoricalDtype
    r"""The stored categories."""
    dtypes: Series
    r"""The original dtypes."""
    original_columns: Index
    r"""The original columns."""

    def __init__(
        self,
        sparse: bool = True,
        values: Optional[str] = None,
        variables: Optional[str] = None,
        check_categoricals: bool = True,
    ) -> None:
        r"""Initialize the encoder.

        Parameters
        ----------
        sparse: bool = True
        """
        super().__init__()
        self.sparse = sparse
        self.value_name = values
        self.var_name = variables
        self.check_categoricals = check_categoricals

    def fit(self, data: DataFrame, /) -> None:
        r"""Fit the encoder.

        Parameters
        ----------
        data
        """
        self.original_dtypes = data.dtypes
        self.original_columns = data.columns

        self.value_column = self.value_name or data.columns[0]
        self.value_name = self.value_column
        print(self.value_column)
        assert self.value_column in data.columns

        remaining_cols = data.columns.drop(self.value_column)

        if self.sparse:
            self.channel_columns = remaining_cols
            categories = self.channel_columns
            self.var_name = self.channel_columns.name or "variable"
        else:
            assert len(remaining_cols) == 1
            self.channel_columns = remaining_cols.item()
            categories = data[self.channel_columns].unique()
            self.var_name = self.channel_columns

        if self.check_categoricals and pd.api.types.is_float_dtype(categories):
            raise ValueError(
                f"channel_ids found in '{self.var_name}' does no look like a"
                " categoricals!\n Please specify `values` and/or `variables`!\n Or,"
                " silence this error with `check_categoricals=False`."
            )

        self.categories = pd.CategoricalDtype(np.sort(categories))

    def encode(self, data: DataFrame, /) -> DataFrame:
        r"""Decode the data."""
        if self.sparse:
            df = data.loc[:, self.channel_columns].stack()
            df = df[df == 1]
            df.index = df.index.rename(self.var_name, level=-1)
            df = df.reset_index(level=-1)
            df[self.value_name] = data[self.value_column]
        else:
            df = data

        df = df.pivot_table(
            # TODO: FIX with https://github.com/pandas-dev/pandas/pull/45994
            # simply use df.index.names instead then.
            index=df.index,
            columns=self.var_name,
            values=self.value_name,
            dropna=False,
        )

        if isinstance(data.index, MultiIndex):
            df.index = MultiIndex.from_tuples(df.index, names=data.index.names)

        # re-add missing columns
        for cat in self.categories.categories:
            if cat not in df.columns:
                df[cat] = float("nan")  # TODO: replace with pd.NA when supported

        result = df[self.categories.categories]  # fix column order
        return result.sort_index()

    def decode(
        self,
        data: DataFrame,
        /,
        encoded_names: Optional[dict[Literal["index", "channel", "value"], str]] = None,
    ) -> DataFrame:
        r"""Encode the data."""
        result = data.melt(
            ignore_index=False,
            var_name=self.var_name,
            value_name=self.value_name,
        ).dropna()

        if self.sparse:
            result = pd.get_dummies(
                result, columns=[self.var_name], sparse=True, prefix="", prefix_sep=""
            )

        result = result.astype(self.original_dtypes)
        result = result.sort_index()

        return result

In [None]:
ds = tsdm.datasets.MIMIC_III()
ts = ds.observations.set_index(["UNIQUE_ID", "TIME_STAMP"])

In [None]:
decoder = TripletDecoder(sparse=False)
decoder.fit(ts)
decoded = decoder.encode(ts[:1000])

In [None]:
decoder.decode(decoded)

In [None]:
ds = tsdm.datasets.MIMIC_III()
ts = ds.observations.set_index(["UNIQUE_ID", "TIME_STAMP"])

In [None]:
enc = TripletDecoder(sparse=False)
enc.fit(ts)
encoded = enc.encode(ts[:1000])
decoded = enc.decode(encoded)

In [None]:
from tsdm.encoders import TripletEncoder

ds = tsdm.datasets.Electricity()
ts = ds.dataset
encoder = TripletEncoder(sparse=True)
encoder.fit(ts)
encoded = encoder.encode(ts[:100])

In [None]:
ts = ts.reset_index().reset_index().set_index(["index", "time"])

In [None]:
enc = TripletEncoder(sparse=False)
enc.fit(ts)
encoded = enc.encode(ts[:1000])

In [None]:
decoded = enc.decode(encoded)