Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for All-NaN in argmax, argmin #3884

Open
christian-oreilly opened this issue Mar 24, 2020 · 6 comments
Open

Allow for All-NaN in argmax, argmin #3884

christian-oreilly opened this issue Mar 24, 2020 · 6 comments

Comments

@christian-oreilly
Copy link

christian-oreilly commented Mar 24, 2020

Background: In data analyses, it is a common occurrence to have multidimensional datasets with missing conditions. For example, having a data array of power measurements for a multi-channel recording device with dimension nb_channel X nb_subjects X [...], it is current that some channels might be missing for some subjects, in which cases the array will have only NaN for this condition. Functions like Dataset.mean() performs well in this situation and will output a "RuntimeWarning: Mean of empty slice" and set the mean values for this all-NaN slice to NaN, which is what is expected for such a use case. Depending on the use case, the user has the possibility to filter these warnings to either ignore them or raise them as errors. This is all fine.

Problem: However, in the case of the Dataset.argmax(), there is no such option. The function will raise a "ValueError: All-NaN slice encountered" exception. I think it would be better if the behaviour of Dataset.argmax() was modelled on the behaviour of Dataset.mean() such that it would raise warning and set a NaN value. It seems fair to consider that the index that maximize the value of a all-NaN slice is NaN.

The implementation of such a feature may (but don't need to) depend on numpy/numpy#12352

MCVE Code Sample

dat = xr.DataArray(np.ones((3, 2)), coords=dict(x=[0.1, 0.2, 0.3], y=[1, 2]), dims=["x", "y"])
dat[:, 1] = np.nan
dat.mean("x")
dat.argmax("x")

output

<xarray.DataArray (y: 2)>
array([ 1., nan])
Coordinates:
  * y        (y) int64 1 2

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-56-b53e860d1e36> in <module>
      2 dat[:, 1] = np.nan
      3 print(dat.mean("x"))
----> 4 print(dat.argmax("x"))

~/anaconda3/lib/python3.7/site-packages/xarray/core/common.py in wrapped_func(self, dim, axis, skipna, **kwargs)
     44 
     45             def wrapped_func(self, dim=None, axis=None, skipna=None, **kwargs):
---> 46                 return self.reduce(func, dim, axis, skipna=skipna, **kwargs)
     47 
     48         else:

~/anaconda3/lib/python3.7/site-packages/xarray/core/dataarray.py in reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
   2233         """
   2234 
-> 2235         var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
   2236         return self._replace_maybe_drop_dims(var)
   2237 

~/anaconda3/lib/python3.7/site-packages/xarray/core/variable.py in reduce(self, func, dim, axis, keep_attrs, keepdims, allow_lazy, **kwargs)
   1533 
   1534         if axis is not None:
-> 1535             data = func(input_data, axis=axis, **kwargs)
   1536         else:
   1537             data = func(input_data, **kwargs)

~/anaconda3/lib/python3.7/site-packages/xarray/core/duck_array_ops.py in f(values, axis, skipna, **kwargs)
    305 
    306         try:
--> 307             return func(values, axis=axis, **kwargs)
    308         except AttributeError:
    309             if isinstance(values, dask_array_type):

~/anaconda3/lib/python3.7/site-packages/xarray/core/nanops.py in nanargmax(a, axis)
    105 
    106     module = dask_array if isinstance(a, dask_array_type) else nputils
--> 107     return module.nanargmax(a, axis=axis)
    108 
    109 

~/anaconda3/lib/python3.7/site-packages/xarray/core/nputils.py in f(values, axis, **kwargs)
    213             result = bn_func(values, axis=axis, **kwargs)
    214         else:
--> 215             result = getattr(npmodule, name)(values, axis=axis, **kwargs)
    216 
    217         return result

<__array_function__ internals> in nanargmax(*args, **kwargs)

~/anaconda3/lib/python3.7/site-packages/numpy/lib/nanfunctions.py in nanargmax(a, axis)
    549         mask = np.all(mask, axis=axis)
    550         if np.any(mask):
--> 551             raise ValueError("All-NaN slice encountered")
    552     return res
    553 

ValueError: All-NaN slice encountered

Expected Output

<xarray.DataArray (y: 2)>
array([ 1., nan])
Coordinates:
  * y        (y) int64 1 2
<xarray.DataArray (y: 2)>
array([0, nan])
Coordinates:
  * y        (y) int64 1 2

Versions

Output of `xr.show_versions()`

INSTALLED VERSIONS

commit: None
python: 3.7.4 (default, Aug 13 2019, 20:35:49)
[GCC 7.3.0]
python-bits: 64
OS: Linux
OS-release: 5.3.0-42-generic
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: en_CA.UTF-8
LOCALE: en_CA.UTF-8
libhdf5: 1.10.4
libnetcdf: 4.6.3

xarray: 0.15.0
pandas: 1.0.1
numpy: 1.18.2
scipy: 1.4.1
netCDF4: 1.5.3
pydap: None
h5netcdf: None
h5py: 2.9.0
Nio: None
zarr: None
cftime: 1.0.4.2
nc_time_axis: None
PseudoNetCDF: None
rasterio: None
cfgrib: None
iris: None
bottleneck: None
dask: 2.11.0
distributed: None
matplotlib: 3.1.2
cartopy: None
seaborn: 0.9.0
numbagg: None
setuptools: 45.2.0.post20200210
pip: 20.0.2
conda: 4.8.2
pytest: None
IPython: 7.12.0
sphinx: 2.3.1

@max-sixty
Copy link
Collaborator

I think this would be a reasonable change

@shoyer
Copy link
Member

shoyer commented Mar 26, 2020

The main concern here is type stability. Normally the return value of argmax is an integer dtype array, but NaN isn't a valid integer :(

My suggestion would be to add an optional fill_value argument, similar to the API under discussion for idxmax in #3871. If fill_value is specified (e.g., with fill_value=np.nan or fill_value=-1) then missing values are returned with the fill value instead of raising an error.

@toddrjen
Copy link
Contributor

The problem I had when implementing idxmin and idxmax is that this behavior is defined by numpy, not by xarray, and bottleneck follows the same behavior, with xarray generally delegating the computation to one of these. So you would need to somehow work around the behavior of numpy in xarray or get a fix implemented both in numpy and bottleneck.

@shoyer
Copy link
Member

shoyer commented Mar 26, 2020

NumPy implements nanargmax, including raising the error, in Python. It would be very doable to copy a modified version into xarray.

@toddrjen
Copy link
Contributor

@shoyer xarray uses bottleneck for that if it can in xarray.nputils, so copying the numpy method would result in a performance hit. However, xarray maintains a wrapper around the numpy/bottleneck version in xarray.nanops where this could perhaps be implemented.

@shoyer
Copy link
Member

shoyer commented Mar 27, 2020

I wouldn’t worry too much about reusing bottleneck here, unless we really these functions will be the bottleneck in user code :)

@dcherian dcherian changed the title Allow for All-NaN in Dataset.argmax() Allow for All-NaN in argmax, argmin Oct 3, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants