-
Notifications
You must be signed in to change notification settings - Fork 3
/
repr_rgb.py
87 lines (67 loc) · 3.05 KB
/
repr_rgb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_repr_rgb.ipynb.
# %% auto 0
__all__ = ['rgb']
# %% ../nbs/01_repr_rgb.ipynb 4
from typing import Any, Optional as O
from matplotlib import axes, figure
from IPython.core.pylabtools import print_figure
from PIL import Image
import jax, jax.numpy as jnp
from lovely_numpy.utils.utils import cached_property
from lovely_numpy.utils.pad import pad_frame_gutters
from lovely_numpy.utils.tile2d import hypertile
from lovely_numpy.repr_rgb import fig_rgb
from lovely_numpy import config as np_config
from .utils.misc import to_numpy
from .utils.config import get_config
# %% ../nbs/01_repr_rgb.ipynb 6
# This is here for the monkey-patched tensor use case.
# I want to be able to call both `tensor.rgb` and `tensor.rgb(stats)`. For the
# first case, the class defines `_repr_png_` to send the image to Jupyter. For
# the later case, it defines __call__, which accps the argument.
class RGBProxy():
"""Flexible `PIL.Image.Image` wrapper"""
def __init__(self, x: jax.Array):
assert x.ndim >= 3, f"Expecting at least 3 dimensions, got shape{x.shape}={x.ndim}"
self.x =x
self.params = dict(denorm = None,
cl = True,
gutter_px = 3,
frame_px = 1,
scale = 1,
view_width = 966,
ax = None)
def __call__(self,
denorm :Any =None,
cl :Any =True,
gutter_px :O[int] =None,
frame_px :O[int] =None,
scale :O[int] =None,
view_width :O[int] =None,
ax :O[axes.Axes]=None):
self.params.update( { k:v for
k,v in locals().items()
if k != "self" and v is not None } )
_ = self.fig # Trigger figure generation
return self
@cached_property
def fig(self) -> figure.Figure:
cfg = get_config()
with np_config(fig_close=cfg.fig_close, fig_show=cfg.fig_show):
return fig_rgb(to_numpy(self.x), **self.params)
def _repr_png_(self):
return print_figure(self.fig, fmt="png", pad_inches=0,
metadata={"Software": "Matplotlib, https://matplotlib.org/"})
# %% ../nbs/01_repr_rgb.ipynb 7
def rgb(x :jax.Array, # Tensor to display. [[...], C,H,W] or [[...], H,W,C]
denorm :Any =None, # Reverse per-channel normalizatoin
cl :Any =True, # Channel-last
gutter_px :int =3, # If more than one tensor -> tile with this gutter width
frame_px :int =1, # If more than one tensor -> tile with this frame width
scale :int =1, # Scale up. Can't scale down.
view_width :int =966, # target width of the image
ax :O[axes.Axes] =None # Use this Axes
) -> RGBProxy:
args = locals()
del args["x"]
return RGBProxy(x)(**args)