# 🧾 View as a summary

In [None]:
#| default_exp repr_str

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()

In [None]:
# |hide
import os
from nbdev.showdoc import *
from fastcore.test import test_eq

In [None]:
#| hide
# For testing, I want to see 8 CPU devices.
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

In [None]:
#| hide
#| export

import warnings

import numpy as np
import jax, jax.numpy as jnp

from lovely_numpy import np_to_str_common, pretty_str, sparse_join, ansi_color, in_debugger, bytes_to_human
from lovely_numpy import config as lnp_config

from lovely_jax.utils.config import get_config, config
from lovely_jax.utils.misc import is_cpu, test_array_repr

In [None]:
# |hide
key = jax.random.PRNGKey(0)
randoms = jax.random.normal(key, (100,))

In [None]:
spicy = (randoms[:12].at[0].mul(10000)
                    .at[1].divide(10000)
                    .at[3].set(float('inf'))
                    .at[4].set(float('-inf'))
                    .at[5].set(float('nan'))
                    .reshape((2,6)))

In [None]:
# |exporti
dtnames =   {   "float16": "f16",
                "float32": "", # Default dtype in jax
                "float64": "f64",
                "bfloat16": "bf16",
                "uint8": "u8",
                "uint16": "u16",
                "uint32": "u32",
                "uint64": "u64",
                "int8": "i8",
                "int16": "i16",
                "int32": "i32",
                "int64": "i64",
            }

def short_dtype(x: jax.Array) -> str:
    return dtnames.get(x.dtype.name, str(x.dtype))

In [None]:
# |hide
test_eq(short_dtype(jnp.array(1., dtype=jnp.bfloat16)), "bf16")

In [None]:
# | exporti
def plain_repr(x: jax.Array):
    "Pick the right function to get a plain repr"
    # assert isinstance(x, np.ndarray), f"expected np.ndarray but got {type(x)}" # Could be a sub-class.
    return x._plain_repr() if hasattr(x, "_plain_repr") else repr(x)

# def plain_str(x: torch.Tensor):
#     "Pick the right function to get a plain str."
#     # assert isinstance(x, np.ndarray), f"expected np.ndarray but got {type(x)}"
#     return x._plain_str() if hasattr(type(x), "_plain_str") else str(x)

In [None]:
# | exporti
def is_nasty(x: jax.Array):
    """Return true of any `x` values are inf or nan"""

    if x.size == 0: return False # min/max don't like zero-lenght arrays

    x_min = x.min()
    x_max = x.max()

    return jnp.isnan(x_min) or jnp.isinf(x_min) or jnp.isinf(x_max)

In [None]:
#| hide

test_eq(is_nasty(jnp.array([1, 2, float("nan")])), True)
test_eq(is_nasty(jnp.array([1, 2, float("inf")])), True)
test_eq(is_nasty(jnp.array([1, 2, 3])), False)
test_eq(is_nasty(jnp.array([])), False)

In [None]:
# |exporti
def format_sharding(sharding) -> str:
    """Format sharding information in a compact, informative way."""
    from jax.sharding import SingleDeviceSharding, NamedSharding

    devices = sorted(sharding.device_set, key=lambda d: d.id)
    platform = devices[0].platform

    if len(devices) == 1:
        return f"{platform}:{devices[0].id}"

    # Format device range
    dev_ids = [d.id for d in devices]
    if len(set(dev_ids)) > 2:
        dev_range = f"{min(dev_ids)}-{max(dev_ids)}"
    else:
        dev_range = ",".join(map(str, dev_ids))

    # Add sharding type info
    if isinstance(sharding, SingleDeviceSharding):
        shard_info = ""
    elif isinstance(sharding, NamedSharding):
        # Format PartitionSpec compactly: P('x', 'y') -> S[x,y], P('x', None) -> S[x,·]
        spec_str = str(sharding.spec)
        # Extract the content from PartitionSpec(...)
        if "PartitionSpec" in spec_str:
            spec_str = spec_str.replace("PartitionSpec(", "").rstrip(")")
        # Clean up the formatting
        spec_str = spec_str.replace("'", "").replace(", ", ",").replace("None", "·")

        # Add mesh shape for multi-dimensional meshes
        # mesh.shape is an OrderedDict with axis_names as keys
        mesh_shape = sharding.mesh.shape
        if len(mesh_shape) > 1:
            # Get the axis names from the mesh in order
            axis_names = list(mesh_shape.keys())
            # Build shape string in the order of axis names
            mesh_str = "×".join(str(mesh_shape[name]) for name in axis_names)
            shard_info = f"S[{spec_str}] {mesh_str} "
        else:
            shard_info = f"S[{spec_str}] "
    else:
        shard_info = f"{type(sharding).__name__} "

    return f"{shard_info}{platform}:{dev_range}"


In [None]:
# |export
def jax_to_str_common(x: jax.Array,  # Input
                        color=True,                     # ANSI color highlighting
                        ddof=0):                        # For "std" unbiasing

    if x.size == 0:
        return ansi_color("empty", "grey", color)

    zeros = ansi_color("all_zeros", "grey", color) if jnp.equal(x, 0.).all() and x.size > 1 else None

    summary = None
    if not zeros and x.ndim > 0:
        minmax = f"x∈[{pretty_str(x.min())}, {pretty_str(x.max())}]" if x.size > 2 else None
        meanstd = f"μ={pretty_str(x.mean())} σ={pretty_str(x.std(ddof=ddof))}" if x.size >= 2 else None
        summary = sparse_join([minmax, meanstd])


    return sparse_join([ summary, zeros])

In [None]:
# |exporti

def to_str(x: jax.Array,  # Input
            plain: bool=False,
            verbose: bool=False,
            depth=0,
            lvl=0,
            color=None) -> str:

    if plain:
        return plain_repr(x)

    conf = get_config()

    tname = type(x).__name__.split(".")[-1]
    if tname in ("ArrayImpl"): tname = "Array"
    shape = str(list(x.shape)) if x.ndim else None
    type_str = sparse_join([tname, shape], sep="")

    # Check for sharding first, as sharded arrays also have .devices()
    if hasattr(x, "sharding"):
        dev = format_sharding(x.sharding)
    elif hasattr(x, "devices"): # Unified Array (jax >= 0.4)
        int_dev_ids = sorted([d.id for d in x.devices()])
        ids = ",".join(map(str, int_dev_ids))
        dev = f"{list(x.devices())[0].platform}:{ids}"
    elif hasattr(x, "device"): # Old-style DeviceArray
        dev = f"{x.device().platform}:{x.device().id}"
    else:
        assert 0, f"Weird input type={type(input)}, expecrted Array, DeviceArray, or ShardedDeviceArray"

    dtype = short_dtype(x)
    # grad_fn = t.grad_fn.name() if t.grad_fn else None
    # PyTorch does not want you to know, but all `grad_fn``
    # tensors actuall have `requires_grad=True`` too.
    # grad = "grad" if t.requires_grad else None
    grad = grad_fn = None

    # For complex tensors, just show the shape / size part for now.
    if not jnp.iscomplexobj(x):
        if color is None: color=conf.color
        if in_debugger(): color = False
        # `lovely-numpy` is used to calculate stats when doing so on GPU would require
        # memory allocation (not float tensors, tensors with bad numbers), or if the
        # data is on CPU (because numpy is faster).
        #
        # Temporarily set the numpy config to match our config for consistency.
        with lnp_config(precision=conf.precision,
                        threshold_min=conf.threshold_min,
                        threshold_max=conf.threshold_max,
                        sci_mode=conf.sci_mode):

            if is_cpu(x) or is_nasty(x):
                common = np_to_str_common(np.array(x), color=color)
            else:
                common = jax_to_str_common(x, color=color)

            numel = None
            if x.shape and max(x.shape) != x.size:
                numel = f"n={x.size}"
                if get_config().show_mem_above <= x.nbytes:
                    numel = sparse_join([numel, f"({bytes_to_human(x.nbytes)})"])
            elif get_config().show_mem_above <= x.nbytes:
                numel = bytes_to_human(x.nbytes)

            vals = pretty_str(x) if 0 < x.size <= 10 else None
            res = sparse_join([type_str, dtype, numel, common, grad, grad_fn, dev, vals])
    else:
        res = plain_repr(x)

    if verbose:
        res += "\n" + plain_repr(x)

    if depth and x.ndim > 1:
        with config(show_mem_above=jnp.inf):
            deep_width = min((x.shape[0]), conf.deeper_width) # Print at most this many lines
            deep_lines = [ " "*conf.indent*(lvl+1) + to_str(x[i,:], depth=depth-1, lvl=lvl+1, color=color)
                                for i in range(deep_width)]

            # If we were limited by width, print ...
            if deep_width < x.shape[0]: deep_lines.append(" "*conf.indent*(lvl+1) + "...")

            res += "\n" + "\n".join(deep_lines)

    return res

In [None]:
# |exporti
def history_warning():
    "Issue a warning (once) ifw e are running in IPYthon with output cache enabled"

    if "get_ipython" in globals() and get_ipython().cache_size > 0:
        warnings.warn("IPYthon has its output cache enabled. See https://xl0.github.io/lovely-tensors/history.html")

In [None]:
# |hide
get_ipython().cache_size=1000
history_warning()



In [None]:
# |hide
get_ipython().cache_size=0

In [None]:
#| exporti

class StrProxy():
    def __init__(self, x: jax.Array, plain=False, verbose=False, depth=0, lvl=0, color=None):
        self.x = x
        self.plain = plain
        self.verbose = verbose
        self.depth=depth
        self.lvl=lvl
        self.color=color
        history_warning()

    def __repr__(self):
        return to_str(self.x, plain=self.plain, verbose=self.verbose,
                      depth=self.depth, lvl=self.lvl, color=self.color)

    # This is used for .deeper attribute and .deeper(depth=...).
    # The second onthe results in a __call__.
    def __call__(self, depth=1):
        return StrProxy(self.x, depth=depth)

In [None]:
# |export
def lovely(x: jax.Array, # Tensor of interest
            verbose=False,  # Whether to show the full tensor
            plain=False,    # Just print if exactly as before
            depth=0,        # Show stats in depth
            color=None):    # Force color (True/False) or auto.
    return StrProxy(x, verbose=verbose, plain=plain, depth=depth, color=color)

### Examples

In [None]:
print(lovely(randoms[0]))
print(lovely(randoms[:2]))
print(lovely(randoms[:6].reshape((2, 3)))) # More than 2 elements -> show statistics
print(lovely(randoms[:11]))           # More than 10 -> suppress data output


Array cpu:0 1.623
Array[2] μ=1.824 σ=0.201 cpu:0 [1.623, 2.025]
Array[2, 3] n=6 x∈[-0.972, 2.025] μ=0.390 σ=1.080 cpu:0 [[1.623, 2.025, -0.434], [-0.079, 0.176, -0.972]]
Array[11] x∈[-0.972, 2.180] μ=0.385 σ=1.081 cpu:0


In [None]:
# |hide
test_array_repr(str(lovely(randoms[0])),                "Array cpu:0 1.623")
test_array_repr(str(lovely(randoms[:2])),               "Array[2] μ=1.824 σ=0.201 cpu:0 [1.623, 2.025]")
test_array_repr(str(lovely(randoms[:6].reshape(2, 3))), "Array[2, 3] n=6 x∈[-0.972, 2.025] μ=0.390 σ=1.080 cpu:0 [[1.623, 2.025, -0.434], [-0.079, 0.176, -0.972]]")
test_array_repr(str(lovely(randoms[:11])),              "Array[11] x∈[-0.972, 2.180] μ=0.385 σ=1.081 cpu:0")

In [None]:
grad = jnp.array(1., dtype=jnp.float16)
print(lovely(grad)); print(lovely(grad+1))

Array f16 cpu:0 1.000
Array f16 cpu:0 2.000


In [None]:
# |hide
# test_eq(str(lovely(grad)), "tensor f64 grad 1.000")
# test_eq(str(lovely(grad+1)), "tensor f64 grad AddBackward0 2.000")

In [None]:
# if torch.cuda.is_available():
#     print(lovely(torch.tensor(1., device=torch.device("cuda:0"))))
#     test_eq(str(lovely(torch.tensor(1., device=torch.device("cuda:0")))), "tensor cuda:0 1.000")

Do we have __any__ floating point nasties? Is the tensor __all__ zeros?

In [None]:
# Statistics and range are calculated on good values only, if there are at lest 3 of them.
lovely(spicy)

Array[2, 6] n=12 x∈[-1.955, 1.623e+04] μ=1.803e+03 σ=5.099e+03 [31m+Inf![0m [31m-Inf![0m [31mNaN![0m cpu:0

In [None]:
# |hide
test_array_repr(str(lovely(spicy)),
    'Array[2, 6] n=12 x∈[-1.955, 1.623e+04] μ=1.803e+03 σ=5.099e+03 \x1b[31m+Inf!\x1b[0m \x1b[31m-Inf!\x1b[0m \x1b[31mNaN!\x1b[0m gpu:0')

In [None]:
lovely(spicy, color=False)

Array[2, 6] n=12 x∈[-1.955, 1.623e+04] μ=1.803e+03 σ=5.099e+03 +Inf! -Inf! NaN! cpu:0

In [None]:
str(lovely(jnp.array([float("nan")]*11)))

'Array[11] \x1b[31mNaN!\x1b[0m cpu:0'

In [None]:
# |hide
test_array_repr(str(lovely(jnp.array([float("nan")]*11))),
        'Array[11] \x1b[31mNaN!\x1b[0m gpu:0')

In [None]:
lovely(jnp.zeros(12))

Array[12] [38;2;127;127;127mall_zeros[0m cpu:0

In [None]:
# |hide
test_array_repr(str(lovely(jnp.zeros(12))),
        'Array[12] \x1b[38;2;127;127;127mall_zeros\x1b[0m gpu:0')

In [None]:
lovely(jnp.array([], dtype=jnp.float16).reshape((0,0,0)))

Array[0, 0, 0] f16 [38;2;127;127;127mempty[0m cpu:0

In [None]:
# |hide
test_array_repr(str(lovely(jnp.array([], dtype=jnp.float16).reshape((0,0,0)))),
        'Array[0, 0, 0] f16 \x1b[38;2;127;127;127mempty\x1b[0m gpu:0')

In [None]:
lovely(jnp.array([1,2,3], dtype=jnp.int32))

Array[3] i32 x∈[1, 3] μ=2.000 σ=0.816 cpu:0 [1, 2, 3]

In [None]:
# |hide
test_array_repr(str(lovely(jnp.array([1,2,3], dtype=jnp.int32))),
        'Array[3] i32 x∈[1, 3] μ=2.000 σ=0.816 gpu:0 [1, 2, 3]')

In [None]:
jnp.set_printoptions(linewidth=120, precision=2)
lovely(spicy, verbose=True)

Array[2, 6] n=12 x∈[-1.955, 1.623e+04] μ=1.803e+03 σ=5.099e+03 [31m+Inf![0m [31m-Inf![0m [31mNaN![0m cpu:0
Array([[ 1.62e+04,  2.03e-04, -4.34e-01,       inf,      -inf,       nan],
       [-4.95e-01,  4.94e-01,  6.64e-01, -9.50e-01,  2.18e+00, -1.96e+00]], dtype=float32)

In [None]:
lovely(spicy, plain=True)

Array([[ 1.6226422e+04,  2.0252647e-04, -4.3359444e-01,            inf,
                  -inf,            nan],
       [-4.9529874e-01,  4.9437860e-01,  6.6434932e-01, -9.5016348e-01,
         2.1795304e+00, -1.9551506e+00]], dtype=float32)

In [None]:
image = jnp.load("mysteryman.npy")
image = image.at[1,2,3].set(float('nan'))

lovely(image, depth=2) # Limited by set_config(deeper_lines=N)

Array[3, 196, 196] n=115248 (0.4Mb) x∈[-2.118, 2.640] μ=-0.388 σ=1.073 [31mNaN![0m cpu:0
  Array[196, 196] n=38416 x∈[-2.118, 2.249] μ=-0.324 σ=1.036 cpu:0
    Array[196] x∈[-1.912, 2.249] μ=-0.673 σ=0.521 cpu:0
    Array[196] x∈[-1.861, 2.163] μ=-0.738 σ=0.417 cpu:0
    Array[196] x∈[-1.758, 2.198] μ=-0.806 σ=0.396 cpu:0
    Array[196] x∈[-1.656, 2.249] μ=-0.849 σ=0.368 cpu:0
    Array[196] x∈[-1.673, 2.198] μ=-0.857 σ=0.356 cpu:0
    Array[196] x∈[-1.656, 2.146] μ=-0.848 σ=0.371 cpu:0
    Array[196] x∈[-1.433, 2.215] μ=-0.784 σ=0.396 cpu:0
    Array[196] x∈[-1.279, 2.249] μ=-0.695 σ=0.485 cpu:0
    Array[196] x∈[-1.364, 2.249] μ=-0.637 σ=0.538 cpu:0
    ...
  Array[196, 196] n=38416 x∈[-1.966, 2.429] μ=-0.274 σ=0.973 [31mNaN![0m cpu:0
    Array[196] x∈[-1.861, 2.411] μ=-0.529 σ=0.555 cpu:0
    Array[196] x∈[-1.826, 2.359] μ=-0.562 σ=0.472 cpu:0
    Array[196] x∈[-1.756, 2.376] μ=-0.622 σ=0.458 [31mNaN![0m cpu:0
    Array[196] x∈[-1.633, 2.429] μ=-0.664 σ=0.429 cpu:0
    Array[1

In [None]:
# |hide
#### CUDA memory is not leaked

In [None]:
# |hide
# |eval: false
# def memstats():
#     allocated = int(torch.cuda.memory_allocated() // (1024*1024))
#     max_allocated = int(torch.cuda.max_memory_allocated() // (1024*1024))
#     return f"Allocated: {allocated} MB, Max: {max_allocated} Mb"

# if torch.cuda.is_available():
#     cudamem = torch.cuda.memory_allocated()
#     print(f"before allocation: {memstats()}")
#     numbers = torch.randn((3, 1024, 1024), device="cuda") # 12Mb image
#     torch.cuda.synchronize()

#     print(f"after allocation: {memstats()}")
#     # Note, the return value of lovely() is not a string, but a
#     # StrProxy that holds reference to 'numbers'. You have to del
#     # the references to it, but once it's gone, the reference to
#     # the tensor is gone too.
#     display(lovely(numbers) )
#     print(f"after repr: {memstats()}")

#     del numbers
#     # torch.cuda.memory.empty_cache()

#     print(f"after cleanup: {memstats()}")
#     test_eq(cudamem >= torch.cuda.memory_allocated(), True)

In [None]:
# We don't really supposed complex numbers yet
c = jnp.array([-0.4011-0.4035j,  1.1300+0.0788j, -0.0277+0.9978j, -0.4636+0.6064j, -1.1505-0.9865j])
lovely(c)

Array([-0.4011-0.4035j,  1.13  +0.0788j, -0.0277+0.9978j, -0.4636+0.6064j,
       -1.1505-0.9865j], dtype=complex64)

In [None]:
#| eval: false
assert jax.__version_info__[0] == 0
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
from jax.experimental import mesh_utils

print("=== Test 1: NamedSharding with 2D mesh (4,2) and P('y', 'x') ===")
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, axis_names=('y', 'x'))  # x has 4 devices, y has 2
sharding = NamedSharding(mesh, P('y', 'x'))  # Shard array dim 0 across y, dim 1 across x

x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
y = jax.device_put(x, sharding)

jax.debug.visualize_array_sharding(y)
print(lovely(y))

print("\n=== Test 2: NamedSharding with P('y', None) - replicate first dim ===")
sharding2 = NamedSharding(mesh, P('y', None))
y2 = jax.device_put(x, sharding2)
jax.debug.visualize_array_sharding(y2)
print(lovely(y2))

print("\n=== Test 3: NamedSharding with P(None, 'x') - replicate second dim ===")
sharding3 = NamedSharding(mesh, P(None, 'x'))
y3 = jax.device_put(x, sharding3)
jax.debug.visualize_array_sharding(y3)
print(lovely(y3))

print("\n=== Test 4: 1D mesh with 8 devices ===")
devices_1d = mesh_utils.create_device_mesh((8,))
mesh_1d = Mesh(devices_1d, axis_names=('x',))
sharding_1d = NamedSharding(mesh_1d, P('x', None))
y4 = jax.device_put(x, sharding_1d)
jax.debug.visualize_array_sharding(y4)
print(lovely(y4))


=== Test 1: NamedSharding with 2D mesh (4,2) and P('y', 'x') ===


Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=1.508e-05 σ=1.000 S[y,x] 4×2 cpu:0-7

=== Test 2: NamedSharding with P('y', None) - replicate first dim ===


Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=1.508e-05 σ=1.000 S[y,·] 4×2 cpu:0-7

=== Test 3: NamedSharding with P(None, 'x') - replicate second dim ===


Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=1.508e-05 σ=1.000 S[·,x] 4×2 cpu:0-7

=== Test 4: 1D mesh with 8 devices ===


Array[8192, 8192] n=67108864 (0.2Gb) x∈[-5.420, 5.220] μ=1.508e-05 σ=1.000 S[x,·] cpu:0-7
