Skip to content

Commit

Permalink
Merge pull request #8540 from hauntsaninja/assert310
Browse files Browse the repository at this point in the history
(cherry picked from commit af31c60)
  • Loading branch information
nicoddemus authored and asottile committed May 4, 2021
1 parent a506148 commit d8d6812
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ Sankt Petersbug
Segev Finer
Serhii Mozghovyi
Seth Junot
Shantanu Jain
Shubham Adep
Simon Gomizelj
Simon Kerr
Expand Down
1 change: 1 addition & 0 deletions changelog/8539.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed assertion rewriting on Python 3.10.
26 changes: 20 additions & 6 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,12 +673,9 @@ def run(self, mod: ast.Module) -> None:
if not mod.body:
# Nothing to do.
return
# Insert some special imports at the top of the module but after any
# docstrings and __future__ imports.
aliases = [
ast.alias("builtins", "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
]

# We'll insert some special imports at the top of the module, but after any
# docstrings and __future__ imports, so first figure out where that is.
doc = getattr(mod, "docstring", None)
expect_docstring = doc is None
if doc is not None and self.is_rewrite_disabled(doc):
Expand Down Expand Up @@ -710,10 +707,27 @@ def run(self, mod: ast.Module) -> None:
lineno = item.decorator_list[0].lineno
else:
lineno = item.lineno
# Now actually insert the special imports.
if sys.version_info >= (3, 10):
aliases = [
ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
ast.alias(
"_pytest.assertion.rewrite",
"@pytest_ar",
lineno=lineno,
col_offset=0,
),
]
else:
aliases = [
ast.alias("builtins", "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
]
imports = [
ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
]
mod.body[pos:pos] = imports

# Collect asserts.
nodes: List[ast.AST] = [mod]
while nodes:
Expand Down

0 comments on commit d8d6812

Please sign in to comment.