Skip to content

Commit

Permalink
Backport PR #55384 on branch 2.1.x (BUG: idxmax raising for arrow str…
Browse files Browse the repository at this point in the history
…ings) (#55531)

BUG: idxmax raising for arrow strings (#55384)

(cherry picked from commit 68e3c4b)
  • Loading branch information
phofl committed Oct 15, 2023
1 parent 5933c60 commit 0c17c96
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
11 changes: 10 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,6 +1596,15 @@ def _reduce(
------
TypeError : subclass does not define reductions
"""
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
if isinstance(result, pa.Array):
return type(self)(result)
else:
return result

def _reduce_calc(
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
):
pa_result = self._reduce_pyarrow(name, skipna=skipna, **kwargs)

if keepdims:
Expand All @@ -1606,7 +1615,7 @@ def _reduce(
[pa_result],
type=to_pyarrow_type(infer_dtype_from_scalar(pa_result)[0]),
)
return type(self)(result)
return result

if pc.is_null(pa_result).as_py():
return self.dtype.na_value
Expand Down
11 changes: 11 additions & 0 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,17 @@ def _str_rstrip(self, to_strip=None):
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _reduce(
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
):
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
if name in ("argmin", "argmax") and isinstance(result, pa.Array):
return self._convert_int_dtype(result)
elif isinstance(result, pa.Array):
return type(self)(result)
else:
return result

def _convert_int_dtype(self, result):
return Int64Dtype().__from_arrow__(result)

Expand Down
9 changes: 9 additions & 0 deletions pandas/tests/frame/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,15 @@ def test_idxmax_arrow_types(self):
expected = Series([2, 1], index=["a", "b"])
tm.assert_series_equal(result, expected)

df = DataFrame({"a": ["b", "c", "a"]}, dtype="string[pyarrow]")
result = df.idxmax(numeric_only=False)
expected = Series([1], index=["a"])
tm.assert_series_equal(result, expected)

result = df.idxmin(numeric_only=False)
expected = Series([2], index=["a"])
tm.assert_series_equal(result, expected)

def test_idxmax_axis_2(self, float_frame):
frame = float_frame
msg = "No axis named 2 for object type DataFrame"
Expand Down

0 comments on commit 0c17c96

Please sign in to comment.