Skip to content

Commit

Permalink
Sourcery refactored develop branch (#544)
Browse files Browse the repository at this point in the history
* 'Refactored by Sourcery'

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: cleanup tests and sets

Co-authored-by: Sourcery AI <>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Henry Schreiner <henryschreineriii@gmail.com>
  • Loading branch information
3 people committed Mar 23, 2021
1 parent 871fbc5 commit a2a7205
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 53 deletions.
26 changes: 18 additions & 8 deletions src/boost_histogram/_internal/axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,13 @@ def __init__(
ax = ca.regular_uflow(bins, start, stop)
elif options == {"overflow"}:
ax = ca.regular_oflow(bins, start, stop)
elif options == {"circular", "underflow", "overflow"} or options == {
"circular",
"overflow",
}:
elif options in [
{"circular", "underflow", "overflow"},
{
"circular",
"overflow",
},
]:
# growth=True, underflow=False is also correct
ax = ca.regular_circular(bins, start, stop)

Expand Down Expand Up @@ -449,10 +452,17 @@ def __init__(
ax = ca.variable_uflow(edges)
elif options == {"overflow"}:
ax = ca.variable_oflow(edges)
elif options == {"circular", "underflow", "overflow",} or options == {
"circular",
"overflow",
}:
elif options in [
{
"circular",
"underflow",
"overflow",
},
{
"circular",
"overflow",
},
]:
# growth=True, underflow=False is also correct
ax = ca.variable_circular(edges)
elif options == set():
Expand Down
36 changes: 16 additions & 20 deletions src/boost_histogram/_internal/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def _fill_cast(value: T, *, inner: bool = False) -> Union[T, np.ndarray, Tuple[T


def _arg_shortcut(item: Union[Tuple[int, float, float], Axis, CppAxis]) -> CppAxis:
msg = "Developer shortcut: will be removed in a future version"
if isinstance(item, tuple) and len(item) == 3:
msg = "Developer shortcut: will be removed in a future version"
warnings.warn(msg, FutureWarning)
return _core.axis.regular_uoflow(item[0], item[1], item[2]) # type: ignore
elif isinstance(item, Axis):
Expand Down Expand Up @@ -364,10 +364,10 @@ def _compute_inplace_op(
len(other.shape), self.ndim
)
)
elif all((a == b or a == 1) for a, b in zip(other.shape, self.shape)):
elif all(a in {b, 1} for a, b in zip(other.shape, self.shape)):
view = self.view(flow=False)
getattr(view, name)(other)
elif all((a == b or a == 1) for a, b in zip(other.shape, self.axes.extent)):
elif all(a in {b, 1} for a, b in zip(other.shape, self.axes.extent)):
view = self.view(flow=True)
getattr(view, name)(other)
else:
Expand Down Expand Up @@ -494,13 +494,11 @@ def __str__(self) -> str:
"""
# TODO check the terminal width and adjust the presentation
# only use for 1D, fall back to repr for ND
if self._hist.rank() == 1:
s = str(self._hist)
# get rid of first line and last character
s = s[s.index("\n") + 1 : -1]
else:
s = repr(self)
return s
if self._hist.rank() != 1:
return repr(self)
s = str(self._hist)
# get rid of first line and last character
return s[s.index("\n") + 1 : -1]

def _axis(self, i: int = 0) -> Axis:
"""
Expand Down Expand Up @@ -547,15 +545,14 @@ def __setstate__(self, state: Any) -> None:
msg = "Cannot open boost-histogram pickle v{}".format(state[0])
raise RuntimeError(msg)

self.axes = self._generate_axes_()

else: # Classic (0.10 and before) state
self._hist = state["_hist"]
self._variance_known = True
self.metadata = state.get("metadata", None)
for i in range(self._hist.rank()):
self._hist.axis(i).metadata = {"metadata": self._hist.axis(i).metadata}
self.axes = self._generate_axes_()

self.axes = self._generate_axes_()

def __repr__(self) -> str:
newline = "\n "
Expand Down Expand Up @@ -779,14 +776,13 @@ def __getitem__( # noqa: C901

if not integrations:
return self._new_hist(reduced)
else:
projections = [i for i in range(self.ndim) if i not in integrations]
projections = [i for i in range(self.ndim) if i not in integrations]

return (
self._new_hist(reduced.project(*projections))
if projections
else reduced.sum(flow=True)
)
return (
self._new_hist(reduced.project(*projections))
if projections
else reduced.sum(flow=True)
)

def __setitem__(
self, index: IndexingExpr, value: Union[ArrayLike, Accumulator]
Expand Down
15 changes: 9 additions & 6 deletions src/boost_histogram/_internal/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,16 @@ def __array_ufunc__(
return ufunc(np.asarray(inputs[0]), np.asarray(inputs[1]), **kwargs) # type: ignore

# Support unary + and -
if method == "__call__" and len(inputs) == 1:
if ufunc in {np.negative, np.positive}:
(result,) = kwargs.pop("out", [np.empty(self.shape, self.dtype)])
if (
method == "__call__"
and len(inputs) == 1
and ufunc in {np.negative, np.positive}
):
(result,) = kwargs.pop("out", [np.empty(self.shape, self.dtype)])

ufunc(inputs[0]["value"], out=result["value"], **kwargs)
result["variance"] = inputs[0]["variance"]
return result.view(self.__class__) # type: ignore
ufunc(inputs[0]["value"], out=result["value"], **kwargs)
result["variance"] = inputs[0]["variance"]
return result.view(self.__class__) # type: ignore

if method == "__call__" and len(inputs) == 2:
input_0 = np.asarray(inputs[0])
Expand Down
16 changes: 8 additions & 8 deletions src/boost_histogram/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ def histogram2d(
threads=threads,
)

if isinstance(result, tuple):
data, (edgesx, edgesy) = result
return data, edgesx, edgesy
else:
if not isinstance(result, tuple):
return result

data, (edgesx, edgesy) = result
return data, edgesx, edgesy


def histogram(
a: ArrayLike,
Expand Down Expand Up @@ -162,12 +162,12 @@ def histogram(
storage=storage,
threads=threads,
)
if isinstance(result, tuple):
data, (edges,) = result
return data, edges
else:
if not isinstance(result, tuple):
return result

data, (edges,) = result
return data, edges


# Process docstrings
for f, n in zip(
Expand Down
10 changes: 5 additions & 5 deletions tests/test_accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ def test_sum_mean(list1, list2):

ab = a + b
assert ab.value == approx(c.value)
assert ab.variance == approx(c.variance, nan_ok=True, abs=1e-9, rel=1e-9)
assert ab.variance == approx(c.variance, nan_ok=True, abs=1e-7, rel=1e-3)
assert ab.count == approx(c.count)

a += b
assert a.value == approx(c.value)
assert a.variance == approx(c.variance, nan_ok=True, abs=1e-9, rel=1e-9)
assert a.variance == approx(c.variance, nan_ok=True, abs=1e-7, rel=1e-3)
assert a.count == approx(c.count)


Expand All @@ -129,7 +129,7 @@ def test_sum_mean(list1, list2):
st.lists(float_st, min_size=n, max_size=n),
st.lists(
st.floats(
allow_nan=False, allow_infinity=False, min_value=1e-4, max_value=1e5
allow_nan=False, allow_infinity=False, min_value=1e-2, max_value=1e3
),
min_size=n,
max_size=n,
Expand All @@ -151,12 +151,12 @@ def test_sum_weighed_mean(pair1, pair2):

ab = a + b
assert ab.value == approx(c.value)
assert ab.variance == approx(c.variance, nan_ok=True, abs=1e-9, rel=1e-9)
assert ab.variance == approx(c.variance, nan_ok=True, abs=1e-7, rel=1e-3)
assert ab.sum_of_weights == approx(c.sum_of_weights)
assert ab.sum_of_weights_squared == approx(c.sum_of_weights_squared)

a += b
assert a.value == approx(c.value)
assert a.variance == approx(c.variance, nan_ok=True, abs=1e-9, rel=1e-9)
assert a.variance == approx(c.variance, nan_ok=True, abs=1e-7, rel=1e-3)
assert a.sum_of_weights == approx(c.sum_of_weights)
assert a.sum_of_weights_squared == approx(c.sum_of_weights_squared)
12 changes: 6 additions & 6 deletions tests/test_minihist_title.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ class NamedAxesTuple(bh.axis.AxesTuple):
__slots__ = ()

def _get_index_by_name(self, name):
if isinstance(name, str):
for i, ax in enumerate(self):
if ax.name == name:
return i
raise KeyError(f"{name} not found in axes")
else:
if not isinstance(name, str):
return name

for i, ax in enumerate(self):
if ax.name == name:
return i
raise KeyError(f"{name} not found in axes")

def __getitem__(self, item):
if isinstance(item, slice):
item = slice(
Expand Down

0 comments on commit a2a7205

Please sign in to comment.