Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer correct types with overloads of Type[Guard | Is] #17678

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5879,15 +5879,29 @@ def find_isinstance_check_helper(
# considered "always right" (i.e. even if the types are not overlapping).
# Also note that a care must be taken to unwrap this back at read places
# where we use this to narrow down declared type.
if node.callee.type_guard is not None:
return {expr: TypeGuardedType(node.callee.type_guard)}, {}
with self.msg.filter_errors(), self.local_type_map():
_, real_func = self.expr_checker.check_call(
get_proper_type(self.lookup_type(node.callee)),
node.args,
node.arg_kinds,
node,
node.arg_names,
)
real_func = get_proper_type(real_func)
if not isinstance(real_func, CallableType) or not (
real_func.type_guard or real_func.type_is
):
return {}, {}

if real_func.type_guard is not None:
return {expr: TypeGuardedType(real_func.type_guard)}, {}
else:
assert node.callee.type_is is not None
assert real_func.type_is is not None
return conditional_types_to_typemaps(
expr,
*self.conditional_types_with_intersection(
self.lookup_type(expr),
[TypeRange(node.callee.type_is, is_upper_bound=False)],
[TypeRange(real_func.type_is, is_upper_bound=False)],
expr,
),
)
Expand Down
83 changes: 73 additions & 10 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2876,16 +2876,37 @@ def infer_overload_return_type(
elif all_same_types([erase_type(typ) for typ in return_types]):
self.chk.store_types(type_maps[0])
return erase_type(return_types[0]), erase_type(inferred_types[0])
else:
return self.check_call(
callee=AnyType(TypeOfAny.special_form),
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
return self.check_call(
callee=AnyType(TypeOfAny.special_form),
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
elif not all_same_type_narrowers(matches):
# This is an example of how overloads can be:
#
# @overload
# def is_int(obj: float) -> TypeGuard[float]: ...
# @overload
# def is_int(obj: int) -> TypeGuard[int]: ...
#
# x: Any
# if is_int(x):
# reveal_type(x) # N: int | float
#
# So, we need to check that special case.
return self.check_call(
callee=self.combine_function_signatures(cast("list[ProperType]", matches)),
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
else:
# Success! No ambiguity; return the first match.
self.chk.store_types(type_maps[0])
Expand Down Expand Up @@ -3100,6 +3121,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call
new_args: list[list[Type]] = [[] for _ in range(len(callables[0].arg_types))]
new_kinds = list(callables[0].arg_kinds)
new_returns: list[Type] = []
new_type_guards: list[Type] = []
new_type_narrowers: list[Type] = []

too_complex = False
for target in callables:
Expand All @@ -3126,8 +3149,25 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call
for i, arg in enumerate(target.arg_types):
new_args[i].append(arg)
new_returns.append(target.ret_type)
if target.type_guard:
new_type_guards.append(target.type_guard)
if target.type_is:
new_type_narrowers.append(target.type_is)

if new_type_guards and new_type_narrowers:
# They cannot be definined at the same time,
# declaring this function as too complex!
too_complex = True
union_type_guard = None
union_type_is = None
else:
union_type_guard = make_simplified_union(new_type_guards) if new_type_guards else None
union_type_is = (
make_simplified_union(new_type_narrowers) if new_type_narrowers else None
)

union_return = make_simplified_union(new_returns)

if too_complex:
any = AnyType(TypeOfAny.special_form)
return callables[0].copy_modified(
Expand All @@ -3137,6 +3177,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call
ret_type=union_return,
variables=variables,
implicit=True,
type_guard=union_type_guard,
type_is=union_type_is,
)

final_args = []
Expand All @@ -3150,6 +3192,8 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call
ret_type=union_return,
variables=variables,
implicit=True,
type_guard=union_type_guard,
type_is=union_type_is,
)

def erased_signature_similarity(
Expand Down Expand Up @@ -6464,6 +6508,25 @@ def all_same_types(types: list[Type]) -> bool:
return all(is_same_type(t, types[0]) for t in types[1:])


def all_same_type_narrowers(types: list[CallableType]) -> bool:
if not types:
return True

type_guards: list[Type] = []
type_narrowers: list[Type] = []

for typ in types:
if typ.type_guard:
type_guards.append(typ.type_guard)
if typ.type_is:
type_narrowers.append(typ.type_is)
if type_guards and type_narrowers:
# Some overloads declare `TypeGuard` and some declare `TypeIs`,
# we cannot handle this in a union.
return False
return all_same_types(type_guards) and all_same_types(type_narrowers)


def merge_typevars_in_callables_by_name(
callables: Sequence[CallableType],
) -> tuple[list[CallableType], list[TypeVarType]]:
Expand Down
56 changes: 56 additions & 0 deletions test-data/unit/check-typeguard.test
Original file line number Diff line number Diff line change
Expand Up @@ -721,3 +721,59 @@ x: object
assert a(x=x)
reveal_type(x) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]

[case testTypeGuardInOverloads]
from typing import Any, overload, Union
from typing_extensions import TypeGuard

@overload
def func1(x: str) -> TypeGuard[str]:
...

@overload
def func1(x: int) -> TypeGuard[int]:
...

def func1(x: Any) -> Any:
return True

def func2(val: Any):
if func1(val):
reveal_type(val) # N: Revealed type is "Union[builtins.str, builtins.int]"
else:
reveal_type(val) # N: Revealed type is "Any"

def func3(val: Union[int, str]):
if func1(val):
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
else:
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"

def func4(val: int):
if func1(val):
reveal_type(val) # N: Revealed type is "builtins.int"
else:
reveal_type(val) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]

[case testTypeIsInOverloadsSameReturn]
from typing import Any, overload, Union
from typing_extensions import TypeGuard

@overload
def func1(x: str) -> TypeGuard[str]:
...

@overload
def func1(x: int) -> TypeGuard[str]:
...

def func1(x: Any) -> Any:
return True

def func2(val: Union[int, str]):
if func1(val):
reveal_type(val) # N: Revealed type is "builtins.str"
else:
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/tuple.pyi]
119 changes: 119 additions & 0 deletions test-data/unit/check-typeis.test
Original file line number Diff line number Diff line change
Expand Up @@ -808,3 +808,122 @@ accept_typeguard(typeis) # E: Argument 1 to "accept_typeguard" has incompatible
accept_typeguard(typeguard)

[builtins fixtures/tuple.pyi]

[case testTypeIsInOverloads]
from typing import Any, overload, Union
from typing_extensions import TypeIs

@overload
def func1(x: str) -> TypeIs[str]:
...

@overload
def func1(x: int) -> TypeIs[int]:
...

def func1(x: Any) -> Any:
return True

def func2(val: Any):
if func1(val):
reveal_type(val) # N: Revealed type is "Union[builtins.str, builtins.int]"
else:
reveal_type(val) # N: Revealed type is "Any"

def func3(val: Union[int, str]):
if func1(val):
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
else:
reveal_type(val)

def func4(val: int):
if func1(val):
reveal_type(val) # N: Revealed type is "builtins.int"
else:
reveal_type(val)
[builtins fixtures/tuple.pyi]

[case testTypeIsInOverloadsSameReturn]
from typing import Any, overload, Union
from typing_extensions import TypeIs

@overload
def func1(x: str) -> TypeIs[str]:
...

@overload
def func1(x: int) -> TypeIs[str]: # type: ignore
...

def func1(x: Any) -> Any:
return True

def func2(val: Union[int, str]):
if func1(val):
reveal_type(val) # N: Revealed type is "builtins.str"
else:
reveal_type(val) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]

[case testTypeIsInOverloadsUnionizeError]
from typing import Any, overload, Union
from typing_extensions import TypeIs, TypeGuard

@overload
def func1(x: str) -> TypeIs[str]:
...

@overload
def func1(x: int) -> TypeGuard[int]:
...

def func1(x: Any) -> Any:
return True

def func2(val: Union[int, str]):
if func1(val):
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
else:
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/tuple.pyi]

[case testTypeIsInOverloadsUnionizeError2]
from typing import Any, overload, Union
from typing_extensions import TypeIs, TypeGuard

@overload
def func1(x: int) -> TypeGuard[int]:
...

@overload
def func1(x: str) -> TypeIs[str]:
...

def func1(x: Any) -> Any:
return True

def func2(val: Union[int, str]):
if func1(val):
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
else:
reveal_type(val) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/tuple.pyi]

[case testTypeIsLikeIsDataclass]
from typing import Any, overload, Union, Type
from typing_extensions import TypeIs

class DataclassInstance: ...

@overload
def is_dataclass(obj: type) -> TypeIs[Type[DataclassInstance]]: ...
@overload
def is_dataclass(obj: object) -> TypeIs[Union[DataclassInstance, Type[DataclassInstance]]]: ...

def is_dataclass(obj: Union[type, object]) -> bool:
return False

def func(arg: Any) -> None:
if is_dataclass(arg):
reveal_type(arg) # N: Revealed type is "Union[Type[__main__.DataclassInstance], __main__.DataclassInstance]"
[builtins fixtures/tuple.pyi]
Loading