Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

COMPAT: Add keepdims and friends to validation #24356

Merged
merged 5 commits into from
Dec 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.24.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,7 @@ Numeric
- Added ``log10`` to the list of supported functions in :meth:`DataFrame.eval` (:issue:`24139`)
- Logical operations ``&, |, ^`` between :class:`Series` and :class:`Index` will no longer raise ``ValueError`` (:issue:`22092`)
- Checking PEP 3141 numbers in :func:`~pandas.api.types.is_scalar` function returns ``True`` (:issue:`22903`)
- Reduction methods like :meth:`Series.sum` now accept the default value of ``keepdims=False`` when called from a NumPy ufunc, rather than raising a ``TypeError``. Full support for ``keepdims`` has not been implemented (:issue:`24356`).

Conversion
^^^^^^^^^^
Expand Down
23 changes: 20 additions & 3 deletions pandas/compat/numpy/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,16 @@ def validate_cum_func_with_skipna(skipna, args, kwargs, name):
ALLANY_DEFAULTS = OrderedDict()
ALLANY_DEFAULTS['dtype'] = None
ALLANY_DEFAULTS['out'] = None
ALLANY_DEFAULTS['keepdims'] = False
validate_all = CompatValidator(ALLANY_DEFAULTS, fname='all',
method='both', max_fname_arg_count=1)
validate_any = CompatValidator(ALLANY_DEFAULTS, fname='any',
method='both', max_fname_arg_count=1)

LOGICAL_FUNC_DEFAULTS = dict(out=None)
LOGICAL_FUNC_DEFAULTS = dict(out=None, keepdims=False)
validate_logical_func = CompatValidator(LOGICAL_FUNC_DEFAULTS, method='kwargs')

MINMAX_DEFAULTS = dict(out=None)
MINMAX_DEFAULTS = dict(out=None, keepdims=False)
validate_min = CompatValidator(MINMAX_DEFAULTS, fname='min',
method='both', max_fname_arg_count=1)
validate_max = CompatValidator(MINMAX_DEFAULTS, fname='max',
Expand Down Expand Up @@ -225,16 +226,32 @@ def validate_cum_func_with_skipna(skipna, args, kwargs, name):
STAT_FUNC_DEFAULTS = OrderedDict()
STAT_FUNC_DEFAULTS['dtype'] = None
STAT_FUNC_DEFAULTS['out'] = None

PROD_DEFAULTS = SUM_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit complicated to get the order of the keywords correct.

SUM_DEFAULTS['keepdims'] = False
SUM_DEFAULTS['initial'] = None

MEDIAN_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
MEDIAN_DEFAULTS['overwrite_input'] = False
MEDIAN_DEFAULTS['keepdims'] = False

STAT_FUNC_DEFAULTS['keepdims'] = False

validate_stat_func = CompatValidator(STAT_FUNC_DEFAULTS,
method='kwargs')
validate_sum = CompatValidator(STAT_FUNC_DEFAULTS, fname='sort',
validate_sum = CompatValidator(SUM_DEFAULTS, fname='sum',
gfyoung marked this conversation as resolved.
Show resolved Hide resolved
method='both', max_fname_arg_count=1)
validate_prod = CompatValidator(PROD_DEFAULTS, fname="prod",
method="both", max_fname_arg_count=1)
validate_mean = CompatValidator(STAT_FUNC_DEFAULTS, fname='mean',
method='both', max_fname_arg_count=1)
validate_median = CompatValidator(MEDIAN_DEFAULTS, fname='median',
method='both', max_fname_arg_count=1)

STAT_DDOF_FUNC_DEFAULTS = OrderedDict()
STAT_DDOF_FUNC_DEFAULTS['dtype'] = None
STAT_DDOF_FUNC_DEFAULTS['out'] = None
STAT_DDOF_FUNC_DEFAULTS['keepdims'] = False
validate_stat_ddof_func = CompatValidator(STAT_DDOF_FUNC_DEFAULTS,
method='kwargs')

Expand Down
12 changes: 10 additions & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10834,7 +10834,12 @@ def _make_min_count_stat_function(cls, name, name1, name2, axis_descr, desc,
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
min_count=0,
**kwargs):
nv.validate_stat_func(tuple(), kwargs, fname=name)
if name == 'sum':
nv.validate_sum(tuple(), kwargs)
elif name == 'prod':
nv.validate_prod(tuple(), kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this more generic? IOW have nv.validate_statu_func dispatch based on fname?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that would make sense, but it isn't really written that way right now. Right now we just have instances of CompatValidator sitting in function.py, and we choose the right one to call.

It'd be nice to have a decorator that did it for us

@validate_numpy
def mean(self, ...):
    ...

then we don't have the duplication of the function name. But I think that's a decent sized refactor of how things are done now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok but this seems very hacky to hardcore when we already know the name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think that's a decent sized refactor of how things are done now.

Agreed. At the time, I wrote it that way to keep things explicit, and it made it slightly easier to handle the details of each analogous numpy function.

That being said, I think refactoring to dispatching seems reasonable as well but would be best served for investigation and execution in a follow-up.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok for now. let's followup and see if can reduce need to specify names like this.

else:
nv.validate_stat_func(tuple(), kwargs, fname=name)
if skipna is None:
skipna = True
if axis is None:
Expand All @@ -10855,7 +10860,10 @@ def _make_stat_function(cls, name, name1, name2, axis_descr, desc, f,
@Appender(_num_doc)
def stat_func(self, axis=None, skipna=None, level=None, numeric_only=None,
**kwargs):
nv.validate_stat_func(tuple(), kwargs, fname=name)
if name == 'median':
nv.validate_median(tuple(), kwargs)
else:
nv.validate_stat_func(tuple(), kwargs, fname=name)
if skipna is None:
skipna = True
if axis is None:
Expand Down
36 changes: 36 additions & 0 deletions pandas/tests/series/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1641,6 +1641,42 @@ def test_value_counts_categorical_not_ordered(self):
tm.assert_series_equal(s.value_counts(normalize=True), exp)
tm.assert_series_equal(idx.value_counts(normalize=True), exp)

@pytest.mark.parametrize("func", [np.any, np.all])
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize("kwargs", [
dict(keepdims=True),
dict(out=object()),
])
@td.skip_if_np_lt_115
def test_validate_any_all_out_keepdims_raises(self, kwargs, func):
s = pd.Series([1, 2])
param = list(kwargs)[0]
name = func.__name__

msg = "the '{}' parameter .* {}".format(param, name)
with pytest.raises(ValueError, match=msg):
func(s, **kwargs)

@td.skip_if_np_lt_115
def test_validate_sum_initial(self):
s = pd.Series([1, 2])
with pytest.raises(ValueError, match="the 'initial' .* sum"):
np.sum(s, initial=10)

def test_validate_median_initial(self):
s = pd.Series([1, 2])
with pytest.raises(ValueError,
match="the 'overwrite_input' .* median"):
# It seems like np.median doesn't dispatch, so we use the
# method instead of the ufunc.
s.median(overwrite_input=True)

@td.skip_if_np_lt_115
def test_validate_stat_keepdims(self):
s = pd.Series([1, 2])
with pytest.raises(ValueError,
match="the 'keepdims'"):
np.sum(s, keepdims=True)


main_dtypes = [
'datetime',
Expand Down