diff --git a/mypy/constraints.py b/mypy/constraints.py index 26504ed06b3e..47f312117264 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -595,15 +595,11 @@ def visit_parameters(self, template: Parameters) -> list[Constraint]: return self.infer_against_any(template.arg_types, self.actual) if type_state.infer_polymorphic and isinstance(self.actual, Parameters): # For polymorphic inference we need to be able to infer secondary constraints - # in situations like [x: T] <: P <: [x: int]. - res = [] - if len(template.arg_types) == len(self.actual.arg_types): - for tt, at in zip(template.arg_types, self.actual.arg_types): - # This avoids bogus constraints like T <: P.args - if isinstance(at, ParamSpecType): - continue - res.extend(infer_constraints(tt, at, self.direction)) - return res + # in situations like [x: T] <: P <: [x: int]. Note we invert direction, since + # this function expects direction between callables. + return infer_callable_arguments_constraints( + template, self.actual, neg_op(self.direction) + ) raise RuntimeError("Parameters cannot be constrained to") # Non-leaf types @@ -722,7 +718,8 @@ def visit_instance(self, template: Instance) -> list[Constraint]: prefix = mapped_arg.prefix if isinstance(instance_arg, Parameters): # No such thing as variance for ParamSpecs, consider them invariant - # TODO: constraints between prefixes + # TODO: constraints between prefixes using + # infer_callable_arguments_constraints() suffix: Type = instance_arg.copy_modified( instance_arg.arg_types[len(prefix.arg_types) :], instance_arg.arg_kinds[len(prefix.arg_kinds) :], @@ -793,7 +790,8 @@ def visit_instance(self, template: Instance) -> list[Constraint]: prefix = template_arg.prefix if isinstance(mapped_arg, Parameters): # No such thing as variance for ParamSpecs, consider them invariant - # TODO: constraints between prefixes + # TODO: constraints between prefixes using + # infer_callable_arguments_constraints() suffix = mapped_arg.copy_modified( mapped_arg.arg_types[len(prefix.arg_types) :], mapped_arg.arg_kinds[len(prefix.arg_kinds) :], @@ -962,24 +960,12 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: unpack_constraints = build_constraints_for_simple_unpack( template_types, actual_types, neg_op(self.direction) ) - template_args = [] - cactual_args = [] res.extend(unpack_constraints) else: - template_args = template.arg_types - cactual_args = cactual.arg_types - # TODO: use some more principled "formal to actual" logic - # instead of this lock-step loop over argument types. This identical - # logic should be used in 5 places: in Parameters vs Parameters - # inference, in Instance vs Instance inference for prefixes (two - # branches), and in Callable vs Callable inference (two branches). - for t, a in zip(template_args, cactual_args): - # This avoids bogus constraints like T <: P.args - if isinstance(a, (ParamSpecType, UnpackType)): - # TODO: can we infer something useful for *T vs P? - continue # Negate direction due to function argument type contravariance. - res.extend(infer_constraints(t, a, neg_op(self.direction))) + res.extend( + infer_callable_arguments_constraints(template, cactual, self.direction) + ) else: prefix = param_spec.prefix prefix_len = len(prefix.arg_types) @@ -1028,11 +1014,9 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: arg_kinds=cactual.arg_kinds[:prefix_len], arg_names=cactual.arg_names[:prefix_len], ) - - for t, a in zip(prefix.arg_types, cactual_prefix.arg_types): - if isinstance(a, ParamSpecType): - continue - res.extend(infer_constraints(t, a, neg_op(self.direction))) + res.extend( + infer_callable_arguments_constraints(prefix, cactual_prefix, self.direction) + ) template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type if template.type_guard is not None: @@ -1435,3 +1419,89 @@ def build_constraints_for_unpack( for template_arg, item in zip(template_unpack.items, mapped_middle): res.extend(infer_constraints(template_arg, item, direction)) return res, mapped_prefix + mapped_suffix, template_prefix + template_suffix + + +def infer_directed_arg_constraints(left: Type, right: Type, direction: int) -> list[Constraint]: + """Infer constraints between two arguments using direction between original callables.""" + if isinstance(left, (ParamSpecType, UnpackType)) or isinstance( + right, (ParamSpecType, UnpackType) + ): + # This avoids bogus constraints like T <: P.args + # TODO: can we infer something useful for *T vs P? + return [] + if direction == SUBTYPE_OF: + # We invert direction to account for argument contravariance. + return infer_constraints(left, right, neg_op(direction)) + else: + return infer_constraints(right, left, neg_op(direction)) + + +def infer_callable_arguments_constraints( + template: CallableType | Parameters, actual: CallableType | Parameters, direction: int +) -> list[Constraint]: + """Infer constraints between argument types of two callables. + + This function essentially extracts four steps from are_parameters_compatible() in + subtypes.py that involve subtype checks between argument types. We keep the argument + matching logic, but ignore various strictness flags present there, and checks that + do not involve subtyping. Then in place of every subtype check we put an infer_constraints() + call for the same types. + """ + res = [] + if direction == SUBTYPE_OF: + left, right = template, actual + else: + left, right = actual, template + left_star = left.var_arg() + left_star2 = left.kw_arg() + right_star = right.var_arg() + right_star2 = right.kw_arg() + + # Numbering of steps below matches the one in are_parameters_compatible() for convenience. + # Phase 1a: compare star vs star arguments. + if left_star is not None and right_star is not None: + res.extend(infer_directed_arg_constraints(left_star.typ, right_star.typ, direction)) + if left_star2 is not None and right_star2 is not None: + res.extend(infer_directed_arg_constraints(left_star2.typ, right_star2.typ, direction)) + + # Phase 1b: compare left args with corresponding non-star right arguments. + for right_arg in right.formal_arguments(): + left_arg = mypy.typeops.callable_corresponding_argument(left, right_arg) + if left_arg is None: + continue + res.extend(infer_directed_arg_constraints(left_arg.typ, right_arg.typ, direction)) + + # Phase 1c: compare left args with right *args. + if right_star is not None: + right_by_position = right.try_synthesizing_arg_from_vararg(None) + assert right_by_position is not None + i = right_star.pos + assert i is not None + while i < len(left.arg_kinds) and left.arg_kinds[i].is_positional(): + left_by_position = left.argument_by_position(i) + assert left_by_position is not None + res.extend( + infer_directed_arg_constraints( + left_by_position.typ, right_by_position.typ, direction + ) + ) + i += 1 + + # Phase 1d: compare left args with right **kwargs. + if right_star2 is not None: + right_names = {name for name in right.arg_names if name is not None} + left_only_names = set() + for name, kind in zip(left.arg_names, left.arg_kinds): + if name is None or kind.is_star() or name in right_names: + continue + left_only_names.add(name) + + right_by_name = right.try_synthesizing_arg_from_kwarg(None) + assert right_by_name is not None + for name in left_only_names: + left_by_name = left.argument_by_name(name) + assert left_by_name is not None + res.extend( + infer_directed_arg_constraints(left_by_name.typ, right_by_name.typ, direction) + ) + return res diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 11847858c62c..288de10cc234 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -590,6 +590,7 @@ def check_mixed( ): nominal = False else: + # TODO: everywhere else ParamSpecs are handled as invariant. if not check_type_parameter( lefta, righta, COVARIANT, self.proper_subtype, self.subtype_context ): @@ -666,13 +667,12 @@ def visit_unpack_type(self, left: UnpackType) -> bool: return False def visit_parameters(self, left: Parameters) -> bool: - if isinstance(self.right, (Parameters, CallableType)): - right = self.right - if isinstance(right, CallableType): - right = right.with_unpacked_kwargs() + if isinstance(self.right, Parameters): + # TODO: direction here should be opposite, this function expects + # order of callables, while parameters are contravariant. return are_parameters_compatible( left, - right, + self.right, is_compat=self._is_subtype, ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names, ) @@ -723,14 +723,6 @@ def visit_callable_type(self, left: CallableType) -> bool: elif isinstance(right, TypeType): # This is unsound, we don't check the __init__ signature. return left.is_type_obj() and self._is_subtype(left.ret_type, right.item) - elif isinstance(right, Parameters): - # this doesn't check return types.... but is needed for is_equivalent - return are_parameters_compatible( - left.with_unpacked_kwargs(), - right, - is_compat=self._is_subtype, - ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names, - ) else: return False @@ -1456,7 +1448,6 @@ def g(x: int) -> int: ... right, is_compat=is_compat, ignore_pos_arg_names=ignore_pos_arg_names, - check_args_covariantly=check_args_covariantly, allow_partial_overlap=allow_partial_overlap, strict_concatenate_check=strict_concatenate_check, ) @@ -1480,7 +1471,6 @@ def are_parameters_compatible( *, is_compat: Callable[[Type, Type], bool], ignore_pos_arg_names: bool = False, - check_args_covariantly: bool = False, allow_partial_overlap: bool = False, strict_concatenate_check: bool = False, ) -> bool: @@ -1534,7 +1524,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N # Phase 1b: Check non-star args: for every arg right can accept, left must # also accept. The only exception is if we are allowing partial - # partial overlaps: in that case, we ignore optional args on the right. + # overlaps: in that case, we ignore optional args on the right. for right_arg in right.formal_arguments(): left_arg = mypy.typeops.callable_corresponding_argument(left, right_arg) if left_arg is None: @@ -1548,7 +1538,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N # Phase 1c: Check var args. Right has an infinite series of optional positional # arguments. Get all further positional args of left, and make sure - # they're more general then the corresponding member in right. + # they're more general than the corresponding member in right. if right_star is not None: # Synthesize an anonymous formal argument for the right right_by_position = right.try_synthesizing_arg_from_vararg(None) @@ -1575,7 +1565,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N # Phase 1d: Check kw args. Right has an infinite series of optional named # arguments. Get all further named args of left, and make sure - # they're more general then the corresponding member in right. + # they're more general than the corresponding member in right. if right_star2 is not None: right_names = {name for name in right.arg_names if name is not None} left_only_names = set() @@ -1643,6 +1633,10 @@ def are_args_compatible( allow_partial_overlap: bool, is_compat: Callable[[Type, Type], bool], ) -> bool: + if left.required and right.required: + # If both arguments are required allow_partial_overlap has no effect. + allow_partial_overlap = False + def is_different(left_item: object | None, right_item: object | None) -> bool: """Checks if the left and right items are different. @@ -1670,7 +1664,7 @@ def is_different(left_item: object | None, right_item: object | None) -> bool: # If right's argument is optional, left's must also be # (unless we're relaxing the checks to allow potential - # rather then definite compatibility). + # rather than definite compatibility). if not allow_partial_overlap and not right.required and left.required: return False diff --git a/mypy/types.py b/mypy/types.py index d4e2fc7cb63c..301ce6e0cf18 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1545,9 +1545,6 @@ class FormalArgument(NamedTuple): required: bool -# TODO: should this take bound typevars too? what would this take? -# ex: class Z(Generic[P, T]): ...; Z[[V], V] -# What does a typevar even mean in this context? class Parameters(ProperType): """Type that represents the parameters to a function. @@ -1559,6 +1556,8 @@ class Parameters(ProperType): "arg_names", "min_args", "is_ellipsis_args", + # TODO: variables don't really belong here, but they are used to allow hacky support + # for forall . Foo[[x: T], T] by capturing generic callable with ParamSpec, see #15909 "variables", ) @@ -1602,7 +1601,7 @@ def copy_modified( variables=variables if variables is not _dummy else self.variables, ) - # the following are copied from CallableType. Is there a way to decrease code duplication? + # TODO: here is a lot of code duplication with Callable type, fix this. def var_arg(self) -> FormalArgument | None: """The formal argument for *args.""" for position, (type, kind) in enumerate(zip(self.arg_types, self.arg_kinds)): @@ -2046,7 +2045,6 @@ def param_spec(self) -> ParamSpecType | None: return arg_type.copy_modified(flavor=ParamSpecFlavor.BARE, prefix=prefix) def expand_param_spec(self, c: Parameters) -> CallableType: - # TODO: try deleting variables from Parameters after new type inference is default. variables = c.variables return self.copy_modified( arg_types=self.arg_types[:-2] + c.arg_types, diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 9ee30b4df859..56d3fe2b4ce7 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3553,3 +3553,136 @@ class E(D): ... reveal_type([E(), D()]) # N: Revealed type is "builtins.list[__main__.D]" reveal_type([D(), E()]) # N: Revealed type is "builtins.list[__main__.D]" + +[case testCallableInferenceAgainstCallablePosVsStar] +from typing import TypeVar, Callable, Tuple + +T = TypeVar('T') +S = TypeVar('S') + +def f(x: Callable[[T, S], None]) -> Tuple[T, S]: ... +def g(*x: int) -> None: ... +reveal_type(f(g)) # N: Revealed type is "Tuple[builtins.int, builtins.int]" +[builtins fixtures/list.pyi] + +[case testCallableInferenceAgainstCallableStarVsPos] +from typing import TypeVar, Callable, Tuple, Protocol + +T = TypeVar('T', contravariant=True) +S = TypeVar('S', contravariant=True) + +class Call(Protocol[T, S]): + def __call__(self, __x: T, *args: S) -> None: ... + +def f(x: Call[T, S]) -> Tuple[T, S]: ... +def g(*x: int) -> None: ... +reveal_type(f(g)) # N: Revealed type is "Tuple[builtins.int, builtins.int]" +[builtins fixtures/list.pyi] + +[case testCallableInferenceAgainstCallableNamedVsStar] +from typing import TypeVar, Callable, Tuple, Protocol + +T = TypeVar('T', contravariant=True) +S = TypeVar('S', contravariant=True) + +class Call(Protocol[T, S]): + def __call__(self, *, x: T, y: S) -> None: ... + +def f(x: Call[T, S]) -> Tuple[T, S]: ... +def g(**kwargs: int) -> None: ... +reveal_type(f(g)) # N: Revealed type is "Tuple[builtins.int, builtins.int]" +[builtins fixtures/list.pyi] + +[case testCallableInferenceAgainstCallableStarVsNamed] +from typing import TypeVar, Callable, Tuple, Protocol + +T = TypeVar('T', contravariant=True) +S = TypeVar('S', contravariant=True) + +class Call(Protocol[T, S]): + def __call__(self, *, x: T, **kwargs: S) -> None: ... + +def f(x: Call[T, S]) -> Tuple[T, S]: ... +def g(**kwargs: int) -> None: pass +reveal_type(f(g)) # N: Revealed type is "Tuple[builtins.int, builtins.int]" +[builtins fixtures/list.pyi] + +[case testCallableInferenceAgainstCallableNamedVsNamed] +from typing import TypeVar, Callable, Tuple, Protocol + +T = TypeVar('T', contravariant=True) +S = TypeVar('S', contravariant=True) + +class Call(Protocol[T, S]): + def __call__(self, *, x: T, y: S) -> None: ... + +def f(x: Call[T, S]) -> Tuple[T, S]: ... + +# Note: order of names is different w.r.t. protocol +def g(*, y: int, x: str) -> None: pass +reveal_type(f(g)) # N: Revealed type is "Tuple[builtins.str, builtins.int]" +[builtins fixtures/list.pyi] + +[case testCallableInferenceAgainstCallablePosOnlyVsNamed] +from typing import TypeVar, Callable, Tuple, Protocol + +T = TypeVar('T', contravariant=True) +S = TypeVar('S', contravariant=True) + +class Call(Protocol[T]): + def __call__(self, *, x: T) -> None: ... + +def f(x: Call[T]) -> Tuple[T, T]: ... + +def g(__x: str) -> None: pass +reveal_type(f(g)) # N: Revealed type is "Tuple[, ]" \ + # E: Argument 1 to "f" has incompatible type "Callable[[str], None]"; expected "Call[]" +[builtins fixtures/list.pyi] + +[case testCallableInferenceAgainstCallableNamedVsPosOnly] +from typing import TypeVar, Callable, Tuple, Protocol + +T = TypeVar('T', contravariant=True) +S = TypeVar('S', contravariant=True) + +class Call(Protocol[T]): + def __call__(self, __x: T) -> None: ... + +def f(x: Call[T]) -> Tuple[T, T]: ... + +def g(*, x: str) -> None: pass +reveal_type(f(g)) # N: Revealed type is "Tuple[, ]" \ + # E: Argument 1 to "f" has incompatible type "Callable[[NamedArg(str, 'x')], None]"; expected "Call[]" +[builtins fixtures/list.pyi] + +[case testCallableInferenceAgainstCallablePosOnlyVsKwargs] +from typing import TypeVar, Callable, Tuple, Protocol + +T = TypeVar('T', contravariant=True) +S = TypeVar('S', contravariant=True) + +class Call(Protocol[T]): + def __call__(self, __x: T) -> None: ... + +def f(x: Call[T]) -> Tuple[T, T]: ... + +def g(**x: str) -> None: pass +reveal_type(f(g)) # N: Revealed type is "Tuple[, ]" \ + # E: Argument 1 to "f" has incompatible type "Callable[[KwArg(str)], None]"; expected "Call[]" +[builtins fixtures/list.pyi] + +[case testCallableInferenceAgainstCallableNamedVsArgs] +from typing import TypeVar, Callable, Tuple, Protocol + +T = TypeVar('T', contravariant=True) +S = TypeVar('S', contravariant=True) + +class Call(Protocol[T]): + def __call__(self, *, x: T) -> None: ... + +def f(x: Call[T]) -> Tuple[T, T]: ... + +def g(*args: str) -> None: pass +reveal_type(f(g)) # N: Revealed type is "Tuple[, ]" \ + # E: Argument 1 to "f" has incompatible type "Callable[[VarArg(str)], None]"; expected "Call[]" +[builtins fixtures/list.pyi] diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index b778dc50b376..ede4a2e4cf62 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6640,3 +6640,13 @@ def bar(x): ... reveal_type(bar) # N: Revealed type is "Overload(def (builtins.int) -> builtins.float, def (builtins.str) -> builtins.str)" [builtins fixtures/paramspec.pyi] + +[case testOverloadOverlapWithNameOnlyArgs] +from typing import overload + +@overload +def d(x: int) -> int: ... +@overload +def d(f: int, *, x: int) -> str: ... +def d(*args, **kwargs): ... +[builtins fixtures/tuple.pyi]