Skip to content
Merged
12 changes: 11 additions & 1 deletion pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3093,6 +3093,7 @@ class DataclassSchema(TypedDict, total=False):
type: Required[Literal['dataclass']]
cls: Required[Type[Any]]
schema: Required[CoreSchema]
fields: Required[List[str]]
cls_name: str
post_init: bool # default: False
revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never'
Expand All @@ -3101,11 +3102,13 @@ class DataclassSchema(TypedDict, total=False):
ref: str
metadata: Any
serialization: SerSchema
slots: bool


def dataclass_schema(
cls: Type[Any],
schema: CoreSchema,
fields: List[str],
*,
cls_name: str | None = None,
post_init: bool | None = None,
Expand All @@ -3115,14 +3118,17 @@ def dataclass_schema(
metadata: Any = None,
serialization: SerSchema | None = None,
frozen: bool | None = None,
slots: bool | None = None,
) -> DataclassSchema:
"""
Returns a schema for a dataclass. As with `ModelSchema`, this schema can only be used as a field within
another schema, not as the root type.

Args:
cls: The dataclass type, used to to perform subclass checks
cls: The dataclass type, used to perform subclass checks
schema: The schema to use for the dataclass fields
fields: Fields of the dataclass, this is used in serialization and in validation during re-validation
and while validating assignment
cls_name: The name to use in error locs, etc; this is useful for generics (default: `cls.__name__`)
post_init: Whether to call `__post_init__` after validation
revalidate_instances: whether instances of models and dataclasses (including subclass instances)
Expand All @@ -3132,10 +3138,13 @@ def dataclass_schema(
metadata: Any other information you want to include with the schema, not used by pydantic-core
serialization: Custom serialization schema
frozen: Whether the dataclass is frozen
slots: Whether `slots=True` on the dataclass, means each field is assigned independently, rather than
simply setting `__dict__`, default false
"""
return dict_not_none(
type='dataclass',
cls=cls,
fields=fields,
cls_name=cls_name,
schema=schema,
post_init=post_init,
Expand All @@ -3145,6 +3154,7 @@ def dataclass_schema(
metadata=metadata,
serialization=serialization,
frozen=frozen,
slots=slots,
)


Expand Down
9 changes: 2 additions & 7 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt;

use pyo3::types::{PyDict, PyString, PyType};
use pyo3::types::{PyDict, PyType};
use pyo3::{intern, prelude::*};

use crate::errors::{InputValue, LocItem, ValResult};
Expand Down Expand Up @@ -40,15 +40,10 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {

fn is_none(&self) -> bool;

#[cfg_attr(has_no_coverage, no_coverage)]
fn input_get_attr(&self, _name: &PyString) -> Option<PyResult<&PyAny>> {
fn input_is_instance(&self, _class: &PyType) -> Option<&PyAny> {
None
}

fn is_exact_instance(&self, _class: &PyType) -> bool {
false
}

fn is_python(&self) -> bool {
false
}
Expand Down
12 changes: 6 additions & 6 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ impl<'a> Input<'a> for PyAny {
self.is_none()
}

fn input_get_attr(&self, name: &PyString) -> Option<PyResult<&PyAny>> {
Some(self.getattr(name))
}

fn is_exact_instance(&self, class: &PyType) -> bool {
self.get_type().is(class)
fn input_is_instance(&self, class: &PyType) -> Option<&PyAny> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal, but given this returns an Option<&PyAny> rather than a bool the name seems a bit weird to me (I would expect it to return a bool since the name sounds like an assertion). I'd suggest downcast or similar, but also fine keeping as is if you prefer.

if self.is_instance(class).unwrap_or(false) {
Some(self)
} else {
None
}
}

fn is_python(&self) -> bool {
Expand Down
106 changes: 46 additions & 60 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use super::errors::{py_err_se_err, PydanticSerializationError};
use super::extra::{Extra, SerMode};
use super::filter::AnyFilter;
use super::ob_type::ObType;
use super::shared::object_to_dict;
use super::shared::dataclass_to_dict;

pub(crate) fn infer_to_python(
value: &PyAny,
Expand Down Expand Up @@ -97,29 +97,23 @@ pub(crate) fn infer_to_python_known(
Ok::<PyObject, PyErr>(new_dict.into_py(py))
};

let serialize_with_serializer = |value: &PyAny, is_model: bool| {
if let Ok(py_serializer) = value.getattr(intern!(py, "__pydantic_serializer__")) {
if let Ok(serializer) = py_serializer.extract::<SchemaSerializer>() {
let extra = serializer.build_extra(
py,
extra.mode,
extra.by_alias,
extra.warnings,
extra.exclude_unset,
extra.exclude_defaults,
extra.exclude_none,
extra.round_trip,
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
);
return serializer.serializer.to_python(value, include, exclude, &extra);
}
}
// Fallback to dict serialization if `__pydantic_serializer__` is not set.
// This currently only affects non-pydantic dataclasses.
let dict = object_to_dict(value, is_model, extra)?;
serialize_dict(dict)
let serialize_with_serializer = || {
let py_serializer = value.getattr(intern!(py, "__pydantic_serializer__"))?;
let serializer: SchemaSerializer = py_serializer.extract()?;
let extra = serializer.build_extra(
py,
extra.mode,
extra.by_alias,
extra.warnings,
extra.exclude_unset,
extra.exclude_defaults,
extra.exclude_none,
extra.round_trip,
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
);
serializer.serializer.to_python(value, include, exclude, &extra)
};

let value = match extra.mode {
Expand Down Expand Up @@ -191,8 +185,8 @@ pub(crate) fn infer_to_python_known(
let py_url: PyMultiHostUrl = value.extract()?;
py_url.__str__().into_py(py)
}
ObType::PydanticSerializable => serialize_with_serializer(value, true)?,
ObType::Dataclass => serialize_with_serializer(value, false)?,
ObType::PydanticSerializable => serialize_with_serializer()?,
ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?,
ObType::Enum => {
let v = value.getattr(intern!(py, "value"))?;
infer_to_python(v, include, exclude, extra)?.into_py(py)
Expand Down Expand Up @@ -257,8 +251,8 @@ pub(crate) fn infer_to_python_known(
}
new_dict.into_py(py)
}
ObType::PydanticSerializable => serialize_with_serializer(value, true)?,
ObType::Dataclass => serialize_with_serializer(value, false)?,
ObType::PydanticSerializable => serialize_with_serializer()?,
ObType::Dataclass => serialize_dict(dataclass_to_dict(value)?)?,
ObType::Generator => {
let iter = super::type_serializers::generator::SerializationIterator::new(
value.downcast()?,
Expand Down Expand Up @@ -406,36 +400,6 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
}};
}

macro_rules! serialize_with_serializer {
($py_serializable:expr, $is_model:expr) => {{
let py = $py_serializable.py();
if let Ok(py_serializer) = value.getattr(intern!(py, "__pydantic_serializer__")) {
if let Ok(extracted_serializer) = py_serializer.extract::<SchemaSerializer>() {
let extra = extracted_serializer.build_extra(
py,
extra.mode,
extra.by_alias,
extra.warnings,
extra.exclude_unset,
extra.exclude_defaults,
extra.exclude_none,
extra.round_trip,
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
);
let pydantic_serializer =
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
return pydantic_serializer.serialize(serializer);
}
}
// Fallback to dict serialization if `__pydantic_serializer__` is not set.
// This currently only affects non-pydantic dataclasses.
let dict = object_to_dict(value, $is_model, extra).map_err(py_err_se_err)?;
serialize_dict!(dict)
}};
}

let ser_result = match ob_type {
ObType::None => serializer.serialize_none(),
ObType::Int | ObType::IntSubclass => serialize!(i64),
Expand Down Expand Up @@ -490,8 +454,30 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
let py_url: PyMultiHostUrl = value.extract().map_err(py_err_se_err)?;
serializer.serialize_str(&py_url.__str__())
}
ObType::Dataclass => serialize_with_serializer!(value, false),
ObType::PydanticSerializable => serialize_with_serializer!(value, true),
ObType::PydanticSerializable => {
let py = value.py();
let py_serializer = value
.getattr(intern!(py, "__pydantic_serializer__"))
.map_err(py_err_se_err)?;
let extracted_serializer: SchemaSerializer = py_serializer.extract().map_err(py_err_se_err)?;
let extra = extracted_serializer.build_extra(
py,
extra.mode,
extra.by_alias,
extra.warnings,
extra.exclude_unset,
extra.exclude_defaults,
extra.exclude_none,
extra.round_trip,
extra.rec_guard,
extra.serialize_unknown,
extra.fallback,
);
let pydantic_serializer =
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
pydantic_serializer.serialize(serializer)
}
ObType::Dataclass => serialize_dict!(dataclass_to_dict(value).map_err(py_err_se_err)?),
ObType::Enum => {
let v = value.getattr(intern!(value.py(), "value")).map_err(py_err_se_err)?;
infer_serialize(v, serializer, include, exclude, extra)
Expand Down
43 changes: 26 additions & 17 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::borrow::Cow;
use std::fmt::Debug;

use pyo3::exceptions::PyTypeError;
use pyo3::once_cell::GILOnceCell;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PySet};
use pyo3::types::{PyDict, PyString};
use pyo3::{intern, PyTraverseError, PyVisit};

use enum_dispatch::enum_dispatch;
Expand Down Expand Up @@ -95,7 +96,6 @@ combined_serializer! {
super::type_serializers::other::CallableBuilder;
super::type_serializers::definitions::DefinitionsSerializerBuilder;
super::type_serializers::dataclass::DataclassArgsBuilder;
super::type_serializers::dataclass::DataclassBuilder;
super::type_serializers::function::FunctionBeforeSerializerBuilder;
super::type_serializers::function::FunctionAfterSerializerBuilder;
super::type_serializers::function::FunctionPlainSerializerBuilder;
Expand Down Expand Up @@ -123,6 +123,7 @@ combined_serializer! {
Generator: super::type_serializers::generator::GeneratorSerializer;
Dict: super::type_serializers::dict::DictSerializer;
Model: super::type_serializers::model::ModelSerializer;
Dataclass: super::type_serializers::dataclass::DataclassSerializer;
Url: super::type_serializers::url::UrlSerializer;
MultiHostUrl: super::type_serializers::url::MultiHostUrlSerializer;
Any: super::type_serializers::any::AnySerializer;
Expand Down Expand Up @@ -327,21 +328,29 @@ pub(crate) fn to_json_bytes(
Ok(bytes)
}

pub(super) fn object_to_dict<'py>(value: &'py PyAny, is_model: bool, extra: &Extra) -> PyResult<&'py PyDict> {
let py = value.py();
let attr = value.getattr(intern!(py, "__dict__"))?;
let attrs: &PyDict = attr.downcast()?;
if is_model && extra.exclude_unset {
let fields_set: &PySet = value.getattr(intern!(py, "__pydantic_fields_set__"))?.downcast()?;

let new_attrs = attrs.copy()?;
for key in new_attrs.keys() {
if !fields_set.contains(key)? {
new_attrs.del_item(key)?;
}
static DC_FIELD_MARKER: GILOnceCell<PyObject> = GILOnceCell::new();

/// needed to match the logic from dataclasses.fields `tuple(f for f in fields.values() if f._field_type is _FIELD)`
pub(super) fn get_field_marker(py: Python<'_>) -> PyResult<&PyAny> {
let field_type_marker_obj = DC_FIELD_MARKER.get_or_try_init(py, || {
let field_ = py.import("dataclasses")?.getattr("_FIELD")?;
Ok::<PyObject, PyErr>(field_.into_py(py))
})?;
Ok(field_type_marker_obj.as_ref(py))
}

pub(super) fn dataclass_to_dict(dc: &PyAny) -> PyResult<&PyDict> {
let py = dc.py();
let dc_fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
let dict = PyDict::new(py);

let field_type_marker = get_field_marker(py)?;
for (field_name, field) in dc_fields.iter() {
let field_type = field.getattr(intern!(py, "_field_type"))?;
if field_type.is(field_type_marker) {
let field_name: &PyString = field_name.downcast()?;
dict.set_item(field_name, dc.getattr(field_name)?)?;
}
Ok(new_attrs)
} else {
Ok(attrs)
}
Ok(dict)
}
Loading