Skip to content

Commit

Permalink
Backport PR #52633 on branch 2.0.x (BUG: Logical and comparison ops w…
Browse files Browse the repository at this point in the history
…ith ArrowDtype & masked) (#52767)

* Backport PR #52633: BUG: Logical and comparison ops with ArrowDtype & masked

* Make runtime import
  • Loading branch information
mroeschke committed Apr 19, 2023
1 parent c10de3a commit 5a95983
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Bug fixes
- Bug in :meth:`DataFrame.max` and related casting different :class:`Timestamp` resolutions always to nanoseconds (:issue:`52524`)
- Bug in :meth:`Series.describe` not returning :class:`ArrowDtype` with ``pyarrow.float64`` type with numeric data (:issue:`52427`)
- Bug in :meth:`Series.dt.tz_localize` incorrectly localizing timestamps with :class:`ArrowDtype` (:issue:`52677`)
- Bug in logical and comparison operations between :class:`ArrowDtype` and numpy masked types (e.g. ``"boolean"``) (:issue:`52625`)
- Fixed bug in :func:`merge` when merging with ``ArrowDtype`` one one and a NumPy dtype on the other side (:issue:`52406`)
- Fixed segfault in :meth:`Series.to_numpy` with ``null[pyarrow]`` dtype (:issue:`52443`)

Expand Down
10 changes: 10 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,16 @@ def __setstate__(self, state) -> None:
self.__dict__.update(state)

def _cmp_method(self, other, op):
from pandas.core.arrays.masked import BaseMaskedArray

pc_func = ARROW_CMP_FUNCS[op.__name__]
if isinstance(other, ArrowExtensionArray):
result = pc_func(self._data, other._data)
elif isinstance(other, (np.ndarray, list)):
result = pc_func(self._data, other)
elif isinstance(other, BaseMaskedArray):
# GH 52625
result = pc_func(self._data, other.__arrow_array__())
elif is_scalar(other):
try:
result = pc_func(self._data, pa.scalar(other))
Expand All @@ -456,6 +461,8 @@ def _cmp_method(self, other, op):
return ArrowExtensionArray(result)

def _evaluate_op_method(self, other, op, arrow_funcs):
from pandas.core.arrays.masked import BaseMaskedArray

pa_type = self._data.type
if (pa.types.is_string(pa_type) or pa.types.is_binary(pa_type)) and op in [
operator.add,
Expand Down Expand Up @@ -486,6 +493,9 @@ def _evaluate_op_method(self, other, op, arrow_funcs):
result = pc_func(self._data, other._data)
elif isinstance(other, (np.ndarray, list)):
result = pc_func(self._data, pa.array(other, from_pandas=True))
elif isinstance(other, BaseMaskedArray):
# GH 52625
result = pc_func(self._data, other.__arrow_array__())
elif is_scalar(other):
if isna(other) and op.__name__ in ARROW_LOGICAL_FUNCS:
# pyarrow kleene ops require null to be typed
Expand Down
34 changes: 33 additions & 1 deletion pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BytesIO,
StringIO,
)
import operator
import pickle
import re

Expand Down Expand Up @@ -1218,7 +1219,7 @@ def test_add_series_with_extension_array(self, data, request):


class TestBaseComparisonOps(base.BaseComparisonOpsTests):
def test_compare_array(self, data, comparison_op, na_value, request):
def test_compare_array(self, data, comparison_op, na_value):
ser = pd.Series(data)
# pd.Series([ser.iloc[0]] * len(ser)) may not return ArrowExtensionArray
# since ser.iloc[0] is a python scalar
Expand Down Expand Up @@ -1257,6 +1258,20 @@ def test_invalid_other_comp(self, data, comparison_op):
):
comparison_op(data, object())

@pytest.mark.parametrize("masked_dtype", ["boolean", "Int64", "Float64"])
def test_comp_masked_numpy(self, masked_dtype, comparison_op):
# GH 52625
data = [1, 0, None]
ser_masked = pd.Series(data, dtype=masked_dtype)
ser_pa = pd.Series(data, dtype=f"{masked_dtype.lower()}[pyarrow]")
result = comparison_op(ser_pa, ser_masked)
if comparison_op in [operator.lt, operator.gt, operator.ne]:
exp = [False, False, None]
else:
exp = [True, True, None]
expected = pd.Series(exp, dtype=ArrowDtype(pa.bool_()))
tm.assert_series_equal(result, expected)


class TestLogicalOps:
"""Various Series and DataFrame logical ops methods."""
Expand Down Expand Up @@ -1401,6 +1416,23 @@ def test_kleene_xor_scalar(self, other, expected):
a, pd.Series([True, False, None], dtype="boolean[pyarrow]")
)

@pytest.mark.parametrize(
"op, exp",
[
["__and__", True],
["__or__", True],
["__xor__", False],
],
)
def test_logical_masked_numpy(self, op, exp):
# GH 52625
data = [True, False, None]
ser_masked = pd.Series(data, dtype="boolean")
ser_pa = pd.Series(data, dtype="boolean[pyarrow]")
result = getattr(ser_pa, op)(ser_masked)
expected = pd.Series([exp, False, None], dtype=ArrowDtype(pa.bool_()))
tm.assert_series_equal(result, expected)


def test_arrowdtype_construct_from_string_type_with_unsupported_parameters():
with pytest.raises(NotImplementedError, match="Passing pyarrow type"):
Expand Down

0 comments on commit 5a95983

Please sign in to comment.