Skip to content
Closed
81 changes: 63 additions & 18 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2739,26 +2739,57 @@ def check_protocol_variance(self, defn: ClassDef) -> None:
if expected != tvar.variance:
self.msg.bad_proto_variance(tvar.variance, tvar.name, expected, defn)

def get_parameterized_base_classes(self, typ: TypeInfo) -> list[Instance]:
"""Build an MRO-like structure with generic type args substituted.

Excludes the class itself.

When several bases have a common ancestor, includes an :class:`Instance`
for each param.
"""
bases = []
for parent in typ.mro[1:]:
if parent.is_generic():
for base in typ.bases:
if parent in base.type.mro:
bases.append(map_instance_to_supertype(base, parent))
else:
bases.append(Instance(parent, []))
return bases

def check_multiple_inheritance(self, typ: TypeInfo) -> None:
"""Check for multiple inheritance related errors."""
if len(typ.bases) <= 1:
# No multiple inheritance.
return

# Verify that inherited attributes are compatible.
mro = typ.mro[1:]
for i, base in enumerate(mro):
typed_mro = self.get_parameterized_base_classes(typ)
# If the first MRO entry is compatible with everything following, we don't need
# (and shouldn't) compare further pairs
# (see testMultipleInheritanceExplcitDiamondResolution)
seen_names = set()
for i, base in enumerate(typed_mro):
# Attributes defined in both the type and base are skipped.
# Normal checks for attribute compatibility should catch any problems elsewhere.
non_overridden_attrs = base.names.keys() - typ.names.keys()
# Sort for consistent messages order.
non_overridden_attrs = sorted(typed_mro[i].type.names - typ.names.keys())
for name in non_overridden_attrs:
if is_private(name):
continue
for base2 in mro[i + 1 :]:
if name in seen_names:
continue
for base2 in typed_mro[i + 1 :]:
# We only need to check compatibility of attributes from classes not
# in a subclass relationship. For subclasses, normal (single inheritance)
# checks suffice (these are implemented elsewhere).
if name in base2.names and base2 not in base.mro:
if name in base2.type.names and not is_subtype(
base, base2, ignore_promotions=True
):
# If base1 already inherits from base2 with correct type args,
# we have reported errors if any. Avoid reporting them again.
self.check_compatibility(name, base, base2, typ)
seen_names.add(name)

def determine_type_of_member(self, sym: SymbolTableNode) -> Type | None:
if sym.type is not None:
Expand All @@ -2783,8 +2814,23 @@ def determine_type_of_member(self, sym: SymbolTableNode) -> Type | None:
# TODO: handle more node kinds here.
return None

def attribute_type_from_base(
self, name: str, base: Instance
) -> tuple[ProperType | None, SymbolTableNode]:
"""For a NameExpr that is part of a class, walk all base classes and try
to find the first class that defines a Type for the same name."""
base_var = base.type[name]
base_type = self.determine_type_of_member(base_var)
if base_type is None:
return None, base_var

if not has_no_typevars(base_type):
base_type = expand_type_by_instance(base_type, base)

return get_proper_type(base_type), base_var

def check_compatibility(
self, name: str, base1: TypeInfo, base2: TypeInfo, ctx: TypeInfo
self, name: str, base1: Instance, base2: Instance, ctx: TypeInfo
) -> None:
"""Check if attribute name in base1 is compatible with base2 in multiple inheritance.

Expand All @@ -2809,10 +2855,9 @@ class C(B, A[int]): ... # this is unsafe because...
if name in ("__init__", "__new__", "__init_subclass__"):
# __init__ and friends can be incompatible -- it's a special case.
return
first = base1.names[name]
second = base2.names[name]
first_type = get_proper_type(self.determine_type_of_member(first))
second_type = get_proper_type(self.determine_type_of_member(second))

first_type, first = self.attribute_type_from_base(name, base1)
second_type, second = self.attribute_type_from_base(name, base2)

# TODO: use more principled logic to decide is_subtype() vs is_equivalent().
# We should rely on mutability of superclass node, not on types being Callable.
Expand All @@ -2822,7 +2867,7 @@ class C(B, A[int]): ... # this is unsafe because...
if isinstance(first_type, Instance):
call = find_member("__call__", first_type, first_type, is_operator=True)
if call and isinstance(second_type, FunctionLike):
second_sig = self.bind_and_map_method(second, second_type, ctx, base2)
second_sig = self.bind_and_map_method(second, second_type, ctx, base2.type)
ok = is_subtype(call, second_sig, ignore_pos_arg_names=True)
elif isinstance(first_type, FunctionLike) and isinstance(second_type, FunctionLike):
if first_type.is_type_obj() and second_type.is_type_obj():
Expand All @@ -2834,8 +2879,8 @@ class C(B, A[int]): ... # this is unsafe because...
)
else:
# First bind/map method types when necessary.
first_sig = self.bind_and_map_method(first, first_type, ctx, base1)
second_sig = self.bind_and_map_method(second, second_type, ctx, base2)
first_sig = self.bind_and_map_method(first, first_type, ctx, base1.type)
second_sig = self.bind_and_map_method(second, second_type, ctx, base2.type)
ok = is_subtype(first_sig, second_sig, ignore_pos_arg_names=True)
elif first_type and second_type:
if isinstance(first.node, Var):
Expand All @@ -2844,7 +2889,7 @@ class C(B, A[int]): ... # this is unsafe because...
second_type = expand_self_type(second.node, second_type, fill_typevars(ctx))
ok = is_equivalent(first_type, second_type)
if not ok:
second_node = base2[name].node
second_node = base2.type[name].node
if (
isinstance(second_type, FunctionLike)
and second_node is not None
Expand All @@ -2854,22 +2899,22 @@ class C(B, A[int]): ... # this is unsafe because...
ok = is_subtype(first_type, second_type)
else:
if first_type is None:
self.msg.cannot_determine_type_in_base(name, base1.name, ctx)
self.msg.cannot_determine_type_in_base(name, base1.type.name, ctx)
if second_type is None:
self.msg.cannot_determine_type_in_base(name, base2.name, ctx)
self.msg.cannot_determine_type_in_base(name, base2.type.name, ctx)
ok = True
# Final attributes can never be overridden, but can override
# non-final read-only attributes.
if is_final_node(second.node) and not is_private(name):
self.msg.cant_override_final(name, base2.name, ctx)
self.msg.cant_override_final(name, base2.type.name, ctx)
if is_final_node(first.node):
self.check_if_final_var_override_writable(name, second.node, ctx)
# Some attributes like __slots__ and __deletable__ are special, and the type can
# vary across class hierarchy.
if isinstance(second.node, Var) and second.node.allow_incompatible_override:
ok = True
if not ok:
self.msg.base_class_definitions_incompatible(name, base1, base2, ctx)
self.msg.base_class_definitions_incompatible(name, base1.type, base2.type, ctx)

def check_metaclass_compatibility(self, typ: TypeInfo) -> None:
"""Ensures that metaclasses of all parent types are compatible."""
Expand Down
50 changes: 45 additions & 5 deletions test-data/unit/check-generic-subtyping.test
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ x1: X1[str, int]
reveal_type(list(x1)) # N: Revealed type is "builtins.list[builtins.int]"
reveal_type([*x1]) # N: Revealed type is "builtins.list[builtins.int]"

class X2(Generic[T, U], Iterator[U], Mapping[T, U]):
class X2(Generic[T, U], Iterator[U], Mapping[T, U]): # E: Definition of "__iter__" in base class "Iterable" is incompatible with definition in base class "Iterable"
pass

x2: X2[str, int]
Expand Down Expand Up @@ -1017,10 +1017,7 @@ x1: X1[str, int]
reveal_type(iter(x1)) # N: Revealed type is "typing.Iterator[builtins.int]"
reveal_type({**x1}) # N: Revealed type is "builtins.dict[builtins.int, builtins.str]"

# Some people would expect this to raise an error, but this currently does not:
# `Mapping` has `Iterable[U]` base class, `X2` has direct `Iterable[T]` base class.
# It would be impossible to define correct `__iter__` method for incompatible `T` and `U`.
class X2(Generic[T, U], Mapping[U, T], Iterable[T]):
class X2(Generic[T, U], Mapping[U, T], Iterable[T]): # E: Definition of "__iter__" in base class "Iterable" is incompatible with definition in base class "Iterable"
pass

x2: X2[str, int]
Expand Down Expand Up @@ -1065,3 +1062,46 @@ class F(E[T_co], Generic[T_co]): ... # E: Variance of TypeVar "T_co" incompatib

class G(Generic[T]): ...
class H(G[T_contra], Generic[T_contra]): ... # E: Variance of TypeVar "T_contra" incompatible with variance in parent type

[case testMultipleInheritanceCompatibleTypeVar]
from typing import Generic, TypeVar

T = TypeVar("T")
U = TypeVar("U")

class A(Generic[T]):
x: T
def fn(self, t: T) -> None: ...

class A2(A[T]):
y: str
z: str

class B(Generic[T]):
x: T
def fn(self, t: T) -> None: ...

class C1(A2[str], B[str]): pass
class C2(A2[str], B[int]): pass # E: Definition of "fn" in base class "A" is incompatible with definition in base class "B" \
# E: Definition of "x" in base class "A" is incompatible with definition in base class "B"
class C3(A2[T], B[T]): pass
class C4(A2[U], B[U]): pass
class C5(A2[U], B[T]): pass # E: Definition of "fn" in base class "A" is incompatible with definition in base class "B" \
# E: Definition of "x" in base class "A" is incompatible with definition in base class "B"

[builtins fixtures/tuple.pyi]

[case testMultipleInheritanceNestedTypeVarPropagation]
from typing import Generic, TypeVar

T = TypeVar("T")

class A(Generic[T]):
foo: T
class B(A[str]): ...
class C(B): ...
class D(C): ...

class Bad(D, A[T]): ... # E: Definition of "foo" in base class "A" is incompatible with definition in base class "A"
class Good(D, A[str]): ... # OK
[builtins fixtures/tuple.pyi]
45 changes: 44 additions & 1 deletion test-data/unit/check-multiple-inheritance.test
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,6 @@ class D2(B[Union[int, str]], C2): ...
class D3(C2, B[str]): ...
class D4(B[str], C2): ... # E: Definition of "foo" in base class "A" is incompatible with definition in base class "C2"


[case testMultipleInheritanceOverridingOfFunctionsWithCallableInstances]
from typing import Any, Callable

Expand Down Expand Up @@ -706,3 +705,47 @@ class C34(B3, B4): ...
class C41(B4, B1): ...
class C42(B4, B2): ...
class C43(B4, B3): ...

[case testMultipleInheritanceTransitive]
class A:
def fn(self, x: int) -> None: ...
class B(A): ...
class C(A):
def fn(self, x: "int | str") -> None: ...
class D(B, C): ...

[case testMultipleInheritanceCompatErrorPropagation]
class A:
foo: bytes
class B(A):
foo: str # type: ignore[assignment]

class Ok(B, A): pass

class C(A): pass
class Ok2(B, C): pass

[case testMultipleInheritanceExplcitDiamondResolution]
class A:
class M:
pass

class B0(A):
class M(A.M):
pass

class B1(A):
class M(A.M):
pass

class C(B0,B1):
class M(B0.M, B1.M):
pass

class D0(B0):
pass
class D1(B1):
pass

class D(D0,D1,C):
pass
Loading