Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,8 @@ class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi):

# A helper state to produce unique temporary names on demand.
_unique_id: int
# Fake concrete type used when checking variance
_variance_dummy_type: Instance | None

def __init__(
self,
Expand Down Expand Up @@ -469,6 +471,7 @@ def __init__(

self.pattern_checker = PatternChecker(self, self.msg, self.plugin, options)
self._unique_id = 0
self._variance_dummy_type = None

@property
def expr_checker(self) -> mypy.checkexpr.ExpressionChecker:
Expand Down Expand Up @@ -2918,17 +2921,19 @@ def check_protocol_variance(self, defn: ClassDef) -> None:
info = defn.info
object_type = Instance(info.mro[-1], [])
tvars = info.defn.type_vars
if self._variance_dummy_type is None:
_, dummy_info = self.make_fake_typeinfo("<dummy>", "Dummy", "Dummy", [])
self._variance_dummy_type = Instance(dummy_info, [])
dummy = self._variance_dummy_type
for i, tvar in enumerate(tvars):
if not isinstance(tvar, TypeVarType):
# Variance of TypeVarTuple and ParamSpec is underspecified by PEPs.
continue
up_args: list[Type] = [
object_type if i == j else AnyType(TypeOfAny.special_form)
for j, _ in enumerate(tvars)
object_type if i == j else dummy.copy_modified() for j, _ in enumerate(tvars)
]
down_args: list[Type] = [
UninhabitedType() if i == j else AnyType(TypeOfAny.special_form)
for j, _ in enumerate(tvars)
UninhabitedType() if i == j else dummy.copy_modified() for j, _ in enumerate(tvars)
]
up, down = Instance(info, up_args), Instance(info, down_args)
# TODO: add advanced variance checks for recursive protocols
Expand Down
13 changes: 13 additions & 0 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,19 @@ main:16: note: def meth(self, x: int) -> int
main:16: note: @overload
main:16: note: def meth(self, x: bytes) -> str

[case testProtocolWithMultiContravariantTypeVarOverloads]
from typing import overload, Protocol, TypeVar

T1 = TypeVar("T1", contravariant=True)
T2 = TypeVar("T2", contravariant=True)

class A(Protocol[T1, T2]):
@overload
def method(self, a: T1) -> None: ...
@overload
def method(self, a: T2) -> None: ...


-- Join and meet with protocol types
-- ---------------------------------

Expand Down