In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import numpy as np
import pandas as pd
import torch
import pyarrow as pa

In [None]:
import logging

import numpy as np
import pandas as pd
import torch
from pytest import mark, skip

from tsdm.encoders.numerical import (
    BoundaryEncoder,
    LinearScaler,
    MinMaxScaler,
    StandardScaler,
    get_broadcast,
    get_reduced_axes,
)

logging.basicConfig(level=logging.INFO)
__logger__ = logging.getLogger(__name__)

In [None]:
backend = "numpy"

_data = [-2.0, -1.1, -1.0, -0.9, 0.0, 0.3, 0.5, 1.0, 1.5, 2.0]

match backend:
    case "numpy":
        data = np.array(_data)
    case "torch":
        data = torch.tensor(_data)
    case "pandas":
        data = pd.Series(_data)
    case _:
        raise ValueError(f"Unknown backend {backend}")

# test clip + numpy
encoder = BoundaryEncoder(-1, +1, mode="clip")
encoder.fit(data)
encoded = encoder.encode(data)
assert all((encoded >= -1) & (encoded <= 1))
assert (encoded == -1).sum() == (data <= -1).sum()
assert (encoded == +1).sum() == (data >= +1).sum()

if isinstance(data, pd.Series):
    assert (
        isinstance(encoded, pd.Series)
        and encoded.shape == data.shape
        and encoded.name == data.name
        and encoded.index.equals(data.index)
    )
if isinstance(data, torch.Tensor):
    assert isinstance(encoded, torch.Tensor) and encoded.shape == data.shape
if isinstance(data, np.ndarray):
    assert isinstance(encoded, np.ndarray) and encoded.shape == data.shape

# test numpy + mask
encoder = BoundaryEncoder(-1, +1, mode="mask")
encoder.fit(data)
encoded = encoder.encode(data)
assert all(np.isnan(encoded) ^ ((encoded >= -1) & (encoded <= 1)))
assert np.isnan(encoded).sum() == ((data < -1).sum() + (data > +1).sum())


# test fitting with mask
encoder = BoundaryEncoder(mode="mask")
encoder.fit(data)
encoded = encoder.encode(data)
decoded = encoder.decode(encoded)
assert not any(np.isnan(encoded))
assert all(data == encoded)
assert all(data == decoded)

# encode some data that violates bounds
data2 = data * 2
encoded2 = encoder.encode(data2)
xmin, xmax = data.min(), data.max()
mask = (data2 >= xmin) & (data2 <= xmax)
assert all(encoded2[mask] == data2[mask])
assert all(np.isnan(encoded2[~mask]))

In [None]:
# test half-open interval + clip
encoder = BoundaryEncoder(0, None, mode="clip")
encoder.fit(data)
encoded = encoder.encode(data)
assert all(encoded >= 0)
assert (encoded == 0).sum() == (data <= 0).sum()

In [None]:
encoder.fit(data)
encoded = encoder.encode(data)

In [None]:
encoder.upper_mask(data)

In [None]:
encoder.upper_bound

In [None]:
# # test half-open unbounded interval + mask
# encoder = BoundaryEncoder(0, None, mode="mask")
# encoder.fit(data)
# encoded = encoder.encode(data)
# assert all(np.isnan(encoded) ^ (encoded >= 0))
# assert np.isnan(encoded).sum() == (data < 0).sum()

# # test half-open bounded interval + mask
# encoder = BoundaryEncoder(0, 1, mode="mask", lower_included=False)
# encoder.fit(data)
# encoded = encoder.encode(data)
# assert all(np.isnan(encoded) ^ (encoded > 0))
# assert np.isnan(encoded).sum() == ((data <= 0).sum() + (data > 1).sum())