From f7bf1f28241574af8caea53f07b16ea82cd29488 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 22 Sep 2022 12:48:11 +0100 Subject: [PATCH 1/2] Only use RecursiveContainerValidator when necessary --- src/validators/mod.rs | 75 ++++++++++++++++++++++++------ tests/validators/test_recursive.py | 11 ++++- 2 files changed, 70 insertions(+), 16 deletions(-) diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 03989fa2e..c06f9254a 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -2,11 +2,12 @@ use std::fmt::Debug; use enum_dispatch::enum_dispatch; +use ahash::AHashSet; use pyo3::exceptions::PyTypeError; use pyo3::intern; use pyo3::once_cell::GILOnceCell; use pyo3::prelude::*; -use pyo3::types::{PyAny, PyByteArray, PyBytes, PyDict, PyString}; +use pyo3::types::{PyAny, PyByteArray, PyBytes, PyDict, PyList, PyString}; use crate::build_tools::{py_error, SchemaDict, SchemaError}; use crate::errors::{ErrorKind, ValError, ValLineError, ValResult, ValidationError}; @@ -69,7 +70,10 @@ impl SchemaValidator { .map_err(|e| SchemaError::from_val_error(py, e))?; let schema = schema_obj.as_ref(py); - let mut build_context = BuildContext::default(); + let mut used_refs = AHashSet::new(); + extract_used_refs(schema, &mut used_refs)?; + let mut build_context = BuildContext::new(used_refs); + let mut validator = build_validator(schema, config, &mut build_context)?; validator.complete(&build_context)?; let slots = build_context.into_slots()?; @@ -219,7 +223,12 @@ impl SchemaValidator { py.run(code, None, Some(locals))?; let self_schema: &PyDict = locals.get_as_req(intern!(py, "self_schema"))?; - let mut build_context = BuildContext::default(); + let mut used_refs = AHashSet::new(); + // NOTE: we don't call `extract_used_refs` for performance reasons, if more recursive references + // are used, they would need to be manually added here. + used_refs.insert("root-schema".to_string()); + let mut build_context = BuildContext::new(used_refs); + let validator = match build_validator(self_schema, None, &mut build_context) { Ok(v) => v, Err(err) => return Err(SchemaError::new_err(format!("Error building self-schema:\n {}", err))), @@ -260,6 +269,7 @@ pub trait BuildValidator: Sized { -> PyResult; } +/// Logic to create a particular validator, called in the `validator_match` macro, then in turn by `build_validator` fn build_single_validator<'a, T: BuildValidator>( val_type: &str, schema_dict: &'a PyDict, @@ -267,19 +277,23 @@ fn build_single_validator<'a, T: BuildValidator>( build_context: &mut BuildContext, ) -> PyResult { let py = schema_dict.py(); - let val: CombinedValidator = if let Some(schema_ref) = schema_dict.get_as::(intern!(py, "ref"))? { - let slot_id = build_context.prepare_slot(schema_ref)?; - let inner_val = T::build(schema_dict, config, build_context) - .map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))?; - let name = inner_val.get_name().to_string(); - build_context.complete_slot(slot_id, inner_val)?; - recursive::RecursiveContainerValidator::create(slot_id, name) - } else { - T::build(schema_dict, config, build_context) - .map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)))? - }; + if let Some(schema_ref) = schema_dict.get_as::(intern!(py, "ref"))? { + // we only want to use a RecursiveContainerValidator if the ref is actually used, + // this means refs can always be set without having an effect on the validator which is generated + // unless it's used/referenced + if build_context.ref_used(&schema_ref) { + let slot_id = build_context.prepare_slot(schema_ref)?; + let inner_val = T::build(schema_dict, config, build_context).map_err(|err| { + SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)) + })?; + let name = inner_val.get_name().to_string(); + build_context.complete_slot(slot_id, inner_val)?; + return Ok(recursive::RecursiveContainerValidator::create(slot_id, name)); + } + } - Ok(val) + T::build(schema_dict, config, build_context) + .map_err(|err| SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err))) } // macro to build the match statement for validator selection @@ -523,10 +537,23 @@ pub trait Validator: Send + Sync + Clone + Debug { /// and therefore can't be owned by them directly. #[derive(Default, Clone)] pub struct BuildContext { + used_refs: AHashSet, slots: Vec<(String, Option)>, } impl BuildContext { + pub fn new(used_refs: AHashSet) -> Self { + Self { + used_refs, + ..Default::default() + } + } + + /// check if a ref is used elsewhere in the schema + pub fn ref_used(&self, ref_: &str) -> bool { + self.used_refs.contains(ref_) + } + /// First of two part process to add a new validator slot, we add the `slot_ref` to the array, but not the /// actual `validator`, we can't add the validator until it's build. /// We need the `id` to build the validator, hence this two-step process. @@ -584,3 +611,21 @@ impl BuildContext { .collect() } } + +fn extract_used_refs(schema: &PyAny, refs: &mut AHashSet) -> PyResult<()> { + if let Ok(dict) = schema.cast_as::() { + let py = schema.py(); + if matches!(dict.get_as(intern!(py, "type")), Ok(Some("recursive-ref"))) { + refs.insert(dict.get_as_req(intern!(py, "schema_ref"))?); + } else { + for (_, value) in dict.iter() { + extract_used_refs(value, refs)?; + } + } + } else if let Ok(list) = schema.cast_as::() { + for item in list.iter() { + extract_used_refs(item, refs)?; + } + } + Ok(()) +} diff --git a/tests/validators/test_recursive.py b/tests/validators/test_recursive.py index 9008a6b07..9aa8eb7e9 100644 --- a/tests/validators/test_recursive.py +++ b/tests/validators/test_recursive.py @@ -5,7 +5,7 @@ from pydantic_core import SchemaError, SchemaValidator, ValidationError -from ..conftest import Err +from ..conftest import Err, plain_repr from .test_typed_dict import Cls @@ -31,6 +31,7 @@ def test_branch_nullable(): ) assert v.validate_python({'name': 'root'}) == {'name': 'root', 'sub_branch': None} + assert plain_repr(v).startswith('SchemaValidator(name="typed-dict",validator=Recursive(RecursiveContainerValidator') assert v.validate_python({'name': 'root', 'sub_branch': {'name': 'b1'}}) == ( {'name': 'root', 'sub_branch': {'name': 'b1', 'sub_branch': None}} @@ -40,6 +41,14 @@ def test_branch_nullable(): ) +def test_unused_ref(): + v = SchemaValidator( + {'type': 'typed-dict', 'ref': 'Branch', 'fields': {'name': {'schema': 'str'}, 'other': {'schema': 'int'}}} + ) + assert plain_repr(v).startswith('SchemaValidator(name="typed-dict",validator=TypedDict(TypedDictValidator') + assert v.validate_python({'name': 'root', 'other': '4'}) == {'name': 'root', 'other': 4} + + def test_nullable_error(): v = SchemaValidator( { From 8b6cbccf847fb8cc24221cdaf715c6992c9aaffd Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 22 Sep 2022 13:12:59 +0100 Subject: [PATCH 2/2] tweak tests --- src/validators/mod.rs | 4 +--- tests/validators/test_recursive.py | 30 ++++++++++++++++++++++++++---- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/validators/mod.rs b/src/validators/mod.rs index c06f9254a..bba4e63e2 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -283,9 +283,7 @@ fn build_single_validator<'a, T: BuildValidator>( // unless it's used/referenced if build_context.ref_used(&schema_ref) { let slot_id = build_context.prepare_slot(schema_ref)?; - let inner_val = T::build(schema_dict, config, build_context).map_err(|err| { - SchemaError::new_err(format!("Error building \"{}\" validator:\n {}", val_type, err)) - })?; + let inner_val = T::build(schema_dict, config, build_context)?; let name = inner_val.get_name().to_string(); build_context.complete_slot(slot_id, inner_val)?; return Ok(recursive::RecursiveContainerValidator::create(slot_id, name)); diff --git a/tests/validators/test_recursive.py b/tests/validators/test_recursive.py index 9aa8eb7e9..da6146912 100644 --- a/tests/validators/test_recursive.py +++ b/tests/validators/test_recursive.py @@ -19,10 +19,7 @@ def test_branch_nullable(): 'sub_branch': { 'schema': { 'type': 'default', - 'schema': { - 'type': 'union', - 'choices': [{'type': 'none'}, {'type': 'recursive-ref', 'schema_ref': 'Branch'}], - }, + 'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 'Branch'}}, 'default': None, } }, @@ -689,3 +686,28 @@ def test_many_uses_of_ref(): long_input = {'name': 'Anne', 'other_names': [f'p-{i}' for i in range(300)]} assert v.validate_python(long_input) == long_input + + +def test_error_inside_recursive_wrapper(): + with pytest.raises(SchemaError) as exc_info: + SchemaValidator( + { + 'type': 'typed-dict', + 'ref': 'Branch', + 'fields': { + 'sub_branch': { + 'schema': { + 'type': 'default', + 'schema': {'type': 'nullable', 'schema': {'type': 'recursive-ref', 'schema_ref': 'Branch'}}, + 'default': None, + 'default_factory': lambda x: 'foobar', + } + } + }, + } + ) + assert str(exc_info.value) == ( + 'Field "sub_branch":\n' + ' SchemaError: Error building "default" validator:\n' + " SchemaError: 'default' and 'default_factory' cannot be used together" + )