Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

11028 - Fix warlus operator behavior when called by a function #11041

Merged
merged 5 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/11028.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug related to the walrus operator in asserts. Now a variable that is assigned with the walrus operator can be used later in a function call.
aless10 marked this conversation as resolved.
Show resolved Hide resolved
25 changes: 22 additions & 3 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,9 @@
]
):
pytest_temp = self.variable()
self.variables_overwrite[v.left.target.id] = pytest_temp
self.variables_overwrite[

Check warning on line 999 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L999

Added line #L999 was not covered by tests
v.left.target.id
] = v.left # type:ignore[assignment]
v.left.target.id = pytest_temp
self.push_format_context()
res, expl = self.visit(v)
Expand Down Expand Up @@ -1037,10 +1039,19 @@
new_args = []
new_kwargs = []
for arg in call.args:
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite:
arg = self.variables_overwrite[arg.id] # type:ignore[assignment]

Check warning on line 1043 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1043

Added line #L1043 was not covered by tests
res, expl = self.visit(arg)
arg_expls.append(expl)
new_args.append(res)
for keyword in call.keywords:
if (
isinstance(keyword.value, ast.Name)
and keyword.value.id in self.variables_overwrite
):
keyword.value = self.variables_overwrite[

Check warning on line 1052 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1052

Added line #L1052 was not covered by tests
keyword.value.id
] # type:ignore[assignment]
res, expl = self.visit(keyword.value)
new_kwargs.append(ast.keyword(keyword.arg, res))
if keyword.arg:
Expand Down Expand Up @@ -1075,7 +1086,13 @@
self.push_format_context()
# We first check if we have overwritten a variable in the previous assert
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
comp.left.id = self.variables_overwrite[comp.left.id]
comp.left = self.variables_overwrite[

Check warning on line 1089 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1089

Added line #L1089 was not covered by tests
comp.left.id
] # type:ignore[assignment]
if isinstance(comp.left, namedExpr):
self.variables_overwrite[

Check warning on line 1093 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1093

Added line #L1093 was not covered by tests
comp.left.target.id
] = comp.left # type:ignore[assignment]
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
left_expl = f"({left_expl})"
Expand All @@ -1093,7 +1110,9 @@
and next_operand.target.id == left_res.id
):
next_operand.target.id = self.variable()
self.variables_overwrite[left_res.id] = next_operand.target.id
self.variables_overwrite[

Check warning on line 1113 in src/_pytest/assertion/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/_pytest/assertion/rewrite.py#L1113

Added line #L1113 was not covered by tests
left_res.id
] = next_operand # type:ignore[assignment]
next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
next_expl = f"({next_expl})"
Expand Down
90 changes: 90 additions & 0 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,96 @@ def test_walrus_operator_not_override_value():
assert result.ret == 0


@pytest.mark.skipif(
sys.version_info < (3, 8), reason="walrus operator not available in py<38"
)
class TestIssue11028:
def test_assertion_walrus_operator_in_operand(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
def test_in_string():
assert (obj := "foo") in obj
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_in_operand_json_dumps(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
import json

def test_json_encoder():
assert (obj := "foo") in json.dumps(obj)
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_equals_operand_function(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def f(a):
return a

def test_call_other_function_arg():
assert (obj := "foo") == f(obj)
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_equals_operand_function_keyword_arg(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def f(a='test'):
return a

def test_call_other_function_k_arg():
assert (obj := "foo") == f(a=obj)
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_equals_operand_function_arg_as_function(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def f(a='test'):
return a

def test_function_of_function():
assert (obj := "foo") == f(f(obj))
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_assertion_walrus_operator_gt_operand_function(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def add_one(a):
return a + 1

def test_gt():
assert (object := 4) > add_one(object)
aless10 marked this conversation as resolved.
Show resolved Hide resolved
"""
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"])


@pytest.mark.skipif(
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
)
Expand Down