Skip to content

Commit

Permalink
BUG: Index.str.cat casting result always to object (#56157)
Browse files Browse the repository at this point in the history
* BUG: Index.str.cat casting result always to object

* Update accessor.py

* Fix further bugs

* Fix

* Update accessor.py

* Update v2.1.4.rst

* Update v2.2.0.rst
  • Loading branch information
phofl committed Dec 8, 2023
1 parent 45361a4 commit a3626f2
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 59 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ Strings
^^^^^^^
- Bug in :func:`pandas.api.types.is_string_dtype` while checking object array with no elements is of the string dtype (:issue:`54661`)
- Bug in :meth:`DataFrame.apply` failing when ``engine="numba"`` and columns or index have ``StringDtype`` (:issue:`56189`)
- Bug in :meth:`Index.str.cat` always casting result to object dtype (:issue:`56157`)
- Bug in :meth:`Series.__mul__` for :class:`ArrowDtype` with ``pyarrow.string`` dtype and ``string[pyarrow]`` for the pyarrow backend (:issue:`51970`)
- Bug in :meth:`Series.str.replace` when ``n < 0`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`56404`)
- Bug in :meth:`Series.str.startswith` and :meth:`Series.str.endswith` with arguments of type ``tuple[str, ...]`` for ``string[pyarrow]`` (:issue:`54942`)
Expand Down
10 changes: 7 additions & 3 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from pandas.core.dtypes.missing import isna

from pandas.core.arrays import ExtensionArray
from pandas.core.base import NoNewAttributesMixin
from pandas.core.construction import extract_array

Expand Down Expand Up @@ -456,7 +457,7 @@ def _get_series_list(self, others):
# in case of list-like `others`, all elements must be
# either Series/Index/np.ndarray (1-dim)...
if all(
isinstance(x, (ABCSeries, ABCIndex))
isinstance(x, (ABCSeries, ABCIndex, ExtensionArray))
or (isinstance(x, np.ndarray) and x.ndim == 1)
for x in others
):
Expand Down Expand Up @@ -690,12 +691,15 @@ def cat(
out: Index | Series
if isinstance(self._orig, ABCIndex):
# add dtype for case that result is all-NA
dtype = None
if isna(result).all():
dtype = object

out = Index(result, dtype=object, name=self._orig.name)
out = Index(result, dtype=dtype, name=self._orig.name)
else: # Series
if isinstance(self._orig.dtype, CategoricalDtype):
# We need to infer the new categories.
dtype = None
dtype = self._orig.dtype.categories.dtype # type: ignore[assignment]
else:
dtype = self._orig.dtype
res_ser = Series(
Expand Down
2 changes: 2 additions & 0 deletions pandas/tests/strings/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

from pandas import (
CategoricalDtype,
DataFrame,
Index,
MultiIndex,
Expand Down Expand Up @@ -178,6 +179,7 @@ def test_api_for_categorical(any_string_method, any_string_dtype):
s = Series(list("aabb"), dtype=any_string_dtype)
s = s + " " + s
c = s.astype("category")
c = c.astype(CategoricalDtype(c.dtype.categories.astype("object")))
assert isinstance(c.str, StringMethods)

method_name, args, kwargs = any_string_method
Expand Down
135 changes: 79 additions & 56 deletions pandas/tests/strings/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import numpy as np
import pytest

import pandas.util._test_decorators as td

from pandas import (
DataFrame,
Index,
MultiIndex,
Series,
_testing as tm,
concat,
option_context,
)


Expand All @@ -26,45 +29,49 @@ def test_str_cat_name(index_or_series, other):
assert result.name == "name"


def test_str_cat(index_or_series):
box = index_or_series
# test_cat above tests "str_cat" from ndarray;
# here testing "str.cat" from Series/Index to ndarray/list
s = box(["a", "a", "b", "b", "c", np.nan])
@pytest.mark.parametrize(
"infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))]
)
def test_str_cat(index_or_series, infer_string):
with option_context("future.infer_string", infer_string):
box = index_or_series
# test_cat above tests "str_cat" from ndarray;
# here testing "str.cat" from Series/Index to ndarray/list
s = box(["a", "a", "b", "b", "c", np.nan])

# single array
result = s.str.cat()
expected = "aabbc"
assert result == expected
# single array
result = s.str.cat()
expected = "aabbc"
assert result == expected

result = s.str.cat(na_rep="-")
expected = "aabbc-"
assert result == expected
result = s.str.cat(na_rep="-")
expected = "aabbc-"
assert result == expected

result = s.str.cat(sep="_", na_rep="NA")
expected = "a_a_b_b_c_NA"
assert result == expected
result = s.str.cat(sep="_", na_rep="NA")
expected = "a_a_b_b_c_NA"
assert result == expected

t = np.array(["a", np.nan, "b", "d", "foo", np.nan], dtype=object)
expected = box(["aa", "a-", "bb", "bd", "cfoo", "--"])
t = np.array(["a", np.nan, "b", "d", "foo", np.nan], dtype=object)
expected = box(["aa", "a-", "bb", "bd", "cfoo", "--"])

# Series/Index with array
result = s.str.cat(t, na_rep="-")
tm.assert_equal(result, expected)
# Series/Index with array
result = s.str.cat(t, na_rep="-")
tm.assert_equal(result, expected)

# Series/Index with list
result = s.str.cat(list(t), na_rep="-")
tm.assert_equal(result, expected)
# Series/Index with list
result = s.str.cat(list(t), na_rep="-")
tm.assert_equal(result, expected)

# errors for incorrect lengths
rgx = r"If `others` contains arrays or lists \(or other list-likes.*"
z = Series(["1", "2", "3"])
# errors for incorrect lengths
rgx = r"If `others` contains arrays or lists \(or other list-likes.*"
z = Series(["1", "2", "3"])

with pytest.raises(ValueError, match=rgx):
s.str.cat(z.values)
with pytest.raises(ValueError, match=rgx):
s.str.cat(z.values)

with pytest.raises(ValueError, match=rgx):
s.str.cat(list(z))
with pytest.raises(ValueError, match=rgx):
s.str.cat(list(z))


def test_str_cat_raises_intuitive_error(index_or_series):
Expand All @@ -78,39 +85,54 @@ def test_str_cat_raises_intuitive_error(index_or_series):
s.str.cat(" ")


@pytest.mark.parametrize(
"infer_string", [False, pytest.param(True, marks=td.skip_if_no("pyarrow"))]
)
@pytest.mark.parametrize("sep", ["", None])
@pytest.mark.parametrize("dtype_target", ["object", "category"])
@pytest.mark.parametrize("dtype_caller", ["object", "category"])
def test_str_cat_categorical(index_or_series, dtype_caller, dtype_target, sep):
def test_str_cat_categorical(
index_or_series, dtype_caller, dtype_target, sep, infer_string
):
box = index_or_series

s = Index(["a", "a", "b", "a"], dtype=dtype_caller)
s = s if box == Index else Series(s, index=s)
t = Index(["b", "a", "b", "c"], dtype=dtype_target)

expected = Index(["ab", "aa", "bb", "ac"])
expected = expected if box == Index else Series(expected, index=s)
with option_context("future.infer_string", infer_string):
s = Index(["a", "a", "b", "a"], dtype=dtype_caller)
s = s if box == Index else Series(s, index=s)
t = Index(["b", "a", "b", "c"], dtype=dtype_target)

# Series/Index with unaligned Index -> t.values
result = s.str.cat(t.values, sep=sep)
tm.assert_equal(result, expected)

# Series/Index with Series having matching Index
t = Series(t.values, index=s)
result = s.str.cat(t, sep=sep)
tm.assert_equal(result, expected)

# Series/Index with Series.values
result = s.str.cat(t.values, sep=sep)
tm.assert_equal(result, expected)
expected = Index(["ab", "aa", "bb", "ac"])
expected = (
expected
if box == Index
else Series(expected, index=Index(s, dtype=dtype_caller))
)

# Series/Index with Series having different Index
t = Series(t.values, index=t.values)
expected = Index(["aa", "aa", "bb", "bb", "aa"])
expected = expected if box == Index else Series(expected, index=expected.str[:1])
# Series/Index with unaligned Index -> t.values
result = s.str.cat(t.values, sep=sep)
tm.assert_equal(result, expected)

# Series/Index with Series having matching Index
t = Series(t.values, index=Index(s, dtype=dtype_caller))
result = s.str.cat(t, sep=sep)
tm.assert_equal(result, expected)

# Series/Index with Series.values
result = s.str.cat(t.values, sep=sep)
tm.assert_equal(result, expected)

# Series/Index with Series having different Index
t = Series(t.values, index=t.values)
expected = Index(["aa", "aa", "bb", "bb", "aa"])
dtype = object if dtype_caller == "object" else s.dtype.categories.dtype
expected = (
expected
if box == Index
else Series(expected, index=Index(expected.str[:1], dtype=dtype))
)

result = s.str.cat(t, sep=sep)
tm.assert_equal(result, expected)
result = s.str.cat(t, sep=sep)
tm.assert_equal(result, expected)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -321,8 +343,9 @@ def test_str_cat_all_na(index_or_series, index_or_series2):

# all-NA target
if box == Series:
expected = Series([np.nan] * 4, index=s.index, dtype=object)
expected = Series([np.nan] * 4, index=s.index, dtype=s.dtype)
else: # box == Index
# TODO: Strimg option, this should return string dtype
expected = Index([np.nan] * 4, dtype=object)
result = s.str.cat(t, join="left")
tm.assert_equal(result, expected)
Expand Down

0 comments on commit a3626f2

Please sign in to comment.