Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dataclass/protocol crash on joining types #15629

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
137 changes: 55 additions & 82 deletions mypy/plugins/dataclasses.py
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Final, Iterator
from typing import TYPE_CHECKING, Final, Iterator, Literal

from mypy import errorcodes, message_registry
from mypy.expandtype import expand_type, expand_type_by_instance
Expand Down Expand Up @@ -86,7 +86,7 @@
field_specifiers=("dataclasses.Field", "dataclasses.field"),
)
_INTERNAL_REPLACE_SYM_NAME: Final = "__mypy-replace"
_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-__post_init__"
_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-post_init"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sobolevn btw, __mypy-replace was intentionally class-private (starts with __ but doesn't end with __), aligning the post_init sym to be the same



class DataclassAttribute:
Expand Down Expand Up @@ -118,14 +118,33 @@ def __init__(
self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen
self._api = api

def to_argument(self, current_info: TypeInfo) -> Argument:
arg_kind = ARG_POS
if self.kw_only and self.has_default:
arg_kind = ARG_NAMED_OPT
elif self.kw_only and not self.has_default:
arg_kind = ARG_NAMED
elif not self.kw_only and self.has_default:
arg_kind = ARG_OPT
def to_argument(
self, current_info: TypeInfo, *, of: Literal["__init__", "replace", "__post_init__"]
) -> Argument:
if of == "__init__":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we're using to_argument for all use cases, it's the arg_pos that differs between the use cases.

arg_kind = ARG_POS
if self.kw_only and self.has_default:
arg_kind = ARG_NAMED_OPT
elif self.kw_only and not self.has_default:
arg_kind = ARG_NAMED
elif not self.kw_only and self.has_default:
arg_kind = ARG_OPT
elif of == "replace":
arg_kind = ARG_NAMED if self.is_init_var and not self.has_default else ARG_NAMED_OPT
elif of == "__post_init__":
# We always use `ARG_POS` without a default value, because it is practical.
# Consider this case:
#
# @dataclass
# class My:
# y: dataclasses.InitVar[str] = 'a'
# def __post_init__(self, y: str) -> None: ...
#
# We would be *required* to specify `y: str = ...` if default is added here.
# But, most people won't care about adding default values to `__post_init__`,
# because it is not designed to be called directly, and duplicating default values
# for the sake of type-checking is unpleasant.
arg_kind = ARG_POS
return Argument(
variable=self.to_var(current_info),
type_annotation=self.expand_type(current_info),
Expand Down Expand Up @@ -236,7 +255,7 @@ def transform(self) -> bool:
and attributes
):
args = [
attr.to_argument(info)
attr.to_argument(info, of="__init__")
for attr in attributes
if attr.is_in_init and not self._is_kw_only_type(attr.type)
]
Expand Down Expand Up @@ -375,70 +394,26 @@ def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) ->
Stashes the signature of 'dataclasses.replace(...)' for this specific dataclass
to be used later whenever 'dataclasses.replace' is called for this dataclass.
"""
arg_types: list[Type] = []
arg_kinds = []
arg_names: list[str | None] = []

info = self._cls.info
for attr in attributes:
attr_type = attr.expand_type(info)
assert attr_type is not None
arg_types.append(attr_type)
arg_kinds.append(
ARG_NAMED if attr.is_init_var and not attr.has_default else ARG_NAMED_OPT
)
arg_names.append(attr.name)

signature = CallableType(
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
ret_type=NoneType(),
fallback=self._api.named_type("builtins.function"),
)

info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
add_method_to_class(
self._api,
self._cls,
_INTERNAL_REPLACE_SYM_NAME,
args=[attr.to_argument(self._cls.info, of="replace") for attr in attributes],
return_type=NoneType(),
is_staticmethod=True,
)

def _add_internal_post_init_method(self, attributes: list[DataclassAttribute]) -> None:
arg_types: list[Type] = [fill_typevars(self._cls.info)]
arg_kinds = [ARG_POS]
arg_names: list[str | None] = ["self"]

info = self._cls.info
for attr in attributes:
if not attr.is_init_var:
continue
attr_type = attr.expand_type(info)
assert attr_type is not None
arg_types.append(attr_type)
# We always use `ARG_POS` without a default value, because it is practical.
# Consider this case:
#
# @dataclass
# class My:
# y: dataclasses.InitVar[str] = 'a'
# def __post_init__(self, y: str) -> None: ...
#
# We would be *required* to specify `y: str = ...` if default is added here.
# But, most people won't care about adding default values to `__post_init__`,
# because it is not designed to be called directly, and duplicating default values
# for the sake of type-checking is unpleasant.
arg_kinds.append(ARG_POS)
arg_names.append(attr.name)

signature = CallableType(
arg_types=arg_types,
arg_kinds=arg_kinds,
arg_names=arg_names,
ret_type=NoneType(),
fallback=self._api.named_type("builtins.function"),
name="__post_init__",
)

info.names[_INTERNAL_POST_INIT_SYM_NAME] = SymbolTableNode(
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
add_method_to_class(
self._api,
self._cls,
_INTERNAL_POST_INIT_SYM_NAME,
args=[
attr.to_argument(self._cls.info, of="__post_init__")
for attr in attributes
if attr.is_init_var
],
return_type=NoneType(),
)

def add_slots(
Expand Down Expand Up @@ -1113,20 +1088,18 @@ def is_processed_dataclass(info: TypeInfo | None) -> bool:
def check_post_init(api: TypeChecker, defn: FuncItem, info: TypeInfo) -> None:
if defn.type is None:
return

ideal_sig = info.get_method(_INTERNAL_POST_INIT_SYM_NAME)
if ideal_sig is None or ideal_sig.type is None:
return

# We set it ourself, so it is always fine:
assert isinstance(ideal_sig.type, ProperType)
assert isinstance(ideal_sig.type, FunctionLike)
# Type of `FuncItem` is always `FunctionLike`:
assert isinstance(defn.type, FunctionLike)

ideal_sig_method = info.get_method(_INTERNAL_POST_INIT_SYM_NAME)
assert ideal_sig_method is not None and ideal_sig_method.type is not None
ideal_sig = ideal_sig_method.type
assert isinstance(ideal_sig, ProperType) # we set it ourselves
assert isinstance(ideal_sig, CallableType)
ideal_sig = ideal_sig.copy_modified(name="__post_init__")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to copy with new name so that it appears correctly in the error message. Fortunately we do it just once per class.


api.check_override(
override=defn.type,
original=ideal_sig.type,
original=ideal_sig,
name="__post_init__",
name_in_super="__post_init__",
supertype="dataclass",
Expand Down
12 changes: 12 additions & 0 deletions test-data/unit/check-dataclasses.test
Expand Up @@ -2420,3 +2420,15 @@ class Test(Protocol):
def reset(self) -> None:
self.x = DEFAULT
[builtins fixtures/dataclasses.pyi]

[case testProtocolNoCrashOnJoining]
from dataclasses import dataclass
from typing import Protocol

@dataclass
class MyDataclass(Protocol): ...

a: MyDataclass
b = [a, a] # trigger joining the types

[builtins fixtures/dataclasses.pyi]