diff --git a/pydantic_core/core_schema.py b/pydantic_core/core_schema.py index 162a1cc5a..7c2b4d00c 100644 --- a/pydantic_core/core_schema.py +++ b/pydantic_core/core_schema.py @@ -3093,6 +3093,7 @@ class DataclassSchema(TypedDict, total=False): type: Required[Literal['dataclass']] cls: Required[Type[Any]] schema: Required[CoreSchema] + fields: Required[List[str]] cls_name: str post_init: bool # default: False revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never' @@ -3101,11 +3102,13 @@ class DataclassSchema(TypedDict, total=False): ref: str metadata: Any serialization: SerSchema + slots: bool def dataclass_schema( cls: Type[Any], schema: CoreSchema, + fields: List[str], *, cls_name: str | None = None, post_init: bool | None = None, @@ -3115,14 +3118,17 @@ def dataclass_schema( metadata: Any = None, serialization: SerSchema | None = None, frozen: bool | None = None, + slots: bool | None = None, ) -> DataclassSchema: """ Returns a schema for a dataclass. As with `ModelSchema`, this schema can only be used as a field within another schema, not as the root type. Args: - cls: The dataclass type, used to to perform subclass checks + cls: The dataclass type, used to perform subclass checks schema: The schema to use for the dataclass fields + fields: Fields of the dataclass, this is used in serialization and in validation during re-validation + and while validating assignment cls_name: The name to use in error locs, etc; this is useful for generics (default: `cls.__name__`) post_init: Whether to call `__post_init__` after validation revalidate_instances: whether instances of models and dataclasses (including subclass instances) @@ -3132,10 +3138,13 @@ def dataclass_schema( metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema frozen: Whether the dataclass is frozen + slots: Whether `slots=True` on the dataclass, means each field is assigned independently, rather than + simply setting `__dict__`, default false """ return dict_not_none( type='dataclass', cls=cls, + fields=fields, cls_name=cls_name, schema=schema, post_init=post_init, @@ -3145,6 +3154,7 @@ def dataclass_schema( metadata=metadata, serialization=serialization, frozen=frozen, + slots=slots, ) diff --git a/src/input/input_abstract.rs b/src/input/input_abstract.rs index 280026ffe..3b8e6b4c0 100644 --- a/src/input/input_abstract.rs +++ b/src/input/input_abstract.rs @@ -1,6 +1,6 @@ use std::fmt; -use pyo3::types::{PyDict, PyString, PyType}; +use pyo3::types::{PyDict, PyType}; use pyo3::{intern, prelude::*}; use crate::errors::{InputValue, LocItem, ValResult}; @@ -40,15 +40,10 @@ pub trait Input<'a>: fmt::Debug + ToPyObject { fn is_none(&self) -> bool; - #[cfg_attr(has_no_coverage, no_coverage)] - fn input_get_attr(&self, _name: &PyString) -> Option> { + fn input_is_instance(&self, _class: &PyType) -> Option<&PyAny> { None } - fn is_exact_instance(&self, _class: &PyType) -> bool { - false - } - fn is_python(&self) -> bool { false } diff --git a/src/input/input_python.rs b/src/input/input_python.rs index a19f5d7ca..b6cea1b5d 100644 --- a/src/input/input_python.rs +++ b/src/input/input_python.rs @@ -109,12 +109,12 @@ impl<'a> Input<'a> for PyAny { self.is_none() } - fn input_get_attr(&self, name: &PyString) -> Option> { - Some(self.getattr(name)) - } - - fn is_exact_instance(&self, class: &PyType) -> bool { - self.get_type().is(class) + fn input_is_instance(&self, class: &PyType) -> Option<&PyAny> { + if self.is_instance(class).unwrap_or(false) { + Some(self) + } else { + None + } } fn is_python(&self) -> bool { diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index c96a6981f..63a6dd037 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -21,7 +21,7 @@ use super::errors::{py_err_se_err, PydanticSerializationError}; use super::extra::{Extra, SerMode}; use super::filter::AnyFilter; use super::ob_type::ObType; -use super::shared::object_to_dict; +use super::shared::dataclass_to_dict; pub(crate) fn infer_to_python( value: &PyAny, @@ -97,29 +97,23 @@ pub(crate) fn infer_to_python_known( Ok::(new_dict.into_py(py)) }; - let serialize_with_serializer = |value: &PyAny, is_model: bool| { - if let Ok(py_serializer) = value.getattr(intern!(py, "__pydantic_serializer__")) { - if let Ok(serializer) = py_serializer.extract::() { - let extra = serializer.build_extra( - py, - extra.mode, - extra.by_alias, - extra.warnings, - extra.exclude_unset, - extra.exclude_defaults, - extra.exclude_none, - extra.round_trip, - extra.rec_guard, - extra.serialize_unknown, - extra.fallback, - ); - return serializer.serializer.to_python(value, include, exclude, &extra); - } - } - // Fallback to dict serialization if `__pydantic_serializer__` is not set. - // This currently only affects non-pydantic dataclasses. - let dict = object_to_dict(value, is_model, extra)?; - serialize_dict(dict) + let serialize_with_serializer = || { + let py_serializer = value.getattr(intern!(py, "__pydantic_serializer__"))?; + let serializer: SchemaSerializer = py_serializer.extract()?; + let extra = serializer.build_extra( + py, + extra.mode, + extra.by_alias, + extra.warnings, + extra.exclude_unset, + extra.exclude_defaults, + extra.exclude_none, + extra.round_trip, + extra.rec_guard, + extra.serialize_unknown, + extra.fallback, + ); + serializer.serializer.to_python(value, include, exclude, &extra) }; let value = match extra.mode { @@ -191,8 +185,8 @@ pub(crate) fn infer_to_python_known( let py_url: PyMultiHostUrl = value.extract()?; py_url.__str__().into_py(py) } - ObType::PydanticSerializable => serialize_with_serializer(value, true)?, - ObType::Dataclass => serialize_with_serializer(value, false)?, + ObType::PydanticSerializable => serialize_with_serializer()?, + ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?, ObType::Enum => { let v = value.getattr(intern!(py, "value"))?; infer_to_python(v, include, exclude, extra)?.into_py(py) @@ -257,8 +251,8 @@ pub(crate) fn infer_to_python_known( } new_dict.into_py(py) } - ObType::PydanticSerializable => serialize_with_serializer(value, true)?, - ObType::Dataclass => serialize_with_serializer(value, false)?, + ObType::PydanticSerializable => serialize_with_serializer()?, + ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?, ObType::Generator => { let iter = super::type_serializers::generator::SerializationIterator::new( value.downcast()?, @@ -406,36 +400,6 @@ pub(crate) fn infer_serialize_known( }}; } - macro_rules! serialize_with_serializer { - ($py_serializable:expr, $is_model:expr) => {{ - let py = $py_serializable.py(); - if let Ok(py_serializer) = value.getattr(intern!(py, "__pydantic_serializer__")) { - if let Ok(extracted_serializer) = py_serializer.extract::() { - let extra = extracted_serializer.build_extra( - py, - extra.mode, - extra.by_alias, - extra.warnings, - extra.exclude_unset, - extra.exclude_defaults, - extra.exclude_none, - extra.round_trip, - extra.rec_guard, - extra.serialize_unknown, - extra.fallback, - ); - let pydantic_serializer = - PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra); - return pydantic_serializer.serialize(serializer); - } - } - // Fallback to dict serialization if `__pydantic_serializer__` is not set. - // This currently only affects non-pydantic dataclasses. - let dict = object_to_dict(value, $is_model, extra).map_err(py_err_se_err)?; - serialize_dict!(dict) - }}; - } - let ser_result = match ob_type { ObType::None => serializer.serialize_none(), ObType::Int | ObType::IntSubclass => serialize!(i64), @@ -490,8 +454,30 @@ pub(crate) fn infer_serialize_known( let py_url: PyMultiHostUrl = value.extract().map_err(py_err_se_err)?; serializer.serialize_str(&py_url.__str__()) } - ObType::Dataclass => serialize_with_serializer!(value, false), - ObType::PydanticSerializable => serialize_with_serializer!(value, true), + ObType::PydanticSerializable => { + let py = value.py(); + let py_serializer = value + .getattr(intern!(py, "__pydantic_serializer__")) + .map_err(py_err_se_err)?; + let extracted_serializer: SchemaSerializer = py_serializer.extract().map_err(py_err_se_err)?; + let extra = extracted_serializer.build_extra( + py, + extra.mode, + extra.by_alias, + extra.warnings, + extra.exclude_unset, + extra.exclude_defaults, + extra.exclude_none, + extra.round_trip, + extra.rec_guard, + extra.serialize_unknown, + extra.fallback, + ); + let pydantic_serializer = + PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra); + pydantic_serializer.serialize(serializer) + } + ObType::Dataclass => serialize_dict!(dataclass_to_dict(value).map_err(py_err_se_err)?), ObType::Enum => { let v = value.getattr(intern!(value.py(), "value")).map_err(py_err_se_err)?; infer_serialize(v, serializer, include, exclude, extra) diff --git a/src/serializers/shared.rs b/src/serializers/shared.rs index 8279dc415..c33cc61e8 100644 --- a/src/serializers/shared.rs +++ b/src/serializers/shared.rs @@ -2,8 +2,9 @@ use std::borrow::Cow; use std::fmt::Debug; use pyo3::exceptions::PyTypeError; +use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; -use pyo3::types::{PyDict, PySet}; +use pyo3::types::{PyDict, PyString}; use pyo3::{intern, PyTraverseError, PyVisit}; use enum_dispatch::enum_dispatch; @@ -95,7 +96,6 @@ combined_serializer! { super::type_serializers::other::CallableBuilder; super::type_serializers::definitions::DefinitionsSerializerBuilder; super::type_serializers::dataclass::DataclassArgsBuilder; - super::type_serializers::dataclass::DataclassBuilder; super::type_serializers::function::FunctionBeforeSerializerBuilder; super::type_serializers::function::FunctionAfterSerializerBuilder; super::type_serializers::function::FunctionPlainSerializerBuilder; @@ -123,6 +123,7 @@ combined_serializer! { Generator: super::type_serializers::generator::GeneratorSerializer; Dict: super::type_serializers::dict::DictSerializer; Model: super::type_serializers::model::ModelSerializer; + Dataclass: super::type_serializers::dataclass::DataclassSerializer; Url: super::type_serializers::url::UrlSerializer; MultiHostUrl: super::type_serializers::url::MultiHostUrlSerializer; Any: super::type_serializers::any::AnySerializer; @@ -327,21 +328,29 @@ pub(crate) fn to_json_bytes( Ok(bytes) } -pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Extra) -> PyResult<&'py PyDict> { - let py = value.py(); - let attr = value.getattr(intern!(py, "__dict__"))?; - let attrs: &PyDict = attr.downcast()?; - if is_model && extra.exclude_unset { - let fields_set: &PySet = value.getattr(intern!(py, "__pydantic_fields_set__"))?.downcast()?; - - let new_attrs = attrs.copy()?; - for key in new_attrs.keys() { - if !fields_set.contains(key)? { - new_attrs.del_item(key)?; - } +static DC_FIELD_MARKER: GILOnceCell = GILOnceCell::new(); + +/// needed to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)` +pub(super) fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> { + let field_type_marker_obj = DC_FIELD_MARKER.get_or_try_init(py, || { + let field_ = py.import("dataclasses")?.getattr("_FIELD")?; + Ok::(field_.into_py(py)) + })?; + Ok(field_type_marker_obj.as_ref(py)) +} + +pub(super) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> { + let py = dc.py(); + let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?; + let dict = PyDict::new(py); + + let field_type_marker = get_field_marker(py)?; + for (field_name, field) in dc_fields.iter() { + let field_type = field.getattr(intern!(py, "_field_type"))?; + if field_type.is(field_type_marker) { + let field_name: &PyString = field_name.downcast()?; + dict.set_item(field_name, dc.getattr(field_name)?)?; } - Ok(new_attrs) - } else { - Ok(attrs) } + Ok(dict) } diff --git a/src/serializers/type_serializers/dataclass.rs b/src/serializers/type_serializers/dataclass.rs index e0463d4b2..8626e202e 100644 --- a/src/serializers/type_serializers/dataclass.rs +++ b/src/serializers/type_serializers/dataclass.rs @@ -1,14 +1,18 @@ -use pyo3::intern; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyList, PyString}; +use pyo3::types::{PyDict, PyList, PyString, PyType}; +use pyo3::{intern, PyTraverseError, PyVisit}; +use std::borrow::Cow; use ahash::AHashMap; use crate::build_tools::{py_error_type, ExtraBehavior, SchemaDict}; use crate::definitions::DefinitionsBuilder; -use super::model::ModelSerializer; -use super::{BuildSerializer, CombinedSerializer, ComputedFields, FieldsMode, GeneralFieldsSerializer, SerField}; +use super::{ + infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer, + CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, SerField, + TypeSerializer, +}; pub struct DataclassArgsBuilder; @@ -54,16 +58,131 @@ impl BuildSerializer for DataclassArgsBuilder { } } -pub struct DataclassBuilder; +#[derive(Debug, Clone)] +pub struct DataclassSerializer { + class: Py, + serializer: Box, + fields: Vec>, + name: String, +} -impl BuildSerializer for DataclassBuilder { +impl BuildSerializer for DataclassSerializer { const EXPECTED_TYPE: &'static str = "dataclass"; fn build( schema: &PyDict, - config: Option<&PyDict>, + _config: Option<&PyDict>, definitions: &mut DefinitionsBuilder, ) -> PyResult { - ModelSerializer::build(schema, config, definitions) + let py = schema.py(); + + // models ignore the parent config and always use the config from this model + let config = schema.get_as(intern!(py, "config"))?; + + let class: &PyType = schema.get_as_req(intern!(py, "cls"))?; + let sub_schema: &PyDict = schema.get_as_req(intern!(py, "schema"))?; + let serializer = Box::new(CombinedSerializer::build(sub_schema, config, definitions)?); + + let fields = schema + .get_as_req::<&PyList>(intern!(py, "fields"))? + .iter() + .map(|s| Ok(s.downcast::()?.into_py(py))) + .collect::>>()?; + + Ok(Self { + class: class.into(), + serializer, + fields, + name: class.getattr(intern!(py, "__name__"))?.extract()?, + } + .into()) + } +} + +impl DataclassSerializer { + fn allow_value(&self, value: &PyAny, extra: &Extra) -> PyResult { + match extra.check { + SerCheck::Strict => Ok(value.get_type().is(self.class.as_ref(value.py()))), + SerCheck::Lax => value.is_instance(self.class.as_ref(value.py())), + SerCheck::None => value.hasattr(intern!(value.py(), "__dataclass_fields__")), + } + } + + fn get_inner_value<'py>(&self, value: &'py PyAny) -> PyResult<&'py PyAny> { + let py = value.py(); + let dict = PyDict::new(py); + + for field_name in &self.fields { + let field_name = field_name.as_ref(py); + dict.set_item(field_name, value.getattr(field_name)?)?; + } + Ok(dict) + } +} + +impl TypeSerializer for DataclassSerializer { + fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { + visit.call(&self.class)?; + self.serializer.py_gc_traverse(visit)?; + Ok(()) + } + + fn to_python( + &self, + value: &PyAny, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> PyResult { + let extra = Extra { + model: Some(value), + ..*extra + }; + if self.allow_value(value, &extra)? { + let inner_value = self.get_inner_value(value)?; + self.serializer.to_python(inner_value, include, exclude, &extra) + } else { + extra.warnings.on_fallback_py(self.get_name(), value, &extra)?; + infer_to_python(value, include, exclude, &extra) + } + } + + fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { + if self.allow_value(key, extra)? { + infer_json_key_known(&ObType::Dataclass, key, extra) + } else { + extra.warnings.on_fallback_py(&self.name, key, extra)?; + infer_json_key(key, extra) + } + } + + fn serde_serialize( + &self, + value: &PyAny, + serializer: S, + include: Option<&PyAny>, + exclude: Option<&PyAny>, + extra: &Extra, + ) -> Result { + let extra = Extra { + model: Some(value), + ..*extra + }; + if self.allow_value(value, &extra).map_err(py_err_se_err)? { + let inner_value = self.get_inner_value(value).map_err(py_err_se_err)?; + self.serializer + .serde_serialize(inner_value, serializer, include, exclude, &extra) + } else { + extra.warnings.on_fallback_ser::(self.get_name(), value, &extra)?; + infer_serialize(value, serializer, include, exclude, &extra) + } + } + + fn get_name(&self) -> &str { + &self.name + } + + fn retry_with_lax_check(&self) -> bool { + true } } diff --git a/src/serializers/type_serializers/mod.rs b/src/serializers/type_serializers/mod.rs index 75451076a..fde74742c 100644 --- a/src/serializers/type_serializers/mod.rs +++ b/src/serializers/type_serializers/mod.rs @@ -35,6 +35,4 @@ pub(self) use super::infer::{ infer_to_python_known, }; pub(self) use super::ob_type::{IsType, ObType}; -pub(self) use super::shared::{ - object_to_dict, to_json_bytes, BuildSerializer, CombinedSerializer, PydanticSerializer, TypeSerializer, -}; +pub(self) use super::shared::{to_json_bytes, BuildSerializer, CombinedSerializer, PydanticSerializer, TypeSerializer}; diff --git a/src/serializers/type_serializers/model.rs b/src/serializers/type_serializers/model.rs index 1c8df5c6c..6107ff2b2 100644 --- a/src/serializers/type_serializers/model.rs +++ b/src/serializers/type_serializers/model.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyString, PyType}; +use pyo3::types::{PyDict, PySet, PyString, PyType}; use pyo3::{intern, PyTraverseError, PyVisit}; use ahash::AHashMap; @@ -10,9 +10,9 @@ use crate::build_tools::{py_error_type, ExtraBehavior, SchemaDict}; use crate::definitions::DefinitionsBuilder; use super::{ - infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, object_to_dict, py_err_se_err, - BuildSerializer, CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, - SerField, TypeSerializer, + infer_json_key, infer_json_key_known, infer_serialize, infer_to_python, py_err_se_err, BuildSerializer, + CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, SerField, + TypeSerializer, }; const ROOT_FIELD: &str = "root"; @@ -116,16 +116,28 @@ impl ModelSerializer { } } - fn get_inner_value<'py>(&self, value: &'py PyAny, extra: &Extra) -> PyResult<&'py PyAny> { - let py = value.py(); - let dict = object_to_dict(value, true, extra)?; + fn get_inner_value<'py>(&self, model: &'py PyAny, extra: &Extra) -> PyResult<&'py PyAny> { + let py = model.py(); + let mut attrs: &PyDict = model.getattr(intern!(py, "__dict__"))?.downcast()?; + + if extra.exclude_unset { + let fields_set: &PySet = model.getattr(intern!(py, "__pydantic_fields_set__"))?.downcast()?; + + let new_attrs = attrs.copy()?; + for key in new_attrs.keys() { + if !fields_set.contains(key)? { + new_attrs.del_item(key)?; + } + } + attrs = new_attrs; + } if self.has_extra { - let model_extra = value.getattr(intern!(py, "__pydantic_extra__"))?; - let py_tuple = (dict, model_extra).to_object(py); + let model_extra = model.getattr(intern!(py, "__pydantic_extra__"))?; + let py_tuple = (attrs, model_extra).to_object(py); Ok(py_tuple.into_ref(py)) } else { - Ok(dict) + Ok(attrs) } } } diff --git a/src/validators/dataclass.rs b/src/validators/dataclass.rs index 1eaaf5f0b..781cdb102 100644 --- a/src/validators/dataclass.rs +++ b/src/validators/dataclass.rs @@ -7,7 +7,6 @@ use ahash::AHashSet; use crate::build_tools::{is_strict, py_err, schema_or_config_same, ExtraBehavior, SchemaDict}; use crate::errors::{ErrorType, ValError, ValLineError, ValResult}; -use crate::input::InputType; use crate::input::{GenericArguments, Input}; use crate::lookup_key::LookupKey; use crate::recursion_guard::RecursionGuard; @@ -411,10 +410,12 @@ pub struct DataclassValidator { strict: bool, validator: Box, class: Py, + fields: Vec>, post_init: Option>, revalidate: Revalidate, name: String, frozen: bool, + slots: bool, } impl BuildValidator for DataclassValidator { @@ -441,10 +442,17 @@ impl BuildValidator for DataclassValidator { None }; + let fields = schema + .get_as_req::<&PyList>(intern!(py, "fields"))? + .iter() + .map(|s| Ok(s.downcast::()?.into_py(py))) + .collect::>>()?; + Ok(Self { strict: is_strict(schema, config)?, validator: Box::new(validator), class: class.into(), + fields, post_init, revalidate: Revalidate::from_str(schema_or_config_same( schema, @@ -453,6 +461,7 @@ impl BuildValidator for DataclassValidator { )?)?, name, frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false), + slots: schema.get_as(intern!(py, "slots"))?.unwrap_or(false), } .into()) } @@ -474,12 +483,12 @@ impl Validator for DataclassValidator { // same logic as on models let class = self.class.as_ref(py); - if matches!(extra.mode, InputType::Python) && input.to_object(py).as_ref(py).is_instance(class)? { - if self.revalidate.should_revalidate(input, class) { - let input = input.input_get_attr(intern!(py, "__dict__")).unwrap()?; + if let Some(py_input) = input.input_is_instance(class) { + if self.revalidate.should_revalidate(py_input, class) { + let input_dict: &PyAny = self.dataclass_to_dict(py, py_input)?; let val_output = self .validator - .validate(py, input, extra, definitions, recursion_guard)?; + .validate(py, input_dict, extra, definitions, recursion_guard)?; let dc = create_class(self.class.as_ref(py))?; self.set_dict_call(py, dc.as_ref(py), val_output, input)?; Ok(dc) @@ -516,14 +525,11 @@ impl Validator for DataclassValidator { if self.frozen { return Err(ValError::new(ErrorType::FrozenInstance, field_value)); } - let dict_py_str = intern!(py, "__dict__"); - let dict: &PyDict = obj.getattr(dict_py_str)?.downcast()?; - let new_dict = dict.copy()?; + let new_dict = self.dataclass_to_dict(py, obj)?; + new_dict.set_item(field_name, field_value)?; - // Discard the second return value, which is `init_only_args` but is always - // None anyway for validate_assignment; see validate_assignment in DataclassArgsValidator let val_assignment_result = self.validator.validate_assignment( py, new_dict, @@ -536,7 +542,14 @@ impl Validator for DataclassValidator { let (dc_dict, _): (&PyDict, PyObject) = val_assignment_result.extract(py)?; - force_setattr(py, obj, dict_py_str, dc_dict)?; + if self.slots { + let value = dc_dict + .get_item(field_name) + .ok_or_else(|| PyKeyError::new_err(field_name.to_string()))?; + force_setattr(py, obj, field_name, value)?; + } else { + force_setattr(py, obj, intern!(py, "__dict__"), dc_dict)?; + } Ok(obj.to_object(py)) } @@ -587,6 +600,17 @@ impl DataclassValidator { Ok(self_instance.into_py(py)) } + + fn dataclass_to_dict<'py>(&self, py: Python<'py>, dc: &'py PyAny) -> PyResult<&'py PyDict> { + let dict = PyDict::new(py); + + for field_name in &self.fields { + let field_name = field_name.as_ref(py); + dict.set_item(field_name, dc.getattr(field_name)?)?; + } + Ok(dict) + } + fn set_dict_call<'s, 'data>( &'s self, py: Python<'data>, @@ -595,7 +619,14 @@ impl DataclassValidator { input: &'data impl Input<'data>, ) -> ValResult<'data, ()> { let (dc_dict, post_init_kwargs): (&PyAny, &PyAny) = val_output.extract(py)?; - force_setattr(py, dc, intern!(py, "__dict__"), dc_dict)?; + if self.slots { + let dc_dict: &PyDict = dc_dict.downcast()?; + for (key, value) in dc_dict.iter() { + force_setattr(py, dc, key, value)?; + } + } else { + force_setattr(py, dc, intern!(py, "__dict__"), dc_dict)?; + } if let Some(ref post_init) = self.post_init { let post_init = post_init.as_ref(py); diff --git a/src/validators/model.rs b/src/validators/model.rs index 804c4af9f..e19d5ee6a 100644 --- a/src/validators/model.rs +++ b/src/validators/model.rs @@ -8,7 +8,7 @@ use pyo3::{ffi, intern}; use crate::build_tools::{py_err, schema_or_config_same, SchemaDict}; use crate::errors::{ErrorType, ValError, ValResult}; -use crate::input::{py_error_on_minusone, Input, InputType}; +use crate::input::{py_error_on_minusone, Input}; use crate::recursion_guard::RecursionGuard; use super::function::convert_err; @@ -37,11 +37,11 @@ impl Revalidate { } } - pub fn should_revalidate<'d>(&self, input: &impl Input<'d>, class: &PyType) -> bool { + pub fn should_revalidate(&self, input: &PyAny, class: &PyType) -> bool { match self { Revalidate::Always => true, Revalidate::Never => false, - Revalidate::SubclassInstances => !input.is_exact_instance(class), + Revalidate::SubclassInstances => !input.get_type().is(class), } } } @@ -125,16 +125,16 @@ impl Validator for ModelValidator { // if the input is an instance of the class, we "revalidate" it - e.g. we extract and reuse `__pydantic_fields_set__` // but use from attributes to create a new instance of the model field type let class = self.class.as_ref(py); - if matches!(extra.mode, InputType::Python) && input.to_object(py).as_ref(py).is_instance(class)? { - if self.revalidate.should_revalidate(input, class) { + if let Some(py_input) = input.input_is_instance(class) { + if self.revalidate.should_revalidate(py_input, class) { if self.root_model { - let inner_input: &PyAny = input.input_get_attr(intern!(py, ROOT_FIELD)).unwrap()?; + let inner_input = py_input.getattr(intern!(py, ROOT_FIELD))?; self.validate_construct(py, inner_input, None, extra, definitions, recursion_guard) } else { - let fields_set = input.input_get_attr(intern!(py, DUNDER_FIELDS_SET_KEY)).unwrap()?; + let fields_set = py_input.getattr(intern!(py, DUNDER_FIELDS_SET_KEY))?; // get dict here so from_attributes logic doesn't apply - let dict = input.input_get_attr(intern!(py, DUNDER_DICT)).unwrap()?; - let model_extra = input.input_get_attr(intern!(py, DUNDER_MODEL_EXTRA_KEY)).unwrap()?; + let dict = py_input.getattr(intern!(py, DUNDER_DICT))?; + let model_extra = py_input.getattr(intern!(py, DUNDER_MODEL_EXTRA_KEY))?; let inner_input: &PyAny = if model_extra.is_none() { dict @@ -211,7 +211,7 @@ impl Validator for ModelValidator { let (output, _, updated_fields_set): (&PyDict, &PyAny, &PySet) = output.extract(py)?; - if let Ok(fields_set) = model.input_get_attr(intern!(py, DUNDER_FIELDS_SET_KEY)).unwrap() { + if let Ok(fields_set) = model.getattr(intern!(py, DUNDER_FIELDS_SET_KEY)) { let fields_set: &PySet = fields_set.downcast()?; for field_name in updated_fields_set { fields_set.add(field_name)?; diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index 30350f654..377713111 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -36,11 +36,17 @@ class MyDataclass: class MyModel: - __pydantic_serializer__ = 42 - def __init__(self, **kwargs): + fields = {} for key, value in kwargs.items(): setattr(self, key, value) + fields[key] = core_schema.model_field(core_schema.any_schema()) + self.__pydantic_serializer__ = SchemaSerializer( + core_schema.model_schema(MyModel, core_schema.model_fields_schema(fields)) + ) + + def __repr__(self): + return f'MyModel({self.__dict__})' @pytest.mark.parametrize('value', [None, 1, 1.0, True, 'foo', [1, 2, 3], {'a': 1, 'b': 2}]) @@ -58,6 +64,7 @@ def test_any_json_round_trip(any_serializer, value): ({1, 2, 3}, {1, 2, 3}, IsList(1, 2, 3, check_order=False)), ({1, '2', b'3'}, {1, '2', b'3'}, IsList(1, '2', '3', check_order=False)), ], + ids=repr, ) def test_any_python(any_serializer, input_value, expected_plain, expected_json_obj): assert any_serializer.to_python(input_value) == expected_plain @@ -247,8 +254,13 @@ class FieldsSetModel: __slots__ = '__dict__', '__pydantic_extra__', '__pydantic_fields_set__' def __init__(self, **kwargs): + fields = {} for key, value in kwargs.items(): setattr(self, key, value) + fields[key] = core_schema.model_field(core_schema.any_schema()) + self.__pydantic_serializer__ = SchemaSerializer( + core_schema.model_schema(MyModel, core_schema.model_fields_schema(fields)) + ) def test_exclude_unset(any_serializer): @@ -427,6 +439,7 @@ class Foo: core_schema.dataclass_args_schema( 'Foo', [core_schema.dataclass_field(name='a', schema=core_schema.str_schema())] ), + ['a'], ) Foo.__pydantic_serializer__ = SchemaSerializer(schema) @@ -444,20 +457,18 @@ class Foo: def test_any_model(): + @dataclasses.dataclass class Foo: a: str b: bytes - def __init__(self, a: str, b: bytes): - self.a = a - self.b = b - # Build a schema that does not include the field 'b', to test that it is not serialized schema = core_schema.dataclass_schema( Foo, core_schema.dataclass_args_schema( 'Foo', [core_schema.dataclass_field(name='a', schema=core_schema.str_schema())] ), + ['a'], ) Foo.__pydantic_validator__ = SchemaValidator(schema) Foo.__pydantic_serializer__ = SchemaSerializer(schema) @@ -473,3 +484,79 @@ def __init__(self, a: str, b: bytes): assert j == b'{"a":"hello"}' assert s.to_python(Foo(a='hello', b=b'more'), exclude={'a'}) == IsStrictDict() + assert s.to_json(Foo(a='hello', b=b'more'), exclude={'a'}) == b'{}' + + +def test_dataclass_classvar(any_serializer): + @dataclasses.dataclass + class Foo: + a: int + b: str + c: ClassVar[int] = 1 + + foo = Foo(1, 'a') + assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a') + assert any_serializer.to_json(foo) == b'{"a":1,"b":"a"}' + + @dataclasses.dataclass + class Foo2(Foo): + pass + + foo2 = Foo2(2, 'b') + assert any_serializer.to_python(foo2) == IsStrictDict(a=2, b='b') + assert any_serializer.to_json(foo2) == b'{"a":2,"b":"b"}' + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') +def test_dataclass_slots(any_serializer): + @dataclasses.dataclass(slots=True) + class Foo: + a: int + b: str + + foo = Foo(1, 'a') + assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a') + assert any_serializer.to_json(foo) == b'{"a":1,"b":"a"}' + + @dataclasses.dataclass(slots=True) + class Foo2(Foo): + pass + + foo2 = Foo2(2, 'b') + assert any_serializer.to_python(foo2) == IsStrictDict(a=2, b='b') + assert any_serializer.to_json(foo2) == b'{"a":2,"b":"b"}' + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') +def test_dataclass_slots_init_vars(any_serializer): + @dataclasses.dataclass(slots=True) + class Foo: + a: int + b: str + c: dataclasses.InitVar[int] + d: ClassVar[int] = 42 + + foo = Foo(1, 'a', 42) + assert any_serializer.to_python(foo) == IsStrictDict(a=1, b='a') + assert any_serializer.to_json(foo) == b'{"a":1,"b":"a"}' + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +def test_slots_mixed(any_serializer): + @dataclasses.dataclass(slots=True) + class Model: + x: int + y: dataclasses.InitVar[str] + z: ClassVar[str] = 'z-classvar' + + @dataclasses.dataclass + class SubModel(Model): + x2: int + y2: dataclasses.InitVar[str] + z2: ClassVar[str] = 'z2-classvar' + + dc = SubModel(x=1, y='a', x2=2, y2='b') + assert dataclasses.asdict(dc) == {'x': 1, 'x2': 2} + + assert any_serializer.to_python(dc) == {'x': 1, 'x2': 2} + assert any_serializer.to_json(dc) == b'{"x":1,"x2":2}' diff --git a/tests/serializers/test_dataclasses.py b/tests/serializers/test_dataclasses.py index 980510b35..57b20c4b2 100644 --- a/tests/serializers/test_dataclasses.py +++ b/tests/serializers/test_dataclasses.py @@ -1,6 +1,10 @@ import dataclasses import json import platform +import sys +from typing import ClassVar + +import pytest from pydantic_core import SchemaSerializer, core_schema @@ -28,6 +32,7 @@ def test_dataclass(): core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema()), ], ), + ['a', 'b'], ) s = SchemaSerializer(schema) assert s.to_python(Foo(a='hello', b=b'more')) == IsStrictDict(a='hello', b=b'more') @@ -53,6 +58,7 @@ def test_serialization_exclude(): core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema(), serialization_exclude=True), ], ), + ['a', 'b'], ) s = SchemaSerializer(schema) assert s.to_python(Foo(a='hello', b=b'more')) == {'a': 'hello'} @@ -75,6 +81,7 @@ def test_serialization_alias(): core_schema.dataclass_field(name='b', schema=core_schema.bytes_schema(), serialization_alias='BAR'), ], ), + ['a', 'b'], ) s = SchemaSerializer(schema) assert s.to_python(Foo(a='hello', b=b'more')) == IsStrictDict(a='hello', BAR=b'more') @@ -107,6 +114,7 @@ def c(self) -> str: ], computed_fields=[core_schema.computed_field('c', core_schema.str_schema())], ), + ['a', 'b'], ) s = SchemaSerializer(schema) assert s.to_python(FooProp(a='hello', b=b'more')) == IsStrictDict(a='hello', b=b'more', c='hello more') @@ -120,3 +128,39 @@ def c(self) -> str: assert s.to_python(FooProp(a='hello', b=b'more'), exclude={'b'}) == IsStrictDict(a='hello', c='hello more') assert s.to_json(FooProp(a='hello', b=b'more'), include={'a'}) == b'{"a":"hello"}' + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python > 3.10') +def test_slots_mixed(): + @dataclasses.dataclass(slots=True) + class Model: + x: int + y: dataclasses.InitVar[str] + z: ClassVar[str] = 'z-classvar' + + @dataclasses.dataclass + class SubModel(Model): + x2: int + y2: dataclasses.InitVar[str] + z2: ClassVar[str] = 'z2-classvar' + + schema = core_schema.dataclass_schema( + SubModel, + core_schema.dataclass_args_schema( + 'SubModel', + [ + core_schema.dataclass_field(name='x', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='y', init_only=True, schema=core_schema.str_schema()), + core_schema.dataclass_field(name='x2', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='y2', init_only=True, schema=core_schema.str_schema()), + ], + ), + ['x', 'x2'], + slots=True, + ) + dc = SubModel(x=1, y='a', x2=2, y2='b') + assert dataclasses.asdict(dc) == {'x': 1, 'x2': 2} + + s = SchemaSerializer(schema) + assert s.to_python(dc) == {'x': 1, 'x2': 2} + assert s.to_json(dc) == b'{"x":1,"x2":2}' diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index ee6893ab0..86e8bf28c 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -1,3 +1,4 @@ +import dataclasses import re from datetime import date from typing import Any @@ -19,6 +20,12 @@ class MyModel: __slots__ = '__dict__', '__pydantic_extra__', '__pydantic_fields_set__' +@dataclasses.dataclass +class MyDataclass: + x: int + y: str + + def ids_function(val): if callable(val): return val.__name__ @@ -248,9 +255,13 @@ def args(*args, **kwargs): ), ( core_schema.dataclass_schema, - # MyModel should be a dataclass, but I'm being lazy here - args(MyModel, {'type': 'int'}), - {'type': 'dataclass', 'schema': {'type': 'int'}, 'cls': MyModel}, + args(MyDataclass, {'type': 'int'}, ['foobar']), + {'type': 'dataclass', 'schema': {'type': 'int'}, 'fields': ['foobar'], 'cls': MyDataclass}, + ), + ( + core_schema.dataclass_schema, + args(MyDataclass, {'type': 'int'}, ['foobar'], slots=True), + {'type': 'dataclass', 'schema': {'type': 'int'}, 'fields': ['foobar'], 'cls': MyDataclass, 'slots': True}, ), ] diff --git a/tests/validators/test_dataclasses.py b/tests/validators/test_dataclasses.py index 87a2578a5..e9a8d73d8 100644 --- a/tests/validators/test_dataclasses.py +++ b/tests/validators/test_dataclasses.py @@ -1,6 +1,7 @@ import dataclasses import re -from typing import Any, Dict, List, Optional, Union +import sys +from typing import Any, ClassVar, Dict, List, Optional, Union import pytest from dirty_equals import IsListOrTuple, IsStr @@ -201,6 +202,7 @@ def test_dataclass(): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -248,7 +250,9 @@ class DuplicateDifferent: ('always', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), ('always', FooDataclass(a='hello', b=True), {'a': 'hello', 'b': True}), ('always', FooDataclassSame(a='hello', b=True), {'a': 'hello', 'b': True}), - ('always', FooDataclassMore(a='hello', b=True, c='more'), Err(r'c\s+Unexpected keyword argument')), + # no error because we only look for fields in schema['fields'] + ('always', FooDataclassMore(a='hello', b=True, c='more'), {'a': 'hello', 'b': True}), + ('always', FooDataclassSame(a='hello', b='wrong'), Err(r'b\s+Input should be a valid boolean,')), ('always', DuplicateDifferent(a='hello', b=True), Err('should be a dictionary or an instance of FooDataclass')), # revalidate_instances='subclass-instances' ('subclass-instances', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), @@ -256,7 +260,9 @@ class DuplicateDifferent: ('subclass-instances', FooDataclass(a=b'hello', b='true'), {'a': b'hello', 'b': 'true'}), ('subclass-instances', FooDataclassSame(a='hello', b=True), {'a': 'hello', 'b': True}), ('subclass-instances', FooDataclassSame(a=b'hello', b='true'), {'a': 'hello', 'b': True}), - ('subclass-instances', FooDataclassMore(a='hello', b=True, c='more'), Err('Unexpected keyword argument')), + # no error because we only look for fields in schema['fields'] + ('subclass-instances', FooDataclassMore(a='hello', b=True, c='more'), {'a': 'hello', 'b': True}), + ('subclass-instances', FooDataclassSame(a='hello', b='wrong'), Err(r'b\s+Input should be a valid boolean,')), ('subclass-instances', DuplicateDifferent(a='hello', b=True), Err('dictionary or an instance of FooDataclass')), # revalidate_instances='never' ('never', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), @@ -278,6 +284,7 @@ def test_dataclass_subclass(revalidate_instances, input_value, expected): ], extra_behavior='forbid', ), + ['a', 'b'], revalidate_instances=revalidate_instances, ) v = SchemaValidator(schema) @@ -306,6 +313,7 @@ def test_dataclass_subclass_strict_never_revalidate(): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + ['a', 'b'], revalidate_instances='never', strict=True, ) @@ -332,6 +340,7 @@ def test_dataclass_subclass_subclass_revalidate(): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + ['a', 'b'], revalidate_instances='subclass-instances', strict=True, ) @@ -364,6 +373,7 @@ def __post_init__(self): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + ['a', 'b'], post_init=True, ) @@ -397,6 +407,7 @@ def __post_init__(self, c: int): ], collect_init_only=True, ), + ['a', 'b'], post_init=True, ) @@ -432,6 +443,7 @@ def __post_init__(self, *args): ], collect_init_only=True, ), + ['a', 'b'], post_init=True, ) @@ -462,6 +474,7 @@ def test_dataclass_exact_validation(revalidate_instances, input_value, expected) core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + ['a', 'b'], revalidate_instances=revalidate_instances, ) @@ -495,6 +508,7 @@ def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str: ), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -526,6 +540,7 @@ def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> str: ), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -558,6 +573,7 @@ def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes: ), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -593,6 +609,7 @@ def validate_b( ), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -626,6 +643,7 @@ def validate_b( ), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -651,6 +669,7 @@ def __init__(self, *args, **kwargs): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema(), kw_only=False), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -674,6 +693,7 @@ class Foo: core_schema.dataclass_field(name='b', schema=core_schema.bool_schema(), validation_alias=['bAlias', 0]), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -715,6 +735,7 @@ class Foo: core_schema.dataclass_field(name='b', schema=core_schema.bool_schema(), validation_alias=['bAlias', 0]), ], ), + ['a', 'b'], ) v = SchemaValidator(schema, {'loc_by_alias': False}) @@ -767,6 +788,7 @@ def __post_init__(self, c): ], collect_init_only=True, ), + ['a', 'b', 'c'], post_init=True, ) v = SchemaValidator(schema) @@ -787,6 +809,7 @@ def test_dataclass_validate_assignment(): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema(), kw_only=False), ], ), + ['a', 'b'], ) v = SchemaValidator(schema) @@ -816,7 +839,7 @@ def test_dataclass_validate_assignment(): assert not hasattr(foo, 'c') # wrong arguments - with pytest.raises(AttributeError, match="'str' object has no attribute '__dict__'"): + with pytest.raises(AttributeError, match="'str' object has no attribute 'a'"): v.validate_assignment('field_a', 'c', 123) @@ -846,6 +869,7 @@ def func(x, info): core_schema.dataclass_field('field_c', core_schema.int_schema()), ], ), + ['field_a', 'field_b', 'field_c'], ) ) @@ -873,6 +897,7 @@ class MyModel: core_schema.dataclass_schema( MyModel, core_schema.dataclass_args_schema('MyModel', [core_schema.dataclass_field('f', core_schema.str_schema())]), + ['f'], frozen=True, ) ) @@ -900,6 +925,7 @@ class MyModel: core_schema.dataclass_args_schema( 'MyModel', [core_schema.dataclass_field('f', core_schema.str_schema(), frozen=True)] ), + ['f'], ) ) @@ -936,6 +962,7 @@ class MyModel: core_schema.dataclass_args_schema( 'MyModel', [core_schema.dataclass_field('f', core_schema.str_schema())], **schema_extra_behavior_kw ), + ['f'], ), config=config, ) @@ -983,6 +1010,7 @@ class MyModel: core_schema.dataclass_args_schema( 'MyModel', [core_schema.dataclass_field('f', core_schema.str_schema())], **schema_extra_behavior_kw ), + ['f'], ), config=config, ) @@ -1028,6 +1056,7 @@ class MyModel: core_schema.dataclass_args_schema( 'MyModel', [core_schema.dataclass_field('f', core_schema.str_schema())], **schema_extra_behavior_kw ), + ['f'], ), config=config, ) @@ -1062,6 +1091,7 @@ class Model: 'Model', [core_schema.dataclass_field('number', core_schema.int_schema())] ), ), + ['number'], ) v = SchemaValidator(cs) @@ -1093,6 +1123,7 @@ class Model: 'Model', [core_schema.dataclass_field('number', core_schema.int_schema())] ), ), + ['number'], ) v = SchemaValidator(cs) @@ -1127,6 +1158,7 @@ class Model: 'Model', [core_schema.dataclass_field('number', core_schema.int_schema())] ), ), + ['number'], ) v = SchemaValidator(cs) @@ -1167,6 +1199,7 @@ def test_custom_dataclass_names(): core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), ], ), + ['a', 'b'], cls_name='FooDataclass[cls_name]', ), core_schema.none_schema(), @@ -1175,6 +1208,7 @@ def test_custom_dataclass_names(): ) ], ), + ['foo'], ) v = SchemaValidator(schema) @@ -1190,3 +1224,237 @@ def test_custom_dataclass_names(): }, {'input': 123, 'loc': ('foo', 'none'), 'msg': 'Input should be None', 'type': 'none_required'}, ] + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') +def test_slots() -> None: + @dataclasses.dataclass(slots=True) + class Model: + x: int + + schema = core_schema.dataclass_schema( + Model, + core_schema.dataclass_args_schema( + 'Model', [core_schema.dataclass_field(name='x', schema=core_schema.int_schema())] + ), + ['x'], + slots=True, + ) + + val = SchemaValidator(schema) + m: Model + + m = val.validate_python({'x': 123}) + assert m == Model(x=123) + + with pytest.raises(ValidationError): + val.validate_python({'x': 'abc'}) + + val.validate_assignment(m, 'x', 456) + assert m.x == 456 + + with pytest.raises(ValidationError): + val.validate_assignment(m, 'x', 'abc') + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') +def test_dataclass_slots_field_before_validator(): + @dataclasses.dataclass(slots=True) + class Foo: + a: int + b: str + + @classmethod + def validate_b(cls, v: bytes, info: core_schema.FieldValidationInfo) -> bytes: + assert v == b'hello' + assert info.field_name == 'b' + assert info.data == {'a': 1} + return b'hello world!' + + schema = core_schema.dataclass_schema( + Foo, + core_schema.dataclass_args_schema( + 'Foo', + [ + core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), + core_schema.dataclass_field( + name='b', + schema=core_schema.field_before_validator_function(Foo.validate_b, core_schema.str_schema()), + ), + ], + ), + ['a', 'b'], + slots=True, + ) + + v = SchemaValidator(schema) + foo = v.validate_python({'a': 1, 'b': b'hello'}) + assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'} + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') +def test_dataclass_slots_field_after_validator(): + @dataclasses.dataclass(slots=True) + class Foo: + a: int + b: str + + @classmethod + def validate_b(cls, v: str, info: core_schema.FieldValidationInfo) -> str: + assert v == 'hello' + assert info.field_name == 'b' + assert info.data == {'a': 1} + return 'hello world!' + + schema = core_schema.dataclass_schema( + Foo, + core_schema.dataclass_args_schema( + 'Foo', + [ + core_schema.dataclass_field(name='a', schema=core_schema.int_schema()), + core_schema.dataclass_field( + name='b', + schema=core_schema.field_after_validator_function(Foo.validate_b, core_schema.str_schema()), + ), + ], + ), + ['a', 'b'], + slots=True, + ) + + v = SchemaValidator(schema) + foo = v.validate_python({'a': 1, 'b': b'hello'}) + assert dataclasses.asdict(foo) == {'a': 1, 'b': 'hello world!'} + + +if sys.version_info < (3, 10): + kwargs = {} +else: + kwargs = {'slots': True} + + +@dataclasses.dataclass(**kwargs) +class FooDataclassSlots: + a: str + b: bool + + +@dataclasses.dataclass(**kwargs) +class FooDataclassSameSlots(FooDataclassSlots): + pass + + +@dataclasses.dataclass(**kwargs) +class FooDataclassMoreSlots(FooDataclassSlots): + c: str + + +@dataclasses.dataclass(**kwargs) +class DuplicateDifferentSlots: + a: str + b: bool + + +@pytest.mark.parametrize( + 'revalidate_instances,input_value,expected', + [ + ('always', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), + ('always', FooDataclassSlots(a='hello', b=True), {'a': 'hello', 'b': True}), + ('always', FooDataclassSameSlots(a='hello', b=True), {'a': 'hello', 'b': True}), + ('always', FooDataclassMoreSlots(a='hello', b=True, c='more'), {'a': 'hello', 'b': True}), + ( + 'always', + DuplicateDifferentSlots(a='hello', b=True), + Err('should be a dictionary or an instance of FooDataclass'), + ), + # revalidate_instances='subclass-instances' + ('subclass-instances', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), + ('subclass-instances', FooDataclassSlots(a='hello', b=True), {'a': 'hello', 'b': True}), + ('subclass-instances', FooDataclassSlots(a=b'hello', b='true'), {'a': b'hello', 'b': 'true'}), + ('subclass-instances', FooDataclassSameSlots(a='hello', b=True), {'a': 'hello', 'b': True}), + ('subclass-instances', FooDataclassSameSlots(a=b'hello', b='true'), {'a': 'hello', 'b': True}), + # no error because we don't look for fields unless their in schema['fields'] + ('subclass-instances', FooDataclassMoreSlots(a='hello', b=True, c='more'), {'a': 'hello', 'b': True}), + ('subclass-instances', FooDataclassSameSlots(a=b'hello', b='wrong'), Err('Input should be a valid boolean,')), + ( + 'subclass-instances', + DuplicateDifferentSlots(a='hello', b=True), + Err('dictionary or an instance of FooDataclass'), + ), + # revalidate_instances='never' + ('never', {'a': 'hello', 'b': True}, {'a': 'hello', 'b': True}), + ('never', FooDataclassSlots(a='hello', b=True), {'a': 'hello', 'b': True}), + ('never', FooDataclassSameSlots(a='hello', b=True), {'a': 'hello', 'b': True}), + ('never', FooDataclassMoreSlots(a='hello', b=True, c='more'), {'a': 'hello', 'b': True, 'c': 'more'}), + ('never', FooDataclassMoreSlots(a='hello', b='wrong', c='more'), {'a': 'hello', 'b': 'wrong', 'c': 'more'}), + ( + 'never', + DuplicateDifferentSlots(a='hello', b=True), + Err('should be a dictionary or an instance of FooDataclass'), + ), + ], +) +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') +def test_slots_dataclass_subclass(revalidate_instances, input_value, expected): + schema = core_schema.dataclass_schema( + FooDataclassSlots, + core_schema.dataclass_args_schema( + 'FooDataclass', + [ + core_schema.dataclass_field(name='a', schema=core_schema.str_schema()), + core_schema.dataclass_field(name='b', schema=core_schema.bool_schema()), + ], + extra_behavior='forbid', + ), + ['a', 'b'], + revalidate_instances=revalidate_instances, + slots=True, + ) + v = SchemaValidator(schema) + + if isinstance(expected, Err): + with pytest.raises(ValidationError, match=expected.message) as exc_info: + print(v.validate_python(input_value)) + + # debug(exc_info.value.errors(include_url=False)) + if expected.errors is not None: + assert exc_info.value.errors(include_url=False) == expected.errors + else: + dc = v.validate_python(input_value) + assert dataclasses.is_dataclass(dc) + assert dataclasses.asdict(dc) == expected + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason='slots are only supported for dataclasses in Python >= 3.10') +def test_slots_mixed(): + @dataclasses.dataclass(slots=True) + class Model: + x: int + y: dataclasses.InitVar[str] + z: ClassVar[str] = 'z-classvar' + + @dataclasses.dataclass + class SubModel(Model): + x2: int + y2: dataclasses.InitVar[str] + z2: ClassVar[str] = 'z2-classvar' + + schema = core_schema.dataclass_schema( + SubModel, + core_schema.dataclass_args_schema( + 'SubModel', + [ + core_schema.dataclass_field(name='x', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='y', init_only=True, schema=core_schema.str_schema()), + core_schema.dataclass_field(name='x2', schema=core_schema.int_schema()), + core_schema.dataclass_field(name='y2', init_only=True, schema=core_schema.str_schema()), + ], + ), + ['x'], + slots=True, + ) + v = SchemaValidator(schema) + dc = v.validate_python({'x': 1, 'y': 'a', 'x2': 2, 'y2': 'b'}) + assert dc.x == 1 + assert dc.x2 == 2 + assert dataclasses.asdict(dc) == {'x': 1, 'x2': 2} diff --git a/tests/validators/test_definitions_recursive.py b/tests/validators/test_definitions_recursive.py index 8d09ee5a6..21e050f56 100644 --- a/tests/validators/test_definitions_recursive.py +++ b/tests/validators/test_definitions_recursive.py @@ -889,6 +889,7 @@ class Model: ) ], ), + ['x'], ref='model', ) v = SchemaValidator(schema, config=core_schema.CoreConfig(revalidate_instances='always'))