diff --git a/flake8_trio/visitors/visitor91x.py b/flake8_trio/visitors/visitor91x.py index 973cf29e..64f13199 100644 --- a/flake8_trio/visitors/visitor91x.py +++ b/flake8_trio/visitors/visitor91x.py @@ -769,9 +769,11 @@ def visit_CompFor(self, node: cst.CompFor): return False - # We don't have any logic on if generators are guaranteed to unroll, so always - # ignore their content by not visiting subnodes. + # The generator target will be immediately evaluated, but the other + # elements will be lazily evaluated as the generator is consumed so we don't + # visit them as any checkpoints in them are not guaranteed to execute. def visit_GeneratorExp(self, node: cst.GeneratorExp): + node.for_in.iter.visit(self) return False def visit_Import(self, node: cst.Import): diff --git a/tests/autofix_files/trio910.py b/tests/autofix_files/trio910.py index ae4f6d56..fe301a18 100644 --- a/tests/autofix_files/trio910.py +++ b/tests/autofix_files/trio910.py @@ -583,3 +583,19 @@ async def foo_comprehension_2(): # error: 0, "exit", Statement("function defini async def foo_comprehension_3(): [... async for x in bar()] + + +# Issue #714 +# (await x async for y in await z) +# ^ ^ ^ this always runs! +# ^ ^ this might not run +# ^ this might not run + + +async def await_in_gen_target(): + (print(x) for x in await foo()) + + +async def await_everywhere_except_gen_target(): # error: 0, "exit", Statement("function definition", lineno) + (await x async for x in bar()) + await trio.lowlevel.checkpoint() diff --git a/tests/autofix_files/trio910.py.diff b/tests/autofix_files/trio910.py.diff index d563f817..d36009bd 100644 --- a/tests/autofix_files/trio910.py.diff +++ b/tests/autofix_files/trio910.py.diff @@ -205,3 +205,8 @@ async def foo_comprehension_3(): +@@ x,3 x,4 @@ + + async def await_everywhere_except_gen_target(): # error: 0, "exit", Statement("function definition", lineno) + (await x async for x in bar()) ++ await trio.lowlevel.checkpoint() diff --git a/tests/eval_files/trio910.py b/tests/eval_files/trio910.py index 930694ed..87707882 100644 --- a/tests/eval_files/trio910.py +++ b/tests/eval_files/trio910.py @@ -555,3 +555,18 @@ async def foo_comprehension_2(): # error: 0, "exit", Statement("function defini async def foo_comprehension_3(): [... async for x in bar()] + + +# Issue #714 +# (await x async for y in await z) +# ^ ^ ^ this always runs! +# ^ ^ this might not run +# ^ this might not run + + +async def await_in_gen_target(): + (print(x) for x in await foo()) + + +async def await_everywhere_except_gen_target(): # error: 0, "exit", Statement("function definition", lineno) + (await x async for x in bar())