Skip to content

Commit

Permalink
Backport PR #54534 on branch 2.1.x (REF: Move methods that can be sha…
Browse files Browse the repository at this point in the history
…red with new string dtype) (#54539)

Backport PR #54534: REF: Move methods that can be shared with new string dtype

Co-authored-by: Patrick Hoefler <61934744+phofl@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and phofl committed Aug 14, 2023
1 parent 723b2c6 commit 64e4527
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 63 deletions.
84 changes: 84 additions & 0 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

from typing import Literal

import numpy as np

from pandas.compat import pa_version_under7p0

if not pa_version_under7p0:
import pyarrow as pa
import pyarrow.compute as pc


class ArrowStringArrayMixin:
_pa_array = None

def __init__(self, *args, **kwargs) -> None:
raise NotImplementedError

def _str_pad(
self,
width: int,
side: Literal["left", "right", "both"] = "left",
fillchar: str = " ",
):
if side == "left":
pa_pad = pc.utf8_lpad
elif side == "right":
pa_pad = pc.utf8_rpad
elif side == "both":
pa_pad = pc.utf8_center
else:
raise ValueError(
f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
)
return type(self)(pa_pad(self._pa_array, width=width, padding=fillchar))

def _str_get(self, i: int):
lengths = pc.utf8_length(self._pa_array)
if i >= 0:
out_of_bounds = pc.greater_equal(i, lengths)
start = i
stop = i + 1
step = 1
else:
out_of_bounds = pc.greater(-i, lengths)
start = i
stop = i - 1
step = -1
not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True))
selected = pc.utf8_slice_codeunits(
self._pa_array, start=start, stop=stop, step=step
)
null_value = pa.scalar(
None, type=self._pa_array.type # type: ignore[attr-defined]
)
result = pc.if_else(not_out_of_bounds, selected, null_value)
return type(self)(result)

def _str_slice_replace(
self, start: int | None = None, stop: int | None = None, repl: str | None = None
):
if repl is None:
repl = ""
if start is None:
start = 0
if stop is None:
stop = np.iinfo(np.int64).max
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))

def _str_capitalize(self):
return type(self)(pc.utf8_capitalize(self._pa_array))

def _str_title(self):
return type(self)(pc.utf8_title(self._pa_array))

def _str_swapcase(self):
return type(self)(pc.utf8_swapcase(self._pa_array))

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)
68 changes: 5 additions & 63 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

from pandas.core import roperator
from pandas.core.arraylike import OpsMixin
from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
from pandas.core.arrays.base import (
ExtensionArray,
ExtensionArraySupportsAnyAll,
Expand Down Expand Up @@ -184,7 +185,10 @@ def to_pyarrow_type(


class ArrowExtensionArray(
OpsMixin, ExtensionArraySupportsAnyAll, BaseStringArrayMethods
OpsMixin,
ExtensionArraySupportsAnyAll,
ArrowStringArrayMixin,
BaseStringArrayMethods,
):
"""
Pandas ExtensionArray backed by a PyArrow ChunkedArray.
Expand Down Expand Up @@ -1986,24 +1990,6 @@ def _str_count(self, pat: str, flags: int = 0):
raise NotImplementedError(f"count not implemented with {flags=}")
return type(self)(pc.count_substring_regex(self._pa_array, pat))

def _str_pad(
self,
width: int,
side: Literal["left", "right", "both"] = "left",
fillchar: str = " ",
):
if side == "left":
pa_pad = pc.utf8_lpad
elif side == "right":
pa_pad = pc.utf8_rpad
elif side == "both":
pa_pad = pc.utf8_center
else:
raise ValueError(
f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
)
return type(self)(pa_pad(self._pa_array, width=width, padding=fillchar))

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
):
Expand Down Expand Up @@ -2088,26 +2074,6 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
)
return type(self)(result)

def _str_get(self, i: int):
lengths = pc.utf8_length(self._pa_array)
if i >= 0:
out_of_bounds = pc.greater_equal(i, lengths)
start = i
stop = i + 1
step = 1
else:
out_of_bounds = pc.greater(-i, lengths)
start = i
stop = i - 1
step = -1
not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True))
selected = pc.utf8_slice_codeunits(
self._pa_array, start=start, stop=stop, step=step
)
null_value = pa.scalar(None, type=self._pa_array.type)
result = pc.if_else(not_out_of_bounds, selected, null_value)
return type(self)(result)

def _str_join(self, sep: str):
if pa.types.is_string(self._pa_array.type):
result = self._apply_elementwise(list)
Expand Down Expand Up @@ -2137,15 +2103,6 @@ def _str_slice(
pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step)
)

def _str_slice_replace(
self, start: int | None = None, stop: int | None = None, repl: str | None = None
):
if repl is None:
repl = ""
if start is None:
start = 0
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))

def _str_isalnum(self):
return type(self)(pc.utf8_is_alnum(self._pa_array))

Expand All @@ -2170,18 +2127,9 @@ def _str_isspace(self):
def _str_istitle(self):
return type(self)(pc.utf8_is_title(self._pa_array))

def _str_capitalize(self):
return type(self)(pc.utf8_capitalize(self._pa_array))

def _str_title(self):
return type(self)(pc.utf8_title(self._pa_array))

def _str_isupper(self):
return type(self)(pc.utf8_is_upper(self._pa_array))

def _str_swapcase(self):
return type(self)(pc.utf8_swapcase(self._pa_array))

def _str_len(self):
return type(self)(pc.utf8_length(self._pa_array))

Expand Down Expand Up @@ -2222,12 +2170,6 @@ def _str_removeprefix(self, prefix: str):
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)

def _str_casefold(self):
predicate = lambda val: val.casefold()
result = self._apply_elementwise(predicate)
Expand Down

0 comments on commit 64e4527

Please sign in to comment.