-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Open
Labels
Description
What is your issue?
xarray currently uses its own nanops.nansum when calling DataArray.sum(..., skipna=None), which relies on sum_where. This implementation route is very inefficient for sparse arrays, especially (and ironically) when operating on a sparse array with fill_value=np.nan, see pydata/sparse#908. Why doesn't xarray try to dispatch to a possible nansum implementation in the array's namespace?
sparse offers its own nansum. Internally, it also seems to use where, but it's much faster than the xarray nansum. I applied the following patch to duck_array_ops.py, reducing the time for sums on a sparse array significantly:
--- duck_array_ops.py 2025-11-14 12:21:49
+++ duck_array_ops.py 2025-11-14 12:23:20
@@ -519,6 +519,15 @@
nanname = "nan" + name
func = getattr(nanops, nanname)
+
+ if "min_count" not in kwargs or kwargs["min_count"] is None:
+ try:
+ kwargs.pop("min_count", None)
+ xp = get_array_namespace(values)
+ func = getattr(xp, name)
+ except AttributeError:
+ pass
+
else:
if name in ["sum", "prod"]:
kwargs.pop("min_count", None)Dispatching to sparse.nansum produces a factor 20+ speedup:
# Without patch
$ !python -m timeit -s "import sparse; import xarray as xr; import numpy as np; arr = xr.DataArray(sparse.random((100, 100, 100), density=0.1, fill_value=np.nan), dims=['x', 'y', 'z'])" "arr.sum(dim=['y', 'z'])"
1 loop, best of 5: 36.2 msec per loop
# With patch
$ !python -m timeit -s "import sparse; import xarray as xr; import numpy as np; arr = xr.DataArray(sparse.random((100, 100, 100), density=0.1, fill_value=np.nan), dims=['x', 'y', 'z'])" "arr.sum(dim=['y', 'z'])"
200 loops, best of 5: 1.37 msec per loop