Skip to content

Commit

Permalink
fix: repr limit and fallback on normal repr
Browse files Browse the repository at this point in the history
  • Loading branch information
henryiii committed Mar 10, 2022
1 parent 4e7d599 commit 7abccb6
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 76 deletions.
85 changes: 45 additions & 40 deletions src/hist/basehist.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .axis import AxisProtocol
from .quick_construct import MetaConstructor
from .storage import Storage
from .svgplots import html_hist, svg_hist_1d, svg_hist_1d_c, svg_hist_2d, svg_hist_nd
from .svgplots import html_hist, svg_hist_1d, svg_hist_1d_c, svg_hist_2d
from .typing import ArrayLike, Protocol, SupportsIndex

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -87,37 +87,37 @@ def __init__(
for a in args
]

if args:
if isinstance(storage, str):
storage_str = storage.title()
if storage_str == "Atomicint64":
storage_str = "AtomicInt64"
elif storage_str == "Weightedmean":
storage_str = "WeightedMean"
storage = getattr(bh.storage, storage_str)()
elif isinstance(storage, type):
msg = (
f"Please use '{storage.__name__}()' instead of '{storage.__name__}'"
if isinstance(storage, str):
storage_str = storage.title()
if storage_str == "Atomicint64":
storage_str = "AtomicInt64"
elif storage_str == "Weightedmean":
storage_str = "WeightedMean"
storage = getattr(bh.storage, storage_str)()
elif isinstance(storage, type):
msg = f"Please use '{storage.__name__}()' instead of '{storage.__name__}'"
warnings.warn(msg)
storage = storage()

super().__init__(*args, storage=storage, metadata=metadata) # type: ignore[call-overload]

disallowed_names = {"weight", "sample", "threads"}
for ax in self.axes:
if ax.name in disallowed_names:
disallowed_warning = (
f"{ax.name} is a protected keyword and cannot be used as axis name"
)
warnings.warn(msg)
storage = storage()
super().__init__(*args, storage=storage, metadata=metadata) # type: ignore[call-overload]

disallowed_names = {"weight", "sample", "threads"}
for ax in self.axes:
if ax.name in disallowed_names:
disallowed_warning = f"{ax.name} is a protected keyword and cannot be used as axis name"
warnings.warn(disallowed_warning)

valid_names = [ax.name for ax in self.axes if ax.name]
if len(valid_names) != len(set(valid_names)):
raise KeyError(
f"{self.__class__.__name__} instance cannot contain axes with duplicated names"
)
for i, ax in enumerate(self.axes):
# label will return name if label is not set, so this is safe
if not ax.label:
ax.label = f"Axis {i}"
warnings.warn(disallowed_warning)

valid_names = [ax.name for ax in self.axes if ax.name]
if len(valid_names) != len(set(valid_names)):
raise KeyError(
f"{self.__class__.__name__} instance cannot contain axes with duplicated names"
)
for i, ax in enumerate(self.axes):
# label will return name if label is not set, so this is safe
if not ax.label:
ax.label = f"Axis {i}"

if data is not None:
self[...] = data
Expand All @@ -130,19 +130,24 @@ def _generate_axes_(self) -> NamedAxesTuple:

return NamedAxesTuple(self._axis(i) for i in range(self.ndim))

def _repr_html_(self) -> str:
def _repr_html_(self) -> str | None:
if self.size == 0:
return str(self)
return None

if self.ndim == 1:
if self.axes[0].traits.circular:
return str(html_hist(self, svg_hist_1d_c))
return str(html_hist(self, svg_hist_1d))
if len(self.axes[0]) <= 1000:
return str(
html_hist(
self,
svg_hist_1d_c if self.axes[0].traits.circular else svg_hist_1d,
)
)

if self.ndim == 2:
return str(html_hist(self, svg_hist_2d))
if self.ndim > 2:
return str(html_hist(self, svg_hist_nd))
if len(self.axes[0]) <= 200 and len(self.axes[1]) <= 200:
return str(html_hist(self, svg_hist_2d))

return str(self)
return None

def _name_to_index(self, name: str) -> int:
"""
Expand Down
29 changes: 0 additions & 29 deletions src/hist/svgplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,32 +193,3 @@ def svg_hist_2d(h: hist.BaseHist) -> svg:
]

return svg(*texts, *boxes, viewBox=f"{-20} {-height - 20} {width+40} {height+40}")


def svg_hist_nd(h: hist.BaseHist) -> svg:
assert h.ndim > 2, "Must be more than 2D"

width = 200
height = 200

boxes = [
rect(
x=20 * i,
y=20 * i,
width=width - 40,
height=height - 40,
style="fill:white;opacity:.5;stroke-width:2;stroke:currentColor;",
)
for i in range(3)
]

nd = text(
f"{h.ndim}D",
x=height / 2 + 20,
y=width / 2 + 20,
style="font-size: 26pt; font-family: verdana; font-style: bold; fill: black;",
text_anchor="middle",
alignment_baseline="middle",
)

return svg(*boxes, nd, viewBox=f"-10 -10 {height + 20} {width + 20}")
19 changes: 12 additions & 7 deletions tests/test_reprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,18 @@ def test_ND_empty_repr(named_hist):
.Double()
)
html = h._repr_html_()
assert html
assert "name='x'" in repr(h)
assert "name='p'" in repr(h)
assert "name='a'" in repr(h)
assert "label='y'" in repr(h)
assert "label='q'" in repr(h)
assert "label='b'" in repr(h)
assert html is None


def test_empty_mega_repr(named_hist):

h = named_hist.new.Reg(1001, -1, 1, name="x").Double()
html = h._repr_html_()
assert html is None

h = named_hist.new.Reg(201, -1, 1, name="x").Reg(100, 0, 1, name="y").Double()
html = h._repr_html_()
assert html is None


def test_stack_repr(named_hist):
Expand Down

0 comments on commit 7abccb6

Please sign in to comment.