Skip to content

Commit

Permalink
Refactor and fix assert_expected_matched_actual
Browse files Browse the repository at this point in the history
This PR:

- Refactors assert_expected_matched_actual function to avoid repeated
  matching between expected and actual output
- Fixes typeddjango#63, typeddjango#64
  • Loading branch information
zero323 committed Oct 7, 2021
1 parent f7b249d commit ea253b8
Showing 1 changed file with 54 additions and 73 deletions.
127 changes: 54 additions & 73 deletions pytest_mypy_plugins/utils.py
Expand Up @@ -3,6 +3,7 @@
import contextlib
import inspect
import io
from itertools import zip_longest
import os
import re
import sys
Expand All @@ -11,6 +12,7 @@
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Expand Down Expand Up @@ -129,20 +131,6 @@ def remove_common_prefix(lines: List[str]) -> List[str]:
return cleaned_lines


def _num_skipped_prefix_lines(a1: List[OutputMatcher], a2: List[str]) -> int:
num_eq = 0
while num_eq < min(len(a1), len(a2)) and a1[num_eq].matches(a2[num_eq]):
num_eq += 1
return max(0, num_eq - 4)


def _num_skipped_suffix_lines(a1: List[OutputMatcher], a2: List[str]) -> int:
num_eq = 0
while num_eq < min(len(a1), len(a2)) and a1[-num_eq - 1].matches(a2[-num_eq - 1]):
num_eq += 1
return max(0, num_eq - 4)


def _add_aligned_message(s1: str, s2: str, error_message: str) -> str:
"""Align s1 and s2 so that the their first difference is highlighted.
Expand Down Expand Up @@ -224,78 +212,71 @@ def assert_expected_matched_actual(expected: List[OutputMatcher], actual: List[s
Display any differences in a human-readable form.
"""

def format_mismatched_line(line: str) -> str:
return " {:<45} (diff)".format(str(line))

def format_matched_line(line: str, width: int = 100) -> str:
return " {}...".format(line[:width]) if len(line) > width else " {}".format(line)

def format_error_lines(lines: List[str]) -> str:
return "\n".join(lines) if lines else " (empty)"

expected = sorted(expected, key=lambda om: (om.fname, om.lnum))
actual = sorted_by_file_and_line(remove_empty_lines(actual))

actual = remove_common_prefix(actual)
error_message = ""

if not all(e.matches(a) for e, a in zip(expected, actual)):
num_skip_start = _num_skipped_prefix_lines(expected, actual)
num_skip_end = _num_skipped_suffix_lines(expected, actual)

error_message += "Expected:\n"
diff_lines: Dict[int, Tuple[OutputMatcher, str]] = {
i: (e, a)
for i, (e, a) in enumerate(zip_longest(expected, actual))
if e is None or a is None or not e.matches(a)
}

# If omit some lines at the beginning, indicate it by displaying a line
# with '...'.
if num_skip_start > 0:
error_message += " ...\n"
if diff_lines:
first_diff_line = min(diff_lines.keys())
last_diff_line = max(diff_lines.keys())

# Keep track of the first different line.
first_diff = -1
expected_message_lines = []
actual_message_lines = []

# Display only this many first characters of identical lines.
width = 100
for i in range(first_diff_line, last_diff_line + 1):
if i in diff_lines:
expected_line, actual_line = diff_lines[i]
if expected_line:
expected_message_lines.append(format_mismatched_line(str(expected_line)))
if actual_line:
actual_message_lines.append(format_mismatched_line(actual_line))

for i in range(num_skip_start, len(expected) - num_skip_end):
if i >= len(actual) or not expected[i].matches(actual[i]):
if first_diff < 0:
first_diff = i
error_message += " {:<45} (diff)".format(expected[i])
else:
e = expected[i]
error_message += " " + str(e)[:width]
if len(e) > width:
error_message += "..."
error_message += "\n"
if num_skip_end > 0:
error_message += " ...\n"

error_message += "Actual:\n"

if num_skip_start > 0:
error_message += " ...\n"

for j in range(num_skip_start, len(actual) - num_skip_end):
if j >= len(expected) or not expected[j].matches(actual[j]):
error_message += " {:<45} (diff)".format(actual[j])
else:
a = actual[j]
error_message += " " + a[:width]
if len(a) > width:
error_message += "..."
error_message += "\n"
if not actual:
error_message += " (empty)\n"
if num_skip_end > 0:
error_message += " ...\n"

error_message += "\n"

if 0 <= first_diff < len(actual) and (
len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
expected_line, actual_line = expected[i], actual[i]
actual_message_lines.append(format_matched_line(actual_line))
expected_message_lines.append(format_matched_line(str(expected_line)))

error_message = "Actual:\n{}\nExpected:\n{}\n".format(
format_error_lines(actual_message_lines), format_error_lines(expected_message_lines)
)

first_diff_expected, first_diff_actual = diff_lines[first_diff_line]

if (
first_diff_actual is not None
and first_diff_expected is not None
and (
len(first_diff_actual) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
or len(str(first_diff_expected)) >= MIN_LINE_LENGTH_FOR_ALIGNMENT
)
):
# Display message that helps visualize the differences between two
# long lines.
error_message = _add_aligned_message(str(expected[first_diff]), actual[first_diff], error_message)
error_message = _add_aligned_message(str(first_diff_expected), first_diff_actual, error_message)

if len(expected) == 0:
raise TypecheckAssertionError(f"Output is not expected: \n{error_message}")
failure_reason = (
"Invalid output" if first_diff_actual and first_diff_expected is None else "Output is not expected"
)

first_failure = expected[first_diff]
if first_failure:
raise TypecheckAssertionError(error_message=f"Invalid output: \n{error_message}", lineno=first_failure.lnum)
raise TypecheckAssertionError(
error_message=f"{failure_reason}: \n{error_message}",
lineno=first_diff_expected.lnum if first_diff_expected else 0,
)


def extract_output_matchers_from_comments(fname: str, input_lines: List[str], regex: bool) -> List[OutputMatcher]:
Expand Down

0 comments on commit ea253b8

Please sign in to comment.