Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
AssertStmt,
AssignmentExpr,
AssignmentStmt,
AwaitExpr,
Block,
BreakStmt,
BytesExpr,
Expand Down Expand Up @@ -4924,7 +4925,11 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
allow_none_func_call = is_lambda or declared_none_return or declared_any_return

# Return with a value.
if isinstance(s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr)):
if (
isinstance(s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr))
or isinstance(s.expr, AwaitExpr)
and isinstance(s.expr.expr, CallExpr)
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about using isinstance(s.expr, AwaitExpr) and isinstance(s.expr.expr, CallExpr)? This will save us some pointless busy-work for cases like return await some_name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, yes, makes sense. I doubt that await not_a_call accounts for some significant portion of await expressions, but this shouldn't harm.

# For expressions that (strongly) depend on type context (i.e. those that
# are handled like a function call), we allow fallback to empty type context
# in case of errors, this improves user experience in some cases,
Expand Down
2 changes: 2 additions & 0 deletions test-data/unit/check-inference-context.test
Original file line number Diff line number Diff line change
Expand Up @@ -1582,3 +1582,5 @@ async def inner(c: Cls[T]) -> Optional[T]:

async def outer(c: Cls[T]) -> Optional[T]:
return await inner(c)
[builtins fixtures/async_await.pyi]
[typing fixtures/typing-async.pyi]
45 changes: 26 additions & 19 deletions test-data/unit/fixtures/typing-async.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ Self = 0

T = TypeVar('T')
T_co = TypeVar('T_co', covariant=True)
R_co = TypeVar('R_co', covariant=True)
T_contra = TypeVar('T_contra', contravariant=True)
S_contra = TypeVar('S_contra', contravariant=True)
U = TypeVar('U')
V = TypeVar('V')
S = TypeVar('S')
Expand All @@ -49,9 +51,9 @@ class Iterator(Iterable[T_co], Protocol):
@abstractmethod
def __next__(self) -> T_co: pass

class Generator(Iterator[T], Generic[T, U, V]):
class Generator(Iterator[T_co], Generic[T_co, S_contra, R_co]):
@abstractmethod
def send(self, value: U) -> T: pass
def send(self, value: S_contra) -> T_co: pass

@abstractmethod
def throw(self, typ: Any, val: Any=None, tb: Any=None) -> None: pass
Expand All @@ -60,49 +62,54 @@ class Generator(Iterator[T], Generic[T, U, V]):
def close(self) -> None: pass

@abstractmethod
def __iter__(self) -> 'Generator[T, U, V]': pass
def __iter__(self) -> 'Generator[T_co, S_contra, R_co]': pass

class AsyncGenerator(AsyncIterator[T], Generic[T, U]):
class AsyncGenerator(AsyncIterator[T_co], Generic[T_co, S_contra]):
@abstractmethod
def __anext__(self) -> Awaitable[T]: pass
def __anext__(self) -> Awaitable[T_co]: pass

@abstractmethod
def asend(self, value: U) -> Awaitable[T]: pass
def asend(self, value: S_contra) -> Awaitable[T_co]: pass

@abstractmethod
def athrow(self, typ: Any, val: Any=None, tb: Any=None) -> Awaitable[T]: pass
def athrow(self, typ: Any, val: Any=None, tb: Any=None) -> Awaitable[T_co]: pass

@abstractmethod
def aclose(self) -> Awaitable[T]: pass
def aclose(self) -> Awaitable[T_co]: pass

@abstractmethod
def __aiter__(self) -> 'AsyncGenerator[T, U]': pass
def __aiter__(self) -> 'AsyncGenerator[T_co, S_contra]': pass

class Awaitable(Protocol[T]):
class Awaitable(Protocol[T_co]):
@abstractmethod
def __await__(self) -> Generator[Any, Any, T]: pass
def __await__(self) -> Generator[Any, Any, T_co]: pass

class AwaitableGenerator(Generator[T, U, V], Awaitable[V], Generic[T, U, V, S], metaclass=ABCMeta):
class AwaitableGenerator(
Awaitable[R_co],
Generator[T_co, S_contra, R_co],
Generic[T_co, S_contra, R_co, S],
metaclass=ABCMeta
):
pass

class Coroutine(Awaitable[V], Generic[T, U, V]):
class Coroutine(Awaitable[R_co], Generic[T_co, S_contra, R_co]):
@abstractmethod
def send(self, value: U) -> T: pass
def send(self, value: S_contra) -> T_co: pass

@abstractmethod
def throw(self, typ: Any, val: Any=None, tb: Any=None) -> None: pass

@abstractmethod
def close(self) -> None: pass

class AsyncIterable(Protocol[T]):
class AsyncIterable(Protocol[T_co]):
@abstractmethod
def __aiter__(self) -> 'AsyncIterator[T]': pass
def __aiter__(self) -> 'AsyncIterator[T_co]': pass

class AsyncIterator(AsyncIterable[T], Protocol):
def __aiter__(self) -> 'AsyncIterator[T]': return self
class AsyncIterator(AsyncIterable[T_co], Protocol):
def __aiter__(self) -> 'AsyncIterator[T_co]': return self
@abstractmethod
def __anext__(self) -> Awaitable[T]: pass
def __anext__(self) -> Awaitable[T_co]: pass

class Sequence(Iterable[T_co], Container[T_co]):
@abstractmethod
Expand Down
Loading