Skip to content

Commit

Permalink
add private API to build SchemaValidator and SchemaSerializer together
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Sep 21, 2023
1 parent b8d3b95 commit 9432e26
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 25 deletions.
2 changes: 2 additions & 0 deletions python/pydantic_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Url,
ValidationError,
__version__,
_build_validator_and_serializer,
to_json,
to_jsonable_python,
)
Expand Down Expand Up @@ -61,6 +62,7 @@
'PydanticSerializationError',
'PydanticSerializationUnexpectedValue',
'TzInfo',
'_build_validator_and_serializer',
'to_json',
'to_jsonable_python',
]
Expand Down
5 changes: 5 additions & 0 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ __all__ = [
'PydanticUndefined',
'PydanticUndefinedType',
'Some',
'_build_validator_and_serializer',
'to_json',
'to_jsonable_python',
'list_all_errors',
Expand Down Expand Up @@ -836,3 +837,7 @@ class TzInfo(datetime.tzinfo):
def dst(self, _dt: datetime.datetime | None) -> datetime.timedelta: ...
def fromutc(self, dt: datetime.datetime) -> datetime.datetime: ...
def __deepcopy__(self, _memo: dict[Any, Any]) -> 'TzInfo': ...

def _build_validator_and_serializer(
schema: CoreSchema, config: CoreConfig | None = None
) -> tuple[SchemaValidator, SchemaSerializer]: ...
20 changes: 20 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ extern crate core;

use std::sync::OnceLock;

use pyo3::types::PyDict;
use pyo3::{prelude::*, sync::GILOnceCell};
use validators::SelfValidator;

// parse this first to get access to the contained macro
#[macro_use]
Expand Down Expand Up @@ -71,6 +73,23 @@ pub fn build_info() -> String {
)
}

/// Private API to build both a `SchemaValidator` and `SchemaSerializer` in one go. This can be
/// helpful for performance because it avoids validating the schema twice.
#[pyfunction]
fn _build_validator_and_serializer(
py: Python,
schema: &PyAny,
config: Option<&PyDict>,
) -> PyResult<(SchemaValidator, SchemaSerializer)> {
let self_validator = SelfValidator::new(py)?;
let schema = self_validator.validate_schema(py, schema)?;

Ok((
SchemaValidator::new(py, &schema, config)?,
SchemaSerializer::new(&schema, config)?,
))
}

#[pymodule]
fn _pydantic_core(py: Python, m: &PyModule) -> PyResult<()> {
m.add("__version__", get_pydantic_core_version())?;
Expand All @@ -97,5 +116,6 @@ fn _pydantic_core(py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(to_json, m)?)?;
m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?;
m.add_function(wrap_pyfunction!(list_all_errors, m)?)?;
m.add_function(wrap_pyfunction!(_build_validator_and_serializer, m)?)?;
Ok(())
}
22 changes: 12 additions & 10 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use pyo3::{PyTraverseError, PyVisit};

use crate::definitions::DefinitionsBuilder;
use crate::py_gc::PyGcTraverse;
use crate::validators::SelfValidator;
use crate::validators::{SelfValidator, ValidatedSchema};

use config::SerializationConfig;
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
Expand Down Expand Up @@ -37,6 +37,16 @@ pub struct SchemaSerializer {
}

impl SchemaSerializer {
pub(crate) fn new(schema: &ValidatedSchema<'_>, config: Option<&PyDict>) -> PyResult<Self> {
let mut definitions_builder = DefinitionsBuilder::new();
Ok(Self {
serializer: CombinedSerializer::build(schema, config, &mut definitions_builder)?,
definitions: definitions_builder.finish()?,
expected_json_size: AtomicUsize::new(1024),
config: SerializationConfig::from_config(config)?,
})
}

#[allow(clippy::too_many_arguments)]
pub(crate) fn build_extra<'b, 'a: 'b>(
&'b self,
Expand Down Expand Up @@ -76,15 +86,7 @@ impl SchemaSerializer {
pub fn py_new(py: Python, schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
let self_validator = SelfValidator::new(py)?;
let schema = self_validator.validate_schema(py, schema)?;
let mut definitions_builder = DefinitionsBuilder::new();

let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
Ok(Self {
serializer,
definitions: definitions_builder.finish()?,
expected_json_size: AtomicUsize::new(1024),
config: SerializationConfig::from_config(config)?,
})
Self::new(&schema, config)
}

#[allow(clippy::too_many_arguments)]
Expand Down
42 changes: 31 additions & 11 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,9 @@ pub struct SchemaValidator {
validation_error_cause: bool,
}

#[pymethods]
impl SchemaValidator {
#[new]
pub fn py_new(py: Python, schema: &PyAny, config: Option<&PyDict>) -> PyResult<Self> {
let self_validator = SelfValidator::new(py)?;
let schema = self_validator.validate_schema(py, schema)?;

/// Construct from an already validated Schema. May raise
pub(crate) fn new(py: Python, schema: &ValidatedSchema<'_>, config: Option<&PyDict>) -> PyResult<Self> {
let mut definitions_builder = DefinitionsBuilder::new();

let mut validator = build_validator(schema, config, &mut definitions_builder)?;
Expand Down Expand Up @@ -143,6 +139,17 @@ impl SchemaValidator {
validation_error_cause,
})
}
}

#[pymethods]
impl SchemaValidator {
#[new]
pub fn py_new(py: Python, schema: &PyAny, config: Option<&PyDict>) -> PyResult<Self> {
let self_validator = SelfValidator::new(py)?;
let schema = self_validator.validate_schema(py, schema)?;

Self::new(py, &schema, config)
}

pub fn __reduce__(&self, py: Python) -> PyResult<PyObject> {
let args = (self.schema.as_ref(py),);
Expand Down Expand Up @@ -361,6 +368,19 @@ pub struct SelfValidator<'py> {
validator: &'py SchemaValidator,
}

/// Validated CoreSchema.
///
/// The only way to build this is by `SelfValidator::validate_schema`.
pub struct ValidatedSchema<'py>(&'py PyDict);

impl<'py> std::ops::Deref for ValidatedSchema<'py> {
type Target = &'py PyDict;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<'py> SelfValidator<'py> {
pub fn new(py: Python<'py>) -> PyResult<Self> {
let validator = SCHEMA_DEFINITION.get_or_init(py, || match Self::build(py) {
Expand All @@ -370,15 +390,15 @@ impl<'py> SelfValidator<'py> {
Ok(Self { validator })
}

pub fn validate_schema(&self, py: Python<'py>, schema: &'py PyAny) -> PyResult<&'py PyAny> {
pub fn validate_schema(&self, py: Python<'py>, schema: &'py PyAny) -> PyResult<ValidatedSchema<'py>> {
let mut recursion_guard = RecursionGuard::default();
let mut state = ValidationState::new(
Extra::new(None, None, None, None, InputType::Python),
&self.validator.definitions,
&mut recursion_guard,
);
match self.validator.validator.validate(py, schema, &mut state) {
Ok(schema_obj) => Ok(schema_obj.into_ref(py)),
Ok(schema_obj) => Ok(ValidatedSchema(schema_obj.into_ref(py).downcast()?)),
Err(e) => Err(SchemaError::from_val_error(py, e)),
}
}
Expand Down Expand Up @@ -446,9 +466,9 @@ macro_rules! validator_match {
};
}

pub fn build_validator<'a>(
schema: &'a PyAny,
config: Option<&'a PyDict>,
pub fn build_validator(
schema: &PyAny,
config: Option<&PyDict>,
definitions: &mut DefinitionsBuilder<CombinedValidator>,
) -> PyResult<CombinedValidator> {
let dict: &PyDict = schema.downcast()?;
Expand Down
2 changes: 1 addition & 1 deletion src/validators/with_default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl BuildValidator for WithDefaultValidator {
_ => unreachable!(),
};

let sub_schema: &PyAny = schema.get_as_req(intern!(schema.py(), "schema"))?;
let sub_schema = schema.get_as_req(intern!(schema.py(), "schema"))?;
let validator = Box::new(build_validator(sub_schema, config, definitions)?);

let copy_default = if let DefaultType::Default(default_obj) = &default {
Expand Down
4 changes: 1 addition & 3 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pydantic_core
from pydantic_core import (
PydanticSerializationError,
SchemaSerializer,
SchemaValidator,
ValidationError,
core_schema,
Expand Down Expand Up @@ -272,8 +271,7 @@ def __init__(self, my_foo: int, my_inners: List['Foobar']):
)
],
)
v = SchemaValidator(c)
s = SchemaSerializer(c)
v, s = pydantic_core._build_validator_and_serializer(c)

Foobar.__pydantic_validator__ = v
Foobar.__pydantic_serializer__ = s
Expand Down

0 comments on commit 9432e26

Please sign in to comment.