Skip to content

Commit

Permalink
BUG: Series.str.split(expand=True) for ArrowDtype(pa.string()) (#53532)
Browse files Browse the repository at this point in the history
* BUG: Series.str.split(expand=True) for ArrowDtype(pa.string())

* whatsnew

* min versions

* ensure ArrowExtensionArray
  • Loading branch information
lukemanley committed Jun 7, 2023
1 parent 2f312da commit 66468ce
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 5 deletions.
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v2.0.3.rst
Expand Up @@ -22,6 +22,8 @@ Fixed regressions
Bug fixes
~~~~~~~~~
- Bug in :func:`read_csv` when defining ``dtype`` with ``bool[pyarrow]`` for the ``"c"`` and ``"python"`` engines (:issue:`53390`)
- Bug in :meth:`Series.str.split` and :meth:`Series.str.rsplit` with ``expand=True`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`53532`)
-

.. ---------------------------------------------------------------------------
.. _whatsnew_203.other:
Expand Down
36 changes: 31 additions & 5 deletions pandas/core/strings/accessor.py
Expand Up @@ -275,14 +275,40 @@ def _wrap_result(
if isinstance(result.dtype, ArrowDtype):
import pyarrow as pa

from pandas.compat import pa_version_under11p0

from pandas.core.arrays.arrow.array import ArrowExtensionArray

max_len = pa.compute.max(
result._pa_array.combine_chunks().value_lengths()
).as_py()
if result.isna().any():
value_lengths = result._pa_array.combine_chunks().value_lengths()
max_len = pa.compute.max(value_lengths).as_py()
min_len = pa.compute.min(value_lengths).as_py()
if result._hasna:
# ArrowExtensionArray.fillna doesn't work for list scalars
result._pa_array = result._pa_array.fill_null([None] * max_len)
result = ArrowExtensionArray(
result._pa_array.fill_null([None] * max_len)
)
if min_len < max_len:
# append nulls to each scalar list element up to max_len
if not pa_version_under11p0:
result = ArrowExtensionArray(
pa.compute.list_slice(
result._pa_array,
start=0,
stop=max_len,
return_fixed_size_list=True,
)
)
else:
all_null = np.full(max_len, fill_value=None, dtype=object)
values = result.to_numpy()
new_values = []
for row in values:
if len(row) < max_len:
nulls = all_null[: max_len - len(row)]
row = np.append(row, nulls)
new_values.append(row)
pa_type = result._pa_array.type
result = ArrowExtensionArray(pa.array(new_values, type=pa_type))
if name is not None:
labels = name
else:
Expand Down
18 changes: 18 additions & 0 deletions pandas/tests/extension/test_arrow.py
Expand Up @@ -2286,6 +2286,15 @@ def test_str_split():
)
tm.assert_frame_equal(result, expected)

result = ser.str.split("1", expand=True)
expected = pd.DataFrame(
{
0: ArrowExtensionArray(pa.array(["a", "a2cbcb", None])),
1: ArrowExtensionArray(pa.array(["cbcb", None, None])),
}
)
tm.assert_frame_equal(result, expected)


def test_str_rsplit():
# GH 52401
Expand All @@ -2311,6 +2320,15 @@ def test_str_rsplit():
)
tm.assert_frame_equal(result, expected)

result = ser.str.rsplit("1", expand=True)
expected = pd.DataFrame(
{
0: ArrowExtensionArray(pa.array(["a", "a2cbcb", None])),
1: ArrowExtensionArray(pa.array(["cbcb", None, None])),
}
)
tm.assert_frame_equal(result, expected)


def test_str_unsupported_extract():
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
Expand Down

0 comments on commit 66468ce

Please sign in to comment.