Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: starting slicing #755

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ repos:
hooks:
- id: mypy
files: ^src
additional_dependencies: [numpy==1.22.4, pytest, uhi, types-dataclasses]
additional_dependencies: [numpy~=1.23.0, pytest, uhi, types-dataclasses]

- repo: https://github.com/mgedmin/check-manifest
rev: "0.48"
Expand Down
22 changes: 21 additions & 1 deletion src/boost_histogram/_internal/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,16 @@ def extent(self) -> int:
"""
return self._ax.extent # type: ignore[no-any-return]

def __getitem__(self, i: AxCallOrInt) -> Union[int, str, Tuple[float, float]]:
def __getitem__(
self: T, i: Union[AxCallOrInt, slice]
) -> Union[int, str, Tuple[float, float], T]:
"""
Access a bin, using normal Python syntax for wraparound.
"""
if isinstance(i, slice):
raise NotImplementedError(
f"Slicing not supported on {self.__class__.__name__}"
)
# UHI support
if callable(i):
i = i(self)
Expand All @@ -241,6 +247,7 @@ def __getitem__(self, i: AxCallOrInt) -> Union[int, str, Tuple[float, float]]:
f"Out of range access, {i} is more than {self._ax.size}"
)
assert not callable(i)
assert not isinstance(i, slice)
return self.bin(i)

@property
Expand Down Expand Up @@ -612,6 +619,9 @@ def _repr_args_(self) -> List[str]:
return ret


TStrC = TypeVar("TStrC", bound="StrCategory")


@set_module("boost_histogram.axis")
@register({ca.category_str_growth, ca.category_str})
class StrCategory(BaseCategory, family=boost_histogram):
Expand Down Expand Up @@ -660,6 +670,16 @@ def __init__(

super().__init__(ax, metadata, __dict__)

def __getitem__(
self: TStrC, i: Union[AxCallOrInt, slice]
) -> Union[int, str, Tuple[float, float], TStrC]:

if isinstance(i, slice):
new_cats = list(self)[i]
return self.__class__(new_cats, __dict__=self.__dict__) # type: ignore[arg-type]
else:
return super().__getitem__(i)

def index(self, value: Union[float, str]) -> int:
"""
Return the fractional index(es) given a value (or values) on the axis.
Expand Down
20 changes: 20 additions & 0 deletions tests/test_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,26 @@ def test_edges_centers_widths(self, ref, growth):
assert_allclose(a.centers, [0.5, 1.5, 2.5])
assert_allclose(a.widths, [1, 1, 1])

def test_slicing(self, growth):
Cat = bh.axis.StrCategory
ref = ["a", "b", "c", "d", "e"]

a = Cat(ref, growth=growth)
b = a[1:3]
assert list(a)[1:3] == list(b)
assert a.__dict__ == b.__dict__
assert a.traits.growth == b.traits.growth

def test_empty_slice(self, growth):
Cat = bh.axis.StrCategory
ref = ["a", "b", "c", "d", "e"]
a = Cat(ref, growth=growth)
if growth:
assert a[0:0] == Cat([], growth=True)
else:
with pytest.raises(RuntimeError):
a[0:0]


class TestBoolean:
def test_init(self):
Expand Down