Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix async generators #431

Merged
merged 2 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Fix return type of async generator functions (#430)
JelleZijlstra marked this conversation as resolved.
Show resolved Hide resolved
- Type check function decorators (#428)
- Handle `NoReturn` in `async def` functions (#427)
- Support PEP 673 (`typing_extensions.Self`) (#423)
Expand Down
39 changes: 38 additions & 1 deletion pyanalyze/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,37 @@ def compute_parameters(
return params


@dataclass
class IsGeneratorVisitor(ast.NodeVisitor):
"""Determine whether an async function is a generator.

This is important because the return type of async generators
should not be wrapped in Awaitable.

We avoid recursing into nested functions, which is why we can't
just use ast.walk.

We do not need to check for yield from because it is illegal
in async generators. We also skip checking nested comprehensions,
because we error anyway if there is a yield within a comprehension.

"""

is_generator: bool = False

def visit_Yield(self, node: ast.Yield) -> None:
self.is_generator = True

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
pass

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
pass

def visit_Lambda(self, node: ast.Lambda) -> None:
pass


def compute_value_of_function(
info: FunctionInfo, ctx: Context, *, result: Optional[Value] = None
) -> Value:
Expand All @@ -316,7 +347,13 @@ def compute_value_of_function(
if result is None:
result = AnyValue(AnySource.unannotated)
if isinstance(info.node, ast.AsyncFunctionDef):
result = GenericValue(collections.abc.Awaitable, [result])
visitor = IsGeneratorVisitor()
for line in info.node.body:
visitor.visit(line)
if visitor.is_generator:
break
if not visitor.is_generator:
result = GenericValue(collections.abc.Awaitable, [result])
sig = Signature.make(
[param_info.param for param_info in info.params],
result,
Expand Down
8 changes: 6 additions & 2 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3171,8 +3171,10 @@ def visit_Break(self, node: ast.Break) -> None:
def visit_Continue(self, node: ast.Continue) -> None:
self._set_name_in_scope(LEAVES_LOOP, node, AnyValue(AnySource.marker))

def visit_For(self, node: ast.For) -> None:
iterated_value = self._member_value_of_iterator(node.iter)
def visit_For(self, node: Union[ast.For, ast.AsyncFor]) -> None:
iterated_value = self._member_value_of_iterator(
node.iter, is_async=isinstance(node, ast.AsyncFor)
)
if self.options.get_value_for(ForLoopAlwaysEntered):
always_entered = True
elif isinstance(iterated_value, Value):
Expand Down Expand Up @@ -3209,6 +3211,8 @@ def visit_For(self, node: ast.For) -> None:
self.visit(node.target)
self._generic_visit_list(node.body)

visit_AsyncFor = visit_For

def visit_While(self, node: ast.While) -> None:
# see comments under For for discussion

Expand Down
18 changes: 18 additions & 0 deletions pyanalyze/test_async_await.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,21 @@ async def pacarana(cond) -> int: # E: missing_return
async def hutia(cond) -> int: # E: missing_return
if cond:
return 3


class TestAsyncGenerator(TestNameCheckVisitorBase):
@assert_passes()
def test_async_gen(self):
import collections.abc
from typing import AsyncIterator

async def gen() -> AsyncIterator[int]:
yield 3

async def capybara() -> None:
assert_is_value(
gen(), GenericValue(collections.abc.AsyncIterator, [TypedValue(int)])
)
async for i in gen():
# TODO should be int
assert_is_value(i, AnyValue(AnySource.generic_argument))
10 changes: 10 additions & 0 deletions pyanalyze/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,13 @@ def bad_deco(x: int) -> str:
@bad_deco # E: incompatible_argument
def capybara():
pass

@skip_before((3, 7))
@assert_passes()
def test_asynccontextmanager(self):
from contextlib import asynccontextmanager
from typing import AsyncIterator

@asynccontextmanager
async def make_cm() -> AsyncIterator[None]:
yield