Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes type inference for generic calls in if expr #11128

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 19 additions & 5 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3836,8 +3836,10 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F
allow_none_return=allow_none_return)

# Analyze the right branch using full type context and store the type
full_context_else_type = self.analyze_cond_branch(else_map, e.else_expr, context=ctx,
allow_none_return=allow_none_return)
full_context_else_type = self.analyze_cond_branch(else_map, e.else_expr,
context=ctx,
allow_none_return=allow_none_return,
is_else=True)
if not mypy.checker.is_valid_inferred_type(if_type):
# Analyze the right branch disregarding the left branch.
else_type = full_context_else_type
Expand All @@ -3855,7 +3857,8 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F
# Analyze the right branch in the context of the left
# branch's type.
else_type = self.analyze_cond_branch(else_map, e.else_expr, context=if_type,
allow_none_return=allow_none_return)
allow_none_return=allow_none_return,
is_else=True)

# Only create a union type if the type context is a union, to be mostly
# compatible with older mypy versions where we always did a join.
Expand All @@ -3870,14 +3873,25 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F

def analyze_cond_branch(self, map: Optional[Dict[Expression, Type]],
node: Expression, context: Optional[Type],
allow_none_return: bool = False) -> Type:
allow_none_return: bool = False,
is_else: bool = False) -> Type:
with self.chk.binder.frame_context(can_skip=True, fall_through=0):
if map is not None:
self.chk.push_type_map(map)
if is_else and context is not None and isinstance(node, CallExpr):
# When calling a function on the else part,
# we can face a generic function with multiple type vars.
# When inferecing it, `context` might be used instead of real args.
# Usually, we don't want that.
# https://github.com/python/mypy/issues/11049
if not is_subtype(self.accept(node), context, ignore_type_params=True):
sobolevn marked this conversation as resolved.
Show resolved Hide resolved
context = None

if map is None:
# We still need to type check node, in case we want to
# process it for isinstance checks later
self.accept(node, type_context=context, allow_none_return=allow_none_return)
return UninhabitedType()
self.chk.push_type_map(map)
return self.accept(node, type_context=context, allow_none_return=allow_none_return)

def visit_backquote_expr(self, e: BackquoteExpr) -> Type:
Expand Down
78 changes: 77 additions & 1 deletion test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -1975,7 +1975,7 @@ T = TypeVar('T')

class A:
def f(self) -> None:
self.g() # E: Too few arguments for "g" of "A"
self.g() # E: Too few arguments for "g" of "A"
self.g(1)
@dec
def g(self, x: str) -> None: pass
Expand Down Expand Up @@ -2246,6 +2246,30 @@ a = set() if f() else {0}
a() # E: "Set[int]" not callable
[builtins fixtures/set.pyi]

[case testUnificationEmptySetRight]
def f(): pass
a = {0} if f() else set()
a() # E: "Set[int]" not callable
[builtins fixtures/set.pyi]

[case testUnificationEmptyCustomSetLeft]
from typing import Set, TypeVar
T = TypeVar('T')
class customset(Set[T]): pass
def f(): pass
a = customset() if f() else {1}
a() # E: "Set[int]" not callable
[builtins fixtures/set.pyi]

[case testUnificationEmptyCustomSetRight]
from typing import Set, TypeVar
T = TypeVar('T')
class customset(Set[T]): pass
def f(): pass
a = {0} if f() else customset()
a() # E: "Set[int]" not callable
[builtins fixtures/set.pyi]

[case testUnificationEmptyDictLeft]
def f(): pass
a = {} if f() else {0: 0}
Expand All @@ -2270,6 +2294,58 @@ a = {0: [0]} if f() else {0: []}
a() # E: "Dict[int, List[int]]" not callable
[builtins fixtures/dict.pyi]

[case testConditionalInferenceGenericFunctionRight]
from typing import TypeVar, Union

T1 = TypeVar("T1")
T2 = TypeVar("T2")

def foo(a: T1, b: T2) -> Union[T1, T2]: pass
x: bool

reveal_type(1 if x else foo(1, "s")) # N: Revealed type is "Union[builtins.int*, builtins.str*]"
reveal_type("a" if x else foo(1, "s")) # N: Revealed type is "Union[builtins.int*, builtins.str*]"
reveal_type(1 if x else foo("s", 1)) # N: Revealed type is "Union[builtins.str*, builtins.int*]"
reveal_type("a" if x else foo("s", 1)) # N: Revealed type is "Union[builtins.str*, builtins.int*]"
[builtins fixtures/bool.pyi]

[case testConditionalInferenceGenericFunctionLeft]
from typing import TypeVar, Union

T1 = TypeVar("T1")
T2 = TypeVar("T2")

def foo(a: T1, b: T2) -> Union[T1, T2]: pass
x: bool

reveal_type(foo(1, "s") if x else 1) # N: Revealed type is "Union[builtins.int*, builtins.str*]"
reveal_type(foo(1, "s") if x else "a") # N: Revealed type is "Union[builtins.int*, builtins.str*]"
reveal_type(foo("s", 1) if x else 1) # N: Revealed type is "Union[builtins.str*, builtins.int*]"
reveal_type(foo("s", 1) if x else "a") # N: Revealed type is "Union[builtins.str*, builtins.int*]"
[builtins fixtures/bool.pyi]

[case testConditionalInferenceSelfNarrowingRight]
from typing import Optional

class C:
x: Optional[int]
def check(self) -> Optional[int]:
return None if self.x is None else self.x.conjugate()

reveal_type(C().check()) # N: Revealed type is "Union[builtins.int, None]"
[builtins fixtures/bool.pyi]

[case testConditionalInferenceSelfNarrowingLeft]
from typing import Optional

class C:
x: Optional[int]
def check(self) -> Optional[int]:
return self.x.conjugate() if self.x is not None else None

reveal_type(C().check()) # N: Revealed type is "Union[builtins.int, None]"
[builtins fixtures/bool.pyi]

[case testMisguidedSetItem]
from typing import Generic, Sequence, TypeVar
T = TypeVar('T')
Expand Down
3 changes: 2 additions & 1 deletion test-data/unit/fixtures/bool.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ class object:
class type: pass
class tuple(Generic[T]): pass
class function: pass
class int: pass
class int:
def conjugate(self) -> int: pass
class bool(int): pass
class float: pass
class str: pass
Expand Down