Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type checking plugin support for functions #3299

Merged
merged 3 commits into from May 25, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion mypy/checker.py
Expand Up @@ -2219,8 +2219,12 @@ def visit_decorator(self, e: Decorator) -> None:
continue
dec = self.expr_checker.accept(d)
temp = self.temp_node(sig)
fullname = None
if isinstance(d, RefExpr):
fullname = d.fullname
sig, t2 = self.expr_checker.check_call(dec, [temp],
[nodes.ARG_POS], e)
[nodes.ARG_POS], e,
callable_name=fullname)
sig = cast(FunctionLike, sig)
sig = set_callable_name(sig, e.func)
e.var.type = sig
Expand Down
43 changes: 39 additions & 4 deletions mypy/checkexpr.py
Expand Up @@ -44,6 +44,7 @@
from mypy.util import split_module_names
from mypy.typevars import fill_typevars
from mypy.visitor import ExpressionVisitor
from mypy.funcplugins import get_function_plugin_callbacks, PluginCallback

from mypy import experiments

Expand Down Expand Up @@ -103,6 +104,7 @@ class ExpressionChecker(ExpressionVisitor[Type]):
type_context = None # type: List[Optional[Type]]

strfrm_checker = None # type: StringFormatterChecker
function_plugins = None # type: Dict[str, PluginCallback]

def __init__(self,
chk: 'mypy.checker.TypeChecker',
Expand All @@ -112,6 +114,7 @@ def __init__(self,
self.msg = msg
self.type_context = [None]
self.strfrm_checker = StringFormatterChecker(self, self.chk, self.msg)
self.function_plugins = get_function_plugin_callbacks(self.chk.options.python_version)

def visit_name_expr(self, e: NameExpr) -> Type:
"""Type check a name expression.
Expand Down Expand Up @@ -198,7 +201,11 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
isinstance(callee_type, CallableType)
and callee_type.implicit):
return self.msg.untyped_function_call(callee_type, e)
ret_type = self.check_call_expr_with_callee_type(callee_type, e)
if not isinstance(e.callee, RefExpr):
fullname = None
else:
fullname = e.callee.fullname
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname)
if isinstance(ret_type, UninhabitedType):
self.chk.binder.unreachable()
if not allow_none_return and isinstance(ret_type, NoneTyp):
Expand Down Expand Up @@ -330,21 +337,44 @@ def try_infer_partial_type(self, e: CallExpr) -> None:
list(full_item_types))
del partial_types[var]

def apply_function_plugin(self,
arg_types: List[Type],
inferred_ret_type: Type,
arg_kinds: List[int],
formal_to_actual: List[List[int]],
args: List[Expression],
num_formals: int,
fullname: Optional[str]) -> Type:
"""Use special case logic to infer the return type for of a particular named function.

Return the inferred return type.
"""
formal_arg_types = [[] for _ in range(num_formals)] # type: List[List[Type]]
formal_arg_exprs = [[] for _ in range(num_formals)] # type: List[List[Expression]]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
formal_arg_types[formal].append(arg_types[actual])
formal_arg_exprs[formal].append(args[actual])
return self.function_plugins[fullname](
formal_arg_types, formal_arg_exprs, inferred_ret_type, self.chk.named_generic_type)

def check_call_expr_with_callee_type(self, callee_type: Type,
e: CallExpr) -> Type:
e: CallExpr, callable_name: Optional[str]) -> Type:
"""Type check call expression.

The given callee type overrides the type of the callee
expression.
"""
return self.check_call(callee_type, e.args, e.arg_kinds, e,
e.arg_names, callable_node=e.callee)[0]
e.arg_names, callable_node=e.callee,
callable_name=callable_name)[0]

def check_call(self, callee: Type, args: List[Expression],
arg_kinds: List[int], context: Context,
arg_names: List[str] = None,
callable_node: Expression = None,
arg_messages: MessageBuilder = None) -> Tuple[Type, Type]:
arg_messages: MessageBuilder = None,
callable_name: Optional[str] = None) -> Tuple[Type, Type]:
"""Type check a call.

Also infer type arguments if the callee is a generic function.
Expand Down Expand Up @@ -406,6 +436,11 @@ def check_call(self, callee: Type, args: List[Expression],
if callable_node:
# Store the inferred callable type.
self.chk.store_type(callable_node, callee)
if callable_name in self.function_plugins:
ret_type = self.apply_function_plugin(
arg_types, callee.ret_type, arg_kinds, formal_to_actual,
args, len(callee.arg_types), callable_name)
callee = callee.copy_modified(ret_type=ret_type)
return callee.ret_type, callee
elif isinstance(callee, Overloaded):
# Type check arguments in empty context. They will be checked again
Expand Down
81 changes: 81 additions & 0 deletions mypy/funcplugins.py
@@ -0,0 +1,81 @@
"""Plugins that implement special type checking rules for individual functions.

The plugins infer better types for tricky functions such as "open".
"""

from typing import Tuple, Dict, Callable, List

from mypy.nodes import Expression, StrExpr
from mypy.types import Type, Instance, CallableType


# A callback that infers the return type of a function with a special signature.
#
# A no-op callback would just return the inferred return type, but a useful callback
# at least sometimes can infer a more precise type.
PluginCallback = Callable[
[
List[List[Type]], # List of types caller provides for each formal argument
List[List[Expression]], # Actual argument expressions for each formal argument
Type, # Return type for call inferred using the regular signature
Callable[[str, List[Type]], Type] # Callable for constructing a named instance type
],
Type # Return type inferred by the callback
]


def get_function_plugin_callbacks(python_version: Tuple[int, int]) -> Dict[str, PluginCallback]:
"""Return all available function plugins for a given Python version."""
if python_version[0] == 3:
return {
'builtins.open': open_callback,
'contextlib.contextmanager': contextmanager_callback,
}
else:
return {
'contextlib.contextmanager': contextmanager_callback,
}


def open_callback(
arg_types: List[List[Type]],
args: List[List[Expression]],
inferred_return_type: Type,
named_generic_type: Callable[[str, List[Type]], Type]) -> Type:
"""Infer a better return type for 'open'.

Infer IO[str] or IO[bytes] as the return value if the mode argument is not
given or is a literal.
"""
mode = None
if not arg_types or len(arg_types[1]) != 1:
mode = 'r'
elif isinstance(args[1][0], StrExpr):
mode = args[1][0].value
if mode is not None:
assert isinstance(inferred_return_type, Instance)
if 'b' in mode:
arg = named_generic_type('builtins.bytes', [])
else:
arg = named_generic_type('builtins.str', [])
return Instance(inferred_return_type.type, [arg])
return inferred_return_type


def contextmanager_callback(
arg_types: List[List[Type]],
args: List[List[Expression]],
inferred_return_type: Type,
named_generic_type: Callable[[str, List[Type]], Type]) -> Type:
"""Infer a better return type for 'contextlib.contextmanager'."""
# Be defensive, just in case.
if arg_types and len(arg_types[0]) == 1:
arg_type = arg_types[0][0]
if isinstance(arg_type, CallableType) and isinstance(inferred_return_type, CallableType):
# The stub signature doesn't preserve information about arguments so
# add them back here.
return inferred_return_type.copy_modified(
arg_types=arg_type.arg_types,
arg_kinds=arg_type.arg_kinds,
arg_names=arg_type.arg_names)
return inferred_return_type
52 changes: 51 additions & 1 deletion test-data/unit/pythoneval.test
Expand Up @@ -399,7 +399,33 @@ f.write('x')
f.write(b'x')
f.foobar()
[out]
_program.py:4: error: IO[Any] has no attribute "foobar"
_program.py:3: error: Argument 1 to "write" of "IO" has incompatible type "bytes"; expected "str"
_program.py:4: error: IO[str] has no attribute "foobar"

[case testOpenReturnTypeInference]
reveal_type(open('x'))
reveal_type(open('x', 'r'))
reveal_type(open('x', 'rb'))
mode = 'rb'
reveal_type(open('x', mode))
[out]
_program.py:1: error: Revealed type is 'typing.IO[builtins.str]'
_program.py:2: error: Revealed type is 'typing.IO[builtins.str]'
_program.py:3: error: Revealed type is 'typing.IO[builtins.bytes]'
_program.py:5: error: Revealed type is 'typing.IO[Any]'

[case testOpenReturnTypeInferenceSpecialCases]
reveal_type(open())
reveal_type(open(mode='rb', file='x'))
reveal_type(open(file='x', mode='rb'))
mode = 'rb'
reveal_type(open(mode=mode, file='r'))
[out]
_testOpenReturnTypeInferenceSpecialCases.py:1: error: Revealed type is 'typing.IO[builtins.str]'
_testOpenReturnTypeInferenceSpecialCases.py:1: error: Too few arguments for "open"
_testOpenReturnTypeInferenceSpecialCases.py:2: error: Revealed type is 'typing.IO[builtins.bytes]'
_testOpenReturnTypeInferenceSpecialCases.py:3: error: Revealed type is 'typing.IO[builtins.bytes]'
_testOpenReturnTypeInferenceSpecialCases.py:5: error: Revealed type is 'typing.IO[Any]'

[case testGenericPatterns]
from typing import Pattern
Expand Down Expand Up @@ -1286,3 +1312,27 @@ a[1] = 2, 'y'
a[:] = [('z', 3)]
[out]
_program.py:4: error: Incompatible types in assignment (expression has type "Tuple[int, str]", target has type "Tuple[str, int]")

[case testContextManager]
import contextlib
from contextlib import contextmanager
from typing import Iterator

@contextmanager
def f(x: int) -> Iterator[str]:
yield 'foo'

@contextlib.contextmanager
def g(*x: str) -> Iterator[int]:
yield 1

reveal_type(f)
reveal_type(g)

with f('') as s:
reveal_type(s)
[out]
_program.py:13: error: Revealed type is 'def (x: builtins.int) -> contextlib.GeneratorContextManager[builtins.str*]'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that that class is misnamed in typeshed, it should be _GeneratorContextManager (to match what it's called at runtime). I also don't understand what its __call__ method is for (contextlib doesn't seem to have reference docs, and the source has few clues).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __call__ is so that @contextmanager-decorated functions can also be used as decorators themselves (executing the decorated function within the context). Nick Coghlan has said that he considers this feature a design mistake in contextlib.

_program.py:14: error: Revealed type is 'def (*x: builtins.str) -> contextlib.GeneratorContextManager[builtins.int*]'
_program.py:16: error: Argument 1 to "f" has incompatible type "str"; expected "int"
_program.py:17: error: Revealed type is 'builtins.str*'