Skip to content

Commit

Permalink
Set order='F' when raveling group_idx after broadcast (#286)
Browse files Browse the repository at this point in the history
* Set order='F' when raveling group_idx after broadcast

This majorly improves the dim=... case for engine="flox" at least.
xref #281

I'm not sure if it is a regression for engine="numpy"

We trade off a single bad reshape for array against argsorting both
array and group_idx for a ~10-20x speedup

```
ds = xr.tutorial.load_dataset('air_temperature')
ds.groupby('lon').count(..., engine="flox")
```

* This is an improvement only for engine=flox

* Update tests

* Fix benchmark

* type ignore
  • Loading branch information
dcherian authored Nov 8, 2023
1 parent 92f4780 commit 273d319
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
2 changes: 1 addition & 1 deletion asv_bench/benchmarks/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def setup(self, *args, **kwargs):
class ChunkReduce2DAllAxes(ChunkReduce):
def setup(self, *args, **kwargs):
self.array = np.ones((N, N))
self.labels = np.repeat(np.arange(N // 5), repeats=5)
self.labels = np.repeat(np.arange(N // 5), repeats=5)[np.newaxis, :]
self.axis = None
setup_jit()

Expand Down
14 changes: 12 additions & 2 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,16 +782,25 @@ def chunk_reduce(
)
groups = grps[0]

order = "C"
if nax > 1:
needs_broadcast = any(
group_idx.shape[ax] != array.shape[ax] and group_idx.shape[ax] == 1
for ax in range(-nax, 0)
)
if needs_broadcast:
# This is the dim=... case, it's a lot faster to ravel group_idx
# in fortran order since group_idx is then sorted
# I'm seeing 400ms -> 23ms for engine="flox"
# Of course we are slower to ravel `array` but we avoid argsorting
# both `array` *and* `group_idx` in _prepare_for_flox
group_idx = np.broadcast_to(group_idx, array.shape[-by.ndim :])
if engine == "flox":
group_idx = group_idx.reshape(-1, order="F")
order = "F"
# always reshape to 1D along group dimensions
newshape = array.shape[: array.ndim - by.ndim] + (math.prod(array.shape[-by.ndim :]),)
array = array.reshape(newshape)
array = array.reshape(newshape, order=order) # type: ignore[call-overload]
group_idx = group_idx.reshape(-1)

assert group_idx.ndim == 1
Expand Down Expand Up @@ -1814,7 +1823,8 @@ def groupby_reduce(
Array to be reduced, possibly nD
*by : ndarray or DaskArray
Array of labels to group over. Must be aligned with ``array`` so that
``array.shape[-by.ndim :] == by.shape``
``array.shape[-by.ndim :] == by.shape`` or any disagreements in that
equality check are for dimensions of size 1 in `by`.
func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
"max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
"quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \
Expand Down
4 changes: 2 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_groupby_reduce(
def gen_array_by(size, func):
by = np.ones(size[-1])
rng = np.random.default_rng(12345)
array = rng.random(size)
array = rng.random(tuple(6 if s == 1 else s for s in size))
if "nan" in func and "nanarg" not in func:
array[[1, 4, 5], ...] = np.nan
elif "nanarg" in func and len(size) > 1:
Expand All @@ -222,8 +222,8 @@ def gen_array_by(size, func):
pytest.param(4, marks=requires_dask),
],
)
@pytest.mark.parametrize("size", ((1, 12), (12,), (12, 9)))
@pytest.mark.parametrize("nby", [1, 2, 3])
@pytest.mark.parametrize("size", ((12,), (12, 9)))
@pytest.mark.parametrize("add_nan_by", [True, False])
@pytest.mark.parametrize("func", ALL_FUNCS)
def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
Expand Down

0 comments on commit 273d319

Please sign in to comment.