From e8eb0f29c2744aaa51ecb71ae10b06ed4e9f56d0 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 31 Jul 2023 20:56:16 -0700 Subject: [PATCH] REF: don patch assert_series_equal, assert_equal --- pandas/tests/extension/test_arrow.py | 41 ++++++++++++++++++---------- pandas/tests/extension/test_numpy.py | 17 ++---------- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index eea79a789148e..311727b124df1 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -873,21 +873,6 @@ def test_basic_equals(self, data): class TestBaseArithmeticOps(base.BaseArithmeticOpsTests): divmod_exc = NotImplementedError - @classmethod - def assert_equal(cls, left, right, **kwargs): - if isinstance(left, pd.DataFrame): - left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype - right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype - else: - left_pa_type = left.dtype.pyarrow_dtype - right_pa_type = right.dtype.pyarrow_dtype - if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type): - # decimal precision can resize in the result type depending on data - # just compare the float values - left = left.astype("float[pyarrow]") - right = right.astype("float[pyarrow]") - tm.assert_equal(left, right, **kwargs) - def get_op_from_name(self, op_name): short_opname = op_name.strip("_") if short_opname == "rtruediv": @@ -934,6 +919,29 @@ def _patch_combine(self, obj, other, op): unit = "us" pa_expected = pa_expected.cast(f"duration[{unit}]") + + elif pa.types.is_decimal(pa_expected.type) and pa.types.is_decimal( + original_dtype.pyarrow_dtype + ): + # decimal precision can resize in the result type depending on data + # just compare the float values + alt = op(obj, other) + alt_dtype = tm.get_dtype(alt) + assert isinstance(alt_dtype, ArrowDtype) + if op is operator.pow and isinstance(other, Decimal): + # TODO: would it make more sense to retain Decimal here? + alt_dtype = ArrowDtype(pa.float64()) + elif ( + op is operator.pow + and isinstance(other, pd.Series) + and other.dtype == original_dtype + ): + # TODO: would it make more sense to retain Decimal here? + alt_dtype = ArrowDtype(pa.float64()) + else: + assert pa.types.is_decimal(alt_dtype.pyarrow_dtype) + return expected.astype(alt_dtype) + else: pa_expected = pa_expected.cast(original_dtype.pyarrow_dtype) @@ -1075,6 +1083,7 @@ def test_arith_series_with_scalar( or pa.types.is_duration(pa_dtype) or pa.types.is_timestamp(pa_dtype) or pa.types.is_date(pa_dtype) + or pa.types.is_decimal(pa_dtype) ): # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does # not upcast @@ -1107,6 +1116,7 @@ def test_arith_frame_with_scalar( or pa.types.is_duration(pa_dtype) or pa.types.is_timestamp(pa_dtype) or pa.types.is_date(pa_dtype) + or pa.types.is_decimal(pa_dtype) ): # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does # not upcast @@ -1160,6 +1170,7 @@ def test_arith_series_with_array( or pa.types.is_duration(pa_dtype) or pa.types.is_timestamp(pa_dtype) or pa.types.is_date(pa_dtype) + or pa.types.is_decimal(pa_dtype) ): monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine) self.check_opname(ser, op_name, other, exc=self.series_array_exc) diff --git a/pandas/tests/extension/test_numpy.py b/pandas/tests/extension/test_numpy.py index 0d4624087fffd..993a08d7bd369 100644 --- a/pandas/tests/extension/test_numpy.py +++ b/pandas/tests/extension/test_numpy.py @@ -19,10 +19,7 @@ import pytest from pandas.core.dtypes.cast import can_hold_element -from pandas.core.dtypes.dtypes import ( - ExtensionDtype, - NumpyEADtype, -) +from pandas.core.dtypes.dtypes import NumpyEADtype import pandas as pd import pandas._testing as tm @@ -176,17 +173,7 @@ def skip_numpy_object(dtype, request): class BaseNumPyTests: - @classmethod - def assert_series_equal(cls, left, right, *args, **kwargs): - # base class tests hard-code expected values with numpy dtypes, - # whereas we generally want the corresponding NumpyEADtype - if ( - isinstance(right, pd.Series) - and not isinstance(right.dtype, ExtensionDtype) - and isinstance(left.dtype, NumpyEADtype) - ): - right = right.astype(NumpyEADtype(right.dtype)) - return tm.assert_series_equal(left, right, *args, **kwargs) + pass class TestCasting(BaseNumPyTests, base.BaseCastingTests):