Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalized (arg)min, (arg)max: add nsmallest, nlargest, arg_nsmallest, arg_nlargest #10075

Open
Huite opened this issue Feb 24, 2025 · 1 comment · May be fixed by #10086
Open

Generalized (arg)min, (arg)max: add nsmallest, nlargest, arg_nsmallest, arg_nlargest #10075

Huite opened this issue Feb 24, 2025 · 1 comment · May be fixed by #10086

Comments

@Huite
Copy link
Contributor

Huite commented Feb 24, 2025

Is your feature request related to a problem?

I find that I need the (index of) N largest or N smallest values along some dimension with some regularity.

Describe the solution you'd like

Pandas provides nsmallest and nlargest:

Something similar would be useful for Xarray, I reckon, although just like there's argmin and argmax next to min and max, having arg_nsmallest and arg_nlargest (or something) would convenient as well.

It could match the existing method signatures, requiring an extra n argument:

    def nlargest(
        self,
        n: int,
        dim: Dims = None,
        *,
        skipna: bool | None = None,
        keep_attrs: bool | None = None,
        **kwargs: Any,
    ) -> Self:

The basic idea is to wrap numpy or bottleneck argpartition, I currently use this quick and dirty utility for a DataArray and a single dimension:

def arg_nsmallest(da: xr.DataArray, dim: str, n: int):
    """
    Return the index or indices of the ``n`` smallest values along dimension ``dim``.

    Parameters
    ----------
    da: xr.DataArray
    dim: str
        Dimension over which to find the ``n`` smallest values.
    n: int
        The number of items to retrieve.
    
    Returns
    -------
    result: xr.DataArray
    """
    # Find the axis over which to apply the partition.
    axis = da.dims.index(dim)

    # Set up output coordinates.
    dim_index = np.arange(n)
    coords = da.coords.copy()
    coords[dim] = dim_index
    shape = list(da.shape)
    shape[axis] = n
    template = xr.DataArray(
        data=dask.array.zeros(shape, dtype=int),
        coords=coords,
        dims=da.dims,
    )
    def _nsmallest(da: xr.DataArray):
        # NOTE: numpy (arg)partition moves NaNs to the back;
        # bottleneck partition does not!
        smallest = np.argpartition(da.to_numpy(), kth=n, axis=axis)
        return template.copy(data=np.take(smallest, indices=np.arange(n), axis=axis))
 
    return xr.map_blocks(_nsmallest, da, template=template)

Describe alternatives you've considered

In principle, the same can be achieved using e.g. xarray's argsort, but this is much more costly when e.g. only the three highest or lowest values are required. Argsort doesn't support dimensions and isn't NaN-aware either; nsmallest is more straightforward since nlargest is obstructed by the NaNs moved to the end.

Additional context

No response

@dcherian
Copy link
Contributor

Agree that a wrapper around partition/argpartition would be useful. Note the python array community calls this topk data-apis/array-api#629 so we should probably follow that.

At least Dask provides topk and argtopk so we need not use map_blocks.

As an aside, that naming sadly clashes with finding the top-k most frequent elements in a stream (example) :/

@Huite Huite linked a pull request Feb 28, 2025 that will close this issue
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants