-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
topk and argtopk #10086
base: main
Are you sure you want to change the base?
topk and argtopk #10086
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this. I took a quick pass. I have one major suggestion on handling skipna
and some minor suggestions for reducing scope so we can get this in earlier
|
||
|
||
def argtopk(values, k, axis=None, skipna=None): | ||
if is_chunked_array(values): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if is_chunked_array(values): | |
if is_duck_dask_array(values): |
@@ -6288,6 +6288,37 @@ def argmax( | |||
else: | |||
return self._replace_maybe_drop_dims(result) | |||
|
|||
def argtopk( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think generate_aggregations.py
would make sense here so that we add it everywhere with the same docstring. That would get us groupby support for example, and I can eventually plug in flox when https://github.com/xarray-contrib/flox/pull/374/files is ready
def topk(values, k: int, axis: int): | ||
"""Extract the k largest elements from a on the given axis. | ||
If k is negative, extract the -k smallest elements instead. | ||
The returned elements are sorted. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not sort. The user can do that if they really need to.
""" | ||
# topk accepts only an integer axis like argmin or argmax, | ||
# not tuples, so we need to stack multiple dimensions. | ||
if dim is ... or dim is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- we have a
infix_dims
utility function for this. - Can we punt the multiple dims case to later once someone asks for it? The stacking approach is bad with dask and will require some
reshape_blockwise
trickery, which isn't hard but we may as well do it in a followup (seepolyfit
for an example)
|
||
# Borrowed from nanops | ||
xp = get_array_namespace(values) | ||
if skipna or ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the way to do this is to use the fact that nan
s sort to the end. Then given k
and count
you know what to provide to partition.
whats-new.rst
api.rst
I've made a start with implementations of
topk
andargtopk
.This'll work uncontroversially for DataArrays with and without NaNs, although I'm getting mostly stuck on what
skipna=True
should entail.There's a number of choices to make, however, which are probably best illustrated with this PR.
nputils.py
. The question that arises here for me: I guess this would work withcupy
too, but I don't quite oversee what's the best way to integrate it? Because dask providestopk
of its own, it appears a bit exceptional.partition
andargpartition
. This version works differs fromnumpy.partition
in its handling of NaNs, however. Using numpy's partition has the benefit that it's consistent with dask.topk
feels mostly similar to quantile, since it shortens, but doesn't reduce the dimension entirely. argmin also supports anaxis
argument next todim
(though exclusive) -- is the axis argument desirable?variable.py
borrows somewhat fromquantile
, since the result has an axis with dimk
size instead oflen(q)
. Unlikequantile
, dask'stopk
andargtopk
do not support tuple arguments for the axis (althoughtopk
accepts it and produces an unexpected result), so part of the stacking and unraveling functionality of_unravel_argminmax
is required. I've currently duplicated the relevant lines to keep changes clearly visible.apply_ufuncs
to judge whetherdask="allowed"
works gracefully with the dasktopk
andargtopk
functions, my guess is that it should.quantile
returns a result with a new dimension and coordinate calledquantile
, I've mimicked this andtopk
andargtopk
return a result with a newtopk
orargtopk
dimension respectively. I was thinking no labels are required fortopk
, but since both positive k values (for largest) and negative k values (for smallest) are possible, it's probably smart to return labelsrange(0, k)
andrange(-k, 0)
?idxtopk
would make sense too?skipna=False
is giving me some headaches. A (naive) implementation as in this PR is assymetric. Numpy partition (and thus the dask version too) sorts NaNs towards the end of the array, such that k > 0 will return NaNs, but k < 0 will not. For the testing, I figuredda.topk(k=-1, skipna=False)
should equalda.min(skipna=False)
andda.topk(k=1, skipna=False)
, should equalda.max()
, but this isn't the case. k=1 will return a NaN value since numpy partition moves the NaN to the end; k=-1 will not. I currently gravitate towards accepting this assymetry, since e.g.np.sort
will also move NaNs to the back and it feels forced to fetch NaNs for k=-1 to match.min(skipna=False)
. On the other hand, Python'ssorted
behaves differently, according to IEEE 754 NaNs are not orderable ... and I reckon you'd mostly useskipna=False
when you want to ensure that no NaNs are present?duck_array_ops
, maybe it belong innanops
as it resembles_nan_argminmax_object
and_nan_minmax_object
, but is again slightly different. But I didn't like the circular imports that it seems to require; in duck_array_ops it decides whether to use dask or numpy (via nputils), but the masking of NaNs is required for both.