From d8d6812bdf2e5d12e3d57baab648df3e869950f0 Mon Sep 17 00:00:00 2001 From: Bruno Oliveira Date: Thu, 15 Apr 2021 08:55:42 -0300 Subject: [PATCH] Merge pull request #8540 from hauntsaninja/assert310 (cherry picked from commit af31c60db1dc4e513f47aabf8f6e844b23afd35f) --- AUTHORS | 1 + changelog/8539.bugfix.rst | 1 + src/_pytest/assertion/rewrite.py | 26 ++++++++++++++++++++------ 3 files changed, 22 insertions(+), 6 deletions(-) create mode 100644 changelog/8539.bugfix.rst diff --git a/AUTHORS b/AUTHORS index b35bebf7af3..2c690c5d28f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -273,6 +273,7 @@ Sankt Petersbug Segev Finer Serhii Mozghovyi Seth Junot +Shantanu Jain Shubham Adep Simon Gomizelj Simon Kerr diff --git a/changelog/8539.bugfix.rst b/changelog/8539.bugfix.rst new file mode 100644 index 00000000000..a2098610e29 --- /dev/null +++ b/changelog/8539.bugfix.rst @@ -0,0 +1 @@ +Fixed assertion rewriting on Python 3.10. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 805d4c8b35b..37ff076aab5 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -673,12 +673,9 @@ def run(self, mod: ast.Module) -> None: if not mod.body: # Nothing to do. return - # Insert some special imports at the top of the module but after any - # docstrings and __future__ imports. - aliases = [ - ast.alias("builtins", "@py_builtins"), - ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), - ] + + # We'll insert some special imports at the top of the module, but after any + # docstrings and __future__ imports, so first figure out where that is. doc = getattr(mod, "docstring", None) expect_docstring = doc is None if doc is not None and self.is_rewrite_disabled(doc): @@ -710,10 +707,27 @@ def run(self, mod: ast.Module) -> None: lineno = item.decorator_list[0].lineno else: lineno = item.lineno + # Now actually insert the special imports. + if sys.version_info >= (3, 10): + aliases = [ + ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0), + ast.alias( + "_pytest.assertion.rewrite", + "@pytest_ar", + lineno=lineno, + col_offset=0, + ), + ] + else: + aliases = [ + ast.alias("builtins", "@py_builtins"), + ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), + ] imports = [ ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases ] mod.body[pos:pos] = imports + # Collect asserts. nodes: List[ast.AST] = [mod] while nodes: