# üíò Lovely JAX

> After all, you are only human.

::: {.content-visible when-format="markdown"}
## [Read full docs](https://xl0.github.io/lovely-jax) | ‚ù§Ô∏è [Lovely Tensors](https://github.com/xl0/lovely-tensors) | üíü [Lovely `NumPy`](https://github.com/xl0/lovely-numpy) | [Discord](https://discord.gg/4NxRV7NH)
:::

::: {.content-visible when-format="html"}
<h2><a href="https://github.com/xl0/lovely-jax">Source code</a> | ‚ù§Ô∏è <a href="https://xl0.github.io/lovely-tensors"> Lovely Tensors</a> | üíü <a href="https://xl0.github.io/lovely-numpy"> Lovely <tt>NumPy</tt></a> | <a href="https://discord.gg/4NxRV7NH">Discord</a>
</h2>
:::


## Note: I'm pretty new to JAX.
If something does not make sense, shoot me an [Issue](https://github.com/xl0/lovely-jax/issues) and let me know how it's supposed to work!

## Install

```sh
pip install lovely-jax
```

## How to use

In [1]:
# |hide
import numpy as np
from PIL import Image
import jax, jax.numpy as jnp
from fastcore.test import test_eq

In [2]:
# |hide
np.set_printoptions(linewidth=120, precision=3)

In [3]:
# |hide
# Don't depend on torchvision
# numbers = VF.normalize(VF.center_crop(VF.to_tensor(Image.open("tenchman.jpg")), 196), **in_stats)
numbers = jnp.load("mysteryman.npy")
numbers1 = numbers

How often do you find yourself debugging JAX code? You dump an array to the cell output, and see this:

In [4]:
# |hide
import lovely_jax as lj

In [5]:
# |hide
# A trick to make sure README.md shows the plain version.
lj.monkey_patch()
numbers = numbers.p


In [6]:
numbers

DeviceArray([[[-0.354, -0.337, -0.405, ..., -0.56 , -0.474,  2.249],
              [-0.405, -0.423, -0.491, ..., -0.919, -0.851,  2.163],
              [-0.474, -0.474, -0.542, ..., -1.039, -1.039,  2.198],
              ...,
              [-0.902, -0.834, -0.936, ..., -1.467, -1.296,  2.232],
              [-0.851, -0.782, -0.936, ..., -1.604, -1.501,  2.18 ],
              [-0.834, -0.816, -0.971, ..., -1.656, -1.553,  2.112]],

             [[-0.197, -0.197, -0.303, ..., -0.478, -0.373,  2.411],
              [-0.25 , -0.232, -0.338, ..., -0.705, -0.67 ,  2.359],
              [-0.303, -0.285, -0.39 , ..., -0.74 , -0.81 ,  2.376],
              ...,
              [-0.425, -0.232, -0.373, ..., -1.09 , -1.02 ,  2.429],
              [-0.39 , -0.232, -0.425, ..., -1.23 , -1.23 ,  2.411],
              [-0.408, -0.285, -0.478, ..., -1.283, -1.283,  2.341]],

             [[-0.672, -0.985, -0.881, ..., -0.968, -0.689,  2.396],
              [-0.724, -1.072, -0.968, ..., -1.247, -1.02 ,  

In [7]:
# | hide
numbers = numbers1

Was it really useful for you, as a human, to see all these numbers?

What is the shape? The size?\
What are the statistics?\
Are any of the values `nan` or `inf`?\
Is it an image of a man holding a tench?

In [8]:
import lovely_jax as lj

In [9]:
lj.monkey_patch()

## `__repr__`

In [10]:
numbers # torch.Tensor

DeviceArray[3, 196, 196] n=115248 x‚àà[-2.118, 2.640] Œº=-0.388 œÉ=1.073 gpu:0

Better, huh?

In [11]:
numbers[1,:6,1] # Still shows values if there are not too many.

DeviceArray[6] x‚àà[-0.443, -0.197] Œº=-0.311 œÉ=0.083 gpu:0 [-0.197, -0.232, -0.285, -0.373, -0.443, -0.338]

In [12]:
spicy = numbers.flatten()[:12].copy()

spicy = (spicy  .at[0].mul(10000)
                .at[1].divide(10000)
                .at[2].set(float('inf'))
                .at[3].set(float('-inf'))
                .at[4].set(float('nan'))
                .reshape((2,6)))
spicy # Spicy stuff

DeviceArray[2, 6] n=12 x‚àà[-3.541e+03, -3.369e-05] Œº=-393.776 œÉ=1.180e+03 [31m+Inf![0m [31m-Inf![0m [31mNaN![0m gpu:0

In [13]:
jnp.zeros((10, 10)) # A zero tensor - make it obvious

DeviceArray[10, 10] [38;2;127;127;127mall_zeros[0m gpu:0

In [14]:
spicy.v # Verbose

DeviceArray[2, 6] n=12 x‚àà[-3.541e+03, -3.369e-05] Œº=-393.776 œÉ=1.180e+03 [31m+Inf![0m [31m-Inf![0m [31mNaN![0m gpu:0
DeviceArray([[-3.541e+03, -3.369e-05,        inf,       -inf,        nan, -4.054e-01],
             [-4.226e-01, -4.911e-01, -5.082e-01, -5.596e-01, -5.424e-01, -5.082e-01]], dtype=float32)

In [15]:
spicy.p # The plain old way

DeviceArray([[-3.541e+03, -3.369e-05,        inf,       -inf,        nan, -4.054e-01],
             [-4.226e-01, -4.911e-01, -5.082e-01, -5.596e-01, -5.424e-01, -5.082e-01]], dtype=float32)

## Going `.deeper`

In [16]:
numbers.deeper

DeviceArray[3, 196, 196] n=115248 x‚àà[-2.118, 2.640] Œº=-0.388 œÉ=1.073 gpu:0
  DeviceArray[196, 196] n=38416 x‚àà[-2.118, 2.249] Œº=-0.324 œÉ=1.036 gpu:0
  DeviceArray[196, 196] n=38416 x‚àà[-1.966, 2.429] Œº=-0.274 œÉ=0.973 gpu:0
  DeviceArray[196, 196] n=38416 x‚àà[-1.804, 2.640] Œº=-0.567 œÉ=1.178 gpu:0

In [17]:
# You can go deeper if you need to
numbers[:,:3,:5].deeper(2)

DeviceArray[3, 3, 5] n=45 x‚àà[-1.316, -0.197] Œº=-0.593 œÉ=0.302 gpu:0
  DeviceArray[3, 5] n=15 x‚àà[-0.765, -0.337] Œº=-0.492 œÉ=0.119 gpu:0
    DeviceArray[5] x‚àà[-0.440, -0.337] Œº=-0.385 œÉ=0.037 gpu:0 [-0.354, -0.337, -0.405, -0.440, -0.388]
    DeviceArray[5] x‚àà[-0.662, -0.405] Œº=-0.512 œÉ=0.097 gpu:0 [-0.405, -0.423, -0.491, -0.577, -0.662]
    DeviceArray[5] x‚àà[-0.765, -0.474] Œº=-0.580 œÉ=0.112 gpu:0 [-0.474, -0.474, -0.542, -0.645, -0.765]
  DeviceArray[3, 5] n=15 x‚àà[-0.513, -0.197] Œº=-0.321 œÉ=0.096 gpu:0
    DeviceArray[5] x‚àà[-0.303, -0.197] Œº=-0.243 œÉ=0.049 gpu:0 [-0.197, -0.197, -0.303, -0.303, -0.215]
    DeviceArray[5] x‚àà[-0.408, -0.232] Œº=-0.327 œÉ=0.075 gpu:0 [-0.250, -0.232, -0.338, -0.408, -0.408]
    DeviceArray[5] x‚àà[-0.513, -0.285] Œº=-0.394 œÉ=0.091 gpu:0 [-0.303, -0.285, -0.390, -0.478, -0.513]
  DeviceArray[3, 5] n=15 x‚àà[-1.316, -0.672] Œº=-0.964 œÉ=0.170 gpu:0
    DeviceArray[5] x‚àà[-0.985, -0.672] Œº=-0.846 œÉ=0.110 gpu:0 [-0.672, -0.98

In [18]:
# __tracebackhide__=False

# def f(x):
#     __tracebackhide__=False

#     jax.debug.print(" sdfs {x} sfsdf", x=x)
#     jax.debug.print(" sdfs {x} sfsdf", x=type(x))
#     return x*2

# fj = jax.jit(f)

# _ = fj(numbers)


# # print(numbers)

In [19]:
print(repr(numbers))

DeviceArray[3, 196, 196] n=115248 x‚àà[-2.118, 2.640] Œº=-0.388 œÉ=1.073 gpu:0


In [20]:
# |hide

# numbers.rgb

In [21]:
# |hide

# in_stats = ( (0.485, 0.456, 0.406),     # mean 
#              (0.229, 0.224, 0.225) )    # std

# numbers.rgb(in_stats, cl=True) # For channel-last input format
# numbers.rgb(in_stats)

In [22]:
# |hide

# torch.manual_seed(42) # For reproducibility of flots, otherwise we updatem the images in git every time.

In [23]:
# |hide

# (numbers+3).plt

In [24]:
# |hide

# (numbers+3).plt(center="mean", max_s=1000)

In [25]:
# |hide


# (numbers+3).plt(center="range")

In [26]:
# |hide

# .chans will map values betwen [0,1] to colors.
# Make our values fit into that range to avoid clipping.
# mean = torch.tensor(in_stats[0])[:,None,None]
# std = torch.tensor(in_stats[1])[:,None,None]
# numbers_01 = (numbers*std + mean)
# numbers_01

In [27]:
# |hide
# |eval: false

# numbers_01.chans

In [28]:
# |hide
# |eval: false


# |eval: false
# from torchvision.models import vgg11

In [29]:
# |hide
# |eval: false

# features = vgg11().features

# Note: I only saved the first 5 layers in "features.pt"
# _ = features.load_state_dict(torch.load("../features.pt"), strict=False)

In [30]:
# |hide
# |eval: false

# Activatons of the second max pool layer of VGG11

# print(features[5])

# acts = (features[:6](numbers[None])[0]/2) # /2 to reduce clipping
# acts

In [31]:
# |hide
# |eval: false

# acts.chans

## Without `monkey_patch()`

In [32]:
lj.lovely(spicy)

DeviceArray[2, 6] n=12 x‚àà[-3.541e+03, -3.369e-05] Œº=-393.776 œÉ=1.180e+03 [31m+Inf![0m [31m-Inf![0m [31mNaN![0m gpu:0

In [33]:
lj.lovely(spicy, verbose=True)

DeviceArray[2, 6] n=12 x‚àà[-3.541e+03, -3.369e-05] Œº=-393.776 œÉ=1.180e+03 [31m+Inf![0m [31m-Inf![0m [31mNaN![0m gpu:0
DeviceArray([[-3.541e+03, -3.369e-05,        inf,       -inf,        nan, -4.054e-01],
             [-4.226e-01, -4.911e-01, -5.082e-01, -5.596e-01, -5.424e-01, -5.082e-01]], dtype=float32)

In [34]:
lj.lovely(numbers, depth=1)

DeviceArray[3, 196, 196] n=115248 x‚àà[-2.118, 2.640] Œº=-0.388 œÉ=1.073 gpu:0
  DeviceArray[196, 196] n=38416 x‚àà[-2.118, 2.249] Œº=-0.324 œÉ=1.036 gpu:0
  DeviceArray[196, 196] n=38416 x‚àà[-1.966, 2.429] Œº=-0.274 œÉ=0.973 gpu:0
  DeviceArray[196, 196] n=38416 x‚àà[-1.804, 2.640] Œº=-0.567 œÉ=1.178 gpu:0

In [35]:
# |hide

#lj.rgb(numbers, in_stats)

In [36]:
# |hide
# torch.manual_seed(42) # For reproducibility of flots, otherwise we updatem the images in git every time.

In [37]:
# |hide

# lt.plot(numbers, center="mean")

In [38]:
# |hide

# lt.chans(numbers_01)