Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ Other enhancements
- :class:`ArrowDtype` now supports ``pyarrow.JsonType`` (:issue:`60958`)
- :class:`DataFrameGroupBy` and :class:`SeriesGroupBy` methods ``sum``, ``mean``, ``median``, ``prod``, ``min``, ``max``, ``std``, ``var`` and ``sem`` now accept ``skipna`` parameter (:issue:`15675`)
- :class:`Easter` has gained a new constructor argument ``method`` which specifies the method used to calculate Easter — for example, Orthodox Easter (:issue:`61665`)
- :class:`ExtensionScalarOpsMixin` now supports having a combination of manually defined operators and element-wise implementation fallback for the others (:issue:`50767`)
- :class:`Holiday` constructor argument ``days_of_week`` will raise a ``ValueError`` when type is something other than ``None`` or ``tuple`` (:issue:`61658`)
- :class:`Holiday` has gained the constructor argument and field ``exclude_dates`` to exclude specific datetimes from a custom holiday calendar (:issue:`54382`)
- :class:`Rolling` and :class:`Expanding` now support ``nunique`` (:issue:`26958`)
Expand Down
75 changes: 48 additions & 27 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2740,37 +2740,57 @@ def _create_arithmetic_method(cls, op):

@classmethod
def _add_arithmetic_ops(cls) -> None:
setattr(cls, "__add__", cls._create_arithmetic_method(operator.add))
setattr(cls, "__radd__", cls._create_arithmetic_method(roperator.radd))
setattr(cls, "__sub__", cls._create_arithmetic_method(operator.sub))
setattr(cls, "__rsub__", cls._create_arithmetic_method(roperator.rsub))
setattr(cls, "__mul__", cls._create_arithmetic_method(operator.mul))
setattr(cls, "__rmul__", cls._create_arithmetic_method(roperator.rmul))
setattr(cls, "__pow__", cls._create_arithmetic_method(operator.pow))
setattr(cls, "__rpow__", cls._create_arithmetic_method(roperator.rpow))
setattr(cls, "__mod__", cls._create_arithmetic_method(operator.mod))
setattr(cls, "__rmod__", cls._create_arithmetic_method(roperator.rmod))
setattr(cls, "__floordiv__", cls._create_arithmetic_method(operator.floordiv))
setattr(
cls, "__rfloordiv__", cls._create_arithmetic_method(roperator.rfloordiv)
)
setattr(cls, "__truediv__", cls._create_arithmetic_method(operator.truediv))
setattr(cls, "__rtruediv__", cls._create_arithmetic_method(roperator.rtruediv))
setattr(cls, "__divmod__", cls._create_arithmetic_method(divmod))
setattr(cls, "__rdivmod__", cls._create_arithmetic_method(roperator.rdivmod))
"""
Add arithmetic operator overloads to the class.

Only operators not already defined on the class will be added,
allowing custom implementations to coexist with mixin fallbacks.
"""
ops = [
("__add__", operator.add),
("__radd__", roperator.radd),
("__sub__", operator.sub),
("__rsub__", roperator.rsub),
("__mul__", operator.mul),
("__rmul__", roperator.rmul),
("__pow__", operator.pow),
("__rpow__", roperator.rpow),
("__mod__", operator.mod),
("__rmod__", roperator.rmod),
("__floordiv__", operator.floordiv),
("__rfloordiv__", roperator.rfloordiv),
("__truediv__", operator.truediv),
("__rtruediv__", roperator.rtruediv),
("__divmod__", divmod),
("__rdivmod__", roperator.rdivmod),
]
for name, op in ops:
if not hasattr(cls, name):
setattr(cls, name, cls._create_arithmetic_method(op))

@classmethod
def _create_comparison_method(cls, op):
raise AbstractMethodError(cls)

@classmethod
def _add_comparison_ops(cls) -> None:
setattr(cls, "__eq__", cls._create_comparison_method(operator.eq))
setattr(cls, "__ne__", cls._create_comparison_method(operator.ne))
setattr(cls, "__lt__", cls._create_comparison_method(operator.lt))
setattr(cls, "__gt__", cls._create_comparison_method(operator.gt))
setattr(cls, "__le__", cls._create_comparison_method(operator.le))
setattr(cls, "__ge__", cls._create_comparison_method(operator.ge))
"""
Add comparison operator overloads to the class.

Only operators not already defined on the class will be added,
allowing custom implementations to coexist with mixin fallbacks.
"""
ops = [
("__eq__", operator.eq),
("__ne__", operator.ne),
("__lt__", operator.lt),
("__gt__", operator.gt),
("__le__", operator.le),
("__ge__", operator.ge),
]
for name, op in ops:
if not hasattr(cls, name):
setattr(cls, name, cls._create_comparison_method(op))

@classmethod
def _create_logical_method(cls, op):
Expand All @@ -2797,13 +2817,14 @@ class ExtensionScalarOpsMixin(ExtensionOpsMixin):
-----
If you have defined a subclass MyExtensionArray(ExtensionArray), then
use MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin) to
get the arithmetic operators. After the definition of MyExtensionArray,
insert the lines
get the arithmetic and comparison operators. After the definition of
MyExtensionArray, insert the lines

MyExtensionArray._add_arithmetic_ops()
MyExtensionArray._add_comparison_ops()

to link the operators to your class.
to link the operators to your class, allowing you to define custom operator methods
for performance, with fallback to element-wise implementations for others.

.. note::

Expand Down
105 changes: 105 additions & 0 deletions pandas/tests/extension/test_scalar_ops_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import numpy as np

import pandas as pd
from pandas.core.arrays.base import ExtensionScalarOpsMixin
import pandas.testing as tm


class CustomAddArray(pd.api.extensions.ExtensionArray, ExtensionScalarOpsMixin):
def __init__(self, values):
self._data = np.array(values)

@property
def dtype(self):
return self._data.dtype

def __getitem__(self, idx):
return self._data[idx]

def __len__(self):
return len(self._data)

def __array__(self):
return self._data

def _from_sequence(self, scalars, dtype=None):
return type(self)(scalars)

def _cast_pointwise_result(self, arr):
return type(self)(arr)

# Custom __add__ implementation
def __add__(self, other):
for i in range(len(self._data)):
self._data[i] += other * 2
return self


# Test fallback logic for arithmetic ops
def test_add_arithmetic_ops_custom():
array_add = CustomAddArray([1, 2, 3])
expected_add = CustomAddArray([7, 8, 9])

# Add mixin ops
CustomAddArray._add_arithmetic_ops()
array_add += 3

# Should use custom add (elementwise equality)
tm.assert_numpy_array_equal(array_add._data, expected_add._data)

# Check that another op (e.g., __sub__) is present and works
assert hasattr(CustomAddArray, "__sub__")
array_sub = CustomAddArray([1, 2, 3])
expected_sub = CustomAddArray([0, 1, 2])

array_sub -= 1

assert isinstance(array_sub, CustomAddArray)
tm.assert_numpy_array_equal(array_sub._data, expected_sub._data)


# Test fallback logic for comparison ops
class CustomEqArray(pd.api.extensions.ExtensionArray, ExtensionScalarOpsMixin):
def __init__(self, values):
self._data = np.array(values)

@property
def dtype(self):
return self._data.dtype

def __getitem__(self, idx):
return self._data[idx]

def __len__(self):
return len(self._data)

def __array__(self):
return self._data

def _from_sequence(self, scalars, dtype=None):
return type(self)(scalars)

def _cast_pointwise_result(self, arr):
return type(self)(arr)

# Custom __eq__ implementation
def __eq__(self, other):
# Dummy implementation
for i in range(len(self._data)):
if self._data[i] != other:
return False
return True


def test_add_comparison_ops_custom():
arr_true = CustomEqArray([1, 1, 1])
arr_false = CustomEqArray([1, 2, 3])
CustomEqArray._add_comparison_ops()

# Test custom __eq__ implementation
result_true = arr_true == 1
result_false = arr_false == 1
assert result_true
assert not result_false

assert hasattr(CustomEqArray, "__ne__")
Loading