Skip to content

Commit

Permalink
CLN: Enforce deprecation of argmin/max and idxmin/max with NA values
Browse files Browse the repository at this point in the history
  • Loading branch information
rhshadrach committed Mar 23, 2024
1 parent a789288 commit c3d0cfa
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 220 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ Removal of prior version deprecations/changes
- Removed the :class:`Grouper` attributes ``ax``, ``groups``, ``indexer``, and ``obj`` (:issue:`51206`, :issue:`51182`)
- Removed deprecated keyword ``verbose`` on :func:`read_csv` and :func:`read_table` (:issue:`56556`)
- Removed the attribute ``dtypes`` from :class:`.DataFrameGroupBy` (:issue:`51997`)
- Enforced deprecation of ``argmin``, ``argmax``, ``idxmin``, and ``idxmax`` returning a result when ``skipna=False`` and an NA value is encountered or all values are NA values; these operations will now raise in such cases (:issue:`33941`, :issue:`51276`)

.. ---------------------------------------------------------------------------
.. _whatsnew_300.performance:
Expand Down
52 changes: 12 additions & 40 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
final,
overload,
)
import warnings

import numpy as np

Expand All @@ -35,7 +34,6 @@
cache_readonly,
doc,
)
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.cast import can_hold_element
from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -736,28 +734,15 @@ def argmax(
nv.validate_minmax_axis(axis)
skipna = nv.validate_argmax_with_skipna(skipna, args, kwargs)

if skipna and len(delegate) > 0 and isna(delegate).all():
raise ValueError("Encountered all NA values")
elif not skipna and isna(delegate).any():
raise ValueError("Encountered an NA value with skipna=False")

if isinstance(delegate, ExtensionArray):
if not skipna and delegate.isna().any():
warnings.warn(
f"The behavior of {type(self).__name__}.argmax/argmin "
"with skipna=False and NAs, or with all-NAs is deprecated. "
"In a future version this will raise ValueError.",
FutureWarning,
stacklevel=find_stack_level(),
)
return -1
else:
return delegate.argmax()
return delegate.argmax()
else:
result = nanops.nanargmax(delegate, skipna=skipna)
if result == -1:
warnings.warn(
f"The behavior of {type(self).__name__}.argmax/argmin "
"with skipna=False and NAs, or with all-NAs is deprecated. "
"In a future version this will raise ValueError.",
FutureWarning,
stacklevel=find_stack_level(),
)
# error: Incompatible return value type (got "Union[int, ndarray]", expected
# "int")
return result # type: ignore[return-value]
Expand All @@ -770,28 +755,15 @@ def argmin(
nv.validate_minmax_axis(axis)
skipna = nv.validate_argmin_with_skipna(skipna, args, kwargs)

if skipna and len(delegate) > 0 and isna(delegate).all():
raise ValueError("Encountered all NA values")
elif not skipna and isna(delegate).any():
raise ValueError("Encountered an NA value with skipna=False")

if isinstance(delegate, ExtensionArray):
if not skipna and delegate.isna().any():
warnings.warn(
f"The behavior of {type(self).__name__}.argmax/argmin "
"with skipna=False and NAs, or with all-NAs is deprecated. "
"In a future version this will raise ValueError.",
FutureWarning,
stacklevel=find_stack_level(),
)
return -1
else:
return delegate.argmin()
return delegate.argmin()
else:
result = nanops.nanargmin(delegate, skipna=skipna)
if result == -1:
warnings.warn(
f"The behavior of {type(self).__name__}.argmax/argmin "
"with skipna=False and NAs, or with all-NAs is deprecated. "
"In a future version this will raise ValueError.",
FutureWarning,
stacklevel=find_stack_level(),
)
# error: Incompatible return value type (got "Union[int, ndarray]", expected
# "int")
return result # type: ignore[return-value]
Expand Down
28 changes: 8 additions & 20 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6976,16 +6976,10 @@ def argmin(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:

if not self._is_multi and self.hasnans:
# Take advantage of cache
mask = self._isnan
if not skipna or mask.all():
warnings.warn(
f"The behavior of {type(self).__name__}.argmax/argmin "
"with skipna=False and NAs, or with all-NAs is deprecated. "
"In a future version this will raise ValueError.",
FutureWarning,
stacklevel=find_stack_level(),
)
return -1
if self._isnan.all():
raise ValueError("Encountered all NA values")
elif not skipna:
raise ValueError("Encountered an NA value with skipna=False")
return super().argmin(skipna=skipna)

@Appender(IndexOpsMixin.argmax.__doc__)
Expand All @@ -6995,16 +6989,10 @@ def argmax(self, axis=None, skipna: bool = True, *args, **kwargs) -> int:

if not self._is_multi and self.hasnans:
# Take advantage of cache
mask = self._isnan
if not skipna or mask.all():
warnings.warn(
f"The behavior of {type(self).__name__}.argmax/argmin "
"with skipna=False and NAs, or with all-NAs is deprecated. "
"In a future version this will raise ValueError.",
FutureWarning,
stacklevel=find_stack_level(),
)
return -1
if self._isnan.all():
raise ValueError("Encountered all NA values")
elif not skipna:
raise ValueError("Encountered an NA value with skipna=False")
return super().argmax(skipna=skipna)

def min(self, axis=None, skipna: bool = True, *args, **kwargs):
Expand Down
15 changes: 8 additions & 7 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,17 +1441,18 @@ def _maybe_arg_null_out(
if axis is None or not getattr(result, "ndim", False):
if skipna:
if mask.all():
return -1
raise ValueError("Encountered all NA values")
else:
if mask.any():
return -1
raise ValueError("Encountered an NA value with skipna=False")
else:
if skipna:
na_mask = mask.all(axis)
else:
na_mask = mask.any(axis)
na_mask = mask.all(axis)
if na_mask.any():
result[na_mask] = -1
raise ValueError("Encountered all NA values")
elif not skipna:
na_mask = mask.any(axis)
if na_mask.any():
raise ValueError("Encountered an NA value with skipna=False")
return result


Expand Down
40 changes: 4 additions & 36 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2384,24 +2384,8 @@ def idxmin(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Hashab
nan
"""
axis = self._get_axis_number(axis)
with warnings.catch_warnings():
# TODO(3.0): this catching/filtering can be removed
# ignore warning produced by argmin since we will issue a different
# warning for idxmin
warnings.simplefilter("ignore")
i = self.argmin(axis, skipna, *args, **kwargs)

if i == -1:
# GH#43587 give correct NA value for Index.
warnings.warn(
f"The behavior of {type(self).__name__}.idxmin with all-NA "
"values, or any-NA and skipna=False, is deprecated. In a future "
"version this will raise ValueError",
FutureWarning,
stacklevel=find_stack_level(),
)
return self.index._na_value
return self.index[i]
iloc = self.argmin(axis, skipna, *args, **kwargs)
return self.index[iloc]

def idxmax(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Hashable:
"""
Expand Down Expand Up @@ -2467,24 +2451,8 @@ def idxmax(self, axis: Axis = 0, skipna: bool = True, *args, **kwargs) -> Hashab
nan
"""
axis = self._get_axis_number(axis)
with warnings.catch_warnings():
# TODO(3.0): this catching/filtering can be removed
# ignore warning produced by argmax since we will issue a different
# warning for argmax
warnings.simplefilter("ignore")
i = self.argmax(axis, skipna, *args, **kwargs)

if i == -1:
# GH#43587 give correct NA value for Index.
warnings.warn(
f"The behavior of {type(self).__name__}.idxmax with all-NA "
"values, or any-NA and skipna=False, is deprecated. In a future "
"version this will raise ValueError",
FutureWarning,
stacklevel=find_stack_level(),
)
return self.index._na_value
return self.index[i]
iloc = self.argmax(axis, skipna, *args, **kwargs)
return self.index[iloc]

def round(self, decimals: int = 0, *args, **kwargs) -> Series:
"""
Expand Down
18 changes: 7 additions & 11 deletions pandas/tests/extension/base/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def test_argmin_argmax_all_na(self, method, data, na_value):
("idxmin", True, 2),
("argmax", True, 0),
("argmin", True, 2),
("idxmax", False, np.nan),
("idxmin", False, np.nan),
("idxmax", False, -1),
("idxmin", False, -1),
("argmax", False, -1),
("argmin", False, -1),
],
Expand All @@ -179,17 +179,13 @@ def test_argreduce_series(
self, data_missing_for_sorting, op_name, skipna, expected
):
# data_missing_for_sorting -> [B, NA, A] with A < B and NA missing.
warn = None
msg = "The behavior of Series.argmax/argmin"
if op_name.startswith("arg") and expected == -1:
warn = FutureWarning
if op_name.startswith("idx") and np.isnan(expected):
warn = FutureWarning
msg = f"The behavior of Series.{op_name}"
ser = pd.Series(data_missing_for_sorting)
with tm.assert_produces_warning(warn, match=msg):
if expected == -1:
with pytest.raises(ValueError, match="Encountered an NA value"):
getattr(ser, op_name)(skipna=skipna)
else:
result = getattr(ser, op_name)(skipna=skipna)
tm.assert_almost_equal(result, expected)
tm.assert_almost_equal(result, expected)

def test_argmax_argmin_no_skipna_notimplemented(self, data_missing_for_sorting):
# GH#38733
Expand Down
56 changes: 30 additions & 26 deletions pandas/tests/frame/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,18 +1065,20 @@ def test_idxmin(self, float_frame, int_frame, skipna, axis):
frame.iloc[5:10] = np.nan
frame.iloc[15:20, -2:] = np.nan
for df in [frame, int_frame]:
warn = None
if skipna is False or axis == 1:
warn = None if df is int_frame else FutureWarning
msg = "The behavior of DataFrame.idxmin with all-NA values"
with tm.assert_produces_warning(warn, match=msg):
if (not skipna or axis == 1) and df is not int_frame:
if axis == 1:
msg = "Encountered all NA values"
else:
msg = "Encountered an NA value"
with pytest.raises(ValueError, match=msg):
df.idxmin(axis=axis, skipna=skipna)
with pytest.raises(ValueError, match=msg):
df.idxmin(axis=axis, skipna=skipna)
else:
result = df.idxmin(axis=axis, skipna=skipna)

msg2 = "The behavior of Series.idxmin"
with tm.assert_produces_warning(warn, match=msg2):
expected = df.apply(Series.idxmin, axis=axis, skipna=skipna)
expected = expected.astype(df.index.dtype)
tm.assert_series_equal(result, expected)
expected = expected.astype(df.index.dtype)
tm.assert_series_equal(result, expected)

@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.filterwarnings(r"ignore:PeriodDtype\[B\] is deprecated:FutureWarning")
Expand Down Expand Up @@ -1113,16 +1115,17 @@ def test_idxmax(self, float_frame, int_frame, skipna, axis):
frame.iloc[5:10] = np.nan
frame.iloc[15:20, -2:] = np.nan
for df in [frame, int_frame]:
warn = None
if skipna is False or axis == 1:
warn = None if df is int_frame else FutureWarning
msg = "The behavior of DataFrame.idxmax with all-NA values"
with tm.assert_produces_warning(warn, match=msg):
result = df.idxmax(axis=axis, skipna=skipna)
if (skipna is False or axis == 1) and df is frame:
if axis == 1:
msg = "Encountered all NA values"
else:
msg = "Encountered an NA value"
with pytest.raises(ValueError, match=msg):
df.idxmax(axis=axis, skipna=skipna)
return

msg2 = "The behavior of Series.idxmax"
with tm.assert_produces_warning(warn, match=msg2):
expected = df.apply(Series.idxmax, axis=axis, skipna=skipna)
result = df.idxmax(axis=axis, skipna=skipna)
expected = df.apply(Series.idxmax, axis=axis, skipna=skipna)
expected = expected.astype(df.index.dtype)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -2118,15 +2121,16 @@ def test_numeric_ea_axis_1(method, skipna, min_count, any_numeric_ea_dtype):
if method in ("prod", "product", "sum"):
kwargs["min_count"] = min_count

warn = None
msg = None
if not skipna and method in ("idxmax", "idxmin"):
warn = FutureWarning
# GH#57745 - EAs use groupby for axis=1 which still needs a proper deprecation.
msg = f"The behavior of DataFrame.{method} with all-NA values"
with tm.assert_produces_warning(warn, match=msg):
result = getattr(df, method)(axis=1, **kwargs)
with tm.assert_produces_warning(warn, match=msg):
expected = getattr(expected_df, method)(axis=1, **kwargs)
with tm.assert_produces_warning(FutureWarning, match=msg):
getattr(df, method)(axis=1, **kwargs)
with pytest.raises(ValueError, match="Encountered an NA value"):
getattr(expected_df, method)(axis=1, **kwargs)
return
result = getattr(df, method)(axis=1, **kwargs)
expected = getattr(expected_df, method)(axis=1, **kwargs)
if method not in ("idxmax", "idxmin"):
expected = expected.astype(expected_dtype)
tm.assert_series_equal(result, expected)

0 comments on commit c3d0cfa

Please sign in to comment.