Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,7 +1366,7 @@ def list_schema(
class TuplePositionalSchema(TypedDict, total=False):
type: Required[Literal['tuple-positional']]
items_schema: Required[List[CoreSchema]]
extra_schema: CoreSchema
extras_schema: CoreSchema
strict: bool
ref: str
metadata: Any
Expand All @@ -1376,7 +1376,7 @@ class TuplePositionalSchema(TypedDict, total=False):
def tuple_positional_schema(
items_schema: list[CoreSchema],
*,
extra_schema: CoreSchema | None = None,
extras_schema: CoreSchema | None = None,
strict: bool | None = None,
ref: str | None = None,
metadata: Any = None,
Expand All @@ -1397,7 +1397,7 @@ def tuple_positional_schema(

Args:
items_schema: The value must be a tuple with items that match these schemas
extra_schema: The value must be a tuple with items that match this schema
extras_schema: The value must be a tuple with items that match this schema
This was inspired by JSON schema's `prefixItems` and `items` fields.
In python's `typing.Tuple`, you can't specify a type for "extra" items -- they must all be the same type
if the length is variable. So this field won't be set from a `typing.Tuple` annotation on a pydantic model.
Expand All @@ -1409,7 +1409,7 @@ def tuple_positional_schema(
return _dict_not_none(
type='tuple-positional',
items_schema=items_schema,
extra_schema=extra_schema,
extras_schema=extras_schema,
strict=strict,
ref=ref,
metadata=metadata,
Expand Down Expand Up @@ -2829,7 +2829,7 @@ class TypedDictSchema(TypedDict, total=False):
fields: Required[Dict[str, TypedDictField]]
computed_fields: List[ComputedField]
strict: bool
extra_validator: CoreSchema
extras_schema: CoreSchema
# all these values can be set via config, equivalent fields have `typed_dict_` prefix
extra_behavior: ExtraBehavior
total: bool # default: True
Expand All @@ -2845,7 +2845,7 @@ def typed_dict_schema(
*,
computed_fields: list[ComputedField] | None = None,
strict: bool | None = None,
extra_validator: CoreSchema | None = None,
extras_schema: CoreSchema | None = None,
extra_behavior: ExtraBehavior | None = None,
total: bool | None = None,
populate_by_name: bool | None = None,
Expand All @@ -2871,7 +2871,7 @@ def typed_dict_schema(
fields: The fields to use for the typed dict
computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model
strict: Whether the typed dict is strict
extra_validator: The extra validator to use for the typed dict
extras_schema: The extra validator to use for the typed dict
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
extra_behavior: The extra behavior to use for the typed dict
Expand All @@ -2884,7 +2884,7 @@ def typed_dict_schema(
fields=fields,
computed_fields=computed_fields,
strict=strict,
extra_validator=extra_validator,
extras_schema=extras_schema,
extra_behavior=extra_behavior,
total=total,
populate_by_name=populate_by_name,
Expand Down Expand Up @@ -2948,7 +2948,7 @@ class ModelFieldsSchema(TypedDict, total=False):
model_name: str
computed_fields: List[ComputedField]
strict: bool
extra_validator: CoreSchema
extras_schema: CoreSchema
# all these values can be set via config, equivalent fields have `typed_dict_` prefix
extra_behavior: ExtraBehavior
populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
Expand All @@ -2964,7 +2964,7 @@ def model_fields_schema(
model_name: str | None = None,
computed_fields: list[ComputedField] | None = None,
strict: bool | None = None,
extra_validator: CoreSchema | None = None,
extras_schema: CoreSchema | None = None,
extra_behavior: ExtraBehavior | None = None,
populate_by_name: bool | None = None,
from_attributes: bool | None = None,
Expand All @@ -2991,7 +2991,7 @@ def model_fields_schema(
model_name: The name of the model, used for error messages, defaults to "Model"
computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model
strict: Whether the typed dict is strict
extra_validator: The extra validator to use for the typed dict
extras_schema: The extra validator to use for the typed dict
ref: optional unique identifier of the schema, used to reference the schema in other places
metadata: Any other information you want to include with the schema, not used by pydantic-core
extra_behavior: The extra behavior to use for the typed dict
Expand All @@ -3005,7 +3005,7 @@ def model_fields_schema(
model_name=model_name,
computed_fields=computed_fields,
strict=strict,
extra_validator=extra_validator,
extras_schema=extras_schema,
extra_behavior=extra_behavior,
populate_by_name=populate_by_name,
from_attributes=from_attributes,
Expand Down
13 changes: 11 additions & 2 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ pub struct GeneralFieldsSerializer {
fields: AHashMap<String, SerField>,
computed_fields: Option<ComputedFields>,
mode: FieldsMode,
extra_serializer: Option<Box<CombinedSerializer>>,
// isize because we look up filter via `.hash()` which returns an isize
filter: SchemaFilter<isize>,
required_fields: usize,
Expand All @@ -103,12 +104,14 @@ impl GeneralFieldsSerializer {
pub(super) fn new(
fields: AHashMap<String, SerField>,
mode: FieldsMode,
extra_serializer: Option<CombinedSerializer>,
computed_fields: Option<ComputedFields>,
) -> Self {
let required_fields = fields.values().filter(|f| f.required).count();
Self {
fields,
mode,
extra_serializer: extra_serializer.map(Box::new),
filter: SchemaFilter::default(),
computed_fields,
required_fields,
Expand Down Expand Up @@ -205,7 +208,10 @@ impl TypeSerializer for GeneralFieldsSerializer {
used_req_fields += 1;
}
} else if self.mode == FieldsMode::TypedDictAllow {
let value = infer_to_python(value, next_include, next_exclude, &extra)?;
let value = match &self.extra_serializer {
Some(serializer) => serializer.to_python(value, next_include, next_exclude, &extra)?,
None => infer_to_python(value, next_include, next_exclude, &extra)?,
};
output_dict.set_item(key, value)?;
} else if extra.check == SerCheck::Strict {
return Err(PydanticSerializationUnexpectedValue::new_err(None));
Expand All @@ -227,7 +233,10 @@ impl TypeSerializer for GeneralFieldsSerializer {
continue;
}
if let Some((next_include, next_exclude)) = self.filter.key_filter(key, include, exclude)? {
let value = infer_to_python(value, next_include, next_exclude, &td_extra)?;
let value = match &self.extra_serializer {
Some(serializer) => serializer.to_python(value, next_include, next_exclude, extra)?,
None => infer_to_python(value, next_include, next_exclude, extra)?,
};
output_dict.set_item(key, value)?;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/serializers/type_serializers/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl BuildSerializer for DataclassArgsBuilder {

let computed_fields = ComputedFields::new(schema, config, definitions)?;

Ok(GeneralFieldsSerializer::new(fields, fields_mode, computed_fields).into())
Ok(GeneralFieldsSerializer::new(fields, fields_mode, None, computed_fields).into())
}
}

Expand Down
9 changes: 8 additions & 1 deletion src/serializers/type_serializers/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use super::{
CombinedSerializer, ComputedFields, Extra, FieldsMode, GeneralFieldsSerializer, ObType, SerCheck, SerField,
TypeSerializer,
};
use crate::build_tools::py_schema_err;
use crate::build_tools::{py_schema_error_type, ExtraBehavior};
use crate::definitions::DefinitionsBuilder;
use crate::serializers::errors::PydanticSerializationUnexpectedValue;
Expand Down Expand Up @@ -38,6 +39,12 @@ impl BuildSerializer for ModelFieldsBuilder {
let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?;
let mut fields: AHashMap<String, SerField> = AHashMap::with_capacity(fields_dict.len());

let extra_serializer = match (schema.get_item(intern!(py, "extras_schema")), &fields_mode) {
(Some(v), FieldsMode::ModelExtra) => Some(CombinedSerializer::build(v.extract()?, config, definitions)?),
(Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"),
(_, _) => None,
};

for (key, value) in fields_dict {
let key_py: &PyString = key.downcast()?;
let key: String = key_py.extract()?;
Expand All @@ -60,7 +67,7 @@ impl BuildSerializer for ModelFieldsBuilder {

let computed_fields = ComputedFields::new(schema, config, definitions)?;

Ok(GeneralFieldsSerializer::new(fields, fields_mode, computed_fields).into())
Ok(GeneralFieldsSerializer::new(fields, fields_mode, extra_serializer, computed_fields).into())
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/serializers/type_serializers/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ impl BuildSerializer for TuplePositionalSerializer {
let py = schema.py();
let items: &PyList = schema.get_as_req(intern!(py, "items_schema"))?;

let extra_serializer = match schema.get_as::<&PyDict>(intern!(py, "extra_schema"))? {
Some(extra_schema) => CombinedSerializer::build(extra_schema, config, definitions)?,
let extra_serializer = match schema.get_as::<&PyDict>(intern!(py, "extras_schema"))? {
Some(extras_schema) => CombinedSerializer::build(extras_schema, config, definitions)?,
None => AnySerializer::build(schema, config, definitions)?,
};
let items_serializers: Vec<CombinedSerializer> = items
Expand Down
11 changes: 10 additions & 1 deletion src/serializers/type_serializers/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::types::{PyDict, PyString};

use ahash::AHashMap;

use crate::build_tools::py_schema_err;
use crate::build_tools::{py_schema_error_type, schema_or_config, ExtraBehavior};
use crate::definitions::DefinitionsBuilder;
use crate::tools::SchemaDict;
Expand Down Expand Up @@ -34,6 +35,14 @@ impl BuildSerializer for TypedDictBuilder {
let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?;
let mut fields: AHashMap<String, SerField> = AHashMap::with_capacity(fields_dict.len());

let extra_serializer = match (schema.get_item(intern!(py, "extras_schema")), &fields_mode) {
(Some(v), FieldsMode::TypedDictAllow) => {
Some(CombinedSerializer::build(v.extract()?, config, definitions)?)
}
(Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"),
(_, _) => None,
};

for (key, value) in fields_dict {
let key_py: &PyString = key.downcast()?;
let key: String = key_py.extract()?;
Expand All @@ -56,6 +65,6 @@ impl BuildSerializer for TypedDictBuilder {

let computed_fields = ComputedFields::new(schema, config, definitions)?;

Ok(GeneralFieldsSerializer::new(fields, fields_mode, computed_fields).into())
Ok(GeneralFieldsSerializer::new(fields, fields_mode, extra_serializer, computed_fields).into())
}
}
25 changes: 24 additions & 1 deletion src/validators/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub struct DataclassArgsValidator {
dataclass_name: String,
validator_name: String,
extra_behavior: ExtraBehavior,
extras_validator: Option<Box<CombinedValidator>>,
loc_by_alias: bool,
}

Expand All @@ -55,6 +56,12 @@ impl BuildValidator for DataclassArgsValidator {

let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?;

let extras_validator = match (schema.get_item(intern!(py, "extras_schema")), &extra_behavior) {
(Some(v), ExtraBehavior::Allow) => Some(Box::new(build_validator(v, config, definitions)?)),
(Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"),
(_, _) => None,
};

let fields_schema: &PyList = schema.get_as_req(intern!(py, "fields"))?;
let mut fields: Vec<Field> = Vec::with_capacity(fields_schema.len());

Expand Down Expand Up @@ -118,6 +125,7 @@ impl BuildValidator for DataclassArgsValidator {
dataclass_name,
validator_name,
extra_behavior,
extras_validator,
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
}
.into())
Expand Down Expand Up @@ -267,7 +275,22 @@ impl Validator for DataclassArgsValidator {
}
ExtraBehavior::Ignore => {}
ExtraBehavior::Allow => {
output_dict.set_item(either_str.as_py_string(py), value)?
if let Some(ref validator) = self.extras_validator {
match validator.validate(py, value, state) {
Ok(value) => output_dict
.set_item(either_str.as_py_string(py), value)?,
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
errors.push(err.with_outer_location(
raw_key.as_loc_item(),
));
}
}
Err(err) => return Err(err),
}
} else {
output_dict.set_item(either_str.as_py_string(py), value)?
}
}
}
}
Expand Down
16 changes: 8 additions & 8 deletions src/validators/model_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub struct ModelFieldsValidator {
fields: Vec<Field>,
model_name: String,
extra_behavior: ExtraBehavior,
extra_validator: Option<Box<CombinedValidator>>,
extras_validator: Option<Box<CombinedValidator>>,
strict: bool,
from_attributes: bool,
loc_by_alias: bool,
Expand All @@ -58,9 +58,9 @@ impl BuildValidator for ModelFieldsValidator {

let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?;

let extra_validator = match (schema.get_item(intern!(py, "extra_validator")), &extra_behavior) {
let extras_validator = match (schema.get_item(intern!(py, "extras_schema")), &extra_behavior) {
(Some(v), ExtraBehavior::Allow) => Some(Box::new(build_validator(v, config, definitions)?)),
(Some(_), _) => return py_schema_err!("extra_validator can only be used if extra_behavior=allow"),
(Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"),
(_, _) => None,
};
let model_name: String = schema
Expand Down Expand Up @@ -102,7 +102,7 @@ impl BuildValidator for ModelFieldsValidator {
fields,
model_name,
extra_behavior,
extra_validator,
extras_validator,
strict,
from_attributes,
loc_by_alias: config.get_as(intern!(py, "loc_by_alias"))?.unwrap_or(true),
Expand All @@ -113,7 +113,7 @@ impl BuildValidator for ModelFieldsValidator {

impl_py_gc_traverse!(ModelFieldsValidator {
fields,
extra_validator
extras_validator
});

impl Validator for ModelFieldsValidator {
Expand Down Expand Up @@ -265,7 +265,7 @@ impl Validator for ModelFieldsValidator {
ExtraBehavior::Ignore => {}
ExtraBehavior::Allow => {
let py_key = either_str.as_py_string(py);
if let Some(ref validator) = self.extra_validator {
if let Some(ref validator) = self.extras_validator {
match validator.validate(py, value, state) {
Ok(value) => {
model_extra_dict.set_item(py_key, value)?;
Expand Down Expand Up @@ -373,7 +373,7 @@ impl Validator for ModelFieldsValidator {
// For models / typed dicts we forbid assigning extra attributes
// unless the user explicitly set extra_behavior to 'allow'
match self.extra_behavior {
ExtraBehavior::Allow => match self.extra_validator {
ExtraBehavior::Allow => match self.extras_validator {
Some(ref validator) => prepare_result(
state.with_new_extra(new_extra, |state| validator.validate(py, field_value, state)),
),
Expand Down Expand Up @@ -430,7 +430,7 @@ impl Validator for ModelFieldsValidator {
self.fields
.iter_mut()
.try_for_each(|f| f.validator.complete(definitions))?;
match &mut self.extra_validator {
match &mut self.extras_validator {
Some(v) => v.complete(definitions),
None => Ok(()),
}
Expand Down
Loading