Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cupy support #134

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
15 changes: 12 additions & 3 deletions flox/aggregate_flox.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def _prepare_for_flox(group_idx, array):
if issorted:
ordered_array = array
else:
perm = group_idx.argsort(kind="stable")
kind = "stable" if isinstance(group_idx, np.ndarray) else None

perm = np.argsort(group_idx, kind=kind)
group_idx = group_idx[..., perm]
ordered_array = array[..., perm]
return group_idx, ordered_array
Expand All @@ -25,7 +27,9 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
most of this code is from shoyer's gist
https://gist.github.com/shoyer/f538ac78ae904c936844
"""
# assumes input is sorted, which I do in core._prepare_for_flox
# For numpy arrays, assumes input is sorted, which I do in _prepare_for_flox
# For cupy arrays, sorting is not needed

aux = group_idx

flag = np.concatenate((np.array([True], like=array), aux[1:] != aux[:-1]))
Expand All @@ -38,7 +42,12 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
dtype = array.dtype

if out is None:
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype, like=array)

# if isinstance(array, cupy_array_type):
# op = cupy_ops[op]
# op(out, group_idx, array)
# return out

if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all():
# The previous version of this if condition
Expand Down
79 changes: 55 additions & 24 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@
generic_aggregate,
)
from .cache import memoize
from .xrutils import is_duck_array, is_duck_dask_array, isnull, module_available
from .xrutils import (
is_duck_array,
is_duck_dask_array,
isnull,
module_available,
to_numpy,
)

HAS_NUMBAGG = module_available("numbagg", minversion="0.3.0")

Expand Down Expand Up @@ -145,42 +151,51 @@ def _collapse_axis(arr: np.ndarray, naxis: int) -> np.ndarray:

@memoize
def _get_optimal_chunks_for_groups(chunks, labels):
chunkidx = np.cumsum(chunks) - 1
chunks_array = np.asarray(chunks, like=labels)
chunkidx = np.cumsum(chunks_array) - 1
# what are the groups at chunk boundaries
labels_at_chunk_bounds = _unique(labels[chunkidx])
# what's the last index of all groups
last_indexes = npg.aggregate_numpy.aggregate(labels, np.arange(len(labels)), func="last")
last_indexes = npg.aggregate_numpy.aggregate(
labels, np.arange(len(labels), like=labels), func="last"
)
# what's the last index of groups at the chunk boundaries.
lastidx = last_indexes[labels_at_chunk_bounds]

if len(chunkidx) == len(lastidx) and (chunkidx == lastidx).all():
return chunks

first_indexes = npg.aggregate_numpy.aggregate(labels, np.arange(len(labels)), func="first")
first_indexes = npg.aggregate_numpy.aggregate(
labels, np.arange(len(labels), like=labels), func="first"
)
firstidx = first_indexes[labels_at_chunk_bounds]

newchunkidx = [0]
newchunkidx = np.array([0], like=labels)
for c, f, l in zip(chunkidx, firstidx, lastidx): # noqa
Δf = abs(c - f)
Δl = abs(c - l)
if c == 0 or newchunkidx[-1] > l:
continue
if Δf < Δl and f > newchunkidx[-1]:
newchunkidx.append(f)
newchunkidx = np.append(newchunkidx, f)
else:
newchunkidx.append(l + 1)
newchunkidx = np.append(newchunkidx, l + 1)
if newchunkidx[-1] != chunkidx[-1] + 1:
newchunkidx.append(chunkidx[-1] + 1)
newchunkidx = np.append(newchunkidx, chunkidx[-1] + 1)
newchunks = np.diff(newchunkidx)

assert sum(newchunks) == sum(chunks)
return tuple(newchunks)
# workaround cupy bug with tuple(array)
return tuple(newchunks.tolist())


def _unique(a: np.ndarray) -> np.ndarray:
def _unique(a):
"""Much faster to use pandas unique and sort the results.
np.unique sorts before uniquifying and is slow."""
return np.sort(pd.unique(a.reshape(-1)))
if isinstance(a, np.ndarray):
return np.sort(pd.unique(a.reshape(-1)))
else:
return np.unique(a.reshape(-1))


@memoize
Expand Down Expand Up @@ -210,7 +225,9 @@ def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
import dask

# To do this, we must have values in memory so casting to numpy should be safe
labels = np.asarray(labels)
if not is_duck_array(labels):
labels = np.asarray(labels)
labels = to_numpy(labels)

# Build an array with the shape of labels, but where every element is the "chunk number"
# 1. First subset the array appropriately
Expand Down Expand Up @@ -433,7 +450,7 @@ def reindex_(
reindexed = np.full(array.shape[:-1] + (len(to),), fill_value, dtype=array.dtype)
return reindexed

from_ = pd.Index(from_)
from_ = pd.Index(to_numpy(from_))
# short-circuit for trivial case
if from_.equals(to):
return array
Expand Down Expand Up @@ -546,7 +563,7 @@ def factorize_(
# this is important in shared-memory parallelism with dask
# TODO: figure out how to avoid this
idx = flat.copy()
found_groups.append(np.array(expect))
found_groups.append(np.array(expect, like=flat, copy=False))
# TODO: fix by using masked integers
idx[idx > expect[-1]] = -1

Expand All @@ -561,7 +578,11 @@ def factorize_(
right = expect.closed_right
idx = np.digitize(
flat,
bins=bins.view(np.int64) if bins.dtype.kind == "M" else bins,
bins=np.array(
bins.view(np.int64) if bins.dtype.kind == "M" else bins,
like=flat,
copy=False,
),
right=right,
)
idx -= 1
Expand All @@ -574,7 +595,7 @@ def factorize_(
else:
if expect is not None and reindex:
sorter = np.argsort(expect)
groups = expect[(sorter,)] if sort else expect
groups = np.array(expect[(sorter,)]) if sort else expect
idx = np.searchsorted(expect, flat, sorter=sorter)
mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect))
if not sort:
Expand All @@ -584,9 +605,16 @@ def factorize_(
idx = sorter[(idx,)]
idx[mask] = -1
else:
idx, groups = pd.factorize(flat, sort=sort) # type: ignore[arg-type]
if isinstance(flat, np.ndarray):
idx, groups = pd.factorize(flat, sort=sort) # type: ignore[call-overload]
groups = np.array(groups)
else:
assert sort
groups, idx = np.unique(flat, return_inverse=True)
idx[np.isnan(flat)] = -1
groups = groups[~np.isnan(groups)] # type: ignore[call-overload,index]

found_groups.append(np.array(groups))
found_groups.append(groups) # type: ignore[arg-type]
factorized.append(idx.reshape(groupvar.shape))

grp_shape = tuple(len(grp) for grp in found_groups)
Expand Down Expand Up @@ -945,7 +973,10 @@ def _find_unique_groups(x_chunk) -> np.ndarray:
from dask.base import flatten
from dask.utils import deepmap

unique_groups = _unique(np.asarray(tuple(flatten(deepmap(listify_groups, x_chunk)))))
tup = tuple(flatten(deepmap(listify_groups, x_chunk)))
# passing like=None raises. Seems like a bug
kwargs = dict(like=tup[0]) if is_duck_array(tup[0]) else {}
unique_groups = _unique(np.asarray(tup, **kwargs))
unique_groups = unique_groups[~isnull(unique_groups)]

if len(unique_groups) == 0:
Expand Down Expand Up @@ -1017,12 +1048,11 @@ def _conc2(x_chunk, key1, key2=slice(None), axis: T_Axes | None = None) -> np.nd


def reindex_intermediates(x: IntermediateDict, agg: Aggregation, unique_groups) -> IntermediateDict:
to = pd.Index(to_numpy(unique_groups))
new_shape = x["groups"].shape[:-1] + (len(unique_groups),)
newx: IntermediateDict = {"groups": np.broadcast_to(unique_groups, new_shape)}
newx["intermediates"] = tuple(
reindex_(
v, from_=np.atleast_1d(x["groups"].squeeze()), to=pd.Index(unique_groups), fill_value=f
)
reindex_(v, from_=to_numpy(np.atleast_1d(x["groups"].squeeze())), to=to, fill_value=f)
for v, f in zip(x["intermediates"], agg.fill_value["intermediate"])
)
return newx
Expand Down Expand Up @@ -1280,7 +1310,7 @@ def subset_to_blocks(
layer: Graph = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
graph = HighLevelGraph.from_collections(name, layer, dependencies=[array])

return dask.array.Array(graph, name, chunks, meta=array)
return dask.array.Array(graph, name, chunks, meta=array._meta)


def _extract_unknown_groups(reduced, dtype) -> tuple[DaskArray]:
Expand Down Expand Up @@ -1521,6 +1551,7 @@ def dask_groupby_agg(
reduced,
inds,
adjust_chunks=dict(zip(out_inds, output_chunks)),
meta=array._meta,
dtype=agg.dtype["final"],
key=agg.name,
name=f"{name}-{token}",
Expand Down Expand Up @@ -2126,7 +2157,7 @@ def groupby_reduce(
# now we get rid of them by reindexing
# This also handles bins with no data
result = reindex_(
result, from_=groups[0], to=expected_groups, fill_value=fill_value
result, from_=to_numpy(groups[0]), to=expected_groups, fill_value=fill_value
).reshape(result.shape[:-1] + grp_shape)
groups = final_groups

Expand Down
14 changes: 14 additions & 0 deletions flox/xrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,20 @@ def nanlast(values, axis, keepdims=False):
return result


def to_numpy(a):
a_np = a
if is_duck_dask_array(a_np):
a_np = a_np.compute()

if module_available("cupy"):
import cupy as cp

cp_types = (cp.ndarray,)
if isinstance(a_np, cp_types):
a_np = a_np.get()
return a_np


def module_available(module: str, minversion: Optional[str] = None) -> bool:
"""Checks whether a module is installed without importing it.

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ changelog = "https://github.com/xarray-contrib/flox/releases"

[project.optional-dependencies]
all = ["cachey", "dask", "numba", "numbagg", "xarray"]
cupy = ["cupy>=12.1"]
test = ["netCDF4"]

[build-system]
Expand Down Expand Up @@ -117,6 +118,7 @@ module=[
"asv_runner.*",
"cachey",
"cftime",
"cupy",
"dask.*",
"importlib_metadata",
"numba",
Expand Down
24 changes: 22 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
except ImportError:
xr_types = () # type: ignore[assignment]

try:
import cupy as cp

cp_types = (cp.ndarray,)
except ImportError:
cp_types = () # type: ignore[assignment]


def _importorskip(modname, minversion=None):
try:
Expand Down Expand Up @@ -78,6 +85,15 @@ def raise_if_dask_computes(max_computes=0):
return dask.config.set(scheduler=scheduler)


def to_numpy(a):
a_np = a
if isinstance(a_np, dask_array_type):
a_np = a_np.compute()
if isinstance(a_np, cp_types):
a_np = a_np.get()
return a_np


def assert_equal(a, b, tolerance=None):
__tracebackhide__ = True

Expand All @@ -100,16 +116,20 @@ def assert_equal(a, b, tolerance=None):
else:
tolerance = {}

if has_dask and isinstance(a, dask_array_type) or isinstance(b, dask_array_type):
if has_dask and (isinstance(a, dask_array_type) or isinstance(b, dask_array_type)):
# sometimes it's nice to see values and shapes
# rather than being dropped into some file in dask
np.testing.assert_allclose(a, b, **tolerance)
np.testing.assert_allclose(to_numpy(a), to_numpy(b), **tolerance)
# does some validation of the dask graph
da.utils.assert_eq(a, b, equal_nan=True)
else:
if a.dtype != b.dtype:
raise AssertionError(f"a and b have different dtypes: (a: {a.dtype}, b: {b.dtype})")

if isinstance(a, cp_types):
a = a.get()
if isinstance(b, cp_types):
b = b.get()
np.testing.assert_allclose(a, b, equal_nan=True, **tolerance)


Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,18 @@
)
def engine(request):
return request.param


@pytest.fixture(scope="module", params=["numpy", "cupy"])
def array_module(request):
if request.param == "cupy":
try:
import cupy # noqa

return cupy
except ImportError:
pytest.xfail()
elif request.param == "numpy":
import numpy

return numpy
Loading
Loading