Skip to content

Commit

Permalink
Special-case unions in polymorphic inference (#16461)
Browse files Browse the repository at this point in the history
Fixes #16451

This special-casing is unfortunate, but this is the best I came up so
far.
  • Loading branch information
ilevkivskyi authored and JukkaL committed Nov 15, 2023
1 parent f862d3e commit 4b5b316
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 9 deletions.
53 changes: 44 additions & 9 deletions mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Iterable, Sequence
from typing_extensions import TypeAlias as _TypeAlias

from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints
from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op
from mypy.expandtype import expand_type
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
from mypy.join import join_types
Expand Down Expand Up @@ -69,6 +69,10 @@ def solve_constraints(
extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars])
originals.update({v.id: v for v in c.extra_tvars if v.id not in originals})

if allow_polymorphic:
# Constraints inferred from unions require special handling in polymorphic inference.
constraints = skip_reverse_union_constraints(constraints)

# Collect a list of constraints for each type variable.
cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars}
for con in constraints:
Expand Down Expand Up @@ -431,19 +435,15 @@ def transitive_closure(
uppers[l] |= uppers[upper]
for lt in lowers[lower]:
for ut in uppers[upper]:
# TODO: what if secondary constraints result in inference
# against polymorphic actual (also in below branches)?
remaining |= set(infer_constraints(lt, ut, SUBTYPE_OF))
remaining |= set(infer_constraints(ut, lt, SUPERTYPE_OF))
add_secondary_constraints(remaining, lt, ut)
elif c.op == SUBTYPE_OF:
if c.target in uppers[c.type_var]:
continue
for l in tvars:
if (l, c.type_var) in graph:
uppers[l].add(c.target)
for lt in lowers[c.type_var]:
remaining |= set(infer_constraints(lt, c.target, SUBTYPE_OF))
remaining |= set(infer_constraints(c.target, lt, SUPERTYPE_OF))
add_secondary_constraints(remaining, lt, c.target)
else:
assert c.op == SUPERTYPE_OF
if c.target in lowers[c.type_var]:
Expand All @@ -452,11 +452,24 @@ def transitive_closure(
if (c.type_var, u) in graph:
lowers[u].add(c.target)
for ut in uppers[c.type_var]:
remaining |= set(infer_constraints(ut, c.target, SUPERTYPE_OF))
remaining |= set(infer_constraints(c.target, ut, SUBTYPE_OF))
add_secondary_constraints(remaining, c.target, ut)
return graph, lowers, uppers


def add_secondary_constraints(cs: set[Constraint], lower: Type, upper: Type) -> None:
"""Add secondary constraints inferred between lower and upper (in place)."""
if isinstance(get_proper_type(upper), UnionType) and isinstance(
get_proper_type(lower), UnionType
):
# When both types are unions, this can lead to inferring spurious constraints,
# for example Union[T, int] <: S <: Union[T, int] may infer T <: int.
# To avoid this, just skip them for now.
return
# TODO: what if secondary constraints result in inference against polymorphic actual?
cs.update(set(infer_constraints(lower, upper, SUBTYPE_OF)))
cs.update(set(infer_constraints(upper, lower, SUPERTYPE_OF)))


def compute_dependencies(
tvars: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds
) -> dict[TypeVarId, list[TypeVarId]]:
Expand Down Expand Up @@ -494,6 +507,28 @@ def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool:
return True


def skip_reverse_union_constraints(cs: list[Constraint]) -> list[Constraint]:
"""Avoid ambiguities for constraints inferred from unions during polymorphic inference.
Polymorphic inference implicitly relies on assumption that a reverse of a linear constraint
is a linear constraint. This is however not true in presence of union types, for example
T :> Union[S, int] vs S <: T. Trying to solve such constraints would be detected ambiguous
as (T, S) form a non-linear SCC. However, simply removing the linear part results in a valid
solution T = Union[S, int], S = <free>.
TODO: a cleaner solution may be to avoid inferring such constraints in first place, but
this would require passing around a flag through all infer_constraints() calls.
"""
reverse_union_cs = set()
for c in cs:
p_target = get_proper_type(c.target)
if isinstance(p_target, UnionType):
for item in p_target.items:
if isinstance(item, TypeVarType):
reverse_union_cs.add(Constraint(item, neg_op(c.op), c.origin_type_var))
return [c for c in cs if c not in reverse_union_cs]


def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]:
"""Find type variables for which we are solving in a target type."""
return {tv.id for tv in get_all_type_vars(target)} & set(vars)
Expand Down
21 changes: 21 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -3767,3 +3767,24 @@ def f(values: List[T]) -> T: ...
x = foo(f([C()]))
reveal_type(x) # N: Revealed type is "__main__.C"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericCallableUnion]
from typing import Callable, TypeVar, List, Union

T = TypeVar("T")
S = TypeVar("S")

def dec(f: Callable[[S], T]) -> Callable[[S], List[T]]: ...
@dec
def func(arg: T) -> Union[T, str]:
...
reveal_type(func) # N: Revealed type is "def [S] (S`1) -> builtins.list[Union[S`1, builtins.str]]"
reveal_type(func(42)) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.str]]"

def dec2(f: Callable[[S], List[T]]) -> Callable[[S], T]: ...
@dec2
def func2(arg: T) -> List[Union[T, str]]:
...
reveal_type(func2) # N: Revealed type is "def [S] (S`4) -> Union[S`4, builtins.str]"
reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/list.pyi]
22 changes: 22 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -2086,3 +2086,25 @@ reveal_type(d(b, f1)) # E: Cannot infer type argument 1 of "d" \
# N: Revealed type is "def (*Any, **Any)"
reveal_type(d(b, f2)) # N: Revealed type is "def (builtins.int)"
[builtins fixtures/paramspec.pyi]

[case testInferenceAgainstGenericCallableUnionParamSpec]
from typing import Callable, TypeVar, List, Union
from typing_extensions import ParamSpec

T = TypeVar("T")
P = ParamSpec("P")

def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ...
@dec
def func(arg: T) -> Union[T, str]:
...
reveal_type(func) # N: Revealed type is "def [T] (arg: T`-1) -> builtins.list[Union[T`-1, builtins.str]]"
reveal_type(func(42)) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.str]]"

def dec2(f: Callable[P, List[T]]) -> Callable[P, T]: ...
@dec2
def func2(arg: T) -> List[Union[T, str]]:
...
reveal_type(func2) # N: Revealed type is "def [T] (arg: T`-1) -> Union[T`-1, builtins.str]"
reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/paramspec.pyi]

0 comments on commit 4b5b316

Please sign in to comment.