Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inference for overloaded __call__ with generic self #16053

Merged
merged 15 commits into from Sep 19, 2023
13 changes: 6 additions & 7 deletions mypy/checkmember.py
Expand Up @@ -330,13 +330,12 @@ def analyze_instance_member_access(
signature = method.type
signature = freshen_all_functions_type_vars(signature)
if not method.is_static:
if name != "__call__":
# TODO: use proper treatment of special methods on unions instead
# of this hack here and below (i.e. mx.self_type).
dispatched_type = meet.meet_types(mx.original_type, typ)
signature = check_self_arg(
signature, dispatched_type, method.is_class, mx.context, name, mx.msg
)
# TODO: use proper treatment of special methods on unions instead
# of this hack here and below (i.e. mx.self_type).
dispatched_type = meet.meet_types(mx.original_type, typ)
signature = check_self_arg(
signature, dispatched_type, method.is_class, mx.context, name, mx.msg
)
signature = bind_self(signature, mx.self_type, is_classmethod=method.is_class)
# TODO: should we skip these steps for static methods as well?
# Since generic static methods should not be allowed.
Expand Down
11 changes: 8 additions & 3 deletions mypy/subtypes.py
Expand Up @@ -463,10 +463,15 @@ def visit_instance(self, left: Instance) -> bool:
assert unpacked.type.fullname == "builtins.tuple"
if isinstance(get_proper_type(unpacked.args[0]), AnyType):
return not self.proper_subtype
if mapped.type.fullname == "builtins.tuple" and isinstance(
get_proper_type(mapped.args[0]), AnyType
):
if all(isinstance(get_proper_type(a), AnyType) for a in mapped.args):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should actually be combined with the if just above. I just realised that the if I added is incomplete, in fact if all type argument are either Any or *tuple[Any, ...] we are still good. But currently we are only considering only plain *tuple[Any, ...] or e.g. Any, Any, Any, while missing Any, *tuple[Any, ...], Any, which should be equally fine (note you don't need to check number of unpacks, it is already validated during semantic analyzis).

Note this also equally applies to the equivalent special case for plain instances I added around line 538. You may fix that as well if you want while you are at it.

return not self.proper_subtype
if mapped.type.tuple_type:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This additional branch looks like a hack. We shouldn't need this. I think the real bug is that we are passing incomplete information when accessing __call__ on a tuple type. This diff I think is a (much) more cleaner fix:

--- a/mypy/checkexpr.py
+++ b/mypy/checkexpr.py
@@ -1476,6 +1476,7 @@ class ExpressionChecker(ExpressionVisitor[Type]):
         callable_node: Expression | None = None,
         callable_name: str | None = None,
         object_type: Type | None = None,
+        original_type: Type | None = None,
     ) -> tuple[Type, Type]:
         """Type check a call.
 
@@ -1538,7 +1539,7 @@ class ExpressionChecker(ExpressionVisitor[Type]):
                 is_super=False,
                 is_operator=True,
                 msg=self.msg,
-                original_type=callee,
+                original_type=original_type or callee,
                 chk=self.chk,
                 in_literal_context=self.is_literal_context(),
             )
@@ -1579,6 +1580,7 @@ class ExpressionChecker(ExpressionVisitor[Type]):
                 callable_node,
                 callable_name,
                 object_type,
+                original_type=callee,
             )
         else:
             return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)

# similar to logic inside map_instance_to_supertype
tuple_type = expand_type_by_instance(mapped.type.tuple_type, mapped)
assert isinstance(tuple_type, TupleType)
return self._is_subtype(
tuple_type, right.copy_modified(fallback=tuple_type.partial_fallback)
) and self._is_subtype(left, mypy.typeops.tuple_fallback(right))
return False
if isinstance(right, TypeVarTupleType):
# tuple[Any, ...] is like Any in the world of tuples (see special case above).
Expand Down
24 changes: 24 additions & 0 deletions test-data/unit/check-overloading.test
Expand Up @@ -6650,3 +6650,27 @@ def d(x: int) -> int: ...
def d(f: int, *, x: int) -> str: ...
def d(*args, **kwargs): ...
[builtins fixtures/tuple.pyi]

[case testOverloadCallableGenericSelf]
from typing import Any, TypeVar, Generic, overload, reveal_type

T = TypeVar("T")

class MyCallable(Generic[T]):
def __init__(self, t: T):
self.t = t

@overload
def __call__(self: "MyCallable[int]") -> str: ...
@overload
def __call__(self: "MyCallable[str]") -> int: ...
def __call__(self): ...

c = MyCallable(5)
reveal_type(c) # N: Revealed type is "__main__.MyCallable[builtins.int]"
reveal_type(c()) # N: Revealed type is "builtins.str"

c2 = MyCallable("test")
reveal_type(c2) # N: Revealed type is "__main__.MyCallable[builtins.str]"
reveal_type(c2()) # should be int # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]
14 changes: 14 additions & 0 deletions test-data/unit/check-tuples.test
Expand Up @@ -1434,7 +1434,21 @@ def foo(o: CallableTuple) -> int:
class CallableTuple(Tuple[str, int]):
def __call__(self, n: int, m: int) -> int:
return n
[builtins fixtures/tuple.pyi]

[case testTypeTupleGenericCall]
from typing import Generic, Tuple, TypeVar

T = TypeVar('T')

def foo(o: CallableTuple[int]) -> int:
reveal_type(o) # N: Revealed type is "Tuple[builtins.str, builtins.int, fallback=__main__.CallableTuple[builtins.int]]"
reveal_type(o.count(3)) # N: Revealed type is "builtins.int"
return o(1, 2)

class CallableTuple(Tuple[str, T]):
def __call__(self, n: int, m: int) -> int:
return n
[builtins fixtures/tuple.pyi]

[case testTupleCompatibleWithSequence]
Expand Down