diff --git a/mypy/checker.py b/mypy/checker.py index 0c27da8b5ac8..b786155079e5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -734,8 +734,10 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # def foo(x: str) -> str: ... # # See Python 2's map function for a concrete example of this kind of overload. + current_class = self.scope.active_class() + type_vars = current_class.defn.type_vars if current_class else [] with state.strict_optional_set(True): - if is_unsafe_overlapping_overload_signatures(sig1, sig2): + if is_unsafe_overlapping_overload_signatures(sig1, sig2, type_vars): self.msg.overloaded_signatures_overlap(i + 1, i + j + 2, item.func) if impl_type is not None: @@ -1702,7 +1704,9 @@ def is_unsafe_overlapping_op( first = forward_tweaked second = reverse_tweaked - return is_unsafe_overlapping_overload_signatures(first, second) + current_class = self.scope.active_class() + type_vars = current_class.defn.type_vars if current_class else [] + return is_unsafe_overlapping_overload_signatures(first, second, type_vars) def check_inplace_operator_method(self, defn: FuncBase) -> None: """Check an inplace operator method such as __iadd__. @@ -3918,11 +3922,12 @@ def is_valid_defaultdict_partial_value_type(self, t: ProperType) -> bool: return True if len(t.args) == 1: arg = get_proper_type(t.args[0]) - # TODO: This is too permissive -- we only allow TypeVarType since - # they leak in cases like defaultdict(list) due to a bug. - # This can result in incorrect types being inferred, but only - # in rare cases. - if isinstance(arg, (TypeVarType, UninhabitedType, NoneType)): + if self.options.new_type_inference: + allowed = isinstance(arg, (UninhabitedType, NoneType)) + else: + # Allow leaked TypeVars for legacy inference logic. + allowed = isinstance(arg, (UninhabitedType, NoneType, TypeVarType)) + if allowed: return True return False @@ -7179,7 +7184,7 @@ def are_argument_counts_overlapping(t: CallableType, s: CallableType) -> bool: def is_unsafe_overlapping_overload_signatures( - signature: CallableType, other: CallableType + signature: CallableType, other: CallableType, class_type_vars: list[TypeVarLikeType] ) -> bool: """Check if two overloaded signatures are unsafely overlapping or partially overlapping. @@ -7198,8 +7203,8 @@ def is_unsafe_overlapping_overload_signatures( # This lets us identify cases where the two signatures use completely # incompatible types -- e.g. see the testOverloadingInferUnionReturnWithMixedTypevars # test case. - signature = detach_callable(signature) - other = detach_callable(other) + signature = detach_callable(signature, class_type_vars) + other = detach_callable(other, class_type_vars) # Note: We repeat this check twice in both directions due to a slight # asymmetry in 'is_callable_compatible'. When checking for partial overlaps, @@ -7230,7 +7235,7 @@ def is_unsafe_overlapping_overload_signatures( ) -def detach_callable(typ: CallableType) -> CallableType: +def detach_callable(typ: CallableType, class_type_vars: list[TypeVarLikeType]) -> CallableType: """Ensures that the callable's type variables are 'detached' and independent of the context. A callable normally keeps track of the type variables it uses within its 'variables' field. @@ -7240,42 +7245,17 @@ def detach_callable(typ: CallableType) -> CallableType: This function will traverse the callable and find all used type vars and add them to the variables field if it isn't already present. - The caller can then unify on all type variables whether or not the callable is originally - from a class or not.""" - type_list = typ.arg_types + [typ.ret_type] - - appear_map: dict[str, list[int]] = {} - for i, inner_type in enumerate(type_list): - typevars_available = get_type_vars(inner_type) - for var in typevars_available: - if var.fullname not in appear_map: - appear_map[var.fullname] = [] - appear_map[var.fullname].append(i) - - used_type_var_names = set() - for var_name, appearances in appear_map.items(): - used_type_var_names.add(var_name) - - all_type_vars = get_type_vars(typ) - new_variables = [] - for var in set(all_type_vars): - if var.fullname not in used_type_var_names: - continue - new_variables.append( - TypeVarType( - name=var.name, - fullname=var.fullname, - id=var.id, - values=var.values, - upper_bound=var.upper_bound, - default=var.default, - variance=var.variance, - ) - ) - out = typ.copy_modified( - variables=new_variables, arg_types=type_list[:-1], ret_type=type_list[-1] + The caller can then unify on all type variables whether the callable is originally from + the class or not.""" + if not class_type_vars: + # Fast path, nothing to update. + return typ + seen_type_vars = set() + for t in typ.arg_types + [typ.ret_type]: + seen_type_vars |= set(get_type_vars(t)) + return typ.copy_modified( + variables=list(typ.variables) + [tv for tv in class_type_vars if tv in seen_type_vars] ) - return out def overload_can_never_match(signature: CallableType, other: CallableType) -> bool: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 114cde8327e0..9e46d9ee39cb 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1857,7 +1857,7 @@ def infer_function_type_arguments_using_context( # expects_literal(identity(3)) # Should type-check if not is_generic_instance(ctx) and not is_literal_type_like(ctx): return callable.copy_modified() - args = infer_type_arguments(callable.type_var_ids(), ret_type, erased_ctx) + args = infer_type_arguments(callable.variables, ret_type, erased_ctx) # Only substitute non-Uninhabited and non-erased types. new_args: list[Type | None] = [] for arg in args: @@ -1906,7 +1906,7 @@ def infer_function_type_arguments( else: pass1_args.append(arg) - inferred_args = infer_function_type_arguments( + inferred_args, _ = infer_function_type_arguments( callee_type, pass1_args, arg_kinds, @@ -1948,7 +1948,7 @@ def infer_function_type_arguments( # variables while allowing for polymorphic solutions, i.e. for solutions # potentially involving free variables. # TODO: support the similar inference for return type context. - poly_inferred_args = infer_function_type_arguments( + poly_inferred_args, free_vars = infer_function_type_arguments( callee_type, arg_types, arg_kinds, @@ -1957,30 +1957,28 @@ def infer_function_type_arguments( strict=self.chk.in_checked_function(), allow_polymorphic=True, ) - for i, pa in enumerate(get_proper_types(poly_inferred_args)): - if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa): - # Indicate that free variables should not be applied in the call below. - poly_inferred_args[i] = None poly_callee_type = self.apply_generic_arguments( callee_type, poly_inferred_args, context ) - yes_vars = poly_callee_type.variables - no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables} - if not set(get_type_vars(poly_callee_type)) & no_vars: - # Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can - # be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed. - applied = apply_poly(poly_callee_type, yes_vars) - if applied is not None and poly_inferred_args != [UninhabitedType()] * len( - poly_inferred_args - ): - freeze_all_type_vars(applied) - return applied + # Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can + # be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed. + applied = apply_poly(poly_callee_type, free_vars) + if applied is not None and all( + a is not None and not isinstance(get_proper_type(a), UninhabitedType) + for a in poly_inferred_args + ): + freeze_all_type_vars(applied) + return applied # If it didn't work, erase free variables as , to avoid confusing errors. + unknown = UninhabitedType() + unknown.ambiguous = True inferred_args = [ - expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables}) + expand_type( + a, {v.id: unknown for v in list(callee_type.variables) + free_vars} + ) if a is not None else None - for a in inferred_args + for a in poly_inferred_args ] else: # In dynamically typed functions use implicit 'Any' types for @@ -2019,7 +2017,7 @@ def infer_function_type_arguments_pass2( arg_types = self.infer_arg_types_in_context(callee_type, args, arg_kinds, formal_to_actual) - inferred_args = infer_function_type_arguments( + inferred_args, _ = infer_function_type_arguments( callee_type, arg_types, arg_kinds, diff --git a/mypy/constraints.py b/mypy/constraints.py index f9124630a706..299c6292a259 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -73,6 +73,10 @@ def __init__(self, type_var: TypeVarLikeType, op: int, target: Type) -> None: self.op = op self.target = target self.origin_type_var = type_var + # These are additional type variables that should be solved for together with type_var. + # TODO: A cleaner solution may be to modify the return type of infer_constraints() + # to include these instead, but this is a rather big refactoring. + self.extra_tvars: list[TypeVarLikeType] = [] def __repr__(self) -> str: op_str = "<:" @@ -168,7 +172,9 @@ def infer_constraints_for_callable( return constraints -def infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]: +def infer_constraints( + template: Type, actual: Type, direction: int, skip_neg_op: bool = False +) -> list[Constraint]: """Infer type constraints. Match a template type, which may contain type variable references, @@ -187,7 +193,9 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons ((T, S), (X, Y)) --> T :> X and S :> Y (X[T], Any) --> T <: Any and T :> Any - The constraints are represented as Constraint objects. + The constraints are represented as Constraint objects. If skip_neg_op == True, + then skip adding reverse (polymorphic) constraints (since this is already a call + to infer such constraints). """ if any( get_proper_type(template) == get_proper_type(t) @@ -202,13 +210,15 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons # Return early on an empty branch. return [] type_state.inferring.append((template, actual)) - res = _infer_constraints(template, actual, direction) + res = _infer_constraints(template, actual, direction, skip_neg_op) type_state.inferring.pop() return res - return _infer_constraints(template, actual, direction) + return _infer_constraints(template, actual, direction, skip_neg_op) -def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]: +def _infer_constraints( + template: Type, actual: Type, direction: int, skip_neg_op: bool +) -> list[Constraint]: orig_template = template template = get_proper_type(template) actual = get_proper_type(actual) @@ -284,7 +294,7 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Con return [] # Remaining cases are handled by ConstraintBuilderVisitor. - return template.accept(ConstraintBuilderVisitor(actual, direction)) + return template.accept(ConstraintBuilderVisitor(actual, direction, skip_neg_op)) def infer_constraints_if_possible( @@ -510,10 +520,14 @@ class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]): # TODO: The value may be None. Is that actually correct? actual: ProperType - def __init__(self, actual: ProperType, direction: int) -> None: + def __init__(self, actual: ProperType, direction: int, skip_neg_op: bool) -> None: # Direction must be SUBTYPE_OF or SUPERTYPE_OF. self.actual = actual self.direction = direction + # Whether to skip polymorphic inference (involves inference in opposite direction) + # this is used to prevent infinite recursion when both template and actual are + # generic callables. + self.skip_neg_op = skip_neg_op # Trivial leaf types @@ -648,13 +662,13 @@ def visit_instance(self, template: Instance) -> list[Constraint]: assert mapped.type.type_var_tuple_prefix is not None assert mapped.type.type_var_tuple_suffix is not None - unpack_constraints, mapped_args, instance_args = build_constraints_for_unpack( - mapped.args, - mapped.type.type_var_tuple_prefix, - mapped.type.type_var_tuple_suffix, + unpack_constraints, instance_args, mapped_args = build_constraints_for_unpack( instance.args, instance.type.type_var_tuple_prefix, instance.type.type_var_tuple_suffix, + mapped.args, + mapped.type.type_var_tuple_prefix, + mapped.type.type_var_tuple_suffix, self.direction, ) res.extend(unpack_constraints) @@ -879,6 +893,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # Note that non-normalized callables can be created in annotations # using e.g. callback protocols. template = template.with_unpacked_kwargs() + extra_tvars = False if isinstance(self.actual, CallableType): res: list[Constraint] = [] cactual = self.actual.with_unpacked_kwargs() @@ -890,6 +905,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: type_state.infer_polymorphic and cactual.variables and cactual.param_spec() is None + and not self.skip_neg_op # Technically, the correct inferred type for application of e.g. # Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic # like U -> U, should be Callable[..., Any], but if U is a self-type, we can @@ -897,18 +913,15 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: # depends on this old behaviour. and not any(tv.id.raw_id == 0 for tv in cactual.variables) ): - # If actual is generic, unify it with template. Note: this is - # not an ideal solution (which would be adding the generic variables - # to the constraint inference set), but it's a good first approximation, - # and this will prevent leaking these variables in the solutions. - # Note: this may infer constraints like T <: S or T <: List[S] - # that contain variables in the target. - unified = mypy.subtypes.unify_generic_callable( - cactual, template, ignore_return=True + # If the actual callable is generic, infer constraints in the opposite + # direction, and indicate to the solver there are extra type variables + # to solve for (see more details in mypy/solve.py). + res.extend( + infer_constraints( + cactual, template, neg_op(self.direction), skip_neg_op=True + ) ) - if unified is not None: - cactual = unified - res.extend(infer_constraints(cactual, template, neg_op(self.direction))) + extra_tvars = True # We can't infer constraints from arguments if the template is Callable[..., T] # (with literal '...'). @@ -978,6 +991,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: cactual_ret_type = cactual.type_guard res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction)) + if extra_tvars: + for c in res: + c.extra_tvars = list(cactual.variables) return res elif isinstance(self.actual, AnyType): param_spec = template.param_spec() @@ -1205,6 +1221,9 @@ def find_and_build_constraints_for_unpack( def build_constraints_for_unpack( + # TODO: this naming is misleading, these should be "actual", not "mapped" + # both template and actual can be mapped before, depending on direction. + # Also the convention is to put template related args first. mapped: tuple[Type, ...], mapped_prefix_len: int | None, mapped_suffix_len: int | None, diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 83d9bf4c8725..b599b49e4c12 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -272,6 +272,11 @@ def visit_param_spec(self, t: ParamSpecType) -> Type: return repl def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type: + # Sometimes solver may need to expand a type variable with (a copy of) itself + # (usually together with other TypeVars, but it is hard to filter out TypeVarTuples). + repl = self.variables[t.id] + if isinstance(repl, TypeVarTupleType): + return repl raise NotImplementedError def visit_unpack_type(self, t: UnpackType) -> Type: diff --git a/mypy/infer.py b/mypy/infer.py index 66ca4169e2ff..f34087910e4b 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -12,7 +12,7 @@ ) from mypy.nodes import ArgKind from mypy.solve import solve_constraints -from mypy.types import CallableType, Instance, Type, TypeVarId +from mypy.types import CallableType, Instance, Type, TypeVarLikeType class ArgumentInferContext(NamedTuple): @@ -37,7 +37,7 @@ def infer_function_type_arguments( context: ArgumentInferContext, strict: bool = True, allow_polymorphic: bool = False, -) -> list[Type | None]: +) -> tuple[list[Type | None], list[TypeVarLikeType]]: """Infer the type arguments of a generic function. Return an array of lower bound types for the type variables -1 (at @@ -57,14 +57,14 @@ def infer_function_type_arguments( ) # Solve constraints. - type_vars = callee_type.type_var_ids() + type_vars = callee_type.variables return solve_constraints(type_vars, constraints, strict, allow_polymorphic) def infer_type_arguments( - type_var_ids: list[TypeVarId], template: Type, actual: Type, is_supertype: bool = False + type_vars: Sequence[TypeVarLikeType], template: Type, actual: Type, is_supertype: bool = False ) -> list[Type | None]: # Like infer_function_type_arguments, but only match a single type # against a generic type. constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF) - return solve_constraints(type_var_ids, constraints) + return solve_constraints(type_vars, constraints)[0] diff --git a/mypy/solve.py b/mypy/solve.py index 6693d66f3479..02df90aff1e1 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -2,9 +2,11 @@ from __future__ import annotations -from typing import Iterable +from collections import defaultdict +from typing import Iterable, Sequence +from typing_extensions import TypeAlias as _TypeAlias -from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, neg_op +from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op from mypy.expandtype import expand_type from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort from mypy.join import join_types @@ -17,6 +19,7 @@ Type, TypeOfAny, TypeVarId, + TypeVarLikeType, TypeVarType, UninhabitedType, UnionType, @@ -25,45 +28,72 @@ ) from mypy.typestate import type_state +Bounds: _TypeAlias = "dict[TypeVarId, set[Type]]" +Graph: _TypeAlias = "set[tuple[TypeVarId, TypeVarId]]" +Solutions: _TypeAlias = "dict[TypeVarId, Type | None]" + def solve_constraints( - vars: list[TypeVarId], + original_vars: Sequence[TypeVarLikeType], constraints: list[Constraint], strict: bool = True, allow_polymorphic: bool = False, -) -> list[Type | None]: +) -> tuple[list[Type | None], list[TypeVarLikeType]]: """Solve type constraints. - Return the best type(s) for type variables; each type can be None if the value of the variable - could not be solved. + Return the best type(s) for type variables; each type can be None if the value of + the variable could not be solved. If a variable has no constraints, if strict=True then arbitrarily - pick NoneType as the value of the type variable. If strict=False, - pick AnyType. + pick UninhabitedType as the value of the type variable. If strict=False, pick AnyType. + If allow_polymorphic=True, then use the full algorithm that can potentially return + free type variables in solutions (these require special care when applying). Otherwise, + use a simplified algorithm that just solves each type variable individually if possible. """ + vars = [tv.id for tv in original_vars] if not vars: - return [] + return [], [] + + originals = {tv.id: tv for tv in original_vars} + extra_vars: list[TypeVarId] = [] + # Get additional type variables from generic actuals. + for c in constraints: + extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars]) + originals.update({v.id: v for v in c.extra_tvars if v.id not in originals}) if allow_polymorphic: # Constraints like T :> S and S <: T are semantically the same, but they are # represented differently. Normalize the constraint list w.r.t this equivalence. - constraints = normalize_constraints(constraints, vars) + constraints = normalize_constraints(constraints, vars + extra_vars) # Collect a list of constraints for each type variable. - cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars} + cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars} for con in constraints: - if con.type_var in vars: + if con.type_var in vars + extra_vars: cmap[con.type_var].append(con) if allow_polymorphic: - solutions = solve_non_linear(vars, constraints, cmap) + if constraints: + solutions, free_vars = solve_with_dependent( + vars + extra_vars, constraints, vars, originals + ) + else: + solutions = {} + free_vars = [] else: solutions = {} + free_vars = [] for tv, cs in cmap.items(): if not cs: continue lowers = [c.target for c in cs if c.op == SUPERTYPE_OF] uppers = [c.target for c in cs if c.op == SUBTYPE_OF] - solutions[tv] = solve_one(lowers, uppers, []) + solution = solve_one(lowers, uppers) + + # Do not leak type variables in non-polymorphic solutions. + if solution is None or not get_vars( + solution, [tv for tv in extra_vars if tv not in vars] + ): + solutions[tv] = solution res: list[Type | None] = [] for v in vars: @@ -78,129 +108,128 @@ def solve_constraints( else: candidate = AnyType(TypeOfAny.special_form) res.append(candidate) - return res + return res, [originals[tv] for tv in free_vars] -def solve_non_linear( - vars: list[TypeVarId], constraints: list[Constraint], cmap: dict[TypeVarId, list[Constraint]] -) -> dict[TypeVarId, Type | None]: - """Solve set of constraints that may include non-linear ones, like T <: List[S]. +def solve_with_dependent( + vars: list[TypeVarId], + constraints: list[Constraint], + original_vars: list[TypeVarId], + originals: dict[TypeVarId, TypeVarLikeType], +) -> tuple[Solutions, list[TypeVarId]]: + """Solve set of constraints that may depend on each other, like T <: List[S]. The whole algorithm consists of five steps: - * Propagate via linear constraints to get all possible constraints for each variable + * Propagate via linear constraints and use secondary constraints to get transitive closure * Find dependencies between type variables, group them in SCCs, and sort topologically - * Check all SCC are intrinsically linear, we can't solve (express) T <: List[T] + * Check that all SCC are intrinsically linear, we can't solve (express) T <: List[T] * Variables in leaf SCCs that don't have constant bounds are free (choose one per SCC) - * Solve constraints iteratively starting from leafs, updating targets after each step. + * Solve constraints iteratively starting from leafs, updating bounds after each step. """ - extra_constraints = [] - for tvar in vars: - extra_constraints.extend(propagate_constraints_for(tvar, SUBTYPE_OF, cmap)) - extra_constraints.extend(propagate_constraints_for(tvar, SUPERTYPE_OF, cmap)) - constraints += remove_dups(extra_constraints) - - # Recompute constraint map after propagating. - cmap = {tv: [] for tv in vars} - for con in constraints: - if con.type_var in vars: - cmap[con.type_var].append(con) + graph, lowers, uppers = transitive_closure(vars, constraints) - dmap = compute_dependencies(cmap) + dmap = compute_dependencies(vars, graph, lowers, uppers) sccs = list(strongly_connected_components(set(vars), dmap)) - if all(check_linear(scc, cmap) for scc in sccs): - raw_batches = list(topsort(prepare_sccs(sccs, dmap))) - leafs = raw_batches[0] - free_vars = [] - for scc in leafs: - # If all constrain targets in this SCC are type variables within the - # same SCC then the only meaningful solution we can express, is that - # each variable is equal to a new free variable. For example if we - # have T <: S, S <: U, we deduce: T = S = U = . - if all( - isinstance(c.target, TypeVarType) and c.target.id in vars - for tv in scc - for c in cmap[tv] - ): - # For convenience with current type application machinery, we randomly - # choose one of the existing type variables in SCC and designate it as free - # instead of defining a new type variable as a common solution. - # TODO: be careful about upper bounds (or values) when introducing free vars. - free_vars.append(sorted(scc, key=lambda x: x.raw_id)[0]) - - # Flatten the SCCs that are independent, we can solve them together, - # since we don't need to update any targets in between. - batches = [] - for batch in raw_batches: - next_bc = [] - for scc in batch: - next_bc.extend(list(scc)) - batches.append(next_bc) - - solutions: dict[TypeVarId, Type | None] = {} - for flat_batch in batches: - solutions.update(solve_iteratively(flat_batch, cmap, free_vars)) - # We remove the solutions like T = T for free variables. This will indicate - # to the apply function, that they should not be touched. - # TODO: return list of free type variables explicitly, this logic is fragile - # (but if we do, we need to be careful everything works in incremental modes). - for tv in free_vars: - if tv in solutions: - del solutions[tv] - return solutions - return {} + if not all(check_linear(scc, lowers, uppers) for scc in sccs): + return {}, [] + raw_batches = list(topsort(prepare_sccs(sccs, dmap))) + + free_vars = [] + for scc in raw_batches[0]: + # If there are no bounds on this SCC, then the only meaningful solution we can + # express, is that each variable is equal to a new free variable. For example, + # if we have T <: S, S <: U, we deduce: T = S = U = . + if all(not lowers[tv] and not uppers[tv] for tv in scc): + # For convenience with current type application machinery, we use a stable + # choice that prefers the original type variables (not polymorphic ones) in SCC. + # TODO: be careful about upper bounds (or values) when introducing free vars. + free_vars.append(sorted(scc, key=lambda x: (x not in original_vars, x.raw_id))[0]) + + # Update lowers/uppers with free vars, so these can now be used + # as valid solutions. + for l, u in graph.copy(): + if l in free_vars: + lowers[u].add(originals[l]) + if u in free_vars: + uppers[l].add(originals[u]) + + # Flatten the SCCs that are independent, we can solve them together, + # since we don't need to update any targets in between. + batches = [] + for batch in raw_batches: + next_bc = [] + for scc in batch: + next_bc.extend(list(scc)) + batches.append(next_bc) + + solutions: dict[TypeVarId, Type | None] = {} + for flat_batch in batches: + res = solve_iteratively(flat_batch, graph, lowers, uppers) + solutions.update(res) + return solutions, free_vars def solve_iteratively( - batch: list[TypeVarId], cmap: dict[TypeVarId, list[Constraint]], free_vars: list[TypeVarId] -) -> dict[TypeVarId, Type | None]: - """Solve constraints sequentially, updating constraint targets after each step. - - We solve for type variables that appear in `batch`. If a constraint target is not constant - (i.e. constraint looks like T :> F[S, ...]), we substitute solutions found so far in - the target F[S, ...]. This way we can gradually solve for all variables in the batch taking - one solvable variable at a time (i.e. such a variable that has at least one constant bound). - - Importantly, variables in free_vars are considered constants, so for example if we have just - one initial constraint T <: List[S], we will have two SCCs {T} and {S}, then we first - designate S as free, and therefore T = List[S] is a valid solution for T. + batch: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds +) -> Solutions: + """Solve transitive closure sequentially, updating upper/lower bounds after each step. + + Transitive closure is represented as a linear graph plus lower/upper bounds for each + type variable, see transitive_closure() docstring for details. + + We solve for type variables that appear in `batch`. If a bound is not constant (i.e. it + looks like T :> F[S, ...]), we substitute solutions found so far in the target F[S, ...] + after solving the batch. + + Importantly, after solving each variable in a batch, we move it from linear graph to + upper/lower bounds, this way we can guarantee consistency of solutions (see comment below + for an example when this is important). """ solutions = {} - relevant_constraints = [] - for tv in batch: - relevant_constraints.extend(cmap.get(tv, [])) - lowers, uppers = transitive_closure(batch, relevant_constraints) s_batch = set(batch) - not_allowed_vars = [v for v in batch if v not in free_vars] while s_batch: - for tv in s_batch: - if any(not get_vars(l, not_allowed_vars) for l in lowers[tv]) or any( - not get_vars(u, not_allowed_vars) for u in uppers[tv] - ): + for tv in sorted(s_batch, key=lambda x: x.raw_id): + if lowers[tv] or uppers[tv]: solvable_tv = tv break else: break # Solve each solvable type variable separately. s_batch.remove(solvable_tv) - result = solve_one(lowers[solvable_tv], uppers[solvable_tv], not_allowed_vars) + result = solve_one(lowers[solvable_tv], uppers[solvable_tv]) solutions[solvable_tv] = result if result is None: - # TODO: support backtracking lower/upper bound choices + # TODO: support backtracking lower/upper bound choices and order within SCCs. # (will require switching this function from iterative to recursive). continue - # Update the (transitive) constraints if there is a solution. - subs = {solvable_tv: result} - lowers = {tv: {expand_type(l, subs) for l in lowers[tv]} for tv in lowers} - uppers = {tv: {expand_type(u, subs) for u in uppers[tv]} for tv in uppers} - for v in cmap: - for c in cmap[v]: - c.target = expand_type(c.target, subs) + + # Update the (transitive) bounds from graph if there is a solution. + # This is needed to guarantee solutions will never contradict the initial + # constraints. For example, consider {T <: S, T <: A, S :> B} with A :> B. + # If we would not update the uppers/lowers from graph, we would infer T = A, S = B + # which is not correct. + for l, u in graph.copy(): + if l == u: + continue + if l == solvable_tv: + lowers[u].add(result) + graph.remove((l, u)) + if u == solvable_tv: + uppers[l].add(result) + graph.remove((l, u)) + + # We can update uppers/lowers only once after solving the whole SCC, + # since uppers/lowers can't depend on type variables in the SCC + # (and we would reject such SCC as non-linear and therefore not solvable). + subs = {tv: s for (tv, s) in solutions.items() if s is not None} + for tv in lowers: + lowers[tv] = {expand_type(lt, subs) for lt in lowers[tv]} + for tv in uppers: + uppers[tv] = {expand_type(ut, subs) for ut in uppers[tv]} return solutions -def solve_one( - lowers: Iterable[Type], uppers: Iterable[Type], not_allowed_vars: list[TypeVarId] -) -> Type | None: +def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None: """Solve constraints by finding by using meets of upper bounds, and joins of lower bounds.""" bottom: Type | None = None top: Type | None = None @@ -210,10 +239,6 @@ def solve_one( # bounds based on constraints. Note that we assume that the constraint # targets do not have constraint references. for target in lowers: - # There may be multiple steps needed to solve all vars within a - # (linear) SCC. We ignore targets pointing to not yet solved vars. - if get_vars(target, not_allowed_vars): - continue if bottom is None: bottom = target else: @@ -225,9 +250,6 @@ def solve_one( bottom = join_types(bottom, target) for target in uppers: - # Same as above. - if get_vars(target, not_allowed_vars): - continue if top is None: top = target else: @@ -262,6 +284,7 @@ def normalize_constraints( This includes two things currently: * Complement T :> S by S <: T * Remove strict duplicates + * Remove constrains for unrelated variables """ res = constraints.copy() for c in constraints: @@ -270,96 +293,81 @@ def normalize_constraints( return [c for c in remove_dups(constraints) if c.type_var in vars] -def propagate_constraints_for( - var: TypeVarId, direction: int, cmap: dict[TypeVarId, list[Constraint]] -) -> list[Constraint]: - """Propagate via linear constraints to get additional constraints for `var`. - - For example if we have constraints: - [T <: int, S <: T, S :> str] - we can add two more - [S <: int, T :> str] - """ - extra_constraints = [] - seen = set() - front = [var] - if cmap[var]: - var_def = cmap[var][0].origin_type_var - else: - return [] - while front: - tv = front.pop(0) - for c in cmap[tv]: - if ( - isinstance(c.target, TypeVarType) - and c.target.id not in seen - and c.target.id in cmap - and c.op == direction - ): - front.append(c.target.id) - seen.add(c.target.id) - elif c.op == direction: - new_c = Constraint(var_def, direction, c.target) - if new_c not in cmap[var]: - extra_constraints.append(new_c) - return extra_constraints - - def transitive_closure( tvars: list[TypeVarId], constraints: list[Constraint] -) -> tuple[dict[TypeVarId, set[Type]], dict[TypeVarId, set[Type]]]: +) -> tuple[Graph, Bounds, Bounds]: """Find transitive closure for given constraints on type variables. Transitive closure gives maximal set of lower/upper bounds for each type variable, such that we cannot deduce any further bounds by chaining other existing bounds. + The transitive closure is represented by: + * A set of lower and upper bounds for each type variable, where only constant and + non-linear terms are included in the bounds. + * A graph of linear constraints between type variables (represented as a set of pairs) + Such separation simplifies reasoning, and allows an efficient and simple incremental + transitive closure algorithm that we use here. + For example if we have initial constraints [T <: S, S <: U, U <: int], the transitive closure is given by: - * {} <: T <: {S, U, int} - * {T} <: S <: {U, int} - * {T, S} <: U <: {int} + * {} <: T <: {int} + * {} <: S <: {int} + * {} <: U <: {int} + * {T <: S, S <: U, T <: U} """ - # TODO: merge propagate_constraints_for() into this function. - # TODO: add secondary constraints here to make the algorithm complete. - uppers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars} - lowers: dict[TypeVarId, set[Type]] = {tv: set() for tv in tvars} - graph: set[tuple[TypeVarId, TypeVarId]] = set() + uppers: Bounds = defaultdict(set) + lowers: Bounds = defaultdict(set) + graph: Graph = {(tv, tv) for tv in tvars} - # Prime the closure with the initial trivial values. - for c in constraints: - if isinstance(c.target, TypeVarType) and c.target.id in tvars: - if c.op == SUBTYPE_OF: - graph.add((c.type_var, c.target.id)) - else: - graph.add((c.target.id, c.type_var)) - if c.op == SUBTYPE_OF: - uppers[c.type_var].add(c.target) - else: - lowers[c.type_var].add(c.target) - - # At this stage we know that constant bounds have been propagated already, so we - # only need to propagate linear constraints. - for c in constraints: + remaining = set(constraints) + while remaining: + c = remaining.pop() if isinstance(c.target, TypeVarType) and c.target.id in tvars: if c.op == SUBTYPE_OF: lower, upper = c.type_var, c.target.id else: lower, upper = c.target.id, c.type_var - extras = { + if (lower, upper) in graph: + continue + graph |= { (l, u) for l in tvars for u in tvars if (l, lower) in graph and (upper, u) in graph } - graph |= extras for u in tvars: if (upper, u) in graph: lowers[u] |= lowers[lower] for l in tvars: if (l, lower) in graph: uppers[l] |= uppers[upper] - return lowers, uppers + for lt in lowers[lower]: + for ut in uppers[upper]: + # TODO: what if secondary constraints result in inference + # against polymorphic actual (also in below branches)? + remaining |= set(infer_constraints(lt, ut, SUBTYPE_OF)) + remaining |= set(infer_constraints(ut, lt, SUPERTYPE_OF)) + elif c.op == SUBTYPE_OF: + if c.target in uppers[c.type_var]: + continue + for l in tvars: + if (l, c.type_var) in graph: + uppers[l].add(c.target) + for lt in lowers[c.type_var]: + remaining |= set(infer_constraints(lt, c.target, SUBTYPE_OF)) + remaining |= set(infer_constraints(c.target, lt, SUPERTYPE_OF)) + else: + assert c.op == SUPERTYPE_OF + if c.target in lowers[c.type_var]: + continue + for u in tvars: + if (c.type_var, u) in graph: + lowers[u].add(c.target) + for ut in uppers[c.type_var]: + remaining |= set(infer_constraints(ut, c.target, SUPERTYPE_OF)) + remaining |= set(infer_constraints(c.target, ut, SUBTYPE_OF)) + return graph, lowers, uppers def compute_dependencies( - cmap: dict[TypeVarId, list[Constraint]] + tvars: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds ) -> dict[TypeVarId, list[TypeVarId]]: """Compute dependencies between type variables induced by constraints. @@ -367,25 +375,30 @@ def compute_dependencies( we will need to solve for S first before we can solve for T. """ res = {} - vars = list(cmap.keys()) - for tv in cmap: + for tv in tvars: deps = set() - for c in cmap[tv]: - deps |= get_vars(c.target, vars) + for lt in lowers[tv]: + deps |= get_vars(lt, tvars) + for ut in uppers[tv]: + deps |= get_vars(ut, tvars) + for other in tvars: + if other == tv: + continue + if (tv, other) in graph or (other, tv) in graph: + deps.add(other) res[tv] = list(deps) return res -def check_linear(scc: set[TypeVarId], cmap: dict[TypeVarId, list[Constraint]]) -> bool: +def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool: """Check there are only linear constraints between type variables in SCC. Linear are constraints like T <: S (while T <: F[S] are non-linear). """ for tv in scc: - if any( - get_vars(c.target, list(scc)) and not isinstance(c.target, TypeVarType) - for c in cmap[tv] - ): + if any(get_vars(lt, list(scc)) for lt in lowers[tv]): + return False + if any(get_vars(ut, list(scc)) for ut in uppers[tv]): return False return True diff --git a/mypy/subtypes.py b/mypy/subtypes.py index a6dc071f92b0..5712d7375e50 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1708,8 +1708,7 @@ def unify_generic_callable( type.ret_type, target.ret_type, return_constraint_direction ) constraints.extend(c) - type_var_ids = [tvar.id for tvar in type.variables] - inferred_vars = mypy.solve.solve_constraints(type_var_ids, constraints) + inferred_vars, _ = mypy.solve.solve_constraints(type.variables, constraints) if None in inferred_vars: return None non_none_inferred_vars = cast(List[Type], inferred_vars) diff --git a/mypy/test/testconstraints.py b/mypy/test/testconstraints.py index b46f31327150..f40996145cba 100644 --- a/mypy/test/testconstraints.py +++ b/mypy/test/testconstraints.py @@ -1,7 +1,5 @@ from __future__ import annotations -import pytest - from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints from mypy.test.helpers import Suite from mypy.test.typefixture import TypeFixture @@ -22,7 +20,6 @@ def test_basic_type_variable(self) -> None: Constraint(type_var=fx.t, op=direction, target=fx.a) ] - @pytest.mark.xfail def test_basic_type_var_tuple_subtype(self) -> None: fx = self.fx assert infer_constraints( diff --git a/mypy/test/testsolve.py b/mypy/test/testsolve.py index d6c585ef4aaa..5d67203dbbf5 100644 --- a/mypy/test/testsolve.py +++ b/mypy/test/testsolve.py @@ -6,7 +6,7 @@ from mypy.solve import solve_constraints from mypy.test.helpers import Suite, assert_equal from mypy.test.typefixture import TypeFixture -from mypy.types import Type, TypeVarId, TypeVarType +from mypy.types import Type, TypeVarLikeType, TypeVarType class SolveSuite(Suite): @@ -17,26 +17,24 @@ def test_empty_input(self) -> None: self.assert_solve([], [], []) def test_simple_supertype_constraints(self) -> None: + self.assert_solve([self.fx.t], [self.supc(self.fx.t, self.fx.a)], [(self.fx.a, self.fx.o)]) self.assert_solve( - [self.fx.t.id], [self.supc(self.fx.t, self.fx.a)], [(self.fx.a, self.fx.o)] - ) - self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.a), self.supc(self.fx.t, self.fx.b)], [(self.fx.a, self.fx.o)], ) def test_simple_subtype_constraints(self) -> None: - self.assert_solve([self.fx.t.id], [self.subc(self.fx.t, self.fx.a)], [self.fx.a]) + self.assert_solve([self.fx.t], [self.subc(self.fx.t, self.fx.a)], [self.fx.a]) self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.subc(self.fx.t, self.fx.a), self.subc(self.fx.t, self.fx.b)], [self.fx.b], ) def test_both_kinds_of_constraints(self) -> None: self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.b), self.subc(self.fx.t, self.fx.a)], [(self.fx.b, self.fx.a)], ) @@ -44,21 +42,19 @@ def test_both_kinds_of_constraints(self) -> None: def test_unsatisfiable_constraints(self) -> None: # The constraints are impossible to satisfy. self.assert_solve( - [self.fx.t.id], - [self.supc(self.fx.t, self.fx.a), self.subc(self.fx.t, self.fx.b)], - [None], + [self.fx.t], [self.supc(self.fx.t, self.fx.a), self.subc(self.fx.t, self.fx.b)], [None] ) def test_exactly_specified_result(self) -> None: self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.b), self.subc(self.fx.t, self.fx.b)], [(self.fx.b, self.fx.b)], ) def test_multiple_variables(self) -> None: self.assert_solve( - [self.fx.t.id, self.fx.s.id], + [self.fx.t, self.fx.s], [ self.supc(self.fx.t, self.fx.b), self.supc(self.fx.s, self.fx.c), @@ -68,40 +64,38 @@ def test_multiple_variables(self) -> None: ) def test_no_constraints_for_var(self) -> None: - self.assert_solve([self.fx.t.id], [], [self.fx.uninhabited]) - self.assert_solve( - [self.fx.t.id, self.fx.s.id], [], [self.fx.uninhabited, self.fx.uninhabited] - ) + self.assert_solve([self.fx.t], [], [self.fx.uninhabited]) + self.assert_solve([self.fx.t, self.fx.s], [], [self.fx.uninhabited, self.fx.uninhabited]) self.assert_solve( - [self.fx.t.id, self.fx.s.id], + [self.fx.t, self.fx.s], [self.supc(self.fx.s, self.fx.a)], [self.fx.uninhabited, (self.fx.a, self.fx.o)], ) def test_simple_constraints_with_dynamic_type(self) -> None: self.assert_solve( - [self.fx.t.id], [self.supc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)] + [self.fx.t], [self.supc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)] ) self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.anyt), self.supc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)], ) self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.anyt), self.supc(self.fx.t, self.fx.a)], [(self.fx.anyt, self.fx.anyt)], ) self.assert_solve( - [self.fx.t.id], [self.subc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)] + [self.fx.t], [self.subc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)] ) self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.subc(self.fx.t, self.fx.anyt), self.subc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)], ) - # self.assert_solve([self.fx.t.id], + # self.assert_solve([self.fx.t], # [self.subc(self.fx.t, self.fx.anyt), # self.subc(self.fx.t, self.fx.a)], # [(self.fx.anyt, self.fx.anyt)]) @@ -111,20 +105,20 @@ def test_both_normal_and_any_types_in_results(self) -> None: # If one of the bounds is any, we promote the other bound to # any as well, since otherwise the type range does not make sense. self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.a), self.subc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)], ) self.assert_solve( - [self.fx.t.id], + [self.fx.t], [self.supc(self.fx.t, self.fx.anyt), self.subc(self.fx.t, self.fx.a)], [(self.fx.anyt, self.fx.anyt)], ) def assert_solve( self, - vars: list[TypeVarId], + vars: list[TypeVarLikeType], constraints: list[Constraint], results: list[None | Type | tuple[Type, Type]], ) -> None: @@ -134,7 +128,7 @@ def assert_solve( res.append(r[0]) else: res.append(r) - actual = solve_constraints(vars, constraints) + actual, _ = solve_constraints(vars, constraints) assert_equal(str(actual), str(res)) def supc(self, type_var: TypeVarType, bound: Type) -> Constraint: diff --git a/mypy/typeops.py b/mypy/typeops.py index 519d3de995f5..65ab4340403c 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -313,7 +313,9 @@ class B(A): pass original_type = get_proper_type(original_type) all_ids = func.type_var_ids() - typeargs = infer_type_arguments(all_ids, self_param_type, original_type, is_supertype=True) + typeargs = infer_type_arguments( + func.variables, self_param_type, original_type, is_supertype=True + ) if ( is_classmethod # TODO: why do we need the extra guards here? @@ -322,7 +324,7 @@ class B(A): pass ): # In case we call a classmethod through an instance x, fallback to type(x) typeargs = infer_type_arguments( - all_ids, self_param_type, TypeType(original_type), is_supertype=True + func.variables, self_param_type, TypeType(original_type), is_supertype=True ) ids = [tid for tid in all_ids if any(tid == t.id for t in get_type_vars(self_param_type))] diff --git a/mypy_self_check.ini b/mypy_self_check.ini index 7413e6407d4f..fcdbe641d6d6 100644 --- a/mypy_self_check.ini +++ b/mypy_self_check.ini @@ -8,6 +8,7 @@ always_false = MYPYC plugins = misc/proper_plugin.py python_version = 3.8 exclude = mypy/typeshed/|mypyc/test-data/|mypyc/lib-rt/ +new_type_inference = True enable_error_code = ignore-without-code,redundant-expr show_error_code_links = True diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 34588bfceb3d..5c510a11b970 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2911,9 +2911,23 @@ def test1(x: V) -> V: ... def test2(x: V, y: V) -> V: ... reveal_type(dec1(test1)) # N: Revealed type is "def () -> def [T] (T`1) -> T`1" -# TODO: support this situation -reveal_type(dec2(test2)) # N: Revealed type is "def (builtins.object) -> def (builtins.object) -> builtins.object" -[builtins fixtures/paramspec.pyi] +reveal_type(dec2(test2)) # N: Revealed type is "def [T] (T`3) -> def (T`3) -> T`3" +[builtins fixtures/list.pyi] + +[case testInferenceAgainstGenericCallableNewVariable] +# flags: --new-type-inference +from typing import TypeVar, Callable, List + +S = TypeVar('S') +T = TypeVar('T') +U = TypeVar('U') + +def dec(f: Callable[[S], T]) -> Callable[[S], T]: + ... +def test(x: List[U]) -> List[U]: + ... +reveal_type(dec(test)) # N: Revealed type is "def [U] (builtins.list[U`-1]) -> builtins.list[U`-1]" +[builtins fixtures/list.pyi] [case testInferenceAgainstGenericCallableGenericAlias] # flags: --new-type-inference diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 89e5aea210b4..50acd7d77c8c 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6554,3 +6554,18 @@ class Snafu(object): reveal_type(Snafu().snafu('123')) # N: Revealed type is "builtins.str" reveal_type(Snafu.snafu('123')) # N: Revealed type is "builtins.str" [builtins fixtures/staticmethod.pyi] + +[case testOverloadedWithInternalTypeVars] +# flags: --new-type-inference +import m + +[file m.pyi] +from typing import Callable, TypeVar, overload + +T = TypeVar("T") +S = TypeVar("S", bound=str) + +@overload +def foo(x: int = ...) -> Callable[[T], T]: ... +@overload +def foo(x: S = ...) -> Callable[[T], T]: ...