diff --git a/mypy/checker.py b/mypy/checker.py index 2dd7f10e6f35..c1cb29652e88 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -395,6 +395,8 @@ class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi): # A helper state to produce unique temporary names on demand. _unique_id: int + # Fake concrete type used when checking variance + _variance_dummy_type: Instance | None def __init__( self, @@ -469,6 +471,7 @@ def __init__( self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options) self._unique_id = 0 + self._variance_dummy_type = None @property def expr_checker(self) -> mypy.checkexpr.ExpressionChecker: @@ -2918,17 +2921,19 @@ def check_protocol_variance(self, defn: ClassDef) -> None: info = defn.info object_type = Instance(info.mro[-1], []) tvars = info.defn.type_vars + if self._variance_dummy_type is None: + _, dummy_info = self.make_fake_typeinfo("", "Dummy", "Dummy", []) + self._variance_dummy_type = Instance(dummy_info, []) + dummy = self._variance_dummy_type for i, tvar in enumerate(tvars): if not isinstance(tvar, TypeVarType): # Variance of TypeVarTuple and ParamSpec is underspecified by PEPs. continue up_args: list[Type] = [ - object_type if i == j else AnyType(TypeOfAny.special_form) - for j, _ in enumerate(tvars) + object_type if i == j else dummy.copy_modified() for j, _ in enumerate(tvars) ] down_args: list[Type] = [ - UninhabitedType() if i == j else AnyType(TypeOfAny.special_form) - for j, _ in enumerate(tvars) + UninhabitedType() if i == j else dummy.copy_modified() for j, _ in enumerate(tvars) ] up, down = Instance(info, up_args), Instance(info, down_args) # TODO: add advanced variance checks for recursive protocols diff --git a/test-data/unit/check-protocols.test b/test-data/unit/check-protocols.test index ae6f60355512..e7971cd5b5d8 100644 --- a/test-data/unit/check-protocols.test +++ b/test-data/unit/check-protocols.test @@ -1373,6 +1373,19 @@ main:16: note: def meth(self, x: int) -> int main:16: note: @overload main:16: note: def meth(self, x: bytes) -> str +[case testProtocolWithMultiContravariantTypeVarOverloads] +from typing import overload, Protocol, TypeVar + +T1 = TypeVar("T1", contravariant=True) +T2 = TypeVar("T2", contravariant=True) + +class A(Protocol[T1, T2]): + @overload + def method(self, a: T1) -> None: ... + @overload + def method(self, a: T2) -> None: ... + + -- Join and meet with protocol types -- ---------------------------------