@@ -873,21 +873,6 @@ def test_basic_equals(self, data):
873873class TestBaseArithmeticOps (base .BaseArithmeticOpsTests ):
874874 divmod_exc = NotImplementedError
875875
876- @classmethod
877- def assert_equal (cls , left , right , ** kwargs ):
878- if isinstance (left , pd .DataFrame ):
879- left_pa_type = left .iloc [:, 0 ].dtype .pyarrow_dtype
880- right_pa_type = right .iloc [:, 0 ].dtype .pyarrow_dtype
881- else :
882- left_pa_type = left .dtype .pyarrow_dtype
883- right_pa_type = right .dtype .pyarrow_dtype
884- if pa .types .is_decimal (left_pa_type ) or pa .types .is_decimal (right_pa_type ):
885- # decimal precision can resize in the result type depending on data
886- # just compare the float values
887- left = left .astype ("float[pyarrow]" )
888- right = right .astype ("float[pyarrow]" )
889- tm .assert_equal (left , right , ** kwargs )
890-
891876 def get_op_from_name (self , op_name ):
892877 short_opname = op_name .strip ("_" )
893878 if short_opname == "rtruediv" :
@@ -934,6 +919,29 @@ def _patch_combine(self, obj, other, op):
934919 unit = "us"
935920
936921 pa_expected = pa_expected .cast (f"duration[{ unit } ]" )
922+
923+ elif pa .types .is_decimal (pa_expected .type ) and pa .types .is_decimal (
924+ original_dtype .pyarrow_dtype
925+ ):
926+ # decimal precision can resize in the result type depending on data
927+ # just compare the float values
928+ alt = op (obj , other )
929+ alt_dtype = tm .get_dtype (alt )
930+ assert isinstance (alt_dtype , ArrowDtype )
931+ if op is operator .pow and isinstance (other , Decimal ):
932+ # TODO: would it make more sense to retain Decimal here?
933+ alt_dtype = ArrowDtype (pa .float64 ())
934+ elif (
935+ op is operator .pow
936+ and isinstance (other , pd .Series )
937+ and other .dtype == original_dtype
938+ ):
939+ # TODO: would it make more sense to retain Decimal here?
940+ alt_dtype = ArrowDtype (pa .float64 ())
941+ else :
942+ assert pa .types .is_decimal (alt_dtype .pyarrow_dtype )
943+ return expected .astype (alt_dtype )
944+
937945 else :
938946 pa_expected = pa_expected .cast (original_dtype .pyarrow_dtype )
939947
@@ -1075,6 +1083,7 @@ def test_arith_series_with_scalar(
10751083 or pa .types .is_duration (pa_dtype )
10761084 or pa .types .is_timestamp (pa_dtype )
10771085 or pa .types .is_date (pa_dtype )
1086+ or pa .types .is_decimal (pa_dtype )
10781087 ):
10791088 # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
10801089 # not upcast
@@ -1107,6 +1116,7 @@ def test_arith_frame_with_scalar(
11071116 or pa .types .is_duration (pa_dtype )
11081117 or pa .types .is_timestamp (pa_dtype )
11091118 or pa .types .is_date (pa_dtype )
1119+ or pa .types .is_decimal (pa_dtype )
11101120 ):
11111121 # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
11121122 # not upcast
@@ -1160,6 +1170,7 @@ def test_arith_series_with_array(
11601170 or pa .types .is_duration (pa_dtype )
11611171 or pa .types .is_timestamp (pa_dtype )
11621172 or pa .types .is_date (pa_dtype )
1173+ or pa .types .is_decimal (pa_dtype )
11631174 ):
11641175 monkeypatch .setattr (TestBaseArithmeticOps , "_combine" , self ._patch_combine )
11651176 self .check_opname (ser , op_name , other , exc = self .series_array_exc )
0 commit comments