# ❤️ Lovely JAX

> After all, you are only human.

## Read full docs [here](https://xl0.github.io/lovely-tensors/)

## Install

```sh
pip install lovely-jax
```

## How to use

In [None]:
# |hide
import jax.numpy as jnp
from PIL import Image
from fastcore.test import test_eq

In [None]:
# |hide
# torch.set_printoptions(linewidth=120)

In [None]:
# |hide
# Don't depend on torchvision
# numbers = VF.normalize(VF.center_crop(VF.to_tensor(Image.open("tenchman.jpg")), 196), **in_stats)
numbers = jnp.load("mysteryman.npy")
numbers1 = numbers

How often do you find yourself debugging PyTorch code? You dump a tensor to the cell output, and see this:

In [None]:
# |hide
import lovely_jax as lj

In [None]:
# |hide
# A trick to make sure README.md shows the plain version.
lj.monkey_patch()
numbers = numbers.p

In [None]:
numbers 

DeviceArray([[[-0.35405433, -0.33692956, -0.4054286 , ..., -0.55955136,
               -0.4739276 ,  2.2489083 ],
              [-0.4054286 , -0.42255333, -0.49105233, ..., -0.91917115,
               -0.8506721 ,  2.1632845 ],
              [-0.4739276 , -0.4739276 , -0.5424266 , ..., -1.0390445 ,
               -1.0390445 ,  2.1975338 ],
              ...,
              [-0.9020464 , -0.8335474 , -0.9362959 , ..., -1.4671633 ,
               -1.2959158 ,  2.2317834 ],
              [-0.8506721 , -0.78217316, -0.9362959 , ..., -1.6041614 ,
               -1.5014129 ,  2.1804092 ],
              [-0.8335474 , -0.81642264, -0.9705454 , ..., -1.6555357 ,
               -1.5527872 ,  2.11191   ]],

             [[-0.19747896, -0.19747896, -0.30252096, ..., -0.47759098,
               -0.37254897,  2.4110641 ],
              [-0.24999997, -0.23249297, -0.33753496, ..., -0.705182  ,
               -0.670168  ,  2.3585434 ],
              [-0.30252096, -0.28501397, -0.39005598, ..., -0.74019

In [None]:
# | hide
numbers = numbers1

Was it really useful for you, as a human, to see all these numbers?

What is the shape? The size?\
What are the statistics?\
Are any of the values `nan` or `inf`?\
Is it an image of a man holding a tench?

In [None]:
import lovely_jax as lj

In [None]:
lj.monkey_patch()

## `__repr__`

In [None]:
numbers # torch.Tensor

DeviceArray[3, 196, 196] n=115248 x∈[-2.118, 2.640] μ=-0.388 σ=1.073

Better, huh?

In [None]:
numbers[1,:6,1] # Still shows values if there are not too many.

DeviceArray[6] x∈[-0.443, -0.197] μ=-0.311 σ=0.083 [-0.197, -0.232, -0.285, -0.373, -0.443, -0.338]

In [None]:
spicy = numbers.flatten()[:12].clone()

spicy = (spicy  .at[0].mul(10000)
                .at[1].divide(10000)
                .at[2].set(float('inf'))
                .at[3].set(float('-inf'))
                .at[4].set(float('nan'))
                .reshape((2,6)))
spicy # Spicy stuff

DeviceArray[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.113e+03 [31m+inf![0m [31m-inf![0m [31mnan![0m

In [None]:
jnp.zeros((10, 10)) # A zero tensor - make it obvious

DeviceArray[10, 10] n=100 [38;2;127;127;127mall_zeros[0m

In [None]:
spicy.v # Verbose

DeviceArray[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.113e+03 [31m+inf![0m [31m-inf![0m [31mnan![0m
DeviceArray([[-3.5405432e+03, -3.3692959e-05,            inf,
                        -inf,            nan, -4.0542859e-01],
             [-4.2255333e-01, -4.9105233e-01, -5.0817710e-01,
              -5.5955136e-01, -5.4242659e-01, -5.0817710e-01]],            dtype=float32)

In [None]:
spicy.p # The plain old way

DeviceArray([[-3.5405432e+03, -3.3692959e-05,            inf,
                        -inf,            nan, -4.0542859e-01],
             [-4.2255333e-01, -4.9105233e-01, -5.0817710e-01,
              -5.5955136e-01, -5.4242659e-01, -5.0817710e-01]],            dtype=float32)

## Going `.deeper`

In [None]:
numbers.deeper

DeviceArray[3, 196, 196] n=115248 x∈[-2.118, 2.640] μ=-0.388 σ=1.073
  DeviceArray[196, 196] n=38416 x∈[-2.118, 2.249] μ=-0.324 σ=1.036
  DeviceArray[196, 196] n=38416 x∈[-1.966, 2.429] μ=-0.274 σ=0.973
  DeviceArray[196, 196] n=38416 x∈[-1.804, 2.640] μ=-0.567 σ=1.178

In [None]:
# You can go deeper if you need to
numbers[:,:3,:5].deeper(2)

DeviceArray[3, 3, 5] n=45 x∈[-1.316, -0.197] μ=-0.593 σ=0.302
  DeviceArray[3, 5] n=15 x∈[-0.765, -0.337] μ=-0.492 σ=0.119
    DeviceArray[5] x∈[-0.440, -0.337] μ=-0.385 σ=0.037 [-0.354, -0.337, -0.405, -0.440, -0.388]
    DeviceArray[5] x∈[-0.662, -0.405] μ=-0.512 σ=0.097 [-0.405, -0.423, -0.491, -0.577, -0.662]
    DeviceArray[5] x∈[-0.765, -0.474] μ=-0.580 σ=0.112 [-0.474, -0.474, -0.542, -0.645, -0.765]
  DeviceArray[3, 5] n=15 x∈[-0.513, -0.197] μ=-0.321 σ=0.096
    DeviceArray[5] x∈[-0.303, -0.197] μ=-0.243 σ=0.049 [-0.197, -0.197, -0.303, -0.303, -0.215]
    DeviceArray[5] x∈[-0.408, -0.232] μ=-0.327 σ=0.075 [-0.250, -0.232, -0.338, -0.408, -0.408]
    DeviceArray[5] x∈[-0.513, -0.285] μ=-0.394 σ=0.091 [-0.303, -0.285, -0.390, -0.478, -0.513]
  DeviceArray[3, 5] n=15 x∈[-1.316, -0.672] μ=-0.964 σ=0.170
    DeviceArray[5] x∈[-0.985, -0.672] μ=-0.846 σ=0.110 [-0.672, -0.985, -0.881, -0.776, -0.916]
    DeviceArray[5] x∈[-1.212, -0.724] μ=-0.989 σ=0.160 [-0.724, -1.072, -0.968, -0.

In [None]:
# |hide

# numbers.rgb

In [None]:
# |hide

# in_stats = ( (0.485, 0.456, 0.406),     # mean 
#              (0.229, 0.224, 0.225) )    # std

# numbers.rgb(in_stats, cl=True) # For channel-last input format
# numbers.rgb(in_stats)

In [None]:
# |hide

# torch.manual_seed(42) # For reproducibility of flots, otherwise we updatem the images in git every time.

In [None]:
# |hide

# (numbers+3).plt

In [None]:
# |hide

# (numbers+3).plt(center="mean", max_s=1000)

In [None]:
# |hide


# (numbers+3).plt(center="range")

In [None]:
# |hide

# .chans will map values betwen [0,1] to colors.
# Make our values fit into that range to avoid clipping.
# mean = torch.tensor(in_stats[0])[:,None,None]
# std = torch.tensor(in_stats[1])[:,None,None]
# numbers_01 = (numbers*std + mean)
# numbers_01

In [None]:
# |hide
# |eval: false

# numbers_01.chans

In [None]:
# |hide
# |eval: false


# |eval: false
from torchvision.models import vgg11

In [None]:
# |hide
# |eval: false

# features = vgg11().features

# Note: I only saved the first 5 layers in "features.pt"
# _ = features.load_state_dict(torch.load("../features.pt"), strict=False)

In [None]:
# |hide
# |eval: false

# Activatons of the second max pool layer of VGG11

# print(features[5])

# acts = (features[:6](numbers[None])[0]/2) # /2 to reduce clipping
# acts

In [None]:
# |hide
# |eval: false

# acts.chans

## Without `.monkey_patch`

In [None]:
lj.lovely(spicy)

DeviceArray[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.113e+03 [31m+inf![0m [31m-inf![0m [31mnan![0m

In [None]:
lj.lovely(spicy, verbose=True)

DeviceArray[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.113e+03 [31m+inf![0m [31m-inf![0m [31mnan![0m
DeviceArray([[-3.5405432e+03, -3.3692959e-05,            inf,
                        -inf,            nan, -4.0542859e-01],
             [-4.2255333e-01, -4.9105233e-01, -5.0817710e-01,
              -5.5955136e-01, -5.4242659e-01, -5.0817710e-01]],            dtype=float32)

In [None]:
lj.lovely(numbers, depth=1)

DeviceArray[3, 196, 196] n=115248 x∈[-2.118, 2.640] μ=-0.388 σ=1.073
  DeviceArray[196, 196] n=38416 x∈[-2.118, 2.249] μ=-0.324 σ=1.036
  DeviceArray[196, 196] n=38416 x∈[-1.966, 2.429] μ=-0.274 σ=0.973
  DeviceArray[196, 196] n=38416 x∈[-1.804, 2.640] μ=-0.567 σ=1.178

In [None]:
# |hide

#lj.rgb(numbers, in_stats)

In [None]:
# |hide
# torch.manual_seed(42) # For reproducibility of flots, otherwise we updatem the images in git every time.

In [None]:
# |hide

# lt.plot(numbers, center="mean")

In [None]:
# |hide

# lt.chans(numbers_01)