Skip to content

Commit

Permalink
Backport PR #55362 on branch 2.1.x (BUG: rank raising for arrow strin…
Browse files Browse the repository at this point in the history
…g dtypes) (#55406)
  • Loading branch information
phofl committed Oct 15, 2023
1 parent 2a32088 commit 5933c60
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 6 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Bug fixes
- Fixed bug in :meth:`DataFrame.interpolate` raising incorrect error message (:issue:`55347`)
- Fixed bug in :meth:`Index.insert` raising when inserting ``None`` into :class:`Index` with ``dtype="string[pyarrow_numpy]"`` (:issue:`55365`)
- Fixed bug in :meth:`Series.all` and :meth:`Series.any` not treating missing values correctly for ``dtype="string[pyarrow_numpy]"`` (:issue:`55367`)
- Fixed bug in :meth:`Series.rank` for ``string[pyarrow_numpy]`` dtype (:issue:`55362`)
- Silence ``Period[B]`` warnings introduced by :issue:`53446` during normal plotting activity (:issue:`55138`)
-

Expand Down
31 changes: 25 additions & 6 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,7 +1712,7 @@ def __setitem__(self, key, value) -> None:
data = pa.chunked_array([data])
self._pa_array = data

def _rank(
def _rank_calc(
self,
*,
axis: AxisInt = 0,
Expand All @@ -1721,9 +1721,6 @@ def _rank(
ascending: bool = True,
pct: bool = False,
):
"""
See Series.rank.__doc__.
"""
if pa_version_under9p0 or axis != 0:
ranked = super()._rank(
axis=axis,
Expand All @@ -1738,7 +1735,7 @@ def _rank(
else:
pa_type = pa.uint64()
result = pa.array(ranked, type=pa_type, from_pandas=True)
return type(self)(result)
return result

data = self._pa_array.combine_chunks()
sort_keys = "ascending" if ascending else "descending"
Expand Down Expand Up @@ -1777,7 +1774,29 @@ def _rank(
divisor = pc.count(result)
result = pc.divide(result, divisor)

return type(self)(result)
return result

def _rank(
self,
*,
axis: AxisInt = 0,
method: str = "average",
na_option: str = "keep",
ascending: bool = True,
pct: bool = False,
):
"""
See Series.rank.__doc__.
"""
return type(self)(
self._rank_calc(
axis=axis,
method=method,
na_option=na_option,
ascending=ascending,
pct=pct,
)
)

def _quantile(self, qs: npt.NDArray[np.float64], interpolation: str) -> Self:
"""
Expand Down
30 changes: 30 additions & 0 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

if TYPE_CHECKING:
from pandas._typing import (
AxisInt,
Dtype,
Scalar,
npt,
Expand Down Expand Up @@ -444,6 +445,31 @@ def _str_rstrip(self, to_strip=None):
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _convert_int_dtype(self, result):
return Int64Dtype().__from_arrow__(result)

def _rank(
self,
*,
axis: AxisInt = 0,
method: str = "average",
na_option: str = "keep",
ascending: bool = True,
pct: bool = False,
):
"""
See Series.rank.__doc__.
"""
return self._convert_int_dtype(
self._rank_calc(
axis=axis,
method=method,
na_option=na_option,
ascending=ascending,
pct=pct,
)
)


class ArrowStringArrayNumpySemantics(ArrowStringArray):
_storage = "pyarrow_numpy"
Expand Down Expand Up @@ -527,6 +553,10 @@ def _str_map(
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _convert_int_dtype(self, result):
if isinstance(result, pa.Array):
result = result.to_numpy(zero_copy_only=False)
elif not isinstance(result, np.ndarray):
result = result.to_numpy()
if result.dtype == np.int32:
result = result.astype(np.int64)
return result
Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/frame/methods/test_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,15 @@ def test_rank_mixed_axis_zero(self, data, expected):
df.rank()
result = df.rank(numeric_only=True)
tm.assert_frame_equal(result, expected)

@pytest.mark.parametrize(
"dtype, exp_dtype",
[("string[pyarrow]", "Int64"), ("string[pyarrow_numpy]", "float64")],
)
def test_rank_string_dtype(self, dtype, exp_dtype):
# GH#55362
pytest.importorskip("pyarrow")
obj = Series(["foo", "foo", None, "foo"], dtype=dtype)
result = obj.rank(method="first")
expected = Series([1, 2, None, 3], dtype=exp_dtype)
tm.assert_series_equal(result, expected)

0 comments on commit 5933c60

Please sign in to comment.