Skip to content

Commit

Permalink
Fix subtyping between ParamSpecs (#15892)
Browse files Browse the repository at this point in the history
Fixes #14169
Fixes #14168

Two sings here:
* Actually check prefix when we should
* `strict_concatenate` check should be off by default (IIUC it is not
mandated by the PEP)
  • Loading branch information
ilevkivskyi committed Aug 17, 2023
1 parent 76c16a4 commit b3d0937
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 16 deletions.
3 changes: 1 addition & 2 deletions mypy/expandtype.py
Expand Up @@ -383,8 +383,6 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
t = t.expand_param_spec(repl)
return t.copy_modified(
arg_types=self.expand_types(t.arg_types),
arg_kinds=t.arg_kinds,
arg_names=t.arg_names,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
)
Expand All @@ -402,6 +400,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
arg_kinds=t.arg_kinds[:-2] + prefix.arg_kinds + t.arg_kinds[-2:],
arg_names=t.arg_names[:-2] + prefix.arg_names + t.arg_names[-2:],
ret_type=t.ret_type.accept(self),
from_concatenate=t.from_concatenate or bool(repl.prefix.arg_types),
)

var_arg = t.var_arg()
Expand Down
18 changes: 12 additions & 6 deletions mypy/messages.py
Expand Up @@ -2116,9 +2116,11 @@ def report_protocol_problems(
return

# Report member type conflicts
conflict_types = get_conflict_protocol_types(subtype, supertype, class_obj=class_obj)
conflict_types = get_conflict_protocol_types(
subtype, supertype, class_obj=class_obj, options=self.options
)
if conflict_types and (
not is_subtype(subtype, erase_type(supertype))
not is_subtype(subtype, erase_type(supertype), options=self.options)
or not subtype.type.defn.type_vars
or not supertype.type.defn.type_vars
):
Expand Down Expand Up @@ -2780,7 +2782,11 @@ def [T <: int] f(self, x: int, y: T) -> None
slash = True

# If we got a "special arg" (i.e: self, cls, etc...), prepend it to the arg list
if isinstance(tp.definition, FuncDef) and hasattr(tp.definition, "arguments"):
if (
isinstance(tp.definition, FuncDef)
and hasattr(tp.definition, "arguments")
and not tp.from_concatenate
):
definition_arg_names = [arg.variable.name for arg in tp.definition.arguments]
if (
len(definition_arg_names) > len(tp.arg_names)
Expand Down Expand Up @@ -2857,7 +2863,7 @@ def get_missing_protocol_members(left: Instance, right: Instance, skip: list[str


def get_conflict_protocol_types(
left: Instance, right: Instance, class_obj: bool = False
left: Instance, right: Instance, class_obj: bool = False, options: Options | None = None
) -> list[tuple[str, Type, Type]]:
"""Find members that are defined in 'left' but have incompatible types.
Return them as a list of ('member', 'got', 'expected').
Expand All @@ -2872,9 +2878,9 @@ def get_conflict_protocol_types(
subtype = mypy.typeops.get_protocol_member(left, member, class_obj)
if not subtype:
continue
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True)
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True, options=options)
if IS_SETTABLE in get_member_flags(member, right):
is_compat = is_compat and is_subtype(supertype, subtype)
is_compat = is_compat and is_subtype(supertype, subtype, options=options)
if not is_compat:
conflicts.append((member, subtype, supertype))
return conflicts
Expand Down
17 changes: 10 additions & 7 deletions mypy/subtypes.py
Expand Up @@ -600,7 +600,7 @@ def check_mixed(
type_state.record_negative_subtype_cache_entry(self._subtype_kind, left, right)
return nominal
if right.type.is_protocol and is_protocol_implementation(
left, right, proper_subtype=self.proper_subtype
left, right, proper_subtype=self.proper_subtype, options=self.options
):
return True
# We record negative cache entry here, and not in the protocol check like we do for
Expand Down Expand Up @@ -647,7 +647,7 @@ def visit_param_spec(self, left: ParamSpecType) -> bool:
and right.id == left.id
and right.flavor == left.flavor
):
return True
return self._is_subtype(left.prefix, right.prefix)
if isinstance(right, Parameters) and are_trivial_parameters(right):
return True
return self._is_subtype(left.upper_bound, self.right)
Expand Down Expand Up @@ -696,7 +696,7 @@ def visit_callable_type(self, left: CallableType) -> bool:
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
strict_concatenate=(self.options.extra_checks or self.options.strict_concatenate)
if self.options
else True,
else False,
)
elif isinstance(right, Overloaded):
return all(self._is_subtype(left, item) for item in right.items)
Expand Down Expand Up @@ -863,7 +863,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
strict_concat = (
(self.options.extra_checks or self.options.strict_concatenate)
if self.options
else True
else False
)
if left_index not in matched_overloads and (
is_callable_compatible(
Expand Down Expand Up @@ -1003,6 +1003,7 @@ def is_protocol_implementation(
proper_subtype: bool = False,
class_obj: bool = False,
skip: list[str] | None = None,
options: Options | None = None,
) -> bool:
"""Check whether 'left' implements the protocol 'right'.
Expand Down Expand Up @@ -1068,7 +1069,9 @@ def f(self) -> A: ...
# Nominal check currently ignores arg names
# NOTE: If we ever change this, be sure to also change the call to
# SubtypeVisitor.build_subtype_kind(...) down below.
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=ignore_names)
is_compat = is_subtype(
subtype, supertype, ignore_pos_arg_names=ignore_names, options=options
)
else:
is_compat = is_proper_subtype(subtype, supertype)
if not is_compat:
Expand All @@ -1080,7 +1083,7 @@ def f(self) -> A: ...
superflags = get_member_flags(member, right)
if IS_SETTABLE in superflags:
# Check opposite direction for settable attributes.
if not is_subtype(supertype, subtype):
if not is_subtype(supertype, subtype, options=options):
return False
if not class_obj:
if IS_SETTABLE not in superflags:
Expand Down Expand Up @@ -1479,7 +1482,7 @@ def are_parameters_compatible(
ignore_pos_arg_names: bool = False,
check_args_covariantly: bool = False,
allow_partial_overlap: bool = False,
strict_concatenate_check: bool = True,
strict_concatenate_check: bool = False,
) -> bool:
"""Helper function for is_callable_compatible, used for Parameter compatibility"""
if right.is_ellipsis_args:
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-overloading.test
Expand Up @@ -6483,7 +6483,7 @@ P = ParamSpec("P")
R = TypeVar("R")

@overload
def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ...
def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
@overload
def func(x: Callable[P, R]) -> Callable[Concatenate[str, P], R]: ...
def func(x: Callable[..., R]) -> Callable[..., R]: ...
Expand Down
70 changes: 70 additions & 0 deletions test-data/unit/check-parameter-specification.test
Expand Up @@ -1576,3 +1576,73 @@ def test() -> None: ...
# TODO: avoid this error, although it may be non-trivial.
apply(apply, test) # E: Argument 2 to "apply" has incompatible type "Callable[[], None]"; expected "Callable[P, T]"
[builtins fixtures/paramspec.pyi]

[case testParamSpecPrefixSubtypingGenericInvalid]
from typing import Generic
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")

class A(Generic[P]):
def foo(self, *args: P.args, **kwargs: P.kwargs):
...

def bar(b: A[P]) -> A[Concatenate[int, P]]:
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]")
[builtins fixtures/paramspec.pyi]

[case testParamSpecPrefixSubtypingProtocolInvalid]
from typing import Protocol
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")

class A(Protocol[P]):
def foo(self, *args: P.args, **kwargs: P.kwargs):
...

def bar(b: A[P]) -> A[Concatenate[int, P]]:
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]")
[builtins fixtures/paramspec.pyi]

[case testParamSpecPrefixSubtypingValidNonStrict]
from typing import Protocol
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")

class A(Protocol[P]):
def foo(self, a: int, *args: P.args, **kwargs: P.kwargs):
...

class B(Protocol[P]):
def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs):
...

def bar(b: B[P]) -> A[Concatenate[int, P]]:
return b
[builtins fixtures/paramspec.pyi]

[case testParamSpecPrefixSubtypingInvalidStrict]
# flags: --extra-checks
from typing import Protocol
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")

class A(Protocol[P]):
def foo(self, a: int, *args: P.args, **kwargs: P.kwargs):
...

class B(Protocol[P]):
def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs):
...

def bar(b: B[P]) -> A[Concatenate[int, P]]:
return b # E: Incompatible return value type (got "B[P]", expected "A[[int, **P]]") \
# N: Following member(s) of "B[P]" have conflicts: \
# N: Expected: \
# N: def foo(self, a: int, int, /, *args: P.args, **kwargs: P.kwargs) -> Any \
# N: Got: \
# N: def foo(self, a: int, b: int, *args: P.args, **kwargs: P.kwargs) -> Any
[builtins fixtures/paramspec.pyi]

0 comments on commit b3d0937

Please sign in to comment.