Skip to content

Commit

Permalink
BUG/TST: non-numeric EA reductions (#59234)
Browse files Browse the repository at this point in the history
* BUG/TST: non-numeric EA reductions

* whatsnew

* add keepdims keyword to StringArray._reduce
  • Loading branch information
lukemanley committed Jul 13, 2024
1 parent 2a9855b commit 39bd3d3
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 19 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,7 @@ ExtensionArray
^^^^^^^^^^^^^^
- Bug in :meth:`.arrays.ArrowExtensionArray.__setitem__` which caused wrong behavior when using an integer array with repeated values as a key (:issue:`58530`)
- Bug in :meth:`api.types.is_datetime64_any_dtype` where a custom :class:`ExtensionDtype` would return ``False`` for array-likes (:issue:`57055`)
- Bug in various :class:`DataFrame` reductions for pyarrow temporal dtypes returning incorrect dtype when result was null (:issue:`59234`)

Styler
^^^^^^
Expand Down
2 changes: 0 additions & 2 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,8 +1706,6 @@ def pyarrow_meth(data, skip_nulls, **kwargs):
if name == "median":
# GH 52679: Use quantile instead of approximate_median; returns array
result = result[0]
if pc.is_null(result).as_py():
return result

if name in ["min", "max", "sum"] and pa.types.is_duration(pa_type):
result = result.cast(pa_type)
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,7 +1986,10 @@ def _reduce(
)
result = meth(skipna=skipna, **kwargs)
if keepdims:
result = np.array([result])
if name in ["min", "max"]:
result = self._from_sequence([result], dtype=self.dtype)
else:
result = np.array([result])

return result

Expand Down
13 changes: 13 additions & 0 deletions pandas/core/arrays/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2275,6 +2275,19 @@ def to_julian_date(self) -> npt.NDArray[np.float64]:
# -----------------------------------------------------------------
# Reductions

def _reduce(
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
):
result = super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)
if keepdims and isinstance(result, np.ndarray):
if name == "std":
from pandas.core.arrays import TimedeltaArray

return TimedeltaArray._from_sequence(result)
else:
return self._from_sequence(result, dtype=self.dtype)
return result

def std(
self,
axis=None,
Expand Down
11 changes: 11 additions & 0 deletions pandas/core/arrays/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,17 @@ def _check_timedeltalike_freq_compat(self, other):
delta = delta.view("i8")
return lib.item_from_zerodim(delta)

# ------------------------------------------------------------------
# Reductions

def _reduce(
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
):
result = super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)
if keepdims and isinstance(result, np.ndarray):
return self._from_sequence(result, dtype=self.dtype)
return result


def raise_on_incompatible(left, right) -> IncompatibleFrequency:
"""
Expand Down
13 changes: 11 additions & 2 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,19 @@ def astype(self, dtype, copy: bool = True):
return super().astype(dtype, copy)

def _reduce(
self, name: str, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs
self,
name: str,
*,
skipna: bool = True,
keepdims: bool = False,
axis: AxisInt | None = 0,
**kwargs,
):
if name in ["min", "max"]:
return getattr(self, name)(skipna=skipna, axis=axis)
result = getattr(self, name)(skipna=skipna, axis=axis)
if keepdims:
return self._from_sequence([result], dtype=self.dtype)
return result

raise TypeError(f"Cannot perform reduction '{name}' with string dtype")

Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/extension/base/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import pandas as pd
import pandas._testing as tm
from pandas.api.types import is_numeric_dtype


class BaseReduceTests:
Expand Down Expand Up @@ -119,8 +118,6 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna):
def test_reduce_frame(self, data, all_numeric_reductions, skipna):
op_name = all_numeric_reductions
ser = pd.Series(data)
if not is_numeric_dtype(ser.dtype):
pytest.skip(f"{ser.dtype} is not numeric dtype")

if op_name in ["count", "kurt", "sem"]:
pytest.skip(f"{op_name} not an array method")
Expand Down
47 changes: 36 additions & 11 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@

pa = pytest.importorskip("pyarrow")

from pandas.core.arrays.arrow.array import ArrowExtensionArray
from pandas.core.arrays.arrow.array import (
ArrowExtensionArray,
get_unit_from_pa_dtype,
)
from pandas.core.arrays.arrow.extension_types import ArrowPeriodType


Expand Down Expand Up @@ -505,6 +508,16 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
# behavior which does not support this.
return False

if pa.types.is_boolean(pa_dtype) and op_name in [
"median",
"std",
"var",
"skew",
"kurt",
"sem",
]:
return False

return True

def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
Expand Down Expand Up @@ -540,18 +553,9 @@ def test_reduce_series_numeric(self, data, all_numeric_reductions, skipna, reque
f"pyarrow={pa.__version__} for {pa_dtype}"
),
)
if all_numeric_reductions in {"skew", "kurt"} and (
dtype._is_numeric or dtype.kind == "b"
):
if all_numeric_reductions in {"skew", "kurt"} and dtype._is_numeric:
request.applymarker(xfail_mark)

elif pa.types.is_boolean(pa_dtype) and all_numeric_reductions in {
"sem",
"std",
"var",
"median",
}:
request.applymarker(xfail_mark)
super().test_reduce_series_numeric(data, all_numeric_reductions, skipna)

@pytest.mark.parametrize("skipna", [True, False])
Expand All @@ -574,15 +578,32 @@ def test_reduce_series_boolean(
return super().test_reduce_series_boolean(data, all_boolean_reductions, skipna)

def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool):
pa_type = arr._pa_array.type

if op_name in ["max", "min"]:
cmp_dtype = arr.dtype
elif pa.types.is_temporal(pa_type):
if op_name in ["std", "sem"]:
if pa.types.is_duration(pa_type):
cmp_dtype = arr.dtype
elif pa.types.is_date(pa_type):
cmp_dtype = ArrowDtype(pa.duration("s"))
elif pa.types.is_time(pa_type):
unit = get_unit_from_pa_dtype(pa_type)
cmp_dtype = ArrowDtype(pa.duration(unit))
else:
cmp_dtype = ArrowDtype(pa.duration(pa_type.unit))
else:
cmp_dtype = arr.dtype
elif arr.dtype.name == "decimal128(7, 3)[pyarrow]":
if op_name not in ["median", "var", "std"]:
cmp_dtype = arr.dtype
else:
cmp_dtype = "float64[pyarrow]"
elif op_name in ["median", "var", "std", "mean", "skew"]:
cmp_dtype = "float64[pyarrow]"
elif op_name in ["sum", "prod"] and pa.types.is_boolean(pa_type):
cmp_dtype = "uint64[pyarrow]"
else:
cmp_dtype = {
"i": "int64[pyarrow]",
Expand All @@ -598,6 +619,10 @@ def test_reduce_frame(self, data, all_numeric_reductions, skipna, request):
if data.dtype._is_numeric:
mark = pytest.mark.xfail(reason="skew not implemented")
request.applymarker(mark)
elif op_name == "std" and pa.types.is_date64(data._pa_array.type) and skipna:
# overflow
mark = pytest.mark.xfail(reason="Cannot cast")
request.applymarker(mark)
return super().test_reduce_frame(data, all_numeric_reductions, skipna)

@pytest.mark.parametrize("typ", ["int64", "uint64", "float64"])
Expand Down
5 changes: 5 additions & 0 deletions pandas/tests/extension/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def _get_expected_exception(self, op_name, obj, other):
return None
return super()._get_expected_exception(op_name, obj, other)

def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool):
if op_name == "std":
return "timedelta64[ns]"
return arr.dtype

def _supports_accumulation(self, ser, op_name: str) -> bool:
return op_name in ["cummin", "cummax"]

Expand Down

0 comments on commit 39bd3d3

Please sign in to comment.