In [1]:
import torch
from torch import Tensor, jit

In [2]:
A =torch.randn(3,4,5)

In [10]:
def _is_float_dtype(x: Tensor) -> bool:
    return x.dtype in (torch.half, torch.float, torch.double, torch.bfloat16,
                       torch.complex32, torch.complex64, torch.complex128)

In [71]:
torch.mean(A, [])

In [19]:
@jit.script
def jmean(x: Tensor, dim: list[int] = (), keepdim: bool=False):
    if not _is_float_dtype(x):
        x = x.to(dtype=torch.float)
    return torch.mean(x, dim, keepdim=keepdim)

In [83]:
@jit.script
def torch_scaled_norm(x: Tensor,  p: float = 2, dim: list[int] = (), keepdim: bool = False) -> Tensor:    
    if not _is_float_dtype(x):
        x.to(dtype=torch.float)

    if p == 0:
        # https://math.stackexchange.com/q/282271/99220
        return torch.exp(torch.mean(torch.log(x), dim=dim, keepdim=keepdim))
    if p == 1:
        return torch.mean(x, dim=dim, keepdim=keepdim)
    if p == 2:
        return torch.sqrt(torch.mean(x ** 2, dim=dim, keepdim=keepdim))
    if p == float('inf'):
        return torch.amax(x, dim=dim, keepdim=keepdim)
    # other p
    return torch.mean(x ** p, dim=dim, keepdim=keepdim) ** (1 / p)

In [70]:
torch.tensor([1,2]).to(dtype=torch.bfloat16).dtype == float

In [22]:
torch_scaled_norm(torch.tensor([1,2]))

In [65]:
import numba
from numpy import ndarray
import numpy as np

def numpy_scaled_norm(x: ndarray, p: float = 2, axis = (), keepdims: bool = False) -> ndarray:
    x = np.abs(x)

    if p == 0:
        # https://math.stackexchange.com/q/282271/99220
        return np.exp(np.sum(np.log(x), axis=axis, keepdims=keepdims))
    if p == 1:
        return np.sum(x, axis=axis, keepdims=keepdims)
    if p == 2:
        return np.sqrt(np.sum(x ** 2, axis=axis, keepdims=keepdims))
    if p == np.inf:
        return np.amax(x, axis=axis, keepdims=keepdims)
    # other p
    return np.sum(x**p, axis=axis, keepdims=keepdims) ** (1 / p)

In [66]:
numpy_scaled_norm(np.array([1,2], dtype=np.float32))

In [51]:
np.array([1,2]).astype(float)

In [72]:
B =np.random.randn(3,4,5)

In [75]:
np.mean(B, axis=())

In [77]:
list(())