Skip to content

Commit

Permalink
Backport PR #52499 on branch 2.0.x (ENH: Implement str.r/split for Ar…
Browse files Browse the repository at this point in the history
…rowDtype) (#52603)
  • Loading branch information
meeseeksmachine committed Apr 11, 2023
1 parent d3902a7 commit 99bcfe3
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 29 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Bug fixes

Other
~~~~~
- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)
- :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
- :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)

Expand Down
24 changes: 16 additions & 8 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1929,16 +1929,24 @@ def _str_rfind(self, sub, start: int = 0, end=None):
)

def _str_split(
self, pat=None, n=-1, expand: bool = False, regex: bool | None = None
self,
pat: str | None = None,
n: int | None = -1,
expand: bool = False,
regex: bool | None = None,
):
raise NotImplementedError(
"str.split not supported with pd.ArrowDtype(pa.string())."
)
if n in {-1, 0}:
n = None
if regex:
split_func = pc.split_pattern_regex
else:
split_func = pc.split_pattern
return type(self)(split_func(self._data, pat, max_splits=n))

def _str_rsplit(self, pat=None, n=-1):
raise NotImplementedError(
"str.rsplit not supported with pd.ArrowDtype(pa.string())."
)
def _str_rsplit(self, pat: str | None = None, n: int | None = -1):
if n in {-1, 0}:
n = None
return type(self)(pc.split_pattern(self._data, pat, max_splits=n, reverse=True))

def _str_translate(self, table):
raise NotImplementedError(
Expand Down
51 changes: 32 additions & 19 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from pandas.core.dtypes.missing import isna

from pandas.core.arrays.arrow.dtype import ArrowDtype
from pandas.core.base import NoNewAttributesMixin
from pandas.core.construction import extract_array

Expand Down Expand Up @@ -267,27 +268,39 @@ def _wrap_result(
# infer from ndim if expand is not specified
expand = result.ndim != 1

elif (
expand is True
and is_object_dtype(result)
and not isinstance(self._orig, ABCIndex)
):
elif expand is True and not isinstance(self._orig, ABCIndex):
# required when expand=True is explicitly specified
# not needed when inferred

def cons_row(x):
if is_list_like(x):
return x
else:
return [x]

result = [cons_row(x) for x in result]
if result and not self._is_string:
# propagate nan values to match longest sequence (GH 18450)
max_len = max(len(x) for x in result)
result = [
x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result
]
if isinstance(result.dtype, ArrowDtype):
import pyarrow as pa

from pandas.core.arrays.arrow.array import ArrowExtensionArray

max_len = pa.compute.max(
result._data.combine_chunks().value_lengths()
).as_py()
if result.isna().any():
result._data = result._data.fill_null([None] * max_len)
result = {
i: ArrowExtensionArray(pa.array(res))
for i, res in enumerate(zip(*result.tolist()))
}
elif is_object_dtype(result):

def cons_row(x):
if is_list_like(x):
return x
else:
return [x]

result = [cons_row(x) for x in result]
if result and not self._is_string:
# propagate nan values to match longest sequence (GH 18450)
max_len = max(len(x) for x in result)
result = [
x * max_len if len(x) == 0 or x[0] is np.nan else x
for x in result
]

if not isinstance(expand, bool):
raise ValueError("expand must be True or False")
Expand Down
58 changes: 56 additions & 2 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2098,6 +2098,62 @@ def test_str_removesuffix(val):
tm.assert_series_equal(result, expected)


def test_str_split():
# GH 52401
ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
result = ser.str.split("c")
expected = pd.Series(
ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None]))
)
tm.assert_series_equal(result, expected)

result = ser.str.split("c", n=1)
expected = pd.Series(
ArrowExtensionArray(pa.array([["a1", "bcb"], ["a2", "bcb"], None]))
)
tm.assert_series_equal(result, expected)

result = ser.str.split("[1-2]", regex=True)
expected = pd.Series(
ArrowExtensionArray(pa.array([["a", "cbcb"], ["a", "cbcb"], None]))
)
tm.assert_series_equal(result, expected)

result = ser.str.split("[1-2]", regex=True, expand=True)
expected = pd.DataFrame(
{
0: ArrowExtensionArray(pa.array(["a", "a", None])),
1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])),
}
)
tm.assert_frame_equal(result, expected)


def test_str_rsplit():
# GH 52401
ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
result = ser.str.rsplit("c")
expected = pd.Series(
ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None]))
)
tm.assert_series_equal(result, expected)

result = ser.str.rsplit("c", n=1)
expected = pd.Series(
ArrowExtensionArray(pa.array([["a1cb", "b"], ["a2cb", "b"], None]))
)
tm.assert_series_equal(result, expected)

result = ser.str.rsplit("c", n=1, expand=True)
expected = pd.DataFrame(
{
0: ArrowExtensionArray(pa.array(["a1cb", "a2cb", None])),
1: ArrowExtensionArray(pa.array(["b", "b", None])),
}
)
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"method, args",
[
Expand All @@ -2113,8 +2169,6 @@ def test_str_removesuffix(val):
["rindex", ("abc",)],
["normalize", ("abc",)],
["rfind", ("abc",)],
["split", ()],
["rsplit", ()],
["translate", ("abc",)],
["wrap", ("abc",)],
],
Expand Down

0 comments on commit 99bcfe3

Please sign in to comment.