In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from typing import Optional

import numpy as np

from tsdm.encoders import BaseEncoder

In [None]:
class BoxCoxEncoder(BaseEncoder):
    r"""Encode data on loggarithmic scale.


    .. math:: x ↦ \log(x+c)

    We consider multiple ideas for how to fit the parameter $c$

    1. Half the minimal non-zero value: `c = min(data[data>0])/2`
    2. Square of the first quartile divided by the third quartile (Stahle 2002)
    3. Value which minimizes the Wasserstein distance to a mean-0, variance-1 uniform distribution
    """

    AVAILABLE_METHODS = ["none", "wasserstein", "minimum", "quartile"]

    method: str
    param: np.ndarray

    def __init__(
        self, *, method: str = "", initial_param: Optional[np.ndarray] = None
    ) -> None:

        if method not in self.AVAILABLE_METHODS:
            raise ValueError(f"{method=} unknown. Available: {self.AVAILABLE_METHODS}")
        if method == "none" and initial_param is None:
            raise ValueError(f"Needs to provide initial param if no fitting.")

        self.method = method
        self.initial_param = initial_param
        super().__init__()

    @staticmethod
    def _wasserstein_uniform(x: np.ndarray, axis=-1) -> np.ndarray:
        r"""Signature: `[..., n] -> ...`."""
        n = x.shape[axis]
        k = np.arange(1, n + 1)
        r = x**2 + 2 * np.sqrt(3) * (1 - (2 * k - 1) / N) * x + 3
        return np.mean(r, axis=axis)
        
        
    def fit(self, data, /) -> None:

        assert np.all(data >= 0)
        method  = self.method
        
        match self.method:
            case "none":
                self.param = self.initial_param
            case "minimum":
                self.param = data[data>0].min() / 2
            case "quartile":
                self.param = ( np.quantile(data, 0.25) / np.quantile(data, 0.75) ) **2
            case "wasserstein":
                

    def encode(self, data, /):
        # TODO: Use copy on data.
        result = data.copy()
        mask = data <= 0
        result[:] = np.where(mask, self.replacement, np.log2(data))
        return result

    def decode(self, data, /):
        result = 2**data
        mask = result < self.threshold
        result[:] = np.where(mask, 0, result)
        return result

In [None]:
def _wasserstein_uniform(x: np.ndarray, axis=-1) -> np.ndarray:
    r"""Signature: `[..., n] -> ...`."""
    x = np.sort(x, axis=axis)
    n = x.shape[axis]
    k = np.arange(1, n + 1)
    r = x**2 + 2 * np.sqrt(3) * (1 - (2 * k - 1) / n) * x + 3
    return np.mean(r, axis=axis)

In [None]:
def wasserstein_uniform_generic(x: np.ndarray, axis=) -> np.ndarray:
    r""".. Signature:: `[..., n] -> ...`."""




In [None]:
x = np.random.rand(3, 3, 10000)
x = (x - x.mean()) / x.std()

In [None]:
_wasserstein_uniform(x)

In [None]:
n = np.array(7)
np.arange(1, n + 1)

In [None]:
def iquantile(arr, q, axis=-1):
    r"""Inverse Quantile function for empirical distribution

    x_min + ∑ (x_{k+1} -x_k) H(q - k/N)

    .. Signature: `[(..., n), 1] -> ...`
    """
    x = np.sort(arr, axis=axis)
    dx = np.diff(x, axis=axis)
    N = x.shape[axis]
    k = np.arange(1, N)
    mask = q >= k / N
    return x[0] + np.sum(dx[mask], axis=axis)

In [None]:
data = np.random.rand(1000)

In [None]:
np.quantile(data, 0.85)

In [None]:
iquantile(data, 0.85)