diff --git a/mypy/mro.py b/mypy/mro.py index f34f3fa0c46d..8d839b817708 100644 --- a/mypy/mro.py +++ b/mypy/mro.py @@ -3,7 +3,7 @@ from typing import Callable from mypy.nodes import TypeInfo -from mypy.types import Instance +from mypy.types import Instance, ProperType, TypeVarLikeType from mypy.typestate import type_state @@ -15,11 +15,61 @@ def calculate_mro(info: TypeInfo, obj_type: Callable[[], Instance] | None = None mro = linearize_hierarchy(info, obj_type) assert mro, f"Could not produce a MRO at all for {info}" info.mro = mro + fill_mapped_type_vars(info) # The property of falling back to Any is inherited. info.fallback_to_any = any(baseinfo.fallback_to_any for baseinfo in info.mro) type_state.reset_all_subtype_caches_for(info) +def fill_mapped_type_vars(info: TypeInfo) -> None: + """Calculates the final TypeVar value from inheritor to parent. + + class A[T1]: + # mapped_type_vars = {T1: str} + + class B[T2]: + # mapped_type_vars = {T2: T4} + + class C[T3](B[T3]): + # mapped_type_vars = {T3: T4} + + class D[T4](C[T4], A[str]): + # mapped_type_vars = {} + """ + bases = {b.type: b for b in info.bases} + + for subinfo in filter(lambda x: x.is_generic, info.mro): + if base_info := bases.get(subinfo): + subinfo.mapped_type_vars = { + tv: actual_type for tv, actual_type in zip(subinfo.defn.type_vars, base_info.args) + } + info.mapped_type_vars |= subinfo.mapped_type_vars + + final_mapped_type_vars: dict[TypeVarLikeType, ProperType] = {} + for k, v in info.mapped_type_vars.items(): + final_mapped_type_vars[k] = _resolve_mappped_vars(info.mapped_type_vars, v) + + for subinfo in filter(lambda x: x.is_generic, info.mro): + _resolve_info_type_vars(subinfo, final_mapped_type_vars) + + +def _resolve_info_type_vars( + info: TypeInfo, mapped_type_vars: dict[TypeVarLikeType, ProperType] +) -> None: + final_mapped_type_vars = {} + for tv in info.defn.type_vars: + final_mapped_type_vars[tv] = _resolve_mappped_vars(mapped_type_vars, tv) + info.mapped_type_vars = final_mapped_type_vars + + +def _resolve_mappped_vars( + mapped_type_vars: dict[TypeVarLikeType, ProperType], key: ProperType +) -> ProperType: + if key in mapped_type_vars: + return _resolve_mappped_vars(mapped_type_vars, mapped_type_vars[key]) + return key + + class MroError(Exception): """Raised if a consistent mro cannot be determined for a class.""" diff --git a/mypy/nodes.py b/mypy/nodes.py index 9e26103e2f58..c3304e391435 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2944,6 +2944,7 @@ class is generic then it will be a type constructor of higher kind. "fallback_to_any", "meta_fallback_to_any", "type_vars", + "mapped_type_vars", "has_param_spec_type", "bases", "_promote", @@ -3048,6 +3049,8 @@ class is generic then it will be a type constructor of higher kind. # Generic type variable names (full names) type_vars: list[str] + # Map of current class TypeVars and Inheritor specified type to calculate real type in MRO + mapped_type_vars: dict[mypy.types.TypeVarLikeType, mypy.types.ProperType] # Whether this class has a ParamSpec type variable has_param_spec_type: bool @@ -3139,6 +3142,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None self.defn = defn self.module_name = module_name self.type_vars = [] + self.mapped_type_vars = {} self.has_param_spec_type = False self.has_type_var_tuple_type = False self.bases = []