From 7259e8db9844f6f973c1d0c0ce46cc68c8248abb Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sat, 9 Sep 2023 14:09:31 +0200 Subject: [PATCH] Fix assert rewriting with assignment expressions (#11414) Fixes #11239 --- AUTHORS | 1 + changelog/11239.bugfix.rst | 1 + src/_pytest/assertion/rewrite.py | 54 +++++++++++++++++++++++--------- testing/test_assertrewrite.py | 21 +++++++++++++ 4 files changed, 63 insertions(+), 14 deletions(-) create mode 100644 changelog/11239.bugfix.rst diff --git a/AUTHORS b/AUTHORS index 466779f6d11..e9e033c73f0 100644 --- a/AUTHORS +++ b/AUTHORS @@ -235,6 +235,7 @@ Maho Maik Figura Mandeep Bhutani Manuel Krebber +Marc Mueller Marc Schlaich Marcelo Duarte Trevisani Marcin Bachry diff --git a/changelog/11239.bugfix.rst b/changelog/11239.bugfix.rst new file mode 100644 index 00000000000..a486224cdda --- /dev/null +++ b/changelog/11239.bugfix.rst @@ -0,0 +1 @@ +Fixed ``:=`` in asserts impacting unrelated test cases. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 9bf79f1e107..258ed9f9ab0 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -13,6 +13,7 @@ import sys import tokenize import types +from collections import defaultdict from pathlib import Path from pathlib import PurePath from typing import Callable @@ -45,6 +46,10 @@ from _pytest.assertion import AssertionState +class Sentinel: + pass + + assertstate_key = StashKey["AssertionState"]() # pytest caches rewritten pycs in pycache dirs @@ -52,6 +57,9 @@ PYC_EXT = ".py" + (__debug__ and "c" or "o") PYC_TAIL = "." + PYTEST_TAG + PYC_EXT +# Special marker that denotes we have just left a scope definition +_SCOPE_END_MARKER = Sentinel() + class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): """PEP302/PEP451 import hook which rewrites asserts.""" @@ -634,6 +642,8 @@ class AssertionRewriter(ast.NodeVisitor): .push_format_context() and .pop_format_context() which allows to build another %-formatted string while already building one. + :scope: A tuple containing the current scope used for variables_overwrite. + :variables_overwrite: A dict filled with references to variables that change value within an assert. This happens when a variable is reassigned with the walrus operator @@ -655,7 +665,10 @@ def __init__( else: self.enable_assertion_pass_hook = False self.source = source - self.variables_overwrite: Dict[str, str] = {} + self.scope: tuple[ast.AST, ...] = () + self.variables_overwrite: defaultdict[ + tuple[ast.AST, ...], Dict[str, str] + ] = defaultdict(dict) def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -719,9 +732,17 @@ def run(self, mod: ast.Module) -> None: mod.body[pos:pos] = imports # Collect asserts. - nodes: List[ast.AST] = [mod] + self.scope = (mod,) + nodes: List[Union[ast.AST, Sentinel]] = [mod] while nodes: node = nodes.pop() + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + self.scope = tuple((*self.scope, node)) + nodes.append(_SCOPE_END_MARKER) + if node == _SCOPE_END_MARKER: + self.scope = self.scope[:-1] + continue + assert isinstance(node, ast.AST) for name, field in ast.iter_fields(node): if isinstance(field, list): new: List[ast.AST] = [] @@ -992,7 +1013,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: ] ): pytest_temp = self.variable() - self.variables_overwrite[ + self.variables_overwrite[self.scope][ v.left.target.id ] = v.left # type:ignore[assignment] v.left.target.id = pytest_temp @@ -1035,17 +1056,20 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: new_args = [] new_kwargs = [] for arg in call.args: - if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite: - arg = self.variables_overwrite[arg.id] # type:ignore[assignment] + if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get( + self.scope, {} + ): + arg = self.variables_overwrite[self.scope][ + arg.id + ] # type:ignore[assignment] res, expl = self.visit(arg) arg_expls.append(expl) new_args.append(res) for keyword in call.keywords: - if ( - isinstance(keyword.value, ast.Name) - and keyword.value.id in self.variables_overwrite - ): - keyword.value = self.variables_overwrite[ + if isinstance( + keyword.value, ast.Name + ) and keyword.value.id in self.variables_overwrite.get(self.scope, {}): + keyword.value = self.variables_overwrite[self.scope][ keyword.value.id ] # type:ignore[assignment] res, expl = self.visit(keyword.value) @@ -1081,12 +1105,14 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() # We first check if we have overwritten a variable in the previous assert - if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite: - comp.left = self.variables_overwrite[ + if isinstance( + comp.left, ast.Name + ) and comp.left.id in self.variables_overwrite.get(self.scope, {}): + comp.left = self.variables_overwrite[self.scope][ comp.left.id ] # type:ignore[assignment] if isinstance(comp.left, ast.NamedExpr): - self.variables_overwrite[ + self.variables_overwrite[self.scope][ comp.left.target.id ] = comp.left # type:ignore[assignment] left_res, left_expl = self.visit(comp.left) @@ -1106,7 +1132,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: and next_operand.target.id == left_res.id ): next_operand.target.id = self.variable() - self.variables_overwrite[ + self.variables_overwrite[self.scope][ left_res.id ] = next_operand # type:ignore[assignment] next_res, next_expl = self.visit(next_operand) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 08813c4dcf0..b3fd0c2f2e7 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1543,6 +1543,27 @@ def test_gt(): result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"]) +class TestIssue11239: + def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None: + """Regression for (#11239) + + Walrus operator rewriting would leak to separate test cases if they used the same variables. + """ + pytester.makepyfile( + """ + def test_1(): + state = {"x": 2}.get("x") + assert state is not None + + def test_2(): + db = {"x": 2} + assert (state := db.get("x")) is not None + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" )