From a7d6e68541e1c6b52338dbe2e3f4a7055a033483 Mon Sep 17 00:00:00 2001 From: Nikita Sobolev Date: Tue, 21 Sep 2021 09:26:45 +0300 Subject: [PATCH] Fix mypy crash on `dataclasses.field(**unpack)` (#11137) --- mypy/plugins/dataclasses.py | 15 ++++++++--- test-data/unit/check-dataclasses.test | 36 +++++++++++++++++++++++++ test-data/unit/fixtures/dataclasses.pyi | 25 +++++++++++++++-- 3 files changed, 71 insertions(+), 5 deletions(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 96b58b3f43a7..9c615f857731 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -269,7 +269,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: if self._is_kw_only_type(node_type): kw_only = True - has_field_call, field_args = _collect_field_args(stmt.rvalue) + has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx) is_in_init_param = field_args.get('init') if is_in_init_param is None: @@ -447,7 +447,8 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> None: transformer.transform() -def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]: +def _collect_field_args(expr: Expression, + ctx: ClassDefContext) -> Tuple[bool, Dict[str, Expression]]: """Returns a tuple where the first value represents whether or not the expression is a call to dataclass.field and the second is a dictionary of the keyword arguments that field() was called with. @@ -460,7 +461,15 @@ def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]: # field() only takes keyword arguments. args = {} for name, arg in zip(expr.arg_names, expr.args): - assert name is not None + if name is None: + # This means that `field` is used with `**` unpacking, + # the best we can do for now is not to fail. + # TODO: we can infer what's inside `**` and try to collect it. + ctx.api.fail( + 'Unpacking **kwargs in "field()" is not supported', + expr, + ) + return True, {} args[name] = arg return True, args return False, {} diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test index 62959488aa27..80ad554d846c 100644 --- a/test-data/unit/check-dataclasses.test +++ b/test-data/unit/check-dataclasses.test @@ -1300,3 +1300,39 @@ a.x = x a.x = x2 # E: Incompatible types in assignment (expression has type "Callable[[str], str]", variable has type "Callable[[int], int]") [builtins fixtures/dataclasses.pyi] + + +[case testDataclassFieldDoesNotFailOnKwargsUnpacking] +# flags: --python-version 3.7 +# https://github.com/python/mypy/issues/10879 +from dataclasses import dataclass, field + +@dataclass +class Foo: + bar: float = field(**{"repr": False}) +[out] +main:7: error: Unpacking **kwargs in "field()" is not supported +main:7: error: No overload variant of "field" matches argument type "Dict[str, bool]" +main:7: note: Possible overload variants: +main:7: note: def [_T] field(*, default: _T, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> _T +main:7: note: def [_T] field(*, default_factory: Callable[[], _T], init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> _T +main:7: note: def field(*, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> Any +[builtins fixtures/dataclasses.pyi] + + +[case testDataclassFieldWithTypedDictUnpacking] +# flags: --python-version 3.7 +from dataclasses import dataclass, field +from typing_extensions import TypedDict + +class FieldKwargs(TypedDict): + repr: bool + +field_kwargs: FieldKwargs = {"repr": False} + +@dataclass +class Foo: + bar: float = field(**field_kwargs) # E: Unpacking **kwargs in "field()" is not supported + +reveal_type(Foo(bar=1.5)) # N: Revealed type is "__main__.Foo" +[builtins fixtures/dataclasses.pyi] diff --git a/test-data/unit/fixtures/dataclasses.pyi b/test-data/unit/fixtures/dataclasses.pyi index fb0053c80b25..206843a88b24 100644 --- a/test-data/unit/fixtures/dataclasses.pyi +++ b/test-data/unit/fixtures/dataclasses.pyi @@ -1,7 +1,12 @@ -from typing import Generic, Sequence, TypeVar +from typing import ( + Generic, Iterator, Iterable, Mapping, Optional, Sequence, Tuple, + TypeVar, Union, overload, +) _T = TypeVar('_T') _U = TypeVar('_U') +KT = TypeVar('KT') +VT = TypeVar('VT') class object: def __init__(self) -> None: pass @@ -15,7 +20,23 @@ class int: pass class float: pass class str: pass class bool(int): pass -class dict(Generic[_T, _U]): pass + +class dict(Mapping[KT, VT]): + @overload + def __init__(self, **kwargs: VT) -> None: pass + @overload + def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass + def __getitem__(self, key: KT) -> VT: pass + def __setitem__(self, k: KT, v: VT) -> None: pass + def __iter__(self) -> Iterator[KT]: pass + def __contains__(self, item: object) -> int: pass + def update(self, a: Mapping[KT, VT]) -> None: pass + @overload + def get(self, k: KT) -> Optional[VT]: pass + @overload + def get(self, k: KT, default: Union[KT, _T]) -> Union[VT, _T]: pass + def __len__(self) -> int: ... + class list(Generic[_T], Sequence[_T]): pass class function: pass class classmethod: pass