Skip to content

Commit 52f1aec

Browse files
BUG: avoid validation error for ufunc with string[python] array (#62498)
1 parent 4ee17b3 commit 52f1aec

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

doc/source/whatsnew/v2.3.3.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Bug fixes
4848
with a compiled regex and custom flags (:issue:`62240`)
4949
- Fix :meth:`Series.str.match` and :meth:`Series.str.fullmatch` not matching patterns with groups correctly for the Arrow-backed string dtype (:issue:`61072`)
5050
- Fix comparing a :class:`StringDtype` Series with mixed objects raising an error (:issue:`60228`)
51+
- Fix error being raised when using a numpy ufunc with a Python-backed string array (:issue:`40800`)
5152

5253
Improvements and fixes for Copy-on-Write
5354
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

pandas/core/arrays/numpy_.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
Any,
66
Literal,
77
Self,
8+
cast,
89
)
910

1011
import numpy as np
@@ -48,6 +49,7 @@
4849
)
4950

5051
from pandas import Index
52+
from pandas.arrays import StringArray
5153

5254

5355
class NumpyExtensionArray(
@@ -234,6 +236,16 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
234236
# e.g. test_np_max_nested_tuples
235237
return result
236238
else:
239+
if self.dtype.type is str: # type: ignore[comparison-overlap]
240+
# StringDtype
241+
self = cast("StringArray", self)
242+
try:
243+
# specify dtype to preserve storage/na_value
244+
return type(self)(result, dtype=self.dtype)
245+
except ValueError:
246+
# if validation of input fails (no strings)
247+
# -> fallback to returning raw numpy array
248+
return result
237249
# one return value; re-box array-like results
238250
return type(self)(result)
239251

pandas/tests/arrays/string_/test_string.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,3 +840,30 @@ def test_string_array_view_type_error():
840840
arr = pd.array(["a", "b", "c"], dtype="string")
841841
with pytest.raises(TypeError, match="Cannot change data-type for string array."):
842842
arr.view("i8")
843+
844+
845+
@pytest.mark.parametrize("box", [pd.Series, pd.array])
846+
def test_numpy_array_ufunc(dtype, box):
847+
arr = box(["a", "bb", "ccc"], dtype=dtype)
848+
849+
# custom ufunc that works with string (object) input -> returning numeric
850+
str_len_ufunc = np.frompyfunc(lambda x: len(x), 1, 1)
851+
result = str_len_ufunc(arr)
852+
expected_cls = pd.Series if box is pd.Series else np.array
853+
# TODO we should infer int64 dtype here?
854+
expected = expected_cls([1, 2, 3], dtype=object)
855+
tm.assert_equal(result, expected)
856+
857+
# custom ufunc returning strings
858+
str_multiply_ufunc = np.frompyfunc(lambda x: x * 2, 1, 1)
859+
result = str_multiply_ufunc(arr)
860+
expected = box(["aa", "bbbb", "cccccc"], dtype=dtype)
861+
if dtype.storage == "pyarrow":
862+
# TODO ArrowStringArray should also preserve the class / dtype
863+
if box is pd.array:
864+
expected = np.array(["aa", "bbbb", "cccccc"], dtype=object)
865+
else:
866+
# not specifying the dtype because the exact dtype is not yet preserved
867+
expected = pd.Series(["aa", "bbbb", "cccccc"])
868+
869+
tm.assert_equal(result, expected)

0 commit comments

Comments
 (0)