# DataFrameSplitter

This Encoder splits a DataFrame into multiple DataFrames / Series


In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

# import logging
# logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np
import pandas as pd
from pandas import *

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
from tsdm.encoders import *
from tsdm.tasks import KIWI_FINAL_PRODUCT

task = KIWI_FINAL_PRODUCT()
ts = task.timeseries.sort_index(axis="index").sort_index(axis="columns")

In [None]:
channel_freq = pd.notna(ts).mean().sort_values()

fast_channels = channel_freq[channel_freq >= 0.1].index
slow_channels = channel_freq[channel_freq < 0.1].index
FAST = ts[fast_channels].dropna(how="all")
SLOW = ts[slow_channels].dropna(how="all")
groups = {"fast": fast_channels, "slow": slow_channels}

In [None]:
from collections.abc import Iterable, Sequence
from typing import Any

from pandas import DataFrame, Series

In [None]:
class DataFrameSplitter(BaseEncoder):
    columns: Index
    dtypes: Series
    groups: dict[Any, Sequence[Any]]

    @staticmethod
    def _pairwise_disjoint(groups: Iterable[Sequence[Any]]) -> bool:
        union: set[HashableType] = set().union(*(set(obj) for obj in groups))
        n = sum(len(u) for u in groups)
        return n == len(union)

    def __init__(self, groups: dict[Any, Sequence[Any]]) -> None:
        super().__init__()
        self.groups = groups
        assert self._pairwise_disjoint(self.groups.values())

    def fit(self, data) -> None:
        self.columns = data.columns
        self.dtypes = data.dtypes

    def encode(self, data: DataFrame) -> tuple[DataFrame, ...]:
        encoded = []
        for group, columns in self.groups.items():
            encoded.append(data[columns].dropna(how="all"))
        return tuple(encoded)

    def decode(self, data: tuple[DataFrame, ...]) -> DataFrame:
        decoded = pd.concat(data, axis="columns")
        decoded = decoded.astype(self.dtypes)
        decoded = decoded[self.columns]
        return decoded

In [None]:
encoder = DataFrameSplitter(groups)
encoder.fit(ts)
encoded = encoder.encode(ts)
decoded = encoder.decode(encoded)
pd.testing.assert_frame_equal(ts, decoded)

In [None]:
encoded

In [None]:
enc = TripletEncoder()
enc.fit(encoded[0])
enc.encode(encoded[0])

In [None]:
encoder = (TripletEncoder() | TripletEncoder()) @ DataFrameSplitter(groups)
encoder.fit(ts)
encoded = encoder.encode(ts)
decoded = encoder.decode(encoded)
pd.testing.assert_frame_equal(decoded, ts)