Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine how overload selection handles *args, **kwargs, and Any #5166

Merged
Diff settings

Always

Just for now

Copy path View file
@@ -1195,18 +1195,49 @@ def plausible_overload_call_targets(self,
arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
overload: Overloaded) -> List[CallableType]:
"""Returns all overload call targets that having matching argument counts."""
"""Returns all overload call targets that having matching argument counts.
If the given args contains a star-arg (*arg or **kwarg argument), this method
will ensure all star-arg overloads appear at the start of the list, instead
of their usual location.
The only exception is if the starred argument is something like a Tuple or a
NamedTuple, which has a definitive "shape". If so, we don't move the corresponding
alternative to the front since we can infer a more precise match using the original
order."""

def has_shape(typ: Type) -> bool:
# TODO: Once https://github.com/python/mypy/issues/5198 is fixed,
# add 'isinstance(typ, TypedDictType)' somewhere below.
return (isinstance(typ, TupleType)
or (isinstance(typ, Instance) and typ.type.is_named_tuple))

matches = [] # type: List[CallableType]
star_matches = [] # type: List[CallableType]

args_have_var_arg = False
args_have_kw_arg = False
for kind, typ in zip(arg_kinds, arg_types):
if kind == ARG_STAR and not has_shape(typ):
args_have_var_arg = True
if kind == ARG_STAR2 and not has_shape(typ):
args_have_kw_arg = True

for typ in overload.items():
formal_to_actual = map_actuals_to_formals(arg_kinds, arg_names,
typ.arg_kinds, typ.arg_names,
lambda i: arg_types[i])

if self.check_argument_count(typ, arg_types, arg_kinds, arg_names,
formal_to_actual, None, None):
matches.append(typ)
if args_have_var_arg and typ.is_var_arg:
star_matches.append(typ)
elif args_have_kw_arg and typ.is_kw_arg:
star_matches.append(typ)
else:
matches.append(typ)

return matches
return star_matches + matches

def infer_overload_return_type(self,
plausible_targets: List[CallableType],
@@ -1270,15 +1301,20 @@ def infer_overload_return_type(self,
return None
elif any_causes_overload_ambiguity(matches, return_types, arg_types, arg_kinds, arg_names):
# An argument of type or containing the type 'Any' caused ambiguity.
# We infer a type of 'Any'
return self.check_call(callee=AnyType(TypeOfAny.special_form),
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
arg_messages=arg_messages,
callable_name=callable_name,
object_type=object_type)
if all(is_subtype(ret_type, return_types[-1]) for ret_type in return_types[:-1]):
# The last match is a supertype of all the previous ones, so it's safe
# to return that inferred type.
return return_types[-1], inferred_types[-1]
else:
# We give up and return 'Any'.
return self.check_call(callee=AnyType(TypeOfAny.special_form),
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
arg_messages=arg_messages,
callable_name=callable_name,
object_type=object_type)
else:
# Success! No ambiguity; return the first match.
return return_types[0], inferred_types[0]
@@ -3174,16 +3210,20 @@ def any_causes_overload_ambiguity(items: List[CallableType],
matching_formals_unfiltered = [(item_idx, lookup[arg_idx])
for item_idx, lookup in enumerate(actual_to_formal)
if lookup[arg_idx]]

matching_returns = []
matching_formals = []
for item_idx, formals in matching_formals_unfiltered:
if len(formals) > 1:
# An actual maps to multiple formals -- give up as too
# complex, just assume it overlaps.
return True
matching_formals.append((item_idx, items[item_idx].arg_types[formals[0]]))
if (not all_same_types(t for _, t in matching_formals) and
not all_same_types(items[idx].ret_type
for idx, _ in matching_formals)):
matched_callable = items[item_idx]
matching_returns.append(matched_callable.ret_type)

# Note: if an actual maps to multiple formals of differing types within
# a single callable, then we know at least one of those formals must be
# a different type then the formal(s) in some other callable.
# So it's safe to just append everything to the same list.
for formal in formals:
matching_formals.append(matched_callable.arg_types[formal])
if not all_same_types(matching_formals) and not all_same_types(matching_returns):
# Any maps to multiple different types, and the return types of these items differ.
return True
return False
Oops, something went wrong.
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.