From 04907a3f9ca81871d59f5fe71aaffc9c45102f41 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Fri, 19 May 2023 19:40:24 +0330 Subject: [PATCH 01/14] Add slots to dataclass schema --- pydantic_core/core_schema.py | 3 +++ src/validators/dataclass.rs | 10 +++++++++- tests/test_schema_functions.py | 2 +- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index 162a1cc5a..df9721eac 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -3101,6 +3101,7 @@ class DataclassSchema(TypedDict, total=False): ref: str metadata: Any serialization: SerSchema + slots: bool def dataclass_schema( @@ -3115,6 +3116,7 @@ def dataclass_schema( metadata: Any = None, serialization: SerSchema | None = None, frozen: bool | None = None, + slots: bool = False, ) -> DataclassSchema: """ Returns a schema for a dataclass. As with `ModelSchema`, this schema can only be used as a field within @@ -3145,6 +3147,7 @@ def dataclass_schema( metadata=metadata, serialization=serialization, frozen=frozen, + slots=slots, ) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 1eaaf5f0b..b77e915a1 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -415,6 +415,7 @@ pub struct DataclassValidator { revalidate: Revalidate, name: String, frozen: bool, + slots: bool, } impl BuildValidator for DataclassValidator { @@ -453,6 +454,7 @@ impl BuildValidator for DataclassValidator { )?)?, name, frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false), + slots: schema.get_as(intern!(py, "slots"))?.unwrap_or(false), } .into()) } @@ -595,7 +597,13 @@ impl DataclassValidator { input: &'data impl Input<'data>, ) -> ValResult<'data, ()> { let (dc_dict, post_init_kwargs): (&PyAny, &PyAny) = val_output.extract(py)?; - force_setattr(py, dc, intern!(py, "__dict__"), dc_dict)?; + if self.slots { + for (key, value) in dc_dict.downcast::()?.iter() { + force_setattr(py, dc, key, value)?; + } + } else { + force_setattr(py, dc, intern!(py, "__dict__"), dc_dict)?; + } if let Some(ref post_init) = self.post_init { let post_init = post_init.as_ref(py); diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index ee6893ab0..809b87897 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -250,7 +250,7 @@ def args(*args, **kwargs): core_schema.dataclass_schema, # MyModel should be a dataclass, but I'm being lazy here args(MyModel, {'type': 'int'}), - {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel}, + {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel, 'slots': False}, ), ] From c2dc0fddd0910154a1dcfd8e3e7548adb056306b Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Fri, 19 May 2023 20:05:23 +0330 Subject: [PATCH 02/14] add test --- pydantic_core/core_schema.py | 1 + tests/test_schema_functions.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index df9721eac..0a2848ca3 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -3134,6 +3134,7 @@ def dataclass_schema( metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema frozen: Whether the dataclass is frozen + slots: Whether the slots is enabled on dataclass """ return dict_not_none( type='dataclass', diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index 809b87897..ed946068e 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -252,6 +252,12 @@ def args(*args, **kwargs): args(MyModel, {'type': 'int'}), {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel, 'slots': False}, ), + ( + core_schema.dataclass_schema, + # MyModel should be a dataclass, but I'm being lazy here + args(MyModel, {'type': 'int'}, slots=True), + {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel, 'slots': True}, + ), ] From 83ac340771f2cde52a9ea8a440566ef37226d170 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 19 May 2023 12:45:38 -0500 Subject: [PATCH 03/14] Add test --- tests/validators/test_dataclasses.py | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index 87a2578a5..d8f10e0e5 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -1,5 +1,6 @@ import dataclasses import re +import sys from typing import Any, Dict, List, Optional, Union import pytest @@ -1190,3 +1191,34 @@ def test_custom_dataclass_names(): }, {'input': 123, 'loc': ('foo', 'none'), 'msg': 'Input should be None', 'type': 'none_required'}, ] + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +def test_slots() -> None: + kwargs = {'slots': True} + + @dataclasses.dataclass(**kwargs) + class Model: + x: int + + schema = core_schema.dataclass_schema( + Model, + core_schema.dataclass_args_schema( + 'Model', [core_schema.dataclass_field(name='x', schema=core_schema.int_schema())] + ), + ) + + val = SchemaValidator(schema) + m: Model + + m = val.validate_python({'x': 123}) + assert m == Model(x=1) + + with pytest.raises(ValidationError): + val.validate_python({'x': 'abc'}) + + val.validate_assignment(m, 'x', 456) + assert m.x == 456 + + with pytest.raises(ValidationError): + val.validate_assignment(m, 'x', 'abc') From d2623dad18324bf453d2e048c2ee65c0f580c82b Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Fri, 19 May 2023 23:46:07 +0330 Subject: [PATCH 04/14] handle tests --- src/validators/dataclass.rs | 2 +- tests/validators/test_dataclasses.py | 70 +++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index b77e915a1..0ce939003 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -454,7 +454,7 @@ impl BuildValidator for DataclassValidator { )?)?, name, frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false), - slots: schema.get_as(intern!(py, "slots"))?.unwrap_or(false), + slots: matches!(class.hasattr(intern!(class.py(), "__slots__")), Ok(true)), } .into()) } diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index d8f10e0e5..8b8d9f99c 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -1212,7 +1212,7 @@ class Model: m: Model m = val.validate_python({'x': 123}) - assert m == Model(x=1) + assert m == Model(x=123) with pytest.raises(ValidationError): val.validate_python({'x': 'abc'}) @@ -1222,3 +1222,71 @@ class Model: with pytest.raises(ValidationError): val.validate_assignment(m, 'x', 'abc') + + +def test_dataclass_slots_field_before_validator(): + kwargs = {'slots': True} + + @dataclasses.dataclass(**kwargs) + class Foo: + a: int + b: str + + @classmethod + def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes: + assert v == b'hello' + assert info.field_name == 'b' + assert info.data == {'a': 1} + return b'hello world!' + + schema = core_schema.dataclass_schema( + Foo, + core_schema.dataclass_args_schema( + 'Foo', + [ + core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), + core_schema.dataclass_field( + name='b', + schema=core_schema.field_before_validator_function(Foo.validate_b, core_schema.str_schema()), + ), + ], + ), + ) + + v = SchemaValidator(schema) + foo = v.validate_python({'a': 1, 'b': b'hello'}) + assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'} + + +def test_dataclass_slots_field_after_validator(): + kwargs = {'slots': True} + + @dataclasses.dataclass(**kwargs) + class Foo: + a: int + b: str + + @classmethod + def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str: + assert v == 'hello' + assert info.field_name == 'b' + assert info.data == {'a': 1} + return 'hello world!' + + schema = core_schema.dataclass_schema( + Foo, + core_schema.dataclass_args_schema( + 'Foo', + [ + core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), + core_schema.dataclass_field( + name='b', + schema=core_schema.field_after_validator_function(Foo.validate_b, core_schema.str_schema()), + ), + ], + ), + ) + + v = SchemaValidator(schema) + foo = v.validate_python({'a': 1, 'b': b'hello'}) + assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'} From 1081443121df85ef0023075a932fabede52e8c97 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Mon, 22 May 2023 15:35:59 +0330 Subject: [PATCH 05/14] Fix for validation and revalidation --- src/validators/dataclass.rs | 34 ++++++++-- tests/validators/test_dataclasses.py | 92 +++++++++++++++++++++++++++- 2 files changed, 118 insertions(+), 8 deletions(-) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 0ce939003..1f41e262d 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -454,7 +454,7 @@ impl BuildValidator for DataclassValidator { )?)?, name, frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false), - slots: matches!(class.hasattr(intern!(class.py(), "__slots__")), Ok(true)), + slots: matches!(class.hasattr(intern!(class.py(), "__slots__")), Ok(true)) | schema.get_as(intern!(py, "slots"))?.unwrap_or(false), } .into()) } @@ -478,10 +478,19 @@ impl Validator for DataclassValidator { let class = self.class.as_ref(py); if matches!(extra.mode, InputType::Python) && input.to_object(py).as_ref(py).is_instance(class)? { if self.revalidate.should_revalidate(input, class) { - let input = input.input_get_attr(intern!(py, "__dict__")).unwrap()?; + let mut validator_input = PyDict::new(py); + if self.slots { + let slots = input.input_get_attr(intern!(py, "__slots__")).unwrap()?.downcast::()?; + for key in slots.iter() { + let key: &PyString = key.downcast()?; + validator_input.set_item(key, input.input_get_attr(key).unwrap()?)?; + } + } else { + validator_input = input.input_get_attr(intern!(py, "__dict__")).unwrap()?.downcast::()?; + } let val_output = self .validator - .validate(py, input, extra, definitions, recursion_guard)?; + .validate(py, validator_input.downcast::()?, extra, definitions, recursion_guard)?; let dc = create_class(self.class.as_ref(py))?; self.set_dict_call(py, dc.as_ref(py), val_output, input)?; Ok(dc) @@ -518,8 +527,19 @@ impl Validator for DataclassValidator { if self.frozen { return Err(ValError::new(ErrorType::FrozenInstance, field_value)); } + + let mut dict = PyDict::new(py); let dict_py_str = intern!(py, "__dict__"); - let dict: &PyDict = obj.getattr(dict_py_str)?.downcast()?; + + if self.slots { + let slots = obj.input_get_attr(intern!(py, "__slots__")).unwrap()?.downcast::()?; + for key in slots.iter() { + let key: &PyString = key.downcast()?; + dict.set_item(key, obj.input_get_attr(key).unwrap()?)?; + } + } else { + dict = obj.getattr(dict_py_str)?.downcast()?; + } let new_dict = dict.copy()?; new_dict.set_item(field_name, field_value)?; @@ -538,7 +558,11 @@ impl Validator for DataclassValidator { let (dc_dict, _): (&PyDict, PyObject) = val_assignment_result.extract(py)?; - force_setattr(py, obj, dict_py_str, dc_dict)?; + if self.slots { + force_setattr(py, obj, field_name, field_value)?; + } else { + force_setattr(py, obj, dict_py_str, dc_dict)?; + } Ok(obj.to_object(py)) } diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index 8b8d9f99c..ecf888dda 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -1195,9 +1195,7 @@ def test_custom_dataclass_names(): @pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') def test_slots() -> None: - kwargs = {'slots': True} - - @dataclasses.dataclass(**kwargs) + @dataclasses.dataclass(slots=True) class Model: x: int @@ -1290,3 +1288,91 @@ def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str: v = SchemaValidator(schema) foo = v.validate_python({'a': 1, 'b': b'hello'}) assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'} + + +@dataclasses.dataclass(slots=True) +class FooDataclassSlots: + a: str + b: bool + + +@dataclasses.dataclass(slots=True) +class FooDataclassSameSlots(FooDataclassSlots): + pass + + +@dataclasses.dataclass(slots=True) +class FooDataclassMoreSlots(FooDataclassSlots): + c: str + + +@dataclasses.dataclass(slots=True) +class DuplicateDifferentSlots: + a: str + b: bool + + +@pytest.mark.parametrize( + 'revalidate_instances,input_value,expected', + [ + ('always', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), + ('always', FooDataclassSlots(a='hello', b=True), {'a': 'hello', 'b': True}), + ('always', FooDataclassSameSlots(a='hello', b=True), {'a': 'hello', 'b': True}), + ('always', FooDataclassMoreSlots(a='hello', b=True, c='more'), Err(r'c\s+Unexpected keyword argument')), + ( + 'always', + DuplicateDifferentSlots(a='hello', b=True), + Err('should be a dictionary or an instance of FooDataclass'), + ), + # revalidate_instances='subclass-instances' + ('subclass-instances', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), + ('subclass-instances', FooDataclassSlots(a='hello', b=True), {'a': 'hello', 'b': True}), + ('subclass-instances', FooDataclassSlots(a=b'hello', b='true'), {'a': b'hello', 'b': 'true'}), + ('subclass-instances', FooDataclassSameSlots(a='hello', b=True), {'a': 'hello', 'b': True}), + ('subclass-instances', FooDataclassSameSlots(a=b'hello', b='true'), {'a': 'hello', 'b': True}), + ('subclass-instances', FooDataclassMoreSlots(a='hello', b=True, c='more'), Err('Unexpected keyword argument')), + ( + 'subclass-instances', + DuplicateDifferentSlots(a='hello', b=True), + Err('dictionary or an instance of FooDataclass'), + ), + # revalidate_instances='never' + ('never', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), + ('never', FooDataclassSlots(a='hello', b=True), {'a': 'hello', 'b': True}), + ('never', FooDataclassSameSlots(a='hello', b=True), {'a': 'hello', 'b': True}), + ('never', FooDataclassMoreSlots(a='hello', b=True, c='more'), {'a': 'hello', 'b': True, 'c': 'more'}), + ('never', FooDataclassMoreSlots(a='hello', b='wrong', c='more'), {'a': 'hello', 'b': 'wrong', 'c': 'more'}), + ( + 'never', + DuplicateDifferentSlots(a='hello', b=True), + Err('should be a dictionary or an instance of FooDataclass'), + ), + ], +) +def test_slots_dataclass_subclass(revalidate_instances, input_value, expected): + schema = core_schema.dataclass_schema( + FooDataclassSlots, + core_schema.dataclass_args_schema( + 'FooDataclass', + [ + core_schema.dataclass_field(name='a', schema=core_schema.str_schema()), + core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), + ], + extra_behavior='forbid', + ), + revalidate_instances=revalidate_instances, + slots=True, + ) + v = SchemaValidator(schema) + + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=expected.message) as exc_info: + print(v.validate_python(input_value)) + + # debug(exc_info.value.errors(include_url=False)) + if expected.errors is not None: + assert exc_info.value.errors(include_url=False) == expected.errors + else: + dc = v.validate_python(input_value) + assert dataclasses.is_dataclass(dc) + assert dataclasses.asdict(dc) == expected From 2862361862b94f89b7272c88f3ee1cd421751356 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Mon, 22 May 2023 15:42:50 +0330 Subject: [PATCH 06/14] Fix lint --- src/validators/dataclass.rs | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 1f41e262d..c82ed78c7 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -454,7 +454,8 @@ impl BuildValidator for DataclassValidator { )?)?, name, frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false), - slots: matches!(class.hasattr(intern!(class.py(), "__slots__")), Ok(true)) | schema.get_as(intern!(py, "slots"))?.unwrap_or(false), + slots: matches!(class.hasattr(intern!(class.py(), "__slots__")), Ok(true)) + | schema.get_as(intern!(py, "slots"))?.unwrap_or(false), } .into()) } @@ -480,17 +481,29 @@ impl Validator for DataclassValidator { if self.revalidate.should_revalidate(input, class) { let mut validator_input = PyDict::new(py); if self.slots { - let slots = input.input_get_attr(intern!(py, "__slots__")).unwrap()?.downcast::()?; + let slots = input + .input_get_attr(intern!(py, "__slots__")) + .unwrap()? + .downcast::()?; for key in slots.iter() { let key: &PyString = key.downcast()?; validator_input.set_item(key, input.input_get_attr(key).unwrap()?)?; } } else { - validator_input = input.input_get_attr(intern!(py, "__dict__")).unwrap()?.downcast::()?; + validator_input = input + .input_get_attr(intern!(py, "__dict__")) + .unwrap()? + .downcast::()?; } let val_output = self .validator - .validate(py, validator_input.downcast::()?, extra, definitions, recursion_guard)?; + .validate( + py, + validator_input.downcast::()?, + extra, + definitions, + recursion_guard + )?; let dc = create_class(self.class.as_ref(py))?; self.set_dict_call(py, dc.as_ref(py), val_output, input)?; Ok(dc) @@ -532,7 +545,10 @@ impl Validator for DataclassValidator { let dict_py_str = intern!(py, "__dict__"); if self.slots { - let slots = obj.input_get_attr(intern!(py, "__slots__")).unwrap()?.downcast::()?; + let slots = obj + .input_get_attr(intern!(py, "__slots__")) + .unwrap()? + .downcast::()?; for key in slots.iter() { let key: &PyString = key.downcast()?; dict.set_item(key, obj.input_get_attr(key).unwrap()?)?; From 7daf210ebc13cb3d21b3068184e0ffde8179de23 Mon Sep 17 00:00:00 2001 From: Hasan Ramezani Date: Mon, 22 May 2023 15:46:47 +0330 Subject: [PATCH 07/14] Skip tests --- tests/validators/test_dataclasses.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index ecf888dda..661c669bb 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -1222,10 +1222,9 @@ class Model: val.validate_assignment(m, 'x', 'abc') +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') def test_dataclass_slots_field_before_validator(): - kwargs = {'slots': True} - - @dataclasses.dataclass(**kwargs) + @dataclasses.dataclass(slots=True) class Foo: a: int b: str @@ -1256,10 +1255,9 @@ def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes: assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'} +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') def test_dataclass_slots_field_after_validator(): - kwargs = {'slots': True} - - @dataclasses.dataclass(**kwargs) + @dataclasses.dataclass(slots=True) class Foo: a: int b: str @@ -1290,23 +1288,29 @@ def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str: assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'} -@dataclasses.dataclass(slots=True) +if sys.version_info < (3, 10): + kwargs = {} +else: + kwargs = {'slots': True} + + +@dataclasses.dataclass(**kwargs) class FooDataclassSlots: a: str b: bool -@dataclasses.dataclass(slots=True) +@dataclasses.dataclass(**kwargs) class FooDataclassSameSlots(FooDataclassSlots): pass -@dataclasses.dataclass(slots=True) +@dataclasses.dataclass(**kwargs) class FooDataclassMoreSlots(FooDataclassSlots): c: str -@dataclasses.dataclass(slots=True) +@dataclasses.dataclass(**kwargs) class DuplicateDifferentSlots: a: str b: bool @@ -1349,6 +1353,7 @@ class DuplicateDifferentSlots: ), ], ) +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') def test_slots_dataclass_subclass(revalidate_instances, input_value, expected): schema = core_schema.dataclass_schema( FooDataclassSlots, From 50d3e242dc9300cb1d47728dc9bfc7a7133c341a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 22 May 2023 21:45:34 +0100 Subject: [PATCH 08/14] fix dataclass support with slots, cleanup input --- pydantic_core/core_schema.py | 6 +- src/input/input_abstract.rs | 4 +- src/input/input_python.rs | 8 ++- src/serializers/shared.rs | 16 ++++- src/validators/dataclass.rs | 95 ++++++++++++++-------------- src/validators/model.rs | 18 +++--- tests/serializers/test_any.py | 11 ++++ tests/test_schema_functions.py | 6 +- tests/validators/test_dataclasses.py | 5 +- 9 files changed, 97 insertions(+), 72 deletions(-) diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index 0a2848ca3..e43c2e503 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -3101,7 +3101,7 @@ class DataclassSchema(TypedDict, total=False): ref: str metadata: Any serialization: SerSchema - slots: bool + slots: List[str] def dataclass_schema( @@ -3116,7 +3116,7 @@ def dataclass_schema( metadata: Any = None, serialization: SerSchema | None = None, frozen: bool | None = None, - slots: bool = False, + slots: List[str] | None = None, ) -> DataclassSchema: """ Returns a schema for a dataclass. As with `ModelSchema`, this schema can only be used as a field within @@ -3134,7 +3134,7 @@ def dataclass_schema( metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema frozen: Whether the dataclass is frozen - slots: Whether the slots is enabled on dataclass + slots: The slots to use for the dataclass, set only if `slots=True` on the dataclass """ return dict_not_none( type='dataclass', diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 280026ffe..a5490c553 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -45,8 +45,8 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { None } - fn is_exact_instance(&self, _class: &PyType) -> bool { - false + fn input_is_instance(&self, _class: &PyType) -> Option<&PyAny> { + None } fn is_python(&self) -> bool { diff --git a/src/input/input_python.rs b/src/input/input_python.rs index a19f5d7ca..5b9b9aff2 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -113,8 +113,12 @@ impl<'a> Input<'a> for PyAny { Some(self.getattr(name)) } - fn is_exact_instance(&self, class: &PyType) -> bool { - self.get_type().is(class) + fn input_is_instance(&self, class: &PyType) -> Option<&PyAny> { + if self.is_instance(class).unwrap_or(false) { + Some(self) + } else { + None + } } fn is_python(&self) -> bool { diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 8279dc415..3a0f9a527 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; -use pyo3::types::{PyDict, PySet}; +use pyo3::types::{PyDict, PySet, PyString, PyTuple}; use pyo3::{intern, PyTraverseError, PyVisit}; use enum_dispatch::enum_dispatch; @@ -329,8 +329,18 @@ pub(crate) fn to_json_bytes( pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Extra) -> PyResult<&'py PyDict> { let py = value.py(); - let attr = value.getattr(intern!(py, "__dict__"))?; - let attrs: &PyDict = attr.downcast()?; + let attrs: &PyDict = match value.getattr(intern!(py, "__dict__")) { + Ok(attr) => attr.downcast()?, + Err(_) => { + let slots: &PyTuple = value.getattr(intern!(py, "__slots__"))?.downcast()?; + let dict = PyDict::new(py); + for slot in slots { + let slot: &PyString = slot.downcast()?; + dict.set_item(slot, value.getattr(slot)?)?; + } + dict + } + }; if is_model && extra.exclude_unset { let fields_set: &PySet = value.getattr(intern!(py, "__pydantic_fields_set__"))?.downcast()?; diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index c82ed78c7..9ee938a2d 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -7,7 +7,6 @@ use ahash::AHashSet; use crate::build_tools::{is_strict, py_err, schema_or_config_same, ExtraBehavior, SchemaDict}; use crate::errors::{ErrorType, ValError, ValLineError, ValResult}; -use crate::input::InputType; use crate::input::{GenericArguments, Input}; use crate::lookup_key::LookupKey; use crate::recursion_guard::RecursionGuard; @@ -415,7 +414,7 @@ pub struct DataclassValidator { revalidate: Revalidate, name: String, frozen: bool, - slots: bool, + slots: Option>>, } impl BuildValidator for DataclassValidator { @@ -442,6 +441,17 @@ impl BuildValidator for DataclassValidator { None }; + let slots = match schema.get_as::<&PyList>(intern!(py, "slots"))? { + Some(slots) => { + let slots = slots + .iter() + .map(|s| Ok(s.downcast::()?.into_py(py))) + .collect::>>()?; + Some(slots) + } + None => None, + }; + Ok(Self { strict: is_strict(schema, config)?, validator: Box::new(validator), @@ -454,8 +464,7 @@ impl BuildValidator for DataclassValidator { )?)?, name, frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false), - slots: matches!(class.hasattr(intern!(class.py(), "__slots__")), Ok(true)) - | schema.get_as(intern!(py, "slots"))?.unwrap_or(false), + slots, } .into()) } @@ -477,33 +486,25 @@ impl Validator for DataclassValidator { // same logic as on models let class = self.class.as_ref(py); - if matches!(extra.mode, InputType::Python) && input.to_object(py).as_ref(py).is_instance(class)? { - if self.revalidate.should_revalidate(input, class) { - let mut validator_input = PyDict::new(py); - if self.slots { - let slots = input - .input_get_attr(intern!(py, "__slots__")) - .unwrap()? - .downcast::()?; - for key in slots.iter() { - let key: &PyString = key.downcast()?; - validator_input.set_item(key, input.input_get_attr(key).unwrap()?)?; + if let Some(py_input) = input.input_is_instance(class) { + if self.revalidate.should_revalidate(py_input, class) { + let input_dict = match py_input.getattr(intern!(py, "__dict__")) { + Ok(attr) => attr, + Err(_) => { + // we inspect `__slots__` to get the attributes instead of using `self.slots` as a + // subclass could have `slots=True` + let slots: &PyTuple = py_input.getattr(intern!(py, "__slots__"))?.downcast()?; + let dict = PyDict::new(py); + for slot in slots { + let slot: &PyString = slot.downcast()?; + dict.set_item(slot, py_input.getattr(slot)?)?; + } + dict } - } else { - validator_input = input - .input_get_attr(intern!(py, "__dict__")) - .unwrap()? - .downcast::()?; - } + }; let val_output = self .validator - .validate( - py, - validator_input.downcast::()?, - extra, - definitions, - recursion_guard - )?; + .validate(py, input_dict, extra, definitions, recursion_guard)?; let dc = create_class(self.class.as_ref(py))?; self.set_dict_call(py, dc.as_ref(py), val_output, input)?; Ok(dc) @@ -541,27 +542,20 @@ impl Validator for DataclassValidator { return Err(ValError::new(ErrorType::FrozenInstance, field_value)); } - let mut dict = PyDict::new(py); - let dict_py_str = intern!(py, "__dict__"); - - if self.slots { - let slots = obj - .input_get_attr(intern!(py, "__slots__")) - .unwrap()? - .downcast::()?; - for key in slots.iter() { - let key: &PyString = key.downcast()?; - dict.set_item(key, obj.input_get_attr(key).unwrap()?)?; + let new_dict = if let Some(ref slots) = self.slots { + let slots_dict = PyDict::new(py); + for slot in slots { + let slot = slot.as_ref(py); + slots_dict.set_item(slot, obj.getattr(slot)?)?; } + slots_dict } else { - dict = obj.getattr(dict_py_str)?.downcast()?; - } + let dunder_dict: &PyDict = obj.getattr(intern!(py, "__dict__"))?.downcast()?; + dunder_dict.copy()? + }; - let new_dict = dict.copy()?; new_dict.set_item(field_name, field_value)?; - // Discard the second return value, which is `init_only_args` but is always - // None anyway for validate_assignment; see validate_assignment in DataclassArgsValidator let val_assignment_result = self.validator.validate_assignment( py, new_dict, @@ -574,10 +568,11 @@ impl Validator for DataclassValidator { let (dc_dict, _): (&PyDict, PyObject) = val_assignment_result.extract(py)?; - if self.slots { - force_setattr(py, obj, field_name, field_value)?; + if self.slots.is_some() { + let value = dc_dict.get_item(field_name).unwrap(); + force_setattr(py, obj, field_name, value)?; } else { - force_setattr(py, obj, dict_py_str, dc_dict)?; + force_setattr(py, obj, intern!(py, "__dict__"), dc_dict)?; } Ok(obj.to_object(py)) @@ -629,6 +624,7 @@ impl DataclassValidator { Ok(self_instance.into_py(py)) } + fn set_dict_call<'s, 'data>( &'s self, py: Python<'data>, @@ -637,8 +633,9 @@ impl DataclassValidator { input: &'data impl Input<'data>, ) -> ValResult<'data, ()> { let (dc_dict, post_init_kwargs): (&PyAny, &PyAny) = val_output.extract(py)?; - if self.slots { - for (key, value) in dc_dict.downcast::()?.iter() { + if self.slots.is_some() { + let dc_dict: &PyDict = dc_dict.downcast()?; + for (key, value) in dc_dict.iter() { force_setattr(py, dc, key, value)?; } } else { diff --git a/src/validators/model.rs b/src/validators/model.rs index 804c4af9f..cddf47216 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -8,7 +8,7 @@ use pyo3::{ffi, intern}; use crate::build_tools::{py_err, schema_or_config_same, SchemaDict}; use crate::errors::{ErrorType, ValError, ValResult}; -use crate::input::{py_error_on_minusone, Input, InputType}; +use crate::input::{py_error_on_minusone, Input}; use crate::recursion_guard::RecursionGuard; use super::function::convert_err; @@ -37,11 +37,11 @@ impl Revalidate { } } - pub fn should_revalidate<'d>(&self, input: &impl Input<'d>, class: &PyType) -> bool { + pub fn should_revalidate(&self, input: &PyAny, class: &PyType) -> bool { match self { Revalidate::Always => true, Revalidate::Never => false, - Revalidate::SubclassInstances => !input.is_exact_instance(class), + Revalidate::SubclassInstances => !input.get_type().is(class), } } } @@ -125,16 +125,16 @@ impl Validator for ModelValidator { // if the input is an instance of the class, we "revalidate" it - e.g. we extract and reuse `__pydantic_fields_set__` // but use from attributes to create a new instance of the model field type let class = self.class.as_ref(py); - if matches!(extra.mode, InputType::Python) && input.to_object(py).as_ref(py).is_instance(class)? { - if self.revalidate.should_revalidate(input, class) { + if let Some(py_input) = input.input_is_instance(class) { + if self.revalidate.should_revalidate(py_input, class) { if self.root_model { - let inner_input: &PyAny = input.input_get_attr(intern!(py, ROOT_FIELD)).unwrap()?; + let inner_input = py_input.getattr(intern!(py, ROOT_FIELD))?; self.validate_construct(py, inner_input, None, extra, definitions, recursion_guard) } else { - let fields_set = input.input_get_attr(intern!(py, DUNDER_FIELDS_SET_KEY)).unwrap()?; + let fields_set = py_input.getattr(intern!(py, DUNDER_FIELDS_SET_KEY))?; // get dict here so from_attributes logic doesn't apply - let dict = input.input_get_attr(intern!(py, DUNDER_DICT)).unwrap()?; - let model_extra = input.input_get_attr(intern!(py, DUNDER_MODEL_EXTRA_KEY)).unwrap()?; + let dict = py_input.getattr(intern!(py, DUNDER_DICT))?; + let model_extra = py_input.getattr(intern!(py, DUNDER_MODEL_EXTRA_KEY))?; let inner_input: &PyAny = if model_extra.is_none() { dict diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index 30350f654..bef061ae0 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -473,3 +473,14 @@ def __init__(self, a: str, b: bytes): assert j == b'{"a":"hello"}' assert s.to_python(Foo(a='hello', b=b'more'), exclude={'a'}) == IsStrictDict() + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +def test_dataclass_slots(any_serializer): + @dataclasses.dataclass(slots=True) + class Foo: + a: int + b: str + + foo = Foo(1, 'a') + assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a') diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index ed946068e..935d101a0 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -250,13 +250,13 @@ def args(*args, **kwargs): core_schema.dataclass_schema, # MyModel should be a dataclass, but I'm being lazy here args(MyModel, {'type': 'int'}), - {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel, 'slots': False}, + {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel}, ), ( core_schema.dataclass_schema, # MyModel should be a dataclass, but I'm being lazy here - args(MyModel, {'type': 'int'}, slots=True), - {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel, 'slots': True}, + args(MyModel, {'type': 'int'}, slots=['a']), + {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel, 'slots': ['a']}, ), ] diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index 661c669bb..0137116a9 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -1204,6 +1204,7 @@ class Model: core_schema.dataclass_args_schema( 'Model', [core_schema.dataclass_field(name='x', schema=core_schema.int_schema())] ), + slots=['x'], ) val = SchemaValidator(schema) @@ -1248,6 +1249,7 @@ def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes: ), ], ), + slots=['a', 'b'], ) v = SchemaValidator(schema) @@ -1281,6 +1283,7 @@ def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str: ), ], ), + slots=['a', 'b'], ) v = SchemaValidator(schema) @@ -1366,7 +1369,7 @@ def test_slots_dataclass_subclass(revalidate_instances, input_value, expected): extra_behavior='forbid', ), revalidate_instances=revalidate_instances, - slots=True, + slots=['a', 'b'], ) v = SchemaValidator(schema) From 00a9aafa0627ce2a387f9e7674b04858ca5a376d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 22 May 2023 21:58:07 +0100 Subject: [PATCH 09/14] fix for python 3.11 --- src/input/input_abstract.rs | 7 +------ src/input/input_python.rs | 4 ---- src/serializers/mod.rs | 1 + src/serializers/shared.rs | 24 ++++++++++++++---------- src/validators/dataclass.rs | 13 ++----------- src/validators/model.rs | 2 +- tests/serializers/test_any.py | 7 +++++++ 7 files changed, 26 insertions(+), 32 deletions(-) diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index a5490c553..3b8e6b4c0 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -1,6 +1,6 @@ use std::fmt; -use pyo3::types::{PyDict, PyString, PyType}; +use pyo3::types::{PyDict, PyType}; use pyo3::{intern, prelude::*}; use crate::errors::{InputValue, LocItem, ValResult}; @@ -40,11 +40,6 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { fn is_none(&self) -> bool; - #[cfg_attr(has_no_coverage, no_coverage)] - fn input_get_attr(&self, _name: &PyString) -> Option> { - None - } - fn input_is_instance(&self, _class: &PyType) -> Option<&PyAny> { None } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index 5b9b9aff2..b6cea1b5d 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -109,10 +109,6 @@ impl<'a> Input<'a> for PyAny { self.is_none() } - fn input_get_attr(&self, name: &PyString) -> Option> { - Some(self.getattr(name)) - } - fn input_is_instance(&self, class: &PyType) -> Option<&PyAny> { if self.is_instance(class).unwrap_or(false) { Some(self) diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 8013c1767..9976b74e6 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -11,6 +11,7 @@ use config::SerializationConfig; pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue}; use extra::{CollectWarnings, SerRecursionGuard}; pub(crate) use extra::{Extra, SerMode, SerializationState}; +pub(crate) use shared::slots_dc_dict; pub use shared::CombinedSerializer; use shared::{to_json_bytes, BuildSerializer, TypeSerializer}; diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 3a0f9a527..fe23107a2 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -3,7 +3,7 @@ use std::fmt::Debug; use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; -use pyo3::types::{PyDict, PySet, PyString, PyTuple}; +use pyo3::types::{PyDict, PySet, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; use enum_dispatch::enum_dispatch; @@ -331,16 +331,9 @@ pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Ext let py = value.py(); let attrs: &PyDict = match value.getattr(intern!(py, "__dict__")) { Ok(attr) => attr.downcast()?, - Err(_) => { - let slots: &PyTuple = value.getattr(intern!(py, "__slots__"))?.downcast()?; - let dict = PyDict::new(py); - for slot in slots { - let slot: &PyString = slot.downcast()?; - dict.set_item(slot, value.getattr(slot)?)?; - } - dict - } + Err(_) => return slots_dc_dict(value), }; + if is_model && extra.exclude_unset { let fields_set: &PySet = value.getattr(intern!(py, "__pydantic_fields_set__"))?.downcast()?; @@ -355,3 +348,14 @@ pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Ext Ok(attrs) } } + +pub(crate) fn slots_dc_dict(value: &PyAny) -> PyResult<&PyDict> { + let py = value.py(); + let dc_fields: &PyDict = value.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; + let dict = PyDict::new(py); + for field in dc_fields.keys() { + let field: &PyString = field.downcast()?; + dict.set_item(field, value.getattr(field)?)?; + } + Ok(dict) +} diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 9ee938a2d..28e3ac7a8 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -10,6 +10,7 @@ use crate::errors::{ErrorType, ValError, ValLineError, ValResult}; use crate::input::{GenericArguments, Input}; use crate::lookup_key::LookupKey; use crate::recursion_guard::RecursionGuard; +use crate::serializers::slots_dc_dict; use crate::validators::function::convert_err; use super::arguments::{json_get, json_slice, py_get, py_slice}; @@ -490,17 +491,7 @@ impl Validator for DataclassValidator { if self.revalidate.should_revalidate(py_input, class) { let input_dict = match py_input.getattr(intern!(py, "__dict__")) { Ok(attr) => attr, - Err(_) => { - // we inspect `__slots__` to get the attributes instead of using `self.slots` as a - // subclass could have `slots=True` - let slots: &PyTuple = py_input.getattr(intern!(py, "__slots__"))?.downcast()?; - let dict = PyDict::new(py); - for slot in slots { - let slot: &PyString = slot.downcast()?; - dict.set_item(slot, py_input.getattr(slot)?)?; - } - dict - } + Err(_) => slots_dc_dict(py_input)?, }; let val_output = self .validator diff --git a/src/validators/model.rs b/src/validators/model.rs index cddf47216..e19d5ee6a 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -211,7 +211,7 @@ impl Validator for ModelValidator { let (output, _, updated_fields_set): (&PyDict, &PyAny, &PySet) = output.extract(py)?; - if let Ok(fields_set) = model.input_get_attr(intern!(py, DUNDER_FIELDS_SET_KEY)).unwrap() { + if let Ok(fields_set) = model.getattr(intern!(py, DUNDER_FIELDS_SET_KEY)) { let fields_set: &PySet = fields_set.downcast()?; for field_name in updated_fields_set { fields_set.add(field_name)?; diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index bef061ae0..ce155e48e 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -484,3 +484,10 @@ class Foo: foo = Foo(1, 'a') assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a') + + @dataclasses.dataclass(slots=True) + class Foo2(Foo): + pass + + foo2 = Foo2(2, 'b') + assert any_serializer.to_python(foo2) == IsStrictDict(a=2, b='b') From 9eae5de48e7ff9d490f2459cefb582600b59af41 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 22 May 2023 22:31:58 +0100 Subject: [PATCH 10/14] properly match dataclasses.fields logic --- src/serializers/shared.rs | 26 ++++++++++++++++++++------ tests/serializers/test_any.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index fe23107a2..5b2cd1975 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use std::fmt::Debug; use pyo3::exceptions::PyTypeError; +use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; use pyo3::types::{PyDict, PySet, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; @@ -349,13 +350,26 @@ pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Ext } } -pub(crate) fn slots_dc_dict(value: &PyAny) -> PyResult<&PyDict> { - let py = value.py(); - let dc_fields: &PyDict = value.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; +static DC_FIELD_MARKER: GILOnceCell = GILOnceCell::new(); + +pub(crate) fn slots_dc_dict(dc: &PyAny) -> PyResult<&PyDict> { + let py = dc.py(); + let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; let dict = PyDict::new(py); - for field in dc_fields.keys() { - let field: &PyString = field.downcast()?; - dict.set_item(field, value.getattr(field)?)?; + + // need to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)` + let field_type_marker_obj = DC_FIELD_MARKER.get_or_try_init(py, || { + let field_ = py.import("dataclasses")?.getattr("_FIELD")?; + Ok::(field_.into_py(py)) + })?; + let field_type_marker = field_type_marker_obj.as_ref(py); + + for (field_name, field) in dc_fields.iter() { + let field_type = field.getattr(intern!(py, "_field_type"))?; + if field_type.is(field_type_marker) { + let field_name: &PyString = field_name.downcast()?; + dict.set_item(field_name, dc.getattr(field_name)?)?; + } } Ok(dict) } diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index ce155e48e..de423170d 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -475,6 +475,24 @@ def __init__(self, a: str, b: bytes): assert s.to_python(Foo(a='hello', b=b'more'), exclude={'a'}) == IsStrictDict() +def test_dataclass_classvar(any_serializer): + @dataclasses.dataclass(slots=True) + class Foo: + a: int + b: str + c: ClassVar[int] = 1 + + foo = Foo(1, 'a') + assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a') + + @dataclasses.dataclass(slots=True) + class Foo2(Foo): + pass + + foo2 = Foo2(2, 'b') + assert any_serializer.to_python(foo2) == IsStrictDict(a=2, b='b') + + @pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') def test_dataclass_slots(any_serializer): @dataclasses.dataclass(slots=True) @@ -491,3 +509,16 @@ class Foo2(Foo): foo2 = Foo2(2, 'b') assert any_serializer.to_python(foo2) == IsStrictDict(a=2, b='b') + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +def test_dataclass_slots_init_vars(any_serializer): + @dataclasses.dataclass(slots=True) + class Foo: + a: int + b: str + c: dataclasses.InitVar[int] + d: ClassVar[int] = 42 + + foo = Foo(1, 'a', 42) + assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a') From c92e7da2767ad4a372b4346451c987abb6a5c088 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 22 May 2023 22:41:49 +0100 Subject: [PATCH 11/14] fix test_dataclass_classvar --- tests/serializers/test_any.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index de423170d..34834613b 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -476,7 +476,7 @@ def __init__(self, a: str, b: bytes): def test_dataclass_classvar(any_serializer): - @dataclasses.dataclass(slots=True) + @dataclasses.dataclass class Foo: a: int b: str @@ -485,7 +485,7 @@ class Foo: foo = Foo(1, 'a') assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a') - @dataclasses.dataclass(slots=True) + @dataclasses.dataclass class Foo2(Foo): pass From 66072e0ba2526ce6d39da6c2977c6456978b967a Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 22 May 2023 15:51:20 -0600 Subject: [PATCH 12/14] Update the note about python versions under which dataclass slots are supported --- tests/serializers/test_any.py | 4 ++-- tests/validators/test_dataclasses.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index 34834613b..1ff149cbc 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -493,7 +493,7 @@ class Foo2(Foo): assert any_serializer.to_python(foo2) == IsStrictDict(a=2, b='b') -@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') def test_dataclass_slots(any_serializer): @dataclasses.dataclass(slots=True) class Foo: @@ -511,7 +511,7 @@ class Foo2(Foo): assert any_serializer.to_python(foo2) == IsStrictDict(a=2, b='b') -@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') def test_dataclass_slots_init_vars(any_serializer): @dataclasses.dataclass(slots=True) class Foo: diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index 0137116a9..58b6d4f87 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -1193,7 +1193,7 @@ def test_custom_dataclass_names(): ] -@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') def test_slots() -> None: @dataclasses.dataclass(slots=True) class Model: @@ -1223,7 +1223,7 @@ class Model: val.validate_assignment(m, 'x', 'abc') -@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') def test_dataclass_slots_field_before_validator(): @dataclasses.dataclass(slots=True) class Foo: @@ -1257,7 +1257,7 @@ def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes: assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'} -@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') def test_dataclass_slots_field_after_validator(): @dataclasses.dataclass(slots=True) class Foo: @@ -1356,7 +1356,7 @@ class DuplicateDifferentSlots: ), ], ) -@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') def test_slots_dataclass_subclass(revalidate_instances, input_value, expected): schema = core_schema.dataclass_schema( FooDataclassSlots, From a0826b1c5498045d97b6268a099edc4b4dedd850 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Tue, 23 May 2023 13:37:51 +0100 Subject: [PATCH 13/14] fix dataclass validation & serialization --- pydantic_core/core_schema.py | 2 +- src/serializers/infer.rs | 106 ++++++------- src/serializers/mod.rs | 2 +- src/serializers/shared.rs | 43 ++---- src/serializers/type_serializers/dataclass.rs | 141 +++++++++++++++++- src/serializers/type_serializers/mod.rs | 2 +- src/serializers/type_serializers/model.rs | 32 ++-- src/validators/dataclass.rs | 7 +- tests/serializers/test_any.py | 48 +++++- tests/serializers/test_dataclasses.py | 39 +++++ tests/test_schema_functions.py | 17 ++- tests/validators/test_dataclasses.py | 35 ++++- 12 files changed, 344 insertions(+), 130 deletions(-) diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index e43c2e503..7743829f9 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -3123,7 +3123,7 @@ def dataclass_schema( another schema, not as the root type. Args: - cls: The dataclass type, used to to perform subclass checks + cls: The dataclass type, used to perform subclass checks schema: The schema to use for the dataclass fields cls_name: The name to use in error locs, etc; this is useful for generics (default: `cls.__name__`) post_init: Whether to call `__post_init__` after validation diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index c96a6981f..63a6dd037 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -21,7 +21,7 @@ use super::errors::{py_err_se_err, PydanticSerializationError}; use super::extra::{Extra, SerMode}; use super::filter::AnyFilter; use super::ob_type::ObType; -use super::shared::object_to_dict; +use super::shared::dataclass_to_dict; pub(crate) fn infer_to_python( value: &PyAny, @@ -97,29 +97,23 @@ pub(crate) fn infer_to_python_known( Ok::(new_dict.into_py(py)) }; - let serialize_with_serializer = |value: &PyAny, is_model: bool| { - if let Ok(py_serializer) = value.getattr(intern!(py, "__pydantic_serializer__")) { - if let Ok(serializer) = py_serializer.extract::() { - let extra = serializer.build_extra( - py, - extra.mode, - extra.by_alias, - extra.warnings, - extra.exclude_unset, - extra.exclude_defaults, - extra.exclude_none, - extra.round_trip, - extra.rec_guard, - extra.serialize_unknown, - extra.fallback, - ); - return serializer.serializer.to_python(value, include, exclude, &extra); - } - } - // Fallback to dict serialization if `__pydantic_serializer__` is not set. - // This currently only affects non-pydantic dataclasses. - let dict = object_to_dict(value, is_model, extra)?; - serialize_dict(dict) + let serialize_with_serializer = || { + let py_serializer = value.getattr(intern!(py, "__pydantic_serializer__"))?; + let serializer: SchemaSerializer = py_serializer.extract()?; + let extra = serializer.build_extra( + py, + extra.mode, + extra.by_alias, + extra.warnings, + extra.exclude_unset, + extra.exclude_defaults, + extra.exclude_none, + extra.round_trip, + extra.rec_guard, + extra.serialize_unknown, + extra.fallback, + ); + serializer.serializer.to_python(value, include, exclude, &extra) }; let value = match extra.mode { @@ -191,8 +185,8 @@ pub(crate) fn infer_to_python_known( let py_url: PyMultiHostUrl = value.extract()?; py_url.__str__().into_py(py) } - ObType::PydanticSerializable => serialize_with_serializer(value, true)?, - ObType::Dataclass => serialize_with_serializer(value, false)?, + ObType::PydanticSerializable => serialize_with_serializer()?, + ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?, ObType::Enum => { let v = value.getattr(intern!(py, "value"))?; infer_to_python(v, include, exclude, extra)?.into_py(py) @@ -257,8 +251,8 @@ pub(crate) fn infer_to_python_known( } new_dict.into_py(py) } - ObType::PydanticSerializable => serialize_with_serializer(value, true)?, - ObType::Dataclass => serialize_with_serializer(value, false)?, + ObType::PydanticSerializable => serialize_with_serializer()?, + ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?, ObType::Generator => { let iter = super::type_serializers::generator::SerializationIterator::new( value.downcast()?, @@ -406,36 +400,6 @@ pub(crate) fn infer_serialize_known( }}; } - macro_rules! serialize_with_serializer { - ($py_serializable:expr, $is_model:expr) => {{ - let py = $py_serializable.py(); - if let Ok(py_serializer) = value.getattr(intern!(py, "__pydantic_serializer__")) { - if let Ok(extracted_serializer) = py_serializer.extract::() { - let extra = extracted_serializer.build_extra( - py, - extra.mode, - extra.by_alias, - extra.warnings, - extra.exclude_unset, - extra.exclude_defaults, - extra.exclude_none, - extra.round_trip, - extra.rec_guard, - extra.serialize_unknown, - extra.fallback, - ); - let pydantic_serializer = - PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra); - return pydantic_serializer.serialize(serializer); - } - } - // Fallback to dict serialization if `__pydantic_serializer__` is not set. - // This currently only affects non-pydantic dataclasses. - let dict = object_to_dict(value, $is_model, extra).map_err(py_err_se_err)?; - serialize_dict!(dict) - }}; - } - let ser_result = match ob_type { ObType::None => serializer.serialize_none(), ObType::Int | ObType::IntSubclass => serialize!(i64), @@ -490,8 +454,30 @@ pub(crate) fn infer_serialize_known( let py_url: PyMultiHostUrl = value.extract().map_err(py_err_se_err)?; serializer.serialize_str(&py_url.__str__()) } - ObType::Dataclass => serialize_with_serializer!(value, false), - ObType::PydanticSerializable => serialize_with_serializer!(value, true), + ObType::PydanticSerializable => { + let py = value.py(); + let py_serializer = value + .getattr(intern!(py, "__pydantic_serializer__")) + .map_err(py_err_se_err)?; + let extracted_serializer: SchemaSerializer = py_serializer.extract().map_err(py_err_se_err)?; + let extra = extracted_serializer.build_extra( + py, + extra.mode, + extra.by_alias, + extra.warnings, + extra.exclude_unset, + extra.exclude_defaults, + extra.exclude_none, + extra.round_trip, + extra.rec_guard, + extra.serialize_unknown, + extra.fallback, + ); + let pydantic_serializer = + PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra); + pydantic_serializer.serialize(serializer) + } + ObType::Dataclass => serialize_dict!(dataclass_to_dict(value).map_err(py_err_se_err)?), ObType::Enum => { let v = value.getattr(intern!(value.py(), "value")).map_err(py_err_se_err)?; infer_serialize(v, serializer, include, exclude, extra) diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index 9976b74e6..c38d5c187 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -11,7 +11,7 @@ use config::SerializationConfig; pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue}; use extra::{CollectWarnings, SerRecursionGuard}; pub(crate) use extra::{Extra, SerMode, SerializationState}; -pub(crate) use shared::slots_dc_dict; +pub(crate) use shared::dataclass_to_dict; pub use shared::CombinedSerializer; use shared::{to_json_bytes, BuildSerializer, TypeSerializer}; diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 5b2cd1975..d6f559711 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -4,7 +4,7 @@ use std::fmt::Debug; use pyo3::exceptions::PyTypeError; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; -use pyo3::types::{PyDict, PySet, PyString}; +use pyo3::types::{PyDict, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; use enum_dispatch::enum_dispatch; @@ -96,7 +96,6 @@ combined_serializer! { super::type_serializers::other::CallableBuilder; super::type_serializers::definitions::DefinitionsSerializerBuilder; super::type_serializers::dataclass::DataclassArgsBuilder; - super::type_serializers::dataclass::DataclassBuilder; super::type_serializers::function::FunctionBeforeSerializerBuilder; super::type_serializers::function::FunctionAfterSerializerBuilder; super::type_serializers::function::FunctionPlainSerializerBuilder; @@ -124,6 +123,7 @@ combined_serializer! { Generator: super::type_serializers::generator::GeneratorSerializer; Dict: super::type_serializers::dict::DictSerializer; Model: super::type_serializers::model::ModelSerializer; + Dataclass: super::type_serializers::dataclass::DataclassSerializer; Url: super::type_serializers::url::UrlSerializer; MultiHostUrl: super::type_serializers::url::MultiHostUrlSerializer; Any: super::type_serializers::any::AnySerializer; @@ -328,42 +328,23 @@ pub(crate) fn to_json_bytes( Ok(bytes) } -pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Extra) -> PyResult<&'py PyDict> { - let py = value.py(); - let attrs: &PyDict = match value.getattr(intern!(py, "__dict__")) { - Ok(attr) => attr.downcast()?, - Err(_) => return slots_dc_dict(value), - }; - - if is_model && extra.exclude_unset { - let fields_set: &PySet = value.getattr(intern!(py, "__pydantic_fields_set__"))?.downcast()?; - - let new_attrs = attrs.copy()?; - for key in new_attrs.keys() { - if !fields_set.contains(key)? { - new_attrs.del_item(key)?; - } - } - Ok(new_attrs) - } else { - Ok(attrs) - } -} - static DC_FIELD_MARKER: GILOnceCell = GILOnceCell::new(); -pub(crate) fn slots_dc_dict(dc: &PyAny) -> PyResult<&PyDict> { - let py = dc.py(); - let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; - let dict = PyDict::new(py); - - // need to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)` +/// needed to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)` +pub(super) fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> { let field_type_marker_obj = DC_FIELD_MARKER.get_or_try_init(py, || { let field_ = py.import("dataclasses")?.getattr("_FIELD")?; Ok::(field_.into_py(py)) })?; - let field_type_marker = field_type_marker_obj.as_ref(py); + Ok(field_type_marker_obj.as_ref(py)) +} + +pub(crate) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> { + let py = dc.py(); + let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; + let dict = PyDict::new(py); + let field_type_marker = get_field_marker(py)?; for (field_name, field) in dc_fields.iter() { let field_type = field.getattr(intern!(py, "_field_type"))?; if field_type.is(field_type_marker) { diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index e0463d4b2..73afc06ca 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -1,14 +1,18 @@ -use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList, PyString}; +use pyo3::types::{PyDict, PyList, PyString, PyType}; +use pyo3::{intern, PyTraverseError, PyVisit}; +use std::borrow::Cow; use ahash::AHashMap; use crate::build_tools::{py_error_type, ExtraBehavior, SchemaDict}; use crate::definitions::DefinitionsBuilder; -use super::model::ModelSerializer; -use super::{BuildSerializer, CombinedSerializer, ComputedFields, FieldsMode, GeneralFieldsSerializer, SerField}; +use super::{ + get_field_marker, infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, py_err_se_err, + BuildSerializer, CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, + SerField, TypeSerializer, +}; pub struct DataclassArgsBuilder; @@ -54,16 +58,137 @@ impl BuildSerializer for DataclassArgsBuilder { } } -pub struct DataclassBuilder; +#[derive(Debug, Clone)] +pub struct DataclassSerializer { + class: Py, + serializer: Box, + fields: Vec>, + name: String, +} -impl BuildSerializer for DataclassBuilder { +impl BuildSerializer for DataclassSerializer { const EXPECTED_TYPE: &'static str = "dataclass"; fn build( schema: &PyDict, - config: Option<&PyDict>, + _config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - ModelSerializer::build(schema, config, definitions) + let py = schema.py(); + + // models ignore the parent config and always use the config from this model + let config = schema.get_as(intern!(py, "config"))?; + + let class: &PyType = schema.get_as_req(intern!(py, "cls"))?; + let sub_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?; + let serializer = Box::new(CombinedSerializer::build(sub_schema, config, definitions)?); + + let dc_fields: &PyDict = class.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; + let mut fields = Vec::with_capacity(dc_fields.len()); + + let field_type_marker = get_field_marker(py)?; + for (field_name, field) in dc_fields.iter() { + let field_type = field.getattr(intern!(py, "_field_type"))?; + if field_type.is(field_type_marker) { + let field_name: &PyString = field_name.downcast()?; + fields.push(field_name.into_py(py)); + } + } + + Ok(Self { + class: class.into(), + serializer, + fields, + name: class.getattr(intern!(py, "__name__"))?.extract()?, + } + .into()) + } +} + +impl DataclassSerializer { + fn allow_value(&self, value: &PyAny, extra: &Extra) -> PyResult { + match extra.check { + SerCheck::Strict => Ok(value.get_type().is(self.class.as_ref(value.py()))), + SerCheck::Lax => value.is_instance(self.class.as_ref(value.py())), + SerCheck::None => value.hasattr(intern!(value.py(), "__dataclass_fields__")), + } + } + + fn get_inner_value<'py>(&self, value: &'py PyAny) -> PyResult<&'py PyAny> { + let py = value.py(); + let dict = PyDict::new(py); + + for field_name in &self.fields { + let field_name = field_name.as_ref(py); + dict.set_item(field_name, value.getattr(field_name)?)?; + } + Ok(dict) + } +} + +impl TypeSerializer for DataclassSerializer { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + visit.call(&self.class)?; + self.serializer.py_gc_traverse(visit)?; + Ok(()) + } + + fn to_python( + &self, + value: &PyAny, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> PyResult { + let extra = Extra { + model: Some(value), + ..*extra + }; + if self.allow_value(value, &extra)? { + let inner_value = self.get_inner_value(value)?; + self.serializer.to_python(inner_value, include, exclude, &extra) + } else { + extra.warnings.on_fallback_py(self.get_name(), value, &extra)?; + infer_to_python(value, include, exclude, &extra) + } + } + + fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { + if self.allow_value(key, extra)? { + infer_json_key_known(&ObType::Dataclass, key, extra) + } else { + extra.warnings.on_fallback_py(&self.name, key, extra)?; + infer_json_key(key, extra) + } + } + + fn serde_serialize( + &self, + value: &PyAny, + serializer: S, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> Result { + let extra = Extra { + model: Some(value), + ..*extra + }; + if self.allow_value(value, &extra).map_err(py_err_se_err)? { + let inner_value = self.get_inner_value(value).map_err(py_err_se_err)?; + self.serializer + .serde_serialize(inner_value, serializer, include, exclude, &extra) + } else { + extra.warnings.on_fallback_ser::(self.get_name(), value, &extra)?; + infer_serialize(value, serializer, include, exclude, &extra) + } + } + + fn get_name(&self) -> &str { + &self.name + } + + fn retry_with_lax_check(&self) -> bool { + true } } diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index 75451076a..40eb8ed83 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -36,5 +36,5 @@ pub(self) use super::infer::{ }; pub(self) use super::ob_type::{IsType, ObType}; pub(self) use super::shared::{ - object_to_dict, to_json_bytes, BuildSerializer, CombinedSerializer, PydanticSerializer, TypeSerializer, + get_field_marker, to_json_bytes, BuildSerializer, CombinedSerializer, PydanticSerializer, TypeSerializer, }; diff --git a/src/serializers/type_serializers/model.rs b/src/serializers/type_serializers/model.rs index 1c8df5c6c..6107ff2b2 100644 --- a/src/serializers/type_serializers/model.rs +++ b/src/serializers/type_serializers/model.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyString, PyType}; +use pyo3::types::{PyDict, PySet, PyString, PyType}; use pyo3::{intern, PyTraverseError, PyVisit}; use ahash::AHashMap; @@ -10,9 +10,9 @@ use crate::build_tools::{py_error_type, ExtraBehavior, SchemaDict}; use crate::definitions::DefinitionsBuilder; use super::{ - infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, object_to_dict, py_err_se_err, - BuildSerializer, CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, - SerField, TypeSerializer, + infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer, + CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, SerField, + TypeSerializer, }; const ROOT_FIELD: &str = "root"; @@ -116,16 +116,28 @@ impl ModelSerializer { } } - fn get_inner_value<'py>(&self, value: &'py PyAny, extra: &Extra) -> PyResult<&'py PyAny> { - let py = value.py(); - let dict = object_to_dict(value, true, extra)?; + fn get_inner_value<'py>(&self, model: &'py PyAny, extra: &Extra) -> PyResult<&'py PyAny> { + let py = model.py(); + let mut attrs: &PyDict = model.getattr(intern!(py, "__dict__"))?.downcast()?; + + if extra.exclude_unset { + let fields_set: &PySet = model.getattr(intern!(py, "__pydantic_fields_set__"))?.downcast()?; + + let new_attrs = attrs.copy()?; + for key in new_attrs.keys() { + if !fields_set.contains(key)? { + new_attrs.del_item(key)?; + } + } + attrs = new_attrs; + } if self.has_extra { - let model_extra = value.getattr(intern!(py, "__pydantic_extra__"))?; - let py_tuple = (dict, model_extra).to_object(py); + let model_extra = model.getattr(intern!(py, "__pydantic_extra__"))?; + let py_tuple = (attrs, model_extra).to_object(py); Ok(py_tuple.into_ref(py)) } else { - Ok(dict) + Ok(attrs) } } } diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 28e3ac7a8..d5b80ca98 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -10,7 +10,7 @@ use crate::errors::{ErrorType, ValError, ValLineError, ValResult}; use crate::input::{GenericArguments, Input}; use crate::lookup_key::LookupKey; use crate::recursion_guard::RecursionGuard; -use crate::serializers::slots_dc_dict; +use crate::serializers::dataclass_to_dict; use crate::validators::function::convert_err; use super::arguments::{json_get, json_slice, py_get, py_slice}; @@ -489,10 +489,7 @@ impl Validator for DataclassValidator { let class = self.class.as_ref(py); if let Some(py_input) = input.input_is_instance(class) { if self.revalidate.should_revalidate(py_input, class) { - let input_dict = match py_input.getattr(intern!(py, "__dict__")) { - Ok(attr) => attr, - Err(_) => slots_dc_dict(py_input)?, - }; + let input_dict: &PyAny = dataclass_to_dict(py_input)?; let val_output = self .validator .validate(py, input_dict, extra, definitions, recursion_guard)?; diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index 1ff149cbc..36f23b819 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -36,11 +36,17 @@ class MyDataclass: class MyModel: - __pydantic_serializer__ = 42 - def __init__(self, **kwargs): + fields = {} for key, value in kwargs.items(): setattr(self, key, value) + fields[key] = core_schema.model_field(core_schema.any_schema()) + self.__pydantic_serializer__ = SchemaSerializer( + core_schema.model_schema(MyModel, core_schema.model_fields_schema(fields)) + ) + + def __repr__(self): + return f'MyModel({self.__dict__})' @pytest.mark.parametrize('value', [None, 1, 1.0, True, 'foo', [1, 2, 3], {'a': 1, 'b': 2}]) @@ -58,6 +64,7 @@ def test_any_json_round_trip(any_serializer, value): ({1, 2, 3}, {1, 2, 3}, IsList(1, 2, 3, check_order=False)), ({1, '2', b'3'}, {1, '2', b'3'}, IsList(1, '2', '3', check_order=False)), ], + ids=repr, ) def test_any_python(any_serializer, input_value, expected_plain, expected_json_obj): assert any_serializer.to_python(input_value) == expected_plain @@ -247,8 +254,13 @@ class FieldsSetModel: __slots__ = '__dict__', '__pydantic_extra__', '__pydantic_fields_set__' def __init__(self, **kwargs): + fields = {} for key, value in kwargs.items(): setattr(self, key, value) + fields[key] = core_schema.model_field(core_schema.any_schema()) + self.__pydantic_serializer__ = SchemaSerializer( + core_schema.model_schema(MyModel, core_schema.model_fields_schema(fields)) + ) def test_exclude_unset(any_serializer): @@ -444,14 +456,11 @@ class Foo: def test_any_model(): + @dataclasses.dataclass class Foo: a: str b: bytes - def __init__(self, a: str, b: bytes): - self.a = a - self.b = b - # Build a schema that does not include the field 'b', to test that it is not serialized schema = core_schema.dataclass_schema( Foo, @@ -473,6 +482,7 @@ def __init__(self, a: str, b: bytes): assert j == b'{"a":"hello"}' assert s.to_python(Foo(a='hello', b=b'more'), exclude={'a'}) == IsStrictDict() + assert s.to_json(Foo(a='hello', b=b'more'), exclude={'a'}) == b'{}' def test_dataclass_classvar(any_serializer): @@ -484,6 +494,7 @@ class Foo: foo = Foo(1, 'a') assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a') + assert any_serializer.to_json(foo) == b'{"a":1,"b":"a"}' @dataclasses.dataclass class Foo2(Foo): @@ -491,6 +502,7 @@ class Foo2(Foo): foo2 = Foo2(2, 'b') assert any_serializer.to_python(foo2) == IsStrictDict(a=2, b='b') + assert any_serializer.to_json(foo2) == b'{"a":2,"b":"b"}' @pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') @@ -502,6 +514,7 @@ class Foo: foo = Foo(1, 'a') assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a') + assert any_serializer.to_json(foo) == b'{"a":1,"b":"a"}' @dataclasses.dataclass(slots=True) class Foo2(Foo): @@ -509,6 +522,7 @@ class Foo2(Foo): foo2 = Foo2(2, 'b') assert any_serializer.to_python(foo2) == IsStrictDict(a=2, b='b') + assert any_serializer.to_json(foo2) == b'{"a":2,"b":"b"}' @pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') @@ -522,3 +536,25 @@ class Foo: foo = Foo(1, 'a', 42) assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a') + assert any_serializer.to_json(foo) == b'{"a":1,"b":"a"}' + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +def test_slots_mixed(any_serializer): + @dataclasses.dataclass(slots=True) + class Model: + x: int + y: dataclasses.InitVar[str] + z: ClassVar[str] = 'z-classvar' + + @dataclasses.dataclass + class SubModel(Model): + x2: int + y2: dataclasses.InitVar[str] + z2: ClassVar[str] = 'z2-classvar' + + dc = SubModel(x=1, y='a', x2=2, y2='b') + assert dataclasses.asdict(dc) == {'x': 1, 'x2': 2} + + assert any_serializer.to_python(dc) == {'x': 1, 'x2': 2} + assert any_serializer.to_json(dc) == b'{"x":1,"x2":2}' diff --git a/tests/serializers/test_dataclasses.py b/tests/serializers/test_dataclasses.py index 980510b35..b0c15bfb1 100644 --- a/tests/serializers/test_dataclasses.py +++ b/tests/serializers/test_dataclasses.py @@ -1,6 +1,10 @@ import dataclasses import json import platform +import sys +from typing import ClassVar + +import pytest from pydantic_core import SchemaSerializer, core_schema @@ -120,3 +124,38 @@ def c(self) -> str: assert s.to_python(FooProp(a='hello', b=b'more'), exclude={'b'}) == IsStrictDict(a='hello', c='hello more') assert s.to_json(FooProp(a='hello', b=b'more'), include={'a'}) == b'{"a":"hello"}' + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +def test_slots_mixed(): + @dataclasses.dataclass(slots=True) + class Model: + x: int + y: dataclasses.InitVar[str] + z: ClassVar[str] = 'z-classvar' + + @dataclasses.dataclass + class SubModel(Model): + x2: int + y2: dataclasses.InitVar[str] + z2: ClassVar[str] = 'z2-classvar' + + schema = core_schema.dataclass_schema( + SubModel, + core_schema.dataclass_args_schema( + 'SubModel', + [ + core_schema.dataclass_field(name='x', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='y', init_only=True, schema=core_schema.str_schema()), + core_schema.dataclass_field(name='x2', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='y2', init_only=True, schema=core_schema.str_schema()), + ], + ), + slots=['x'], + ) + dc = SubModel(x=1, y='a', x2=2, y2='b') + assert dataclasses.asdict(dc) == {'x': 1, 'x2': 2} + + s = SchemaSerializer(schema) + assert s.to_python(dc) == {'x': 1, 'x2': 2} + assert s.to_json(dc) == b'{"x":1,"x2":2}' diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index 935d101a0..dcd86e532 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -1,3 +1,4 @@ +import dataclasses import re from datetime import date from typing import Any @@ -19,6 +20,12 @@ class MyModel: __slots__ = '__dict__', '__pydantic_extra__', '__pydantic_fields_set__' +@dataclasses.dataclass +class MyDataclass: + x: int + y: str + + def ids_function(val): if callable(val): return val.__name__ @@ -248,15 +255,13 @@ def args(*args, **kwargs): ), ( core_schema.dataclass_schema, - # MyModel should be a dataclass, but I'm being lazy here - args(MyModel, {'type': 'int'}), - {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel}, + args(MyDataclass, {'type': 'int'}), + {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyDataclass}, ), ( core_schema.dataclass_schema, - # MyModel should be a dataclass, but I'm being lazy here - args(MyModel, {'type': 'int'}, slots=['a']), - {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel, 'slots': ['a']}, + args(MyDataclass, {'type': 'int'}, slots=['a']), + {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyDataclass, 'slots': ['a']}, ), ] diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index 58b6d4f87..a90fd45b1 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -1,7 +1,7 @@ import dataclasses import re import sys -from typing import Any, Dict, List, Optional, Union +from typing import Any, ClassVar, Dict, List, Optional, Union import pytest from dirty_equals import IsListOrTuple, IsStr @@ -1384,3 +1384,36 @@ def test_slots_dataclass_subclass(revalidate_instances, input_value, expected): dc = v.validate_python(input_value) assert dataclasses.is_dataclass(dc) assert dataclasses.asdict(dc) == expected + + +def test_slots_mixed(): + @dataclasses.dataclass(slots=True) + class Model: + x: int + y: dataclasses.InitVar[str] + z: ClassVar[str] = 'z-classvar' + + @dataclasses.dataclass + class SubModel(Model): + x2: int + y2: dataclasses.InitVar[str] + z2: ClassVar[str] = 'z2-classvar' + + schema = core_schema.dataclass_schema( + SubModel, + core_schema.dataclass_args_schema( + 'SubModel', + [ + core_schema.dataclass_field(name='x', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='y', init_only=True, schema=core_schema.str_schema()), + core_schema.dataclass_field(name='x2', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='y2', init_only=True, schema=core_schema.str_schema()), + ], + ), + slots=['x'], + ) + v = SchemaValidator(schema) + dc = v.validate_python({'x': 1, 'y': 'a', 'x2': 2, 'y2': 'b'}) + assert dc.x == 1 + assert dc.x2 == 2 + assert dataclasses.asdict(dc) == {'x': 1, 'x2': 2} From 559e0c6e43de80ef5d0956a6e58cec445529f7f9 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Wed, 24 May 2023 12:43:15 +0100 Subject: [PATCH 14/14] add dataclass.fields to schema --- pydantic_core/core_schema.py | 12 +++- src/serializers/mod.rs | 1 - src/serializers/shared.rs | 2 +- src/serializers/type_serializers/dataclass.rs | 22 +++---- src/serializers/type_serializers/mod.rs | 4 +- src/validators/dataclass.rs | 54 ++++++++-------- tests/serializers/test_any.py | 2 + tests/serializers/test_dataclasses.py | 7 ++- tests/test_schema_functions.py | 8 +-- tests/validators/test_dataclasses.py | 61 ++++++++++++++++--- .../validators/test_definitions_recursive.py | 1 + 11 files changed, 109 insertions(+), 65 deletions(-) diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index 7743829f9..7c2b4d00c 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -3093,6 +3093,7 @@ class DataclassSchema(TypedDict, total=False): type: Required[Literal['dataclass']] cls: Required[Type[Any]] schema: Required[CoreSchema] + fields: Required[List[str]] cls_name: str post_init: bool # default: False revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never' @@ -3101,12 +3102,13 @@ class DataclassSchema(TypedDict, total=False): ref: str metadata: Any serialization: SerSchema - slots: List[str] + slots: bool def dataclass_schema( cls: Type[Any], schema: CoreSchema, + fields: List[str], *, cls_name: str | None = None, post_init: bool | None = None, @@ -3116,7 +3118,7 @@ def dataclass_schema( metadata: Any = None, serialization: SerSchema | None = None, frozen: bool | None = None, - slots: List[str] | None = None, + slots: bool | None = None, ) -> DataclassSchema: """ Returns a schema for a dataclass. As with `ModelSchema`, this schema can only be used as a field within @@ -3125,6 +3127,8 @@ def dataclass_schema( Args: cls: The dataclass type, used to perform subclass checks schema: The schema to use for the dataclass fields + fields: Fields of the dataclass, this is used in serialization and in validation during re-validation + and while validating assignment cls_name: The name to use in error locs, etc; this is useful for generics (default: `cls.__name__`) post_init: Whether to call `__post_init__` after validation revalidate_instances: whether instances of models and dataclasses (including subclass instances) @@ -3134,11 +3138,13 @@ def dataclass_schema( metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema frozen: Whether the dataclass is frozen - slots: The slots to use for the dataclass, set only if `slots=True` on the dataclass + slots: Whether `slots=True` on the dataclass, means each field is assigned independently, rather than + simply setting `__dict__`, default false """ return dict_not_none( type='dataclass', cls=cls, + fields=fields, cls_name=cls_name, schema=schema, post_init=post_init, diff --git a/src/serializers/mod.rs b/src/serializers/mod.rs index c38d5c187..8013c1767 100644 --- a/src/serializers/mod.rs +++ b/src/serializers/mod.rs @@ -11,7 +11,6 @@ use config::SerializationConfig; pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue}; use extra::{CollectWarnings, SerRecursionGuard}; pub(crate) use extra::{Extra, SerMode, SerializationState}; -pub(crate) use shared::dataclass_to_dict; pub use shared::CombinedSerializer; use shared::{to_json_bytes, BuildSerializer, TypeSerializer}; diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index d6f559711..c33cc61e8 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -339,7 +339,7 @@ pub(super) fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> { Ok(field_type_marker_obj.as_ref(py)) } -pub(crate) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> { +pub(super) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> { let py = dc.py(); let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; let dict = PyDict::new(py); diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index 73afc06ca..8626e202e 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -9,9 +9,9 @@ use crate::build_tools::{py_error_type, ExtraBehavior, SchemaDict}; use crate::definitions::DefinitionsBuilder; use super::{ - get_field_marker, infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, py_err_se_err, - BuildSerializer, CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, - SerField, TypeSerializer, + infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer, + CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, SerField, + TypeSerializer, }; pub struct DataclassArgsBuilder; @@ -83,17 +83,11 @@ impl BuildSerializer for DataclassSerializer { let sub_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?; let serializer = Box::new(CombinedSerializer::build(sub_schema, config, definitions)?); - let dc_fields: &PyDict = class.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; - let mut fields = Vec::with_capacity(dc_fields.len()); - - let field_type_marker = get_field_marker(py)?; - for (field_name, field) in dc_fields.iter() { - let field_type = field.getattr(intern!(py, "_field_type"))?; - if field_type.is(field_type_marker) { - let field_name: &PyString = field_name.downcast()?; - fields.push(field_name.into_py(py)); - } - } + let fields = schema + .get_as_req::<&PyList>(intern!(py, "fields"))? + .iter() + .map(|s| Ok(s.downcast::()?.into_py(py))) + .collect::>>()?; Ok(Self { class: class.into(), diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index 40eb8ed83..fde74742c 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -35,6 +35,4 @@ pub(self) use super::infer::{ infer_to_python_known, }; pub(self) use super::ob_type::{IsType, ObType}; -pub(self) use super::shared::{ - get_field_marker, to_json_bytes, BuildSerializer, CombinedSerializer, PydanticSerializer, TypeSerializer, -}; +pub(self) use super::shared::{to_json_bytes, BuildSerializer, CombinedSerializer, PydanticSerializer, TypeSerializer}; diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index d5b80ca98..781cdb102 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -10,7 +10,6 @@ use crate::errors::{ErrorType, ValError, ValLineError, ValResult}; use crate::input::{GenericArguments, Input}; use crate::lookup_key::LookupKey; use crate::recursion_guard::RecursionGuard; -use crate::serializers::dataclass_to_dict; use crate::validators::function::convert_err; use super::arguments::{json_get, json_slice, py_get, py_slice}; @@ -411,11 +410,12 @@ pub struct DataclassValidator { strict: bool, validator: Box, class: Py, + fields: Vec>, post_init: Option>, revalidate: Revalidate, name: String, frozen: bool, - slots: Option>>, + slots: bool, } impl BuildValidator for DataclassValidator { @@ -442,21 +442,17 @@ impl BuildValidator for DataclassValidator { None }; - let slots = match schema.get_as::<&PyList>(intern!(py, "slots"))? { - Some(slots) => { - let slots = slots - .iter() - .map(|s| Ok(s.downcast::()?.into_py(py))) - .collect::>>()?; - Some(slots) - } - None => None, - }; + let fields = schema + .get_as_req::<&PyList>(intern!(py, "fields"))? + .iter() + .map(|s| Ok(s.downcast::()?.into_py(py))) + .collect::>>()?; Ok(Self { strict: is_strict(schema, config)?, validator: Box::new(validator), class: class.into(), + fields, post_init, revalidate: Revalidate::from_str(schema_or_config_same( schema, @@ -465,7 +461,7 @@ impl BuildValidator for DataclassValidator { )?)?, name, frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false), - slots, + slots: schema.get_as(intern!(py, "slots"))?.unwrap_or(false), } .into()) } @@ -489,7 +485,7 @@ impl Validator for DataclassValidator { let class = self.class.as_ref(py); if let Some(py_input) = input.input_is_instance(class) { if self.revalidate.should_revalidate(py_input, class) { - let input_dict: &PyAny = dataclass_to_dict(py_input)?; + let input_dict: &PyAny = self.dataclass_to_dict(py, py_input)?; let val_output = self .validator .validate(py, input_dict, extra, definitions, recursion_guard)?; @@ -530,17 +526,7 @@ impl Validator for DataclassValidator { return Err(ValError::new(ErrorType::FrozenInstance, field_value)); } - let new_dict = if let Some(ref slots) = self.slots { - let slots_dict = PyDict::new(py); - for slot in slots { - let slot = slot.as_ref(py); - slots_dict.set_item(slot, obj.getattr(slot)?)?; - } - slots_dict - } else { - let dunder_dict: &PyDict = obj.getattr(intern!(py, "__dict__"))?.downcast()?; - dunder_dict.copy()? - }; + let new_dict = self.dataclass_to_dict(py, obj)?; new_dict.set_item(field_name, field_value)?; @@ -556,8 +542,10 @@ impl Validator for DataclassValidator { let (dc_dict, _): (&PyDict, PyObject) = val_assignment_result.extract(py)?; - if self.slots.is_some() { - let value = dc_dict.get_item(field_name).unwrap(); + if self.slots { + let value = dc_dict + .get_item(field_name) + .ok_or_else(|| PyKeyError::new_err(field_name.to_string()))?; force_setattr(py, obj, field_name, value)?; } else { force_setattr(py, obj, intern!(py, "__dict__"), dc_dict)?; @@ -613,6 +601,16 @@ impl DataclassValidator { Ok(self_instance.into_py(py)) } + fn dataclass_to_dict<'py>(&self, py: Python<'py>, dc: &'py PyAny) -> PyResult<&'py PyDict> { + let dict = PyDict::new(py); + + for field_name in &self.fields { + let field_name = field_name.as_ref(py); + dict.set_item(field_name, dc.getattr(field_name)?)?; + } + Ok(dict) + } + fn set_dict_call<'s, 'data>( &'s self, py: Python<'data>, @@ -621,7 +619,7 @@ impl DataclassValidator { input: &'data impl Input<'data>, ) -> ValResult<'data, ()> { let (dc_dict, post_init_kwargs): (&PyAny, &PyAny) = val_output.extract(py)?; - if self.slots.is_some() { + if self.slots { let dc_dict: &PyDict = dc_dict.downcast()?; for (key, value) in dc_dict.iter() { force_setattr(py, dc, key, value)?; diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index 36f23b819..377713111 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -439,6 +439,7 @@ class Foo: core_schema.dataclass_args_schema( 'Foo', [core_schema.dataclass_field(name='a', schema=core_schema.str_schema())] ), + ['a'], ) Foo.__pydantic_serializer__ = SchemaSerializer(schema) @@ -467,6 +468,7 @@ class Foo: core_schema.dataclass_args_schema( 'Foo', [core_schema.dataclass_field(name='a', schema=core_schema.str_schema())] ), + ['a'], ) Foo.__pydantic_validator__ = SchemaValidator(schema) Foo.__pydantic_serializer__ = SchemaSerializer(schema) diff --git a/tests/serializers/test_dataclasses.py b/tests/serializers/test_dataclasses.py index b0c15bfb1..57b20c4b2 100644 --- a/tests/serializers/test_dataclasses.py +++ b/tests/serializers/test_dataclasses.py @@ -32,6 +32,7 @@ def test_dataclass(): core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema()), ], ), + ['a', 'b'], ) s = SchemaSerializer(schema) assert s.to_python(Foo(a='hello', b=b'more')) == IsStrictDict(a='hello', b=b'more') @@ -57,6 +58,7 @@ def test_serialization_exclude(): core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema(), serialization_exclude=True), ], ), + ['a', 'b'], ) s = SchemaSerializer(schema) assert s.to_python(Foo(a='hello', b=b'more')) == {'a': 'hello'} @@ -79,6 +81,7 @@ def test_serialization_alias(): core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema(), serialization_alias='BAR'), ], ), + ['a', 'b'], ) s = SchemaSerializer(schema) assert s.to_python(Foo(a='hello', b=b'more')) == IsStrictDict(a='hello', BAR=b'more') @@ -111,6 +114,7 @@ def c(self) -> str: ], computed_fields=[core_schema.computed_field('c', core_schema.str_schema())], ), + ['a', 'b'], ) s = SchemaSerializer(schema) assert s.to_python(FooProp(a='hello', b=b'more')) == IsStrictDict(a='hello', b=b'more', c='hello more') @@ -151,7 +155,8 @@ class SubModel(Model): core_schema.dataclass_field(name='y2', init_only=True, schema=core_schema.str_schema()), ], ), - slots=['x'], + ['x', 'x2'], + slots=True, ) dc = SubModel(x=1, y='a', x2=2, y2='b') assert dataclasses.asdict(dc) == {'x': 1, 'x2': 2} diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index dcd86e532..86e8bf28c 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -255,13 +255,13 @@ def args(*args, **kwargs): ), ( core_schema.dataclass_schema, - args(MyDataclass, {'type': 'int'}), - {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyDataclass}, + args(MyDataclass, {'type': 'int'}, ['foobar']), + {'type': 'dataclass', 'schema': {'type': 'int'}, 'fields': ['foobar'], 'cls': MyDataclass}, ), ( core_schema.dataclass_schema, - args(MyDataclass, {'type': 'int'}, slots=['a']), - {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyDataclass, 'slots': ['a']}, + args(MyDataclass, {'type': 'int'}, ['foobar'], slots=True), + {'type': 'dataclass', 'schema': {'type': 'int'}, 'fields': ['foobar'], 'cls': MyDataclass, 'slots': True}, ), ] diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index a90fd45b1..e9a8d73d8 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -202,6 +202,7 @@ def test_dataclass(): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -249,7 +250,9 @@ class DuplicateDifferent: ('always', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), ('always', FooDataclass(a='hello', b=True), {'a': 'hello', 'b': True}), ('always', FooDataclassSame(a='hello', b=True), {'a': 'hello', 'b': True}), - ('always', FooDataclassMore(a='hello', b=True, c='more'), Err(r'c\s+Unexpected keyword argument')), + # no error because we only look for fields in schema['fields'] + ('always', FooDataclassMore(a='hello', b=True, c='more'), {'a': 'hello', 'b': True}), + ('always', FooDataclassSame(a='hello', b='wrong'), Err(r'b\s+Input should be a valid boolean,')), ('always', DuplicateDifferent(a='hello', b=True), Err('should be a dictionary or an instance of FooDataclass')), # revalidate_instances='subclass-instances' ('subclass-instances', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), @@ -257,7 +260,9 @@ class DuplicateDifferent: ('subclass-instances', FooDataclass(a=b'hello', b='true'), {'a': b'hello', 'b': 'true'}), ('subclass-instances', FooDataclassSame(a='hello', b=True), {'a': 'hello', 'b': True}), ('subclass-instances', FooDataclassSame(a=b'hello', b='true'), {'a': 'hello', 'b': True}), - ('subclass-instances', FooDataclassMore(a='hello', b=True, c='more'), Err('Unexpected keyword argument')), + # no error because we only look for fields in schema['fields'] + ('subclass-instances', FooDataclassMore(a='hello', b=True, c='more'), {'a': 'hello', 'b': True}), + ('subclass-instances', FooDataclassSame(a='hello', b='wrong'), Err(r'b\s+Input should be a valid boolean,')), ('subclass-instances', DuplicateDifferent(a='hello', b=True), Err('dictionary or an instance of FooDataclass')), # revalidate_instances='never' ('never', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), @@ -279,6 +284,7 @@ def test_dataclass_subclass(revalidate_instances, input_value, expected): ], extra_behavior='forbid', ), + ['a', 'b'], revalidate_instances=revalidate_instances, ) v = SchemaValidator(schema) @@ -307,6 +313,7 @@ def test_dataclass_subclass_strict_never_revalidate(): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + ['a', 'b'], revalidate_instances='never', strict=True, ) @@ -333,6 +340,7 @@ def test_dataclass_subclass_subclass_revalidate(): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + ['a', 'b'], revalidate_instances='subclass-instances', strict=True, ) @@ -365,6 +373,7 @@ def __post_init__(self): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + ['a', 'b'], post_init=True, ) @@ -398,6 +407,7 @@ def __post_init__(self, c: int): ], collect_init_only=True, ), + ['a', 'b'], post_init=True, ) @@ -433,6 +443,7 @@ def __post_init__(self, *args): ], collect_init_only=True, ), + ['a', 'b'], post_init=True, ) @@ -463,6 +474,7 @@ def test_dataclass_exact_validation(revalidate_instances, input_value, expected) core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + ['a', 'b'], revalidate_instances=revalidate_instances, ) @@ -496,6 +508,7 @@ def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str: ), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -527,6 +540,7 @@ def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> str: ), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -559,6 +573,7 @@ def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes: ), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -594,6 +609,7 @@ def validate_b( ), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -627,6 +643,7 @@ def validate_b( ), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -652,6 +669,7 @@ def __init__(self, *args, **kwargs): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema(), kw_only=False), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -675,6 +693,7 @@ class Foo: core_schema.dataclass_field(name='b', schema=core_schema.bool_schema(), validation_alias=['bAlias', 0]), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -716,6 +735,7 @@ class Foo: core_schema.dataclass_field(name='b', schema=core_schema.bool_schema(), validation_alias=['bAlias', 0]), ], ), + ['a', 'b'], ) v = SchemaValidator(schema, {'loc_by_alias': False}) @@ -768,6 +788,7 @@ def __post_init__(self, c): ], collect_init_only=True, ), + ['a', 'b', 'c'], post_init=True, ) v = SchemaValidator(schema) @@ -788,6 +809,7 @@ def test_dataclass_validate_assignment(): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema(), kw_only=False), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -817,7 +839,7 @@ def test_dataclass_validate_assignment(): assert not hasattr(foo, 'c') # wrong arguments - with pytest.raises(AttributeError, match="'str' object has no attribute '__dict__'"): + with pytest.raises(AttributeError, match="'str' object has no attribute 'a'"): v.validate_assignment('field_a', 'c', 123) @@ -847,6 +869,7 @@ def func(x, info): core_schema.dataclass_field('field_c', core_schema.int_schema()), ], ), + ['field_a', 'field_b', 'field_c'], ) ) @@ -874,6 +897,7 @@ class MyModel: core_schema.dataclass_schema( MyModel, core_schema.dataclass_args_schema('MyModel', [core_schema.dataclass_field('f', core_schema.str_schema())]), + ['f'], frozen=True, ) ) @@ -901,6 +925,7 @@ class MyModel: core_schema.dataclass_args_schema( 'MyModel', [core_schema.dataclass_field('f', core_schema.str_schema(), frozen=True)] ), + ['f'], ) ) @@ -937,6 +962,7 @@ class MyModel: core_schema.dataclass_args_schema( 'MyModel', [core_schema.dataclass_field('f', core_schema.str_schema())], **schema_extra_behavior_kw ), + ['f'], ), config=config, ) @@ -984,6 +1010,7 @@ class MyModel: core_schema.dataclass_args_schema( 'MyModel', [core_schema.dataclass_field('f', core_schema.str_schema())], **schema_extra_behavior_kw ), + ['f'], ), config=config, ) @@ -1029,6 +1056,7 @@ class MyModel: core_schema.dataclass_args_schema( 'MyModel', [core_schema.dataclass_field('f', core_schema.str_schema())], **schema_extra_behavior_kw ), + ['f'], ), config=config, ) @@ -1063,6 +1091,7 @@ class Model: 'Model', [core_schema.dataclass_field('number', core_schema.int_schema())] ), ), + ['number'], ) v = SchemaValidator(cs) @@ -1094,6 +1123,7 @@ class Model: 'Model', [core_schema.dataclass_field('number', core_schema.int_schema())] ), ), + ['number'], ) v = SchemaValidator(cs) @@ -1128,6 +1158,7 @@ class Model: 'Model', [core_schema.dataclass_field('number', core_schema.int_schema())] ), ), + ['number'], ) v = SchemaValidator(cs) @@ -1168,6 +1199,7 @@ def test_custom_dataclass_names(): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + ['a', 'b'], cls_name='FooDataclass[cls_name]', ), core_schema.none_schema(), @@ -1176,6 +1208,7 @@ def test_custom_dataclass_names(): ) ], ), + ['foo'], ) v = SchemaValidator(schema) @@ -1204,7 +1237,8 @@ class Model: core_schema.dataclass_args_schema( 'Model', [core_schema.dataclass_field(name='x', schema=core_schema.int_schema())] ), - slots=['x'], + ['x'], + slots=True, ) val = SchemaValidator(schema) @@ -1249,7 +1283,8 @@ def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes: ), ], ), - slots=['a', 'b'], + ['a', 'b'], + slots=True, ) v = SchemaValidator(schema) @@ -1283,7 +1318,8 @@ def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str: ), ], ), - slots=['a', 'b'], + ['a', 'b'], + slots=True, ) v = SchemaValidator(schema) @@ -1325,7 +1361,7 @@ class DuplicateDifferentSlots: ('always', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), ('always', FooDataclassSlots(a='hello', b=True), {'a': 'hello', 'b': True}), ('always', FooDataclassSameSlots(a='hello', b=True), {'a': 'hello', 'b': True}), - ('always', FooDataclassMoreSlots(a='hello', b=True, c='more'), Err(r'c\s+Unexpected keyword argument')), + ('always', FooDataclassMoreSlots(a='hello', b=True, c='more'), {'a': 'hello', 'b': True}), ( 'always', DuplicateDifferentSlots(a='hello', b=True), @@ -1337,7 +1373,9 @@ class DuplicateDifferentSlots: ('subclass-instances', FooDataclassSlots(a=b'hello', b='true'), {'a': b'hello', 'b': 'true'}), ('subclass-instances', FooDataclassSameSlots(a='hello', b=True), {'a': 'hello', 'b': True}), ('subclass-instances', FooDataclassSameSlots(a=b'hello', b='true'), {'a': 'hello', 'b': True}), - ('subclass-instances', FooDataclassMoreSlots(a='hello', b=True, c='more'), Err('Unexpected keyword argument')), + # no error because we don't look for fields unless their in schema['fields'] + ('subclass-instances', FooDataclassMoreSlots(a='hello', b=True, c='more'), {'a': 'hello', 'b': True}), + ('subclass-instances', FooDataclassSameSlots(a=b'hello', b='wrong'), Err('Input should be a valid boolean,')), ( 'subclass-instances', DuplicateDifferentSlots(a='hello', b=True), @@ -1368,8 +1406,9 @@ def test_slots_dataclass_subclass(revalidate_instances, input_value, expected): ], extra_behavior='forbid', ), + ['a', 'b'], revalidate_instances=revalidate_instances, - slots=['a', 'b'], + slots=True, ) v = SchemaValidator(schema) @@ -1386,6 +1425,7 @@ def test_slots_dataclass_subclass(revalidate_instances, input_value, expected): assert dataclasses.asdict(dc) == expected +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') def test_slots_mixed(): @dataclasses.dataclass(slots=True) class Model: @@ -1410,7 +1450,8 @@ class SubModel(Model): core_schema.dataclass_field(name='y2', init_only=True, schema=core_schema.str_schema()), ], ), - slots=['x'], + ['x'], + slots=True, ) v = SchemaValidator(schema) dc = v.validate_python({'x': 1, 'y': 'a', 'x2': 2, 'y2': 'b'}) diff --git a/tests/validators/test_definitions_recursive.py b/tests/validators/test_definitions_recursive.py index 8d09ee5a6..21e050f56 100644 --- a/tests/validators/test_definitions_recursive.py +++ b/tests/validators/test_definitions_recursive.py @@ -889,6 +889,7 @@ class Model: ) ], ), + ['x'], ref='model', ) v = SchemaValidator(schema, config=core_schema.CoreConfig(revalidate_instances='always'))