diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 448ceffdaa1eb..575db2338c6c7 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -196,6 +196,7 @@ Other enhancements - :class:`ArrowDtype` now supports ``pyarrow.JsonType`` (:issue:`60958`) - :class:`DataFrameGroupBy` and :class:`SeriesGroupBy` methods ``sum``, ``mean``, ``median``, ``prod``, ``min``, ``max``, ``std``, ``var`` and ``sem`` now accept ``skipna`` parameter (:issue:`15675`) - :class:`Easter` has gained a new constructor argument ``method`` which specifies the method used to calculate Easter — for example, Orthodox Easter (:issue:`61665`) +- :class:`ExtensionScalarOpsMixin` now supports having a combination of manually defined operators and element-wise implementation fallback for the others (:issue:`50767`) - :class:`Holiday` constructor argument ``days_of_week`` will raise a ``ValueError`` when type is something other than ``None`` or ``tuple`` (:issue:`61658`) - :class:`Holiday` has gained the constructor argument and field ``exclude_dates`` to exclude specific datetimes from a custom holiday calendar (:issue:`54382`) - :class:`Rolling` and :class:`Expanding` now support ``nunique`` (:issue:`26958`) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index e8ca51ef92a94..937955eab3cf2 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -2740,24 +2740,33 @@ def _create_arithmetic_method(cls, op): @classmethod def _add_arithmetic_ops(cls) -> None: - setattr(cls, "__add__", cls._create_arithmetic_method(operator.add)) - setattr(cls, "__radd__", cls._create_arithmetic_method(roperator.radd)) - setattr(cls, "__sub__", cls._create_arithmetic_method(operator.sub)) - setattr(cls, "__rsub__", cls._create_arithmetic_method(roperator.rsub)) - setattr(cls, "__mul__", cls._create_arithmetic_method(operator.mul)) - setattr(cls, "__rmul__", cls._create_arithmetic_method(roperator.rmul)) - setattr(cls, "__pow__", cls._create_arithmetic_method(operator.pow)) - setattr(cls, "__rpow__", cls._create_arithmetic_method(roperator.rpow)) - setattr(cls, "__mod__", cls._create_arithmetic_method(operator.mod)) - setattr(cls, "__rmod__", cls._create_arithmetic_method(roperator.rmod)) - setattr(cls, "__floordiv__", cls._create_arithmetic_method(operator.floordiv)) - setattr( - cls, "__rfloordiv__", cls._create_arithmetic_method(roperator.rfloordiv) - ) - setattr(cls, "__truediv__", cls._create_arithmetic_method(operator.truediv)) - setattr(cls, "__rtruediv__", cls._create_arithmetic_method(roperator.rtruediv)) - setattr(cls, "__divmod__", cls._create_arithmetic_method(divmod)) - setattr(cls, "__rdivmod__", cls._create_arithmetic_method(roperator.rdivmod)) + """ + Add arithmetic operator overloads to the class. + + Only operators not already defined on the class will be added, + allowing custom implementations to coexist with mixin fallbacks. + """ + ops = [ + ("__add__", operator.add), + ("__radd__", roperator.radd), + ("__sub__", operator.sub), + ("__rsub__", roperator.rsub), + ("__mul__", operator.mul), + ("__rmul__", roperator.rmul), + ("__pow__", operator.pow), + ("__rpow__", roperator.rpow), + ("__mod__", operator.mod), + ("__rmod__", roperator.rmod), + ("__floordiv__", operator.floordiv), + ("__rfloordiv__", roperator.rfloordiv), + ("__truediv__", operator.truediv), + ("__rtruediv__", roperator.rtruediv), + ("__divmod__", divmod), + ("__rdivmod__", roperator.rdivmod), + ] + for name, op in ops: + if not hasattr(cls, name): + setattr(cls, name, cls._create_arithmetic_method(op)) @classmethod def _create_comparison_method(cls, op): @@ -2765,12 +2774,23 @@ def _create_comparison_method(cls, op): @classmethod def _add_comparison_ops(cls) -> None: - setattr(cls, "__eq__", cls._create_comparison_method(operator.eq)) - setattr(cls, "__ne__", cls._create_comparison_method(operator.ne)) - setattr(cls, "__lt__", cls._create_comparison_method(operator.lt)) - setattr(cls, "__gt__", cls._create_comparison_method(operator.gt)) - setattr(cls, "__le__", cls._create_comparison_method(operator.le)) - setattr(cls, "__ge__", cls._create_comparison_method(operator.ge)) + """ + Add comparison operator overloads to the class. + + Only operators not already defined on the class will be added, + allowing custom implementations to coexist with mixin fallbacks. + """ + ops = [ + ("__eq__", operator.eq), + ("__ne__", operator.ne), + ("__lt__", operator.lt), + ("__gt__", operator.gt), + ("__le__", operator.le), + ("__ge__", operator.ge), + ] + for name, op in ops: + if not hasattr(cls, name): + setattr(cls, name, cls._create_comparison_method(op)) @classmethod def _create_logical_method(cls, op): @@ -2797,13 +2817,14 @@ class ExtensionScalarOpsMixin(ExtensionOpsMixin): ----- If you have defined a subclass MyExtensionArray(ExtensionArray), then use MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin) to - get the arithmetic operators. After the definition of MyExtensionArray, - insert the lines + get the arithmetic and comparison operators. After the definition of + MyExtensionArray, insert the lines MyExtensionArray._add_arithmetic_ops() MyExtensionArray._add_comparison_ops() - to link the operators to your class. + to link the operators to your class, allowing you to define custom operator methods + for performance, with fallback to element-wise implementations for others. .. note:: diff --git a/pandas/tests/extension/test_scalar_ops_mixin.py b/pandas/tests/extension/test_scalar_ops_mixin.py new file mode 100644 index 0000000000000..86d151f0c3eba --- /dev/null +++ b/pandas/tests/extension/test_scalar_ops_mixin.py @@ -0,0 +1,105 @@ +import numpy as np + +import pandas as pd +from pandas.core.arrays.base import ExtensionScalarOpsMixin +import pandas.testing as tm + + +class CustomAddArray(pd.api.extensions.ExtensionArray, ExtensionScalarOpsMixin): + def __init__(self, values): + self._data = np.array(values) + + @property + def dtype(self): + return self._data.dtype + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + def __array__(self): + return self._data + + def _from_sequence(self, scalars, dtype=None): + return type(self)(scalars) + + def _cast_pointwise_result(self, arr): + return type(self)(arr) + + # Custom __add__ implementation + def __add__(self, other): + for i in range(len(self._data)): + self._data[i] += other * 2 + return self + + +# Test fallback logic for arithmetic ops +def test_add_arithmetic_ops_custom(): + array_add = CustomAddArray([1, 2, 3]) + expected_add = CustomAddArray([7, 8, 9]) + + # Add mixin ops + CustomAddArray._add_arithmetic_ops() + array_add += 3 + + # Should use custom add (elementwise equality) + tm.assert_numpy_array_equal(array_add._data, expected_add._data) + + # Check that another op (e.g., __sub__) is present and works + assert hasattr(CustomAddArray, "__sub__") + array_sub = CustomAddArray([1, 2, 3]) + expected_sub = CustomAddArray([0, 1, 2]) + + array_sub -= 1 + + assert isinstance(array_sub, CustomAddArray) + tm.assert_numpy_array_equal(array_sub._data, expected_sub._data) + + +# Test fallback logic for comparison ops +class CustomEqArray(pd.api.extensions.ExtensionArray, ExtensionScalarOpsMixin): + def __init__(self, values): + self._data = np.array(values) + + @property + def dtype(self): + return self._data.dtype + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + def __array__(self): + return self._data + + def _from_sequence(self, scalars, dtype=None): + return type(self)(scalars) + + def _cast_pointwise_result(self, arr): + return type(self)(arr) + + # Custom __eq__ implementation + def __eq__(self, other): + # Dummy implementation + for i in range(len(self._data)): + if self._data[i] != other: + return False + return True + + +def test_add_comparison_ops_custom(): + arr_true = CustomEqArray([1, 1, 1]) + arr_false = CustomEqArray([1, 2, 3]) + CustomEqArray._add_comparison_ops() + + # Test custom __eq__ implementation + result_true = arr_true == 1 + result_false = arr_false == 1 + assert result_true + assert not result_false + + assert hasattr(CustomEqArray, "__ne__")