Skip to content

Commit

Permalink
Support method plugin hooks on unions (#6560)
Browse files Browse the repository at this point in the history
Fixes #6117
Fixes #5930

Currently both our plugin method hooks don't work with unions. This PR fixes this with three things:
* Moves a bit of logic from `visit_call_expr_inner()` (which is a long method already) to `check_call_expr_with_callee_type()` (which is a short method).
* Special-cases unions in `check_call_expr_with_callee_type()` (normal method calls) and `check_method_call_by_name()` (dunder/operator method calls).
* Adds some clarifying comments and a docstring.

The week point is interaction with binder, but IMO this is the best we can have for now. I left a comment mentioning that check for overlap should be consistent in two functions.

In general, I don't like special-casing, but I spent several days thinking of other solutions, and it looks like special-casing unions in couple more places is the only reasonable way to fix unions-vs-plugins interactions.

This PR may interfere with #6558 that fixes an "opposite" problem, hopefully they will work together unmodified, so that accessing union of literals on union of typed dicts works. Whatever PR lands second, should add a test for this.
  • Loading branch information
ilevkivskyi committed Jul 4, 2019
1 parent b724cca commit 72734f2
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 21 deletions.
130 changes: 109 additions & 21 deletions mypy/checkexpr.py
Expand Up @@ -5,7 +5,7 @@
from typing import (
cast, Dict, Set, List, Tuple, Callable, Union, Optional, Sequence, Iterator
)
from typing_extensions import ClassVar, Final
from typing_extensions import ClassVar, Final, overload

from mypy.errors import report_internal_error
from mypy.typeanal import (
Expand Down Expand Up @@ -284,24 +284,27 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
return self.msg.untyped_function_call(callee_type, e)
# Figure out the full name of the callee for plugin lookup.
object_type = None
if not isinstance(e.callee, RefExpr):
fullname = None
else:
member = None
fullname = None
if isinstance(e.callee, RefExpr):
# There are two special cases where plugins might act:
# * A "static" reference/alias to a class or function;
# get_function_hook() will be invoked for these.
fullname = e.callee.fullname
if (isinstance(e.callee.node, TypeAlias) and
isinstance(e.callee.node.target, Instance)):
fullname = e.callee.node.target.type.fullname()
# * Call to a method on object that has a full name (see
# method_fullname() for details on supported objects);
# get_method_hook() and get_method_signature_hook() will
# be invoked for these.
if (fullname is None
and isinstance(e.callee, MemberExpr)
and e.callee.expr in self.chk.type_map
and isinstance(callee_type, FunctionLike)):
# For method calls we include the defining class for the method
# in the full name (example: 'typing.Mapping.get').
callee_expr_type = self.chk.type_map[e.callee.expr]
fullname = self.method_fullname(callee_expr_type, e.callee.name)
if fullname is not None:
object_type = callee_expr_type
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname, object_type)
and e.callee.expr in self.chk.type_map):
member = e.callee.name
object_type = self.chk.type_map[e.callee.expr]
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname,
object_type, member)
if isinstance(e.callee, RefExpr) and len(e.args) == 2:
if e.callee.fullname in ('builtins.isinstance', 'builtins.issubclass'):
self.check_runtime_protocol_test(e)
Expand Down Expand Up @@ -632,21 +635,53 @@ def check_call_expr_with_callee_type(self,
callee_type: Type,
e: CallExpr,
callable_name: Optional[str],
object_type: Optional[Type]) -> Type:
object_type: Optional[Type],
member: Optional[str] = None) -> Type:
"""Type check call expression.
The given callee type overrides the type of the callee
expression.
"""
# Try to refine the call signature using plugin hooks before checking the call.
callee_type = self.transform_callee_type(
callable_name, callee_type, e.args, e.arg_kinds, e, e.arg_names, object_type)
The callee_type should be used as the type of callee expression. In particular,
in case of a union type this can be a particular item of the union, so that we can
apply plugin hooks to each item.
The 'member', 'callable_name' and 'object_type' are only used to call plugin hooks.
If 'callable_name' is None but 'member' is not None (member call), try constructing
'callable_name' using 'object_type' (the base type on which the method is called),
for example 'typing.Mapping.get'.
"""
if callable_name is None and member is not None:
assert object_type is not None
callable_name = self.method_fullname(object_type, member)
if callable_name:
# Try to refine the call signature using plugin hooks before checking the call.
callee_type = self.transform_callee_type(
callable_name, callee_type, e.args, e.arg_kinds, e, e.arg_names, object_type)
# Unions are special-cased to allow plugins to act on each item in the union.
elif member is not None and isinstance(object_type, UnionType):
return self.check_union_call_expr(e, object_type, member)
return self.check_call(callee_type, e.args, e.arg_kinds, e,
e.arg_names, callable_node=e.callee,
callable_name=callable_name,
object_type=object_type)[0]

def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str) -> Type:
""""Type check calling a member expression where the base type is a union."""
res = [] # type: List[Type]
for typ in object_type.relevant_items():
# Member access errors are already reported when visiting the member expression.
self.msg.disable_errors()
item = analyze_member_access(member, typ, e, False, False, False,
self.msg, original_type=object_type, chk=self.chk,
in_literal_context=self.is_literal_context())
self.msg.enable_errors()
narrowed = self.narrow_type_from_binder(e.callee, item, skip_non_overlapping=True)
if narrowed is None:
continue
callable_name = self.method_fullname(typ, member)
item_object_type = typ if callable_name else None
res.append(self.check_call_expr_with_callee_type(narrowed, e, callable_name,
item_object_type))
return UnionType.make_simplified_union(res)

def check_call(self,
callee: Type,
args: List[Expression],
Expand Down Expand Up @@ -2018,13 +2053,48 @@ def check_method_call_by_name(self,
"""
local_errors = local_errors or self.msg
original_type = original_type or base_type
# Unions are special-cased to allow plugins to act on each element of the union.
if isinstance(base_type, UnionType):
return self.check_union_method_call_by_name(method, base_type,
args, arg_kinds,
context, local_errors, original_type)

method_type = analyze_member_access(method, base_type, context, False, False, True,
local_errors, original_type=original_type,
chk=self.chk,
in_literal_context=self.is_literal_context())
return self.check_method_call(
method, base_type, method_type, args, arg_kinds, context, local_errors)

def check_union_method_call_by_name(self,
method: str,
base_type: UnionType,
args: List[Expression],
arg_kinds: List[int],
context: Context,
local_errors: MessageBuilder,
original_type: Optional[Type] = None
) -> Tuple[Type, Type]:
"""Type check a call to a named method on an object with union type.
This essentially checks the call using check_method_call_by_name() for each
union item and unions the result. We do this to allow plugins to act on
individual union items.
"""
res = [] # type: List[Type]
meth_res = [] # type: List[Type]
for typ in base_type.relevant_items():
# Format error messages consistently with
# mypy.checkmember.analyze_union_member_access().
local_errors.disable_type_names += 1
item, meth_item = self.check_method_call_by_name(method, typ, args, arg_kinds,
context, local_errors,
original_type)
local_errors.disable_type_names -= 1
res.append(item)
meth_res.append(meth_item)
return UnionType.make_simplified_union(res), UnionType.make_simplified_union(meth_res)

def check_method_call(self,
method_name: str,
base_type: Type,
Expand Down Expand Up @@ -3524,14 +3594,32 @@ def bool_type(self) -> Instance:
"""Return instance type 'bool'."""
return self.named_type('builtins.bool')

def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type:
@overload
def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type: ...

@overload # noqa
def narrow_type_from_binder(self, expr: Expression, known_type: Type,
skip_non_overlapping: bool) -> Optional[Type]: ...

def narrow_type_from_binder(self, expr: Expression, known_type: Type, # noqa
skip_non_overlapping: bool = False) -> Optional[Type]:
"""Narrow down a known type of expression using information in conditional type binder.
If 'skip_non_overlapping' is True, return None if the type and restriction are
non-overlapping.
"""
if literal(expr) >= LITERAL_TYPE:
restriction = self.chk.binder.get(expr)
# If the current node is deferred, some variables may get Any types that they
# otherwise wouldn't have. We don't want to narrow down these since it may
# produce invalid inferred Optional[Any] types, at least.
if restriction and not (isinstance(known_type, AnyType)
and self.chk.current_node_deferred):
# Note: this call should match the one in narrow_declared_type().
if (skip_non_overlapping and
not is_overlapping_types(known_type, restriction,
prohibit_none_typevar_overlap=True)):
return None
ans = narrow_declared_type(known_type, restriction)
return ans
return known_type
Expand Down
5 changes: 5 additions & 0 deletions test-data/unit/check-ctypes.test
Expand Up @@ -23,6 +23,7 @@ for x in a:

[case testCtypesArrayCustomElementType]
import ctypes
from typing import Union, List

class MyCInt(ctypes.c_int):
pass
Expand All @@ -46,6 +47,10 @@ mya[3] = b"bytes" # E: No overload variant of "__setitem__" of "Array" matches
# N: def __setitem__(self, slice, List[Union[MyCInt, int]]) -> None
for myx in mya:
reveal_type(myx) # N: Revealed type is '__main__.MyCInt*'

myu: Union[ctypes.Array[ctypes.c_int], List[str]]
for myi in myu:
reveal_type(myi) # N: Revealed type is 'Union[builtins.int*, builtins.str*]'
[builtins fixtures/floatdict.pyi]

[case testCtypesArrayUnionElementType]
Expand Down
61 changes: 61 additions & 0 deletions test-data/unit/check-custom-plugin.test
Expand Up @@ -585,6 +585,67 @@ reveal_type(instance(2)) # N: Revealed type is 'builtins.float*'
[[mypy]
plugins=<ROOT>/test-data/unit/plugins/callable_instance.py

[case testGetMethodHooksOnUnions]
# flags: --config-file tmp/mypy.ini --no-strict-optional
from typing import Union

class Foo:
def meth(self, x: str) -> str: ...
class Bar:
def meth(self, x: int) -> float: ...
class Other:
meth: int

x: Union[Foo, Bar, Other]
if isinstance(x.meth, int):
reveal_type(x.meth) # N: Revealed type is 'builtins.int'
else:
reveal_type(x.meth(int())) # N: Revealed type is 'builtins.int'

[builtins fixtures/isinstancelist.pyi]
[file mypy.ini]
[[mypy]
plugins=<ROOT>/test-data/unit/plugins/union_method.py

[case testGetMethodHooksOnUnionsStrictOptional]
# flags: --config-file tmp/mypy.ini --strict-optional
from typing import Union

class Foo:
def meth(self, x: str) -> str: ...
class Bar:
def meth(self, x: int) -> float: ...
class Other:
meth: int

x: Union[Foo, Bar, Other]
if isinstance(x.meth, int):
reveal_type(x.meth) # N: Revealed type is 'builtins.int'
else:
reveal_type(x.meth(int())) # N: Revealed type is 'builtins.int'

[builtins fixtures/isinstancelist.pyi]
[file mypy.ini]
[[mypy]
plugins=<ROOT>/test-data/unit/plugins/union_method.py

[case testGetMethodHooksOnUnionsSpecial]
# flags: --config-file tmp/mypy.ini
from typing import Union

class Foo:
def __getitem__(self, x: str) -> str: ...
class Bar:
def __getitem__(self, x: int) -> float: ...

x: Union[Foo, Bar]
reveal_type(x[int()]) # N: Revealed type is 'builtins.int'

[builtins fixtures/isinstancelist.pyi]
[file mypy.ini]
[[mypy]
plugins=<ROOT>/test-data/unit/plugins/union_method.py

[case testPluginDependencies]
# flags: --config-file tmp/mypy.ini

Expand Down
49 changes: 49 additions & 0 deletions test-data/unit/check-typeddict.test
Expand Up @@ -1589,6 +1589,55 @@ alias('x') # E: Argument 1 has incompatible type "str"; expected "NoReturn"
alias(s) # E: Argument 1 has incompatible type "str"; expected "NoReturn"
[builtins fixtures/dict.pyi]

[case testPluginUnionsOfTypedDicts]
from typing import Union
from mypy_extensions import TypedDict

class TDA(TypedDict):
a: int
b: str

class TDB(TypedDict):
a: int
b: int
c: int

td: Union[TDA, TDB]

reveal_type(td.get('a')) # N: Revealed type is 'builtins.int'
reveal_type(td.get('b')) # N: Revealed type is 'Union[builtins.str, builtins.int]'
reveal_type(td.get('c')) # N: Revealed type is 'Union[Any, builtins.int]' \
# E: TypedDict "TDA" has no key 'c'

reveal_type(td['a']) # N: Revealed type is 'builtins.int'
reveal_type(td['b']) # N: Revealed type is 'Union[builtins.str, builtins.int]'
reveal_type(td['c']) # N: Revealed type is 'Union[Any, builtins.int]' \
# E: TypedDict "TDA" has no key 'c'
[builtins fixtures/dict.pyi]
[typing fixtures/typing-full.pyi]

[case testPluginUnionsOfTypedDictsNonTotal]
from typing import Union
from mypy_extensions import TypedDict

class TDA(TypedDict, total=False):
a: int
b: str

class TDB(TypedDict, total=False):
a: int
b: int
c: int

td: Union[TDA, TDB]

reveal_type(td.pop('a')) # N: Revealed type is 'builtins.int'
reveal_type(td.pop('b')) # N: Revealed type is 'Union[builtins.str, builtins.int]'
reveal_type(td.pop('c')) # N: Revealed type is 'Union[Any, builtins.int]' \
# E: TypedDict "TDA" has no key 'c'
[builtins fixtures/dict.pyi]
[typing fixtures/typing-full.pyi]

[case testCanCreateTypedDictWithTypingExtensions]
# flags: --python-version 3.6
from typing_extensions import TypedDict
Expand Down
49 changes: 49 additions & 0 deletions test-data/unit/plugins/union_method.py
@@ -0,0 +1,49 @@
from mypy.plugin import (
CallableType, CheckerPluginInterface, MethodSigContext, MethodContext, Plugin
)
from mypy.types import Instance, Type


class MethodPlugin(Plugin):
def get_method_signature_hook(self, fullname):
if fullname.startswith('__main__.Foo.'):
return my_meth_sig_hook
return None

def get_method_hook(self, fullname):
if fullname.startswith('__main__.Bar.'):
return my_meth_hook
return None


def _str_to_int(api: CheckerPluginInterface, typ: Type) -> Type:
if isinstance(typ, Instance):
if typ.type.fullname() == 'builtins.str':
return api.named_generic_type('builtins.int', [])
elif typ.args:
return typ.copy_modified(args=[_str_to_int(api, t) for t in typ.args])
return typ


def _float_to_int(api: CheckerPluginInterface, typ: Type) -> Type:
if isinstance(typ, Instance):
if typ.type.fullname() == 'builtins.float':
return api.named_generic_type('builtins.int', [])
elif typ.args:
return typ.copy_modified(args=[_float_to_int(api, t) for t in typ.args])
return typ


def my_meth_sig_hook(ctx: MethodSigContext) -> CallableType:
return ctx.default_signature.copy_modified(
arg_types=[_str_to_int(ctx.api, t) for t in ctx.default_signature.arg_types],
ret_type=_str_to_int(ctx.api, ctx.default_signature.ret_type),
)


def my_meth_hook(ctx: MethodContext) -> Type:
return _float_to_int(ctx.api, ctx.default_return_type)


def plugin(version):
return MethodPlugin

0 comments on commit 72734f2

Please sign in to comment.