Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

topk and argtopk #10086

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft

topk and argtopk #10086

wants to merge 3 commits into from

Conversation

Huite
Copy link
Contributor

@Huite Huite commented Feb 28, 2025

I've made a start with implementations of topk and argtopk.
This'll work uncontroversially for DataArrays with and without NaNs, although I'm getting mostly stuck on what skipna=True should entail.

There's a number of choices to make, however, which are probably best illustrated with this PR.

  • Numpy does not provide a topk implementation, dask does, as @dcherian helpfully pointed out. I've borrowed some parts of the dask implementation and put it in nputils.py. The question that arises here for me: I guess this would work with cupy too, but I don't quite oversee what's the best way to integrate it? Because dask provides topk of its own, it appears a bit exceptional.
  • Another question: bottleneck also provides partition and argpartition. This version works differs from numpy.partition in its handling of NaNs, however. Using numpy's partition has the benefit that it's consistent with dask.
  • In terms of API: topk feels mostly similar to quantile, since it shortens, but doesn't reduce the dimension entirely. argmin also supports an axis argument next to dim (though exclusive) -- is the axis argument desirable?
  • The code in variable.py borrows somewhat from quantile, since the result has an axis with dim k size instead of len(q). Unlike quantile, dask's topk and argtopk do not support tuple arguments for the axis (although topk accepts it and produces an unexpected result), so part of the stacking and unraveling functionality of _unravel_argminmax is required. I've currently duplicated the relevant lines to keep changes clearly visible.
  • I'm insufficiently familiar with apply_ufuncs to judge whether dask="allowed" works gracefully with the dask topk and argtopk functions, my guess is that it should.
  • quantile returns a result with a new dimension and coordinate called quantile, I've mimicked this and topk and argtopk return a result with a new topk or argtopk dimension respectively. I was thinking no labels are required for topk, but since both positive k values (for largest) and negative k values (for smallest) are possible, it's probably smart to return labels range(0, k) and range(-k, 0)?
  • There's also idxmin for the coordinate labels, and I suppose idxtopk would make sense too?
  • skipna=False is giving me some headaches. A (naive) implementation as in this PR is assymetric. Numpy partition (and thus the dask version too) sorts NaNs towards the end of the array, such that k > 0 will return NaNs, but k < 0 will not. For the testing, I figured da.topk(k=-1, skipna=False) should equal da.min(skipna=False) and da.topk(k=1, skipna=False), should equal da.max(), but this isn't the case. k=1 will return a NaN value since numpy partition moves the NaN to the end; k=-1 will not. I currently gravitate towards accepting this assymetry, since e.g. np.sort will also move NaNs to the back and it feels forced to fetch NaNs for k=-1 to match .min(skipna=False) . On the other hand, Python's sorted behaves differently, according to IEEE 754 NaNs are not orderable ... and I reckon you'd mostly use skipna=False when you want to ensure that no NaNs are present?
  • The NA-handling is currently in duck_array_ops, maybe it belong in nanops as it resembles _nan_argminmax_object and _nan_minmax_object, but is again slightly different. But I didn't like the circular imports that it seems to require; in duck_array_ops it decides whether to use dask or numpy (via nputils), but the masking of NaNs is required for both.

@Huite Huite marked this pull request as draft February 28, 2025 19:37
Copy link
Contributor

@dcherian dcherian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. I took a quick pass. I have one major suggestion on handling skipna and some minor suggestions for reducing scope so we can get this in earlier



def argtopk(values, k, axis=None, skipna=None):
if is_chunked_array(values):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_chunked_array(values):
if is_duck_dask_array(values):

@@ -6288,6 +6288,37 @@ def argmax(
else:
return self._replace_maybe_drop_dims(result)

def argtopk(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think generate_aggregations.py would make sense here so that we add it everywhere with the same docstring. That would get us groupby support for example, and I can eventually plug in flox when https://github.com/xarray-contrib/flox/pull/374/files is ready

def topk(values, k: int, axis: int):
"""Extract the k largest elements from a on the given axis.
If k is negative, extract the -k smallest elements instead.
The returned elements are sorted.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not sort. The user can do that if they really need to.

"""
# topk accepts only an integer axis like argmin or argmax,
# not tuples, so we need to stack multiple dimensions.
if dim is ... or dim is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. we have a infix_dims utility function for this.
  2. Can we punt the multiple dims case to later once someone asks for it? The stacking approach is bad with dask and will require some reshape_blockwise trickery, which isn't hard but we may as well do it in a followup (see polyfit for an example)


# Borrowed from nanops
xp = get_array_namespace(values)
if skipna or (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the way to do this is to use the fact that nans sort to the end. Then given k and count you know what to provide to partition.

See https://github.com/xarray-contrib/flox/blob/a5bcc5be642c0c0c825ccb536208a0b736d569e3/flox/aggregate_flox.py#L85-L92

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Generalized (arg)min, (arg)max: add nsmallest, nlargest, arg_nsmallest, arg_nlargest
2 participants