# View as a summary

In [None]:
#| default_exp repr_str

In [None]:
# |hide
from nbdev.showdoc import *
from fastcore.test import test_eq

In [None]:
#| hide
#| export
from typing import Optional, Union
import jax.numpy as jnp
from jax import random

In [None]:
# |exports
class __PrinterOptions(object):
    precision: int = 3
    threshold_max: int = 3 # .abs() larger than 1e3 -> Sci mode
    threshold_min: int = -4 # .abs() smaller that 1e-4 -> Sci mode
    sci_mode: Optional[bool] = None # None = auto. Otherwise, force sci mode.
    indent: int = 2 # Indent for .deeper()
    color: bool = True

PRINT_OPTS = __PrinterOptions()

In [None]:
# |hide
# |exporti

# Do we want this float in decimal or scientific mode?
def sci_mode(f: float):
    return (abs(f) < 10**(PRINT_OPTS.threshold_min) or
            abs(f) > 10**PRINT_OPTS.threshold_max)

In [None]:
# |hide
test_eq(sci_mode(1.), False)
test_eq(sci_mode(0.00001), True)
test_eq(sci_mode(10000000), True)

# It would be fine either way, both `e` and `f` formats handle those.
test_eq(sci_mode(float('nan')), False)
test_eq(sci_mode(float('inf')), True) 

In [None]:
# |hide

# What's happening in the cell below
fmt = f"{{:.{4}{'e'}}}"
fmt, fmt.format(1.23)

('{:.4e}', '1.2300e+00')

In [None]:
# |export

# Convert a tensor or scalar into a string.
# This only looks good for small tensors, which is how it's intended to be used.
def pretty_str(x: Union[jnp.DeviceArray, float, int]):
    """A slightly better way to print `float`-y values"""

    if isinstance(x, int):
        return '{}'.format(x)
    elif isinstance(x, float):
        if x == 0.:
            return "0."

        sci = (PRINT_OPTS.sci_mode or
                (PRINT_OPTS.sci_mode is None and sci_mode(x)))
        # The f-string will generate something like "{.4f}", which is used
        # to format the value.
        return f"{{:.{PRINT_OPTS.precision}{'e' if sci else 'f'}}}".format(x)
    elif x.ndim == 0:
            return pretty_str(x.item())
    else:
        slices = [pretty_str(x[i]) for i in range(0, x.shape[0])]
        return '[' + ", ".join(slices) + ']'

In [None]:
key = random.PRNGKey(0)
randoms: jnp.DeviceArray = random.normal(key, (100,))

In [None]:
spicy = (randoms[:12].at[0].mul(10000)
                    .at[1].divide(10000)
                    .at[3].set(float('inf'))
                    .at[4].set(float('-inf'))
                    .at[5].set(float('nan'))
                    .reshape((2,6)))

In [None]:
pretty_str(spicy)

'[[-1.981e+04, 0.000, 0.890, inf, -inf, nan], [0.031, -0.390, 0.013, -0.421, -1.234, -1.252]]'

In [None]:
# |hide
test_eq(pretty_str(spicy), '[[-1.981e+04, 0.000, 0.890, inf, -inf, nan], [0.031, -0.390, 0.013, -0.421, -1.234, -1.252]]')

In [None]:
# |exporti
# |hide
def space_join(lst: list):
    "Join non-empty list elements into a space-sepaeated string"
    return " ".join( [ l for l in lst if l] )

In [None]:
# |hide
test_eq(space_join(["Hello", None, "World"]), 'Hello World')

In [None]:
# |exporti

dtnames = { jnp.dtype(k): v for k,v in {"float32": "",
                                        "float16": "f16",
                                        "float64": "f64",
                                        "uint8": "u8",
                                        "uint16": "u16",
                                        "uint32": "u32",
                                        "int8": "i8",
                                        "int16": "i16",
                                        "int32": "i32", }.items()
}
def short_dtype(x): return dtnames.get(x.dtype, str(x.dtype))

In [None]:
# |exporti

def plain_repr(x):
    "Pick either x.__repr__ or x._plain_repr if __repr__ has been monkey-patched"
    return x._plain_repr() if hasattr(x.__class__, "_plain_repr") else x.__repr__()

In [None]:
#| exporti

class StrProxy():
    def __init__(self, x: jnp.DeviceArray, plain=False, verbose=False, depth=0, lvl=0, color=None):
        self.x = x
        self.plain = plain
        self.verbose = verbose
        self.depth=depth
        self.lvl=lvl
        self.color=color

    # @torch.no_grad()
    def to_str(self):
        x : jnp.DeviceArray = self.x

        if self.plain or jnp.iscomplex(x).any():
            return plain_repr(x)

        color = PRINT_OPTS.color if self.color is None else self.color
        
        grey_style = "\x1b[38;2;127;127;127m" if color else ""
        red_style = "\x1b[31m" if color else ""
        end_style = "\x1b[0m" if color else ""

        tname = "DeviceArray" if type(x) is jnp.DeviceArray else type(x).__name__
        dev = None # XXX str(x.device) if x.device.type != "cpu" else None
        dtype = short_dtype(x)


        grad_fn = None# x.grad_fn.name() if x.grad_fn else None
        # All tensors along the compute path actually have required_grad=True.
        # Torch __repr__ just dones not show it.
        grad = None #"grad" if x.requires_grad else None

        shape = str(list(x.shape))

        # Later, we might be indexing 't' with a bool tensor derived from it. 
        # THis takes 4x memory and will result in a CUDA OOM if 't' is very large.
        # Move it to the cpu now - it won't matter for small tensors, and for
        # very large ones we trade a CUDA OOM for a few seconds delay.
        # x = x.detach().cpu()

        zeros = grey_style+"all_zeros"+end_style if not x.any() and x.size > 1 else None
        pinf = red_style+"+inf!"+end_style if jnp.isposinf(x).any() else None
        ninf = red_style+"-inf!"+end_style if jnp.isneginf(x).any() else None
        nan = red_style+"nan!"+end_style if jnp.isnan(x).any() else None

        attention = space_join([zeros,pinf,ninf,nan])

        vals = ""
        numel = f"n={x.size}" if x.size > 5 and max(x.shape) != x.size else None
        summary = None
        if not zeros:
            if x.size <= 10: vals = pretty_str(x)
            
        #     # Make sure it's float32. Also, we calculate stats on good values only.

            ft = jnp.extract(jnp.isfinite(x), x).astype(jnp.float32)

            minmax = f"x∈[{pretty_str(ft.min())}, {pretty_str(ft.max())}]" if ft.size > 2 else None
            meanstd = f"μ={pretty_str(ft.mean())} σ={pretty_str(ft.std())}" if ft.size >= 2 else None

            summary = space_join([minmax, meanstd])




        res = tname + space_join([  shape,
                                    numel,
                                    summary,
                                    dtype,
                                    grad,
                                    grad_fn,
                                    dev,
                                    attention,
                                    vals if not self.verbose else None])

        if self.verbose:
            res += "\n" + plain_repr(x)

        if self.depth and x.ndim > 1:
            res += "\n" + "\n".join([
                " "*PRINT_OPTS.indent*(self.lvl+1) +
                str(StrProxy(x[i,:], depth=self.depth-1, lvl=self.lvl+1))
                for i in range(x.shape[0])])

        return res
    
    def __repr__(self):
        return self.to_str()

    def __call__(self, depth=0):
        return StrProxy(self.x, depth=depth)


Would be _lovely_ if you could see all the important stats too!

In [None]:
# |export
def lovely(x: jnp.DeviceArray, # Tensor of interest
            verbose=False,  # Whether to show the full tensor
            plain=False,    # Just print if exactly as before
            depth=0,        # Show stats in depth
            color=None):    # Force color (True/False) or auto.
    return StrProxy(x, verbose=verbose, plain=plain, depth=depth, color=color)

In [None]:
print(lovely(randoms[0]))
print(lovely(randoms[:2]))
print(lovely(randoms[:6].reshape((2, 3)))) # More than 2 elements -> show statistics
print(lovely(randoms[:11])) # More than 10 -> suppress data output

DeviceArray[] -1.981
DeviceArray[2] μ=-0.466 σ=1.515 [-1.981, 1.048]
DeviceArray[2, 3] n=6 x∈[-1.981, 1.048] μ=-0.017 σ=1.113 [[-1.981, 1.048, 0.890], [0.035, -0.947, 0.851]]
DeviceArray[11] x∈[-1.981, 1.048] μ=-0.191 σ=0.899


In [None]:
# |hide
test_eq(str(lovely(randoms[0])), "DeviceArray[] -1.981")
test_eq(str(lovely(randoms[:2])), "DeviceArray[2] μ=-0.466 σ=1.515 [-1.981, 1.048]")
test_eq(str(lovely(randoms[:6].reshape((2, 3)))), "DeviceArray[2, 3] n=6 x∈[-1.981, 1.048] μ=-0.017 σ=1.113 [[-1.981, 1.048, 0.890], [0.035, -0.947, 0.851]]")
test_eq(str(lovely(randoms[:11])), "DeviceArray[11] x∈[-1.981, 1.048] μ=-0.191 σ=0.899")

In [None]:
# |hide
# grad = torch.tensor(1., requires_grad=True)
# print(lovely(grad)); print(lovely(grad+1))

In [None]:
# |hide
# test_eq(str(lovely(grad)), "tensor[] grad 1.000")
# test_eq(str(lovely(grad+1)), "tensor[] grad AddBackward0 2.000")

In [None]:
# |hide
# if torch.cuda.is_available():
#     print(lovely(torch.tensor(1., device=torch.device("cuda:0"))))
#     test_eq(str(lovely(torch.tensor(1., device=torch.device("cuda:0")))), "tensor[] cuda:0 1.000")

Do we have __any__ floating point nasties? Is the tensor __all__ zeros?

In [None]:
# Statistics and range are calculated on good values only, if there are at lest 3 of them.
lovely(spicy)

DeviceArray[2, 6] n=12 x∈[-1.981e+04, 0.890] μ=-2.201e+03 σ=6.226e+03 [31m+inf![0m [31m-inf![0m [31mnan![0m

In [None]:
lovely(spicy, color=False)

DeviceArray[2, 6] n=12 x∈[-1.981e+04, 0.890] μ=-2.201e+03 σ=6.226e+03 +inf! -inf! nan!

In [None]:
lovely(jnp.array([float("nan")]*11))

DeviceArray[11] [31mnan![0m

In [None]:
lovely(jnp.zeros(12))

DeviceArray[12] [38;2;127;127;127mall_zeros[0m

In [None]:
test_eq(str(lovely(spicy)),
    'DeviceArray[2, 6] n=12 x∈[-1.981e+04, 0.890] μ=-2.201e+03 σ=6.226e+03 \x1b[31m+inf!\x1b[0m \x1b[31m-inf!\x1b[0m \x1b[31mnan!\x1b[0m')
test_eq(str(lovely(jnp.array([float("nan")]*11))), 'DeviceArray[11] \x1b[31mnan!\x1b[0m')
test_eq(str(lovely(jnp.zeros(12))), 'DeviceArray[12] \x1b[38;2;127;127;127mall_zeros\x1b[0m')

In [None]:
# torch.set_printoptions(linewidth=120)
lovely(spicy, verbose=True)

DeviceArray[2, 6] n=12 x∈[-1.981e+04, 0.890] μ=-2.201e+03 σ=6.226e+03 [31m+inf![0m [31m-inf![0m [31mnan![0m
DeviceArray([[-1.9810703e+04,  1.0481724e-04,  8.8981909e-01,
                         inf,           -inf,            nan],
             [ 3.1245498e-02, -3.8968593e-01,  1.3208009e-02,
              -4.2052191e-01, -1.2335656e+00, -1.2524313e+00]],            dtype=float32)

In [None]:
lovely(spicy, plain=True)

DeviceArray([[-1.9810703e+04,  1.0481724e-04,  8.8981909e-01,
                         inf,           -inf,            nan],
             [ 3.1245498e-02, -3.8968593e-01,  1.3208009e-02,
              -4.2052191e-01, -1.2335656e+00, -1.2524313e+00]],            dtype=float32)

In [None]:
numbers = jnp.load("mysteryman.npy")
numbers=  numbers.at[1,100,100].set(float('nan'))

lovely(numbers, depth=1)

DeviceArray[3, 196, 196] n=115248 x∈[-2.118, 2.640] μ=-0.388 σ=1.073 [31mnan![0m
  DeviceArray[196, 196] n=38416 x∈[-2.118, 2.249] μ=-0.324 σ=1.036
  DeviceArray[196, 196] n=38416 x∈[-1.966, 2.429] μ=-0.274 σ=0.973 [31mnan![0m
  DeviceArray[196, 196] n=38416 x∈[-1.804, 2.640] μ=-0.567 σ=1.178

In [None]:
# We don't really supposed complex numbers yet
c = random.normal(key, (10,), dtype=jnp.complex64)
c

DeviceArray([-1.8459435 -0.27444658j,  0.02393756-0.03172904j,
              0.7681536 -1.4444252j , -1.0467294 +0.0560899j ,
              0.3457446 +0.23581952j,  0.75131226+0.5628553j ,
              0.38307393-1.0190806j ,  0.01203694-1.1971303j ,
              0.1925229 -0.26424018j,  0.21582629-1.089025j  ],            dtype=complex64)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()