Skip to content

Commit

Permalink
Backport PR #55364 on branch 2.1.x (BUG: eq not implemented for categ…
Browse files Browse the repository at this point in the history
…orical and arrow backed strings) (#55381)

Backport PR #55364: BUG: eq not implemented for categorical and arrow backed strings

Co-authored-by: Patrick Hoefler <61934744+phofl@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and phofl committed Oct 3, 2023
1 parent 2201408 commit 0191caf
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Fixed regressions

Bug fixes
~~~~~~~~~
- Fixed bug in :meth:`Categorical.equals` if other has arrow backed string dtype (:issue:`55364`)
- Fixed bug in :meth:`DataFrame.idxmin` and :meth:`DataFrame.idxmax` raising for arrow dtypes (:issue:`55368`)
- Fixed bug in :meth:`Index.insert` raising when inserting ``None`` into :class:`Index` with ``dtype="string[pyarrow_numpy]"`` (:issue:`55365`)
-
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from pandas.core.dtypes.cast import infer_dtype_from_scalar
from pandas.core.dtypes.common import (
CategoricalDtype,
is_array_like,
is_bool_dtype,
is_integer,
Expand Down Expand Up @@ -628,7 +629,9 @@ def __setstate__(self, state) -> None:

def _cmp_method(self, other, op):
pc_func = ARROW_CMP_FUNCS[op.__name__]
if isinstance(other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray)):
if isinstance(
other, (ArrowExtensionArray, np.ndarray, list, BaseMaskedArray)
) or isinstance(getattr(other, "dtype", None), CategoricalDtype):
result = pc_func(self._pa_array, self._box_pa(other))
elif is_scalar(other):
try:
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/indexes/categorical/test_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,9 @@ def test_equals_multiindex(self):
ci = mi.to_flat_index().astype("category")

assert not ci.equals(mi)

def test_equals_string_dtype(self, any_string_dtype):
# GH#55364
idx = CategoricalIndex(list("abc"), name="B")
other = Index(["a", "b", "c"], name="B", dtype=any_string_dtype)
assert idx.equals(other)

0 comments on commit 0191caf

Please sign in to comment.