Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for operators with union operands #5545

Merged
merged 11 commits into from
Sep 5, 2018
118 changes: 91 additions & 27 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,14 +1188,14 @@ def check_overload_call(self,
# gives a narrower type.
if unioned_return:
returns, inferred_types = zip(*unioned_return)
# Note that we use `union_overload_matches` instead of just returning
# Note that we use `combine_function_signatures` instead of just returning
# a union of inferred callables because for example a call
# Union[int -> int, str -> str](Union[int, str]) is invalid and
# we don't want to introduce internal inconsistencies.
unioned_result = (UnionType.make_simplified_union(list(returns),
context.line,
context.column),
self.union_overload_matches(inferred_types))
self.combine_function_signatures(inferred_types))

# Step 3: We try checking each branch one-by-one.
inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types,
Expand Down Expand Up @@ -1492,8 +1492,8 @@ def type_overrides_set(self, exprs: Sequence[Expression],
for expr in exprs:
del self.type_overrides[expr]

def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, CallableType]:
"""Accepts a list of overload signatures and attempts to combine them together into a
def combine_function_signatures(self, types: Sequence[Type]) -> Union[AnyType, CallableType]:
"""Accepts a list of function signatures and attempts to combine them together into a
new CallableType consisting of the union of all of the given arguments and return types.

If there is at least one non-callable type, return Any (this can happen if there is
Expand All @@ -1507,7 +1507,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab
return callables[0]

# Note: we are assuming here that if a user uses some TypeVar 'T' in
# two different overloads, they meant for that TypeVar to mean the
# two different functions, they meant for that TypeVar to mean the
# same thing.
#
# This function will make sure that all instances of that TypeVar 'T'
Expand All @@ -1525,7 +1525,7 @@ def union_overload_matches(self, types: Sequence[Type]) -> Union[AnyType, Callab

too_complex = False
for target in callables:
# We fall back to Callable[..., Union[<returns>]] if the overloads do not have
# We fall back to Callable[..., Union[<returns>]] if the functions do not have
# the exact same signature. The only exception is if one arg is optional and
# the other is positional: in that case, we continue unioning (and expect a
# positional arg).
Expand Down Expand Up @@ -1820,19 +1820,12 @@ def check_op_reversible(self,
left_expr: Expression,
right_type: Type,
right_expr: Expression,
context: Context) -> Tuple[Type, Type]:
# Note: this kludge exists mostly to maintain compatibility with
# existing error messages. Apparently, if the left-hand-side is a
# union and we have a type mismatch, we print out a special,
# abbreviated error message. (See messages.unsupported_operand_types).
unions_present = isinstance(left_type, UnionType)

context: Context,
msg: MessageBuilder) -> Tuple[Type, Type]:
def make_local_errors() -> MessageBuilder:
"""Creates a new MessageBuilder object."""
local_errors = self.msg.clean_copy()
local_errors = msg.clean_copy()
local_errors.disable_count = 0
if unions_present:
local_errors.disable_type_names += 1
return local_errors

def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]:
Expand Down Expand Up @@ -2009,9 +2002,9 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
# TODO: Remove this extra case
return result

self.msg.add_errors(errors[0])
msg.add_errors(errors[0])
if warn_about_uncalled_reverse_operator:
self.msg.reverse_operator_method_never_called(
msg.reverse_operator_method_never_called(
nodes.op_methods_to_symbols[op_name],
op_name,
right_type,
Expand All @@ -2025,22 +2018,93 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]:
result = error_any, error_any
return result

def check_op(self, method: str, base_type: Type, arg: Expression,
context: Context,
def check_op(self, method: str, base_type: Type,
arg: Expression, context: Context,
allow_reverse: bool = False) -> Tuple[Type, Type]:
"""Type check a binary operation which maps to a method call.

Return tuple (result type, inferred operator method type).
"""

if allow_reverse:
return self.check_op_reversible(
op_name=method,
left_type=base_type,
left_expr=TempNode(base_type),
right_type=self.accept(arg),
right_expr=arg,
context=context)
left_variants = [base_type]
if isinstance(base_type, UnionType):
left_variants = [item for item in base_type.relevant_items()]
right_type = self.accept(arg)

# Step 1: We first try leaving the right arguments alone and destructure
# just the left ones. (Mypy can sometimes perform some more precise inference
# if we leave the right operands a union -- see testOperatorWithEmptyListAndSum.
msg = self.msg.clean_copy()
msg.disable_count = 0
all_results = []
all_inferred = []

for left_possible_type in left_variants:
result, inferred = self.check_op_reversible(
op_name=method,
left_type=left_possible_type,
left_expr=TempNode(left_possible_type),
right_type=right_type,
right_expr=arg,
context=context,
msg=msg)
all_results.append(result)
all_inferred.append(inferred)

if not msg.is_errors():
results_final = UnionType.make_simplified_union(all_results)
inferred_final = UnionType.make_simplified_union(all_inferred)
return results_final, inferred_final

# Step 2: If that fails, we try again but also destructure the right argument.
# This is also necessary to make certain edge cases work -- see
# testOperatorDoubleUnionInterwovenUnionAdd, for example.

# Note: We want to pass in the original 'arg' for 'left_expr' and 'right_expr'
# whenever possible so that plugins and similar things can introspect on the original
# node if possible.
#
# We don't do the same for the base expression because it could lead to weird
# type inference errors -- e.g. see 'testOperatorDoubleUnionSum'.
# TODO: Can we use `type_overrides_set()` here?
right_variants = [(right_type, arg)]
if isinstance(right_type, UnionType):
right_variants = [(item, TempNode(item)) for item in right_type.relevant_items()]
Copy link
Member

@ilevkivskyi ilevkivskyi Sep 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be that all these things will be simplified if you use

with self.type_overrides_set(args, arg_types): ...

?
I remember similar situations from union overloads, where an argument expression once gets in type map, and then you see weird errors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, that's a good point -- it didn't occur to me until you pointed it out that type_overrides_set is trying to solve the same problem.

I think the overall level of complexity is going to be about the same though -- e.g. if I switch to using type_overrides_set, I'd be able to simplify this list comprehension a bit but would need to add the with block to the doubly nested loop below.

I decided to keep the code the same for now mostly because I didn't want to have to worry about what would happen if you tried doing union math things on an overloaded operator. But if you think it's worth switching over to the other approach for consistency, LMK.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is OK. It is probably safer because a more robust implementation of type_overrides_set should allow "stacking" the overrides. But it is a separate topic, we can do this later if this pattern will be needed elsewhere.


msg = self.msg.clean_copy()
msg.disable_count = 0
all_results = []
all_inferred = []

for left_possible_type in left_variants:
for right_possible_type, right_expr in right_variants:
result, inferred = self.check_op_reversible(
op_name=method,
left_type=left_possible_type,
left_expr=TempNode(left_possible_type),
right_type=right_possible_type,
right_expr=right_expr,
context=context,
msg=msg)
all_results.append(result)
all_inferred.append(inferred)

if msg.is_errors():
self.msg.add_errors(msg)
if len(left_variants) >= 2 and len(right_variants) >= 2:
self.msg.warn_both_operands_are_from_unions(context)
elif len(left_variants) >= 2:
self.msg.warn_operand_was_from_union("Left", base_type, context)
elif len(right_variants) >= 2:
self.msg.warn_operand_was_from_union("Right", right_type, context)

# See the comment in 'check_overload_call' for more details on why
# we call 'combine_function_signature' instead of just unioning the inferred
# callable types.
results_final = UnionType.make_simplified_union(all_results)
inferred_final = self.combine_function_signatures(all_inferred)
return results_final, inferred_final
else:
return self.check_op_local_by_name(
method=method,
Expand Down
6 changes: 6 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,12 @@ def reverse_operator_method_never_called(self,
),
context=context)

def warn_both_operands_are_from_unions(self, context: Context) -> None:
self.note('Both left and right operands are unions', context)

def warn_operand_was_from_union(self, side: str, original: Type, context: Context) -> None:
self.note('{} operand is of type {}'.format(side, self.format(original)), context)

def operator_method_signatures_overlap(
self, reverse_class: TypeInfo, reverse_method: str, forward_class: Type,
forward_method: str, context: Context) -> None:
Expand Down
9 changes: 6 additions & 3 deletions test-data/unit/check-callable.test
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ from typing import Callable, Union
x = 5 # type: Union[int, Callable[[], str], Callable[[], int]]

if callable(x):
y = x() + 2 # E: Unsupported operand types for + (likely involving Union)
y = x() + 2 # E: Unsupported operand types for + ("str" and "int") \
# N: Left operand is of type "Union[str, int]"
else:
z = x + 6

Expand All @@ -60,7 +61,8 @@ x = 5 # type: Union[int, str, Callable[[], str]]
if callable(x):
y = x() + 'test'
else:
z = x + 6 # E: Unsupported operand types for + (likely involving Union)
z = x + 6 # E: Unsupported operand types for + ("str" and "int") \
# N: Left operand is of type "Union[int, str]"

[builtins fixtures/callable.pyi]

Expand Down Expand Up @@ -153,7 +155,8 @@ x = 5 # type: Union[int, Callable[[], str]]
if callable(x) and x() == 'test':
x()
else:
x + 5 # E: Unsupported left operand type for + (some union)
x + 5 # E: Unsupported left operand type for + ("Callable[[], str]") \
# N: Left operand is of type "Union[int, Callable[[], str]]"

[builtins fixtures/callable.pyi]

Expand Down
Loading