diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index d2404e96bab9..89a21871c373 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -350,11 +350,51 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: Return None if some dataclass base class hasn't been processed yet and thus we'll need to ask for another pass. """ - # First, collect attributes belonging to the current class. ctx = self._ctx cls = self._ctx.cls - attrs: list[DataclassAttribute] = [] - known_attrs: set[str] = set() + + # First, collect attributes belonging to any class in the MRO, ignoring duplicates. + # + # We iterate through the MRO in reverse because attrs defined in the parent must appear + # earlier in the attributes list than attrs defined in the child. See: + # https://docs.python.org/3/library/dataclasses.html#inheritance + # + # However, we also want attributes defined in the subtype to override ones defined + # in the parent. We can implement this via a dict without disrupting the attr order + # because dicts preserve insertion order in Python 3.7+. + found_attrs: dict[str, DataclassAttribute] = {} + found_dataclass_supertype = False + for info in reversed(cls.info.mro[1:-1]): + if "dataclass_tag" in info.metadata and "dataclass" not in info.metadata: + # We haven't processed the base class yet. Need another pass. + return None + if "dataclass" not in info.metadata: + continue + + # Each class depends on the set of attributes in its dataclass ancestors. + ctx.api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) + found_dataclass_supertype = True + + for data in info.metadata["dataclass"]["attributes"]: + name: str = data["name"] + + attr = DataclassAttribute.deserialize(info, data, ctx.api) + # TODO: We shouldn't be performing type operations during the main + # semantic analysis pass, since some TypeInfo attributes might + # still be in flux. This should be performed in a later phase. + with state.strict_optional_set(ctx.api.options.strict_optional): + attr.expand_typevar_from_subtype(ctx.cls.info) + found_attrs[name] = attr + + sym_node = cls.info.names.get(name) + if sym_node and sym_node.node and not isinstance(sym_node.node, Var): + ctx.api.fail( + "Dataclass attribute may only be overridden by another attribute", + sym_node.node, + ) + + # Second, collect attributes belonging to the current class. + current_attr_names: set[str] = set() kw_only = _get_decorator_bool_argument(ctx, "kw_only", False) for stmt in cls.defs.body: # Any assignment that doesn't use the new type declaration @@ -435,8 +475,6 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: if field_kw_only_param is not None: is_kw_only = bool(ctx.api.parse_bool(field_kw_only_param)) - known_attrs.add(lhs.name) - if sym.type is None and node.is_final and node.is_inferred: # This is a special case, assignment like x: Final = 42 is classified # annotated above, but mypy strips the `Final` turning it into x = 42. @@ -453,75 +491,27 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: ) node.type = AnyType(TypeOfAny.from_error) - attrs.append( - DataclassAttribute( - name=lhs.name, - is_in_init=is_in_init, - is_init_var=is_init_var, - has_default=has_default, - line=stmt.line, - column=stmt.column, - type=sym.type, - info=cls.info, - kw_only=is_kw_only, - ) + current_attr_names.add(lhs.name) + found_attrs[lhs.name] = DataclassAttribute( + name=lhs.name, + is_in_init=is_in_init, + is_init_var=is_init_var, + has_default=has_default, + line=stmt.line, + column=stmt.column, + type=sym.type, + info=cls.info, + kw_only=is_kw_only, ) - # Next, collect attributes belonging to any class in the MRO - # as long as those attributes weren't already collected. This - # makes it possible to overwrite attributes in subclasses. - # copy() because we potentially modify all_attrs below and if this code requires debugging - # we'll have unmodified attrs laying around. - all_attrs = attrs.copy() - known_super_attrs = set() - for info in cls.info.mro[1:-1]: - if "dataclass_tag" in info.metadata and "dataclass" not in info.metadata: - # We haven't processed the base class yet. Need another pass. - return None - if "dataclass" not in info.metadata: - continue - - super_attrs = [] - # Each class depends on the set of attributes in its dataclass ancestors. - ctx.api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) - - for data in info.metadata["dataclass"]["attributes"]: - name: str = data["name"] - if name not in known_attrs: - attr = DataclassAttribute.deserialize(info, data, ctx.api) - # TODO: We shouldn't be performing type operations during the main - # semantic analysis pass, since some TypeInfo attributes might - # still be in flux. This should be performed in a later phase. - with state.strict_optional_set(ctx.api.options.strict_optional): - attr.expand_typevar_from_subtype(ctx.cls.info) - known_attrs.add(name) - known_super_attrs.add(name) - super_attrs.append(attr) - elif all_attrs: - # How early in the attribute list an attribute appears is determined by the - # reverse MRO, not simply MRO. - # See https://docs.python.org/3/library/dataclasses.html#inheritance for - # details. - for attr in all_attrs: - if attr.name == name: - all_attrs.remove(attr) - super_attrs.append(attr) - break - all_attrs = super_attrs + all_attrs + all_attrs = list(found_attrs.values()) + if found_dataclass_supertype: all_attrs.sort(key=lambda a: a.kw_only) - for known_super_attr_name in known_super_attrs: - sym_node = cls.info.names.get(known_super_attr_name) - if sym_node and sym_node.node and not isinstance(sym_node.node, Var): - ctx.api.fail( - "Dataclass attribute may only be overridden by another attribute", - sym_node.node, - ) - - # Ensure that arguments without a default don't follow - # arguments that have a default. + # Third, ensure that arguments without a default don't follow + # arguments that have a default and that the KW_ONLY sentinel + # is only provided once. found_default = False - # Ensure that the KW_ONLY sentinel is only provided once found_kw_sentinel = False for attr in all_attrs: # If we find any attribute that is_in_init, not kw_only, and that @@ -530,17 +520,20 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: if found_default and attr.is_in_init and not attr.has_default and not attr.kw_only: # If the issue comes from merging different classes, report it # at the class definition point. - context = Context(line=attr.line, column=attr.column) if attr in attrs else ctx.cls + context: Context = ctx.cls + if attr.name in current_attr_names: + context = Context(line=attr.line, column=attr.column) ctx.api.fail( "Attributes without a default cannot follow attributes with one", context ) found_default = found_default or (attr.has_default and attr.is_in_init) if found_kw_sentinel and self._is_kw_only_type(attr.type): - context = Context(line=attr.line, column=attr.column) if attr in attrs else ctx.cls + context = ctx.cls + if attr.name in current_attr_names: + context = Context(line=attr.line, column=attr.column) ctx.api.fail("There may not be more than one field with the KW_ONLY type", context) found_kw_sentinel = found_kw_sentinel or self._is_kw_only_type(attr.type) - return all_attrs def _freeze(self, attributes: list[DataclassAttribute]) -> None: