# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
import logging
from collections.abc import Hashable, Iterable, Mapping
from typing import Any, Optional, Union

import numpy as np
import pandas as pd
import pandas.api.types
from pandas import DataFrame, Index, MultiIndex, Series
from pandas.core.indexes.frozen import FrozenList

from tsdm.datasets import KIWI_RUNS
from tsdm.encoders import *
from tsdm.encoders import BaseEncoder
from tsdm.tasks import KIWI_FINAL_PRODUCT

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
ds = KIWI_RUNS()
ts = ds.timeseries

In [None]:
task = KIWI_FINAL_PRODUCT()
ts = task.timeseries.sort_index(axis="index").sort_index(axis="columns")
channel_freq = pd.notna(ts).mean().sort_values()
fast_channels = FrozenList(channel_freq[channel_freq >= 0.1].index)
slow_channels = FrozenList(channel_freq[channel_freq < 0.1].index)
FAST = ts[fast_channels].dropna(how="all")
SLOW = ts[slow_channels].dropna(how="all")
groups = {"fast": fast_channels, "slow": slow_channels}

In [None]:
class FrameSplitter(BaseEncoder, Mapping):
    r"""Split a DataFrame into multiple groups.

    The special value ``...`` (:class:`Ellipsis`) can be used to indicate
    that all other columns belong to this group.

    This function can be used on index columns as well.
    """

    column_columns: Index
    column_dtypes: Series
    column_indices: list[int]

    index_columns: Index
    index_dtypes = Series
    index_indices: list[int]

    # FIXME: Union[types.EllipsisType, set[Hashable]] in 3.10
    groups: dict[Hashable, Union[Hashable, list[Hashable]]]
    group_indices: dict[Hashable, list[int]]

    indices: dict[Hashable, list[int]]
    has_ellipsis: bool = False
    ellipsis: Optional[Hashable] = None

    permutation: list[int]
    inverse_permutation: list[int]

    # @property
    # def names(self) -> set[Hashable]:
    #     r"""Return the union of all groups."""
    #     sets: list[set] = [
    #         set(obj) if isinstance(obj, Iterable) else {Ellipsis}
    #         for obj in self.groups.values()
    #     ]
    #     union: set[Hashable] = set.union(*sets)
    #     assert sum(len(u) for u in sets) == len(union), "Duplicate columns!"
    #     return union

    def __init__(
        self,
        groups: Iterable[Hashable],
        /,
        keep_index: bool = True,
        dropna: bool = False,
    ) -> None:
        super().__init__()

        if not isinstance(groups, Mapping):
            groups = dict(enumerate(groups))

        self.groups = {}
        for key, obj in groups.items():
            if obj is Ellipsis:
                self.groups[key] = obj
                self.ellipsis = key
                self.has_ellipsis = True
            elif isinstance(obj, str) or not isinstance(obj, Iterable):
                self.groups[key] = [obj]
            else:
                self.groups[key] = list(obj)

        self.keep_index = keep_index
        self.dropna = dropna

    def __repr__(self):
        r"""Return a string representation of the object."""
        return repr_mapping(self)

    def __len__(self):
        r"""Return the number of groups."""
        return len(self.groups)

    def __iter__(self):
        r"""Iterate over the groups."""
        return iter(self.groups)

    def __getitem__(self, item):
        r"""Return the group."""
        return self.groups[item]

    def fit(self, original: DataFrame, /) -> None:
        r"""Fit the encoder."""
        columns = DataFrame(original).copy()
        index = columns.index.to_frame()

        self.column_dtypes = original.dtypes
        self.column_columns = original.columns
        self.index_columns = index.columns
        self.index_dtypes = index.dtypes

        assert not (
            j := set(self.index_columns) & set(self.column_columns)
        ), f"index columns and data columns must be disjoint {j}!"

        data = pd.concat([index, columns], axis="columns")

        if not self.keep_index:
            data = data.reset_index(drop=True)

        def get_idx(cols: Any) -> list[int]:
            return [data.columns.get_loc(i) for i in cols]

        self.indices: dict[Hashable, int] = dict(enumerate(data.columns))
        self.group_indices: dict[Hashable, list[int]] = {}
        self.column_indices = get_idx(self.column_columns)
        self.index_indices = get_idx(self.index_columns)

        # replace ellipsis indices
        if self.has_ellipsis:
            # FIXME EllipsisType in 3.10
            fixed_cols = set().union(
                *(
                    set(cols)  # type: ignore[arg-type]
                    for cols in self.groups.values()
                    if cols is not Ellipsis
                )
            )
            ellipsis_columns = [c for c in data.columns if c not in fixed_cols]
            self.groups[self.ellipsis] = ellipsis_columns

        # set column indices
        self.permutation = []
        for group, columns in self.groups.items():
            if columns is Ellipsis:
                continue
            self.group_indices[group] = get_idx(columns)
            self.permutation += self.group_indices[group]
        self.inverse_permutation = np.argsort(self.permutation).tolist()
        # sorted(p.copy(), key=p.__getitem__)

    def encode(self, original: DataFrame, /) -> tuple[DataFrame, ...]:
        r"""Encode the data."""
        # copy the frame and add index as columns.
        columns = DataFrame(original).copy()
        index = columns.index.to_frame()
        data = pd.concat([index, columns], axis="columns")

        if not self.keep_index:
            data = data.reset_index(drop=True)

        data_columns = set(data.columns)

        assert data_columns <= set(self.indices.values()), (
            f"Unknown columns {data_columns - set(self.indices)}."
            "If you want to encode unknown columns add a group ``...`` (Ellipsis)."
        )

        encoded = []
        for columns in self.groups.values():
            encoded.append(data[columns].squeeze(axis="columns"))
        return tuple(encoded)

    def decode(self, data: tuple[DataFrame, ...], /) -> DataFrame:
        r"""Decode the data."""
        data = tuple(DataFrame(x) for x in data)
        joined = pd.concat(data, axis="columns")

        # unshuffle the columns, restoring original order
        joined = joined.iloc[..., self.inverse_permutation]

        # Assemble the columns
        columns = joined.iloc[..., self.column_indices]
        columns.columns = self.column_columns
        columns = columns.astype(self.column_dtypes)
        columns = columns.squeeze(axis="columns")

        # assemble the index
        index = joined.iloc[..., self.index_indices]
        index.columns = self.index_columns
        index = index.astype(self.index_dtypes)
        index = index.squeeze(axis="columns")

        if isinstance(index, Series):
            decoded = columns.set_index(index)
        else:
            decoded = columns.set_index(MultiIndex.from_frame(index))
        return decoded

In [None]:
import warnings

from pandas.core.indexes.frozen import FrozenList


def pairwise_disjoint(sets: Iterable[set]):
    union = set().union(*sets)
    return len(union) == sum(len(s) for s in sets)


class FrameSplitter(BaseEncoder, Mapping):
    r"""Split a DataFrame into multiple groups.

    The special value ``...`` (:class:`Ellipsis`) can be used to indicate
    that all other columns belong to this group.

    This function can be used on index columns as well.

    Index mapping

    [0|1|2|3|4|5]

    [2|0|1], [5|4]

    corresponds to mapping

    +---+---+---+---+---+---+
    | 0 | 1 | 2 | 3 | 4 | 5 |
    +===+===+===+===+===+===+
    | 1 | 2 | 0 | - | 5 | 4 |
    +---+---+---+---+---+---+


    with inverse

    +---+---+---+---+---+---+
    | 0 | 1 | 2 | 3 | 4 | 5 |
    +===+===+===+===+===+===+
    | 1 | 2 | 0 | - | 5 | 4 |
    +---+---+---+---+---+---+


    """

    column_columns: Index
    column_dtypes: Series
    column_indices: list[int]

    index_columns: Index
    index_dtypes = Series
    index_indices: list[int]

    # FIXME: Union[types.EllipsisType, set[Hashable]] in 3.10
    groups: dict[Hashable, Union[Hashable, list[Hashable]]]
    group_indices: dict[Hashable, list[int]]

    indices: dict[Hashable, list[int]]
    has_ellipsis: bool = False
    ellipsis: Optional[list[Hashable]] = None

    permutation: list[int]
    inverse_permutation: list[int]

    def __init__(
        self,
        groups: Union[Iterable[Hashable], Mapping[Hashable, Hashable]],
        /,
        dropna: bool = False,
        fillna: bool = True,
    ) -> None:
        super().__init__()

        if not isinstance(groups, Mapping):
            groups = dict(enumerate(groups))

        self.groups = {}
        for key, obj in groups.items():
            if obj is Ellipsis:
                self.groups[key] = obj
                self.ellipsis = key
                self.has_ellipsis = True
            elif isinstance(obj, str) or not isinstance(obj, Iterable):
                self.groups[key] = FrozenList([obj])
            else:
                self.groups[key] = FrozenList(obj)

        column_sets = [
            set(cols) for cols in self.groups.values() if cols is not Ellipsis
        ]
        self.fixed_columns = set().union(*column_sets)
        assert pairwise_disjoint(column_sets)

        self.inverse_groups = {}
        for group, columns in self.groups.items():
            if columns is Ellipsis:
                continue
            for column in columns:
                inverse_groups[column] = group

        # self.keep_index = keep_index
        self.dropna = dropna
        self.fillna = fillna

    def __repr__(self):
        r"""Return a string representation of the object."""
        return repr_mapping(self)

    def __len__(self):
        r"""Return the number of groups."""
        return len(self.groups)

    def __iter__(self):
        r"""Iterate over the groups."""
        return iter(self.groups)

    def __getitem__(self, item):
        r"""Return the group."""
        return self.groups[item]

    def fit(self, original: DataFrame, /) -> None:
        r"""Fit the encoder."""
        data = DataFrame(original).copy()

        if self.dropna and not df.index.is_monotonic_increasing:
            raise ValueError(f"If {self.dropna=}, Index must be monotonic increasing!")
        self.original_dtypes = original.dtypes
        self.original_columns = original.columns

        self.variable_indices = {col: [] for col in self.original_columns}
        for group, columns in self.groups.items():
            if columns is Ellipsis:
                continue
            for column in columns:
                self.variable_indices[column].append(group)

        if self.has_ellipsis:
            self.ellipsis_columns = [
                c for c in data.columns if c not in self.fixed_columns
            ]
        else:
            unused_columns = (
                set() if self.has_ellipsis else set(data.columns) - self.fixed_columns
            )
            data = data.drop(columns=unused_columns)

        columns_index = data.columns.to_series().reset_index(drop=True)
        reverse_index = Series(columns_index.index, index=columns_index)

        self.indices: dict[Hashable, int] = dict(enumerate(data.columns))
        self.group_indices: dict[Hashable, list[int]] = {}

        # set column indices
        self.permutation = []
        for group, columns in self.groups.items():
            if columns is Ellipsis:
                self.group_indices[group] = reverse_index[
                    self.ellipsis_columns
                ].to_list()
            else:
                self.group_indices[group] = reverse_index[columns].to_list()
            self.permutation += self.group_indices[group]

        # compute inverse permutation
        self.inverse_permutation = np.argsort(self.permutation).tolist()
        # sorted(p.copy(), key=p.__getitem__)

    def encode(self, original: DataFrame, /) -> tuple[DataFrame, ...]:
        r"""Encode the data."""
        # copy the frame and add index as columns.
        data = DataFrame(original).copy()
        # index = columns.index.to_frame()
        # data = pd.concat([index, columns], axis="columns")

        # if not self.keep_index:
        #     data = data.reset_index(drop=True)

        if not self.has_ellipsis and set(data.columns) > self.fixed_columns:
            warnings.warn(
                f"Unknown columns {set(data.columns) - self.fixed_columns}."
                "If you want to encode unknown columns add a group ``...`` (Ellipsis)."
            )

        encoded_frames = []
        for columns in self.groups.values():
            if columns is Ellipsis:
                encoded = data[self.ellipsis_columns]
            else:
                encoded = data[columns]
            if self.dropna:
                encoded = encoded.dropna(axis="index", how="all")
            encoded_frames.append(encoded)

        return tuple(encoded_frames)

    def decode(self, data: tuple[DataFrame, ...], /) -> DataFrame:
        r"""Decode the data."""
        data = tuple(DataFrame(x) for x in data)
        joined = pd.concat(data, axis="columns")

        # bring columns in order
        joined = joined.iloc[..., self.inverse_permutation]
        reconstructed = DataFrame(columns=self.original_columns)
        reconstructed[joined.columns] = joined
        reconstructed = reconstructed.astype(self.original_dtypes)

        if self.dropna:
            reconstructed = reconstructed.sort_index()
        return reconstructed

In [None]:
T = ts.iloc[:200]

encoder = FrameSplitter(
    [slow_channels, fast_channels],
    dropna=True,
    # {"D" : ["run_id","measurement_time"], "A": "Flow_Air", "B": ["StirringSpeed", "Temperature"], "C": Ellipsis}
)
encoder.fit(T)

In [None]:
encoded = encoder.encode(T)
encoded[0]

In [None]:
decoded = encoder.decode(encoded)

In [None]:
pandas.testing.assert_frame_equal(T, decoded)

In [None]:
from abc import ABC, abstractmethod


class A:
    
    @abstractmethod
    def f(self, ...):
        ...

In [None]:
pandas.testing.assert_frame_equal(T, decoded)

In [None]:
encoder.names