-
Notifications
You must be signed in to change notification settings - Fork 3
/
patch.py
84 lines (66 loc) · 2.46 KB
/
patch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/10_patch.ipynb.
# %% auto 0
__all__ = ['monkey_patch']
# %% ../nbs/10_patch.ipynb 5
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import array
from fastcore.foundation import patch_to
import matplotlib.pyplot as plt
from .repr_str import StrProxy
from .repr_rgb import RGBProxy
from .repr_plt import PlotProxy
from .repr_chans import ChanProxy
# %% ../nbs/10_patch.ipynb 6
def _monkey_patch(cls):
"Monkey-patch lovely features into `cls`"
if not hasattr(cls, '_plain_repr'):
cls._plain_repr = cls.__repr__
cls._plain_str = cls.__str__
cls._plain_format = cls.__format__
@patch_to(cls)
def __repr__(self: jax.Array):
return str(StrProxy(self))
# __str__ is used when you do print(), and gives a less detailed version of the object.
# __repr__ is used when you inspect an object in Jupyter or VSCode, and gives a more detailed version.
# I think we want to patch both.
@patch_to(cls)
def __str__(self: jax.Array):
return str(StrProxy(self))
# Without this, the native __format__ will call into numpy formatter
# and will produce raw numbers. Idea: A way to pass fmt through?
@patch_to(cls)
def __format__(self: jax.Array, tmp: str):
return str(StrProxy(self))
# Plain - the old behavior
@patch_to(cls, as_prop=True)
def p(self: jax.Array):
return StrProxy(self, plain=True)
# Verbose - print both stats and plain values
@patch_to(cls, as_prop=True)
def v(self: jax.Array):
return StrProxy(self, verbose=True)
@patch_to(cls, as_prop=True)
def deeper(self: jax.Array):
return StrProxy(self, depth=1)
@patch_to(cls, as_prop=True)
def rgb(t: jax.Array):
return RGBProxy(t)
@patch_to(cls, as_prop=True)
def chans(t: jax.Array):
return ChanProxy(t)
@patch_to(cls, as_prop=True)
def plt(t: jax.Array):
return PlotProxy(t)
def monkey_patch():
_monkey_patch(array.ArrayImpl)
# To support jax version higher than 0.4.14
if hasattr(array, "DeviceArray"):
_monkey_patch(array.DeviceArray)
# This was required for earlied version of jax 0.4.x
# In jax version higher than 0.4.14 pxla is not accesible
# instead we use jax.interpreters.pxla
if not hasattr(jax, "interpreters"):
if hasattr(jax.pxla, '_ShardedDeviceArray'):
_monkey_patch(jax.pxla._ShardedDeviceArray)