Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enh/transform py fallback #58639

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Other enhancements
- :meth:`DataFrame.fillna` and :meth:`Series.fillna` can now accept ``value=None``; for non-object dtype the corresponding NA value will be used (:issue:`57723`)
- :meth:`Series.cummin` and :meth:`Series.cummax` now supports :class:`CategoricalDtype` (:issue:`52335`)
- :meth:`Series.plot` now correctly handle the ``ylabel`` parameter for pie charts, allowing for explicit control over the y-axis label (:issue:`58239`)
- Cythonized transformations now supports python fallback (:issue:`49758`)
- Support reading Stata 110-format (Stata 7) dta files (:issue:`47176`)
-

Expand Down
75 changes: 67 additions & 8 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
all_indexes_same,
default_index,
)
from pandas.core.internals.blocks import ensure_block_shape
from pandas.core.series import Series
from pandas.core.sorting import get_group_index
from pandas.core.util.numba_ import maybe_use_numba
Expand Down Expand Up @@ -535,17 +536,28 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
)

def _cython_transform(self, how: str, numeric_only: bool = False, **kwargs):
def _cython_transform(
self,
how: str,
alt: Callable | None = None,
numeric_only: bool = False,
**kwargs,
):
obj = self._obj_with_exclusions
values = obj._values

try:
result = self._grouper._cython_operation(
"transform", obj._values, how, 0, **kwargs
"transform", values, how, 0, **kwargs
)
except NotImplementedError as err:
# e.g. test_groupby_raises_string
raise TypeError(f"{how} is not supported for {obj.dtype} dtype") from err
except NotImplementedError:
if alt is None:
raise
else:
return obj._constructor(result, index=self.obj.index, name=obj.name)

assert alt is not None
result = self._transform_py_fallback(how, values, alt=alt)
return obj._constructor(result, index=self.obj.index, name=obj.name)

def _transform_general(
Expand Down Expand Up @@ -583,6 +595,23 @@ def _transform_general(
result.name = self.obj.name
return result

def _transform_py_fallback(
self, how: str, values: ArrayLike, alt: Callable
) -> ArrayLike:
assert alt is not None

series = Series(values, copy=False)
try:
res_values = self._grouper.transform_series(series, alt)
except Exception as err:
msg = f"transform function failed [how->{how},dtype->{series.dtype}]"
# preserve the kind of exception that raised
raise type(err)(msg) from err

if series.dtype == object:
res_values = res_values.astype(object, copy=False)
return res_values

def filter(self, func, dropna: bool = True, *args, **kwargs):
"""
Filter elements from groups that don't satisfy a criterion.
Expand Down Expand Up @@ -1742,6 +1771,7 @@ def _wrap_applied_output_series(
def _cython_transform(
self,
how: str,
alt: Callable | None = None,
numeric_only: bool = False,
**kwargs,
) -> DataFrame:
Expand All @@ -1753,9 +1783,17 @@ def _cython_transform(
)

def arr_func(bvalues: ArrayLike) -> ArrayLike:
return self._grouper._cython_operation(
"transform", bvalues, how, 1, **kwargs
)
try:
return self._grouper._cython_operation(
"transform", bvalues, how, 1, **kwargs
)
except NotImplementedError:
if alt is None:
raise

assert alt is not None
result = self._transform_py_fallback(how, bvalues, alt=alt)
return result

res_mgr = mgr.apply(arr_func)

Expand Down Expand Up @@ -1866,6 +1904,27 @@ def _transform_general(self, func, engine, engine_kwargs, *args, **kwargs):
"""
)

def _transform_py_fallback(
self, how: str, values: ArrayLike, alt: Callable
) -> ArrayLike:
print("IN TRANSFORM PY FALLBACK")
assert alt is not None

df = DataFrame(values.T, dtype=values.dtype)
assert df.shape[1] == 1
series = df.iloc[:, 0]

try:
res_values = self._grouper.transform_series(series, alt)
except Exception as err:
msg = f"transform function failed [how->{how},dtype->{series.dtype}]"
# preserve the kind of exception that raised
raise type(err)(msg) from err

if series.dtype == object:
res_values = res_values.astype(object, copy=False)
return ensure_block_shape(res_values, ndim=2)

@Substitution(klass="DataFrame", example=__examples_dataframe_doc)
@Appender(_transform_template)
def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
Expand Down
21 changes: 17 additions & 4 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -4656,6 +4656,13 @@ def rank(
return self._cython_transform(
"rank",
numeric_only=False,
alt=lambda x: Series(x, copy=False).rank(
method=method,
numeric_only=False,
na_option=na_option,
ascending=ascending,
pct=pct,
),
**kwargs,
)

Expand Down Expand Up @@ -4716,7 +4723,7 @@ def cumprod(self, *args, **kwargs) -> NDFrameT:
bull 6 9
"""
nv.validate_groupby_func("cumprod", args, kwargs, ["numeric_only", "skipna"])
return self._cython_transform("cumprod", **kwargs)
return self._cython_transform("cumprod", alt=np.cumprod, **kwargs)

@final
@Substitution(name="groupby")
Expand Down Expand Up @@ -4775,7 +4782,7 @@ def cumsum(self, *args, **kwargs) -> NDFrameT:
lion 6 9
"""
nv.validate_groupby_func("cumsum", args, kwargs, ["numeric_only", "skipna"])
return self._cython_transform("cumsum", **kwargs)
return self._cython_transform("cumsum", alt=np.cumsum, **kwargs)

@final
@Substitution(name="groupby")
Expand Down Expand Up @@ -4845,7 +4852,10 @@ def cummin(
"""
skipna = kwargs.get("skipna", True)
return self._cython_transform(
"cummin", numeric_only=numeric_only, skipna=skipna
"cummin",
numeric_only=numeric_only,
skipna=skipna,
alt=np.minimum.accumulate,
)

@final
Expand Down Expand Up @@ -4916,7 +4926,10 @@ def cummax(
"""
skipna = kwargs.get("skipna", True)
return self._cython_transform(
"cummax", numeric_only=numeric_only, skipna=skipna
"cummax",
numeric_only=numeric_only,
skipna=skipna,
alt=np.maximum.accumulate,
)

@final
Expand Down
68 changes: 68 additions & 0 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,14 @@
ensure_uint64,
is_1d_only_ea_dtype,
)
from pandas.core.dtypes.dtypes import ArrowDtype
from pandas.core.dtypes.missing import (
isna,
maybe_fill,
)

from pandas.core.arrays import Categorical
from pandas.core.arrays.arrow.array import ArrowExtensionArray
from pandas.core.frame import DataFrame
from pandas.core.groupby import grouper
from pandas.core.indexes.api import (
Expand Down Expand Up @@ -910,6 +912,72 @@ def _cython_operation(
**kwargs,
)

@final
def transform_series(
self, obj: Series, func: Callable, preserve_dtype: bool = False
) -> ArrayLike:
"""
Parameters
----------
obj : Series
func : function taking a Series and returning a Series
preserve_dtype : bool
Whether the aggregation is known to be dtype-preserving.

Returns
-------
np.ndarray or ExtensionArray
"""
# GH#58129
result = self._transform_series_pure_python(obj, func)
npvalues = lib.maybe_convert_objects(result, try_float=False)

if isinstance(obj._values, ArrowExtensionArray):
out = maybe_cast_pointwise_result(
npvalues, obj.dtype, numeric_only=True, same_dtype=preserve_dtype
)
import pyarrow as pa

if isinstance(out.dtype, ArrowDtype) and pa.types.is_struct(
out.dtype.pyarrow_dtype
):
out = npvalues

elif not isinstance(obj._values, np.ndarray):
out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
else:
out = npvalues

return out

@final
def _transform_series_pure_python(
self, obj: Series, func: Callable
) -> npt.NDArray[np.object_]:
splitter = self._get_splitter(obj)
res_by_group = []

for group in splitter:
res = func(group)
if hasattr(res, "_values"):
res = res._values

res_by_group.append(res)

res_by_group_pointers = np.zeros(self.ngroups, dtype=np.int64)
series_len = len(obj._values)
result = np.empty(series_len, dtype="O")

for i in range(series_len):
label = splitter.labels[i]
group_res = res_by_group[label]
pointer = res_by_group_pointers[label]
result[i] = group_res[pointer]

res_by_group_pointers[label] = pointer + 1

return result

@final
def agg_series(
self, obj: Series, func: Callable, preserve_dtype: bool = False
Expand Down
48 changes: 34 additions & 14 deletions pandas/tests/groupby/test_numeric_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,25 +159,35 @@ def _check(self, df, method, expected_columns, expected_columns_numeric):

# object dtypes for transformations are not implemented in Cython and
# have no Python fallback
exception = NotImplementedError if method.startswith("cum") else TypeError
exception = TypeError

if method in ("min", "max", "cummin", "cummax", "cumsum", "cumprod"):
if method in ("min", "max"):
# The methods default to numeric_only=False and raise TypeError
msg = "|".join(
[
"Categorical is not ordered",
f"Cannot perform {method} with non-ordered Categorical",
re.escape(f"agg function failed [how->{method},dtype->object]"),
# cumsum/cummin/cummax/cumprod
"function is not implemented for this dtype",
]
)
with pytest.raises(exception, match=msg):
getattr(gb, method)()
elif method in ("sum", "mean", "median", "prod"):
elif method in (
"sum",
"mean",
"median",
"prod",
"cummin",
"cummax",
"cumsum",
"cumprod",
):
msg = "|".join(
[
"category type does not support sum operations",
re.escape(f"category type does not support {method} operations"),
re.escape(
f"transform function failed [how->{method},dtype->object]"
),
re.escape(f"agg function failed [how->{method},dtype->object]"),
re.escape(f"agg function failed [how->{method},dtype->string]"),
]
Expand All @@ -195,6 +205,9 @@ def _check(self, df, method, expected_columns, expected_columns_numeric):
"category type does not support",
"function is not implemented for this dtype",
f"Cannot perform {method} with non-ordered Categorical",
re.escape(
f"transform function failed [how->{method},dtype->object]"
),
re.escape(f"agg function failed [how->{method},dtype->object]"),
re.escape(f"agg function failed [how->{method},dtype->string]"),
]
Expand Down Expand Up @@ -276,9 +289,7 @@ def test_numeric_only(kernel, has_arg, numeric_only, keys):
assert numeric_only is not True
# kernels that are successful on any dtype were above; this will fail

# object dtypes for transformations are not implemented in Cython and
# have no Python fallback
exception = NotImplementedError if kernel.startswith("cum") else TypeError
exception = TypeError

msg = "|".join(
[
Expand All @@ -289,6 +300,7 @@ def test_numeric_only(kernel, has_arg, numeric_only, keys):
"unsupported operand type",
"function is not implemented for this dtype",
re.escape(f"agg function failed [how->{kernel},dtype->object]"),
re.escape(f"transform function failed [how->{kernel},dtype->object]"),
]
)
if kernel == "idxmin":
Expand Down Expand Up @@ -334,10 +346,6 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request):
fails_on_numeric_object = (
"corr",
"cov",
"cummax",
"cummin",
"cumprod",
"cumsum",
"quantile",
)
# ops that give an object result on object input
Expand All @@ -358,6 +366,11 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request):
"max",
"prod",
"skew",
"cummax",
"cummin",
"cumsum",
# cumprod does not fail for object dtype, if element are numeric
"cumprod",
)

# Test default behavior; kernels that fail may be enabled in the future but kernels
Expand All @@ -376,6 +389,13 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request):
expected = expected.astype(object)
tm.assert_series_equal(result, expected)

valid_func_has_numeric_only = (
"cummin",
"cummax",
"cumsum",
# cumprod does not fail for object dtype, if element are numeric
"cumprod",
)
has_numeric_only = (
"first",
"last",
Expand All @@ -399,7 +419,7 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request):
msg = "got an unexpected keyword argument 'numeric_only'"
with pytest.raises(TypeError, match=msg):
method(*args, numeric_only=True)
elif dtype is object:
elif dtype is object and groupby_func not in valid_func_has_numeric_only:
msg = "|".join(
[
"SeriesGroupBy.sem called with numeric_only=True and dtype object",
Expand Down
Loading
Loading