diff --git a/flox/xarray.py b/flox/xarray.py index 6b7c174b0..ec06b9161 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -257,8 +257,6 @@ def xarray_reduce( more_drop.update(idx_other_names) maybe_drop.update(more_drop) - ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) - if dim is Ellipsis: if nby > 1: raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.") @@ -275,6 +273,9 @@ def xarray_reduce( # broadcast to make sure grouper dimensions are present in the array. exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple) + if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): + raise ValueError(f"Cannot reduce over absent dimensions {dim}.") + try: xr.align(ds, *by_da, join="exact", copy=False) except ValueError as e: @@ -282,10 +283,13 @@ def xarray_reduce( "Object being grouped must be exactly aligned with every array in `by`." ) from e - ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0] - - if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): - raise ValueError(f"Cannot reduce over absent dimensions {dim}.") + needs_broadcast = any( + not set(grouper_dims).issubset(set(variable.dims)) for variable in ds.data_vars.values() + ) + if needs_broadcast: + ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0] + else: + ds_broad = ds dims_not_in_groupers = tuple(d for d in dim_tuple if d not in grouper_dims) if dims_not_in_groupers == tuple(dim_tuple) and not any(isbins): @@ -305,6 +309,8 @@ def xarray_reduce( else: return result + ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) + axis = tuple(range(-len(dim_tuple), 0)) # Set expected_groups and convert to index since we need coords, sizes @@ -432,7 +438,7 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): # restore non-dim coord variables without the core dimension # TODO: shouldn't apply_ufunc handle this? - for var in set(ds_broad.variables) - set(ds_broad._indexes) - set(ds_broad.dims): + for var in set(ds_broad._coord_names) - set(ds_broad._indexes) - set(ds_broad.dims): if all(d not in ds_broad[var].dims for d in dim_tuple): actual[var] = ds_broad[var] diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 6d99732b0..2fce2552c 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -343,6 +343,8 @@ def test_multi_index_groupby_sum(engine): expected = ds.sum("z") stacked = ds.stack(space=["x", "y"]) actual = xarray_reduce(stacked, "space", dim="z", func="sum", engine=engine) + expected_xarray = stacked.groupby("space").sum("z") + assert_equal(expected_xarray, actual) assert_equal(expected, actual.unstack("space")) actual = xarray_reduce(stacked.foo, "space", dim="z", func="sum", engine=engine)