Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make imprecise constraints handling more robust #16502

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 43 additions & 33 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,25 +226,22 @@ def infer_constraints_for_callable(
actual_type = mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
)
if (
param_spec
and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2)
and not incomplete_star_mapping
):
if param_spec and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2):
# If actual arguments are mapped to ParamSpec type, we can't infer individual
# constraints, instead store them and infer single constraint at the end.
# It is impossible to map actual kind to formal kind, so use some heuristic.
# This inference is used as a fallback, so relying on heuristic should be OK.
param_spec_arg_types.append(
mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], None, arg_kinds[actual]
if not incomplete_star_mapping:
param_spec_arg_types.append(
mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], None, arg_kinds[actual]
)
)
)
actual_kind = arg_kinds[actual]
param_spec_arg_kinds.append(
ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind
)
param_spec_arg_names.append(arg_names[actual] if arg_names else None)
actual_kind = arg_kinds[actual]
param_spec_arg_kinds.append(
ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind
)
param_spec_arg_names.append(arg_names[actual] if arg_names else None)
else:
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
constraints.extend(c)
Expand All @@ -267,6 +264,9 @@ def infer_constraints_for_callable(
),
)
)
if any(isinstance(v, ParamSpecType) for v in callee.variables):
# As a perf optimization filter imprecise constraints only when we can have them.
constraints = filter_imprecise_kinds(constraints)
return constraints


Expand Down Expand Up @@ -1094,29 +1094,18 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
)

param_spec_target: Type | None = None
skip_imprecise = (
any(c.type_var == param_spec.id for c in res) and cactual.imprecise_arg_kinds
)
if not cactual_ps:
max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)])
prefix_len = min(prefix_len, max_prefix_len)
# This logic matches top-level callable constraint exception, if we managed
# to get other constraints for ParamSpec, don't infer one with imprecise kinds
if not skip_imprecise:
param_spec_target = Parameters(
arg_types=cactual.arg_types[prefix_len:],
arg_kinds=cactual.arg_kinds[prefix_len:],
arg_names=cactual.arg_names[prefix_len:],
variables=cactual.variables
if not type_state.infer_polymorphic
else [],
imprecise_arg_kinds=cactual.imprecise_arg_kinds,
)
param_spec_target = Parameters(
arg_types=cactual.arg_types[prefix_len:],
arg_kinds=cactual.arg_kinds[prefix_len:],
arg_names=cactual.arg_names[prefix_len:],
variables=cactual.variables if not type_state.infer_polymorphic else [],
imprecise_arg_kinds=cactual.imprecise_arg_kinds,
)
else:
if (
len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types)
and not skip_imprecise
):
if len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types):
param_spec_target = cactual_ps.copy_modified(
prefix=Parameters(
arg_types=cactual_ps.prefix.arg_types[prefix_len:],
Expand Down Expand Up @@ -1611,3 +1600,24 @@ def infer_callable_arguments_constraints(
infer_directed_arg_constraints(left_by_name.typ, right_by_name.typ, direction)
)
return res


def filter_imprecise_kinds(cs: list[Constraint]) -> list[Constraint]:
"""For each ParamSpec remove all imprecise constraints, if at least one precise available."""
have_precise = set()
for c in cs:
if not isinstance(c.origin_type_var, ParamSpecType):
continue
if (
isinstance(c.target, ParamSpecType)
or isinstance(c.target, Parameters)
and not c.target.imprecise_arg_kinds
):
have_precise.add(c.type_var)
new_cs = []
for c in cs:
if not isinstance(c.origin_type_var, ParamSpecType) or c.type_var not in have_precise:
new_cs.append(c)
if not isinstance(c.target, Parameters) or not c.target.imprecise_arg_kinds:
new_cs.append(c)
return new_cs
1 change: 1 addition & 0 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
t.prefix.arg_kinds + repl.arg_kinds,
t.prefix.arg_names + repl.arg_names,
variables=[*t.prefix.variables, *repl.variables],
imprecise_arg_kinds=repl.imprecise_arg_kinds,
)
else:
# We could encode Any as trivial parameters etc., but it would be too verbose.
Expand Down
23 changes: 23 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -2108,3 +2108,26 @@ def func2(arg: T) -> List[Union[T, str]]:
reveal_type(func2) # N: Revealed type is "def [T] (arg: T`-1) -> Union[T`-1, builtins.str]"
reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/paramspec.pyi]

[case testParamSpecPreciseKindsUsedIfPossible]
from typing import Callable, Generic
from typing_extensions import ParamSpec

P = ParamSpec('P')

class Case(Generic[P]):
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
pass

def _test(a: int, b: int = 0) -> None: ...

def parametrize(
func: Callable[P, None], *cases: Case[P], **named_cases: Case[P]
) -> Callable[[], None]:
...

parametrize(_test, Case(1, 2), Case(3, 4))
parametrize(_test, Case(1, b=2), Case(3, b=4))
parametrize(_test, Case(1, 2), Case(3))
parametrize(_test, Case(1, 2), Case(3, b=4))
[builtins fixtures/paramspec.pyi]