# Monkey-patching


In [1]:
#| default_exp patch

In [2]:
# |hide
from nbdev.showdoc import *

In [3]:
#| hide
#| export

import numpy as np
import jax
import jax.numpy as jnp
from jax._src import array
from fastcore.foundation import patch_to
import matplotlib.pyplot as plt

from lovely_jax.repr_str import StrProxy
# from lovely_tensors.repr_rgb import RGBProxy
# from lovely_tensors.repr_plt import PlotProxy
# from lovely_tensors.repr_chans import ChanProxy

In [4]:
# |export
def _monkey_patch(cls):
    "Monkey-patch lovely features into `cls`" 

    # print(cls)
    # print(cls.__repr__)
    # print(cls.__repr)

    if not hasattr(cls, '_plain_repr'):
        cls._plain_repr = cls.__repr__
        cls._plain_str = cls.__str__

    @patch_to(cls)
    def __repr__(self: jax.Array):
        return str(StrProxy(self))
    
    # __str__ is used when you do print(), and gives a less detailed version of the object.
    # __repr__ is used when you inspect an object in Jupyter or VSCode, and gives a more detailed version.
    # I think we want to patch both.
    @patch_to(cls)
    def __str__(self: jax.Array):
        return str(StrProxy(self))


    # Plain - the old behavior
    @patch_to(cls, as_prop=True)
    def p(self: jax.Array):
        return StrProxy(self, plain=True)

    # Verbose - print both stats and plain values
    @patch_to(cls, as_prop=True)
    def v(self: jax.Array):
        return StrProxy(self, verbose=True)

    @patch_to(cls, as_prop=True)
    def deeper(self: jax.Array):
        return StrProxy(self, depth=1)

    # @patch_to(cls, as_prop=True)
    # def rgb(t: torch.Tensor):
    #     return RGBProxy(t)
    
    # @patch_to(cls, as_prop=True)
    # def chans(t: torch.Tensor):
    #     return ChanProxy(t)

    # @patch_to(cls, as_prop=True)
    # def plt(t: torch.Tensor):
    #     return PlotProxy(t)


def monkey_patch():
    _monkey_patch(array.ArrayImpl)
    _monkey_patch(array.DeviceArray)    

In [5]:
monkey_patch()

In [6]:
image = jnp.load("mysteryman.npy")

In [7]:
spicy = image[0,0,:12].copy()

spicy = (spicy  .at[0].mul(10000)
                .at[1].divide(10000)
                .at[2].set(float('inf'))
                .at[3].set(float('-inf'))
                .at[4].set(float('nan'))
                .reshape((2,6)))

spicy

Array[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.180e+03 [31m+Inf![0m [31m-Inf![0m [31mNaN![0m gpu:0

In [8]:
spicy.v # Verbose

Array[2, 6] n=12 x∈[-3.541e+03, -3.369e-05] μ=-393.776 σ=1.180e+03 [31m+Inf![0m [31m-Inf![0m [31mNaN![0m gpu:0
Array([[-3.5405432e+03, -3.3692959e-05,            inf,           -inf,
                   nan, -4.0542859e-01],
       [-4.2255333e-01, -4.9105233e-01, -5.0817710e-01, -5.5955136e-01,
        -5.4242659e-01, -5.0817710e-01]], dtype=float32)

In [9]:
spicy.p # Plain

Array([[-3.5405432e+03, -3.3692959e-05,            inf,           -inf,
                   nan, -4.0542859e-01],
       [-4.2255333e-01, -4.9105233e-01, -5.0817710e-01, -5.5955136e-01,
        -5.4242659e-01, -5.0817710e-01]], dtype=float32)

In [10]:
image.deeper

Array[3, 196, 196] n=115248 x∈[-2.118, 2.640] μ=-0.388 σ=1.073 gpu:0
  Array[196, 196] n=38416 x∈[-2.118, 2.249] μ=-0.324 σ=1.036 gpu:0
  Array[196, 196] n=38416 x∈[-1.966, 2.429] μ=-0.274 σ=0.973 gpu:0
  Array[196, 196] n=38416 x∈[-1.804, 2.640] μ=-0.567 σ=1.178 gpu:0

In [11]:
dt = image[:3,:3,:5]
dt.deeper(3)

Array[3, 3, 5] n=45 x∈[-1.316, -0.197] μ=-0.593 σ=0.302 gpu:0
  Array[3, 5] n=15 x∈[-0.765, -0.337] μ=-0.492 σ=0.119 gpu:0
    Array[5] x∈[-0.440, -0.337] μ=-0.385 σ=0.037 gpu:0 [-0.354, -0.337, -0.405, -0.440, -0.388]
    Array[5] x∈[-0.662, -0.405] μ=-0.512 σ=0.097 gpu:0 [-0.405, -0.423, -0.491, -0.577, -0.662]
    Array[5] x∈[-0.765, -0.474] μ=-0.580 σ=0.112 gpu:0 [-0.474, -0.474, -0.542, -0.645, -0.765]
  Array[3, 5] n=15 x∈[-0.513, -0.197] μ=-0.321 σ=0.096 gpu:0
    Array[5] x∈[-0.303, -0.197] μ=-0.243 σ=0.049 gpu:0 [-0.197, -0.197, -0.303, -0.303, -0.215]
    Array[5] x∈[-0.408, -0.232] μ=-0.327 σ=0.075 gpu:0 [-0.250, -0.232, -0.338, -0.408, -0.408]
    Array[5] x∈[-0.513, -0.285] μ=-0.394 σ=0.091 gpu:0 [-0.303, -0.285, -0.390, -0.478, -0.513]
  Array[3, 5] n=15 x∈[-1.316, -0.672] μ=-0.964 σ=0.170 gpu:0
    Array[5] x∈[-0.985, -0.672] μ=-0.846 σ=0.110 gpu:0 [-0.672, -0.985, -0.881, -0.776, -0.916]
    Array[5] x∈[-1.212, -0.724] μ=-0.989 σ=0.160 gpu:0 [-0.724, -1.072, -0.968, -0.

In [12]:
# |hide
# image.rgb

In [13]:
# |hide
# in_stats = ( (0.485, 0.456, 0.406),     # mean 
#              (0.229, 0.224, 0.225) )    # std
# image.rgb(in_stats)

In [14]:
# |hide
# mean = torch.tensor(in_stats[0])[:,None,None]
# std = torch.tensor(in_stats[1])[:,None,None]

# (image*std + mean).chans # all pixels in [0, 1] range

In [15]:
# |hide
# (image*0.3+0.5) # Slightly outside of [0, 1] range

In [16]:
# |hide
# (image*0.3+0.5).chans # shows clipping (bright blue/red)

In [17]:
# |hide
# image.plt

In [18]:
# |hide
# image.plt(center="mean")

In [19]:
# |hide
# fig, ax = plt.subplots(figsize=(6, 2))
# plt.close(fig)
# image.plt(ax=ax)
# fig

In [20]:
# |hide
import nbdev; nbdev.nbdev_export()