Skip to content

Commit

Permalink
REF: dont pass exception to check_opname (#54365)
Browse files Browse the repository at this point in the history
* REF: dont pass exception to check_opname

* future imports

* privatize

* REF: update pattern for check_divmod_op

* typo fixup

* suggested edit

* mypy fixup

* lint fixup
  • Loading branch information
jbrockmendel committed Aug 4, 2023
1 parent 92d1d6a commit a3d6c36
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 130 deletions.
65 changes: 49 additions & 16 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,28 @@


class BaseOpsUtil(BaseExtensionTests):
series_scalar_exc: type[Exception] | None = TypeError
frame_scalar_exc: type[Exception] | None = TypeError
series_array_exc: type[Exception] | None = TypeError
divmod_exc: type[Exception] | None = TypeError

def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
# Find the Exception, if any we expect to raise calling
# obj.__op_name__(other)

# The self.obj_bar_exc pattern isn't great in part because it can depend
# on op_name or dtypes, but we use it here for backward-compatibility.
if op_name in ["__divmod__", "__rdivmod__"]:
return self.divmod_exc
if isinstance(obj, pd.Series) and isinstance(other, pd.Series):
return self.series_array_exc
elif isinstance(obj, pd.Series):
return self.series_scalar_exc
else:
return self.frame_scalar_exc

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
# In _check_op we check that the result of a pointwise operation
# (found via _combine) matches the result of the vectorized
Expand All @@ -24,17 +46,21 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
def get_op_from_name(self, op_name: str):
return tm.get_op_from_name(op_name)

def check_opname(self, ser: pd.Series, op_name: str, other, exc=Exception):
op = self.get_op_from_name(op_name)

self._check_op(ser, op, other, op_name, exc)

# Subclasses are not expected to need to override _check_op or _combine.
# Subclasses are not expected to need to override check_opname, _check_op,
# _check_divmod_op, or _combine.
# Ideally any relevant overriding can be done in _cast_pointwise_result,
# get_op_from_name, and the specification of `exc`. If you find a use
# case that still requires overriding _check_op or _combine, please let
# us know at github.com/pandas-dev/pandas/issues
@final
def check_opname(self, ser: pd.Series, op_name: str, other):
exc = self._get_expected_exception(op_name, ser, other)
op = self.get_op_from_name(op_name)

self._check_op(ser, op, other, op_name, exc)

# see comment on check_opname
@final
def _combine(self, obj, other, op):
if isinstance(obj, pd.DataFrame):
if len(obj.columns) != 1:
Expand All @@ -44,11 +70,14 @@ def _combine(self, obj, other, op):
expected = obj.combine(other, op)
return expected

# see comment on _combine
# see comment on check_opname
@final
def _check_op(
self, ser: pd.Series, op, other, op_name: str, exc=NotImplementedError
):
# Check that the Series/DataFrame arithmetic/comparison method matches
# the pointwise result from _combine.

if exc is None:
result = op(ser, other)
expected = self._combine(ser, other, op)
Expand All @@ -59,8 +88,14 @@ def _check_op(
with pytest.raises(exc):
op(ser, other)

def _check_divmod_op(self, ser: pd.Series, op, other, exc=Exception):
# divmod has multiple return values, so check separately
# see comment on check_opname
@final
def _check_divmod_op(self, ser: pd.Series, op, other):
# check that divmod behavior matches behavior of floordiv+mod
if op is divmod:
exc = self._get_expected_exception("__divmod__", ser, other)
else:
exc = self._get_expected_exception("__rdivmod__", ser, other)
if exc is None:
result_div, result_mod = op(ser, other)
if op is divmod:
Expand Down Expand Up @@ -96,26 +131,24 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
# series & scalar
op_name = all_arithmetic_operators
ser = pd.Series(data)
self.check_opname(ser, op_name, ser.iloc[0], exc=self.series_scalar_exc)
self.check_opname(ser, op_name, ser.iloc[0])

def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
# frame & scalar
op_name = all_arithmetic_operators
df = pd.DataFrame({"A": data})
self.check_opname(df, op_name, data[0], exc=self.frame_scalar_exc)
self.check_opname(df, op_name, data[0])

def test_arith_series_with_array(self, data, all_arithmetic_operators):
# ndarray & other series
op_name = all_arithmetic_operators
ser = pd.Series(data)
self.check_opname(
ser, op_name, pd.Series([ser.iloc[0]] * len(ser)), exc=self.series_array_exc
)
self.check_opname(ser, op_name, pd.Series([ser.iloc[0]] * len(ser)))

def test_divmod(self, data):
ser = pd.Series(data)
self._check_divmod_op(ser, divmod, 1, exc=self.divmod_exc)
self._check_divmod_op(1, ops.rdivmod, ser, exc=self.divmod_exc)
self._check_divmod_op(ser, divmod, 1)
self._check_divmod_op(1, ops.rdivmod, ser)

def test_divmod_series_array(self, data, data_for_twos):
ser = pd.Series(data)
Expand Down
16 changes: 10 additions & 6 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import decimal
import operator

Expand Down Expand Up @@ -311,8 +313,14 @@ def test_astype_dispatches(frame):


class TestArithmeticOps(base.BaseArithmeticOpsTests):
def check_opname(self, s, op_name, other, exc=None):
super().check_opname(s, op_name, other, exc=None)
series_scalar_exc = None
frame_scalar_exc = None
series_array_exc = None

def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
return None

def test_arith_series_with_array(self, data, all_arithmetic_operators):
op_name = all_arithmetic_operators
Expand All @@ -336,10 +344,6 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators):
context.traps[decimal.DivisionByZero] = divbyzerotrap
context.traps[decimal.InvalidOperation] = invalidoptrap

def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
# We implement divmod
super()._check_divmod_op(s, op, other, exc=None)


class TestComparisonOps(base.BaseComparisonOpsTests):
def test_compare_scalar(self, data, comparison_op):
Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,6 @@ def test_divmod_series_array(self):
# skipping because it is not implemented
super().test_divmod_series_array()

def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
return super()._check_divmod_op(s, op, other, exc=TypeError)


class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests):
pass
Expand Down
34 changes: 17 additions & 17 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
classes (if they are relevant for the extension interface for all dtypes), or
be added to the array-specific tests in `pandas/tests/arrays/`.
"""
from __future__ import annotations

from datetime import (
date,
datetime,
Expand Down Expand Up @@ -964,16 +966,26 @@ def _is_temporal_supported(self, opname, pa_dtype):
and pa.types.is_temporal(pa_dtype)
)

def _get_scalar_exception(self, opname, pa_dtype):
arrow_temporal_supported = self._is_temporal_supported(opname, pa_dtype)
if opname in {
def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
if op_name in ("__divmod__", "__rdivmod__"):
return self.divmod_exc

dtype = tm.get_dtype(obj)
# error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no
# attribute "pyarrow_dtype"
pa_dtype = dtype.pyarrow_dtype # type: ignore[union-attr]

arrow_temporal_supported = self._is_temporal_supported(op_name, pa_dtype)
if op_name in {
"__mod__",
"__rmod__",
}:
exc = NotImplementedError
elif arrow_temporal_supported:
exc = None
elif opname in ["__add__", "__radd__"] and (
elif op_name in ["__add__", "__radd__"] and (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
):
exc = None
Expand Down Expand Up @@ -1060,10 +1072,6 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators, request)
):
pytest.skip("Skip testing Python string formatting")

self.series_scalar_exc = self._get_scalar_exception(
all_arithmetic_operators, pa_dtype
)

mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
if mark is not None:
request.node.add_marker(mark)
Expand All @@ -1078,10 +1086,6 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
):
pytest.skip("Skip testing Python string formatting")

self.frame_scalar_exc = self._get_scalar_exception(
all_arithmetic_operators, pa_dtype
)

mark = self._get_arith_xfail_marker(all_arithmetic_operators, pa_dtype)
if mark is not None:
request.node.add_marker(mark)
Expand All @@ -1091,10 +1095,6 @@ def test_arith_frame_with_scalar(self, data, all_arithmetic_operators, request):
def test_arith_series_with_array(self, data, all_arithmetic_operators, request):
pa_dtype = data.dtype.pyarrow_dtype

self.series_array_exc = self._get_scalar_exception(
all_arithmetic_operators, pa_dtype
)

if (
all_arithmetic_operators
in (
Expand Down Expand Up @@ -1124,7 +1124,7 @@ def test_arith_series_with_array(self, data, all_arithmetic_operators, request):
# since ser.iloc[0] is a python scalar
other = pd.Series(pd.array([ser.iloc[0]] * len(ser), dtype=data.dtype))

self.check_opname(ser, op_name, other, exc=self.series_array_exc)
self.check_opname(ser, op_name, other)

def test_add_series_with_extension_array(self, data, request):
pa_dtype = data.dtype.pyarrow_dtype
Expand Down
22 changes: 5 additions & 17 deletions pandas/tests/extension/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,14 @@ class TestMissing(base.BaseMissingTests):
class TestArithmeticOps(base.BaseArithmeticOpsTests):
implements = {"__sub__", "__rsub__"}

def check_opname(self, s, op_name, other, exc=None):
# overwriting to indicate ops don't raise an error
exc = None
def _get_expected_exception(self, op_name, obj, other):
if op_name.strip("_").lstrip("r") in ["pow", "truediv", "floordiv"]:
# match behavior with non-masked bool dtype
exc = NotImplementedError
return NotImplementedError
elif op_name in self.implements:
# exception message would include "numpy boolean subtract""
exc = TypeError

super().check_opname(s, op_name, other, exc=exc)
return TypeError
return None

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
if op_name in (
Expand Down Expand Up @@ -170,18 +167,9 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
def test_divmod_series_array(self, data, data_for_twos):
super().test_divmod_series_array(data, data_for_twos)

@pytest.mark.xfail(
reason="Inconsistency between floordiv and divmod; we raise for floordiv "
"but not for divmod. This matches what we do for non-masked bool dtype."
)
def test_divmod(self, data):
super().test_divmod(data)


class TestComparisonOps(base.BaseComparisonOpsTests):
def check_opname(self, s, op_name, other, exc=None):
# overwriting to indicate ops don't raise an error
super().check_opname(s, op_name, other, exc=None)
pass


class TestReshaping(base.BaseReshapingTests):
Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/extension/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,6 @@ def test_divmod_series_array(self):
# skipping because it is not implemented
pass

def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
return super()._check_divmod_op(s, op, other, exc=TypeError)


class TestComparisonOps(base.BaseComparisonOpsTests):
def _compare_other(self, s, data, op, other):
Expand Down
28 changes: 4 additions & 24 deletions pandas/tests/extension/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,10 @@ class TestInterface(BaseDatetimeTests, base.BaseInterfaceTests):
class TestArithmeticOps(BaseDatetimeTests, base.BaseArithmeticOpsTests):
implements = {"__sub__", "__rsub__"}

def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
# frame & scalar
if all_arithmetic_operators in self.implements:
df = pd.DataFrame({"A": data})
self.check_opname(df, all_arithmetic_operators, data[0], exc=None)
else:
# ... but not the rest.
super().test_arith_frame_with_scalar(data, all_arithmetic_operators)

def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
if all_arithmetic_operators in self.implements:
ser = pd.Series(data)
self.check_opname(ser, all_arithmetic_operators, ser.iloc[0], exc=None)
else:
# ... but not the rest.
super().test_arith_series_with_scalar(data, all_arithmetic_operators)
def _get_expected_exception(self, op_name, obj, other):
if op_name in self.implements:
return None
return super()._get_expected_exception(op_name, obj, other)

def test_add_series_with_extension_array(self, data):
# Datetime + Datetime not implemented
Expand All @@ -154,14 +142,6 @@ def test_add_series_with_extension_array(self, data):
with pytest.raises(TypeError, match=msg):
ser + data

def test_arith_series_with_array(self, data, all_arithmetic_operators):
if all_arithmetic_operators in self.implements:
ser = pd.Series(data)
self.check_opname(ser, all_arithmetic_operators, ser.iloc[0], exc=None)
else:
# ... but not the rest.
super().test_arith_series_with_scalar(data, all_arithmetic_operators)

def test_divmod_series_array(self):
# GH 23287
# skipping because it is not implemented
Expand Down
17 changes: 8 additions & 9 deletions pandas/tests/extension/test_masked_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,21 +163,20 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
expected = expected.astype(sdtype)
return expected

def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
# overwriting to indicate ops don't raise an error
super().check_opname(ser, op_name, other, exc=None)

def _check_divmod_op(self, ser: pd.Series, op, other, exc=None):
super()._check_divmod_op(ser, op, other, None)
series_scalar_exc = None
series_array_exc = None
frame_scalar_exc = None
divmod_exc = None


class TestComparisonOps(base.BaseComparisonOpsTests):
series_scalar_exc = None
series_array_exc = None
frame_scalar_exc = None

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
return pointwise_result.astype("boolean")

def check_opname(self, ser: pd.Series, op_name: str, other, exc=None):
super().check_opname(ser, op_name, other, exc=None)

def _compare_other(self, ser: pd.Series, data, op, other):
op_name = f"__{op.__name__}__"
self.check_opname(ser, op_name, other)
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_divmod(self, data):
@skip_nested
def test_divmod_series_array(self, data):
ser = pd.Series(data)
self._check_divmod_op(ser, divmod, data, exc=None)
self._check_divmod_op(ser, divmod, data)

@skip_nested
def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
Expand Down

0 comments on commit a3d6c36

Please sign in to comment.