Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/14445.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed assertion rewriting evaluating walrus operator (``:=``) expressions multiple times, causing incorrect test results when the expression had side effects (e.g., incrementing a counter or calling a function).
104 changes: 37 additions & 67 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import ast
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Iterator
Expand Down Expand Up @@ -58,20 +57,13 @@
from _pytest.assertion import AssertionState


class Sentinel:
pass


assertstate_key = StashKey["AssertionState"]()

# pytest caches rewritten pycs in pycache dirs
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
PYC_EXT = ".py" + ((__debug__ and "c") or "o")
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT

# Special marker that denotes we have just left a scope definition
_SCOPE_END_MARKER = Sentinel()


class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
"""PEP302/PEP451 import hook which rewrites asserts."""
Expand Down Expand Up @@ -652,14 +644,8 @@ class AssertionRewriter(ast.NodeVisitor):
.push_format_context() and .pop_format_context() which allows
to build another %-formatted string while already building one.

:scope: A tuple containing the current scope used for variables_overwrite.

:variables_overwrite: A dict filled with references to variables
that change value within an assert. This happens when a variable is
reassigned with the walrus operator

This state, except the variables_overwrite, is reset on every new assert
statement visited and used by the other visitors.
This state is reset on every new assert statement visited and used by
the other visitors.
"""

def __init__(
Expand All @@ -675,10 +661,6 @@ def __init__(
else:
self.enable_assertion_pass_hook = False
self.source = source
self.scope: tuple[ast.AST, ...] = ()
self.variables_overwrite: defaultdict[tuple[ast.AST, ...], dict[str, str]] = (
defaultdict(dict)
)

def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them."""
Expand Down Expand Up @@ -728,16 +710,9 @@ def run(self, mod: ast.Module) -> None:
mod.body[pos:pos] = imports

# Collect asserts.
self.scope = (mod,)
nodes: list[ast.AST | Sentinel] = [mod]
nodes: list[ast.AST] = [mod]
while nodes:
node = nodes.pop()
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef):
self.scope = tuple((*self.scope, node))
nodes.append(_SCOPE_END_MARKER)
if node == _SCOPE_END_MARKER:
self.scope = self.scope[:-1]
continue
assert isinstance(node, ast.AST)
for name, field in ast.iter_fields(node):
if isinstance(field, list):
Expand Down Expand Up @@ -964,15 +939,17 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
return self.statements

def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
# This method handles the 'walrus operator' repr of the target
# name if it's a local variable or _should_repr_global_name()
# thinks it's acceptable.
# Return the NamedExpr as-is so it evaluates in its natural position
# (preserving left-to-right evaluation order). For the explanation,
# reference the target variable (already assigned by the walrus) to
# avoid re-evaluating the expression.
locs = ast.Call(self.builtin("locals"), [], [])
target_id = name.target.id
target_name = ast.Name(target_id, ast.Load())
inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", name)
dorepr = self.helper("_should_repr_global_name", target_name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(name), ast.Constant(target_id))
expr = ast.IfExp(test, self.display(target_name), ast.Constant(target_id))
return name, self.explanation_param(expr)

def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]:
Expand All @@ -998,32 +975,30 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
for i, v in enumerate(boolop.values):
if i:
fail_inner: list[ast.stmt] = []
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821
# expl_cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(expl_cond, fail_inner, [])) # noqa: F821
self.expl_stmts = fail_inner
match v:
# Check if the left operand is an ast.NamedExpr and the value has already been visited
case ast.Compare(
left=ast.NamedExpr(target=ast.Name(id=target_id))
) if target_id in [
e.id for e in boolop.values[:i] if hasattr(e, "id")
]:
pytest_temp = self.variable()
self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment]
# mypy's false positive, we're checking that the 'target' attribute exists.
v.left.target.id = pytest_temp # type:ignore[attr-defined]
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
expl_format = self.pop_format_context(ast.Constant(expl))
call = ast.Call(app, [expl_format], [])
self.expl_stmts.append(ast.Expr(call))
if i < levels:
cond: ast.expr = res
# Use res_var (already assigned above) rather than res directly,
# so that NamedExpr operands aren't evaluated a second time.
cond: ast.expr = ast.Name(res_var, ast.Load())
if is_or:
cond = ast.UnaryOp(ast.Not(), cond)
# Capture the condition in a stable temp for the explanation
# path — res_var is overwritten by subsequent operands.
cond_var = self.variable()
body.append(ast.Assign([ast.Name(cond_var, ast.Store())], cond))
expl_cond: ast.expr = ast.Name(cond_var, ast.Load()) # noqa: F841
inner: list[ast.stmt] = []
self.statements.append(ast.If(cond, inner, []))
self.statements.append(
ast.If(ast.Name(cond_var, ast.Load()), inner, [])
)
self.statements = body = inner
self.statements = save
self.expl_stmts = fail_save
Expand Down Expand Up @@ -1053,19 +1028,10 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
new_args = []
new_kwargs = []
for arg in call.args:
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
self.scope, {}
):
arg = self.variables_overwrite[self.scope][arg.id] # type:ignore[assignment]
res, expl = self.visit(arg)
arg_expls.append(expl)
new_args.append(res)
for keyword in call.keywords:
match keyword.value:
case ast.Name(id=id) if id in self.variables_overwrite.get(
self.scope, {}
):
keyword.value = self.variables_overwrite[self.scope][id] # type:ignore[assignment]
res, expl = self.visit(keyword.value)
new_kwargs.append(ast.keyword(keyword.arg, res))
if keyword.arg:
Expand Down Expand Up @@ -1100,17 +1066,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:

def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
self.push_format_context()
# We first check if we have overwritten a variable in the previous assert
match comp.left:
case ast.Name(id=name_id) if name_id in self.variables_overwrite.get(
self.scope, {}
):
comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment]
case ast.NamedExpr(target=ast.Name(id=target_id)):
self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment]
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, ast.Compare | ast.BoolOp):
left_expl = f"({left_expl})"
# If the left operand is a NamedExpr, assign it to a temp so the
# walrus executes before any right-side expressions are hoisted.
if isinstance(left_res, ast.NamedExpr):
left_res = self.assign(left_res)
res_variables = [self.variable() for i in range(len(comp.ops))]
load_names: list[ast.expr] = [ast.Name(v, ast.Load()) for v in res_variables]
store_names = [ast.Name(v, ast.Store()) for v in res_variables]
Expand All @@ -1119,17 +1081,25 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
syms: list[ast.expr] = []
results = [left_res]
for i, op, next_operand in it:
# If the next operand is a walrus that assigns to the same name as
# the current left_res, we must freeze left_res's value before the
# walrus modifies it.
match (next_operand, left_res):
case (
ast.NamedExpr(target=ast.Name(id=target_id)),
ast.Name(id=name_id),
) if target_id == name_id:
next_operand.target.id = self.variable()
self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment]
left_res = self.assign(left_res)
results[-1] = left_res

next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, ast.Compare | ast.BoolOp):
next_expl = f"({next_expl})"
# Assign NamedExpr comparators to a temp so each walrus evaluates
# exactly once — critical for chained comparisons where the same
# node would otherwise be re-evaluated as left_res next iteration.
if isinstance(next_res, ast.NamedExpr):
next_res = self.assign(next_res)
results.append(next_res)
sym = BINOP_MAP[op.__class__]
syms.append(ast.Constant(sym))
Expand Down
106 changes: 104 additions & 2 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ def test_walrus_operator_change_boolean_value():
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*assert not (True and False is False)"])
result.stdout.fnmatch_lines(["*assert not (False and False is False)"])

def test_assertion_walrus_operator_boolean_none_fails(
self, pytester: Pytester
Expand All @@ -1702,7 +1702,7 @@ def test_walrus_operator_change_boolean_value():
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*assert not (True and None is None)"])
result.stdout.fnmatch_lines(["*assert not (None and None is None)"])

def test_assertion_walrus_operator_value_changes_cleared_after_each_test(
self, pytester: Pytester
Expand Down Expand Up @@ -1846,6 +1846,108 @@ def test_2():
assert result.ret == 0


class TestIssue14445:
"""Regression tests for #14445: walrus operator double evaluation."""

def test_walrus_no_double_eval_basic(self, pytester: Pytester) -> None:
"""Walrus captures the value at assignment time, not re-evaluated later."""
pytester.makepyfile(
"""
class Counter:
def __init__(self):
self.value = 0
def increment(self):
self.value += 1

def test_walrus_in_assertion_basic():
c = Counter()
assert (before := c.value) == 0
c.increment()
assert before != (after := c.value)
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_walrus_no_double_eval_running_counter(self, pytester: Pytester) -> None:
"""Walrus increments fire exactly once per assert statement."""
pytester.makepyfile(
"""
def test_walrus_running_counter():
count = 0
items = []
items.append("a")
assert (count := count + 1) == len(items)
items.append("b")
assert (count := count + 1) == len(items)
items.append("c")
assert (count := count + 1) == len(items)
assert count == 3
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_walrus_no_double_eval_in_function_call(self, pytester: Pytester) -> None:
"""Walrus in function call arguments not evaluated twice."""
pytester.makepyfile(
"""
call_count = 0

def side_effect():
global call_count
call_count += 1
return call_count

def test_walrus_side_effect():
assert (val := side_effect()) == 1
assert val == 1
assert (val := side_effect()) == 2
assert val == 2
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_walrus_no_double_eval_in_boolop(self, pytester: Pytester) -> None:
"""Bare walrus as a BoolOp operand must not be evaluated twice."""
pytester.makepyfile(
"""
call_count = 0

def side_effect():
global call_count
call_count += 1
return call_count

def test_walrus_boolop():
assert (x := side_effect()) and x == 1
assert call_count == 1
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_walrus_no_double_eval_chained_compare(self, pytester: Pytester) -> None:
"""Same walrus target in chained comparison must evaluate each once."""
pytester.makepyfile(
"""
call_count = 0

def track(value):
global call_count
call_count += 1
return value

def test_walrus_chained():
assert (x := track(1)) < (x := track(3)) < (x := track(5))
assert call_count == 3
"""
)
result = pytester.runpytest()
assert result.ret == 0


@pytest.mark.skipif(
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
)
Expand Down
Loading