Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lenient handling of trivial Callable suffixes #15913

Merged
merged 8 commits into from
Sep 14, 2023
4 changes: 3 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,9 @@ def check_func_def(
):
if defn.is_class or defn.name == "__new__":
ref_type = mypy.types.TypeType.make_normalized(ref_type)
erased = get_proper_type(erase_to_bound(arg_type))
# This level of erasure matches the one in checkmember.check_self_arg(),
# better keep these two checks consistent.
erased = get_proper_type(erase_typevars(erase_to_bound(arg_type)))
ilevkivskyi marked this conversation as resolved.
Show resolved Hide resolved
if not is_subtype(ref_type, erased, ignore_type_params=True):
if (
isinstance(erased, Instance)
Expand Down
7 changes: 7 additions & 0 deletions mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
return self.replacement
return t

def visit_callable_type(self, t: CallableType) -> Type:
result = super().visit_callable_type(t)
if t.param_spec():
assert isinstance(result, ProperType) and isinstance(result, CallableType)
result.erased = True
return result

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Type alias target can't contain bound type variables (not bound by the type
# alias itself), so it is safe to just erase the arguments.
Expand Down
3 changes: 3 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2132,6 +2132,9 @@ def report_protocol_problems(
not is_subtype(subtype, erase_type(supertype), options=self.options)
or not subtype.type.defn.type_vars
or not supertype.type.defn.type_vars
# Always show detailed message for ParamSpec
or subtype.type.has_param_spec_type
or supertype.type.has_param_spec_type
):
type_name = format_type(subtype, self.options, module_names=True)
self.note(f"Following member(s) of {type_name} have conflicts:", context, code=code)
Expand Down
21 changes: 18 additions & 3 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,18 @@ def are_trivial_parameters(param: Parameters | NormalizedCallableType) -> bool:
)


def is_trivial_suffix(param: Parameters | NormalizedCallableType) -> bool:
param_star = param.var_arg()
param_star2 = param.kw_arg()
return (
param.arg_kinds[-2:] == [ARG_STAR, ARG_STAR2]
and param_star is not None
and isinstance(get_proper_type(param_star.typ), AnyType)
and param_star2 is not None
and isinstance(get_proper_type(param_star2.typ), AnyType)
)


def are_parameters_compatible(
left: Parameters | NormalizedCallableType,
right: Parameters | NormalizedCallableType,
Expand All @@ -1498,6 +1510,9 @@ def are_parameters_compatible(
if are_trivial_parameters(right):
return True

# Parameters should not contain nested ParamSpec, so erasure doesn't make them less general.
trivial_suffix = isinstance(right, CallableType) and right.erased and is_trivial_suffix(right)

# Match up corresponding arguments and check them for compatibility. In
# every pair (argL, argR) of corresponding arguments from L and R, argL must
# be "more general" than argR if L is to be a subtype of R.
Expand Down Expand Up @@ -1527,7 +1542,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
if right_arg is None:
return False
if left_arg is None:
return not allow_partial_overlap
return not allow_partial_overlap and not trivial_suffix
return not is_compat(right_arg.typ, left_arg.typ)

if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2):
Expand All @@ -1551,7 +1566,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
# arguments. Get all further positional args of left, and make sure
# they're more general than the corresponding member in right.
# TODO: are we handling UnpackType correctly here?
if right_star is not None:
if right_star is not None and not trivial_suffix:
# Synthesize an anonymous formal argument for the right
right_by_position = right.try_synthesizing_arg_from_vararg(None)
assert right_by_position is not None
Expand All @@ -1578,7 +1593,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
# Phase 1d: Check kw args. Right has an infinite series of optional named
# arguments. Get all further named args of left, and make sure
# they're more general than the corresponding member in right.
if right_star2 is not None:
if right_star2 is not None and not trivial_suffix:
right_names = {name for name in right.arg_names if name is not None}
left_only_names = set()
for name, kind in zip(left.arg_names, left.arg_kinds):
Expand Down
4 changes: 4 additions & 0 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ def supported_self_type(typ: ProperType) -> bool:
"""
if isinstance(typ, TypeType):
return supported_self_type(typ.item)
if isinstance(typ, CallableType):
# Special case: allow class callable instead of Type[...] as cls annotation,
# as well as callable self for callback protocols.
return True
return isinstance(typ, TypeVarType) or (
isinstance(typ, Instance) and typ != fill_typevars(typ.type)
)
Expand Down
7 changes: 7 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1778,6 +1778,7 @@ class CallableType(FunctionLike):
# (this is used for error messages)
"imprecise_arg_kinds",
"unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable?
"erased", # Is this callable created as an erased form of a more precise type?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels kind of ad hoc -- and there seems to overlap with is_ellipsis_args. Is there a way to merge these two? For example, allow is_ellipsis_args to be used with an argument prefix. Then erasing a ParamSpec could produce a type with is_ellipsis_args=True.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. This looks quite ad-hoc. The more think about it the more I think we should do this unconditionally. This will make the whole thing much simpler.

)

def __init__(
Expand All @@ -1803,6 +1804,7 @@ def __init__(
from_concatenate: bool = False,
imprecise_arg_kinds: bool = False,
unpack_kwargs: bool = False,
erased: bool = False,
) -> None:
super().__init__(line, column)
assert len(arg_types) == len(arg_kinds) == len(arg_names)
Expand Down Expand Up @@ -1850,6 +1852,7 @@ def __init__(
self.def_extras = {}
self.type_guard = type_guard
self.unpack_kwargs = unpack_kwargs
self.erased = erased

def copy_modified(
self: CT,
Expand All @@ -1873,6 +1876,7 @@ def copy_modified(
from_concatenate: Bogus[bool] = _dummy,
imprecise_arg_kinds: Bogus[bool] = _dummy,
unpack_kwargs: Bogus[bool] = _dummy,
erased: Bogus[bool] = _dummy,
) -> CT:
modified = CallableType(
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
Expand Down Expand Up @@ -1903,6 +1907,7 @@ def copy_modified(
else self.imprecise_arg_kinds
),
unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs,
erased=erased if erased is not _dummy else self.erased,
)
# Optimization: Only NewTypes are supported as subtypes since
# the class is effectively final, so we can use a cast safely.
Expand Down Expand Up @@ -2220,6 +2225,7 @@ def serialize(self) -> JsonDict:
"from_concatenate": self.from_concatenate,
"imprecise_arg_kinds": self.imprecise_arg_kinds,
"unpack_kwargs": self.unpack_kwargs,
"erased": self.erased,
}

@classmethod
Expand All @@ -2244,6 +2250,7 @@ def deserialize(cls, data: JsonDict) -> CallableType:
from_concatenate=data["from_concatenate"],
imprecise_arg_kinds=data["imprecise_arg_kinds"],
unpack_kwargs=data["unpack_kwargs"],
erased=data["erased"],
)


Expand Down
113 changes: 112 additions & 1 deletion test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -1729,7 +1729,12 @@ class A(Protocol[P]):
...

def bar(b: A[P]) -> A[Concatenate[int, P]]:
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]")
return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]") \
# N: Following member(s) of "A[P]" have conflicts: \
# N: Expected: \
# N: def foo(self, int, /, *args: P.args, **kwargs: P.kwargs) -> Any \
# N: Got: \
# N: def foo(self, *args: P.args, **kwargs: P.kwargs) -> Any
[builtins fixtures/paramspec.pyi]

[case testParamSpecPrefixSubtypingValidNonStrict]
Expand Down Expand Up @@ -1825,6 +1830,112 @@ c: C[int, [int, str], str] # E: Nested parameter specifications are not allowed
reveal_type(c) # N: Revealed type is "__main__.C[Any]"
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateSelfType]
from typing import Callable
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")
class A:
def __init__(self, a_param_1: str) -> None: ...

@classmethod
def add_params(cls: Callable[P, A]) -> Callable[Concatenate[float, P], A]:
def new_constructor(i: float, *args: P.args, **kwargs: P.kwargs) -> A:
return cls(*args, **kwargs)
return new_constructor

@classmethod
def remove_params(cls: Callable[Concatenate[str, P], A]) -> Callable[P, A]:
def new_constructor(*args: P.args, **kwargs: P.kwargs) -> A:
return cls("my_special_str", *args, **kwargs)
return new_constructor

reveal_type(A.add_params()) # N: Revealed type is "def (builtins.float, a_param_1: builtins.str) -> __main__.A"
reveal_type(A.remove_params()) # N: Revealed type is "def () -> __main__.A"
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateCallbackProtocol]
from typing import Protocol, TypeVar
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")
R = TypeVar("R", covariant=True)

class Path: ...

class Function(Protocol[P, R]):
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ...

def file_cache(fn: Function[Concatenate[Path, P], R]) -> Function[P, R]:
def wrapper(*args: P.args, **kw: P.kwargs) -> R:
return fn(Path(), *args, **kw)
return wrapper

@file_cache
def get_thing(path: Path, *, some_arg: int) -> int: ...
reveal_type(get_thing) # N: Revealed type is "__main__.Function[[*, some_arg: builtins.int], builtins.int]"
get_thing(some_arg=1) # OK
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateKeywordOnly]
from typing import Callable, TypeVar
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")
R = TypeVar("R")

class Path: ...

def file_cache(fn: Callable[Concatenate[Path, P], R]) -> Callable[P, R]:
def wrapper(*args: P.args, **kw: P.kwargs) -> R:
return fn(Path(), *args, **kw)
return wrapper

@file_cache
def get_thing(path: Path, *, some_arg: int) -> int: ...
reveal_type(get_thing) # N: Revealed type is "def (*, some_arg: builtins.int) -> builtins.int"
get_thing(some_arg=1) # OK
[builtins fixtures/paramspec.pyi]

[case testParamSpecConcatenateCallbackApply]
from typing import Callable, Protocol
from typing_extensions import ParamSpec, Concatenate

P = ParamSpec("P")

class FuncType(Protocol[P]):
def __call__(self, x: int, s: str, *args: P.args, **kw_args: P.kwargs) -> str: ...

def forwarder1(fp: FuncType[P], *args: P.args, **kw_args: P.kwargs) -> str:
return fp(0, '', *args, **kw_args)

def forwarder2(fp: Callable[Concatenate[int, str, P], str], *args: P.args, **kw_args: P.kwargs) -> str:
return fp(0, '', *args, **kw_args)

def my_f(x: int, s: str, d: bool) -> str: ...
forwarder1(my_f, True) # OK
forwarder2(my_f, True) # OK
forwarder1(my_f, 1.0) # E: Argument 2 to "forwarder1" has incompatible type "float"; expected "bool"
forwarder2(my_f, 1.0) # E: Argument 2 to "forwarder2" has incompatible type "float"; expected "bool"
[builtins fixtures/paramspec.pyi]

[case testParamSpecCallbackProtocolSelf]
from typing import Callable, Protocol, TypeVar
from typing_extensions import ParamSpec, Concatenate

Params = ParamSpec("Params")
Result = TypeVar("Result", covariant=True)

class FancyMethod(Protocol):
def __call__(self, arg1: int, arg2: str) -> bool: ...
def return_me(self: Callable[Params, Result]) -> Callable[Params, Result]: ...
def return_part(self: Callable[Concatenate[int, Params], Result]) -> Callable[Params, Result]: ...

m: FancyMethod
reveal_type(m.return_me()) # N: Revealed type is "def (arg1: builtins.int, arg2: builtins.str) -> builtins.bool"
reveal_type(m.return_part()) # N: Revealed type is "def (arg2: builtins.str) -> builtins.bool"
[builtins fixtures/paramspec.pyi]

[case testParamSpecInferenceWithCallbackProtocol]
from typing import Protocol, Callable, ParamSpec

Expand Down
1 change: 1 addition & 0 deletions test-data/unit/fixtures/paramspec.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class object:

class function: ...
class ellipsis: ...
class classmethod: ...

class type:
def __init__(self, *a: object) -> None: ...
Expand Down