Skip to content

Commit

Permalink
Fix all() unroll for non-generators/non-list comprehensions (#5360)
Browse files Browse the repository at this point in the history
Fix all() unroll for non-generators/non-list comprehensions
  • Loading branch information
nicoddemus authored and asottile committed Jun 2, 2019
1 parent dba62f8 commit f078984
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 6 deletions.
1 change: 1 addition & 0 deletions changelog/5358.bugfix.rst
@@ -0,0 +1 @@
Fix assertion rewriting of ``all()`` calls to deal with non-generators.
16 changes: 12 additions & 4 deletions src/_pytest/assertion/rewrite.py
Expand Up @@ -949,11 +949,21 @@ def visit_BinOp(self, binop):
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
return res, explanation

@staticmethod
def _is_any_call_with_generator_or_list_comprehension(call):
"""Return True if the Call node is an 'any' call with a generator or list comprehension"""
return (
isinstance(call.func, ast.Name)
and call.func.id == "all"
and len(call.args) == 1
and isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp))
)

def visit_Call_35(self, call):
"""
visit `ast.Call` nodes on Python3.5 and after
"""
if isinstance(call.func, ast.Name) and call.func.id == "all":
if self._is_any_call_with_generator_or_list_comprehension(call):
return self._visit_all(call)
new_func, func_expl = self.visit(call.func)
arg_expls = []
Expand All @@ -980,8 +990,6 @@ def visit_Call_35(self, call):

def _visit_all(self, call):
"""Special rewrite for the builtin all function, see #5062"""
if not isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp)):
return
gen_exp = call.args[0]
assertion_module = ast.Module(
body=[ast.Assert(test=gen_exp.elt, lineno=1, msg="", col_offset=1)]
Expand Down Expand Up @@ -1009,7 +1017,7 @@ def visit_Call_legacy(self, call):
"""
visit `ast.Call nodes on 3.4 and below`
"""
if isinstance(call.func, ast.Name) and call.func.id == "all":
if self._is_any_call_with_generator_or_list_comprehension(call):
return self._visit_all(call)
new_func, func_expl = self.visit(call.func)
arg_expls = []
Expand Down
29 changes: 27 additions & 2 deletions testing/test_assertrewrite.py
Expand Up @@ -677,7 +677,7 @@ def __repr__(self):
assert "UnicodeDecodeError" not in msg
assert "UnicodeEncodeError" not in msg

def test_unroll_generator(self, testdir):
def test_unroll_all_generator(self, testdir):
testdir.makepyfile(
"""
def check_even(num):
Expand All @@ -692,7 +692,7 @@ def test_generator():
result = testdir.runpytest()
result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])

def test_unroll_list_comprehension(self, testdir):
def test_unroll_all_list_comprehension(self, testdir):
testdir.makepyfile(
"""
def check_even(num):
Expand All @@ -707,6 +707,31 @@ def test_list_comprehension():
result = testdir.runpytest()
result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"])

def test_unroll_all_object(self, testdir):
"""all() for non generators/non list-comprehensions (#5358)"""
testdir.makepyfile(
"""
def test():
assert all((1, 0))
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(["*assert False*", "*where False = all((1, 0))*"])

def test_unroll_all_starred(self, testdir):
"""all() for non generators/non list-comprehensions (#5358)"""
testdir.makepyfile(
"""
def test():
x = ((1, 0),)
assert all(*x)
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(
["*assert False*", "*where False = all(*((1, 0),))*"]
)

def test_for_loop(self, testdir):
testdir.makepyfile(
"""
Expand Down

0 comments on commit f078984

Please sign in to comment.