diff --git a/mypy/checkmember.py b/mypy/checkmember.py index c81b3fbe4f7e0..1c38bb4f00dcd 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -6,7 +6,11 @@ from mypy import meet, message_registry, subtypes from mypy.erasetype import erase_typevars -from mypy.expandtype import expand_self_type, expand_type_by_instance, freshen_function_type_vars +from mypy.expandtype import ( + expand_self_type, + expand_type_by_instance, + freshen_all_functions_type_vars, +) from mypy.maptype import map_instance_to_supertype from mypy.messages import MessageBuilder from mypy.nodes import ( @@ -66,6 +70,7 @@ get_proper_type, has_type_vars, ) +from mypy.typetraverser import TypeTraverserVisitor if TYPE_CHECKING: # import for forward declaration only import mypy.checker @@ -311,7 +316,7 @@ def analyze_instance_member_access( if mx.is_lvalue: mx.msg.cant_assign_to_method(mx.context) signature = function_type(method, mx.named_type("builtins.function")) - signature = freshen_function_type_vars(signature) + signature = freshen_all_functions_type_vars(signature) if name == "__new__" or method.is_static: # __new__ is special and behaves like a static method -- don't strip # the first argument. @@ -329,7 +334,7 @@ def analyze_instance_member_access( # Since generic static methods should not be allowed. typ = map_instance_to_supertype(typ, method.info) member_type = expand_type_by_instance(signature, typ) - freeze_type_vars(member_type) + freeze_all_type_vars(member_type) return member_type else: # Not a method. @@ -727,11 +732,13 @@ def analyze_var( mx.msg.read_only_property(name, itype.type, mx.context) if mx.is_lvalue and var.is_classvar: mx.msg.cant_assign_to_classvar(name, mx.context) + t = freshen_all_functions_type_vars(typ) if not (mx.is_self or mx.is_super) or supported_self_type( get_proper_type(mx.original_type) ): - typ = expand_self_type(var, typ, mx.original_type) - t = get_proper_type(expand_type_by_instance(typ, itype)) + t = expand_self_type(var, t, mx.original_type) + t = get_proper_type(expand_type_by_instance(t, itype)) + freeze_all_type_vars(t) result: Type = t typ = get_proper_type(typ) if ( @@ -759,13 +766,13 @@ def analyze_var( # In `x.f`, when checking `x` against A1 we assume x is compatible with A # and similarly for B1 when checking against B dispatched_type = meet.meet_types(mx.original_type, itype) - signature = freshen_function_type_vars(functype) + signature = freshen_all_functions_type_vars(functype) signature = check_self_arg( signature, dispatched_type, var.is_classmethod, mx.context, name, mx.msg ) signature = bind_self(signature, mx.self_type, var.is_classmethod) expanded_signature = expand_type_by_instance(signature, itype) - freeze_type_vars(expanded_signature) + freeze_all_type_vars(expanded_signature) if var.is_property: # A property cannot have an overloaded type => the cast is fine. assert isinstance(expanded_signature, CallableType) @@ -788,16 +795,14 @@ def analyze_var( return result -def freeze_type_vars(member_type: Type) -> None: - if not isinstance(member_type, ProperType): - return - if isinstance(member_type, CallableType): - for v in member_type.variables: +def freeze_all_type_vars(member_type: Type) -> None: + member_type.accept(FreezeTypeVarsVisitor()) + + +class FreezeTypeVarsVisitor(TypeTraverserVisitor): + def visit_callable_type(self, t: CallableType) -> None: + for v in t.variables: v.id.meta_level = 0 - if isinstance(member_type, Overloaded): - for it in member_type.items: - for v in it.variables: - v.id.meta_level = 0 def lookup_member_var_or_accessor(info: TypeInfo, name: str, is_lvalue: bool) -> SymbolNode | None: @@ -1131,11 +1136,11 @@ class B(A[str]): pass if isinstance(t, CallableType): tvars = original_vars if original_vars is not None else [] if is_classmethod: - t = freshen_function_type_vars(t) + t = freshen_all_functions_type_vars(t) t = bind_self(t, original_type, is_classmethod=True) assert isuper is not None t = cast(CallableType, expand_type_by_instance(t, isuper)) - freeze_type_vars(t) + freeze_all_type_vars(t) return t.copy_modified(variables=list(tvars) + list(t.variables)) elif isinstance(t, Overloaded): return Overloaded( diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 5a56857e11145..f8206d152f9b2 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -3,6 +3,7 @@ from typing import Iterable, Mapping, Sequence, TypeVar, cast, overload from mypy.nodes import ARG_STAR, Var +from mypy.type_visitor import TypeTranslator from mypy.types import ( AnyType, CallableType, @@ -124,6 +125,26 @@ def freshen_function_type_vars(callee: F) -> F: return cast(F, fresh_overload) +T = TypeVar("T", bound=Type) + + +def freshen_all_functions_type_vars(t: T) -> T: + result = t.accept(FreshenCallableVisitor()) + assert isinstance(result, type(t)) + return result + + +class FreshenCallableVisitor(TypeTranslator): + def visit_callable_type(self, t: CallableType) -> Type: + result = super().visit_callable_type(t) + assert isinstance(result, ProperType) and isinstance(result, CallableType) + return freshen_function_type_vars(result) + + def visit_type_alias_type(self, t: TypeAliasType) -> Type: + # Same as for ExpandTypeVisitor + return t.copy_modified(args=[arg.accept(self) for arg in t.args]) + + class ExpandTypeVisitor(TypeVisitor[Type]): """Visitor that substitutes type variables with values.""" diff --git a/mypy/typestate.py b/mypy/typestate.py index a5d65c4b4ea39..7398f0d7f5242 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -10,7 +10,7 @@ from mypy.nodes import TypeInfo from mypy.server.trigger import make_trigger -from mypy.types import Instance, Type, get_proper_type +from mypy.types import Instance, Type, TypeVarId, get_proper_type # Represents that the 'left' instance is a subtype of the 'right' instance SubtypeRelationship: _TypeAlias = Tuple[Instance, Instance] @@ -275,3 +275,4 @@ def reset_global_state() -> None: """ TypeState.reset_all_subtype_caches() TypeState.reset_protocol_deps() + TypeVarId.next_raw_id = 1 diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 7df52b60fc0bc..04108dded7234 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -1544,7 +1544,7 @@ class C(Generic[T]): reveal_type(C.F(17).foo()) # N: Revealed type is "builtins.int" reveal_type(C("").F(17).foo()) # N: Revealed type is "builtins.int" reveal_type(C.F) # N: Revealed type is "def [K] (k: K`1) -> __main__.C.F[K`1]" -reveal_type(C("").F) # N: Revealed type is "def [K] (k: K`1) -> __main__.C.F[K`1]" +reveal_type(C("").F) # N: Revealed type is "def [K] (k: K`6) -> __main__.C.F[K`6]" -- Callable subtyping with generic functions @@ -2580,3 +2580,68 @@ class Bar(Foo[AnyStr]): [out] main:10: error: Argument 1 to "method1" of "Foo" has incompatible type "str"; expected "AnyStr" main:10: error: Argument 2 to "method1" of "Foo" has incompatible type "bytes"; expected "AnyStr" + +[case testTypeVariableClashVar] +from typing import Generic, TypeVar, Callable + +T = TypeVar("T") +R = TypeVar("R") +class C(Generic[R]): + x: Callable[[T], R] + +def func(x: C[R]) -> R: + return x.x(42) # OK + +[case testTypeVariableClashVarTuple] +from typing import Generic, TypeVar, Callable, Tuple + +T = TypeVar("T") +R = TypeVar("R") +class C(Generic[R]): + x: Callable[[T], Tuple[R, T]] + +def func(x: C[R]) -> R: + if bool(): + return x.x(42)[0] # OK + else: + return x.x(42)[1] # E: Incompatible return value type (got "int", expected "R") +[builtins fixtures/tuple.pyi] + +[case testTypeVariableClashMethod] +from typing import Generic, TypeVar, Callable + +T = TypeVar("T") +R = TypeVar("R") +class C(Generic[R]): + def x(self) -> Callable[[T], R]: ... + +def func(x: C[R]) -> R: + return x.x()(42) # OK + +[case testTypeVariableClashMethodTuple] +from typing import Generic, TypeVar, Callable, Tuple + +T = TypeVar("T") +R = TypeVar("R") +class C(Generic[R]): + def x(self) -> Callable[[T], Tuple[R, T]]: ... + +def func(x: C[R]) -> R: + if bool(): + return x.x()(42)[0] # OK + else: + return x.x()(42)[1] # E: Incompatible return value type (got "int", expected "R") +[builtins fixtures/tuple.pyi] + +[case testTypeVariableClashVarSelf] +from typing import Self, TypeVar, Generic, Callable + +T = TypeVar("T") +S = TypeVar("S") + +class C(Generic[T]): + x: Callable[[S], Self] + y: T + +def foo(x: C[T]) -> T: + return x.x(42).y # OK diff --git a/test-data/unit/check-selftype.test b/test-data/unit/check-selftype.test index a7dc41a2ff867..57e10a764e185 100644 --- a/test-data/unit/check-selftype.test +++ b/test-data/unit/check-selftype.test @@ -1654,7 +1654,7 @@ class C: def bar(self) -> Self: ... foo: Callable[[S, Self], Tuple[Self, S]] -reveal_type(C().foo) # N: Revealed type is "def [S] (S`-1, __main__.C) -> Tuple[__main__.C, S`-1]" +reveal_type(C().foo) # N: Revealed type is "def [S] (S`1, __main__.C) -> Tuple[__main__.C, S`1]" reveal_type(C().foo(42, C())) # N: Revealed type is "Tuple[__main__.C, builtins.int]" class This: ... [builtins fixtures/tuple.pyi]