diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index e9d4225c3dbd9..3dff5eed8a81a 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -653,6 +653,7 @@ Reshaping - Bug in :meth:`Series.where` and :meth:`DataFrame.where` with ``datetime64[ns, tz]`` dtype (:issue:`21546`) - Bug in :meth:`Series.mask` and :meth:`DataFrame.mask` with ``list`` conditionals (:issue:`21891`) - Bug in :meth:`DataFrame.replace` raises RecursionError when converting OutOfBounds ``datetime64[ns, tz]`` (:issue:`20380`) +- :func:`pandas.core.groupby.GroupBy.rank` now raises a ``ValueError`` when an invalid value is passed for argument ``na_option`` (:issue:`22124`) - Build Changes diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 4b0143b3e1ced..3f84fa0f0670e 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1705,6 +1705,9 @@ def rank(self, method='average', ascending=True, na_option='keep', ----- DataFrame with ranking of values within each group """ + if na_option not in {'keep', 'top', 'bottom'}: + msg = "na_option must be one of 'keep', 'top', or 'bottom'" + raise ValueError(msg) return self._cython_transform('rank', numeric_only=False, ties_method=method, ascending=ascending, na_option=na_option, pct=pct, axis=axis) diff --git a/pandas/tests/groupby/test_rank.py b/pandas/tests/groupby/test_rank.py index 0628f9c79a154..f0dcf768e3607 100644 --- a/pandas/tests/groupby/test_rank.py +++ b/pandas/tests/groupby/test_rank.py @@ -172,35 +172,35 @@ def test_infs_n_nans(grps, vals, ties_method, ascending, na_option, exp): [3., 3., np.nan, 1., 3., 2., np.nan, np.nan]), ('dense', False, 'keep', True, [3. / 3., 3. / 3., np.nan, 1. / 3., 3. / 3., 2. / 3., np.nan, np.nan]), - ('average', True, 'no_na', False, [2., 2., 7., 5., 2., 4., 7., 7.]), - ('average', True, 'no_na', True, + ('average', True, 'bottom', False, [2., 2., 7., 5., 2., 4., 7., 7.]), + ('average', True, 'bottom', True, [0.25, 0.25, 0.875, 0.625, 0.25, 0.5, 0.875, 0.875]), - ('average', False, 'no_na', False, [4., 4., 7., 1., 4., 2., 7., 7.]), - ('average', False, 'no_na', True, + ('average', False, 'bottom', False, [4., 4., 7., 1., 4., 2., 7., 7.]), + ('average', False, 'bottom', True, [0.5, 0.5, 0.875, 0.125, 0.5, 0.25, 0.875, 0.875]), - ('min', True, 'no_na', False, [1., 1., 6., 5., 1., 4., 6., 6.]), - ('min', True, 'no_na', True, + ('min', True, 'bottom', False, [1., 1., 6., 5., 1., 4., 6., 6.]), + ('min', True, 'bottom', True, [0.125, 0.125, 0.75, 0.625, 0.125, 0.5, 0.75, 0.75]), - ('min', False, 'no_na', False, [3., 3., 6., 1., 3., 2., 6., 6.]), - ('min', False, 'no_na', True, + ('min', False, 'bottom', False, [3., 3., 6., 1., 3., 2., 6., 6.]), + ('min', False, 'bottom', True, [0.375, 0.375, 0.75, 0.125, 0.375, 0.25, 0.75, 0.75]), - ('max', True, 'no_na', False, [3., 3., 8., 5., 3., 4., 8., 8.]), - ('max', True, 'no_na', True, + ('max', True, 'bottom', False, [3., 3., 8., 5., 3., 4., 8., 8.]), + ('max', True, 'bottom', True, [0.375, 0.375, 1., 0.625, 0.375, 0.5, 1., 1.]), - ('max', False, 'no_na', False, [5., 5., 8., 1., 5., 2., 8., 8.]), - ('max', False, 'no_na', True, + ('max', False, 'bottom', False, [5., 5., 8., 1., 5., 2., 8., 8.]), + ('max', False, 'bottom', True, [0.625, 0.625, 1., 0.125, 0.625, 0.25, 1., 1.]), - ('first', True, 'no_na', False, [1., 2., 6., 5., 3., 4., 7., 8.]), - ('first', True, 'no_na', True, + ('first', True, 'bottom', False, [1., 2., 6., 5., 3., 4., 7., 8.]), + ('first', True, 'bottom', True, [0.125, 0.25, 0.75, 0.625, 0.375, 0.5, 0.875, 1.]), - ('first', False, 'no_na', False, [3., 4., 6., 1., 5., 2., 7., 8.]), - ('first', False, 'no_na', True, + ('first', False, 'bottom', False, [3., 4., 6., 1., 5., 2., 7., 8.]), + ('first', False, 'bottom', True, [0.375, 0.5, 0.75, 0.125, 0.625, 0.25, 0.875, 1.]), - ('dense', True, 'no_na', False, [1., 1., 4., 3., 1., 2., 4., 4.]), - ('dense', True, 'no_na', True, + ('dense', True, 'bottom', False, [1., 1., 4., 3., 1., 2., 4., 4.]), + ('dense', True, 'bottom', True, [0.25, 0.25, 1., 0.75, 0.25, 0.5, 1., 1.]), - ('dense', False, 'no_na', False, [3., 3., 4., 1., 3., 2., 4., 4.]), - ('dense', False, 'no_na', True, + ('dense', False, 'bottom', False, [3., 3., 4., 1., 3., 2., 4., 4.]), + ('dense', False, 'bottom', True, [0.75, 0.75, 1., 0.25, 0.75, 0.5, 1., 1.]) ]) def test_rank_args_missing(grps, vals, ties_method, ascending, @@ -252,14 +252,24 @@ def test_rank_object_raises(ties_method, ascending, na_option, with tm.assert_raises_regex(TypeError, "not callable"): df.groupby('key').rank(method=ties_method, ascending=ascending, - na_option='bad', pct=pct) + na_option=na_option, pct=pct) - with tm.assert_raises_regex(TypeError, "not callable"): - df.groupby('key').rank(method=ties_method, - ascending=ascending, - na_option=True, pct=pct) - with tm.assert_raises_regex(TypeError, "not callable"): +@pytest.mark.parametrize("na_option", [True, "bad", 1]) +@pytest.mark.parametrize("ties_method", [ + 'average', 'min', 'max', 'first', 'dense']) +@pytest.mark.parametrize("ascending", [True, False]) +@pytest.mark.parametrize("pct", [True, False]) +@pytest.mark.parametrize("vals", [ + ['bar', 'bar', 'foo', 'bar', 'baz'], + ['bar', np.nan, 'foo', np.nan, 'baz'], + [1, np.nan, 2, np.nan, 3] +]) +def test_rank_naoption_raises(ties_method, ascending, na_option, pct, vals): + df = DataFrame({'key': ['foo'] * 5, 'val': vals}) + msg = "na_option must be one of 'keep', 'top', or 'bottom'" + + with tm.assert_raises_regex(ValueError, msg): df.groupby('key').rank(method=ties_method, ascending=ascending, na_option=na_option, pct=pct)