Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,9 @@ def any_string_dtype(request):
return pd.StringDtype(storage, na_value)


any_string_dtype2 = any_string_dtype


@pytest.fixture(params=tm.DATETIME64_DTYPES)
def datetime64_dtype(request):
"""
Expand Down
308 changes: 308 additions & 0 deletions pandas/tests/arithmetic/test_string.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,49 @@
import operator
from pathlib import Path

import numpy as np
import pytest

from pandas.compat import HAS_PYARROW
from pandas.errors import Pandas4Warning
import pandas.util._test_decorators as td

import pandas as pd
from pandas import (
NA,
ArrowDtype,
Series,
StringDtype,
)
import pandas._testing as tm
from pandas.core.construction import extract_array


def string_dtype_highest_priority(dtype1, dtype2):
if HAS_PYARROW:
DTYPE_HIERARCHY = [
StringDtype("python", na_value=np.nan),
StringDtype("pyarrow", na_value=np.nan),
StringDtype("python", na_value=NA),
StringDtype("pyarrow", na_value=NA),
]
else:
DTYPE_HIERARCHY = [
StringDtype("python", na_value=np.nan),
StringDtype("python", na_value=NA),
]

h1 = DTYPE_HIERARCHY.index(dtype1)
h2 = DTYPE_HIERARCHY.index(dtype2)
return DTYPE_HIERARCHY[max(h1, h2)]


def test_eq_all_na():
pytest.importorskip("pyarrow")
a = pd.array([NA, NA], dtype=StringDtype("pyarrow"))
result = a == a
expected = pd.array([NA, NA], dtype="boolean[pyarrow]")
tm.assert_extension_array_equal(result, expected)


def test_reversed_logical_ops(any_string_dtype):
Expand Down Expand Up @@ -134,3 +166,279 @@ def test_mul_bool_invalid(any_string_dtype):
ser * np.array([True, False, True], dtype=bool)
with pytest.raises(TypeError, match=msg):
np.array([True, False, True], dtype=bool) * ser


def test_add(any_string_dtype, request):
dtype = any_string_dtype
if dtype == object:
mark = pytest.mark.xfail(
reason="Need to update expected for numpy object dtype"
)
request.applymarker(mark)

a = Series(["a", "b", "c", None, None], dtype=dtype)
b = Series(["x", "y", None, "z", None], dtype=dtype)

result = a + b
expected = Series(["ax", "by", None, None, None], dtype=dtype)
tm.assert_series_equal(result, expected)

result = a.add(b)
tm.assert_series_equal(result, expected)

result = a.radd(b)
expected = Series(["xa", "yb", None, None, None], dtype=dtype)
tm.assert_series_equal(result, expected)

result = a.add(b, fill_value="-")
expected = Series(["ax", "by", "c-", "-z", None], dtype=dtype)
tm.assert_series_equal(result, expected)


def test_add_2d(any_string_dtype, request):
dtype = any_string_dtype

if dtype == object or dtype.storage == "pyarrow":
reason = "Failed: DID NOT RAISE <class 'ValueError'>"
mark = pytest.mark.xfail(raises=None, reason=reason)
request.applymarker(mark)

a = pd.array(["a", "b", "c"], dtype=dtype)
b = np.array([["a", "b", "c"]], dtype=object)
with pytest.raises(ValueError, match="3 != 1"):
a + b

s = Series(a)
with pytest.raises(ValueError, match="3 != 1"):
s + b


def test_add_sequence(any_string_dtype, request):
dtype = any_string_dtype
if dtype == np.dtype(object):
mark = pytest.mark.xfail(reason="Cannot broadcast list")
request.applymarker(mark)

a = pd.array(["a", "b", None, None], dtype=dtype)
other = ["x", None, "y", None]

result = a + other
expected = pd.array(["ax", None, None, None], dtype=dtype)
tm.assert_extension_array_equal(result, expected)

result = other + a
expected = pd.array(["xa", None, None, None], dtype=dtype)
tm.assert_extension_array_equal(result, expected)


def test_mul(any_string_dtype):
dtype = any_string_dtype
a = pd.array(["a", "b", None], dtype=dtype)
result = a * 2
expected = pd.array(["aa", "bb", None], dtype=dtype)
tm.assert_extension_array_equal(result, expected)

result = 2 * a
tm.assert_extension_array_equal(result, expected)


def test_add_strings(any_string_dtype, request):
dtype = any_string_dtype
if dtype != np.dtype(object):
mark = pytest.mark.xfail(reason="GH-28527")
request.applymarker(mark)
arr = pd.array(["a", "b", "c", "d"], dtype=dtype)
df = pd.DataFrame([["t", "y", "v", "w"]], dtype=object)
assert arr.__add__(df) is NotImplemented

result = arr + df
expected = pd.DataFrame([["at", "by", "cv", "dw"]]).astype(dtype)
tm.assert_frame_equal(result, expected)

result = df + arr
expected = pd.DataFrame([["ta", "yb", "vc", "wd"]]).astype(dtype)
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(reason="GH-28527")
def test_add_frame(dtype):
arr = pd.array(["a", "b", np.nan, np.nan], dtype=dtype)
df = pd.DataFrame([["x", np.nan, "y", np.nan]])

assert arr.__add__(df) is NotImplemented

result = arr + df
expected = pd.DataFrame([["ax", np.nan, np.nan, np.nan]]).astype(dtype)
tm.assert_frame_equal(result, expected)

result = df + arr
expected = pd.DataFrame([["xa", np.nan, np.nan, np.nan]]).astype(dtype)
tm.assert_frame_equal(result, expected)


def test_comparison_methods_scalar(comparison_op, any_string_dtype):
dtype = any_string_dtype
op_name = f"__{comparison_op.__name__}__"
a = pd.array(["a", None, "c"], dtype=dtype)
other = "a"
result = getattr(a, op_name)(other)
if dtype == object or dtype.na_value is np.nan:
expected = np.array([getattr(item, op_name)(other) for item in a])
if comparison_op == operator.ne:
expected[1] = True
else:
expected[1] = False
result = extract_array(result, extract_numpy=True)
tm.assert_numpy_array_equal(result, expected.astype(np.bool_))
else:
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
expected = np.array([getattr(item, op_name)(other) for item in a], dtype=object)
expected = pd.array(expected, dtype=expected_dtype)
tm.assert_extension_array_equal(result, expected)


def test_comparison_methods_scalar_pd_na(comparison_op, any_string_dtype):
dtype = any_string_dtype
op_name = f"__{comparison_op.__name__}__"
a = pd.array(["a", None, "c"], dtype=dtype)
result = getattr(a, op_name)(NA)

if dtype == np.dtype(object) or dtype.na_value is np.nan:
if operator.ne == comparison_op:
expected = np.array([True, True, True])
else:
expected = np.array([False, False, False])
result = extract_array(result, extract_numpy=True)
tm.assert_numpy_array_equal(result, expected)
else:
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
expected = pd.array([None, None, None], dtype=expected_dtype)
tm.assert_extension_array_equal(result, expected)
tm.assert_extension_array_equal(result, expected)


def test_comparison_methods_scalar_not_string(comparison_op, any_string_dtype):
op_name = f"__{comparison_op.__name__}__"
dtype = any_string_dtype

a = pd.array(["a", None, "c"], dtype=dtype)
other = 42

if op_name not in ["__eq__", "__ne__"]:
with pytest.raises(TypeError, match="Invalid comparison|not supported between"):
getattr(a, op_name)(other)

return

result = getattr(a, op_name)(other)
result = extract_array(result, extract_numpy=True)

if dtype == np.dtype(object) or dtype.na_value is np.nan:
expected_data = {
"__eq__": [False, False, False],
"__ne__": [True, True, True],
}[op_name]
expected = np.array(expected_data)
tm.assert_numpy_array_equal(result, expected)
else:
expected_data = {"__eq__": [False, None, False], "__ne__": [True, None, True]}[
op_name
]
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
expected = pd.array(expected_data, dtype=expected_dtype)
tm.assert_extension_array_equal(result, expected)


def test_comparison_methods_array(comparison_op, any_string_dtype, any_string_dtype2):
op_name = f"__{comparison_op.__name__}__"
dtype = any_string_dtype
dtype2 = any_string_dtype2

a = pd.array(["a", None, "c"], dtype=dtype)
other = pd.array([None, None, "c"], dtype=dtype2)
result = comparison_op(a, other)
result = extract_array(result, extract_numpy=True)

# ensure operation is commutative
result2 = comparison_op(other, a)
result2 = extract_array(result2, extract_numpy=True)
tm.assert_equal(result, result2)

if (dtype == object or dtype.na_value is np.nan) and (
dtype2 == object or dtype2.na_value is np.nan
):
if operator.ne == comparison_op:
expected = np.array([True, True, False])
else:
expected = np.array([False, False, False])
expected[-1] = getattr(other[-1], op_name)(a[-1])
result = extract_array(result, extract_numpy=True)
tm.assert_numpy_array_equal(result, expected)

else:
if dtype == object:
max_dtype = dtype2
elif dtype2 == object:
max_dtype = dtype
else:
max_dtype = string_dtype_highest_priority(dtype, dtype2)
if max_dtype.storage == "python":
expected_dtype = "boolean"
else:
expected_dtype = "bool[pyarrow]"

expected = np.full(len(a), fill_value=None, dtype="object")
expected[-1] = getattr(other[-1], op_name)(a[-1])
expected = pd.array(expected, dtype=expected_dtype)
tm.assert_equal(result, expected)


@td.skip_if_no("pyarrow")
def test_comparison_methods_array_arrow_extension(comparison_op, any_string_dtype):
# Test pd.ArrowDtype(pa.string()) against other string arrays
import pyarrow as pa

dtype2 = any_string_dtype

op_name = f"__{comparison_op.__name__}__"
dtype = ArrowDtype(pa.string())
a = pd.array(["a", None, "c"], dtype=dtype)
other = pd.array([None, None, "c"], dtype=dtype2)
result = comparison_op(a, other)

# ensure operation is commutative
result2 = comparison_op(other, a)
tm.assert_equal(result, result2)

expected = pd.array([None, None, True], dtype="bool[pyarrow]")
expected[-1] = getattr(other[-1], op_name)(a[-1])
tm.assert_extension_array_equal(result, expected)


def test_comparison_methods_list(comparison_op, any_string_dtype):
dtype = any_string_dtype
op_name = f"__{comparison_op.__name__}__"

a = pd.array(["a", None, "c"], dtype=dtype)
other = [None, None, "c"]
result = comparison_op(a, other)

# ensure operation is commutative
result2 = comparison_op(other, a)
tm.assert_equal(result, result2)

if dtype == object or dtype.na_value is np.nan:
if operator.ne == comparison_op:
expected = np.array([True, True, False])
else:
expected = np.array([False, False, False])
expected[-1] = getattr(other[-1], op_name)(a[-1])
result = extract_array(result, extract_numpy=True)
tm.assert_numpy_array_equal(result, expected)

else:
expected_dtype = "boolean[pyarrow]" if dtype.storage == "pyarrow" else "boolean"
expected = np.full(len(a), fill_value=None, dtype="object")
expected[-1] = getattr(other[-1], op_name)(a[-1])
expected = pd.array(expected, dtype=expected_dtype)
tm.assert_extension_array_equal(result, expected)
Loading
Loading