Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer unions for ternary expressions #17427

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5759,16 +5759,15 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F
context=if_type_fallback,
allow_none_return=allow_none_return,
)

# Only create a union type if the type context is a union, to be mostly
# compatible with older mypy versions where we always did a join.
#
# TODO: Always create a union or at least in more cases?
if isinstance(get_proper_type(self.type_context[-1]), UnionType):
res: Type = make_simplified_union([if_type, full_context_else_type])
else:
res = join.join_types(if_type, else_type)

res: Type = make_simplified_union([if_type, else_type])
if has_uninhabited_component(res) and not isinstance(
get_proper_type(self.type_context[-1]), UnionType
):
# In rare cases with empty collections join may give a better result.
alternative = join.join_types(if_type, else_type)
p_alt = get_proper_type(alternative)
if not isinstance(p_alt, Instance) or p_alt.type.fullname != "builtins.object":
res = alternative
return res

def analyze_cond_branch(
Expand Down
4 changes: 3 additions & 1 deletion mypyc/test-data/irbuild-any.test
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def f4(a, n, b):
a :: object
n :: int
b :: bool
r0, r1, r2, r3 :: object
r0 :: union[object, int]
r1, r2 :: object
r3 :: union[int, object]
r4 :: int
L0:
if b goto L1 else goto L2 :: bool
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-errorcodes.test
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ a: D = {'x': ''} # E: Incompatible types (expression has type "str", TypedDict
b: D = {'y': ''} # E: Missing key "x" for TypedDict "D" [typeddict-item] \
# E: Extra key "y" for TypedDict "D" [typeddict-unknown-key]
c = D(x=0) if int() else E(x=0, y=0)
c = {} # E: Expected TypedDict key "x" but found no keys [typeddict-item]
c = {} # E: Missing key "x" for TypedDict "D" [typeddict-item]
d: D = {'x': '', 'y': 1} # E: Extra key "y" for TypedDict "D" [typeddict-unknown-key] \
# E: Incompatible types (expression has type "str", TypedDict item "x" has type "int") [typeddict-item]

Expand Down
17 changes: 8 additions & 9 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1470,10 +1470,9 @@ if int():

[case testConditionalExpressionUnion]
from typing import Union
reveal_type(1 if bool() else 2) # N: Revealed type is "builtins.int"
reveal_type(1 if bool() else '') # N: Revealed type is "builtins.object"
x: Union[int, str] = reveal_type(1 if bool() else '') \
# N: Revealed type is "Union[Literal[1]?, Literal['']?]"
reveal_type(1 if bool() else 2) # N: Revealed type is "Union[Literal[1]?, Literal[2]?]"
reveal_type(1 if bool() else '') # N: Revealed type is "Union[Literal[1]?, Literal['']?]"
x: Union[int, str] = reveal_type(1 if bool() else '') # N: Revealed type is "Union[Literal[1]?, Literal['']?]"
class A:
pass
class B(A):
Expand All @@ -1487,17 +1486,17 @@ b = B()
c = C()
d = D()
reveal_type(a if bool() else b) # N: Revealed type is "__main__.A"
reveal_type(b if bool() else c) # N: Revealed type is "builtins.object"
reveal_type(c if bool() else b) # N: Revealed type is "builtins.object"
reveal_type(c if bool() else a) # N: Revealed type is "builtins.object"
reveal_type(d if bool() else b) # N: Revealed type is "__main__.A"
reveal_type(b if bool() else c) # N: Revealed type is "Union[__main__.B, __main__.C]"
reveal_type(c if bool() else b) # N: Revealed type is "Union[__main__.C, __main__.B]"
reveal_type(c if bool() else a) # N: Revealed type is "Union[__main__.C, __main__.A]"
reveal_type(d if bool() else b) # N: Revealed type is "Union[__main__.D, __main__.B]"
[builtins fixtures/bool.pyi]

[case testConditionalExpressionUnionWithAny]
from typing import Union, Any
a: Any
x: Union[int, str] = reveal_type(a if int() else 1) # N: Revealed type is "Union[Any, Literal[1]?]"
reveal_type(a if int() else 1) # N: Revealed type is "Any"
reveal_type(a if int() else 1) # N: Revealed type is "Union[Any, Literal[1]?]"

[case testConditionalExpressionStatementNoReturn]
from typing import List, Union
Expand Down
17 changes: 15 additions & 2 deletions test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2250,13 +2250,26 @@ def dec(f: Callable[[A, str], None]) -> Callable[[A, int], None]: pass
[out]

[case testUnknownFunctionNotCallable]
from typing import TypeVar

def f() -> None:
pass
def g(x: int) -> None:
pass
h = f if bool() else g
reveal_type(h) # N: Revealed type is "builtins.function"
h(7) # E: Cannot call function of unknown type
reveal_type(h) # N: Revealed type is "Union[def (), def (x: builtins.int)]"
h(7) # E: Too many arguments for "f"

T = TypeVar("T")
def join(x: T, y: T) -> T: ...

h2 = join(f, g)
reveal_type(h2) # N: Revealed type is "builtins.function"
h2(7) # E: Cannot call function of unknown type

h3 = join(g, f)
reveal_type(h3) # N: Revealed type is "builtins.function"
h3(7) # E: Cannot call function of unknown type
[builtins fixtures/bool.pyi]

[case testFunctionWithNameUnderscore]
Expand Down
6 changes: 3 additions & 3 deletions test-data/unit/check-inference-context.test
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ class A: pass
class B(A): pass
class C(A): pass
def f(func: Callable[[T], S], *z: T, r: Optional[S] = None) -> S: pass
reveal_type(f(lambda x: 0 if isinstance(x, B) else 1)) # N: Revealed type is "builtins.int"
reveal_type(f(lambda x: 0 if isinstance(x, B) else 1)) # N: Revealed type is "Union[Literal[0]?, Literal[1]?]"
f(lambda x: 0 if isinstance(x, B) else 1, A())() # E: "int" not callable
f(lambda x: x if isinstance(x, B) else B(), A(), r=B())() # E: "B" not callable
f(
Expand Down Expand Up @@ -1391,15 +1391,15 @@ from typing import Union, List, Any

def f(x: Union[List[str], Any]) -> None:
a = x if x else []
reveal_type(a) # N: Revealed type is "Union[builtins.list[Union[builtins.str, Any]], builtins.list[builtins.str], Any]"
reveal_type(a) # N: Revealed type is "Union[builtins.list[builtins.str], Any, builtins.list[Union[builtins.str, Any]]]"
[builtins fixtures/list.pyi]

[case testConditionalExpressionWithEmptyIteableAndUnionWithAny]
from typing import Union, Iterable, Any

def f(x: Union[Iterable[str], Any]) -> None:
a = x if x else []
reveal_type(a) # N: Revealed type is "Union[builtins.list[Union[builtins.str, Any]], typing.Iterable[builtins.str], Any]"
reveal_type(a) # N: Revealed type is "Union[typing.Iterable[builtins.str], Any, builtins.list[Union[builtins.str, Any]]]"
[builtins fixtures/list.pyi]

[case testInferMultipleAnyUnionCovariant]
Expand Down
10 changes: 7 additions & 3 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -1438,18 +1438,22 @@ class Wrapper:

def f(cond: bool) -> Any:
f = Wrapper if cond else lambda x: x
reveal_type(f) # N: Revealed type is "def (x: Any) -> Any"
reveal_type(f) # N: Revealed type is "Union[def (x: Any) -> __main__.Wrapper, def (x: Any) -> Any]"
return f(3)

def g(cond: bool) -> Any:
f = lambda x: x if cond else Wrapper
reveal_type(f) # N: Revealed type is "def (x: Any) -> Any"
reveal_type(f) # N: Revealed type is "def (x: Any) -> Union[Any, def (x: Any) -> __main__.Wrapper]"
return f(3)

def h(cond: bool) -> Any:
f = (lambda x: x) if cond else Wrapper
reveal_type(f) # N: Revealed type is "Union[def (x: Any) -> Any, def (x: Any) -> __main__.Wrapper]"
return f(3)

-- Boolean operators
-- -----------------


[case testOrOperationWithGenericOperands]
from typing import List
a: List[A]
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-optional.test
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def lookup_field(name, obj):
attr = None

[case testTernaryWithNone]
reveal_type(None if bool() else 0) # N: Revealed type is "Union[Literal[0]?, None]"
reveal_type(None if bool() else 0) # N: Revealed type is "Union[None, Literal[0]?]"
[builtins fixtures/bool.pyi]

[case testListWithNone]
Expand Down
87 changes: 52 additions & 35 deletions test-data/unit/check-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -1228,68 +1228,76 @@ x, y = g(z) # E: Argument 1 to "g" has incompatible type "int"; expected "Tuple[
[out]

[case testFixedTupleJoinVarTuple]
from typing import Tuple
from typing import Tuple, TypeVar

class A: pass
class B(A): pass

fixtup: Tuple[B, B]

T = TypeVar("T")
def join(x: T, y: T) -> T: ...

vartup_b: Tuple[B, ...]
reveal_type(fixtup if int() else vartup_b) # N: Revealed type is "builtins.tuple[__main__.B, ...]"
reveal_type(vartup_b if int() else fixtup) # N: Revealed type is "builtins.tuple[__main__.B, ...]"
reveal_type(join(fixtup, vartup_b)) # N: Revealed type is "builtins.tuple[__main__.B, ...]"
reveal_type(join(vartup_b, fixtup)) # N: Revealed type is "builtins.tuple[__main__.B, ...]"

vartup_a: Tuple[A, ...]
reveal_type(fixtup if int() else vartup_a) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(vartup_a if int() else fixtup) # N: Revealed type is "builtins.tuple[__main__.A, ...]"

reveal_type(join(fixtup, vartup_a)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(join(vartup_a, fixtup)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"

[builtins fixtures/tuple.pyi]
[out]

[case testFixedTupleJoinList]
from typing import Tuple, List
from typing import Tuple, List, TypeVar

class A: pass
class B(A): pass

fixtup: Tuple[B, B]

T = TypeVar("T")
def join(x: T, y: T) -> T: ...

lst_b: List[B]
reveal_type(fixtup if int() else lst_b) # N: Revealed type is "typing.Sequence[__main__.B]"
reveal_type(lst_b if int() else fixtup) # N: Revealed type is "typing.Sequence[__main__.B]"
reveal_type(join(fixtup, lst_b)) # N: Revealed type is "typing.Sequence[__main__.B]"
reveal_type(join(lst_b, fixtup)) # N: Revealed type is "typing.Sequence[__main__.B]"

lst_a: List[A]
reveal_type(fixtup if int() else lst_a) # N: Revealed type is "typing.Sequence[__main__.A]"
reveal_type(lst_a if int() else fixtup) # N: Revealed type is "typing.Sequence[__main__.A]"
reveal_type(join(fixtup, lst_a)) # N: Revealed type is "typing.Sequence[__main__.A]"
reveal_type(join(lst_a, fixtup)) # N: Revealed type is "typing.Sequence[__main__.A]"

[builtins fixtures/tuple.pyi]
[out]

[case testEmptyTupleJoin]
from typing import Tuple, List
from typing import Tuple, List, TypeVar

class A: pass

empty = ()

T = TypeVar("T")
def join(x: T, y: T) -> T: ...

fixtup: Tuple[A]
reveal_type(fixtup if int() else empty) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(empty if int() else fixtup) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(join(fixtup, empty)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(join(empty, fixtup)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"

vartup: Tuple[A, ...]
reveal_type(empty if int() else vartup) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(vartup if int() else empty) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(join(vartup, empty)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"
reveal_type(join(empty, vartup)) # N: Revealed type is "builtins.tuple[__main__.A, ...]"

lst: List[A]
reveal_type(empty if int() else lst) # N: Revealed type is "typing.Sequence[__main__.A]"
reveal_type(lst if int() else empty) # N: Revealed type is "typing.Sequence[__main__.A]"
reveal_type(join(empty, lst)) # N: Revealed type is "typing.Sequence[__main__.A]"
reveal_type(join(lst, empty)) # N: Revealed type is "typing.Sequence[__main__.A]"

[builtins fixtures/tuple.pyi]
[out]

[case testTupleSubclassJoin]
from typing import Tuple, NamedTuple
from typing import Tuple, NamedTuple, TypeVar

class NTup(NamedTuple):
a: bool
Expand All @@ -1302,32 +1310,38 @@ ntup: NTup
subtup: SubTuple
vartup: SubVarTuple

reveal_type(ntup if int() else vartup) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(subtup if int() else vartup) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
T = TypeVar("T")
def join(x: T, y: T) -> T: ...

reveal_type(join(ntup, vartup)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(subtup, vartup)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"

[builtins fixtures/tuple.pyi]
[out]

[case testTupleJoinIrregular]
from typing import Tuple
from typing import Tuple, TypeVar

tup1: Tuple[bool, int]
tup2: Tuple[bool]

reveal_type(tup1 if int() else tup2) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(tup2 if int() else tup1) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
T = TypeVar("T")
def join(x: T, y: T) -> T: ...

reveal_type(join(tup1, tup2)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(tup2, tup1)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"

reveal_type(tup1 if int() else ()) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(() if int() else tup1) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(tup1, ())) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join((), tup1)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"

reveal_type(tup2 if int() else ()) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
reveal_type(() if int() else tup2) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
reveal_type(join(tup2, ())) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
reveal_type(join((), tup2)) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"

[builtins fixtures/tuple.pyi]
[out]

[case testTupleSubclassJoinIrregular]
from typing import Tuple, NamedTuple
from typing import Tuple, NamedTuple, TypeVar

class NTup1(NamedTuple):
a: bool
Expand All @@ -1342,14 +1356,17 @@ tup1: NTup1
tup2: NTup2
subtup: SubTuple

reveal_type(tup1 if int() else tup2) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
reveal_type(tup2 if int() else tup1) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
T = TypeVar("T")
def join(x: T, y: T) -> T: ...

reveal_type(join(tup1, tup2)) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"
reveal_type(join(tup2, tup1)) # N: Revealed type is "builtins.tuple[builtins.bool, ...]"

reveal_type(tup1 if int() else subtup) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(subtup if int() else tup1) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(tup1, subtup)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(subtup, tup1)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"

reveal_type(tup2 if int() else subtup) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(subtup if int() else tup2) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(tup2, subtup)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"
reveal_type(join(subtup, tup2)) # N: Revealed type is "builtins.tuple[builtins.int, ...]"

[builtins fixtures/tuple.pyi]
[out]
Expand Down
Loading