Skip to content

Commit

Permalink
Check method only for dask reductions. (#241)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed May 8, 2023
1 parent 622ddb2 commit 4164712
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
12 changes: 6 additions & 6 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,12 +1875,6 @@ def groupby_reduce(
axis_ = np.core.numeric.normalize_axis_tuple(axis, array.ndim) # type: ignore
nax = len(axis_)

if method in ["blockwise", "cohorts"] and nax != by_.ndim:
raise NotImplementedError(
"Must reduce along all dimensions of `by` when method != 'map-reduce'."
f"Received method={method!r}"
)

# TODO: make sure expected_groups is unique
if nax == 1 and by_.ndim > 1 and expected_groups is None:
if not any_by_dask:
Expand Down Expand Up @@ -1949,6 +1943,12 @@ def groupby_reduce(
f"\n\n Received: {func}"
)

if method in ["blockwise", "cohorts"] and nax != by_.ndim:
raise NotImplementedError(
"Must reduce along all dimensions of `by` when method != 'map-reduce'."
f"Received method={method!r}"
)

# TODO: just do this in dask_groupby_agg
# we always need some fill_value (see above) so choose the default if needed
if kwargs["fill_value"] is None:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,3 +1347,37 @@ def test_expected_index_conversion_passthrough_range_index(sort):
expected_groups=(index,), isbin=(False,), sort=(sort,)
)
assert actual[0] is index


def test_method_check_numpy():
bins = [-2, -1, 0, 1, 2]
field = np.ones((5, 3))
by = np.array([[-1.5, -1.5, 0.5, 1.5, 1.5] * 3]).reshape(5, 3)
actual, _ = groupby_reduce(
field,
by,
expected_groups=pd.IntervalIndex.from_breaks(bins),
func="count",
method="cohorts",
fill_value=np.nan,
)
expected = np.array([6, np.nan, 3, 6])
assert_equal(actual, expected)

actual, _ = groupby_reduce(
field,
by,
expected_groups=pd.IntervalIndex.from_breaks(bins),
func="count",
fill_value=np.nan,
method="cohorts",
axis=0,
)
expected = np.array(
[
[2.0, np.nan, 1.0, 2.0],
[2.0, np.nan, 1.0, 2.0],
[2.0, np.nan, 1.0, 2.0],
]
)
assert_equal(actual, expected)

0 comments on commit 4164712

Please sign in to comment.