In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from typing import Optional

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from numpy import pi as π
from scipy.optimize import minimize
from scipy.special import erfinv

from tsdm.encoders import BaseEncoder

In [None]:
class BoxCoxEncoder(BaseEncoder):
    r"""Encode data on logarithmic scale.

    .. math:: x ↦ \log(x+c)

    We consider multiple ideas for how to fit the parameter $c$

    1. Half the minimal non-zero value: `c = min(data[data>0])/2`
    2. Square of the first quartile divided by the third quartile (Stahle 2002)
    3. Value which minimizes the Wasserstein distance to a mean-0, variance-1 uniform distribution
    """

    AVAILABLE_METHODS = ["none", "wasserstein", "minimum", "quartile"]

    method: str
    param: np.ndarray

    def __init__(
        self, *, method: str = "", initial_param: Optional[np.ndarray] = None
    ) -> None:
        if method not in self.AVAILABLE_METHODS:
            raise ValueError(f"{method=} unknown. Available: {self.AVAILABLE_METHODS}")
        if method == "none" and initial_param is None:
            raise ValueError("Needs to provide initial param if no fitting.")

        self.method = method
        self.initial_param = initial_param
        super().__init__()

    @staticmethod
    def _wasserstein_uniform(x: np.ndarray, axis=-1) -> np.ndarray:
        r"""Signature: `[..., n] -> ...`."""
        n = x.shape[axis]
        k = np.arange(1, n + 1)
        r = x**2 + 2 * np.sqrt(3) * (1 - (2 * k - 1) / N) * x + 3
        return np.mean(r, axis=axis)

    def fit(self, data, /) -> None:
        assert np.all(data >= 0)

        match self.method:
            case "none":
                self.param = self.initial_param
            case "minimum":
                self.param = data[data > 0].min() / 2
            case "quartile":
                self.param = (np.quantile(data, 0.25) / np.quantile(data, 0.75)) ** 2
            case "wasserstein": ...

    def encode(self, data, /):
        # TODO: Use copy on data.
        result = data.copy()
        mask = data <= 0
        result[:] = np.where(mask, self.replacement, np.log2(data))
        return result

    def decode(self, data, /):
        result = 2**data
        mask = result < self.threshold
        result[:] = np.where(mask, 0, result)
        return result

In [None]:
def construct_loss_wasserstein_uniform(x, model, a=-np.sqrt(3), b=+np.sqrt(3)):
    r"""Construct the loss for the Uniform distribution

    .. math::
        W₂² = ∑ₖ [αₖxₖ² -2βₖxₖ + αₖC] = ∑ₖ αₖ[xₖ² -2(βₖ/αₖ)xₖ + C]
        F^{-1}(q) &= a + (b-a)q
        β &= ∫ F^{-1}(q)dq = aq + ½(b-a)q²
        C &= ∫_0^1 F^{-1}(q)^2 dq = ⅓(a^2 + ab + b^2)
    """
    if (a, b) == (-np.sqrt(3), +np.sqrt(3)):
        C = 1

        def integrate_quantile(q):
            return np.sqrt(3) * q * (q - 1)

    else:
        C = (a**2 + a * b + b**2) / 3

        def integrate_quantile(q):
            return a * q + (b - a) * q**2 / 2

    unique, counts = np.unique(x, return_counts=True)
    α = counts / np.sum(counts)
    p = np.insert(np.cumsum(α), 0, 0).clip(0, 1)
    β = integrate_quantile(p[1:]) - integrate_quantile(p[:-1])
    μ = (b + a) / 2
    σ = abs(b - a) / np.sqrt(12)

    def fun(c):
        u = model(unique, c)
        # transform to target loc-scale
        mean = np.mean(u, axis=-1, keepdims=True)
        stdv = np.std(u, axis=-1, keepdims=True)
        y = (u - mean + μ) * (σ / stdv)
        return np.einsum("...i, i -> ...", y**2 - 2 * (β / α) * y + C, α)

    return fun

In [None]:
def construct_loss_wasserstein_normal(x, model, μ=0.0, σ=1.0):
    r"""Construct the loss for the Normal distribution

    .. math::
        W₂² = ∑ₖ [αₖxₖ² -2βₖxₖ + αₖC] = ∑ₖ αₖ[xₖ² -2(βₖ/αₖ)xₖ + C]
        F^{-1}(q) &= μ + σ√2\erf^{-1}(2q-1)
        β &= ∫_a^b F^{-1}(q)dq = (b-a)μ - σ/√(2π) (e^{-\erf^{-1}(2b-1)^2} - e^{-\erf^{-1}(2a-1)^2}
        C &= ∫_0^1 F^{-1}(q)^2 dq = μ^2 + σ^2
    """
    if (μ, σ) == (0, 1):
        C = 1

        def integrate_quantile(q):
            return -np.exp(-erfinv(2 * q - 1) ** 2) / np.sqrt(2 * π)

    else:
        C = μ**2 + σ**2

        def integrate_quantile(q):
            return μ * q - σ * np.exp(-erfinv(2 * q - 1) ** 2) / np.sqrt(2 * π)

    unique, counts = np.unique(x, return_counts=True)
    α = counts / np.sum(counts)
    p = np.insert(np.cumsum(α), 0, 0).clip(0, 1)
    β = integrate_quantile(p[1:]) - integrate_quantile(p[:-1])

    def fun(c):
        u = model(unique, c)
        # transform to target loc-scale
        mean = np.mean(u, axis=-1, keepdims=True)
        stdv = np.std(u, axis=-1, keepdims=True)
        y = (u - mean + μ) * (σ / stdv)
        return np.einsum("...i, i -> ...", y**2 - 2 * (β / α) * y + C, α)

    return fun

In [None]:
from tsdm.datasets import KIWI_RUNS

dataset = KIWI_RUNS()

ts = dataset.timeseries

data = np.array(ts.Glucose[pd.notna(ts.Glucose)].astype(float))

In [None]:
def model(x, c):
    return np.log(np.add.outer(c, x))

# Match Uniform

In [None]:
%matplotlib inline
fun = construct_loss_wasserstein_uniform(data, model)
c = np.logspace(-3, 2)
plt.semilogx(c, fun(c))

In [None]:
x0 = np.array([1.0])
sol = minimize(
    fun,
    x0,
    method="trust-constr",
    # jac=jac,
    # hess=hess,
    bounds=[(0, np.inf)],
    options={"disp": True},
)

values = [
    data[data > 0].min() / 2,
    np.quantile(data, 0.25) ** 2 / np.quantile(data, 0.75),
    np.quantile(data, 0.25) ** 2 / np.quantile(data, 0.75) ** 2,
    sol.x.squeeze(),
]

In [None]:
np.linalg.inv(np.array([[1, 1], [-1, 1]]))

In [None]:
%matplotlib inline

from scipy.stats import uniform

fig, axes = plt.subplots(
    ncols=len(values),
    constrained_layout=True,
    figsize=(3 * len(values), 3),
    sharey=True,
    sharex=True,
)

t = np.linspace(-6, +6, 1024)
for val, ax in zip(values, axes):
    z = model(data, val)
    μ = z.mean()
    σ = z.std()
    a = μ - np.sqrt(3) * σ
    b = μ + np.sqrt(3) * σ
    ax.hist(z, density=True, bins=50)
    ax.plot(t, uniform.pdf(t, loc=a, scale=b - a))
    # ax.set_yscale("log")

# Match Normal

In [None]:
fun = construct_loss_wasserstein_normal(data, model)
c = np.logspace(-3, 2)
plt.semilogx(c, fun(c))

In [None]:
x0 = np.array([1.0])
sol = minimize(
    fun,
    x0,
    method="trust-constr",
    # jac=jac,
    # hess=hess,
    bounds=[(0, np.inf)],
    options={"disp": True},
)

values = [
    data[data > 0].min() / 2,
    np.quantile(data, 0.25) ** 2 / np.quantile(data, 0.75),
    np.quantile(data, 0.25) ** 2 / np.quantile(data, 0.75) ** 2,
    sol.x.squeeze(),
]

In [None]:
%matplotlib inline

from scipy.stats import norm as normal

fig, axes = plt.subplots(
    ncols=len(values),
    constrained_layout=True,
    figsize=(3 * len(values), 3),
    sharey=True,
    sharex=True,
)

t = np.linspace(-6, +6, 1024)
for val, ax in zip(values, axes):
    z = model(data, val)
    μ = z.mean()
    σ = z.std()
    ax.hist(z, density=True, bins=50)
    ax.plot(t, normal.pdf(t, loc=μ, scale=σ))
    # ax.set_yscale("log")