Skip to content

Commit

Permalink
New type inference: complete transitive closure (#15754)
Browse files Browse the repository at this point in the history
This is a first follow-up for #15287 (I like how my PR titles sound like
research paper titles, LOL)

This PR completes the new type inference foundations by switching to a
complete and well founded algorithm [1] for transitive closure (that
replaces more ad hoc initial algorithm that covered 80% of cases and was
good for experimenting with new inference scheme). In particular the
algorithm in this PR covers two important edge cases (see tests). Some
comments:
* I don't intend to switch the default for `--new-type-inference`, I
just want to see the effect of the switch on `mypy_primer`, I will
switch back to false before merging
* This flag is still not ready to be publicly announced, I am going to
make another 2-3 PRs from the list in #15287 before making this public.
* I am not adding yet the unit tests as discussed in previous PR. This
PR is already quite big, and the next one (support for upper bounds and
values) should be much smaller. I am going to add unit tests only for
`transitive_closure()` which is the core of new logic.
* While working on this I fixed couple bugs exposed in `TypeVarTuple`
support: one is rare technical corner case, another one is serious,
template and actual where swapped during constraint inference,
effectively causing outer/return context to be completely ignored for
instances.
* It is better to review the PR with "ignore whitespace" option turned
on (there is big chunk in solve.py that is just change of indentation).
* There is one questionable design choice I am making in this PR, I am
adding `extra_tvars` as an attribute of `Constraint` class, while it
logically should not be attributed to any individual constraint, but
rather to the full list of constrains. However, doing this properly
would require changing the return type of `infer_constrains()` and all
related functions, which would be a really big refactoring.

[1] Definition 7.1 in https://inria.hal.science/inria-00073205/document

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ilevkivskyi and pre-commit-ci[bot] committed Aug 3, 2023
1 parent 2b613e5 commit 0d708cb
Show file tree
Hide file tree
Showing 13 changed files with 356 additions and 319 deletions.
72 changes: 26 additions & 46 deletions mypy/checker.py
Expand Up @@ -734,8 +734,10 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
# def foo(x: str) -> str: ...
#
# See Python 2's map function for a concrete example of this kind of overload.
current_class = self.scope.active_class()
type_vars = current_class.defn.type_vars if current_class else []
with state.strict_optional_set(True):
if is_unsafe_overlapping_overload_signatures(sig1, sig2):
if is_unsafe_overlapping_overload_signatures(sig1, sig2, type_vars):
self.msg.overloaded_signatures_overlap(i + 1, i + j + 2, item.func)

if impl_type is not None:
Expand Down Expand Up @@ -1702,7 +1704,9 @@ def is_unsafe_overlapping_op(
first = forward_tweaked
second = reverse_tweaked

return is_unsafe_overlapping_overload_signatures(first, second)
current_class = self.scope.active_class()
type_vars = current_class.defn.type_vars if current_class else []
return is_unsafe_overlapping_overload_signatures(first, second, type_vars)

def check_inplace_operator_method(self, defn: FuncBase) -> None:
"""Check an inplace operator method such as __iadd__.
Expand Down Expand Up @@ -3918,11 +3922,12 @@ def is_valid_defaultdict_partial_value_type(self, t: ProperType) -> bool:
return True
if len(t.args) == 1:
arg = get_proper_type(t.args[0])
# TODO: This is too permissive -- we only allow TypeVarType since
# they leak in cases like defaultdict(list) due to a bug.
# This can result in incorrect types being inferred, but only
# in rare cases.
if isinstance(arg, (TypeVarType, UninhabitedType, NoneType)):
if self.options.new_type_inference:
allowed = isinstance(arg, (UninhabitedType, NoneType))
else:
# Allow leaked TypeVars for legacy inference logic.
allowed = isinstance(arg, (UninhabitedType, NoneType, TypeVarType))
if allowed:
return True
return False

Expand Down Expand Up @@ -7179,7 +7184,7 @@ def are_argument_counts_overlapping(t: CallableType, s: CallableType) -> bool:


def is_unsafe_overlapping_overload_signatures(
signature: CallableType, other: CallableType
signature: CallableType, other: CallableType, class_type_vars: list[TypeVarLikeType]
) -> bool:
"""Check if two overloaded signatures are unsafely overlapping or partially overlapping.
Expand All @@ -7198,8 +7203,8 @@ def is_unsafe_overlapping_overload_signatures(
# This lets us identify cases where the two signatures use completely
# incompatible types -- e.g. see the testOverloadingInferUnionReturnWithMixedTypevars
# test case.
signature = detach_callable(signature)
other = detach_callable(other)
signature = detach_callable(signature, class_type_vars)
other = detach_callable(other, class_type_vars)

# Note: We repeat this check twice in both directions due to a slight
# asymmetry in 'is_callable_compatible'. When checking for partial overlaps,
Expand Down Expand Up @@ -7230,7 +7235,7 @@ def is_unsafe_overlapping_overload_signatures(
)


def detach_callable(typ: CallableType) -> CallableType:
def detach_callable(typ: CallableType, class_type_vars: list[TypeVarLikeType]) -> CallableType:
"""Ensures that the callable's type variables are 'detached' and independent of the context.
A callable normally keeps track of the type variables it uses within its 'variables' field.
Expand All @@ -7240,42 +7245,17 @@ def detach_callable(typ: CallableType) -> CallableType:
This function will traverse the callable and find all used type vars and add them to the
variables field if it isn't already present.
The caller can then unify on all type variables whether or not the callable is originally
from a class or not."""
type_list = typ.arg_types + [typ.ret_type]

appear_map: dict[str, list[int]] = {}
for i, inner_type in enumerate(type_list):
typevars_available = get_type_vars(inner_type)
for var in typevars_available:
if var.fullname not in appear_map:
appear_map[var.fullname] = []
appear_map[var.fullname].append(i)

used_type_var_names = set()
for var_name, appearances in appear_map.items():
used_type_var_names.add(var_name)

all_type_vars = get_type_vars(typ)
new_variables = []
for var in set(all_type_vars):
if var.fullname not in used_type_var_names:
continue
new_variables.append(
TypeVarType(
name=var.name,
fullname=var.fullname,
id=var.id,
values=var.values,
upper_bound=var.upper_bound,
default=var.default,
variance=var.variance,
)
)
out = typ.copy_modified(
variables=new_variables, arg_types=type_list[:-1], ret_type=type_list[-1]
The caller can then unify on all type variables whether the callable is originally from
the class or not."""
if not class_type_vars:
# Fast path, nothing to update.
return typ
seen_type_vars = set()
for t in typ.arg_types + [typ.ret_type]:
seen_type_vars |= set(get_type_vars(t))
return typ.copy_modified(
variables=list(typ.variables) + [tv for tv in class_type_vars if tv in seen_type_vars]
)
return out


def overload_can_never_match(signature: CallableType, other: CallableType) -> bool:
Expand Down
40 changes: 19 additions & 21 deletions mypy/checkexpr.py
Expand Up @@ -1857,7 +1857,7 @@ def infer_function_type_arguments_using_context(
# expects_literal(identity(3)) # Should type-check
if not is_generic_instance(ctx) and not is_literal_type_like(ctx):
return callable.copy_modified()
args = infer_type_arguments(callable.type_var_ids(), ret_type, erased_ctx)
args = infer_type_arguments(callable.variables, ret_type, erased_ctx)
# Only substitute non-Uninhabited and non-erased types.
new_args: list[Type | None] = []
for arg in args:
Expand Down Expand Up @@ -1906,7 +1906,7 @@ def infer_function_type_arguments(
else:
pass1_args.append(arg)

inferred_args = infer_function_type_arguments(
inferred_args, _ = infer_function_type_arguments(
callee_type,
pass1_args,
arg_kinds,
Expand Down Expand Up @@ -1948,7 +1948,7 @@ def infer_function_type_arguments(
# variables while allowing for polymorphic solutions, i.e. for solutions
# potentially involving free variables.
# TODO: support the similar inference for return type context.
poly_inferred_args = infer_function_type_arguments(
poly_inferred_args, free_vars = infer_function_type_arguments(
callee_type,
arg_types,
arg_kinds,
Expand All @@ -1957,30 +1957,28 @@ def infer_function_type_arguments(
strict=self.chk.in_checked_function(),
allow_polymorphic=True,
)
for i, pa in enumerate(get_proper_types(poly_inferred_args)):
if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa):
# Indicate that free variables should not be applied in the call below.
poly_inferred_args[i] = None
poly_callee_type = self.apply_generic_arguments(
callee_type, poly_inferred_args, context
)
yes_vars = poly_callee_type.variables
no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables}
if not set(get_type_vars(poly_callee_type)) & no_vars:
# Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can
# be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed.
applied = apply_poly(poly_callee_type, yes_vars)
if applied is not None and poly_inferred_args != [UninhabitedType()] * len(
poly_inferred_args
):
freeze_all_type_vars(applied)
return applied
# Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can
# be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed.
applied = apply_poly(poly_callee_type, free_vars)
if applied is not None and all(
a is not None and not isinstance(get_proper_type(a), UninhabitedType)
for a in poly_inferred_args
):
freeze_all_type_vars(applied)
return applied
# If it didn't work, erase free variables as <nothing>, to avoid confusing errors.
unknown = UninhabitedType()
unknown.ambiguous = True
inferred_args = [
expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables})
expand_type(
a, {v.id: unknown for v in list(callee_type.variables) + free_vars}
)
if a is not None
else None
for a in inferred_args
for a in poly_inferred_args
]
else:
# In dynamically typed functions use implicit 'Any' types for
Expand Down Expand Up @@ -2019,7 +2017,7 @@ def infer_function_type_arguments_pass2(

arg_types = self.infer_arg_types_in_context(callee_type, args, arg_kinds, formal_to_actual)

inferred_args = infer_function_type_arguments(
inferred_args, _ = infer_function_type_arguments(
callee_type,
arg_types,
arg_kinds,
Expand Down
63 changes: 41 additions & 22 deletions mypy/constraints.py
Expand Up @@ -73,6 +73,10 @@ def __init__(self, type_var: TypeVarLikeType, op: int, target: Type) -> None:
self.op = op
self.target = target
self.origin_type_var = type_var
# These are additional type variables that should be solved for together with type_var.
# TODO: A cleaner solution may be to modify the return type of infer_constraints()
# to include these instead, but this is a rather big refactoring.
self.extra_tvars: list[TypeVarLikeType] = []

def __repr__(self) -> str:
op_str = "<:"
Expand Down Expand Up @@ -168,7 +172,9 @@ def infer_constraints_for_callable(
return constraints


def infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]:
def infer_constraints(
template: Type, actual: Type, direction: int, skip_neg_op: bool = False
) -> list[Constraint]:
"""Infer type constraints.
Match a template type, which may contain type variable references,
Expand All @@ -187,7 +193,9 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons
((T, S), (X, Y)) --> T :> X and S :> Y
(X[T], Any) --> T <: Any and T :> Any
The constraints are represented as Constraint objects.
The constraints are represented as Constraint objects. If skip_neg_op == True,
then skip adding reverse (polymorphic) constraints (since this is already a call
to infer such constraints).
"""
if any(
get_proper_type(template) == get_proper_type(t)
Expand All @@ -202,13 +210,15 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons
# Return early on an empty branch.
return []
type_state.inferring.append((template, actual))
res = _infer_constraints(template, actual, direction)
res = _infer_constraints(template, actual, direction, skip_neg_op)
type_state.inferring.pop()
return res
return _infer_constraints(template, actual, direction)
return _infer_constraints(template, actual, direction, skip_neg_op)


def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]:
def _infer_constraints(
template: Type, actual: Type, direction: int, skip_neg_op: bool
) -> list[Constraint]:
orig_template = template
template = get_proper_type(template)
actual = get_proper_type(actual)
Expand Down Expand Up @@ -284,7 +294,7 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Con
return []

# Remaining cases are handled by ConstraintBuilderVisitor.
return template.accept(ConstraintBuilderVisitor(actual, direction))
return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op))


def infer_constraints_if_possible(
Expand Down Expand Up @@ -510,10 +520,14 @@ class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]):
# TODO: The value may be None. Is that actually correct?
actual: ProperType

def __init__(self, actual: ProperType, direction: int) -> None:
def __init__(self, actual: ProperType, direction: int, skip_neg_op: bool) -> None:
# Direction must be SUBTYPE_OF or SUPERTYPE_OF.
self.actual = actual
self.direction = direction
# Whether to skip polymorphic inference (involves inference in opposite direction)
# this is used to prevent infinite recursion when both template and actual are
# generic callables.
self.skip_neg_op = skip_neg_op

# Trivial leaf types

Expand Down Expand Up @@ -648,13 +662,13 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
assert mapped.type.type_var_tuple_prefix is not None
assert mapped.type.type_var_tuple_suffix is not None

unpack_constraints, mapped_args, instance_args = build_constraints_for_unpack(
mapped.args,
mapped.type.type_var_tuple_prefix,
mapped.type.type_var_tuple_suffix,
unpack_constraints, instance_args, mapped_args = build_constraints_for_unpack(
instance.args,
instance.type.type_var_tuple_prefix,
instance.type.type_var_tuple_suffix,
mapped.args,
mapped.type.type_var_tuple_prefix,
mapped.type.type_var_tuple_suffix,
self.direction,
)
res.extend(unpack_constraints)
Expand Down Expand Up @@ -879,6 +893,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# Note that non-normalized callables can be created in annotations
# using e.g. callback protocols.
template = template.with_unpacked_kwargs()
extra_tvars = False
if isinstance(self.actual, CallableType):
res: list[Constraint] = []
cactual = self.actual.with_unpacked_kwargs()
Expand All @@ -890,25 +905,23 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
type_state.infer_polymorphic
and cactual.variables
and cactual.param_spec() is None
and not self.skip_neg_op
# Technically, the correct inferred type for application of e.g.
# Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic
# like U -> U, should be Callable[..., Any], but if U is a self-type, we can
# allow it to leak, to be later bound to self. A bunch of existing code
# depends on this old behaviour.
and not any(tv.id.raw_id == 0 for tv in cactual.variables)
):
# If actual is generic, unify it with template. Note: this is
# not an ideal solution (which would be adding the generic variables
# to the constraint inference set), but it's a good first approximation,
# and this will prevent leaking these variables in the solutions.
# Note: this may infer constraints like T <: S or T <: List[S]
# that contain variables in the target.
unified = mypy.subtypes.unify_generic_callable(
cactual, template, ignore_return=True
# If the actual callable is generic, infer constraints in the opposite
# direction, and indicate to the solver there are extra type variables
# to solve for (see more details in mypy/solve.py).
res.extend(
infer_constraints(
cactual, template, neg_op(self.direction), skip_neg_op=True
)
)
if unified is not None:
cactual = unified
res.extend(infer_constraints(cactual, template, neg_op(self.direction)))
extra_tvars = True

# We can't infer constraints from arguments if the template is Callable[..., T]
# (with literal '...').
Expand Down Expand Up @@ -978,6 +991,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
cactual_ret_type = cactual.type_guard

res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))
if extra_tvars:
for c in res:
c.extra_tvars = list(cactual.variables)
return res
elif isinstance(self.actual, AnyType):
param_spec = template.param_spec()
Expand Down Expand Up @@ -1205,6 +1221,9 @@ def find_and_build_constraints_for_unpack(


def build_constraints_for_unpack(
# TODO: this naming is misleading, these should be "actual", not "mapped"
# both template and actual can be mapped before, depending on direction.
# Also the convention is to put template related args first.
mapped: tuple[Type, ...],
mapped_prefix_len: int | None,
mapped_suffix_len: int | None,
Expand Down
5 changes: 5 additions & 0 deletions mypy/expandtype.py
Expand Up @@ -272,6 +272,11 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
return repl

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
# Sometimes solver may need to expand a type variable with (a copy of) itself
# (usually together with other TypeVars, but it is hard to filter out TypeVarTuples).
repl = self.variables[t.id]
if isinstance(repl, TypeVarTupleType):
return repl
raise NotImplementedError

def visit_unpack_type(self, t: UnpackType) -> Type:
Expand Down

0 comments on commit 0d708cb

Please sign in to comment.