Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support kw_only on dataclasses #3674

Merged
merged 4 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3670-detachhead.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support `kw_only` in dataclasses
119 changes: 84 additions & 35 deletions pydantic/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class M:
The trick is to create a wrapper around `M` that will act as a proxy to trigger
validation without altering default `M` behaviour.
"""
import sys
from contextlib import contextmanager
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload
Expand Down Expand Up @@ -85,38 +86,73 @@ def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
'make_dataclass_validator',
]


@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
) -> Callable[[Type[Any]], 'DataclassClassOrWrapper']:
...


@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
_cls: Type[Any],
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
) -> 'DataclassClassOrWrapper':
...
if sys.version_info >= (3, 10):

@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
kw_only: bool = ...,
) -> Callable[[Type[Any]], 'DataclassClassOrWrapper']:
...

@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
_cls: Type[Any],
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
kw_only: bool = ...,
) -> 'DataclassClassOrWrapper':
...

else:

@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
) -> Callable[[Type[Any]], 'DataclassClassOrWrapper']:
...

@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
_cls: Type[Any],
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
) -> 'DataclassClassOrWrapper':
...


@__dataclass_transform__(kw_only_default=True, field_descriptors=(Field, FieldInfo))
Expand All @@ -131,6 +167,7 @@ def dataclass(
frozen: bool = False,
config: Union[ConfigDict, Type[Any], None] = None,
validate_on_init: Optional[bool] = None,
kw_only: bool = False,
) -> Union[Callable[[Type[Any]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']:
"""
Like the python standard lib dataclasses but with type validation.
Expand All @@ -149,9 +186,21 @@ def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
default_validate_on_init = False
else:
dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass
dc_cls = dataclasses.dataclass( # type: ignore
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
)
if sys.version_info >= (3, 10):
dc_cls = dataclasses.dataclass(
cls,
init=init,
repr=repr,
eq=eq,
order=order,
unsafe_hash=unsafe_hash,
frozen=frozen,
kw_only=kw_only,
)
else:
dc_cls = dataclasses.dataclass( # type: ignore
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
)
default_validate_on_init = True

should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init
Expand Down
14 changes: 14 additions & 0 deletions tests/test_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import pickle
import re
import sys
from collections.abc import Hashable
from datetime import datetime
from pathlib import Path
Expand Down Expand Up @@ -1346,3 +1347,16 @@ class MyDataclass:
self_reference: 'MyDataclass'

assert MyDataclass.__pydantic_model__.__fields__['self_reference'].type_ is MyDataclass


@pytest.mark.skipif(sys.version_info < (3, 10), reason='kw_only is not available in python < 3.10')
def test_kw_only():
@pydantic.dataclasses.dataclass(kw_only=True)
class A:
a: int | None = None
b: str

with pytest.raises(TypeError, match='takes 1 positional argument but 3 were given'):
A(1, '')

assert A(b='hi').b == 'hi'