Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Treat obvious return types as annotated #4411

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,8 @@ def check_func_item(self, defn: FuncItem,

If type_override is provided, use it as the function type.
"""
if not defn.is_checkable:
return
# We may be checking a function definition or an anonymous function. In
# the first case, set up another reference with the precise type.
fdef = None # type: Optional[FuncDef]
Expand Down
41 changes: 20 additions & 21 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,27 +376,26 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
self.set_type_optional(arg_type, arg.initializer)

func_type = None
if any(arg_types) or return_type:
if len(arg_types) != 1 and any(isinstance(t, EllipsisType) for t in arg_types):
self.fail("Ellipses cannot accompany other argument types "
"in function type signature.", n.lineno, 0)
elif len(arg_types) > len(arg_kinds):
self.fail('Type signature has too many arguments', n.lineno, 0)
elif len(arg_types) < len(arg_kinds):
self.fail('Type signature has too few arguments', n.lineno, 0)
else:
func_type = CallableType([a if a is not None else
AnyType(TypeOfAny.unannotated) for a in arg_types],
arg_kinds,
arg_names,
return_type if return_type is not None else
AnyType(TypeOfAny.unannotated),
_dummy_fallback)

func_def = FuncDef(n.name,
args,
self.as_required_block(n.body, n.lineno),
func_type)
if len(arg_types) != 1 and any(isinstance(t, EllipsisType) for t in arg_types):
self.fail("Ellipses cannot accompany other argument types "
"in function type signature.", n.lineno, 0)
elif len(arg_types) > len(arg_kinds):
self.fail('Type signature has too many arguments', n.lineno, 0)
elif len(arg_types) < len(arg_kinds):
self.fail('Type signature has too few arguments', n.lineno, 0)
else:
func_type = CallableType([a if a is not None else
AnyType(TypeOfAny.unannotated) for a in arg_types],
arg_kinds,
arg_names,
return_type if return_type is not None else
AnyType(TypeOfAny.unannotated),
_dummy_fallback)

func_def = FuncDef(n.name, args,
self.as_required_block(n.body, n.lineno),
func_type)
func_def.is_checkable = bool(any(arg_types) or return_type)
if is_coroutine:
# A coroutine is also a generator, mostly for internal reasons.
func_def.is_generator = func_def.is_coroutine = True
Expand Down
3 changes: 3 additions & 0 deletions mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ def add_invertible_flag(flag: str,
"(experimental -- read documentation before using!). "
"Implies --strict-optional. Has the undesirable side-effect of "
"suppressing other errors in non-whitelisted files.")
add_invertible_flag('--obvious-return', default=False, strict_flag=True,
help="Treat obvious return values as function annotations")

parser.add_argument('--junit-xml', help="write junit.xml to the given file")
parser.add_argument('--pdb', action='store_true', help="invoke pdb on fatal error")
parser.add_argument('--show-traceback', '--tb', action='store_true',
Expand Down
3 changes: 2 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,12 +429,13 @@ class FuncItem(FuncBase):
is_awaitable_coroutine = False # Decorated with '@{typing,asyncio}.coroutine'?
is_static = False # Uses @staticmethod?
is_class = False # Uses @classmethod?
is_checkable = False
# Variants of function with type variables with values expanded
expanded = None # type: List[FuncItem]

FLAGS = [
'is_overload', 'is_generator', 'is_coroutine', 'is_async_generator',
'is_awaitable_coroutine', 'is_static', 'is_class',
'is_awaitable_coroutine', 'is_static', 'is_class', 'is_checkable',
]

def __init__(self, arguments: List[Argument], body: 'Block',
Expand Down
3 changes: 3 additions & 0 deletions mypy/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class Options:
"no_implicit_optional",
"strict_optional",
"disallow_untyped_decorators",
"obvious_return",
}

OPTIONS_AFFECTING_CACHE = ((PER_MODULE_OPTIONS | {"quick_and_dirty", "platform"})
Expand Down Expand Up @@ -112,6 +113,8 @@ def __init__(self) -> None:
# Apply strict None checking
self.strict_optional = False

self.obvious_return = False

# Show "note: In function "foo":" messages.
self.show_error_context = False

Expand Down
48 changes: 36 additions & 12 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,8 @@ def analyze_property_with_multi_part_definition(self, defn: OverloadedFuncDef) -
def analyze_function(self, defn: FuncItem) -> None:
is_method = self.is_class_scope()
with self.tvar_scope_frame(self.tvar_scope.method_frame()):
if defn.type:

if defn.type is not None and defn.is_checkable:
self.check_classvar_in_signature(defn.type)
assert isinstance(defn.type, CallableType)
# Signature must be analyzed in the surrounding scope so that
Expand All @@ -601,9 +602,10 @@ def analyze_function(self, defn: FuncItem) -> None:
if arg.initializer:
arg.initializer.accept(self)
# Bind the type variables again to visit the body.
if defn.type:
if defn.type is not None and defn.is_checkable:
a = self.type_analyzer()
a.bind_function_type_variables(cast(CallableType, defn.type), defn)
assert isinstance(defn.type, CallableType)
a.bind_function_type_variables(defn.type, defn)
self.function_stack.append(defn)
self.enter()
for arg in defn.arguments:
Expand All @@ -627,6 +629,17 @@ def analyze_function(self, defn: FuncItem) -> None:
self.postpone_nested_functions_stack.pop()
self.postponed_functions_stack.pop()

if self.options.obvious_return:
if isinstance(defn.type, CallableType):
if isinstance(defn.type.ret_type, AnyType):
if defn.type.ret_type.type_of_any == TypeOfAny.unannotated:
finder = RetFinder()
defn.accept(finder)
# Ad-hoc join: trivial identity
ret_types = {self.analyze_simple_literal_type(ret) for ret in finder.rets}
if None not in ret_types and len(ret_types) == 1:
defn.type.ret_type = ret_types.pop()

self.leave()
self.function_stack.pop()

Expand Down Expand Up @@ -1703,7 +1716,16 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:
# Set the type if the rvalue is a simple literal (even if the above error occurred).
if len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr):
if s.lvalues[0].is_inferred_def:
s.type = self.analyze_simple_literal_type(s.rvalue)
if self.options.semantic_analysis_only or self.function_stack:
# Skip this if we're only doing the semantic analysis pass.
# This is mostly to avoid breaking unit tests.
# Also skip inside a function; this is to avoid confusing
# the code that handles dead code due to isinstance()
# inside type variables with value restrictions (like
# AnyStr).
s.type = None
else:
s.type = self.analyze_simple_literal_type(s.rvalue)
if s.type:
# Store type into nodes.
for lvalue in s.lvalues:
Expand All @@ -1724,14 +1746,6 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None:

def analyze_simple_literal_type(self, rvalue: Expression) -> Optional[Type]:
"""Return builtins.int if rvalue is an int literal, etc."""
if self.options.semantic_analysis_only or self.function_stack:
# Skip this if we're only doing the semantic analysis pass.
# This is mostly to avoid breaking unit tests.
# Also skip inside a function; this is to avoid confusing
# the code that handles dead code due to isinstance()
# inside type variables with value restrictions (like
# AnyStr).
return None
if isinstance(rvalue, IntExpr):
return self.named_type_or_none('builtins.int')
if isinstance(rvalue, FloatExpr):
Expand Down Expand Up @@ -4232,6 +4246,16 @@ def make_any_non_explicit(t: Type) -> Type:
return t.accept(MakeAnyNonExplicit())


class RetFinder(TraverserVisitor):
rets = None # type: List[Expression]

def __init__(self) -> None:
self.rets = []

def visit_return_stmt(self, ret: ReturnStmt) -> None:
self.rets.append(ret.expr or NameExpr("None"))


class MakeAnyNonExplicit(TypeTranslator):
def visit_any(self, t: AnyType) -> Type:
if t.type_of_any == TypeOfAny.explicit:
Expand Down
1 change: 1 addition & 0 deletions mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
'check-serialize.test',
'check-bound.test',
'check-optional.test',
'check-obvious.test',
'check-fastparse.test',
'check-warnings.test',
'check-async-await.test',
Expand Down
5 changes: 3 additions & 2 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def __init__(self,
arg_kinds: List[int],
arg_names: Sequence[Optional[str]],
ret_type: Type,
fallback: Instance,
fallback: Instance, # or none at parsing time
name: Optional[str] = None,
definition: Optional[SymbolNode] = None,
variables: Optional[List[TypeVarDef]] = None,
Expand Down Expand Up @@ -720,7 +720,8 @@ def copy_modified(self,
)

def is_type_obj(self) -> bool:
return self.fallback.type.is_metaclass()
# self.fallback is None at initialization
return self.fallback is not None and self.fallback.type.is_metaclass()

def is_concrete_type_obj(self) -> bool:
return self.is_type_obj() and self.is_classmethod_class
Expand Down
24 changes: 24 additions & 0 deletions test-data/unit/check-obvious.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
-- Tests for obvious return types

[case testObviousLiteralString]
# flags: --obvious-return
def foo():
"" + 1
return ""

def bar():
1 + ""
return 1

reveal_type(foo() + "") # E: Revealed type is 'builtins.str'
bar() + foo() # E: Unsupported operand types for + ("int" and "str")

[case testObviousOverride]
# flags: --obvious-return
class A:
def foo(self):
return ""

class B(A):
def foo(self): # E: Return type of "foo" incompatible with supertype "A"
return 1