In [None]:
%config InlineBackend.figure_format = 'retina'
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np

N = 10000

t = np.linspace(-np.pi, +np.pi, num=10000)
x = np.stack([np.cos(t), np.sin(t)])

In [None]:
import logging
from typing import Callable, Final, Generic, TypeVar, Union

import numpy as np
import torch
from torch import Tensor

logger = logging.getLogger(__name__)
__all__: Final[list[str]] = [
    "collate_list",
    "collate_packed",
    "collate_padded",
    "unpad_sequence",
    "upack_sequence",
]


T = TypeVar("T")


class Sample(Generic[T]):
    def __getitem__(self, T) -> Union[T, Callable[[], T]]: ...

In [None]:
def

In [None]:
np.mean(np.array([1, 2, 3]) ** 2) ** (1 / 2)

In [None]:
from tsdm.utils import scaled_norm

scaled_norm([1, 2, 3])

In [None]:
from functools import singledispatch
from typing import Iterable, Type

# from collections.abc import Iterable
import torch
from torch import Tensor


@singledispatch
def g(x):
    pass


@g.register
def f(x: list[int]) -> list[int]:
    return sum(x)


g([1, 2, 3])

In [None]:
isinstance([torch.randn], Iterable[Tensor])

In [None]:
fig, ax = plt.subplots(
    figsize=(16, 10), ncols=2, sharex=True, sharey=True, subplot_kw=dict(box_aspect=1)
)
ax[0].set_xlim([-1.5, +4.5])
ax[0].set_ylim([-1.5, +4.5])
ax[1].set_xlim([-1.5, +4.5])
ax[1].set_ylim([-1.5, +4.5])
for p in (ps := (np.inf, 4, 2, 1, 0.5, 0.25, 0)):
    ax[0].plot(*(x / scaled_norm(x, axis=0, p=p)), "-")
    if p:
        ax[1].plot(*(x / np.linalg.norm(x, axis=0, ord=p)), "-")
    else:
        ax[1].plot(
            *np.array([
                [0, 0],
                [0, 1],
                [0, 0],
                [0, -1],
                [0, 0],
                [-1, 0],
                [0, 0],
                [1, 0],
                [0, 0],
            ]).T,
            "-",
        )
    ax[0].legend(ps)
    ax[1].legend(ps)

In [None]:
from unittest.mock import ANY


def area(p):
    assert p.shape == ANY(), 2
    x, y = p.T
    n = len(x)
    shift = (np.arange(n) + 1) % n
    return np.sum(y[shift] * x - x[shift] * y) / 2

In [None]:
r"""Utility functions."""

from __future__ import annotations

import logging
from collections.abc import Mapping
from functools import singledispatch
from typing import Any, Final, Iterable, Type, Union

import numpy as np
import torch
from numpy import ndarray
from numpy.typing import ArrayLike, NDArray
from torch import Tensor, nn

logger = logging.getLogger(__name__)
__all__: Final[list[str]] = [
    "ACTIVATIONS",
    "deep_dict_update",
    "deep_kval_update",
    "relative_error",
    "scaled_norm",
]


ACTIVATIONS: Final[dict[str, Type[nn.Module]]] = {
    "AdaptiveLogSoftmaxWithLoss": nn.AdaptiveLogSoftmaxWithLoss,
    "ELU": nn.ELU,
    "Hardshrink": nn.Hardshrink,
    "Hardsigmoid": nn.Hardsigmoid,
    "Hardtanh": nn.Hardtanh,
    "Hardswish": nn.Hardswish,
    "LeakyReLU": nn.LeakyReLU,
    "LogSigmoid": nn.LogSigmoid,
    "LogSoftmax": nn.LogSoftmax,
    "MultiheadAttention": nn.MultiheadAttention,
    "PReLU": nn.PReLU,
    "ReLU": nn.ReLU,
    "ReLU6": nn.ReLU6,
    "RReLU": nn.RReLU,
    "SELU": nn.SELU,
    "CELU": nn.CELU,
    "GELU": nn.GELU,
    "Sigmoid": nn.Sigmoid,
    "SiLU": nn.SiLU,
    "Softmax": nn.Softmax,
    "Softmax2d": nn.Softmax2d,
    "Softplus": nn.Softplus,
    "Softshrink": nn.Softshrink,
    "Softsign": nn.Softsign,
    "Tanh": nn.Tanh,
    "Tanhshrink": nn.Tanhshrink,
    "Threshold": nn.Threshold,
}
r"""Utility dictionary, for use in model creation from Hyperparameter dicts."""


def _torch_is_float_dtype(x: Tensor) -> bool:
    return x.dtype in (
        torch.half,
        torch.float,
        torch.double,
        torch.bfloat16,
        torch.complex32,
        torch.complex64,
        torch.complex128,
    )


def deep_dict_update(d: dict, new_kvals: Mapping) -> dict:
    r"""Update nested dictionary recursively in-place with new dictionary.

    Reference: https://stackoverflow.com/a/30655448/9318372

    Parameters
    ----------
    d: dict
    new_kvals: Mapping
    """
    # if not inplace:
    #     return deep_dict_update(deepcopy(d), new_kvals, inplace=False)

    for key, value in new_kvals.items():
        if isinstance(value, Mapping) and value:
            d[key] = deep_dict_update(d.get(key, {}), value)
        else:
            # if value is not None or not safe:
            d[key] = new_kvals[key]
    return d


def deep_kval_update(d: dict, **new_kvals: dict) -> dict:
    r"""Update nested dictionary recursively in-place with key-value pairs.

    Reference: https://stackoverflow.com/a/30655448/9318372

    Parameters
    ----------
    d: dict
    new_kvals: dict
    """
    # if not inplace:
    #     return deep_dict_update(deepcopy(d), new_kvals, inplace=False)

    for key, value in d.items():
        if isinstance(value, Mapping) and value:
            d[key] = deep_kval_update(d.get(key, {}), **new_kvals)
        elif key in new_kvals:
            # if value is not None or not safe:
            d[key] = new_kvals[key]
    return d


@singledispatch
def relative_error(
    xhat: Union[ArrayLike, Tensor], x_true: Union[ArrayLike, Tensor]
) -> Union[ArrayLike, Tensor]:
    r"""Relative error, works with both :class:`~torch.Tensor` and :class:`~numpy.ndarray`.

    .. math::
        r(x̂, x) = \tfrac{|x̂ - x|}{|x|+ϵ}

    The tolerance parameter $ϵ$ is determined automatically. By default,
    $ϵ=2^{-24}$ for single and $ϵ=2^{-53}$ for double precision.

    Parameters
    ----------
    xhat: ArrayLike
        The estimation
    x_true:  ArrayLike
        The true value

    Returns
    -------
    ArrayLike
    """
    xhat, x_true = np.asanyarray(xhat), np.asanyarray(x_true)
    return _numpy_relative_error(xhat, x_true)


@relative_error.register
def _numpy_relative_error(xhat: ndarray, x_true: ndarray) -> ndarray:
    if xhat.dtype in (np.float16, np.int16):
        eps = 2**-11
    elif xhat.dtype in (np.float32, np.int32):
        eps = 2**-24
    elif xhat.dtype in (np.float64, np.int64):
        eps = 2**-53
    else:
        raise NotImplementedError

    return np.abs(xhat - x_true) / (np.abs(x_true) + eps)


@relative_error.register
def _torch_relative_error(xhat: Tensor, x_true: Tensor) -> Tensor:
    if xhat.dtype in (torch.bfloat16,):
        eps = 2**-8
    elif xhat.dtype in (torch.float16, torch.int16):
        eps = 2**-11
    elif xhat.dtype in (torch.float32, torch.int32):
        eps = 2**-24
    elif xhat.dtype in (torch.float64, torch.int64):
        eps = 2**-53
    else:
        raise NotImplementedError

    # eps = eps or _eps
    return torch.abs(xhat - x_true) / (torch.abs(x_true) + eps)


@singledispatch
def scaled_norm(
    x: Union[ArrayLike, Tensor],
    p: float = 2,
    axis: tuple[int] = (),
    keepdims: bool = False,
) -> Union[NDArray, Tensor]:
    r"""Scaled $ℓ^p$-norm, works with both :class:`torch.Tensor` and :class:`numpy.ndarray`.

    .. math::
        ‖x‖_p = (⅟ₙ ∑_{i=1}^n |x_i|^p)^{1/p}

    This naturally leads to

    .. math::
       ∥u⊕v∥ = \frac{\dim U}{\dim U⊕V} ∥u∥ + \frac{\dim V}{\dim U⊕V} ∥v∥

    This choice is consistent with associativity: $∥(u⊕v)⊕w∥ = ∥u⊕(v⊕w)∥$

    In particular, given $𝓤=⨁_{i=1:n} U_i$, then

    ..math::
        ∥u∥_p^p = ∑_{i=1:n} \frac{\dim U_i}{\dim 𝓤} ∥u_i∥_p^p

    Parameters
    ----------
    x: ArrayLike
    p: int, default=2
    axis: tuple[int], default=None
    keepdims: bool, default=False

    Returns
    -------
    ArrayLike
    """
    # x = np.asanyarray(x)
    # return scaled_norm(x, p=p, axis=axis, keepdims=keepdims)


@torch.jit.script
@scaled_norm.register
def torch_scaled_norm(
    ## type: (Tensor, float, list[int], bool) -> Tensor
    x: Tensor,
    p: float = 2,
    axis: tuple[int, ...] = (),  # TODO: use tuple[int, ...] once supported
    keepdims: bool = False,
) -> Tensor:
    if not _torch_is_float_dtype(x):
        x = x.to(dtype=torch.float)
    x = torch.abs(x)

    if p == 0:
        # https://math.stackexchange.com/q/282271/99220
        return torch.exp(torch.mean(torch.log(x), dim=axis, keepdim=keepdims))
    if p == 1:
        return torch.mean(x, dim=axis, keepdim=keepdims)
    if p == 2:
        return torch.sqrt(torch.mean(x**2, dim=axis, keepdim=keepdims))
    if p == float("inf"):
        return torch.amax(x, dim=axis, keepdim=keepdims)
    # other p
    return torch.mean(x**p, dim=axis, keepdim=keepdims) ** (1 / p)


@scaled_norm.register
def numpy_scaled_norm(
    x: ndarray,
    p: float = 2,
    axis: Union[int, tuple[int, ...]] = None,
    keepdims: bool = False,
) -> ndarray:
    x = np.abs(x)

    if p == 0:
        # https://math.stackexchange.com/q/282271/99220
        return np.exp(np.mean(np.log(x), axis=axis, keepdims=keepdims))
    if p == 1:
        return np.mean(x, axis=axis, keepdims=keepdims)
    if p == 2:
        return np.sqrt(np.mean(x**2, axis=axis, keepdims=keepdims))
    if p == float("inf"):
        return np.max(x, axis=axis, keepdims=keepdims)
    # other p
    return np.mean(x**p, axis=axis, keepdims=keepdims) ** (1 / p)


def flatten_dict(
    d: dict[Any, Iterable[Any]], recursive: bool = True
) -> list[tuple[Any, ...]]:
    r"""Flatten a dictionary containing iterables to a list of tuples.

    Parameters
    ----------
    d: dict
    recursive: bool (default=True)
        If true applies flattening strategy recursively on nested dicts, yielding
        list[tuple[key1, key2, ...., keyN, value]]

    Returns
    -------
    list[tuple[Any, ...]]
    """
    result = []
    for key, iterable in d.items():
        for item in iterable:
            if isinstance(item, dict) and recursive:
                gen: list[tuple[Any, ...]] = flatten_dict(item, recursive=True)
                result += [(key,) + tup for tup in gen]
            else:
                result += [(key, item)]
    return result


# @torch.jit.script
# @scaled_norm.register
# def multi_scaled_norm(
#     x: list[Tensor],
#     p: float = 2,
# ) -> Tensor:
#     z = torch.stack([scaled_norm(z, p=p) for z in x])
#     w = torch.tensor([z.numel() for z in x], device=z.device, dtype=z.dtype)
#     return torch.dot(w, z)/torch.sum(w)


# How would you call tuples of tensors?
# hil-bor hil-tor hil-ber
# tup-lor
# poly-tor poly-sor
# mul-tor mul-sor
# n-dor en-dor

In [None]:
np.mean(np.random.randn(2, 3, 4), axis=None)

In [None]:
@singledispatch
def _process(x: Tensor) -> Tensor:
    return process_torch(x)


@torch.jit.script
@_process.register
def process_torch(x: Tensor) -> Tensor:
    return x


@_process.register
def process_numpy(x: NDArray) -> NDArray:
    return x


def process(x: Tensor) -> Tensor:
    return _process(x)

In [None]:
@torch.jit.script
def test(x: Tensor) -> Tensor:
    return process(x)

In [None]:
process(torch.randn(4))

In [None]:
scaled_norm([1, 2, 3, 4])