From f36ea01e9de2ffac42a2615307f3692fd8e84c4a Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Fri, 23 Jun 2023 06:05:05 -0700 Subject: [PATCH] Fix async iterator body stripping (#15491) Fixes #15489 --- mypy/fastparse.py | 58 +++++++++++++++------------ test-data/unit/check-async-await.test | 13 ++++-- 2 files changed, 43 insertions(+), 28 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 902bde110421..fe59ff48bdfc 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -521,7 +521,14 @@ def translate_stmt_list( return [block] stack = self.class_and_function_stack - if self.strip_function_bodies and len(stack) == 1 and stack[0] == "F": + # Fast case for stripping function bodies + if ( + can_strip + and self.strip_function_bodies + and len(stack) == 1 + and stack[0] == "F" + and not is_coroutine + ): return [] res: list[Statement] = [] @@ -529,32 +536,33 @@ def translate_stmt_list( node = self.visit(stmt) res.append(node) - if ( - self.strip_function_bodies - and can_strip - and stack[-2:] == ["C", "F"] - and not is_possible_trivial_body(res) - ): - # We only strip method bodies if they don't assign to an attribute, as - # this may define an attribute which has an externally visible effect. - visitor = FindAttributeAssign() - for s in res: - s.accept(visitor) - if visitor.found: - break - else: - if is_coroutine: - # Yields inside an async function affect the return type and should not - # be stripped. - yield_visitor = FindYield() + # Slow case for stripping function bodies + if can_strip and self.strip_function_bodies: + if stack[-2:] == ["C", "F"]: + if is_possible_trivial_body(res): + can_strip = False + else: + # We only strip method bodies if they don't assign to an attribute, as + # this may define an attribute which has an externally visible effect. + visitor = FindAttributeAssign() for s in res: - s.accept(yield_visitor) - if yield_visitor.found: + s.accept(visitor) + if visitor.found: + can_strip = False break - else: - return [] - else: - return [] + + if can_strip and stack[-1] == "F" and is_coroutine: + # Yields inside an async function affect the return type and should not + # be stripped. + yield_visitor = FindYield() + for s in res: + s.accept(yield_visitor) + if yield_visitor.found: + can_strip = False + break + + if can_strip: + return [] return res def translate_type_comment( diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index 83a66ef4a815..d0df2c9e0104 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -945,17 +945,21 @@ async def bar(x: Union[A, B]) -> None: [typing fixtures/typing-async.pyi] [case testAsyncIteratorWithIgnoredErrors] -from m import L +import m -async def func(l: L) -> None: +async def func(l: m.L) -> None: reveal_type(l.get_iterator) # N: Revealed type is "def () -> typing.AsyncIterator[builtins.str]" reveal_type(l.get_iterator2) # N: Revealed type is "def () -> typing.AsyncIterator[builtins.str]" async for i in l.get_iterator(): reveal_type(i) # N: Revealed type is "builtins.str" + reveal_type(m.get_generator) # N: Revealed type is "def () -> typing.AsyncGenerator[builtins.int, None]" + async for i2 in m.get_generator(): + reveal_type(i2) # N: Revealed type is "builtins.int" + [file m.py] # mypy: ignore-errors=True -from typing import AsyncIterator +from typing import AsyncIterator, AsyncGenerator class L: async def some_func(self, i: int) -> str: @@ -968,6 +972,9 @@ class L: if self: a = (yield 'x') +async def get_generator() -> AsyncGenerator[int, None]: + yield 1 + [builtins fixtures/async_await.pyi] [typing fixtures/typing-async.pyi]