From 43b9ebcc1de9c44a1c5a62af3ae4c915bdaafcb9 Mon Sep 17 00:00:00 2001 From: Mila Mathias Date: Wed, 8 Oct 2025 08:32:45 -0700 Subject: [PATCH 1/3] more formatting --- pandas/core/arrays/base.py | 75 ++++++++++------ .../tests/extension/test_scalar_ops_mixin.py | 88 +++++++++++++++++++ 2 files changed, 136 insertions(+), 27 deletions(-) create mode 100644 pandas/tests/extension/test_scalar_ops_mixin.py diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index e8ca51ef92a94..937955eab3cf2 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -2740,24 +2740,33 @@ def _create_arithmetic_method(cls, op): @classmethod def _add_arithmetic_ops(cls) -> None: - setattr(cls, "__add__", cls._create_arithmetic_method(operator.add)) - setattr(cls, "__radd__", cls._create_arithmetic_method(roperator.radd)) - setattr(cls, "__sub__", cls._create_arithmetic_method(operator.sub)) - setattr(cls, "__rsub__", cls._create_arithmetic_method(roperator.rsub)) - setattr(cls, "__mul__", cls._create_arithmetic_method(operator.mul)) - setattr(cls, "__rmul__", cls._create_arithmetic_method(roperator.rmul)) - setattr(cls, "__pow__", cls._create_arithmetic_method(operator.pow)) - setattr(cls, "__rpow__", cls._create_arithmetic_method(roperator.rpow)) - setattr(cls, "__mod__", cls._create_arithmetic_method(operator.mod)) - setattr(cls, "__rmod__", cls._create_arithmetic_method(roperator.rmod)) - setattr(cls, "__floordiv__", cls._create_arithmetic_method(operator.floordiv)) - setattr( - cls, "__rfloordiv__", cls._create_arithmetic_method(roperator.rfloordiv) - ) - setattr(cls, "__truediv__", cls._create_arithmetic_method(operator.truediv)) - setattr(cls, "__rtruediv__", cls._create_arithmetic_method(roperator.rtruediv)) - setattr(cls, "__divmod__", cls._create_arithmetic_method(divmod)) - setattr(cls, "__rdivmod__", cls._create_arithmetic_method(roperator.rdivmod)) + """ + Add arithmetic operator overloads to the class. + + Only operators not already defined on the class will be added, + allowing custom implementations to coexist with mixin fallbacks. + """ + ops = [ + ("__add__", operator.add), + ("__radd__", roperator.radd), + ("__sub__", operator.sub), + ("__rsub__", roperator.rsub), + ("__mul__", operator.mul), + ("__rmul__", roperator.rmul), + ("__pow__", operator.pow), + ("__rpow__", roperator.rpow), + ("__mod__", operator.mod), + ("__rmod__", roperator.rmod), + ("__floordiv__", operator.floordiv), + ("__rfloordiv__", roperator.rfloordiv), + ("__truediv__", operator.truediv), + ("__rtruediv__", roperator.rtruediv), + ("__divmod__", divmod), + ("__rdivmod__", roperator.rdivmod), + ] + for name, op in ops: + if not hasattr(cls, name): + setattr(cls, name, cls._create_arithmetic_method(op)) @classmethod def _create_comparison_method(cls, op): @@ -2765,12 +2774,23 @@ def _create_comparison_method(cls, op): @classmethod def _add_comparison_ops(cls) -> None: - setattr(cls, "__eq__", cls._create_comparison_method(operator.eq)) - setattr(cls, "__ne__", cls._create_comparison_method(operator.ne)) - setattr(cls, "__lt__", cls._create_comparison_method(operator.lt)) - setattr(cls, "__gt__", cls._create_comparison_method(operator.gt)) - setattr(cls, "__le__", cls._create_comparison_method(operator.le)) - setattr(cls, "__ge__", cls._create_comparison_method(operator.ge)) + """ + Add comparison operator overloads to the class. + + Only operators not already defined on the class will be added, + allowing custom implementations to coexist with mixin fallbacks. + """ + ops = [ + ("__eq__", operator.eq), + ("__ne__", operator.ne), + ("__lt__", operator.lt), + ("__gt__", operator.gt), + ("__le__", operator.le), + ("__ge__", operator.ge), + ] + for name, op in ops: + if not hasattr(cls, name): + setattr(cls, name, cls._create_comparison_method(op)) @classmethod def _create_logical_method(cls, op): @@ -2797,13 +2817,14 @@ class ExtensionScalarOpsMixin(ExtensionOpsMixin): ----- If you have defined a subclass MyExtensionArray(ExtensionArray), then use MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin) to - get the arithmetic operators. After the definition of MyExtensionArray, - insert the lines + get the arithmetic and comparison operators. After the definition of + MyExtensionArray, insert the lines MyExtensionArray._add_arithmetic_ops() MyExtensionArray._add_comparison_ops() - to link the operators to your class. + to link the operators to your class, allowing you to define custom operator methods + for performance, with fallback to element-wise implementations for others. .. note:: diff --git a/pandas/tests/extension/test_scalar_ops_mixin.py b/pandas/tests/extension/test_scalar_ops_mixin.py new file mode 100644 index 0000000000000..f7403e37d5271 --- /dev/null +++ b/pandas/tests/extension/test_scalar_ops_mixin.py @@ -0,0 +1,88 @@ +import numpy as np + +import pandas as pd +from pandas.core.arrays.base import ExtensionScalarOpsMixin + + +class CustomAddArray(pd.api.extensions.ExtensionArray, ExtensionScalarOpsMixin): + def __init__(self, values): + self._data = np.array(values) + + @property + def dtype(self): + return self._data.dtype + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + def __array__(self): + return self._data + + def _from_sequence(self, scalars, dtype=None): + return type(self)(scalars) + + def _cast_pointwise_result(self, arr): + return type(self)(arr) + + # Custom __add__ implementation + def __add__(self, other): + return "custom_add" + + +# Test fallback logic for arithmetic ops +def test_add_arithmetic_ops_custom(): + arr = CustomAddArray([1, 2, 3]) + # Remove __add__ if present, then add custom + CustomAddArray.__add__ = lambda self, other: "custom_add" + # Add mixin ops + CustomAddArray._add_arithmetic_ops() + result = arr + 1 + + # Should use custom + assert result == "custom_add" + + # Check that another op (e.g., __sub__) is present and works + assert hasattr(CustomAddArray, "__sub__") + sub_result = arr - 1 + assert isinstance(sub_result, CustomAddArray) + + +# Test fallback logic for comparison ops +class CustomEqArray(pd.api.extensions.ExtensionArray, ExtensionScalarOpsMixin): + def __init__(self, values): + self._data = np.array(values) + + @property + def dtype(self): + return self._data.dtype + + def __getitem__(self, idx): + return self._data[idx] + + def __len__(self): + return len(self._data) + + def __array__(self): + return self._data + + def _from_sequence(self, scalars, dtype=None): + return type(self)(scalars) + + def _cast_pointwise_result(self, arr): + return type(self)(arr) + + +def test_add_comparison_ops_custom(): + arr = CustomEqArray([1, 2, 3]) + CustomEqArray.__eq__ = lambda self, other: self != other + CustomEqArray._add_comparison_ops() + result = arr == 1 + + assert not result + # Check that another op (e.g., __ne__) is present and works + assert hasattr(CustomEqArray, "__ne__") + ne_result = arr != 1 + assert isinstance(ne_result, np.ndarray) or isinstance(ne_result, CustomEqArray) From edf22e14626227e22cb7bb1661fd910d52b7623f Mon Sep 17 00:00:00 2001 From: Mila Mathias Date: Mon, 13 Oct 2025 08:34:51 -0700 Subject: [PATCH 2/3] tests and formatting --- .../tests/extension/test_scalar_ops_mixin.py | 49 +++++++++++++------ 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/pandas/tests/extension/test_scalar_ops_mixin.py b/pandas/tests/extension/test_scalar_ops_mixin.py index f7403e37d5271..86d151f0c3eba 100644 --- a/pandas/tests/extension/test_scalar_ops_mixin.py +++ b/pandas/tests/extension/test_scalar_ops_mixin.py @@ -2,6 +2,7 @@ import pandas as pd from pandas.core.arrays.base import ExtensionScalarOpsMixin +import pandas.testing as tm class CustomAddArray(pd.api.extensions.ExtensionArray, ExtensionScalarOpsMixin): @@ -29,25 +30,32 @@ def _cast_pointwise_result(self, arr): # Custom __add__ implementation def __add__(self, other): - return "custom_add" + for i in range(len(self._data)): + self._data[i] += other * 2 + return self # Test fallback logic for arithmetic ops def test_add_arithmetic_ops_custom(): - arr = CustomAddArray([1, 2, 3]) - # Remove __add__ if present, then add custom - CustomAddArray.__add__ = lambda self, other: "custom_add" + array_add = CustomAddArray([1, 2, 3]) + expected_add = CustomAddArray([7, 8, 9]) + # Add mixin ops CustomAddArray._add_arithmetic_ops() - result = arr + 1 + array_add += 3 - # Should use custom - assert result == "custom_add" + # Should use custom add (elementwise equality) + tm.assert_numpy_array_equal(array_add._data, expected_add._data) # Check that another op (e.g., __sub__) is present and works assert hasattr(CustomAddArray, "__sub__") - sub_result = arr - 1 - assert isinstance(sub_result, CustomAddArray) + array_sub = CustomAddArray([1, 2, 3]) + expected_sub = CustomAddArray([0, 1, 2]) + + array_sub -= 1 + + assert isinstance(array_sub, CustomAddArray) + tm.assert_numpy_array_equal(array_sub._data, expected_sub._data) # Test fallback logic for comparison ops @@ -74,15 +82,24 @@ def _from_sequence(self, scalars, dtype=None): def _cast_pointwise_result(self, arr): return type(self)(arr) + # Custom __eq__ implementation + def __eq__(self, other): + # Dummy implementation + for i in range(len(self._data)): + if self._data[i] != other: + return False + return True + def test_add_comparison_ops_custom(): - arr = CustomEqArray([1, 2, 3]) - CustomEqArray.__eq__ = lambda self, other: self != other + arr_true = CustomEqArray([1, 1, 1]) + arr_false = CustomEqArray([1, 2, 3]) CustomEqArray._add_comparison_ops() - result = arr == 1 - assert not result - # Check that another op (e.g., __ne__) is present and works + # Test custom __eq__ implementation + result_true = arr_true == 1 + result_false = arr_false == 1 + assert result_true + assert not result_false + assert hasattr(CustomEqArray, "__ne__") - ne_result = arr != 1 - assert isinstance(ne_result, np.ndarray) or isinstance(ne_result, CustomEqArray) From 9a2fe9eb15e442f749c21857e4255ee87be4d4a0 Mon Sep 17 00:00:00 2001 From: Mila Mathias Date: Mon, 13 Oct 2025 08:42:10 -0700 Subject: [PATCH 3/3] whatsnew --- doc/source/whatsnew/v3.0.0.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 448ceffdaa1eb..575db2338c6c7 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -196,6 +196,7 @@ Other enhancements - :class:`ArrowDtype` now supports ``pyarrow.JsonType`` (:issue:`60958`) - :class:`DataFrameGroupBy` and :class:`SeriesGroupBy` methods ``sum``, ``mean``, ``median``, ``prod``, ``min``, ``max``, ``std``, ``var`` and ``sem`` now accept ``skipna`` parameter (:issue:`15675`) - :class:`Easter` has gained a new constructor argument ``method`` which specifies the method used to calculate Easter — for example, Orthodox Easter (:issue:`61665`) +- :class:`ExtensionScalarOpsMixin` now supports having a combination of manually defined operators and element-wise implementation fallback for the others (:issue:`50767`) - :class:`Holiday` constructor argument ``days_of_week`` will raise a ``ValueError`` when type is something other than ``None`` or ``tuple`` (:issue:`61658`) - :class:`Holiday` has gained the constructor argument and field ``exclude_dates`` to exclude specific datetimes from a custom holiday calendar (:issue:`54382`) - :class:`Rolling` and :class:`Expanding` now support ``nunique`` (:issue:`26958`)