Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,21 +873,6 @@ def test_basic_equals(self, data):
class TestBaseArithmeticOps(base.BaseArithmeticOpsTests):
divmod_exc = NotImplementedError

@classmethod
def assert_equal(cls, left, right, **kwargs):
if isinstance(left, pd.DataFrame):
left_pa_type = left.iloc[:, 0].dtype.pyarrow_dtype
right_pa_type = right.iloc[:, 0].dtype.pyarrow_dtype
else:
left_pa_type = left.dtype.pyarrow_dtype
right_pa_type = right.dtype.pyarrow_dtype
if pa.types.is_decimal(left_pa_type) or pa.types.is_decimal(right_pa_type):
# decimal precision can resize in the result type depending on data
# just compare the float values
left = left.astype("float[pyarrow]")
right = right.astype("float[pyarrow]")
tm.assert_equal(left, right, **kwargs)

def get_op_from_name(self, op_name):
short_opname = op_name.strip("_")
if short_opname == "rtruediv":
Expand Down Expand Up @@ -934,6 +919,29 @@ def _patch_combine(self, obj, other, op):
unit = "us"

pa_expected = pa_expected.cast(f"duration[{unit}]")

elif pa.types.is_decimal(pa_expected.type) and pa.types.is_decimal(
original_dtype.pyarrow_dtype
):
# decimal precision can resize in the result type depending on data
# just compare the float values
alt = op(obj, other)
alt_dtype = tm.get_dtype(alt)
assert isinstance(alt_dtype, ArrowDtype)
if op is operator.pow and isinstance(other, Decimal):
# TODO: would it make more sense to retain Decimal here?
alt_dtype = ArrowDtype(pa.float64())
elif (
op is operator.pow
and isinstance(other, pd.Series)
and other.dtype == original_dtype
):
# TODO: would it make more sense to retain Decimal here?
alt_dtype = ArrowDtype(pa.float64())
else:
assert pa.types.is_decimal(alt_dtype.pyarrow_dtype)
return expected.astype(alt_dtype)

else:
pa_expected = pa_expected.cast(original_dtype.pyarrow_dtype)

Expand Down Expand Up @@ -1075,6 +1083,7 @@ def test_arith_series_with_scalar(
or pa.types.is_duration(pa_dtype)
or pa.types.is_timestamp(pa_dtype)
or pa.types.is_date(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
# not upcast
Expand Down Expand Up @@ -1107,6 +1116,7 @@ def test_arith_frame_with_scalar(
or pa.types.is_duration(pa_dtype)
or pa.types.is_timestamp(pa_dtype)
or pa.types.is_date(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
# BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
# not upcast
Expand Down Expand Up @@ -1160,6 +1170,7 @@ def test_arith_series_with_array(
or pa.types.is_duration(pa_dtype)
or pa.types.is_timestamp(pa_dtype)
or pa.types.is_date(pa_dtype)
or pa.types.is_decimal(pa_dtype)
):
monkeypatch.setattr(TestBaseArithmeticOps, "_combine", self._patch_combine)
self.check_opname(ser, op_name, other, exc=self.series_array_exc)
Expand Down
17 changes: 2 additions & 15 deletions pandas/tests/extension/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
import pytest

from pandas.core.dtypes.cast import can_hold_element
from pandas.core.dtypes.dtypes import (
ExtensionDtype,
NumpyEADtype,
)
from pandas.core.dtypes.dtypes import NumpyEADtype

import pandas as pd
import pandas._testing as tm
Expand Down Expand Up @@ -176,17 +173,7 @@ def skip_numpy_object(dtype, request):


class BaseNumPyTests:
@classmethod
def assert_series_equal(cls, left, right, *args, **kwargs):
# base class tests hard-code expected values with numpy dtypes,
# whereas we generally want the corresponding NumpyEADtype
if (
isinstance(right, pd.Series)
and not isinstance(right.dtype, ExtensionDtype)
and isinstance(left.dtype, NumpyEADtype)
):
right = right.astype(NumpyEADtype(right.dtype))
return tm.assert_series_equal(left, right, *args, **kwargs)
pass


class TestCasting(BaseNumPyTests, base.BaseCastingTests):
Expand Down