Skip to content

Commit

Permalink
Backport PR #35750: Pass check_dtype to assert_extension_array_equal (#…
Browse files Browse the repository at this point in the history
…35773)

Co-authored-by: Daniel Saxton <2658661+dsaxton@users.noreply.github.com>
  • Loading branch information
meeseeksmachine and dsaxton committed Aug 17, 2020
1 parent 66d08dc commit 56e95ad
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 2 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.1.1.rst
Expand Up @@ -37,6 +37,7 @@ Bug fixes
~~~~~~~~~

- Bug in ``Styler`` whereby `cell_ids` argument had no effect due to other recent changes (:issue:`35588`) (:issue:`35663`).
- Bug in :func:`pandas.testing.assert_series_equal` and :func:`pandas.testing.assert_frame_equal` where extension dtypes were not ignored when ``check_dtypes`` was set to ``False`` (:issue:`35715`).

Categorical
^^^^^^^^^^^
Expand Down
10 changes: 8 additions & 2 deletions pandas/_testing.py
Expand Up @@ -1377,12 +1377,18 @@ def assert_series_equal(
)
elif is_extension_array_dtype(left.dtype) and is_extension_array_dtype(right.dtype):
assert_extension_array_equal(
left._values, right._values, index_values=np.asarray(left.index)
left._values,
right._values,
check_dtype=check_dtype,
index_values=np.asarray(left.index),
)
elif needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype):
# DatetimeArray or TimedeltaArray
assert_extension_array_equal(
left._values, right._values, index_values=np.asarray(left.index)
left._values,
right._values,
check_dtype=check_dtype,
index_values=np.asarray(left.index),
)
else:
_testing.assert_almost_equal(
Expand Down
9 changes: 9 additions & 0 deletions pandas/tests/util/test_assert_extension_array_equal.py
@@ -1,6 +1,7 @@
import numpy as np
import pytest

from pandas import array
import pandas._testing as tm
from pandas.core.arrays.sparse import SparseArray

Expand Down Expand Up @@ -102,3 +103,11 @@ def test_assert_extension_array_equal_non_extension_array(side):

with pytest.raises(AssertionError, match=msg):
tm.assert_extension_array_equal(*args)


@pytest.mark.parametrize("right_dtype", ["Int32", "int64"])
def test_assert_extension_array_equal_ignore_dtype_mismatch(right_dtype):
# https://github.com/pandas-dev/pandas/issues/35715
left = array([1, 2, 3], dtype="Int64")
right = array([1, 2, 3], dtype=right_dtype)
tm.assert_extension_array_equal(left, right, check_dtype=False)
8 changes: 8 additions & 0 deletions pandas/tests/util/test_assert_frame_equal.py
Expand Up @@ -260,3 +260,11 @@ def test_assert_frame_equal_interval_dtype_mismatch():

with pytest.raises(AssertionError, match=msg):
tm.assert_frame_equal(left, right, check_dtype=True)


@pytest.mark.parametrize("right_dtype", ["Int32", "int64"])
def test_assert_frame_equal_ignore_extension_dtype_mismatch(right_dtype):
# https://github.com/pandas-dev/pandas/issues/35715
left = pd.DataFrame({"a": [1, 2, 3]}, dtype="Int64")
right = pd.DataFrame({"a": [1, 2, 3]}, dtype=right_dtype)
tm.assert_frame_equal(left, right, check_dtype=False)
8 changes: 8 additions & 0 deletions pandas/tests/util/test_assert_series_equal.py
Expand Up @@ -296,3 +296,11 @@ def test_series_equal_exact_for_nonnumeric():
tm.assert_series_equal(s1, s3, check_exact=True)
with pytest.raises(AssertionError):
tm.assert_series_equal(s3, s1, check_exact=True)


@pytest.mark.parametrize("right_dtype", ["Int32", "int64"])
def test_assert_series_equal_ignore_extension_dtype_mismatch(right_dtype):
# https://github.com/pandas-dev/pandas/issues/35715
left = pd.Series([1, 2, 3], dtype="Int64")
right = pd.Series([1, 2, 3], dtype=right_dtype)
tm.assert_series_equal(left, right, check_dtype=False)

0 comments on commit 56e95ad

Please sign in to comment.