diff --git a/mypy/meet.py b/mypy/meet.py index ee32f239df8c..a0c54bbe03b3 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -541,6 +541,9 @@ def _type_object_overlap(left: Type, right: Type) -> bool: return False if isinstance(left, CallableType) and isinstance(right, CallableType): + # We run is_callable_compatible in both directions, similar to the logic + # in is_unsafe_overlapping_overload_signatures + # See comments in https://github.com/python/mypy/pull/5476 return is_callable_compatible( left, right, @@ -548,6 +551,14 @@ def _type_object_overlap(left: Type, right: Type) -> bool: is_proper_subtype=False, ignore_pos_arg_names=not overlap_for_overloads, allow_partial_overlap=True, + ) or is_callable_compatible( + right, + left, + is_compat=_is_overlapping_types, + is_proper_subtype=False, + ignore_pos_arg_names=not overlap_for_overloads, + check_args_covariantly=True, + allow_partial_overlap=True, ) call = None diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 6562f541d73b..ec9af3e66934 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -8,7 +8,7 @@ from mypy.erasetype import erase_type, remove_instance_last_known_values from mypy.indirection import TypeIndirectionVisitor from mypy.join import join_types -from mypy.meet import meet_types, narrow_declared_type +from mypy.meet import is_overlapping_types, meet_types, narrow_declared_type from mypy.nodes import ( ARG_NAMED, ARG_OPT, @@ -645,6 +645,20 @@ def assert_simplified_union(self, original: list[Type], union: Type) -> None: assert_equal(make_simplified_union(original), union) assert_equal(make_simplified_union(list(reversed(original))), union) + def test_generic_callable_overlap_is_symmetric(self) -> None: + any_type = AnyType(TypeOfAny.from_omitted_generics) + outer_t = TypeVarType("T", "T", TypeVarId(1), [], self.fx.o, any_type) + outer_s = TypeVarType("S", "S", TypeVarId(2), [], self.fx.o, any_type) + generic_t = TypeVarType("T", "T", TypeVarId(-1), [], self.fx.o, any_type) + + callable_type = CallableType([outer_t], [ARG_POS], [None], outer_s, self.fx.function) + generic_identity = CallableType( + [generic_t], [ARG_POS], [None], generic_t, self.fx.function, variables=[generic_t] + ) + + assert is_overlapping_types(callable_type, generic_identity) + assert is_overlapping_types(generic_identity, callable_type) + # Helpers def tuple(self, *a: Type) -> TupleType: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 648a4b001da5..fcf8e08ef548 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -3995,6 +3995,25 @@ def f2(func: Callable[..., T], arg: str) -> T: [builtins fixtures/primitives.pyi] +[case testNarrowGenericCallableEquality] +# flags: --strict-equality --warn-unreachable +from typing import Callable, TypeVar + +S = TypeVar("S") +T = TypeVar("T") + +def identity(x: T) -> T: + return x + +def msg(cmp_property: Callable[[T], S]) -> None: + if cmp_property == identity: + # TODO: the swapping of these reveal's is not ideal + reveal_type(cmp_property) # N: Revealed type is "def [T] (x: T`-1) -> T`-1" + reveal_type(identity) # N: Revealed type is "def (T`-1) -> S`-2" + return +[builtins fixtures/primitives.pyi] + + [case testPropagatedParentNarrowingMeet] # flags: --strict-equality --warn-unreachable from __future__ import annotations