Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stubgen: Support yield from statements #15271

Merged
merged 2 commits into from May 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 18 additions & 9 deletions mypy/stubgen.py
Expand Up @@ -126,7 +126,12 @@
report_missing,
walk_packages,
)
from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression
from mypy.traverser import (
all_yield_expressions,
has_return_statement,
has_yield_expression,
has_yield_from_expression,
)
from mypy.types import (
OVERLOAD_NAMES,
TPDICT_NAMES,
Expand Down Expand Up @@ -774,18 +779,22 @@ def visit_func_def(self, o: FuncDef) -> None:
retname = None # implicit Any
elif o.name in KNOWN_MAGIC_METHODS_RETURN_TYPES:
retname = KNOWN_MAGIC_METHODS_RETURN_TYPES[o.name]
elif has_yield_expression(o):
elif has_yield_expression(o) or has_yield_from_expression(o):
self.add_typing_import("Generator")
yield_name = "None"
send_name = "None"
return_name = "None"
for expr, in_assignment in all_yield_expressions(o):
if expr.expr is not None and not self.is_none_expr(expr.expr):
self.add_typing_import("Incomplete")
yield_name = self.typing_name("Incomplete")
if in_assignment:
self.add_typing_import("Incomplete")
send_name = self.typing_name("Incomplete")
if has_yield_from_expression(o):
self.add_typing_import("Incomplete")
yield_name = send_name = self.typing_name("Incomplete")
else:
for expr, in_assignment in all_yield_expressions(o):
if expr.expr is not None and not self.is_none_expr(expr.expr):
self.add_typing_import("Incomplete")
yield_name = self.typing_name("Incomplete")
if in_assignment:
self.add_typing_import("Incomplete")
send_name = self.typing_name("Incomplete")
if has_return_statement(o):
self.add_typing_import("Incomplete")
return_name = self.typing_name("Incomplete")
Expand Down
36 changes: 36 additions & 0 deletions mypy/traverser.py
Expand Up @@ -873,6 +873,21 @@ def has_yield_expression(fdef: FuncBase) -> bool:
return seeker.found


class YieldFromSeeker(FuncCollectorBase):
def __init__(self) -> None:
super().__init__()
self.found = False

def visit_yield_from_expr(self, o: YieldFromExpr) -> None:
self.found = True


def has_yield_from_expression(fdef: FuncBase) -> bool:
seeker = YieldFromSeeker()
fdef.accept(seeker)
return seeker.found


class AwaitSeeker(TraverserVisitor):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -922,3 +937,24 @@ def all_yield_expressions(node: Node) -> list[tuple[YieldExpr, bool]]:
v = YieldCollector()
node.accept(v)
return v.yield_expressions


class YieldFromCollector(FuncCollectorBase):
def __init__(self) -> None:
super().__init__()
self.in_assignment = False
self.yield_from_expressions: list[tuple[YieldFromExpr, bool]] = []

def visit_assignment_stmt(self, stmt: AssignmentStmt) -> None:
self.in_assignment = True
super().visit_assignment_stmt(stmt)
self.in_assignment = False

def visit_yield_from_expr(self, expr: YieldFromExpr) -> None:
self.yield_from_expressions.append((expr, self.in_assignment))


def all_yield_from_expressions(node: Node) -> list[tuple[YieldFromExpr, bool]]:
v = YieldFromCollector()
node.accept(v)
return v.yield_from_expressions
82 changes: 79 additions & 3 deletions test-data/unit/stubgen.test
Expand Up @@ -1231,6 +1231,9 @@ def h1():
def h2():
yield
return "abc"
def h3():
yield
return None
def all():
x = yield 123
return "abc"
Expand All @@ -1242,6 +1245,7 @@ def f() -> Generator[Incomplete, None, None]: ...
def g() -> Generator[None, Incomplete, None]: ...
def h1() -> Generator[None, None, None]: ...
def h2() -> Generator[None, None, Incomplete]: ...
def h3() -> Generator[None, None, None]: ...
def all() -> Generator[Incomplete, Incomplete, Incomplete]: ...

[case testFunctionYieldsNone]
Expand Down Expand Up @@ -1270,6 +1274,69 @@ class Generator: ...

def f() -> _Generator[Incomplete, None, None]: ...

[case testGeneratorYieldFrom]
def g1():
yield from x
def g2():
y = yield from x
def g3():
yield from x
return
def g4():
yield from x
return None
def g5():
yield from x
return z

[out]
from _typeshed import Incomplete
from collections.abc import Generator

def g1() -> Generator[Incomplete, Incomplete, None]: ...
def g2() -> Generator[Incomplete, Incomplete, None]: ...
hamdanal marked this conversation as resolved.
Show resolved Hide resolved
def g3() -> Generator[Incomplete, Incomplete, None]: ...
def g4() -> Generator[Incomplete, Incomplete, None]: ...
def g5() -> Generator[Incomplete, Incomplete, Incomplete]: ...

[case testGeneratorYieldAndYieldFrom]
def g1():
yield x1
yield from x2
def g2():
yield x1
y = yield from x2
def g3():
y = yield x1
yield from x2
def g4():
yield x1
yield from x2
return
def g5():
yield x1
yield from x2
return None
def g6():
yield x1
yield from x2
return z
def g7():
yield None
yield from x2

[out]
from _typeshed import Incomplete
from collections.abc import Generator

def g1() -> Generator[Incomplete, Incomplete, None]: ...
def g2() -> Generator[Incomplete, Incomplete, None]: ...
def g3() -> Generator[Incomplete, Incomplete, None]: ...
def g4() -> Generator[Incomplete, Incomplete, None]: ...
def g5() -> Generator[Incomplete, Incomplete, None]: ...
def g6() -> Generator[Incomplete, Incomplete, Incomplete]: ...
def g7() -> Generator[Incomplete, Incomplete, None]: ...

[case testCallable]
from typing import Callable

Expand Down Expand Up @@ -2977,13 +3044,17 @@ def func(*, non_default_kwarg: bool, default_kwarg: bool = True): ...
def func(*, non_default_kwarg: bool, default_kwarg: bool = ...): ...

[case testNestedGenerator]
def f():
def f1():
def g():
yield 0

return 0
def f2():
def g():
yield from [0]
return 0
[out]
def f(): ...
def f1(): ...
def f2(): ...

[case testKnownMagicMethodsReturnTypes]
class Some:
Expand Down Expand Up @@ -3193,6 +3264,10 @@ def gen():
y = yield x
return z

def gen2():
y = yield from x
return z

class X(unknown_call("X", "a b")): ...
class Y(collections.namedtuple("Y", xx)): ...
[out]
Expand Down Expand Up @@ -3227,6 +3302,7 @@ TD2: _Incomplete
TD3: _Incomplete

def gen() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...
def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...

class X(_Incomplete): ...
class Y(_Incomplete): ...