From 5adf014ed1498ce7353a2af3d9185bd78aabfa05 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Mon, 1 Sep 2025 20:50:50 +0200 Subject: [PATCH 1/2] Fix wrong test --- mypy/checker.py | 5 ++- test-data/unit/check-inference-context.test | 2 + test-data/unit/fixtures/typing-async.pyi | 45 ++++++++++++--------- 3 files changed, 32 insertions(+), 20 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 77822b7068ae..680d1228336e 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -80,6 +80,7 @@ AssertStmt, AssignmentExpr, AssignmentStmt, + AwaitExpr, Block, BreakStmt, BytesExpr, @@ -4924,7 +4925,9 @@ def check_return_stmt(self, s: ReturnStmt) -> None: allow_none_func_call = is_lambda or declared_none_return or declared_any_return # Return with a value. - if isinstance(s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr)): + if isinstance( + s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr, AwaitExpr) + ): # For expressions that (strongly) depend on type context (i.e. those that # are handled like a function call), we allow fallback to empty type context # in case of errors, this improves user experience in some cases, diff --git a/test-data/unit/check-inference-context.test b/test-data/unit/check-inference-context.test index 7dbbd68c4215..a41ee5f59670 100644 --- a/test-data/unit/check-inference-context.test +++ b/test-data/unit/check-inference-context.test @@ -1582,3 +1582,5 @@ async def inner(c: Cls[T]) -> Optional[T]: async def outer(c: Cls[T]) -> Optional[T]: return await inner(c) +[builtins fixtures/async_await.pyi] +[typing fixtures/typing-async.pyi] diff --git a/test-data/unit/fixtures/typing-async.pyi b/test-data/unit/fixtures/typing-async.pyi index 03728f822316..7ce2821d2916 100644 --- a/test-data/unit/fixtures/typing-async.pyi +++ b/test-data/unit/fixtures/typing-async.pyi @@ -28,7 +28,9 @@ Self = 0 T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) +R_co = TypeVar('R_co', covariant=True) T_contra = TypeVar('T_contra', contravariant=True) +S_contra = TypeVar('S_contra', contravariant=True) U = TypeVar('U') V = TypeVar('V') S = TypeVar('S') @@ -49,9 +51,9 @@ class Iterator(Iterable[T_co], Protocol): @abstractmethod def __next__(self) -> T_co: pass -class Generator(Iterator[T], Generic[T, U, V]): +class Generator(Iterator[T_co], Generic[T_co, S_contra, R_co]): @abstractmethod - def send(self, value: U) -> T: pass + def send(self, value: S_contra) -> T_co: pass @abstractmethod def throw(self, typ: Any, val: Any=None, tb: Any=None) -> None: pass @@ -60,34 +62,39 @@ class Generator(Iterator[T], Generic[T, U, V]): def close(self) -> None: pass @abstractmethod - def __iter__(self) -> 'Generator[T, U, V]': pass + def __iter__(self) -> 'Generator[T_co, S_contra, R_co]': pass -class AsyncGenerator(AsyncIterator[T], Generic[T, U]): +class AsyncGenerator(AsyncIterator[T_co], Generic[T_co, S_contra]): @abstractmethod - def __anext__(self) -> Awaitable[T]: pass + def __anext__(self) -> Awaitable[T_co]: pass @abstractmethod - def asend(self, value: U) -> Awaitable[T]: pass + def asend(self, value: S_contra) -> Awaitable[T_co]: pass @abstractmethod - def athrow(self, typ: Any, val: Any=None, tb: Any=None) -> Awaitable[T]: pass + def athrow(self, typ: Any, val: Any=None, tb: Any=None) -> Awaitable[T_co]: pass @abstractmethod - def aclose(self) -> Awaitable[T]: pass + def aclose(self) -> Awaitable[T_co]: pass @abstractmethod - def __aiter__(self) -> 'AsyncGenerator[T, U]': pass + def __aiter__(self) -> 'AsyncGenerator[T_co, S_contra]': pass -class Awaitable(Protocol[T]): +class Awaitable(Protocol[T_co]): @abstractmethod - def __await__(self) -> Generator[Any, Any, T]: pass + def __await__(self) -> Generator[Any, Any, T_co]: pass -class AwaitableGenerator(Generator[T, U, V], Awaitable[V], Generic[T, U, V, S], metaclass=ABCMeta): +class AwaitableGenerator( + Awaitable[R_co], + Generator[T_co, S_contra, R_co], + Generic[T_co, S_contra, R_co, S], + metaclass=ABCMeta +): pass -class Coroutine(Awaitable[V], Generic[T, U, V]): +class Coroutine(Awaitable[R_co], Generic[T_co, S_contra, R_co]): @abstractmethod - def send(self, value: U) -> T: pass + def send(self, value: S_contra) -> T_co: pass @abstractmethod def throw(self, typ: Any, val: Any=None, tb: Any=None) -> None: pass @@ -95,14 +102,14 @@ class Coroutine(Awaitable[V], Generic[T, U, V]): @abstractmethod def close(self) -> None: pass -class AsyncIterable(Protocol[T]): +class AsyncIterable(Protocol[T_co]): @abstractmethod - def __aiter__(self) -> 'AsyncIterator[T]': pass + def __aiter__(self) -> 'AsyncIterator[T_co]': pass -class AsyncIterator(AsyncIterable[T], Protocol): - def __aiter__(self) -> 'AsyncIterator[T]': return self +class AsyncIterator(AsyncIterable[T_co], Protocol): + def __aiter__(self) -> 'AsyncIterator[T_co]': return self @abstractmethod - def __anext__(self) -> Awaitable[T]: pass + def __anext__(self) -> Awaitable[T_co]: pass class Sequence(Iterable[T_co], Container[T_co]): @abstractmethod From 75863a4c66fa2d21331b4c8f54966d4c3bf0c606 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 2 Sep 2025 18:15:46 +0200 Subject: [PATCH 2/2] Specialize to `await call()` only --- mypy/checker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 680d1228336e..ba821df621e5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4925,8 +4925,10 @@ def check_return_stmt(self, s: ReturnStmt) -> None: allow_none_func_call = is_lambda or declared_none_return or declared_any_return # Return with a value. - if isinstance( - s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr, AwaitExpr) + if ( + isinstance(s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr)) + or isinstance(s.expr, AwaitExpr) + and isinstance(s.expr.expr, CallExpr) ): # For expressions that (strongly) depend on type context (i.e. those that # are handled like a function call), we allow fallback to empty type context