diff --git a/pydantic_core/_types.py b/pydantic_core/_types.py index 547a150ed..352e750f7 100644 --- a/pydantic_core/_types.py +++ b/pydantic_core/_types.py @@ -2,12 +2,12 @@ import sys from datetime import date, datetime, time -from typing import Any, Callable, Dict, List, Sequence, Union +from typing import Any, Callable, Dict, List, Union if sys.version_info < (3, 11): from typing_extensions import NotRequired, Required else: - from typing import NotRequired + from typing import NotRequired, Required if sys.version_info < (3, 8): from typing_extensions import Literal, TypedDict @@ -19,9 +19,10 @@ class AnySchema(TypedDict): type: Literal['any'] -class BoolSchema(TypedDict): - type: Literal['bool'] - strict: NotRequired[bool] +class BoolSchema(TypedDict, total=False): + type: Required[Literal['bool']] + strict: bool + ref: str class ConfigSchema(TypedDict, total=False): @@ -39,6 +40,7 @@ class DictSchema(TypedDict, total=False): min_items: int max_items: int strict: bool + ref: str class FloatSchema(TypedDict, total=False): @@ -49,7 +51,7 @@ class FloatSchema(TypedDict, total=False): lt: float gt: float strict: bool - default: float + ref: str class FunctionSchema(TypedDict): @@ -57,12 +59,14 @@ class FunctionSchema(TypedDict): mode: Literal['before', 'after', 'wrap'] function: Callable[..., Any] schema: Schema + ref: NotRequired[str] class FunctionPlainSchema(TypedDict): type: Literal['function'] mode: Literal['plain'] function: Callable[..., Any] + ref: NotRequired[str] class IntSchema(TypedDict, total=False): @@ -73,6 +77,7 @@ class IntSchema(TypedDict, total=False): lt: int gt: int strict: bool + ref: str class ListSchema(TypedDict, total=False): @@ -81,17 +86,20 @@ class ListSchema(TypedDict, total=False): min_items: int max_items: int strict: bool + ref: str class LiteralSchema(TypedDict): type: Literal['literal'] - expected: Sequence[Any] + expected: List[Any] + ref: NotRequired[str] class ModelClassSchema(TypedDict): type: Literal['model-class'] class_type: type schema: TypedDictSchema + ref: NotRequired[str] class TypedDictField(TypedDict, total=False): @@ -102,33 +110,30 @@ class TypedDictField(TypedDict, total=False): aliases: List[List[Union[str, int]]] -class TypedDictSchema(TypedDict): - type: Literal['typed-dict'] - fields: Dict[str, TypedDictField] - extra_validator: NotRequired[Schema] - config: NotRequired[ConfigSchema] - return_fields_set: NotRequired[bool] +class TypedDictSchema(TypedDict, total=False): + type: Required[Literal['typed-dict']] + fields: Required[Dict[str, TypedDictField]] + extra_validator: Schema + config: ConfigSchema + return_fields_set: bool + ref: str class NoneSchema(TypedDict): type: Literal['none'] + ref: NotRequired[str] -class NullableSchema(TypedDict): - type: Literal['nullable'] - schema: Schema - strict: NotRequired[bool] +class NullableSchema(TypedDict, total=False): + type: Required[Literal['nullable']] + schema: Required[Schema] + strict: bool + ref: str class RecursiveReferenceSchema(TypedDict): type: Literal['recursive-ref'] - name: str - - -class RecursiveContainerSchema(TypedDict): - type: Literal['recursive-container'] - name: str - schema: Schema + schema_ref: str class SetSchema(TypedDict, total=False): @@ -137,6 +142,7 @@ class SetSchema(TypedDict, total=False): min_items: int max_items: int strict: bool + ref: str class FrozenSetSchema(TypedDict, total=False): @@ -145,6 +151,7 @@ class FrozenSetSchema(TypedDict, total=False): min_items: int max_items: int strict: bool + ref: str class StringSchema(TypedDict, total=False): @@ -156,13 +163,14 @@ class StringSchema(TypedDict, total=False): to_lower: bool to_upper: bool strict: bool + ref: str -class UnionSchema(TypedDict): - type: Literal['union'] - choices: List[Schema] - strict: NotRequired[bool] - default: NotRequired[Any] +class UnionSchema(TypedDict, total=False): + type: Required[Literal['union']] + choices: Required[List[Schema]] + strict: bool + ref: str class BytesSchema(TypedDict, total=False): @@ -170,6 +178,7 @@ class BytesSchema(TypedDict, total=False): max_length: int min_length: int strict: bool + ref: str class DateSchema(TypedDict, total=False): @@ -179,7 +188,7 @@ class DateSchema(TypedDict, total=False): ge: date lt: date gt: date - default: date + ref: str class TimeSchema(TypedDict, total=False): @@ -189,7 +198,7 @@ class TimeSchema(TypedDict, total=False): ge: time lt: time gt: time - default: time + ref: str class DatetimeSchema(TypedDict, total=False): @@ -199,13 +208,14 @@ class DatetimeSchema(TypedDict, total=False): ge: datetime lt: datetime gt: datetime - default: datetime + ref: str -class TupleFixLenSchema(TypedDict): - type: Literal['tuple-fix-len'] - items_schema: List[Schema] - strict: NotRequired[bool] +class TupleFixLenSchema(TypedDict, total=False): + type: Required[Literal['tuple-fix-len']] + items_schema: Required[List[Schema]] + strict: bool + ref: str class TupleVarLenSchema(TypedDict, total=False): @@ -214,6 +224,7 @@ class TupleVarLenSchema(TypedDict, total=False): min_items: int max_items: int strict: bool + ref: str # pydantic allows types to be defined via a simple string instead of dict with just `type`, e.g. @@ -256,7 +267,6 @@ class TupleVarLenSchema(TypedDict, total=False): ModelClassSchema, NoneSchema, NullableSchema, - RecursiveContainerSchema, RecursiveReferenceSchema, SetSchema, FrozenSetSchema, diff --git a/src/validators/mod.rs b/src/validators/mod.rs index bb34948e0..33ac898cb 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -151,19 +151,35 @@ pub trait BuildValidator: Sized { ) -> PyResult; } +fn build_single_validator<'a, T: BuildValidator>( + val_type: &str, + schema_dict: &'a PyDict, + config: Option<&'a PyDict>, + build_context: &mut BuildContext, +) -> PyResult<(CombinedValidator, &'a PyDict)> { + build_context.incr_check_depth()?; + + let val: CombinedValidator = if let Some(schema_ref) = schema_dict.get_as::("ref")? { + let slot_id = build_context.prepare_slot(schema_ref)?; + let inner_val = T::build(schema_dict, config, build_context) + .map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))?; + build_context.complete_slot(slot_id, inner_val); + recursive::RecursiveContainerValidator::create(slot_id) + } else { + T::build(schema_dict, config, build_context) + .map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))? + }; + + build_context.decr_depth(); + Ok((val, schema_dict)) +} + // macro to build the match statement for validator selection macro_rules! validator_match { ($type:ident, $dict:ident, $config:ident, $build_context:ident, $($validator:path,)+) => { match $type { $( - <$validator>::EXPECTED_TYPE => { - $build_context.incr_check_depth()?; - let val = <$validator>::build($dict, $config, $build_context).map_err(|err| { - SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", $type, err)) - })?; - $build_context.decr_depth(); - Ok((val, $dict)) - }, + <$validator>::EXPECTED_TYPE => build_single_validator::<$validator>($type, $dict, $config, $build_context), )+ _ => { return py_error!(r#"Unknown schema type: "{}""#, $type) @@ -221,7 +237,6 @@ pub fn build_validator<'a>( // functions - before, after, plain & wrap function::FunctionBuilder, // recursive (self-referencing) models - recursive::RecursiveValidator, recursive::RecursiveRefValidator, // literals literal::LiteralBuilder, @@ -294,7 +309,7 @@ pub enum CombinedValidator { FunctionPlain(function::FunctionPlainValidator), FunctionWrap(function::FunctionWrapValidator), // recursive (self-referencing) models - Recursive(recursive::RecursiveValidator), + Recursive(recursive::RecursiveContainerValidator), RecursiveRef(recursive::RecursiveRefValidator), // literals LiteralSingleString(literal::LiteralSingleStringValidator), @@ -348,6 +363,7 @@ pub trait Validator: Send + Sync + Clone + Debug { fn get_name(&self, py: Python) -> String; } +#[derive(Default)] pub struct BuildContext { named_slots: Vec<(Option, Option)>, depth: usize, @@ -355,22 +371,18 @@ pub struct BuildContext { const MAX_DEPTH: usize = 100; -impl Default for BuildContext { - fn default() -> Self { - let named_slots: Vec<(Option, Option)> = Vec::new(); - BuildContext { named_slots, depth: 0 } - } -} - impl BuildContext { - pub fn add_named_slot(&mut self, name: String, schema: &PyAny, config: Option<&PyDict>) -> PyResult { + pub fn prepare_slot(&mut self, slot_ref: String) -> PyResult { let id = self.named_slots.len(); - self.named_slots.push((Some(name), None)); - let validator = build_validator(schema, config, self)?.0; - self.named_slots[id] = (None, Some(validator)); + self.named_slots.push((Some(slot_ref), None)); Ok(id) } + pub fn complete_slot(&mut self, slot_id: usize, validator: CombinedValidator) { + let (name, _) = self.named_slots.get(slot_id).unwrap(); + self.named_slots[slot_id] = (name.clone(), Some(validator)); + } + pub fn incr_check_depth(&mut self) -> PyResult<()> { self.depth += 1; if self.depth > MAX_DEPTH { @@ -384,14 +396,14 @@ impl BuildContext { self.depth -= 1; } - pub fn find_id(&self, name: &str) -> PyResult { + pub fn find_slot_id(&self, slot_ref: &str) -> PyResult { let is_match = |(n, _): &(Option, Option)| match n { - Some(n) => n == name, + Some(n) => n == slot_ref, None => false, }; match self.named_slots.iter().position(is_match) { Some(id) => Ok(id), - None => py_error!("Recursive reference error: ref '{}' not found", name), + None => py_error!("Recursive reference error: ref '{}' not found", slot_ref), } } diff --git a/src/validators/recursive.rs b/src/validators/recursive.rs index 5be761bde..5de7c2607 100644 --- a/src/validators/recursive.rs +++ b/src/validators/recursive.rs @@ -8,26 +8,17 @@ use crate::input::Input; use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; #[derive(Debug, Clone)] -pub struct RecursiveValidator { +pub struct RecursiveContainerValidator { validator_id: usize, } -impl BuildValidator for RecursiveValidator { - const EXPECTED_TYPE: &'static str = "recursive-container"; - - fn build( - schema: &PyDict, - config: Option<&PyDict>, - build_context: &mut BuildContext, - ) -> PyResult { - let sub_schema: &PyAny = schema.get_as_req("schema")?; - let name: String = schema.get_as_req("name")?; - let validator_id = build_context.add_named_slot(name, sub_schema, config)?; - Ok(Self { validator_id }.into()) +impl RecursiveContainerValidator { + pub fn create(validator_id: usize) -> CombinedValidator { + Self { validator_id }.into() } } -impl Validator for RecursiveValidator { +impl Validator for RecursiveContainerValidator { fn validate<'s, 'data>( &'s self, py: Python<'data>, @@ -40,7 +31,7 @@ impl Validator for RecursiveValidator { } fn get_name(&self, _py: Python) -> String { - Self::EXPECTED_TYPE.to_string() + "recursive-container".to_string() } } @@ -57,8 +48,8 @@ impl BuildValidator for RecursiveRefValidator { _config: Option<&PyDict>, build_context: &mut BuildContext, ) -> PyResult { - let name: String = schema.get_as_req("name")?; - let validator_id = build_context.find_id(&name)?; + let name: String = schema.get_as_req("schema_ref")?; + let validator_id = build_context.find_slot_id(&name)?; Ok(Self { validator_id }.into()) } } diff --git a/tests/benchmarks/complete_schema.py b/tests/benchmarks/complete_schema.py index 2aad48861..0dd309a13 100644 --- a/tests/benchmarks/complete_schema.py +++ b/tests/benchmarks/complete_schema.py @@ -125,19 +125,16 @@ def wrap_function(input_value, *, validator, **kwargs): }, 'field_recursive': { 'schema': { - 'type': 'recursive-container', - 'name': 'Branch', - 'schema': { - 'type': 'typed-dict', - 'fields': { - 'name': {'schema': 'str'}, - 'sub_branch': { - 'schema': { - 'type': 'nullable', - 'schema': {'type': 'recursive-ref', 'name': 'Branch'}, - }, - 'default': None, + 'ref': 'Branch', + 'type': 'typed-dict', + 'fields': { + 'name': {'schema': 'str'}, + 'sub_branch': { + 'schema': { + 'type': 'nullable', + 'schema': {'type': 'recursive-ref', 'schema_ref': 'Branch'}, }, + 'default': None, }, }, } diff --git a/tests/benchmarks/test_micro_benchmarks.py b/tests/benchmarks/test_micro_benchmarks.py index 4dc018b64..bbb19d799 100644 --- a/tests/benchmarks/test_micro_benchmarks.py +++ b/tests/benchmarks/test_micro_benchmarks.py @@ -189,20 +189,17 @@ class CoreBranch: v = SchemaValidator( { - 'type': 'recursive-container', - 'name': 'Branch', + 'ref': 'Branch', + 'type': 'model-class', + 'class_type': CoreBranch, 'schema': { - 'type': 'model-class', - 'class_type': CoreBranch, - 'schema': { - 'type': 'typed-dict', - 'return_fields_set': True, - 'fields': { - 'width': {'schema': {'type': 'int'}}, - 'branch': { - 'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'name': 'Branch'}}, - 'default': None, - }, + 'type': 'typed-dict', + 'return_fields_set': True, + 'fields': { + 'width': {'schema': {'type': 'int'}}, + 'branch': { + 'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 'Branch'}}, + 'default': None, }, }, }, diff --git a/tests/test_typing.py b/tests/test_typing.py index 5b13772a6..d97dc0986 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -64,19 +64,16 @@ def test_schema_typing() -> None: schema: Schema = {'type': 'function', 'mode': 'plain', 'function': foo} SchemaValidator(schema) schema: Schema = { - 'type': 'recursive-container', - 'name': 'Branch', - 'schema': { - 'type': 'typed-dict', - 'fields': { - 'name': {'schema': {'type': 'str'}}, - 'sub_branch': { - 'schema': { - 'type': 'union', - 'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'name': 'Branch'}], - }, - 'default': None, + 'ref': 'Branch', + 'type': 'typed-dict', + 'fields': { + 'name': {'schema': {'type': 'str'}}, + 'sub_branch': { + 'schema': { + 'type': 'union', + 'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'schema_ref': 'Branch'}], }, + 'default': None, }, }, } diff --git a/tests/validators/test_recursive.py b/tests/validators/test_recursive.py index 1f0c09732..a4e39f810 100644 --- a/tests/validators/test_recursive.py +++ b/tests/validators/test_recursive.py @@ -8,19 +8,16 @@ def test_branch_nullable(): v = SchemaValidator( { - 'type': 'recursive-container', - 'name': 'Branch', - 'schema': { - 'type': 'typed-dict', - 'fields': { - 'name': {'schema': {'type': 'str'}}, - 'sub_branch': { - 'schema': { - 'type': 'union', - 'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'name': 'Branch'}], - }, - 'default': None, + 'type': 'typed-dict', + 'ref': 'Branch', + 'fields': { + 'name': {'schema': {'type': 'str'}}, + 'sub_branch': { + 'schema': { + 'type': 'union', + 'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'schema_ref': 'Branch'}], }, + 'default': None, }, }, } @@ -39,19 +36,16 @@ def test_branch_nullable(): def test_nullable_error(): v = SchemaValidator( { - 'type': 'recursive-container', - 'name': 'Branch', - 'schema': { - 'type': 'typed-dict', - 'fields': { - 'width': {'schema': 'int'}, - 'sub_branch': { - 'schema': { - 'type': 'union', - 'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'name': 'Branch'}], - }, - 'default': None, + 'ref': 'Branch', + 'type': 'typed-dict', + 'fields': { + 'width': {'schema': 'int'}, + 'sub_branch': { + 'schema': { + 'type': 'union', + 'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'schema_ref': 'Branch'}], }, + 'default': None, }, }, } @@ -80,16 +74,13 @@ def test_nullable_error(): def test_list(): v = SchemaValidator( { - 'type': 'recursive-container', - 'name': 'BranchList', - 'schema': { - 'type': 'typed-dict', - 'fields': { - 'width': {'schema': 'int'}, - 'branches': { - 'schema': {'type': 'list', 'items_schema': {'type': 'recursive-ref', 'name': 'BranchList'}}, - 'default': None, - }, + 'type': 'typed-dict', + 'ref': 'BranchList', + 'fields': { + 'width': {'schema': 'int'}, + 'branches': { + 'schema': {'type': 'list', 'items_schema': {'type': 'recursive-ref', 'schema_ref': 'BranchList'}}, + 'default': None, }, }, } @@ -117,38 +108,32 @@ class Bar: v = SchemaValidator( { - 'type': 'recursive-container', - 'name': 'Foo', - 'schema': { - 'type': 'typed-dict', - 'fields': { - 'height': {'schema': 'int'}, - 'bar': { - 'schema': { - 'type': 'recursive-container', - 'name': 'Bar', - 'schema': { - 'type': 'typed-dict', - 'fields': { - 'width': {'schema': 'int'}, - 'bars': { - 'schema': { - 'type': 'list', - 'items_schema': {'type': 'recursive-ref', 'name': 'Bar'}, - }, - 'default': None, - }, - 'foo': { - 'schema': { - 'type': 'union', - 'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'name': 'Foo'}], - }, - 'default': None, - }, + 'ref': 'Foo', + 'type': 'typed-dict', + 'fields': { + 'height': {'schema': 'int'}, + 'bar': { + 'schema': { + 'ref': 'Bar', + 'type': 'typed-dict', + 'fields': { + 'width': {'schema': 'int'}, + 'bars': { + 'schema': { + 'type': 'list', + 'items_schema': {'type': 'recursive-ref', 'schema_ref': 'Bar'}, }, + 'default': None, }, - } - }, + 'foo': { + 'schema': { + 'type': 'union', + 'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'schema_ref': 'Foo'}], + }, + 'default': None, + }, + }, + } }, }, } @@ -175,23 +160,20 @@ class Branch: v = SchemaValidator( { - 'type': 'recursive-container', - 'name': 'Branch', + 'type': 'model-class', + 'ref': 'Branch', + 'class_type': Branch, 'schema': { - 'type': 'model-class', - 'class_type': Branch, - 'schema': { - 'type': 'typed-dict', - 'return_fields_set': True, - 'fields': { - 'width': {'schema': 'int'}, - 'branch': { - 'schema': { - 'type': 'union', - 'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'name': 'Branch'}], - }, - 'default': None, + 'type': 'typed-dict', + 'return_fields_set': True, + 'fields': { + 'width': {'schema': 'int'}, + 'branch': { + 'schema': { + 'type': 'union', + 'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'schema_ref': 'Branch'}], }, + 'default': None, }, }, }, @@ -223,10 +205,29 @@ def test_invalid_schema(): 'fields': { 'width': {'schema': {'type': 'int'}}, 'branch': { - 'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'name': 'Branch'}}, + 'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 'Branch'}}, 'default': None, }, }, }, } ) + + +def test_outside_parent(): + v = SchemaValidator( + { + 'type': 'typed-dict', + 'fields': { + 'tuple1': { + 'schema': {'type': 'tuple-fix-len', 'ref': 'tuple-iis', 'items_schema': ['int', 'int', 'str']} + }, + 'tuple2': {'schema': {'type': 'recursive-ref', 'schema_ref': 'tuple-iis'}}, + }, + } + ) + + assert v.validate_python({'tuple1': [1, '1', 'frog'], 'tuple2': [2, '2', 'toad']}) == { + 'tuple1': (1, 1, 'frog'), + 'tuple2': (2, 2, 'toad'), + }