Skip to content

Commit

Permalink
Add validate_core_schema function and remove validation from `Schem…
Browse files Browse the repository at this point in the history
…aValidator` and `SchemaSerializer` constructors (#982)

Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>
  • Loading branch information
davidhewitt and adriangb authored Sep 21, 2023
1 parent 33a7cc0 commit 916d909
Show file tree
Hide file tree
Showing 26 changed files with 118 additions and 94 deletions.
8 changes: 5 additions & 3 deletions benches/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ use test::{black_box, Bencher};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString};

use _pydantic_core::SchemaValidator;
use _pydantic_core::{validate_core_schema, SchemaValidator};

fn build_schema_validator_with_globals(py: Python, code: &str, globals: Option<&PyDict>) -> SchemaValidator {
let schema: &PyDict = py.eval(code, globals, None).unwrap().extract().unwrap();
let mut schema: &PyDict = py.eval(code, globals, None).unwrap().extract().unwrap();
schema = validate_core_schema(py, schema).unwrap().extract().unwrap();
SchemaValidator::py_new(py, schema, None).unwrap()
}

Expand Down Expand Up @@ -444,7 +445,8 @@ fn complete_model(bench: &mut Bencher) {
sys_path.call_method1("append", ("./tests/benchmarks/",)).unwrap();

let complete_schema = py.import("complete_schema").unwrap();
let schema = complete_schema.call_method0("schema").unwrap();
let mut schema = complete_schema.call_method0("schema").unwrap();
schema = validate_core_schema(py, schema).unwrap().extract().unwrap();
let validator = SchemaValidator::py_new(py, schema, None).unwrap();

let input = complete_schema.call_method0("input_data_lax").unwrap();
Expand Down
2 changes: 2 additions & 0 deletions python/pydantic_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
__version__,
to_json,
to_jsonable_python,
validate_core_schema,
)
from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, ErrorType

Expand Down Expand Up @@ -63,6 +64,7 @@
'TzInfo',
'to_json',
'to_jsonable_python',
'validate_core_schema',
]


Expand Down
9 changes: 9 additions & 0 deletions python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ __all__ = [
'to_jsonable_python',
'list_all_errors',
'TzInfo',
'validate_core_schema',
]
__version__: str
build_profile: str
Expand Down Expand Up @@ -836,3 +837,11 @@ class TzInfo(datetime.tzinfo):
def dst(self, _dt: datetime.datetime | None) -> datetime.timedelta: ...
def fromutc(self, dt: datetime.datetime) -> datetime.datetime: ...
def __deepcopy__(self, _memo: dict[Any, Any]) -> 'TzInfo': ...

def validate_core_schema(schema: CoreSchema) -> CoreSchema:
"""Validate a CoreSchema
This currently uses lax mode for validation (i.e. will coerce strings to dates and such)
but may use strict mode in the future.
We may also remove this function altogether, do not rely on it being present if you are
using pydantic-core directly.
"""
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pub use errors::{
pub use serializers::{
to_json, to_jsonable_python, PydanticSerializationError, PydanticSerializationUnexpectedValue, SchemaSerializer,
};
pub use validators::{PySome, SchemaValidator};
pub use validators::{validate_core_schema, PySome, SchemaValidator};

pub fn get_pydantic_core_version() -> &'static str {
static PYDANTIC_CORE_VERSION: OnceLock<String> = OnceLock::new();
Expand Down Expand Up @@ -97,5 +97,6 @@ fn _pydantic_core(py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(to_json, m)?)?;
m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?;
m.add_function(wrap_pyfunction!(list_all_errors, m)?)?;
m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?;
Ok(())
}
5 changes: 1 addition & 4 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use pyo3::{PyTraverseError, PyVisit};

use crate::definitions::DefinitionsBuilder;
use crate::py_gc::PyGcTraverse;
use crate::validators::SelfValidator;

use config::SerializationConfig;
pub use errors::{PydanticSerializationError, PydanticSerializationUnexpectedValue};
Expand Down Expand Up @@ -73,9 +72,7 @@ impl SchemaSerializer {
#[pymethods]
impl SchemaSerializer {
#[new]
pub fn py_new(py: Python, schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
let self_validator = SelfValidator::new(py)?;
let schema = self_validator.validate_schema(py, schema)?;
pub fn py_new(schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
let mut definitions_builder = DefinitionsBuilder::new();

let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
Expand Down
9 changes: 6 additions & 3 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ pub struct SchemaValidator {
impl SchemaValidator {
#[new]
pub fn py_new(py: Python, schema: &PyAny, config: Option<&PyDict>) -> PyResult<Self> {
let self_validator = SelfValidator::new(py)?;
let schema = self_validator.validate_schema(py, schema)?;

let mut definitions_builder = DefinitionsBuilder::new();

let mut validator = build_validator(schema, config, &mut definitions_builder)?;
Expand Down Expand Up @@ -411,6 +408,12 @@ impl<'py> SelfValidator<'py> {
}
}

#[pyfunction]
pub fn validate_core_schema<'a>(py: Python<'a>, schema: &'a PyAny) -> PyResult<&'a PyAny> {
let self_validator = SelfValidator::new(py)?;
self_validator.validate_schema(py, schema)
}

pub trait BuildValidator: Sized {
const EXPECTED_TYPE: &'static str;

Expand Down
4 changes: 2 additions & 2 deletions src/validators/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,8 @@ enum RegexEngine {
}

impl RegexEngine {
const RUST_REGEX: &str = "rust-regex";
const PYTHON_RE: &str = "python-re";
const RUST_REGEX: &'static str = "rust-regex";
const PYTHON_RE: &'static str = "python-re";
}

impl Pattern {
Expand Down
20 changes: 10 additions & 10 deletions tests/benchmarks/test_complete_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

import pytest

from pydantic_core import SchemaValidator, ValidationError
from pydantic_core import SchemaValidator, ValidationError, validate_core_schema

from .complete_schema import input_data_lax, input_data_strict, input_data_wrong, schema


def test_complete_valid():
lax_schema = schema()
cls = lax_schema['cls']
lax_validator = SchemaValidator(lax_schema)
lax_validator = SchemaValidator(validate_core_schema(lax_schema))
output = lax_validator.validate_python(input_data_lax())
assert isinstance(output, cls)
assert len(output.__pydantic_fields_set__) == 41
Expand Down Expand Up @@ -73,34 +73,34 @@ def test_complete_valid():
},
}

strict_validator = SchemaValidator(schema(strict=True))
strict_validator = SchemaValidator(validate_core_schema(schema(strict=True)))
output2 = strict_validator.validate_python(input_data_strict())
assert output_dict == output2.__dict__


def test_complete_invalid():
lax_schema = schema()
lax_validator = SchemaValidator(lax_schema)
lax_validator = SchemaValidator(validate_core_schema(lax_schema))
with pytest.raises(ValidationError) as exc_info:
lax_validator.validate_python(input_data_wrong())
assert len(exc_info.value.errors(include_url=False)) == 739


@pytest.mark.benchmark(group='complete')
def test_complete_core_lax(benchmark):
v = SchemaValidator(schema())
v = SchemaValidator(validate_core_schema(schema()))
benchmark(v.validate_python, input_data_lax())


@pytest.mark.benchmark(group='complete')
def test_complete_core_strict(benchmark):
v = SchemaValidator(schema(strict=True))
v = SchemaValidator(validate_core_schema(schema(strict=True)))
benchmark(v.validate_python, input_data_strict())


@pytest.mark.benchmark(group='complete-wrong')
def test_complete_core_error(benchmark):
v = SchemaValidator(schema())
v = SchemaValidator(validate_core_schema(schema()))
data = input_data_wrong()

@benchmark
Expand All @@ -115,7 +115,7 @@ def f():

@pytest.mark.benchmark(group='complete-wrong')
def test_complete_core_isinstance(benchmark):
v = SchemaValidator(schema())
v = SchemaValidator(validate_core_schema(schema()))
data = input_data_wrong()
assert v.isinstance_python(data) is False

Expand All @@ -135,12 +135,12 @@ def default_json_encoder(obj):

@pytest.mark.benchmark(group='complete-json')
def test_complete_core_json(benchmark):
v = SchemaValidator(schema())
v = SchemaValidator(validate_core_schema(schema()))
json_data = json.dumps(input_data_lax(), default=default_json_encoder)
benchmark(v.validate_json, json_data)


@pytest.mark.benchmark(group='build')
def test_build_schema(benchmark):
lax_schema = schema()
benchmark(SchemaValidator, lax_schema)
benchmark(lambda s: SchemaValidator(validate_core_schema(s)), lax_schema)
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pytest
from typing_extensions import Literal

from pydantic_core import ArgsKwargs, SchemaValidator, ValidationError
from pydantic_core import ArgsKwargs, SchemaValidator, ValidationError, validate_core_schema
from pydantic_core.core_schema import CoreConfig

__all__ = 'Err', 'PyAndJson', 'plain_repr', 'infinite_generator'
Expand Down Expand Up @@ -53,7 +53,7 @@ class PyAndJsonValidator:
def __init__(
self, schema, config: CoreConfig | None = None, *, validator_type: Literal['json', 'python'] | None = None
):
self.validator = SchemaValidator(schema, config)
self.validator = SchemaValidator(validate_core_schema(schema), config)
self.validator_type = validator_type

def validate_python(self, py_input, strict: bool | None = None, context: Any = None):
Expand Down
4 changes: 2 additions & 2 deletions tests/serializers/test_definitions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from pydantic_core import SchemaError, SchemaSerializer, core_schema
from pydantic_core import SchemaError, SchemaSerializer, core_schema, validate_core_schema


def test_custom_ser():
Expand All @@ -25,7 +25,7 @@ def test_ignored_def():

def test_def_error():
with pytest.raises(SchemaError) as exc_info:
SchemaSerializer(
validate_core_schema(
core_schema.definitions_schema(
core_schema.list_schema(core_schema.definition_reference_schema('foobar')),
[core_schema.int_schema(ref='foobar'), {'type': 'wrong'}],
Expand Down
6 changes: 4 additions & 2 deletions tests/serializers/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from dirty_equals import IsStrictDict

from pydantic_core import SchemaError, SchemaSerializer, core_schema
from pydantic_core import SchemaError, SchemaSerializer, core_schema, validate_core_schema


def test_dict_str_int():
Expand Down Expand Up @@ -155,4 +155,6 @@ def test_filter_runtime_int():
)
def test_include_error(include_value, error_msg):
with pytest.raises(SchemaError, match=error_msg):
SchemaSerializer(core_schema.dict_schema(serialization=core_schema.filter_dict_schema(include=include_value)))
validate_core_schema(
core_schema.dict_schema(serialization=core_schema.filter_dict_schema(include=include_value))
)
10 changes: 6 additions & 4 deletions tests/serializers/test_list_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from pydantic_core import SchemaError, SchemaSerializer, core_schema
from pydantic_core import SchemaError, SchemaSerializer, core_schema, validate_core_schema


def test_list_any():
Expand Down Expand Up @@ -144,8 +144,10 @@ def test_exclude(schema_func, seq_f):
@pytest.mark.parametrize('include,exclude', [({1, 3, 5}, {5, 6}), ([1, 3, 5], [5, 6])])
def test_filter(include, exclude):
v = SchemaSerializer(
core_schema.list_schema(
core_schema.any_schema(), serialization=core_schema.filter_seq_schema(include=include, exclude=exclude)
validate_core_schema(
core_schema.list_schema(
core_schema.any_schema(), serialization=core_schema.filter_seq_schema(include=include, exclude=exclude)
)
)
)
assert v.to_python([0, 1, 2, 3, 4, 5, 6, 7]) == [1, 3]
Expand Down Expand Up @@ -186,7 +188,7 @@ class RemovedContains(ImplicitContains):
@pytest.mark.parametrize('schema_func', [core_schema.list_schema, core_schema.tuple_variable_schema])
def test_include_error(schema_func, include_value, error_msg):
with pytest.raises(SchemaError, match=error_msg):
SchemaSerializer(
validate_core_schema(
schema_func(core_schema.any_schema(), serialization=core_schema.filter_seq_schema(include=include_value))
)

Expand Down
4 changes: 2 additions & 2 deletions tests/serializers/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from pydantic_core import SchemaError, SchemaSerializer, core_schema
from pydantic_core import SchemaError, core_schema, validate_core_schema


@pytest.mark.parametrize(
Expand All @@ -12,4 +12,4 @@
)
def test_invalid_ser_schema(ser_schema, msg):
with pytest.raises(SchemaError, match=msg):
SchemaSerializer(core_schema.any_schema(serialization=ser_schema))
validate_core_schema(core_schema.any_schema(serialization=ser_schema))
4 changes: 2 additions & 2 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ mod tests {
]
}"#;
let schema: &PyDict = py.eval(code, None, None).unwrap().extract().unwrap();
SchemaSerializer::py_new(py, schema, None).unwrap();
SchemaSerializer::py_new(schema, None).unwrap();
});
}

Expand Down Expand Up @@ -77,7 +77,7 @@ a = A()
py.run(code, None, Some(locals)).unwrap();
let a: &PyAny = locals.get_item("a").unwrap().extract().unwrap();
let schema: &PyDict = locals.get_item("schema").unwrap().extract().unwrap();
let serialized: Vec<u8> = SchemaSerializer::py_new(py, schema, None)
let serialized: Vec<u8> = SchemaSerializer::py_new(schema, None)
.unwrap()
.to_json(py, a, None, None, None, true, false, false, false, false, true, None)
.unwrap()
Expand Down
18 changes: 9 additions & 9 deletions tests/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@

import pytest

from pydantic_core import SchemaError, SchemaValidator
from pydantic_core import SchemaError, SchemaValidator, validate_core_schema
from pydantic_core import core_schema as cs


def test_build_error_type():
with pytest.raises(SchemaError, match="Input tag 'foobar' found using 'type' does not match any of the"):
SchemaValidator({'type': 'foobar', 'title': 'TestModel'})
validate_core_schema({'type': 'foobar', 'title': 'TestModel'})


def test_build_error_internal():
with pytest.raises(SchemaError, match='Input should be a valid integer, unable to parse string as an integer'):
SchemaValidator({'type': 'str', 'min_length': 'xxx', 'title': 'TestModel'})
validate_core_schema({'type': 'str', 'min_length': 'xxx', 'title': 'TestModel'})


def test_build_error_deep():
with pytest.raises(SchemaError, match='Input should be a valid integer, unable to parse string as an integer'):
SchemaValidator(
validate_core_schema(
{
'title': 'MyTestModel',
'type': 'typed-dict',
Expand All @@ -34,7 +34,7 @@ def test_schema_as_string():

def test_schema_wrong_type(pydantic_version):
with pytest.raises(SchemaError) as exc_info:
SchemaValidator(1)
validate_core_schema(1)
assert str(exc_info.value) == (
'Invalid Schema:\n Input should be a valid dictionary or object to'
' extract fields from [type=model_attributes_type, input_value=1, input_type=int]\n'
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_schema_definition_error():
schema = {'type': 'union', 'choices': []}
schema['choices'].append({'type': 'nullable', 'schema': schema})
with pytest.raises(SchemaError, match='Recursion error - cyclic reference detected'):
SchemaValidator(schema)
validate_core_schema(schema)


def test_not_schema_definition_error():
Expand All @@ -83,17 +83,17 @@ def test_not_schema_definition_error():

def test_no_type():
with pytest.raises(SchemaError, match="Unable to extract tag using discriminator 'type'"):
SchemaValidator({})
validate_core_schema({})


def test_wrong_type():
with pytest.raises(SchemaError, match="Input tag 'unknown' found using 'type' does not match any of the"):
SchemaValidator({'type': 'unknown'})
validate_core_schema({'type': 'unknown'})


def test_function_no_mode():
with pytest.raises(SchemaError, match="Input tag 'function' found using 'type' does not match any of the"):
SchemaValidator({'type': 'function'})
validate_core_schema({'type': 'function'})


def test_try_self_schema_discriminator():
Expand Down
Loading

0 comments on commit 916d909

Please sign in to comment.