Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for operators with union operands #5545

Merged
merged 11 commits into from
Sep 5, 2018
116 changes: 90 additions & 26 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,14 +1188,14 @@ def check_overload_call(self,
# gives a narrower type.
if unioned_return:
returns, inferred_types = zip(*unioned_return)
# Note that we use `union_overload_matches` instead of just returning
# Note that we use `combine_function_signatures` instead of just returning
# a union of inferred callables because for example a call
# Union[int -> int, str -> str](Union[int, str]) is invalid and
# we don't want to introduce internal inconsistencies.
unioned_result = (UnionType.make_simplified_union(list(returns),
context.line,
context.column),
self.union_overload_matches(inferred_types))
self.combine_function_signatures(inferred_types))

# Step 3: We try checking each branch one-by-one.
inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types,
Expand Down Expand Up @@ -1492,8 +1492,8 @@ def type_overrides_set(self, exprs: Sequence[Expression],
for expr in exprs:
del self.type_overrides[expr]

def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, CallableType]:
"""Accepts a list of overload signatures and attempts to combine them together into a
def combine_function_signatures(self, types: Sequence[Type]) -> Union[AnyType, CallableType]:
"""Accepts a list of function signatures and attempts to combine them together into a
new CallableType consisting of the union of all of the given arguments and return types.

If there is at least one non-callable type, return Any (this can happen if there is
Expand All @@ -1507,7 +1507,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab
return callables[0]

# Note: we are assuming here that if a user uses some TypeVar 'T' in
# two different overloads, they meant for that TypeVar to mean the
# two different functions, they meant for that TypeVar to mean the
# same thing.
#
# This function will make sure that all instances of that TypeVar 'T'
Expand All @@ -1525,7 +1525,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab

too_complex = False
for target in callables:
# We fall back to Callable[..., Union[<returns>]] if the overloads do not have
# We fall back to Callable[..., Union[<returns>]] if the functions do not have
# the exact same signature. The only exception is if one arg is optional and
# the other is positional: in that case, we continue unioning (and expect a
# positional arg).
Expand Down Expand Up @@ -1820,19 +1820,12 @@ def check_op_reversible(self,
left_expr: Expression,
right_type: Type,
right_expr: Expression,
context: Context) -> Tuple[Type, Type]:
# Note: this kludge exists mostly to maintain compatibility with
# existing error messages. Apparently, if the left-hand-side is a
# union and we have a type mismatch, we print out a special,
# abbreviated error message. (See messages.unsupported_operand_types).
unions_present = isinstance(left_type, UnionType)

context: Context,
msg: MessageBuilder) -> Tuple[Type, Type]:
def make_local_errors() -> MessageBuilder:
"""Creates a new MessageBuilder object."""
local_errors = self.msg.clean_copy()
local_errors = msg.clean_copy()
local_errors.disable_count = 0
if unions_present:
local_errors.disable_type_names += 1
return local_errors

def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]:
Expand Down Expand Up @@ -2006,30 +1999,101 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
# TODO: Remove this extra case
return result

self.msg.add_errors(errors[0])
msg.add_errors(errors[0])
if len(results) == 1:
return results[0]
else:
error_any = AnyType(TypeOfAny.from_error)
result = error_any, error_any
return result

def check_op(self, method: str, base_type: Type, arg: Expression,
context: Context,
def check_op(self, method: str, base_type: Type,
arg: Expression, context: Context,
allow_reverse: bool = False) -> Tuple[Type, Type]:
"""Type check a binary operation which maps to a method call.

Return tuple (result type, inferred operator method type).
"""

if allow_reverse:
return self.check_op_reversible(
op_name=method,
left_type=base_type,
left_expr=TempNode(base_type),
right_type=self.accept(arg),
right_expr=arg,
context=context)
left_variants = [base_type]
if isinstance(base_type, UnionType):
left_variants = [item for item in base_type.relevant_items()]
right_type = self.accept(arg)

# Step 1: We first try leaving the right arguments alone and destructure
# just the left ones. (Mypy can sometimes perform some more precise inference
# if we leave the right operands a union -- see testOperatorWithEmptyListAndSum.
msg = self.msg.clean_copy()
msg.disable_count = 0
all_results = []
all_inferred = []

for left_possible_type in left_variants:
result, inferred = self.check_op_reversible(
op_name=method,
left_type=left_possible_type,
left_expr=TempNode(left_possible_type),
right_type=right_type,
right_expr=arg,
context=context,
msg=msg)
all_results.append(result)
all_inferred.append(inferred)

if not msg.is_errors():
results_final = UnionType.make_simplified_union(all_results)
inferred_final = UnionType.make_simplified_union(all_inferred)
return results_final, inferred_final

# Step 2: If that fails, we try again but also destructure the right argument.
# This is also necessary to make certain edge cases work -- see
# testOperatorDoubleUnionInterwovenUnionAdd, for example.

# Note: We want to pass in the original 'arg' for 'left_expr' and 'right_expr'
# whenever possible so that plugins and similar things can introspect on the original
# node if possible.
#
# We don't do the same for the base expression because it could lead to weird
# type inference errors -- e.g. see 'testOperatorDoubleUnionSum'.
# TODO: Can we use `type_overrides_set()` here?
right_variants = [(right_type, arg)]
if isinstance(right_type, UnionType):
right_variants = [(item, TempNode(item)) for item in right_type.relevant_items()]

msg = self.msg.clean_copy()
msg.disable_count = 0
all_results = []
all_inferred = []

for left_possible_type in left_variants:
for right_possible_type, right_expr in right_variants:
result, inferred = self.check_op_reversible(
op_name=method,
left_type=left_possible_type,
left_expr=TempNode(left_possible_type),
right_type=right_possible_type,
right_expr=right_expr,
context=context,
msg=msg)
all_results.append(result)
all_inferred.append(inferred)

if msg.is_errors():
self.msg.add_errors(msg)
if len(left_variants) >= 2 and len(right_variants) >= 2:
self.msg.warn_both_operands_are_from_unions(context)
elif len(left_variants) >= 2:
self.msg.warn_operand_was_from_union("Left", base_type, context)
elif len(right_variants) >= 2:
self.msg.warn_operand_was_from_union("Right", right_type, context)

# See the comment in 'check_overload_call' for more details on why
# we call 'combine_function_signature' instead of just unioning the inferred
# callable types.
results_final = UnionType.make_simplified_union(all_results)
inferred_final = self.combine_function_signatures(all_inferred)
return results_final, inferred_final
else:
return self.check_op_local_by_name(
method=method,
Expand Down
6 changes: 6 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,12 @@ def overloaded_signatures_ret_specific(self, index: int, context: Context) -> No
self.fail('Overloaded function implementation cannot produce return type '
'of signature {}'.format(index), context)

def warn_both_operands_are_from_unions(self, context: Context) -> None:
self.note('Both left and right operands are unions', context)

def warn_operand_was_from_union(self, side: str, original: Type, context: Context) -> None:
self.note('{} operand is of type {}'.format(side, self.format(original)), context)

def operator_method_signatures_overlap(
self, reverse_class: TypeInfo, reverse_method: str, forward_class: Type,
forward_method: str, context: Context) -> None:
Expand Down
9 changes: 6 additions & 3 deletions test-data/unit/check-callable.test
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ from typing import Callable, Union
x = 5 # type: Union[int, Callable[[], str], Callable[[], int]]

if callable(x):
y = x() + 2 # E: Unsupported operand types for + (likely involving Union)
y = x() + 2 # E: Unsupported operand types for + ("str" and "int") \
# N: Left operand is of type "Union[str, int]"
else:
z = x + 6

Expand All @@ -60,7 +61,8 @@ x = 5 # type: Union[int, str, Callable[[], str]]
if callable(x):
y = x() + 'test'
else:
z = x + 6 # E: Unsupported operand types for + (likely involving Union)
z = x + 6 # E: Unsupported operand types for + ("str" and "int") \
# N: Left operand is of type "Union[int, str]"

[builtins fixtures/callable.pyi]

Expand Down Expand Up @@ -153,7 +155,8 @@ x = 5 # type: Union[int, Callable[[], str]]
if callable(x) and x() == 'test':
x()
else:
x + 5 # E: Unsupported left operand type for + (some union)
x + 5 # E: Unsupported left operand type for + ("Callable[[], str]") \
# N: Left operand is of type "Union[int, Callable[[], str]]"

[builtins fixtures/callable.pyi]

Expand Down
Loading