# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np
import pandas

np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
from tsdm.datasets import KIWI_RUNS

ds = KIWI_RUNS()
# ds.clean(force=True)
ts = ds.timeseries
ds.units

In [None]:
fluo = ts[["Cumulated_feed_volume_glucose", "Cumulated_feed_volume_medium"]].astype(
    "float32"
)
(fluo == 0).mean()

In [None]:
from tsdm.encoders import BaseEncoder

In [None]:
class LogEncoder(BaseEncoder):
    """Encode data on loggarithmic scale.

    Uses base 2 by default for lower numerical error and fast computation.
    """

    threshold: np.ndarray
    replacement: np.ndarray

    def __init__(self) -> None:
        super().__init__()

    def fit(self, data, /) -> None:
        assert np.all(data >= 0)

        mask = data == 0
        self.threshold = data[~mask].min()
        self.replacement = np.log2(self.threshold / 2)

    def encode(self, data, /):
        # TODO: Use copy on data.
        result = data.copy()
        mask = data <= 0
        result[:] = np.where(mask, self.replacement, np.log2(data))
        return result

    def decode(self, data, /):
        result = 2**data
        mask = result < self.threshold
        result[:] = np.where(mask, 0, result)
        return result

In [None]:
encoder = LogEncoder()
encoder.fit(fluo)
fluo

In [None]:
encoded = encoder.encode(fluo)

In [None]:
decoded = encoder.decode(encoded)

In [None]:
pandas.testing.assert_frame_equal(fluo, decoded)

In [None]:
(fluo >= 0).all().all().all()

In [None]:
np.all(fluo >= 0)

In [None]:
?fluo.min