Skip to content

Commit

Permalink
Backport PR #52614: ENH: Implement more str accessor methods for Arro…
Browse files Browse the repository at this point in the history
…wDtype (#52842)
  • Loading branch information
mroeschke committed Apr 22, 2023
1 parent 6d63392 commit fa94d3b
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 77 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.1.rst
Expand Up @@ -50,6 +50,7 @@ Other
- :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
- :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)
- Implemented most ``str`` accessor methods for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)

.. ---------------------------------------------------------------------------
.. _whatsnew_201.contributors:
Expand Down
137 changes: 85 additions & 52 deletions pandas/core/arrays/arrow/array.py
@@ -1,8 +1,11 @@
from __future__ import annotations

from copy import deepcopy
import functools
import operator
import re
import sys
import textwrap
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -12,6 +15,7 @@
TypeVar,
cast,
)
import unicodedata

import numpy as np

Expand Down Expand Up @@ -1655,6 +1659,16 @@ def _replace_with_mask(
result[mask] = replacements
return pa.array(result, type=values.type, from_pandas=True)

def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
"""Apply a callable to each element while maintaining the chunking structure."""
return [
[
None if val is None else func(val)
for val in chunk.to_numpy(zero_copy_only=False)
]
for chunk in self._data.iterchunks()
]

def _str_count(self, pat: str, flags: int = 0):
if flags:
raise NotImplementedError(f"count not implemented with {flags=}")
Expand Down Expand Up @@ -1788,14 +1802,14 @@ def _str_join(self, sep: str):
return type(self)(pc.binary_join(self._data, sep))

def _str_partition(self, sep: str, expand: bool):
raise NotImplementedError(
"str.partition not supported with pd.ArrowDtype(pa.string())."
)
predicate = lambda val: val.partition(sep)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_rpartition(self, sep: str, expand: bool):
raise NotImplementedError(
"str.rpartition not supported with pd.ArrowDtype(pa.string())."
)
predicate = lambda val: val.rpartition(sep)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_slice(
self, start: int | None = None, stop: int | None = None, step: int | None = None
Expand Down Expand Up @@ -1884,14 +1898,21 @@ def _str_rstrip(self, to_strip=None):
return type(self)(result)

def _str_removeprefix(self, prefix: str):
raise NotImplementedError(
"str.removeprefix not supported with pd.ArrowDtype(pa.string())."
)
# TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed
# starts_with = pc.starts_with(self._data, pattern=prefix)
# removed = pc.utf8_slice_codeunits(self._data, len(prefix))
# result = pc.if_else(starts_with, removed, self._data)
# return type(self)(result)
if sys.version_info < (3, 9):
# NOTE pyupgrade will remove this when we run it with --py39-plus
# so don't remove the unnecessary `else` statement below
from pandas.util._str_methods import removeprefix

predicate = functools.partial(removeprefix, prefix=prefix)
else:
predicate = lambda val: val.removeprefix(prefix)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._data, pattern=suffix)
Expand All @@ -1900,49 +1921,59 @@ def _str_removesuffix(self, suffix: str):
return type(self)(result)

def _str_casefold(self):
raise NotImplementedError(
"str.casefold not supported with pd.ArrowDtype(pa.string())."
)
predicate = lambda val: val.casefold()
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_encode(self, encoding, errors: str = "strict"):
raise NotImplementedError(
"str.encode not supported with pd.ArrowDtype(pa.string())."
)
def _str_encode(self, encoding: str, errors: str = "strict"):
predicate = lambda val: val.encode(encoding, errors)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
raise NotImplementedError(
"str.extract not supported with pd.ArrowDtype(pa.string())."
)

def _str_findall(self, pat, flags: int = 0):
raise NotImplementedError(
"str.findall not supported with pd.ArrowDtype(pa.string())."
)
def _str_findall(self, pat: str, flags: int = 0):
regex = re.compile(pat, flags=flags)
predicate = lambda val: regex.findall(val)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_get_dummies(self, sep: str = "|"):
raise NotImplementedError(
"str.get_dummies not supported with pd.ArrowDtype(pa.string())."
)

def _str_index(self, sub, start: int = 0, end=None):
raise NotImplementedError(
"str.index not supported with pd.ArrowDtype(pa.string())."
)

def _str_rindex(self, sub, start: int = 0, end=None):
raise NotImplementedError(
"str.rindex not supported with pd.ArrowDtype(pa.string())."
)

def _str_normalize(self, form):
raise NotImplementedError(
"str.normalize not supported with pd.ArrowDtype(pa.string())."
)

def _str_rfind(self, sub, start: int = 0, end=None):
raise NotImplementedError(
"str.rfind not supported with pd.ArrowDtype(pa.string())."
)
split = pc.split_pattern(self._data, sep).combine_chunks()
uniques = split.flatten().unique()
uniques_sorted = uniques.take(pa.compute.array_sort_indices(uniques))
result_data = []
for lst in split.to_pylist():
if lst is None:
result_data.append([False] * len(uniques_sorted))
else:
res = pc.is_in(uniques_sorted, pa.array(set(lst)))
result_data.append(res.to_pylist())
result = type(self)(pa.array(result_data))
return result, uniques_sorted.to_pylist()

def _str_index(self, sub: str, start: int = 0, end: int | None = None):
predicate = lambda val: val.index(sub, start, end)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_rindex(self, sub: str, start: int = 0, end: int | None = None):
predicate = lambda val: val.rindex(sub, start, end)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_normalize(self, form: str):
predicate = lambda val: unicodedata.normalize(form, val)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_rfind(self, sub: str, start: int = 0, end=None):
predicate = lambda val: val.rfind(sub, start, end)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_split(
self,
Expand All @@ -1964,15 +1995,17 @@ def _str_rsplit(self, pat: str | None = None, n: int | None = -1):
n = None
return type(self)(pc.split_pattern(self._data, pat, max_splits=n, reverse=True))

def _str_translate(self, table):
raise NotImplementedError(
"str.translate not supported with pd.ArrowDtype(pa.string())."
)

def _str_wrap(self, width, **kwargs):
raise NotImplementedError(
"str.wrap not supported with pd.ArrowDtype(pa.string())."
)
def _str_translate(self, table: dict[int, str]):
predicate = lambda val: val.translate(table)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_wrap(self, width: int, **kwargs):
kwargs["width"] = width
tw = textwrap.TextWrapper(**kwargs)
predicate = lambda val: "\n".join(tw.wrap(val))
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

@property
def _dt_year(self):
Expand Down
10 changes: 7 additions & 3 deletions pandas/core/strings/accessor.py
Expand Up @@ -267,7 +267,6 @@ def _wrap_result(
if expand is None:
# infer from ndim if expand is not specified
expand = result.ndim != 1

elif expand is True and not isinstance(self._orig, ABCIndex):
# required when expand=True is explicitly specified
# not needed when inferred
Expand All @@ -280,10 +279,15 @@ def _wrap_result(
result._data.combine_chunks().value_lengths()
).as_py()
if result.isna().any():
# ArrowExtensionArray.fillna doesn't work for list scalars
result._data = result._data.fill_null([None] * max_len)
if name is not None:
labels = name
else:
labels = range(max_len)
result = {
i: ArrowExtensionArray(pa.array(res))
for i, res in enumerate(zip(*result.tolist()))
label: ArrowExtensionArray(pa.array(res))
for label, res in zip(labels, (zip(*result.tolist())))
}
elif is_object_dtype(result):

Expand Down

0 comments on commit fa94d3b

Please sign in to comment.