Skip to content

Commit

Permalink
Fix async iterator body stripping (#15491)
Browse files Browse the repository at this point in the history
Fixes #15489
  • Loading branch information
hauntsaninja committed Jun 24, 2023
1 parent ba7887b commit f36ea01
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 28 deletions.
58 changes: 33 additions & 25 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,40 +521,48 @@ def translate_stmt_list(
return [block]

stack = self.class_and_function_stack
if self.strip_function_bodies and len(stack) == 1 and stack[0] == "F":
# Fast case for stripping function bodies
if (
can_strip
and self.strip_function_bodies
and len(stack) == 1
and stack[0] == "F"
and not is_coroutine
):
return []

res: list[Statement] = []
for stmt in stmts:
node = self.visit(stmt)
res.append(node)

if (
self.strip_function_bodies
and can_strip
and stack[-2:] == ["C", "F"]
and not is_possible_trivial_body(res)
):
# We only strip method bodies if they don't assign to an attribute, as
# this may define an attribute which has an externally visible effect.
visitor = FindAttributeAssign()
for s in res:
s.accept(visitor)
if visitor.found:
break
else:
if is_coroutine:
# Yields inside an async function affect the return type and should not
# be stripped.
yield_visitor = FindYield()
# Slow case for stripping function bodies
if can_strip and self.strip_function_bodies:
if stack[-2:] == ["C", "F"]:
if is_possible_trivial_body(res):
can_strip = False
else:
# We only strip method bodies if they don't assign to an attribute, as
# this may define an attribute which has an externally visible effect.
visitor = FindAttributeAssign()
for s in res:
s.accept(yield_visitor)
if yield_visitor.found:
s.accept(visitor)
if visitor.found:
can_strip = False
break
else:
return []
else:
return []

if can_strip and stack[-1] == "F" and is_coroutine:
# Yields inside an async function affect the return type and should not
# be stripped.
yield_visitor = FindYield()
for s in res:
s.accept(yield_visitor)
if yield_visitor.found:
can_strip = False
break

if can_strip:
return []
return res

def translate_type_comment(
Expand Down
13 changes: 10 additions & 3 deletions test-data/unit/check-async-await.test
Original file line number Diff line number Diff line change
Expand Up @@ -945,17 +945,21 @@ async def bar(x: Union[A, B]) -> None:
[typing fixtures/typing-async.pyi]

[case testAsyncIteratorWithIgnoredErrors]
from m import L
import m

async def func(l: L) -> None:
async def func(l: m.L) -> None:
reveal_type(l.get_iterator) # N: Revealed type is "def () -> typing.AsyncIterator[builtins.str]"
reveal_type(l.get_iterator2) # N: Revealed type is "def () -> typing.AsyncIterator[builtins.str]"
async for i in l.get_iterator():
reveal_type(i) # N: Revealed type is "builtins.str"

reveal_type(m.get_generator) # N: Revealed type is "def () -> typing.AsyncGenerator[builtins.int, None]"
async for i2 in m.get_generator():
reveal_type(i2) # N: Revealed type is "builtins.int"

[file m.py]
# mypy: ignore-errors=True
from typing import AsyncIterator
from typing import AsyncIterator, AsyncGenerator

class L:
async def some_func(self, i: int) -> str:
Expand All @@ -968,6 +972,9 @@ class L:
if self:
a = (yield 'x')

async def get_generator() -> AsyncGenerator[int, None]:
yield 1

[builtins fixtures/async_await.pyi]
[typing fixtures/typing-async.pyi]

Expand Down

0 comments on commit f36ea01

Please sign in to comment.