diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 1874f5fbca5a..c1dac63c231d 100644 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -59,7 +59,7 @@ Expression, IntExpr, UnaryExpr, StrExpr, BytesExpr, NameExpr, FloatExpr, MemberExpr, TupleExpr, ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr, ClassDef, MypyFile, Decorator, AssignmentStmt, - IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, TempNode, + IfStmt, ReturnStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, TempNode, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT, ) from mypy.stubgenc import parse_all_signatures, find_unique_signatures, generate_stub_for_c_module @@ -475,7 +475,7 @@ def visit_func_def(self, o: FuncDef) -> None: retname = None if isinstance(o.type, CallableType): retname = self.print_annotation(o.type.ret_type) - elif o.name() == '__init__': + elif o.name() == '__init__' or not has_return_statement(o): retname = 'None' retfield = '' if retname is not None: @@ -814,6 +814,19 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: return results +def has_return_statement(fdef: FuncBase) -> bool: + class ReturnSeeker(mypy.traverser.TraverserVisitor): + def __init__(self) -> None: + self.found = False + + def visit_return_stmt(self, o: ReturnStmt) -> None: + self.found = True + + seeker = ReturnSeeker() + fdef.accept(seeker) + return seeker.found + + def get_qualified_name(o: Expression) -> str: if isinstance(o, NameExpr): return o.name diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index a6a543992c6f..aa8b6136d660 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -5,7 +5,7 @@ def f(): x = 1 [out] -def f(): ... +def f() -> None: ... [case testTwoFunctions] def f(a, b): @@ -13,49 +13,49 @@ def f(a, b): def g(arg): pass [out] -def f(a, b): ... -def g(arg): ... +def f(a, b) -> None: ... +def g(arg) -> None: ... [case testDefaultArgInt] def f(a, b=2): ... def g(b=-1, c=0): ... [out] -def f(a, b: int = ...): ... -def g(b: int = ..., c: int = ...): ... +def f(a, b: int = ...) -> None: ... +def g(b: int = ..., c: int = ...) -> None: ... [case testDefaultArgNone] def f(x=None): ... [out] from typing import Any, Optional -def f(x: Optional[Any] = ...): ... +def f(x: Optional[Any] = ...) -> None: ... [case testDefaultArgBool] def f(x=True, y=False): ... [out] -def f(x: bool = ..., y: bool = ...): ... +def f(x: bool = ..., y: bool = ...) -> None: ... [case testDefaultArgStr] def f(x='foo'): ... [out] -def f(x: str = ...): ... +def f(x: str = ...) -> None: ... [case testDefaultArgBytes] def f(x=b'foo'): ... [out] -def f(x: bytes = ...): ... +def f(x: bytes = ...) -> None: ... [case testDefaultArgFloat] def f(x=1.2): ... [out] -def f(x: float = ...): ... +def f(x: float = ...) -> None: ... [case testDefaultArgOther] def f(x=ord): ... [out] from typing import Any -def f(x: Any = ...): ... +def f(x: Any = ...) -> None: ... [case testPreserveFunctionAnnotation] def f(x: Foo) -> Bar: ... @@ -75,12 +75,12 @@ x: Foo [case testVarArgs] def f(x, *y): ... [out] -def f(x, *y): ... +def f(x, *y) -> None: ... [case testKwVarArgs] def f(x, **y): ... [out] -def f(x, **y): ... +def f(x, **y) -> None: ... [case testClass] class A: @@ -89,9 +89,9 @@ class A: def g(): ... [out] class A: - def f(self, x): ... + def f(self, x) -> None: ... -def g(): ... +def g() -> None: ... [case testVariable] x = 1 @@ -169,15 +169,15 @@ class A: ... def _f(): ... def g(): ... [out] -def g(): ... +def g() -> None: ... [case testIncludePrivateFunction] # flags: --include-private def _f(): ... def g(): ... [out] -def _f(): ... -def g(): ... +def _f() -> None: ... +def g() -> None: ... [case testSkipPrivateMethod] class A: @@ -191,7 +191,7 @@ class A: def _f(self): ... [out] class A: - def _f(self): ... + def _f(self) -> None: ... [case testSkipPrivateVar] _x = 1 @@ -228,7 +228,7 @@ class B(A): ... @decorator def foo(x): ... [out] -def foo(x): ... +def foo(x) -> None: ... [case testMultipleAssignment] x, y = 1, 2 @@ -256,8 +256,8 @@ y: Any def f(x, *, y=1): ... def g(x, *, y=1, z=2): ... [out] -def f(x, *, y: int = ...): ... -def g(x, *, y: int = ..., z: int = ...): ... +def f(x, *, y: int = ...) -> None: ... +def g(x, *, y: int = ..., z: int = ...) -> None: ... [case testProperty] class A: @@ -271,7 +271,7 @@ class A: @property def f(self): ... @f.setter - def f(self, x): ... + def f(self, x) -> None: ... [case testStaticMethod] class A: @@ -280,7 +280,7 @@ class A: [out] class A: @staticmethod - def f(x): ... + def f(x) -> None: ... [case testClassMethod] class A: @@ -289,7 +289,7 @@ class A: [out] class A: @classmethod - def f(cls): ... + def f(cls) -> None: ... [case testIfMainCheck] def a(): ... @@ -298,8 +298,8 @@ if __name__ == '__main__': def f(): ... def b(): ... [out] -def a(): ... -def b(): ... +def a() -> None: ... +def b() -> None: ... [case testImportStar] from x import * @@ -309,7 +309,7 @@ def f(): ... from x import * from a.b import * -def f(): ... +def f() -> None: ... [case testNoSpacesBetweenEmptyClasses] class X: @@ -320,13 +320,13 @@ class C: def f(self): ... [out] class X: - def g(self): ... + def g(self) -> None: ... class A: ... class B: ... class C: - def f(self): ... + def f(self) -> None: ... [case testExceptionBaseClasses] class A(Exception): ... @@ -344,14 +344,14 @@ class A: def __setstate__(self, state): ... [out] class A: - def __eq__(self): ... + def __eq__(self) -> None: ... [case testOmitDefsNotInAll_import] __all__ = [] + ['f'] def f(): ... def g(): ... [out] -def f(): ... +def f() -> None: ... [case testVarDefsNotInAll_import] __all__ = [] + ['f', 'g'] @@ -360,15 +360,15 @@ x = 1 y = 1 def g(): ... [out] -def f(): ... -def g(): ... +def f() -> None: ... +def g() -> None: ... [case testIncludeClassNotInAll_import] __all__ = [] + ['f'] def f(): ... class A: ... [out] -def f(): ... +def f() -> None: ... class A: ... @@ -380,7 +380,7 @@ class A: [out] class A: x: int = ... - def f(self): ... + def f(self) -> None: ... [case testSkipMultiplePrivateDefs] class A: ... @@ -472,12 +472,12 @@ x = 1 class C: def g(self): ... [out] -def f(): ... +def f() -> None: ... x: int class C: - def g(self): ... + def g(self) -> None: ... # Names in __all__ with no definition: # g @@ -503,7 +503,7 @@ class A: [out] class A: @property - def _foo(self): ... + def _foo(self) -> None: ... [case testSkipPrivateStaticAndClassMethod] class A: @@ -524,9 +524,9 @@ class A: [out] class A: @staticmethod - def _foo(): ... + def _foo() -> None: ... @classmethod - def _bar(cls): ... + def _bar(cls) -> None: ... [case testNamedtuple] import collections, x @@ -552,11 +552,11 @@ def g(): ... [out] from collections import namedtuple -def f(): ... +def f() -> None: ... X = namedtuple('X', 'a b') -def g(): ... +def g() -> None: ... [case testNamedtupleBaseClass] import collections, x @@ -638,9 +638,9 @@ def f(): self.x = 1 def g(): ... [out] -def x(): ... -def f(): ... -def g(): ... +def x() -> None: ... +def f() -> None: ... +def g() -> None: ... [case testNestedClass] class A: @@ -652,8 +652,8 @@ class A: class A: class B: x: int = ... - def f(self): ... - def g(self): ... + def f(self) -> None: ... + def g(self) -> None: ... [case testExportViaRelativeImport] from .api import get @@ -682,13 +682,13 @@ class A(X): ... def syslog(a): pass def syslog(a): pass [out] -def syslog(a): ... +def syslog(a) -> None: ... [case testAsyncAwait_fast_parser] async def f(a): x = await y [out] -def f(a): ... +def f(a) -> None: ... [case testInferOptionalOnlyFunc] class A: @@ -703,7 +703,7 @@ from typing import Any, Optional class A: x: Any = ... def __init__(self, a: Optional[Any] = ...) -> None: ... - def method(self, a: Optional[Any] = ...): ... + def method(self, a: Optional[Any] = ...) -> None: ... [case testAnnotationImportsFrom] import foo @@ -853,3 +853,23 @@ noalias3: bool -- More features/fixes: -- do not export deleted names + +[case testFunctionNoReturnInfersReturnNone] +def f(): + x = 1 +[out] +def f() -> None: ... + +[case testFunctionReturnNoReturnType] +def f(): + return 1 +def g(): + return +[out] +def f(): ... +def g(): ... + +[case testFunctionEllipsisInfersReturnNone] +def f(): ... +[out] +def f() -> None: ...