From bd485c2977d888d6387efe412612d3b8c0100c5e Mon Sep 17 00:00:00 2001 From: Evgeniy Slobodkin Date: Sat, 20 Apr 2024 21:24:38 +0300 Subject: [PATCH] fix: TypeGuard becomes bool instead of Any when passed as TypeVar --- mypy/constraints.py | 8 ++++++-- test-data/unit/check-typeguard.test | 9 +++++++++ test-data/unit/check-typeis.test | 9 +++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index cdfa39ac45f3..967f6d590c5b 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -1018,13 +1018,17 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: param_spec = template.param_spec() template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type + bool_type = UnionType( + [LiteralType(True, cactual_ret_type), LiteralType(False, cactual_ret_type)] # type: ignore[arg-type] + ) + if template.type_guard is not None and cactual.type_guard is not None: template_ret_type = template.type_guard cactual_ret_type = cactual.type_guard elif template.type_guard is not None: template_ret_type = AnyType(TypeOfAny.special_form) elif cactual.type_guard is not None: - cactual_ret_type = AnyType(TypeOfAny.special_form) + cactual_ret_type = bool_type if template.type_is is not None and cactual.type_is is not None: template_ret_type = template.type_is @@ -1032,7 +1036,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]: elif template.type_is is not None: template_ret_type = AnyType(TypeOfAny.special_form) elif cactual.type_is is not None: - cactual_ret_type = AnyType(TypeOfAny.special_form) + cactual_ret_type = bool_type res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction)) diff --git a/test-data/unit/check-typeguard.test b/test-data/unit/check-typeguard.test index 27b88553fb43..e7a8eac4f043 100644 --- a/test-data/unit/check-typeguard.test +++ b/test-data/unit/check-typeguard.test @@ -87,6 +87,15 @@ def main(a: Tuple[T, ...]): reveal_type(a) # N: Revealed type is "Tuple[T`-1, T`-1]" [builtins fixtures/tuple.pyi] +[case testTypeGuardPassedAsTypeVarIsBool] +from typing import Callable, TypeVar +from typing_extensions import TypeGuard +T = TypeVar('T') +def is_str(x: object) -> TypeGuard[str]: ... +def main(f: Callable[[object], T]) -> T: ... +reveal_type(main(is_str)) # N: Revealed type is "builtins.bool" +[builtins fixtures/tuple.pyi] + [case testTypeGuardNonOverlapping] from typing import List from typing_extensions import TypeGuard diff --git a/test-data/unit/check-typeis.test b/test-data/unit/check-typeis.test index 6b96845504ab..2372f990fda1 100644 --- a/test-data/unit/check-typeis.test +++ b/test-data/unit/check-typeis.test @@ -104,6 +104,15 @@ def main(x: object, type_check_func: Callable[[object], TypeIs[T]]) -> T: reveal_type(main("a", is_str)) # N: Revealed type is "builtins.str" [builtins fixtures/exception.pyi] +[case testTypeIsPassedAsTypeVarIsBool] +from typing import Callable, TypeVar +from typing_extensions import TypeIs +T = TypeVar('T') +def is_str(x: object) -> TypeIs[str]: pass +def main(f: Callable[[object], T]) -> T: pass +reveal_type(main(is_str)) # N: Revealed type is "builtins.bool" +[builtins fixtures/tuple.pyi] + [case testTypeIsUnionIn] from typing import Union from typing_extensions import TypeIs