Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
MYPY = False
if MYPY:
import typing # for typing.Type, which conflicts with types.Type
from typing_extensions import Final
from typing_extensions import Final, Literal

from mypy.sharedparse import (
special_function_elide_names, argument_elide_name,
Expand Down Expand Up @@ -248,7 +248,8 @@ def __init__(self,
options: Options,
is_stub: bool,
errors: Errors) -> None:
self.class_nesting = 0
# 'C' for class, 'F' for function
self.class_and_function_stack = [] # type: List[Literal['C', 'F']]
self.imports = [] # type: List[ImportBase]

self.options = options
Expand Down Expand Up @@ -382,8 +383,8 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
ret.append(OverloadedFuncDef(current_overload))
return ret

def in_class(self) -> bool:
return self.class_nesting > 0
def in_method_scope(self) -> bool:
return self.class_and_function_stack[-2:] == ['C', 'F']

def translate_module_id(self, id: str) -> str:
"""Return the actual, internal module id for a source text id.
Expand Down Expand Up @@ -424,6 +425,7 @@ def visit_AsyncFunctionDef(self, n: ast3.AsyncFunctionDef) -> Union[FuncDef, Dec
def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
is_coroutine: bool = False) -> Union[FuncDef, Decorator]:
"""Helper shared between visit_FunctionDef and visit_AsyncFunctionDef."""
self.class_and_function_stack.append('F')
no_type_check = bool(n.decorator_list and
any(is_no_type_check_decorator(d) for d in n.decorator_list))

Expand Down Expand Up @@ -465,7 +467,7 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],
line=lineno).visit(func_type_ast.returns)

# add implicit self type
if self.in_class() and len(arg_types) < len(args):
if self.in_method_scope() and len(arg_types) < len(args):
arg_types.insert(0, AnyType(TypeOfAny.special_form))
except SyntaxError:
self.fail(TYPE_COMMENT_SYNTAX_ERROR, lineno, n.col_offset)
Expand Down Expand Up @@ -529,11 +531,13 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef],

deco = Decorator(func_def, self.translate_expr_list(n.decorator_list), var)
deco.set_line(n.decorator_list[0].lineno)
return deco
retval = deco # type: Union[FuncDef, Decorator]
else:
# FuncDef overrides set_line -- can't use self.set_line
func_def.set_line(lineno, n.col_offset)
return func_def
retval = func_def
self.class_and_function_stack.pop()
return retval

def set_type_optional(self, type: Optional[Type], initializer: Optional[Expression]) -> None:
if self.options.no_implicit_optional:
Expand Down Expand Up @@ -614,7 +618,7 @@ def fail_arg(self, msg: str, arg: ast3.arg) -> None:
# stmt* body,
# expr* decorator_list)
def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef:
self.class_nesting += 1
self.class_and_function_stack.append('C')
keywords = [(kw.arg, self.visit(kw.value))
for kw in n.keywords if kw.arg]

Expand All @@ -634,7 +638,7 @@ def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef:
cdef.column = n.col_offset
else:
self.set_line(cdef, n)
self.class_nesting -= 1
self.class_and_function_stack.pop()
return cdef

# Return(expr? value)
Expand Down
24 changes: 14 additions & 10 deletions mypy/fastparse2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
MYPY = False
if MYPY:
import typing # for typing.Type, which conflicts with types.Type
from typing_extensions import Final
from typing_extensions import Final, Literal

from mypy.sharedparse import (
special_function_elide_names, argument_elide_name,
Expand Down Expand Up @@ -134,7 +134,8 @@ class ASTConverter:
def __init__(self,
options: Options,
errors: Errors) -> None:
self.class_nesting = 0
# 'C' for class, 'F' for function
self.class_and_function_stack = [] # type: List[Literal['C', 'F']]
self.imports = [] # type: List[ImportBase]

self.options = options
Expand Down Expand Up @@ -285,8 +286,8 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
ret.append(OverloadedFuncDef(current_overload))
return ret

def in_class(self) -> bool:
return self.class_nesting > 0
def in_method_scope(self) -> bool:
return self.class_and_function_stack[-2:] == ['C', 'F']

def translate_module_id(self, id: str) -> str:
"""Return the actual, internal module id for a source text id.
Expand All @@ -313,10 +314,11 @@ def visit_Module(self, mod: ast27.Module) -> MypyFile:

# --- stmt ---
# FunctionDef(identifier name, arguments args,
# stmt* body, expr* decorator_list, expr? returns, string? type_comment)
# stmt* body, expr* decorator_list, expr? returns, string? type_comment)
# arguments = (arg* args, arg? vararg, arg* kwonlyargs, expr* kw_defaults,
# arg? kwarg, expr* defaults)
def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement:
self.class_and_function_stack.append('F')
lineno = n.lineno
converter = TypeConverter(self.errors, line=lineno,
assume_str_is_unicode=self.unicode_literals)
Expand Down Expand Up @@ -353,7 +355,7 @@ def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement:
return_type = converter.visit(func_type_ast.returns)

# add implicit self type
if self.in_class() and len(arg_types) < len(args):
if self.in_method_scope() and len(arg_types) < len(args):
arg_types.insert(0, AnyType(TypeOfAny.special_form))
except SyntaxError:
self.fail(TYPE_COMMENT_SYNTAX_ERROR, lineno, n.col_offset)
Expand Down Expand Up @@ -407,11 +409,13 @@ def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement:
func_def.body.set_line(func_def.get_line())
dec = Decorator(func_def, self.translate_expr_list(n.decorator_list), var)
dec.set_line(lineno, n.col_offset)
return dec
retval = dec # type: Statement
else:
# Overrides set_line -- can't use self.set_line
func_def.set_line(lineno, n.col_offset)
return func_def
retval = func_def
self.class_and_function_stack.pop()
return retval

def set_type_optional(self, type: Optional[Type], initializer: Optional[Expression]) -> None:
if self.options.no_implicit_optional:
Expand Down Expand Up @@ -515,7 +519,7 @@ def stringify_name(self, n: AST) -> str:
# stmt* body,
# expr* decorator_list)
def visit_ClassDef(self, n: ast27.ClassDef) -> ClassDef:
self.class_nesting += 1
self.class_and_function_stack.append('C')

cdef = ClassDef(n.name,
self.as_required_block(n.body, n.lineno),
Expand All @@ -524,7 +528,7 @@ def visit_ClassDef(self, n: ast27.ClassDef) -> ClassDef:
metaclass=None)
cdef.decorators = self.translate_expr_list(n.decorator_list)
self.set_line(cdef, n)
self.class_nesting -= 1
self.class_and_function_stack.pop()
return cdef

# Return(expr? value)
Expand Down
29 changes: 29 additions & 0 deletions test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,35 @@ class A:
main:6: error: Incompatible types in assignment (expression has type "int", variable has type "A")
main:8: error: Argument 1 to "g" has incompatible type "A"; expected "int"

[case testNestedFunctionInMethodWithTooFewArgumentsInTypeComment]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps you could also add a test for a few other configurations of the stack, e.g. a method inside a class inside a method (i.e. ['C', 'F', 'C', 'F'])?

class A:
def f(self):
# type: () -> None
def g(x): # E: Type signature has too few arguments
# type: () -> None
pass

[case testDeepNestedFunctionWithTooFewArgumentsInTypeComment]
class A:
def f(self):
# type: () -> None
class B:
def g(self):
# type: () -> None
def h(x): # E: Type signature has too few arguments
# type: () -> None
pass

[case testDeepNestedMethodInTypeComment]
class A:
def f(self):
# type: () -> None
class B:
class C:
def g(self):
# type: () -> None
pass

[case testMutuallyRecursiveNestedFunctions]
def f() -> None:
def g() -> None:
Expand Down