Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PERF: assert_frame_equal and assert_series_equal for frames/series with a MultiIndex #55949

Merged
merged 3 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ Other Deprecations

Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~
- Performance improvement in :func:`.testing.assert_frame_equal` and :func:`.testing.assert_series_equal` for objects indexed by a :class:`MultiIndex` (:issue:`55949`)
- Performance improvement in :func:`concat` with ``axis=1`` and objects with unaligned indexes (:issue:`55084`)
- Performance improvement in :func:`merge_asof` when ``by`` is not ``None`` (:issue:`55580`, :issue:`55678`)
- Performance improvement in :func:`read_stata` for files with many variables (:issue:`55515`)
Expand Down
2 changes: 1 addition & 1 deletion pandas/_libs/testing.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ cpdef assert_almost_equal(a, b,
robj : str, default None
Specify right object name being compared, internally used to show
appropriate assertion message.
index_values : ndarray, default None
index_values : Index | ndarray, default None
Specify shared index values of objects being compared, internally used
to show appropriate assertion message.

Expand Down
66 changes: 42 additions & 24 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,22 +283,37 @@ def _get_ilevel_values(index, level):
right = cast(MultiIndex, right)

for level in range(left.nlevels):
# cannot use get_level_values here because it can change dtype
llevel = _get_ilevel_values(left, level)
rlevel = _get_ilevel_values(right, level)

lobj = f"MultiIndex level [{level}]"
assert_index_equal(
llevel,
rlevel,
exact=exact,
check_names=check_names,
check_exact=check_exact,
check_categorical=check_categorical,
rtol=rtol,
atol=atol,
obj=lobj,
)
try:
# try comparison on levels/codes to avoid densifying MultiIndex
assert_index_equal(
left.levels[level],
right.levels[level],
exact=exact,
check_names=check_names,
check_exact=check_exact,
check_categorical=check_categorical,
rtol=rtol,
atol=atol,
obj=lobj,
)
assert_numpy_array_equal(left.codes[level], right.codes[level])
except AssertionError:
# cannot use get_level_values here because it can change dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does it change dtype?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thats an old comment and I'm not sure its valid anymore. opened #55971 to remove

llevel = _get_ilevel_values(left, level)
rlevel = _get_ilevel_values(right, level)

assert_index_equal(
llevel,
rlevel,
exact=exact,
check_names=check_names,
check_exact=check_exact,
check_categorical=check_categorical,
rtol=rtol,
atol=atol,
obj=lobj,
)
# get_level_values may change dtype
_check_types(left.levels[level], right.levels[level], obj=obj)

Expand Down Expand Up @@ -576,6 +591,9 @@ def raise_assert_detail(

{message}"""

if isinstance(index_values, Index):
index_values = np.array(index_values)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

array->asarray can avoid a copy

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opened #55971 to update


if isinstance(index_values, np.ndarray):
msg += f"\n[index]: {pprint_thing(index_values)}"

Expand Down Expand Up @@ -630,7 +648,7 @@ def assert_numpy_array_equal(
obj : str, default 'numpy array'
Specify object name being compared, internally used to show appropriate
assertion message.
index_values : numpy.ndarray, default None
index_values : Index | numpy.ndarray, default None
optional index (shared by both left and right), used in output.
"""
__tracebackhide__ = True
Expand Down Expand Up @@ -701,7 +719,7 @@ def assert_extension_array_equal(
The two arrays to compare.
check_dtype : bool, default True
Whether to check if the ExtensionArray dtypes are identical.
index_values : numpy.ndarray, default None
index_values : Index | numpy.ndarray, default None
Optional index (shared by both left and right), used in output.
check_exact : bool, default False
Whether to compare number exactly.
Expand Down Expand Up @@ -932,7 +950,7 @@ def assert_series_equal(
left_values,
right_values,
check_dtype=check_dtype,
index_values=np.asarray(left.index),
index_values=left.index,
obj=str(obj),
)
else:
Expand All @@ -941,7 +959,7 @@ def assert_series_equal(
right_values,
check_dtype=check_dtype,
obj=str(obj),
index_values=np.asarray(left.index),
index_values=left.index,
)
elif check_datetimelike_compat and (
needs_i8_conversion(left.dtype) or needs_i8_conversion(right.dtype)
Expand Down Expand Up @@ -972,7 +990,7 @@ def assert_series_equal(
atol=atol,
check_dtype=bool(check_dtype),
obj=str(obj),
index_values=np.asarray(left.index),
index_values=left.index,
)
elif isinstance(left.dtype, ExtensionDtype) and isinstance(
right.dtype, ExtensionDtype
Expand All @@ -983,7 +1001,7 @@ def assert_series_equal(
rtol=rtol,
atol=atol,
check_dtype=check_dtype,
index_values=np.asarray(left.index),
index_values=left.index,
obj=str(obj),
)
elif is_extension_array_dtype_and_needs_i8_conversion(
Expand All @@ -993,7 +1011,7 @@ def assert_series_equal(
left._values,
right._values,
check_dtype=check_dtype,
index_values=np.asarray(left.index),
index_values=left.index,
obj=str(obj),
)
elif needs_i8_conversion(left.dtype) and needs_i8_conversion(right.dtype):
Expand All @@ -1002,7 +1020,7 @@ def assert_series_equal(
left._values,
right._values,
check_dtype=check_dtype,
index_values=np.asarray(left.index),
index_values=left.index,
obj=str(obj),
)
else:
Expand All @@ -1013,7 +1031,7 @@ def assert_series_equal(
atol=atol,
check_dtype=bool(check_dtype),
obj=str(obj),
index_values=np.asarray(left.index),
index_values=left.index,
)

# metadata comparison
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/frame/methods/test_value_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_data_frame_value_counts_dropna_false(nulls_fixture):
index=pd.MultiIndex(
levels=[
pd.Index(["Anne", "Beth", "John"]),
pd.Index(["Louise", "Smith", nulls_fixture]),
pd.Index(["Louise", "Smith", np.nan]),
],
codes=[[0, 1, 2, 2], [2, 0, 1, 2]],
names=["first_name", "middle_name"],
Expand Down
Loading