diff --git a/mypy/exprtotype.py b/mypy/exprtotype.py index b82d35607ef1..5f0ef79acbd7 100644 --- a/mypy/exprtotype.py +++ b/mypy/exprtotype.py @@ -196,6 +196,8 @@ def expr_to_unanalyzed_type( elif isinstance(expr, EllipsisExpr): return EllipsisType(expr.line) elif allow_unpack and isinstance(expr, StarExpr): - return UnpackType(expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax)) + return UnpackType( + expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax), from_star_syntax=True + ) else: raise TypeTranslationError() diff --git a/mypy/fastparse.py b/mypy/fastparse.py index a96e697d40bf..fe158d468ce8 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -2041,7 +2041,7 @@ def visit_Attribute(self, n: Attribute) -> Type: # Used for Callable[[X *Ys, Z], R] def visit_Starred(self, n: ast3.Starred) -> Type: - return UnpackType(self.visit(n.value)) + return UnpackType(self.visit(n.value), from_star_syntax=True) # List(expr* elts, expr_context ctx) def visit_List(self, n: ast3.List) -> Type: diff --git a/mypy/semanal_typeargs.py b/mypy/semanal_typeargs.py index 3e11951376c9..ed04b30e90ba 100644 --- a/mypy/semanal_typeargs.py +++ b/mypy/semanal_typeargs.py @@ -214,7 +214,9 @@ def visit_unpack_type(self, typ: UnpackType) -> None: # Avoid extra errors if there were some errors already. Also interpret plain Any # as tuple[Any, ...] (this is better for the code in type checker). self.fail( - message_registry.INVALID_UNPACK.format(format_type(proper_type, self.options)), typ + message_registry.INVALID_UNPACK.format(format_type(proper_type, self.options)), + typ.type, + code=codes.VALID_TYPE, ) typ.type = self.named_type("builtins.tuple", [AnyType(TypeOfAny.from_error)]) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index e297f2bf1631..385c5d35d67f 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -961,7 +961,7 @@ def visit_unpack_type(self, t: UnpackType) -> Type: if not self.allow_unpack: self.fail(message_registry.INVALID_UNPACK_POSITION, t.type, code=codes.VALID_TYPE) return AnyType(TypeOfAny.from_error) - return UnpackType(self.anal_type(t.type)) + return UnpackType(self.anal_type(t.type), from_star_syntax=t.from_star_syntax) def visit_parameters(self, t: Parameters) -> Type: raise NotImplementedError("ParamSpec literals cannot have unbound TypeVars") @@ -969,6 +969,7 @@ def visit_parameters(self, t: Parameters) -> Type: def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: # Every Callable can bind its own type variables, if they're not in the outer scope with self.tvar_scope_frame(): + unpacked_kwargs = False if self.defining_alias: variables = t.variables else: @@ -996,6 +997,15 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: ) validated_args.append(AnyType(TypeOfAny.from_error)) else: + if nested and isinstance(at, UnpackType) and i == star_index: + # TODO: it would be better to avoid this get_proper_type() call. + p_at = get_proper_type(at.type) + if isinstance(p_at, TypedDictType) and not at.from_star_syntax: + # Automatically detect Unpack[Foo] in Callable as backwards + # compatible syntax for **Foo, if Foo is a TypedDict. + at = p_at + arg_kinds[i] = ARG_STAR2 + unpacked_kwargs = True validated_args.append(at) arg_types = validated_args # If there were multiple (invalid) unpacks, the arg types list will become shorter, @@ -1013,6 +1023,7 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: fallback=(t.fallback if t.fallback.type else self.named_type("builtins.function")), variables=self.anal_var_defs(variables), type_guard=special, + unpack_kwargs=unpacked_kwargs, ) return ret diff --git a/mypy/types.py b/mypy/types.py index 04d90c9dc124..22fcd601d6a0 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1053,11 +1053,14 @@ class UnpackType(ProperType): wild west, technically anything can be present in the wrapped type. """ - __slots__ = ["type"] + __slots__ = ["type", "from_star_syntax"] - def __init__(self, typ: Type, line: int = -1, column: int = -1) -> None: + def __init__( + self, typ: Type, line: int = -1, column: int = -1, from_star_syntax: bool = False + ) -> None: super().__init__(line, column) self.type = typ + self.from_star_syntax = from_star_syntax def accept(self, visitor: TypeVisitor[T]) -> T: return visitor.visit_unpack_type(self) diff --git a/test-data/unit/check-varargs.test b/test-data/unit/check-varargs.test index ef2c3c57fad5..41668e991972 100644 --- a/test-data/unit/check-varargs.test +++ b/test-data/unit/check-varargs.test @@ -1079,3 +1079,18 @@ class C: class D: def __init__(self, **kwds: Unpack[int, str]) -> None: ... # E: Unpack[...] requires exactly one type argument [builtins fixtures/dict.pyi] + +[case testUnpackInCallableType] +from typing import Callable +from typing_extensions import Unpack, TypedDict + +class TD(TypedDict): + key: str + value: str + +foo: Callable[[Unpack[TD]], None] +foo(key="yes", value=42) # E: Argument "value" has incompatible type "int"; expected "str" +foo(key="yes", value="ok") + +bad: Callable[[*TD], None] # E: "TD" cannot be unpacked (must be tuple or TypeVarTuple) +[builtins fixtures/dict.pyi]