From 68bbaa763b958e442a62cbb35421d7bed74709ab Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 6 Apr 2023 16:25:48 -0600 Subject: [PATCH 1/6] Optimize broadcasting xref https://github.com/pydata/xarray/issues/7730 --- flox/xarray.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/flox/xarray.py b/flox/xarray.py index da5cf0e35..40605932e 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -257,8 +257,6 @@ def xarray_reduce( more_drop.update(idx_other_names) maybe_drop.update(more_drop) - ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) - if dim is Ellipsis: if nby > 1: raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.") @@ -275,18 +273,6 @@ def xarray_reduce( # broadcast to make sure grouper dimensions are present in the array. exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple) - try: - xr.align(ds, *by_da, join="exact", copy=False) - except ValueError as e: - raise ValueError( - "Object being grouped must be exactly aligned with every array in `by`." - ) from e - - ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0] - - if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): - raise ValueError(f"Cannot reduce over absent dimensions {dim}.") - dims_not_in_groupers = tuple(d for d in dim_tuple if d not in grouper_dims) if dims_not_in_groupers == tuple(dim_tuple) and not any(isbins): # reducing along a dimension along which groups do not vary @@ -299,12 +285,29 @@ def xarray_reduce( "func must be a string when reducing along a dimension not present in `by`" ) # TODO: skipna needs test - result = getattr(ds_broad, dsfunc)(dim=dim_tuple, skipna=skipna) + result = getattr(ds, dsfunc)(dim=dim_tuple, skipna=skipna) if isinstance(obj, xr.DataArray): return obj._from_temp_dataset(result) else: return result + try: + xr.align(ds, *by_da, join="exact", copy=False) + except ValueError as e: + raise ValueError( + "Object being grouped must be exactly aligned with every array in `by`." + ) from e + + if set(ds.dims) < set(grouper_dims): + ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0] + else: + ds_broad = ds + + if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): + raise ValueError(f"Cannot reduce over absent dimensions {dim}.") + + ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) + axis = tuple(range(-len(dim_tuple), 0)) # Set expected_groups and convert to index since we need coords, sizes From 6dc5f41e9411d0ec785cc69f707a8c7dc0ff1d2b Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 6 Apr 2023 16:27:51 -0600 Subject: [PATCH 2/6] reorder --- flox/xarray.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/flox/xarray.py b/flox/xarray.py index 40605932e..ca74cf34e 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -291,6 +291,11 @@ def xarray_reduce( else: return result + if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): + raise ValueError(f"Cannot reduce over absent dimensions {dim}.") + + ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) + try: xr.align(ds, *by_da, join="exact", copy=False) except ValueError as e: @@ -303,11 +308,6 @@ def xarray_reduce( else: ds_broad = ds - if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): - raise ValueError(f"Cannot reduce over absent dimensions {dim}.") - - ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) - axis = tuple(range(-len(dim_tuple), 0)) # Set expected_groups and convert to index since we need coords, sizes From 66ed57c3ff801018eb56a84a287e40cd6c44fbd2 Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 6 Apr 2023 20:18:37 -0600 Subject: [PATCH 3/6] Fix tests --- flox/xarray.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flox/xarray.py b/flox/xarray.py index ca74cf34e..7ba544eef 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -273,6 +273,9 @@ def xarray_reduce( # broadcast to make sure grouper dimensions are present in the array. exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple) + if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): + raise ValueError(f"Cannot reduce over absent dimensions {dim}.") + dims_not_in_groupers = tuple(d for d in dim_tuple if d not in grouper_dims) if dims_not_in_groupers == tuple(dim_tuple) and not any(isbins): # reducing along a dimension along which groups do not vary @@ -291,9 +294,6 @@ def xarray_reduce( else: return result - if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): - raise ValueError(f"Cannot reduce over absent dimensions {dim}.") - ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) try: @@ -303,7 +303,7 @@ def xarray_reduce( "Object being grouped must be exactly aligned with every array in `by`." ) from e - if set(ds.dims) < set(grouper_dims): + if not set(grouper_dims).issubset(set(ds.dims)): ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0] else: ds_broad = ds From 208eaf4831f40842e03eb39dd7d7f6915e7eb5bb Mon Sep 17 00:00:00 2001 From: dcherian Date: Thu, 6 Apr 2023 20:18:59 -0600 Subject: [PATCH 4/6] Another optimization --- flox/xarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flox/xarray.py b/flox/xarray.py index 7ba544eef..9b18ffb6d 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -435,7 +435,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): # restore non-dim coord variables without the core dimension # TODO: shouldn't apply_ufunc handle this? - for var in set(ds_broad.variables) - set(ds_broad._indexes) - set(ds_broad.dims): + for var in set(ds_broad._coord_names) - set(ds_broad._indexes) - set(ds_broad.dims): if all(d not in ds_broad[var].dims for d in dim_tuple): actual[var] = ds_broad[var] From 5527f7a3772eb498e1e2cb5fa6d098eddf7d0d37 Mon Sep 17 00:00:00 2001 From: dcherian Date: Tue, 18 Apr 2023 21:24:58 -0600 Subject: [PATCH 5/6] fixes --- flox/xarray.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/flox/xarray.py b/flox/xarray.py index 9b18ffb6d..d0f646d02 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -276,6 +276,23 @@ def xarray_reduce( if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): raise ValueError(f"Cannot reduce over absent dimensions {dim}.") + ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) + + try: + xr.align(ds, *by_da, join="exact", copy=False) + except ValueError as e: + raise ValueError( + "Object being grouped must be exactly aligned with every array in `by`." + ) from e + + needs_broadcast = any( + not set(grouper_dims).issubset(set(variable.dims)) for variable in ds.data_vars.values() + ) + if needs_broadcast: + ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0] + else: + ds_broad = ds + dims_not_in_groupers = tuple(d for d in dim_tuple if d not in grouper_dims) if dims_not_in_groupers == tuple(dim_tuple) and not any(isbins): # reducing along a dimension along which groups do not vary @@ -288,26 +305,12 @@ def xarray_reduce( "func must be a string when reducing along a dimension not present in `by`" ) # TODO: skipna needs test - result = getattr(ds, dsfunc)(dim=dim_tuple, skipna=skipna) + result = getattr(ds_broad, dsfunc)(dim=dim_tuple, skipna=skipna) if isinstance(obj, xr.DataArray): return obj._from_temp_dataset(result) else: return result - ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) - - try: - xr.align(ds, *by_da, join="exact", copy=False) - except ValueError as e: - raise ValueError( - "Object being grouped must be exactly aligned with every array in `by`." - ) from e - - if not set(grouper_dims).issubset(set(ds.dims)): - ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0] - else: - ds_broad = ds - axis = tuple(range(-len(dim_tuple), 0)) # Set expected_groups and convert to index since we need coords, sizes From d479bbd9313200d56c0098ea53ba5e300c236701 Mon Sep 17 00:00:00 2001 From: dcherian Date: Wed, 3 May 2023 20:54:26 -0600 Subject: [PATCH 6/6] fix --- flox/xarray.py | 4 ++-- tests/test_xarray.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/flox/xarray.py b/flox/xarray.py index 9a618dd38..ec06b9161 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -276,8 +276,6 @@ def xarray_reduce( if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): raise ValueError(f"Cannot reduce over absent dimensions {dim}.") - ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) - try: xr.align(ds, *by_da, join="exact", copy=False) except ValueError as e: @@ -311,6 +309,8 @@ def xarray_reduce( else: return result + ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) + axis = tuple(range(-len(dim_tuple), 0)) # Set expected_groups and convert to index since we need coords, sizes diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 6d99732b0..2fce2552c 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -343,6 +343,8 @@ def test_multi_index_groupby_sum(engine): expected = ds.sum("z") stacked = ds.stack(space=["x", "y"]) actual = xarray_reduce(stacked, "space", dim="z", func="sum", engine=engine) + expected_xarray = stacked.groupby("space").sum("z") + assert_equal(expected_xarray, actual) assert_equal(expected, actual.unstack("space")) actual = xarray_reduce(stacked.foo, "space", dim="z", func="sum", engine=engine)