diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 4df3f77a4..7380e70a3 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -14,7 +14,9 @@ def _prepare_for_flox(group_idx, array): if issorted: ordered_array = array else: - perm = group_idx.argsort(kind="stable") + kind = "stable" if isinstance(group_idx, np.ndarray) else None + + perm = np.argsort(group_idx, kind=kind) group_idx = group_idx[..., perm] ordered_array = array[..., perm] return group_idx, ordered_array @@ -25,7 +27,9 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt most of this code is from shoyer's gist https://gist.github.com/shoyer/f538ac78ae904c936844 """ - # assumes input is sorted, which I do in core._prepare_for_flox + # For numpy arrays, assumes input is sorted, which I do in _prepare_for_flox + # For cupy arrays, sorting is not needed + aux = group_idx flag = np.concatenate((np.array([True], like=array), aux[1:] != aux[:-1])) @@ -38,7 +42,12 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt dtype = array.dtype if out is None: - out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype) + out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype, like=array) + + # if isinstance(array, cupy_array_type): + # op = cupy_ops[op] + # op(out, group_idx, array) + # return out if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all(): # The previous version of this if condition diff --git a/flox/core.py b/flox/core.py index 66e39b6f1..27e2c58f7 100644 --- a/flox/core.py +++ b/flox/core.py @@ -33,7 +33,13 @@ generic_aggregate, ) from .cache import memoize -from .xrutils import is_duck_array, is_duck_dask_array, isnull, module_available +from .xrutils import ( + is_duck_array, + is_duck_dask_array, + isnull, + module_available, + to_numpy, +) HAS_NUMBAGG = module_available("numbagg", minversion="0.3.0") @@ -145,42 +151,51 @@ def _collapse_axis(arr: np.ndarray, naxis: int) -> np.ndarray: @memoize def _get_optimal_chunks_for_groups(chunks, labels): - chunkidx = np.cumsum(chunks) - 1 + chunks_array = np.asarray(chunks, like=labels) + chunkidx = np.cumsum(chunks_array) - 1 # what are the groups at chunk boundaries labels_at_chunk_bounds = _unique(labels[chunkidx]) # what's the last index of all groups - last_indexes = npg.aggregate_numpy.aggregate(labels, np.arange(len(labels)), func="last") + last_indexes = npg.aggregate_numpy.aggregate( + labels, np.arange(len(labels), like=labels), func="last" + ) # what's the last index of groups at the chunk boundaries. lastidx = last_indexes[labels_at_chunk_bounds] if len(chunkidx) == len(lastidx) and (chunkidx == lastidx).all(): return chunks - first_indexes = npg.aggregate_numpy.aggregate(labels, np.arange(len(labels)), func="first") + first_indexes = npg.aggregate_numpy.aggregate( + labels, np.arange(len(labels), like=labels), func="first" + ) firstidx = first_indexes[labels_at_chunk_bounds] - newchunkidx = [0] + newchunkidx = np.array([0], like=labels) for c, f, l in zip(chunkidx, firstidx, lastidx): # noqa Δf = abs(c - f) Δl = abs(c - l) if c == 0 or newchunkidx[-1] > l: continue if Δf < Δl and f > newchunkidx[-1]: - newchunkidx.append(f) + newchunkidx = np.append(newchunkidx, f) else: - newchunkidx.append(l + 1) + newchunkidx = np.append(newchunkidx, l + 1) if newchunkidx[-1] != chunkidx[-1] + 1: - newchunkidx.append(chunkidx[-1] + 1) + newchunkidx = np.append(newchunkidx, chunkidx[-1] + 1) newchunks = np.diff(newchunkidx) assert sum(newchunks) == sum(chunks) - return tuple(newchunks) + # workaround cupy bug with tuple(array) + return tuple(newchunks.tolist()) -def _unique(a: np.ndarray) -> np.ndarray: +def _unique(a): """Much faster to use pandas unique and sort the results. np.unique sorts before uniquifying and is slow.""" - return np.sort(pd.unique(a.reshape(-1))) + if isinstance(a, np.ndarray): + return np.sort(pd.unique(a.reshape(-1))) + else: + return np.unique(a.reshape(-1)) @memoize @@ -210,7 +225,9 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict: import dask # To do this, we must have values in memory so casting to numpy should be safe - labels = np.asarray(labels) + if not is_duck_array(labels): + labels = np.asarray(labels) + labels = to_numpy(labels) # Build an array with the shape of labels, but where every element is the "chunk number" # 1. First subset the array appropriately @@ -433,7 +450,7 @@ def reindex_( reindexed = np.full(array.shape[:-1] + (len(to),), fill_value, dtype=array.dtype) return reindexed - from_ = pd.Index(from_) + from_ = pd.Index(to_numpy(from_)) # short-circuit for trivial case if from_.equals(to): return array @@ -546,7 +563,7 @@ def factorize_( # this is important in shared-memory parallelism with dask # TODO: figure out how to avoid this idx = flat.copy() - found_groups.append(np.array(expect)) + found_groups.append(np.array(expect, like=flat, copy=False)) # TODO: fix by using masked integers idx[idx > expect[-1]] = -1 @@ -561,7 +578,11 @@ def factorize_( right = expect.closed_right idx = np.digitize( flat, - bins=bins.view(np.int64) if bins.dtype.kind == "M" else bins, + bins=np.array( + bins.view(np.int64) if bins.dtype.kind == "M" else bins, + like=flat, + copy=False, + ), right=right, ) idx -= 1 @@ -574,7 +595,7 @@ def factorize_( else: if expect is not None and reindex: sorter = np.argsort(expect) - groups = expect[(sorter,)] if sort else expect + groups = np.array(expect[(sorter,)]) if sort else expect idx = np.searchsorted(expect, flat, sorter=sorter) mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect)) if not sort: @@ -584,9 +605,16 @@ def factorize_( idx = sorter[(idx,)] idx[mask] = -1 else: - idx, groups = pd.factorize(flat, sort=sort) # type: ignore[arg-type] + if isinstance(flat, np.ndarray): + idx, groups = pd.factorize(flat, sort=sort) # type: ignore[call-overload] + groups = np.array(groups) + else: + assert sort + groups, idx = np.unique(flat, return_inverse=True) + idx[np.isnan(flat)] = -1 + groups = groups[~np.isnan(groups)] # type: ignore[call-overload,index] - found_groups.append(np.array(groups)) + found_groups.append(groups) # type: ignore[arg-type] factorized.append(idx.reshape(groupvar.shape)) grp_shape = tuple(len(grp) for grp in found_groups) @@ -945,7 +973,10 @@ def _find_unique_groups(x_chunk) -> np.ndarray: from dask.base import flatten from dask.utils import deepmap - unique_groups = _unique(np.asarray(tuple(flatten(deepmap(listify_groups, x_chunk))))) + tup = tuple(flatten(deepmap(listify_groups, x_chunk))) + # passing like=None raises. Seems like a bug + kwargs = dict(like=tup[0]) if is_duck_array(tup[0]) else {} + unique_groups = _unique(np.asarray(tup, **kwargs)) unique_groups = unique_groups[~isnull(unique_groups)] if len(unique_groups) == 0: @@ -1017,12 +1048,11 @@ def _conc2(x_chunk, key1, key2=slice(None), axis: T_Axes | None = None) -> np.nd def reindex_intermediates(x: IntermediateDict, agg: Aggregation, unique_groups) -> IntermediateDict: + to = pd.Index(to_numpy(unique_groups)) new_shape = x["groups"].shape[:-1] + (len(unique_groups),) newx: IntermediateDict = {"groups": np.broadcast_to(unique_groups, new_shape)} newx["intermediates"] = tuple( - reindex_( - v, from_=np.atleast_1d(x["groups"].squeeze()), to=pd.Index(unique_groups), fill_value=f - ) + reindex_(v, from_=to_numpy(np.atleast_1d(x["groups"].squeeze())), to=to, fill_value=f) for v, f in zip(x["intermediates"], agg.fill_value["intermediate"]) ) return newx @@ -1280,7 +1310,7 @@ def subset_to_blocks( layer: Graph = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys} graph = HighLevelGraph.from_collections(name, layer, dependencies=[array]) - return dask.array.Array(graph, name, chunks, meta=array) + return dask.array.Array(graph, name, chunks, meta=array._meta) def _extract_unknown_groups(reduced, dtype) -> tuple[DaskArray]: @@ -1521,6 +1551,7 @@ def dask_groupby_agg( reduced, inds, adjust_chunks=dict(zip(out_inds, output_chunks)), + meta=array._meta, dtype=agg.dtype["final"], key=agg.name, name=f"{name}-{token}", @@ -2126,7 +2157,7 @@ def groupby_reduce( # now we get rid of them by reindexing # This also handles bins with no data result = reindex_( - result, from_=groups[0], to=expected_groups, fill_value=fill_value + result, from_=to_numpy(groups[0]), to=expected_groups, fill_value=fill_value ).reshape(result.shape[:-1] + grp_shape) groups = final_groups diff --git a/flox/xrutils.py b/flox/xrutils.py index 497cd7b24..c5a9a9952 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -321,6 +321,20 @@ def nanlast(values, axis, keepdims=False): return result +def to_numpy(a): + a_np = a + if is_duck_dask_array(a_np): + a_np = a_np.compute() + + if module_available("cupy"): + import cupy as cp + + cp_types = (cp.ndarray,) + if isinstance(a_np, cp_types): + a_np = a_np.get() + return a_np + + def module_available(module: str, minversion: Optional[str] = None) -> bool: """Checks whether a module is installed without importing it. diff --git a/pyproject.toml b/pyproject.toml index 9cf0b4ca4..0a7b1dc46 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ changelog = "https://github.com/xarray-contrib/flox/releases" [project.optional-dependencies] all = ["cachey", "dask", "numba", "numbagg", "xarray"] +cupy = ["cupy>=12.1"] test = ["netCDF4"] [build-system] @@ -117,6 +118,7 @@ module=[ "asv_runner.*", "cachey", "cftime", + "cupy", "dask.*", "importlib_metadata", "numba", diff --git a/tests/__init__.py b/tests/__init__.py index 4af6bc87e..7e216e615 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -24,6 +24,13 @@ except ImportError: xr_types = () # type: ignore[assignment] +try: + import cupy as cp + + cp_types = (cp.ndarray,) +except ImportError: + cp_types = () # type: ignore[assignment] + def _importorskip(modname, minversion=None): try: @@ -78,6 +85,15 @@ def raise_if_dask_computes(max_computes=0): return dask.config.set(scheduler=scheduler) +def to_numpy(a): + a_np = a + if isinstance(a_np, dask_array_type): + a_np = a_np.compute() + if isinstance(a_np, cp_types): + a_np = a_np.get() + return a_np + + def assert_equal(a, b, tolerance=None): __tracebackhide__ = True @@ -100,16 +116,20 @@ def assert_equal(a, b, tolerance=None): else: tolerance = {} - if has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type): + if has_dask and (isinstance(a, dask_array_type) or isinstance(b, dask_array_type)): # sometimes it's nice to see values and shapes # rather than being dropped into some file in dask - np.testing.assert_allclose(a, b, **tolerance) + np.testing.assert_allclose(to_numpy(a), to_numpy(b), **tolerance) # does some validation of the dask graph da.utils.assert_eq(a, b, equal_nan=True) else: if a.dtype != b.dtype: raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})") + if isinstance(a, cp_types): + a = a.get() + if isinstance(b, cp_types): + b = b.get() np.testing.assert_allclose(a, b, equal_nan=True, **tolerance) diff --git a/tests/conftest.py b/tests/conftest.py index 504564b5a..078f9a437 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,3 +14,18 @@ ) def engine(request): return request.param + + +@pytest.fixture(scope="module", params=["numpy", "cupy"]) +def array_module(request): + if request.param == "cupy": + try: + import cupy # noqa + + return cupy + except ImportError: + pytest.xfail() + elif request.param == "numpy": + import numpy + + return numpy diff --git a/tests/test_core.py b/tests/test_core.py index 99f181255..21a3c4317 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -200,14 +200,34 @@ def test_groupby_reduce( assert_equal(expected_result, result) -def gen_array_by(size, func): - by = np.ones(size[-1]) - rng = np.random.default_rng(12345) +def maybe_skip_cupy(array_module, func, engine): + if array_module is np: + return + + import cupy + + assert array_module is cupy + + if engine == "numba": + pytest.skip() + + if engine == "numpy" and ("prod" in func or "first" in func or "last" in func): + pytest.xfail() + elif engine == "flox" and not ( + "sum" in func or "mean" in func or "std" in func or "var" in func + ): + pytest.xfail() + + +def gen_array_by(size, func, array_module): + xp = array_module + by = xp.ones(size[-1]) + rng = xp.random.default_rng(12345) array = rng.random(size) if "nan" in func and "nanarg" not in func: - array[[1, 4, 5], ...] = np.nan + array[[1, 4, 5], ...] = xp.nan elif "nanarg" in func and len(size) > 1: - array[[1, 4, 5], 1] = np.nan + array[[1, 4, 5], 1] = xp.nan if func in ["any", "all"]: array = array > 0.5 return array, by @@ -224,15 +244,17 @@ def gen_array_by(size, func): ) @pytest.mark.parametrize("nby", [1, 2, 3]) @pytest.mark.parametrize("size", ((12,), (12, 9))) -@pytest.mark.parametrize("add_nan_by", [True, False]) @pytest.mark.parametrize("func", ALL_FUNCS) -def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): +@pytest.mark.parametrize("add_nan_by", [True, False]) +def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine, array_module): if ("arg" in func and engine in ["flox", "numbagg"]) or ( func in BLOCKWISE_FUNCS and chunks != -1 ): pytest.skip() - array, by = gen_array_by(size, func) + maybe_skip_cupy(array_module, func, engine) + + array, by = gen_array_by(size, func, array_module) if chunks: array = dask.array.from_array(array, chunks=chunks) by = (by,) * nby @@ -289,10 +311,12 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert expected.ndim == (array.ndim + nby - 1) expected_groups = tuple(np.array([idx + 1.0]) for idx in range(nby)) for actual_group, expect in zip(groups, expected_groups): - assert_equal(actual_group, expect) + assert_equal(actual_group, array_module.asarray(expect)) if "arg" in func: assert actual.dtype.kind == "i" - assert_equal(actual, expected, tolerance) + if chunks is not None: + assert isinstance(actual._meta, type(array._meta)) + assert_equal(actual, array_module.asarray(expected), tolerance) if not has_dask or chunks is None or func in BLOCKWISE_FUNCS: continue @@ -322,6 +346,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine): assert_equal(actual_group, expect, tolerance) if "arg" in func: assert actual.dtype.kind == "i" + if chunks is not None: + assert isinstance(actual._meta, type(array._meta)) assert_equal(actual, expected, tolerance) @@ -348,18 +374,18 @@ def test_arg_reduction_dtype_is_int(size, func): assert actual.dtype.kind == "i" -def test_groupby_reduce_count(): - array = np.array([0, 0, np.nan, np.nan, np.nan, 1, 1]) - labels = np.array(["a", "b", "b", "b", "c", "c", "c"]) +def test_groupby_reduce_count(array_module): + array = array_module.array([0, 0, np.nan, np.nan, np.nan, 1, 1]) + labels = array_module.array(["a", "b", "b", "b", "c", "c", "c"]) result, _ = groupby_reduce(array, labels, func="count") assert_equal(result, np.array([1, 1, 2], dtype=np.intp)) -def test_func_is_aggregation(): +def test_func_is_aggregation(array_module): from flox.aggregations import mean - array = np.array([0, 0, np.nan, np.nan, np.nan, 1, 1]) - labels = np.array(["a", "b", "b", "b", "c", "c", "c"]) + array = array_module.array([0, 0, np.nan, np.nan, np.nan, 1, 1]) + labels = array_module.array(["a", "b", "b", "b", "c", "c", "c"]) expected, _ = groupby_reduce(array, labels, func="mean") actual, _ = groupby_reduce(array, labels, func=mean) assert_equal(actual, expected) @@ -821,8 +847,8 @@ def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None: [(10,), (10,)], ], ) -def test_rechunk_for_blockwise(inchunks, expected): - labels = np.array([1, 1, 1, 2, 2, 3, 3, 5, 5, 5]) +def test_rechunk_for_blockwise(inchunks, expected, array_module): + labels = array_module.array([1, 1, 1, 2, 2, 3, 3, 5, 5, 5]) assert _get_optimal_chunks_for_groups(inchunks, labels) == expected