210 changes: 153 additions & 57 deletions Cargo.lock

Large diffs are not rendered by default.

26 changes: 14 additions & 12 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pydantic-core"
version = "2.10.1"
version = "2.14.1"
edition = "2021"
license = "MIT"
homepage = "https://github.com/pydantic/pydantic-core"
Expand All @@ -26,23 +26,25 @@ include = [
]

[dependencies]
pyo3 = { version = "0.19.2", features = ["generate-import-lib", "num-bigint"] }
regex = "1.9.5"
pyo3 = { version = "0.20.0", features = ["generate-import-lib", "num-bigint"] }
regex = "1.10.2"
strum = { version = "0.25.0", features = ["derive"] }
strum_macros = "0.25.2"
serde_json = {version = "1.0.107", features = ["arbitrary_precision", "preserve_order"]}
strum_macros = "0.25.3"
serde_json = {version = "1.0.108", features = ["arbitrary_precision", "preserve_order"]}
enum_dispatch = "0.3.8"
serde = { version = "1.0.188", features = ["derive"] }
speedate = "0.12.0"
serde = { version = "1.0.190", features = ["derive"] }
speedate = "0.13.0"
smallvec = "1.11.1"
ahash = "0.8.0"
ahash = "0.8.6"
url = "2.4.1"
# idna is already required by url, added here to be explicit
idna = "0.4.0"
base64 = "0.21.4"
base64 = "0.21.5"
num-bigint = "0.4.4"
python3-dll-a = "0.2.7"
uuid = "1.4.1"
uuid = "1.5.0"
jiter = {version = "0.0.4", features = ["python"]}
#jiter = {path = "../jiter", features = ["python"]}

[lib]
name = "_pydantic_core"
Expand All @@ -62,9 +64,9 @@ debug = true
strip = false

[dev-dependencies]
pyo3 = { version= "0.19.2", features = ["auto-initialize"] }
pyo3 = { version = "0.20.0", features = ["auto-initialize"] }

[build-dependencies]
version_check = "0.9.4"
# used where logic has to be version/distribution specific, e.g. pypy
pyo3-build-config = { version = "0.19.2" }
pyo3-build-config = { version = "0.20.0" }
12 changes: 6 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.DEFAULT_GOAL := all
black = black python/pydantic_core tests generate_self_schema.py wasm-preview/run_tests.py
ruff = ruff python/pydantic_core tests generate_self_schema.py wasm-preview/run_tests.py
sources = python/pydantic_core tests generate_self_schema.py wasm-preview/run_tests.py

mypy-stubtest = python -m mypy.stubtest pydantic_core._pydantic_core --allowlist .mypy-stubtest-allowlist

# using pip install cargo (via maturin via pip) doesn't get the tty handle
Expand Down Expand Up @@ -90,14 +90,14 @@ build-wasm:

.PHONY: format
format:
$(black)
$(ruff) --fix --exit-zero
ruff --fix $(sources)
ruff format $(sources)
cargo fmt

.PHONY: lint-python
lint-python:
$(ruff)
$(black) --check --diff
ruff $(sources)
ruff format --check $(sources)
$(mypy-stubtest)
griffe dump -f -d google -LWARNING -o/dev/null python/pydantic_core

Expand Down
44 changes: 44 additions & 0 deletions benches/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,50 @@ fn complete_model(bench: &mut Bencher) {
})
}

#[bench]
fn nested_model_using_definitions(bench: &mut Bencher) {
Python::with_gil(|py| {
let sys_path = py.import("sys").unwrap().getattr("path").unwrap();
sys_path.call_method1("append", ("./tests/benchmarks/",)).unwrap();

let complete_schema = py.import("nested_schema").unwrap();
let mut schema = complete_schema.call_method0("schema_using_defs").unwrap();
schema = validate_core_schema(py, schema, None).unwrap().extract().unwrap();
let validator = SchemaValidator::py_new(py, schema, None).unwrap();

let input = complete_schema.call_method0("input_data_valid").unwrap();
let input = black_box(input);

validator.validate_python(py, input, None, None, None, None).unwrap();

bench.iter(|| {
black_box(validator.validate_python(py, input, None, None, None, None).unwrap());
})
})
}

#[bench]
fn nested_model_inlined(bench: &mut Bencher) {
Python::with_gil(|py| {
let sys_path = py.import("sys").unwrap().getattr("path").unwrap();
sys_path.call_method1("append", ("./tests/benchmarks/",)).unwrap();

let complete_schema = py.import("nested_schema").unwrap();
let mut schema = complete_schema.call_method0("inlined_schema").unwrap();
schema = validate_core_schema(py, schema, None).unwrap().extract().unwrap();
let validator = SchemaValidator::py_new(py, schema, None).unwrap();

let input = complete_schema.call_method0("input_data_valid").unwrap();
let input = black_box(input);

validator.validate_python(py, input, None, None, None, None).unwrap();

bench.iter(|| {
black_box(validator.validate_python(py, input, None, None, None, None).unwrap());
})
})
}

#[bench]
fn literal_ints_few_python(bench: &mut Bencher) {
Python::with_gil(|py| {
Expand Down
11 changes: 3 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ features = ["pyo3/extension-module"]
line-length = 120
extend-select = ['Q', 'RUF100', 'C90', 'I']
extend-ignore = [
'E501', # ignore line too long and let black handle it
'E721', # using type() instead of isinstance() - we use this in tests
]
flake8-quotes = {inline-quotes = 'single', multiline-quotes = 'double'}
mccabe = { max-complexity = 13 }
isort = { known-first-party = ['pydantic_core', 'tests'] }

[tool.ruff.format]
quote-style = 'single'

[tool.pytest.ini_options]
testpaths = 'tests'
log_format = '%(name)s %(levelname)s: %(message)s'
Expand Down Expand Up @@ -97,13 +99,6 @@ exclude_lines = [
'@overload',
]

[tool.black]
color = true
line-length = 120
target-version = ['py37', 'py38', 'py39', 'py310']
skip-string-normalization = true
skip-magic-trailing-comma = true

# configuring https://github.com/pydantic/hooky
[tool.hooky]
assignees = ['samuelcolvin', 'adriangb', 'dmontagu', 'davidhewitt', 'lig']
Expand Down
2 changes: 2 additions & 0 deletions python/pydantic_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Url,
ValidationError,
__version__,
from_json,
to_json,
to_jsonable_python,
validate_core_schema,
Expand Down Expand Up @@ -63,6 +64,7 @@
'PydanticSerializationUnexpectedValue',
'TzInfo',
'to_json',
'from_json',
'to_jsonable_python',
'validate_core_schema',
]
Expand Down
19 changes: 18 additions & 1 deletion python/pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ __all__ = [
'PydanticUndefinedType',
'Some',
'to_json',
'from_json',
'to_jsonable_python',
'list_all_errors',
'TzInfo',
Expand Down Expand Up @@ -384,6 +385,23 @@ def to_json(
JSON bytes.
"""

def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True) -> Any:
"""
Deserialize JSON data to a Python object.
This is effectively a faster version of [`json.loads()`][json.loads].
Arguments:
data: The JSON data to deserialize.
allow_inf_nan: Whether to allow `Infinity`, `-Infinity` and `NaN` values as `json.loads()` does by default.
Raises:
ValueError: If deserialization fails.
Returns:
The deserialized Python object.
"""

def to_jsonable_python(
value: Any,
*,
Expand Down Expand Up @@ -829,7 +847,6 @@ def list_all_errors() -> list[ErrorTypeInfo]:
Returns:
A list of `ErrorTypeInfo` typed dicts.
"""

@final
class TzInfo(datetime.tzinfo):
def tzname(self, _dt: datetime.datetime | None) -> str | None: ...
Expand Down
9 changes: 8 additions & 1 deletion python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class CoreConfig(TypedDict, total=False):
allow_inf_nan: Whether to allow infinity and NaN values for float fields. Default is `True`.
ser_json_timedelta: The serialization option for `timedelta` values. Default is 'iso8601'.
ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'.
ser_json_inf_nan: The serialization option for infinity and NaN values
in float fields. Default is 'null'.
hide_input_in_errors: Whether to hide input data from `ValidationError` representation.
validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError.
Requires exceptiongroup backport pre Python 3.11.
Expand Down Expand Up @@ -101,11 +103,13 @@ class CoreConfig(TypedDict, total=False):
allow_inf_nan: bool # default: True
# the config options are used to customise serialization to JSON
ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601'
ser_json_bytes: Literal['utf8', 'base64'] # default: 'utf8'
ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8'
ser_json_inf_nan: Literal['null', 'constants'] # default: 'null'
# used to hide input data from ValidationError repr
hide_input_in_errors: bool
validation_error_cause: bool # default: False
coerce_numbers_to_str: bool # default: False
regex_engine: Literal['rust-regex', 'python-re'] # default: 'rust-regex'


IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None'
Expand Down Expand Up @@ -3909,6 +3913,9 @@ def general_plain_validator_function(*args, **kwargs):
'FieldWrapValidatorFunction': WithInfoWrapValidatorFunction,
}

if TYPE_CHECKING:
FieldValidationInfo = ValidationInfo


def __getattr__(attr_name: str) -> object:
new_attr = _deprecated_import_lookup.get(attr_name)
Expand Down
256 changes: 192 additions & 64 deletions src/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
/// Unlike json schema we let you put definitions inline, not just in a single '#/$defs/' block or similar.
/// We use DefinitionsBuilder to collect the references / definitions into a single vector
/// and then get a definition from a reference using an integer id (just for performance of not using a HashMap)
use std::collections::hash_map::Entry;
use std::{
collections::hash_map::Entry,
fmt::Debug,
sync::{
atomic::{AtomicBool, Ordering},
Arc, OnceLock,
},
};

use pyo3::prelude::*;
use pyo3::{prelude::*, PyTraverseError, PyVisit};

use ahash::AHashMap;

use crate::build_tools::py_schema_err;

// An integer id for the reference
pub type ReferenceId = usize;
use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse};

/// Definitions are validators and serializers that are
/// shared by reference.
Expand All @@ -24,91 +28,215 @@ pub type ReferenceId = usize;
/// They get indexed by a ReferenceId, which are integer identifiers
/// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer}
/// gets build.
pub type Definitions<T> = [T];
#[derive(Clone)]
pub struct Definitions<T>(AHashMap<Arc<String>, Definition<T>>);

#[derive(Clone, Debug)]
struct Definition<T> {
pub id: ReferenceId,
pub value: Option<T>,
/// Internal type which contains a definition to be filled
pub struct Definition<T>(Arc<DefinitionInner<T>>);

struct DefinitionInner<T> {
value: OnceLock<T>,
name: LazyName,
}

/// Reference to a definition.
pub struct DefinitionRef<T> {
name: Arc<String>,
value: Definition<T>,
}

// DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone)
impl<T> Clone for DefinitionRef<T> {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
value: self.value.clone(),
}
}
}

impl<T> DefinitionRef<T> {
pub fn id(&self) -> usize {
Arc::as_ptr(&self.value.0) as usize
}

pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str {
match self.value.0.value.get() {
Some(value) => self.value.0.name.get_or_init(|| init(value)),
None => "...",
}
}

pub fn get(&self) -> Option<&T> {
self.value.0.value.get()
}
}

impl<T: Debug> Debug for DefinitionRef<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To avoid possible infinite recursion from recursive definitions,
// a DefinitionRef just displays debug as its name
self.name.fmt(f)
}
}

impl<T: Debug> Debug for Definitions<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Formatted as a list for backwards compatibility; in principle
// this could be formatted as a map. Maybe change in a future
// minor release of pydantic.
write![f, "["]?;
let mut first = true;
for def in self.0.values() {
write![f, "{sep}{def:?}", sep = if first { "" } else { ", " }]?;
first = false;
}
write![f, "]"]?;
Ok(())
}
}

impl<T> Clone for Definition<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl<T: Debug> Debug for Definition<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0.value.get() {
Some(value) => value.fmt(f),
None => "...".fmt(f),
}
}
}

impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
if let Some(value) = self.value.0.value.get() {
value.py_gc_traverse(visit)?;
}
Ok(())
}
}

impl<T: PyGcTraverse> PyGcTraverse for Definitions<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
for value in self.0.values() {
if let Some(value) = value.0.value.get() {
value.py_gc_traverse(visit)?;
}
}
Ok(())
}
}

#[derive(Clone, Debug)]
pub struct DefinitionsBuilder<T> {
definitions: AHashMap<String, Definition<T>>,
definitions: Definitions<T>,
}

impl<T: Clone + std::fmt::Debug> DefinitionsBuilder<T> {
impl<T: std::fmt::Debug> DefinitionsBuilder<T> {
pub fn new() -> Self {
Self {
definitions: AHashMap::new(),
definitions: Definitions(AHashMap::new()),
}
}

/// Get a ReferenceId for the given reference string.
// This ReferenceId can later be used to retrieve a definition
pub fn get_reference_id(&mut self, reference: &str) -> ReferenceId {
let next_id = self.definitions.len();
pub fn get_definition(&mut self, reference: &str) -> DefinitionRef<T> {
// We either need a String copy or two hashmap lookups
// Neither is better than the other
// We opted for the easier outward facing API
match self.definitions.entry(reference.to_string()) {
Entry::Occupied(entry) => entry.get().id,
Entry::Vacant(entry) => {
entry.insert(Definition {
id: next_id,
value: None,
});
next_id
}
let name = Arc::new(reference.to_string());
let value = match self.definitions.0.entry(name.clone()) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner {
value: OnceLock::new(),
name: LazyName::new(),
}))),
};
DefinitionRef {
name,
value: value.clone(),
}
}

/// Add a definition, returning the ReferenceId that maps to it
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<ReferenceId> {
let next_id = self.definitions.len();
match self.definitions.entry(reference.clone()) {
Entry::Occupied(mut entry) => match entry.get_mut().value.replace(value) {
Some(_) => py_schema_err!("Duplicate ref: `{}`", reference),
None => Ok(entry.get().id),
},
Entry::Vacant(entry) => {
entry.insert(Definition {
id: next_id,
value: Some(value),
});
Ok(next_id)
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<DefinitionRef<T>> {
let name = Arc::new(reference);
let value = match self.definitions.0.entry(name.clone()) {
Entry::Occupied(entry) => {
let definition = entry.into_mut();
match definition.0.value.set(value) {
Ok(()) => definition.clone(),
Err(_) => return py_schema_err!("Duplicate ref: `{}`", name),
}
}
Entry::Vacant(entry) => entry
.insert(Definition(Arc::new(DefinitionInner {
value: OnceLock::from(value),
name: LazyName::new(),
})))
.clone(),
};
Ok(DefinitionRef { name, value })
}

/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
pub fn finish(self) -> PyResult<Definitions<T>> {
for (reference, def) in &self.definitions.0 {
if def.0.value.get().is_none() {
return py_schema_err!("Definitions error: definition `{}` was never filled", reference);
}
}
Ok(self.definitions)
}
}

struct LazyName {
initialized: OnceLock<String>,
in_recursion: AtomicBool,
}

impl LazyName {
fn new() -> Self {
Self {
initialized: OnceLock::new(),
in_recursion: AtomicBool::new(false),
}
}

/// Retrieve an item definition using a ReferenceId
/// If the definition doesn't yet exist (as happens in recursive types) then we create it
/// At the end (in finish()) we check that there are no undefined definitions
pub fn get_definition(&self, reference_id: ReferenceId) -> PyResult<&T> {
let (reference, def) = match self.definitions.iter().find(|(_, def)| def.id == reference_id) {
Some(v) => v,
None => return py_schema_err!("Definitions error: no definition for ReferenceId `{}`", reference_id),
};
match def.value.as_ref() {
Some(v) => Ok(v),
None => py_schema_err!(
"Definitions error: attempted to use `{}` before it was filled",
reference
),
/// Gets the validator name, returning the default in the case of recursion loops
fn get_or_init(&self, init: impl FnOnce() -> String) -> &str {
if let Some(s) = self.initialized.get() {
return s.as_str();
}

if self
.in_recursion
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
return "...";
}
let result = self.initialized.get_or_init(init).as_str();
self.in_recursion.store(false, Ordering::SeqCst);
result
}
}

/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
pub fn finish(self) -> PyResult<Vec<T>> {
// We need to create a vec of defs according to the order in their ids
let mut defs: Vec<(usize, T)> = Vec::new();
for (reference, def) in self.definitions {
match def.value {
None => return py_schema_err!("Definitions error: definition {} was never filled", reference),
Some(v) => defs.push((def.id, v)),
}
impl Debug for LazyName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.initialized.get().map_or("...", String::as_str).fmt(f)
}
}

impl Clone for LazyName {
fn clone(&self) -> Self {
Self {
initialized: OnceLock::new(),
in_recursion: AtomicBool::new(false),
}
defs.sort_by_key(|(id, _)| *id);
Ok(defs.into_iter().map(|(_, v)| v).collect())
}
}
6 changes: 4 additions & 2 deletions src/errors/line_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::PyDowncastError;

use crate::input::{Input, JsonInput};
use jiter::JsonValue;

use crate::input::Input;

use super::location::{LocItem, Location};
use super::types::ErrorType;
Expand Down Expand Up @@ -147,7 +149,7 @@ impl<'a> ValLineError<'a> {
#[derive(Clone)]
pub enum InputValue<'a> {
PyAny(&'a PyAny),
JsonInput(JsonInput),
JsonInput(JsonValue),
String(&'a str),
}

Expand Down
28 changes: 9 additions & 19 deletions src/errors/location.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ use pyo3::once_cell::GILOnceCell;
use std::fmt;

use pyo3::prelude::*;
use pyo3::types::{PyList, PyString, PyTuple};
use pyo3::types::{PyList, PyTuple};
use serde::ser::SerializeSeq;
use serde::{Serialize, Serializer};

use crate::lookup_key::{LookupPath, PathItem};
use crate::tools::extract_i64;

/// Used to store individual items of the error location, e.g. a string for key/field names
/// or a number for array indices.
Expand All @@ -35,6 +34,12 @@ impl fmt::Display for LocItem {
}
}

// TODO rename to ToLocItem
pub trait AsLocItem {
// TODO rename to to_loc_item
fn as_loc_item(&self) -> LocItem;
}

impl From<String> for LocItem {
fn from(s: String) -> Self {
Self::S(s)
Expand Down Expand Up @@ -82,21 +87,6 @@ impl ToPyObject for LocItem {
}
}

impl TryFrom<&PyAny> for LocItem {
type Error = PyErr;

fn try_from(loc_item: &PyAny) -> PyResult<Self> {
if let Ok(py_str) = loc_item.downcast::<PyString>() {
let str = py_str.to_str()?.to_string();
Ok(Self::S(str))
} else if let Ok(int) = extract_i64(loc_item) {
Ok(Self::I(int))
} else {
Err(PyTypeError::new_err("Item in a location must be a string or int"))
}
}
}

impl Serialize for LocItem {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
Expand Down Expand Up @@ -211,9 +201,9 @@ impl TryFrom<Option<&PyAny>> for Location {
fn try_from(location: Option<&PyAny>) -> PyResult<Self> {
if let Some(location) = location {
let mut loc_vec: Vec<LocItem> = if let Ok(tuple) = location.downcast::<PyTuple>() {
tuple.iter().map(LocItem::try_from).collect::<PyResult<_>>()?
tuple.iter().map(AsLocItem::as_loc_item).collect()
} else if let Ok(list) = location.downcast::<PyList>() {
list.iter().map(LocItem::try_from).collect::<PyResult<_>>()?
list.iter().map(AsLocItem::as_loc_item).collect()
} else {
return Err(PyTypeError::new_err(
"Location must be a list or tuple of strings and ints",
Expand Down
2 changes: 1 addition & 1 deletion src/errors/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod validation_exception;
mod value_exception;

pub use self::line_error::{InputValue, ValError, ValLineError, ValResult};
pub use self::location::LocItem;
pub use self::location::{AsLocItem, LocItem};
pub use self::types::{list_all_errors, ErrorType, ErrorTypeDefaults, Number};
pub use self::validation_exception::ValidationError;
pub use self::value_exception::{PydanticCustomError, PydanticKnownError, PydanticOmit, PydanticUseDefault};
Expand Down
62 changes: 43 additions & 19 deletions src/errors/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fn field_from_context<'py, T: FromPyObject<'py>>(
) -> PyResult<T> {
context
.ok_or_else(|| py_error_type!(PyTypeError; "{}: '{}' required in context", enum_name, field_name))?
.get_item(field_name)
.get_item(field_name)?
.ok_or_else(|| py_error_type!(PyTypeError; "{}: '{}' required in context", enum_name, field_name))?
.extract::<T>()
.map_err(|_| py_error_type!(PyTypeError; "{}: '{}' context value must be a {}", enum_name, field_name, type_name_fn()))
Expand Down Expand Up @@ -445,8 +445,8 @@ macro_rules! to_string_render {
};
}

fn plural_s(value: usize) -> &'static str {
if value == 1 {
fn plural_s<T: From<u8> + PartialEq>(value: T) -> &'static str {
if value == 1.into() {
""
} else {
"s"
Expand Down Expand Up @@ -494,8 +494,8 @@ impl ErrorType {
Self::StringType {..} => "Input should be a valid string",
Self::StringSubType {..} => "Input should be a string, not an instance of a subclass of str",
Self::StringUnicode {..} => "Input should be a valid string, unable to parse raw data as a unicode string",
Self::StringTooShort {..} => "String should have at least {min_length} characters",
Self::StringTooLong {..} => "String should have at most {max_length} characters",
Self::StringTooShort {..} => "String should have at least {min_length} character{expected_plural}",
Self::StringTooLong {..} => "String should have at most {max_length} character{expected_plural}",
Self::StringPatternMismatch {..} => "String should match pattern '{pattern}'",
Self::Enum {..} => "Input should be {expected}",
Self::DictType {..} => "Input should be a valid dictionary",
Expand All @@ -512,8 +512,8 @@ impl ErrorType {
Self::FloatType {..} => "Input should be a valid number",
Self::FloatParsing {..} => "Input should be a valid number, unable to parse string as a number",
Self::BytesType {..} => "Input should be a valid bytes",
Self::BytesTooShort {..} => "Data should have at least {min_length} bytes",
Self::BytesTooLong {..} => "Data should have at most {max_length} bytes",
Self::BytesTooShort {..} => "Data should have at least {min_length} byte{expected_plural}",
Self::BytesTooLong {..} => "Data should have at most {max_length} byte{expected_plural}",
Self::ValueError {..} => "Value error, {error}",
Self::AssertionError {..} => "Assertion failed, {error}",
Self::CustomError {..} => "", // custom errors are handled separately
Expand Down Expand Up @@ -552,16 +552,16 @@ impl ErrorType {
Self::UrlType {..} => "URL input should be a string or URL",
Self::UrlParsing {..} => "Input should be a valid URL, {error}",
Self::UrlSyntaxViolation {..} => "Input violated strict URL syntax rules, {error}",
Self::UrlTooLong {..} => "URL should have at most {max_length} characters",
Self::UrlTooLong {..} => "URL should have at most {max_length} character{expected_plural}",
Self::UrlScheme {..} => "URL scheme should be {expected_schemes}",
Self::UuidType {..} => "UUID input should be a string, bytes or UUID object",
Self::UuidParsing {..} => "Input should be a valid UUID, {error}",
Self::UuidVersion {..} => "UUID version {expected_version} expected",
Self::DecimalType {..} => "Decimal input should be an integer, float, string or Decimal object",
Self::DecimalParsing {..} => "Input should be a valid decimal",
Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digits in total",
Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal places",
Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digits before the decimal point",
Self::DecimalMaxDigits {..} => "Decimal input should have no more than {max_digits} digit{expected_plural} in total",
Self::DecimalMaxPlaces {..} => "Decimal input should have no more than {decimal_places} decimal place{expected_plural}",
Self::DecimalWholeDigits {..} => "Decimal input should have no more than {whole_digits} digit{expected_plural} before the decimal point",
}
}

Expand Down Expand Up @@ -643,13 +643,25 @@ impl ErrorType {
to_string_render!(tmpl, field_type, max_length, actual_length, expected_plural,)
}
Self::IterationError { error, .. } => render!(tmpl, error),
Self::StringTooShort { min_length, .. } => to_string_render!(tmpl, min_length),
Self::StringTooLong { max_length, .. } => to_string_render!(tmpl, max_length),
Self::StringTooShort { min_length, .. } => {
let expected_plural = plural_s(*min_length);
to_string_render!(tmpl, min_length, expected_plural)
}
Self::StringTooLong { max_length, .. } => {
let expected_plural = plural_s(*max_length);
to_string_render!(tmpl, max_length, expected_plural)
}
Self::StringPatternMismatch { pattern, .. } => render!(tmpl, pattern),
Self::Enum { expected, .. } => to_string_render!(tmpl, expected),
Self::MappingType { error, .. } => render!(tmpl, error),
Self::BytesTooShort { min_length, .. } => to_string_render!(tmpl, min_length),
Self::BytesTooLong { max_length, .. } => to_string_render!(tmpl, max_length),
Self::BytesTooShort { min_length, .. } => {
let expected_plural = plural_s(*min_length);
to_string_render!(tmpl, min_length, expected_plural)
}
Self::BytesTooLong { max_length, .. } => {
let expected_plural = plural_s(*max_length);
to_string_render!(tmpl, max_length, expected_plural)
}
Self::ValueError { error, .. } => {
let error = &error
.as_ref()
Expand Down Expand Up @@ -688,13 +700,25 @@ impl ErrorType {
Self::UnionTagNotFound { discriminator, .. } => render!(tmpl, discriminator),
Self::UrlParsing { error, .. } => render!(tmpl, error),
Self::UrlSyntaxViolation { error, .. } => render!(tmpl, error),
Self::UrlTooLong { max_length, .. } => to_string_render!(tmpl, max_length),
Self::UrlTooLong { max_length, .. } => {
let expected_plural = plural_s(*max_length);
to_string_render!(tmpl, max_length, expected_plural)
}
Self::UrlScheme { expected_schemes, .. } => render!(tmpl, expected_schemes),
Self::UuidParsing { error, .. } => render!(tmpl, error),
Self::UuidVersion { expected_version, .. } => to_string_render!(tmpl, expected_version),
Self::DecimalMaxDigits { max_digits, .. } => to_string_render!(tmpl, max_digits),
Self::DecimalMaxPlaces { decimal_places, .. } => to_string_render!(tmpl, decimal_places),
Self::DecimalWholeDigits { whole_digits, .. } => to_string_render!(tmpl, whole_digits),
Self::DecimalMaxDigits { max_digits, .. } => {
let expected_plural = plural_s(*max_digits);
to_string_render!(tmpl, max_digits, expected_plural)
}
Self::DecimalMaxPlaces { decimal_places, .. } => {
let expected_plural = plural_s(*decimal_places);
to_string_render!(tmpl, decimal_places, expected_plural)
}
Self::DecimalWholeDigits { whole_digits, .. } => {
let expected_plural = plural_s(*whole_digits);
to_string_render!(tmpl, whole_digits, expected_plural)
}
_ => Ok(tmpl.to_string()),
}
}
Expand Down
12 changes: 6 additions & 6 deletions src/errors/validation_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use std::str::from_utf8;

use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError};
use pyo3::ffi;
use pyo3::intern;
use pyo3::once_cell::GILOnceCell;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyString};
use pyo3::{intern, AsPyPointer};
use serde::ser::{Error, SerializeMap, SerializeSeq};
use serde::{Serialize, Serializer};

Expand Down Expand Up @@ -324,12 +324,12 @@ impl ValidationError {
Some(indent) => {
let indent = vec![b' '; indent];
let formatter = PrettyFormatter::with_indent(&indent);
let mut ser = serde_json::Serializer::with_formatter(writer, formatter);
let mut ser = crate::serializers::ser::PythonSerializer::with_formatter(writer, formatter);
serializer.serialize(&mut ser).map_err(json_py_err)?;
ser.into_inner()
}
None => {
let mut ser = serde_json::Serializer::new(writer);
let mut ser = crate::serializers::ser::PythonSerializer::new(writer);
serializer.serialize(&mut ser).map_err(json_py_err)?;
ser.into_inner()
}
Expand Down Expand Up @@ -445,7 +445,7 @@ impl TryFrom<&PyAny> for PyLineError {
let py = value.py();

let type_raw = dict
.get_item(intern!(py, "type"))
.get_item(intern!(py, "type"))?
.ok_or_else(|| PyKeyError::new_err("type"))?;

let error_type = if let Ok(type_str) = type_raw.downcast::<PyString>() {
Expand All @@ -459,9 +459,9 @@ impl TryFrom<&PyAny> for PyLineError {
));
};

let location = Location::try_from(dict.get_item("loc"))?;
let location = Location::try_from(dict.get_item("loc")?)?;

let input_value = match dict.get_item("input") {
let input_value = match dict.get_item("input")? {
Some(i) => i.into_py(py),
None => py.None(),
};
Expand Down
4 changes: 2 additions & 2 deletions src/errors/value_exception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl PydanticCustomError {
}
}

#[getter(type)]
#[getter(r#type)]
pub fn error_type(&self) -> String {
self.error_type.clone()
}
Expand Down Expand Up @@ -147,7 +147,7 @@ impl PydanticKnownError {
Ok(Self { error_type })
}

#[getter(type)]
#[getter(r#type)]
pub fn error_type(&self) -> String {
self.error_type.to_string()
}
Expand Down
168 changes: 29 additions & 139 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ use pyo3::exceptions::PyValueError;
use pyo3::types::{PyDict, PyType};
use pyo3::{intern, prelude::*};

use crate::errors::{InputValue, LocItem, ValResult};
use jiter::JsonValue;

use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, ValError, ValResult};
use crate::tools::py_err;
use crate::{PyMultiHostUrl, PyUrl};

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, JsonInput};
use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, ValidationMatch};

#[derive(Debug, Clone, Copy)]
pub enum InputType {
Expand Down Expand Up @@ -46,9 +48,7 @@ impl TryFrom<&str> for InputType {
/// the convention is to either implement:
/// * `strict_*` & `lax_*` if they have different behavior
/// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same
pub trait Input<'a>: fmt::Debug + ToPyObject {
fn as_loc_item(&self) -> LocItem;

pub trait Input<'a>: fmt::Debug + ToPyObject + AsLocItem + Sized {
fn as_error_value(&'a self) -> InputValue<'a>;

fn identity(&self) -> Option<usize> {
Expand Down Expand Up @@ -89,87 +89,39 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {

fn validate_dataclass_args(&'a self, dataclass_name: &str) -> ValResult<'a, GenericArguments<'a>>;

fn parse_json(&'a self) -> ValResult<'a, JsonInput>;
fn parse_json(&'a self) -> ValResult<'a, JsonValue>;

fn validate_str(&'a self, strict: bool, coerce_numbers_to_str: bool) -> ValResult<EitherString<'a>> {
if strict {
self.strict_str()
} else {
self.lax_str(coerce_numbers_to_str)
}
}
fn strict_str(&'a self) -> ValResult<EitherString<'a>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_str(&'a self, _coerce_numbers_to_str: bool) -> ValResult<EitherString<'a>> {
self.strict_str()
}
fn validate_str(
&'a self,
strict: bool,
coerce_numbers_to_str: bool,
) -> ValResult<ValidationMatch<EitherString<'a>>>;

fn validate_bytes(&'a self, strict: bool) -> ValResult<EitherBytes<'a>> {
if strict {
self.strict_bytes()
} else {
self.lax_bytes()
}
}
fn strict_bytes(&'a self) -> ValResult<EitherBytes<'a>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_bytes(&'a self) -> ValResult<EitherBytes<'a>> {
self.strict_bytes()
}
fn validate_bytes(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a>>>;

fn validate_bool(&self, strict: bool) -> ValResult<bool> {
if strict {
self.strict_bool()
} else {
self.lax_bool()
}
}
fn strict_bool(&self) -> ValResult<bool>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_bool(&self) -> ValResult<bool> {
self.strict_bool()
}
fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch<bool>>;

fn validate_int(&'a self, strict: bool) -> ValResult<EitherInt<'a>> {
if strict {
self.strict_int()
} else {
self.lax_int()
}
}
fn strict_int(&'a self) -> ValResult<EitherInt<'a>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
self.strict_int()
}
fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>>;

/// Extract an EitherInt from the input, only allowing exact
/// matches for an Int (no subclasses)
fn exact_int(&'a self) -> ValResult<EitherInt<'a>> {
self.strict_int()
self.validate_int(true).and_then(|val_match| {
val_match
.require_exact()
.ok_or_else(|| ValError::new(ErrorTypeDefaults::IntType, self))
})
}

/// Extract a String from the input, only allowing exact
/// matches for a String (no subclasses)
fn exact_str(&'a self) -> ValResult<EitherString<'a>> {
self.strict_str()
self.validate_str(true, false).and_then(|val_match| {
val_match
.require_exact()
.ok_or_else(|| ValError::new(ErrorTypeDefaults::StringType, self))
})
}

fn validate_float(&'a self, strict: bool, ultra_strict: bool) -> ValResult<EitherFloat<'a>> {
if ultra_strict {
self.ultra_strict_float()
} else if strict {
self.strict_float()
} else {
self.lax_float()
}
}
fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>>;
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_float(&'a self) -> ValResult<EitherFloat<'a>> {
self.strict_float()
}
fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>>;

fn validate_decimal(&'a self, strict: bool, py: Python<'a>) -> ValResult<&'a PyAny> {
if strict {
Expand Down Expand Up @@ -257,87 +209,25 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {

fn validate_iter(&self) -> ValResult<GenericIterator>;

fn validate_date(&self, strict: bool) -> ValResult<EitherDate> {
if strict {
self.strict_date()
} else {
self.lax_date()
}
}
fn strict_date(&self) -> ValResult<EitherDate>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_date(&self) -> ValResult<EitherDate> {
self.strict_date()
}
fn validate_date(&self, strict: bool) -> ValResult<ValidationMatch<EitherDate>>;

fn validate_time(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTime> {
if strict {
self.strict_time(microseconds_overflow_behavior)
} else {
self.lax_time(microseconds_overflow_behavior)
}
}
fn strict_time(
&self,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTime>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_time(
&self,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTime> {
self.strict_time(microseconds_overflow_behavior)
}
) -> ValResult<ValidationMatch<EitherTime>>;

fn validate_datetime(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherDateTime> {
if strict {
self.strict_datetime(microseconds_overflow_behavior)
} else {
self.lax_datetime(microseconds_overflow_behavior)
}
}
fn strict_datetime(
&self,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherDateTime>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_datetime(
&self,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherDateTime> {
self.strict_datetime(microseconds_overflow_behavior)
}
) -> ValResult<ValidationMatch<EitherDateTime>>;

fn validate_timedelta(
&self,
strict: bool,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTimedelta> {
if strict {
self.strict_timedelta(microseconds_overflow_behavior)
} else {
self.lax_timedelta(microseconds_overflow_behavior)
}
}
fn strict_timedelta(
&self,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTimedelta>;
#[cfg_attr(has_coverage_attribute, coverage(off))]
fn lax_timedelta(
&self,
microseconds_overflow_behavior: speedate::MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTimedelta> {
self.strict_timedelta(microseconds_overflow_behavior)
}
) -> ValResult<ValidationMatch<EitherTimedelta>>;
}

/// The problem to solve here is that iterating a `StringMapping` returns an owned
Expand Down
328 changes: 148 additions & 180 deletions src/input/input_json.rs

Large diffs are not rendered by default.

538 changes: 282 additions & 256 deletions src/input/input_python.rs

Large diffs are not rendered by default.

79 changes: 39 additions & 40 deletions src/input/input_string.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString};

use jiter::JsonValue;
use speedate::MicrosecondsPrecisionOverflowBehavior;

use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::errors::{AsLocItem, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::input::py_string_str;
use crate::tools::safe_repr;
use crate::validators::decimal::create_decimal;
Expand All @@ -14,7 +15,7 @@ use super::datetime::{
use super::shared::{map_json_err, str_as_bool, str_as_float};
use super::{
BorrowInput, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable,
GenericIterator, GenericMapping, Input, JsonInput,
GenericIterator, GenericMapping, Input, ValidationMatch,
};

#[derive(Debug)]
Expand Down Expand Up @@ -52,14 +53,16 @@ impl<'py> StringMapping<'py> {
}
}

impl<'a> Input<'a> for StringMapping<'a> {
impl AsLocItem for StringMapping<'_> {
fn as_loc_item(&self) -> LocItem {
match self {
Self::String(s) => s.to_string_lossy().as_ref().into(),
Self::Mapping(d) => safe_repr(d).to_string().into(),
}
}
}

impl<'a> Input<'a> for StringMapping<'a> {
fn as_error_value(&'a self) -> InputValue<'a> {
match self {
Self::String(s) => s.as_error_value(),
Expand All @@ -83,64 +86,54 @@ impl<'a> Input<'a> for StringMapping<'a> {
}
}

fn parse_json(&'a self) -> ValResult<'a, JsonInput> {
fn parse_json(&'a self) -> ValResult<'a, JsonValue> {
match self {
Self::String(s) => {
let str = py_string_str(s)?;
serde_json::from_str(str).map_err(|e| map_json_err(self, e))
JsonValue::parse(str.as_bytes(), true).map_err(|e| map_json_err(self, e))
}
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::JsonType, self)),
}
}

fn strict_str(&'a self) -> ValResult<EitherString<'a>> {
fn validate_str(
&'a self,
_strict: bool,
_coerce_numbers_to_str: bool,
) -> ValResult<ValidationMatch<EitherString<'a>>> {
match self {
Self::String(s) => Ok((*s).into()),
Self::String(s) => Ok(ValidationMatch::strict((*s).into())),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::StringType, self)),
}
}

fn strict_bytes(&'a self) -> ValResult<EitherBytes<'a>> {
fn validate_bytes(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a>>> {
match self {
Self::String(s) => py_string_str(s).map(|b| b.as_bytes().into()),
Self::String(s) => py_string_str(s).map(|b| ValidationMatch::strict(b.as_bytes().into())),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BytesType, self)),
}
}

fn lax_bytes(&'a self) -> ValResult<EitherBytes<'a>> {
fn validate_bool(&self, _strict: bool) -> ValResult<'_, ValidationMatch<bool>> {
match self {
Self::String(s) => {
let str = py_string_str(s)?;
Ok(str.as_bytes().into())
}
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BytesType, self)),
}
}

fn strict_bool(&self) -> ValResult<bool> {
match self {
Self::String(s) => str_as_bool(self, py_string_str(s)?),
Self::String(s) => str_as_bool(self, py_string_str(s)?).map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BoolType, self)),
}
}

fn strict_int(&'a self) -> ValResult<EitherInt<'a>> {
fn validate_int(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>> {
match self {
Self::String(s) => match py_string_str(s)?.parse() {
Ok(i) => Ok(EitherInt::I64(i)),
Ok(i) => Ok(ValidationMatch::strict(EitherInt::I64(i))),
Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)),
},
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::IntType, self)),
}
}

fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
self.strict_float()
}

fn strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
fn validate_float(&'a self, _strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>> {
match self {
Self::String(s) => str_as_float(self, py_string_str(s)?),
Self::String(s) => str_as_float(self, py_string_str(s)?).map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::FloatType, self)),
}
}
Expand Down Expand Up @@ -183,39 +176,45 @@ impl<'a> Input<'a> for StringMapping<'a> {
Err(ValError::new(ErrorTypeDefaults::IterableType, self))
}

fn strict_date(&self) -> ValResult<EitherDate> {
fn validate_date(&self, _strict: bool) -> ValResult<ValidationMatch<EitherDate>> {
match self {
Self::String(s) => bytes_as_date(self, py_string_str(s)?.as_bytes()),
Self::String(s) => bytes_as_date(self, py_string_str(s)?.as_bytes()).map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DateType, self)),
}
}

fn strict_time(
fn validate_time(
&self,
_strict: bool,
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTime> {
) -> ValResult<ValidationMatch<EitherTime>> {
match self {
Self::String(s) => bytes_as_time(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior),
Self::String(s) => bytes_as_time(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior)
.map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeType, self)),
}
}

fn strict_datetime(
fn validate_datetime(
&self,
_strict: bool,
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherDateTime> {
) -> ValResult<ValidationMatch<EitherDateTime>> {
match self {
Self::String(s) => bytes_as_datetime(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior),
Self::String(s) => bytes_as_datetime(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior)
.map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DatetimeType, self)),
}
}

fn strict_timedelta(
fn validate_timedelta(
&self,
_strict: bool,
microseconds_overflow_behavior: MicrosecondsPrecisionOverflowBehavior,
) -> ValResult<EitherTimedelta> {
) -> ValResult<ValidationMatch<EitherTimedelta>> {
match self {
Self::String(s) => bytes_as_timedelta(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior),
Self::String(s) => bytes_as_timedelta(self, py_string_str(s)?.as_bytes(), microseconds_overflow_behavior)
.map(ValidationMatch::strict),
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::TimeDeltaType, self)),
}
}
Expand Down
4 changes: 1 addition & 3 deletions src/input/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ mod input_abstract;
mod input_json;
mod input_python;
mod input_string;
mod parse_json;
mod return_enums;
mod shared;

Expand All @@ -18,11 +17,10 @@ pub(crate) use datetime::{
};
pub(crate) use input_abstract::{BorrowInput, Input, InputType};
pub(crate) use input_string::StringMapping;
pub(crate) use parse_json::{JsonArray, JsonInput, JsonObject};
pub(crate) use return_enums::{
py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherFloat, EitherInt, EitherString,
GenericArguments, GenericIterable, GenericIterator, GenericMapping, Int, JsonArgs, JsonObjectGenericIterator,
MappingGenericIterator, PyArgs, StringMappingGenericIterator,
MappingGenericIterator, PyArgs, StringMappingGenericIterator, ValidationMatch,
};

// Defined here as it's not exported by pyo3
Expand Down
222 changes: 0 additions & 222 deletions src/input/parse_json.rs

This file was deleted.

53 changes: 43 additions & 10 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::ops::Rem;
use std::slice::Iter as SliceIter;
use std::str::FromStr;

use jiter::{JsonArray, JsonObject, JsonValue};
use num_bigint::BigInt;

use pyo3::exceptions::PyTypeError;
Expand All @@ -13,7 +14,7 @@ use pyo3::types::{
PyByteArray, PyBytes, PyDict, PyFloat, PyFrozenSet, PyIterator, PyList, PyMapping, PySequence, PySet, PyString,
PyTuple,
};
use pyo3::{ffi, intern, AsPyPointer, PyNativeType};
use pyo3::{ffi, intern, PyNativeType};

#[cfg(not(PyPy))]
use pyo3::types::PyFunction;
Expand All @@ -23,12 +24,44 @@ use serde::{ser::Error, Serialize, Serializer};

use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult};
use crate::tools::py_err;
use crate::validators::{CombinedValidator, ValidationState, Validator};
use crate::validators::{CombinedValidator, Exactness, ValidationState, Validator};

use super::input_string::StringMapping;
use super::parse_json::{JsonArray, JsonInput, JsonObject};
use super::{py_error_on_minusone, Input};

pub struct ValidationMatch<T>(T, Exactness);

impl<T> ValidationMatch<T> {
pub const fn new(value: T, exactness: Exactness) -> Self {
Self(value, exactness)
}

pub const fn exact(value: T) -> Self {
Self(value, Exactness::Exact)
}

pub const fn strict(value: T) -> Self {
Self(value, Exactness::Strict)
}

pub const fn lax(value: T) -> Self {
Self(value, Exactness::Lax)
}

pub fn require_exact(self) -> Option<T> {
(self.1 == Exactness::Exact).then_some(self.0)
}

pub fn unpack(self, state: &mut ValidationState) -> T {
state.floor_exactness(self.1);
self.0
}

pub fn into_inner(self) -> T {
self.0
}
}

/// Container for all the collections (sized iterable containers) types, which
/// can mostly be converted to each other in lax mode.
/// This mostly matches python's definition of `Collection`.
Expand All @@ -50,7 +83,7 @@ pub enum GenericIterable<'a> {
PyByteArray(&'a PyByteArray),
Sequence(&'a PySequence),
Iterator(&'a PyIterator),
JsonArray(&'a [JsonInput]),
JsonArray(&'a [JsonValue]),
JsonObject(&'a JsonObject),
JsonString(&'a String),
}
Expand Down Expand Up @@ -573,7 +606,7 @@ impl<'py> Iterator for AttributesGenericIterator<'py> {
}

pub struct JsonObjectGenericIterator<'py> {
object_iter: SliceIter<'py, (String, JsonInput)>,
object_iter: SliceIter<'py, (String, JsonValue)>,
}

impl<'py> JsonObjectGenericIterator<'py> {
Expand All @@ -585,7 +618,7 @@ impl<'py> JsonObjectGenericIterator<'py> {
}

impl<'py> Iterator for JsonObjectGenericIterator<'py> {
type Item = ValResult<'py, (&'py String, &'py JsonInput)>;
type Item = ValResult<'py, (&'py String, &'py JsonValue)>;

fn next(&mut self) -> Option<Self::Item> {
self.object_iter.next().map(|(key, value)| Ok((key, value)))
Expand Down Expand Up @@ -653,7 +686,7 @@ pub struct GenericJsonIterator {
}

impl GenericJsonIterator {
pub fn next(&mut self, _py: Python) -> PyResult<Option<(&JsonInput, usize)>> {
pub fn next(&mut self, _py: Python) -> PyResult<Option<(&JsonValue, usize)>> {
if self.index < self.array.len() {
// panic here is impossible due to bounds check above; compiler should be
// able to optimize it away even
Expand All @@ -667,7 +700,7 @@ impl GenericJsonIterator {
}

pub fn input_as_error_value<'py>(&self, _py: Python<'py>) -> InputValue<'py> {
InputValue::JsonInput(JsonInput::Array(self.array.clone()))
InputValue::JsonInput(JsonValue::Array(self.array.clone()))
}

pub fn index(&self) -> usize {
Expand All @@ -689,12 +722,12 @@ impl<'a> PyArgs<'a> {

#[cfg_attr(debug_assertions, derive(Debug))]
pub struct JsonArgs<'a> {
pub args: Option<&'a [JsonInput]>,
pub args: Option<&'a [JsonValue]>,
pub kwargs: Option<&'a JsonObject>,
}

impl<'a> JsonArgs<'a> {
pub fn new(args: Option<&'a [JsonInput]>, kwargs: Option<&'a JsonObject>) -> Self {
pub fn new(args: Option<&'a [JsonValue]>, kwargs: Option<&'a JsonObject>) -> Self {
Self { args, kwargs }
}
}
Expand Down
24 changes: 17 additions & 7 deletions src/input/shared.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
use pyo3::sync::GILOnceCell;
use pyo3::{intern, Py, PyAny, Python, ToPyObject};

use jiter::JsonValueError;
use num_bigint::BigInt;
use pyo3::{intern, PyAny, Python};

use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValResult};

use super::parse_json::{JsonArray, JsonInput};
use super::{EitherFloat, EitherInt, Input};
static ENUM_META_OBJECT: GILOnceCell<Py<PyAny>> = GILOnceCell::new();

pub fn get_enum_meta_object(py: Python) -> Py<PyAny> {
ENUM_META_OBJECT
.get_or_init(py, || {
py.import(intern!(py, "enum"))
.and_then(|enum_module| enum_module.getattr(intern!(py, "EnumMeta")))
.unwrap()
.to_object(py)
})
.clone()
}

pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: serde_json::Error) -> ValError<'a> {
pub fn map_json_err<'a>(input: &'a impl Input<'a>, error: JsonValueError) -> ValError<'a> {
ValError::new(
ErrorType::JsonInvalid {
error: error.to_string(),
Expand Down Expand Up @@ -150,7 +164,3 @@ pub fn decimal_as_int<'a>(py: Python, input: &'a impl Input<'a>, decimal: &'a Py
}
Ok(EitherInt::Py(numerator))
}

pub fn string_to_vec(s: &str) -> JsonArray {
JsonArray::new(s.chars().map(|c| JsonInput::String(c.to_string())).collect())
}
63 changes: 0 additions & 63 deletions src/lazy_index_map.rs

This file was deleted.

17 changes: 16 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ extern crate core;

use std::sync::OnceLock;

use pyo3::exceptions::PyTypeError;
use pyo3::types::{PyByteArray, PyBytes, PyString};
use pyo3::{prelude::*, sync::GILOnceCell};

// parse this first to get access to the contained macro
Expand All @@ -15,7 +17,6 @@ mod build_tools;
mod definitions;
mod errors;
mod input;
mod lazy_index_map;
mod lookup_key;
mod recursion_guard;
mod serializers;
Expand All @@ -36,6 +37,19 @@ pub use serializers::{
};
pub use validators::{validate_core_schema, PySome, SchemaValidator};

#[pyfunction(signature = (data, *, allow_inf_nan=true))]
pub fn from_json(py: Python, data: &PyAny, allow_inf_nan: bool) -> PyResult<PyObject> {
if let Ok(py_bytes) = data.downcast::<PyBytes>() {
jiter::python_parse(py, py_bytes.as_bytes(), allow_inf_nan)
} else if let Ok(py_str) = data.downcast::<PyString>() {
jiter::python_parse(py, py_str.to_str()?.as_bytes(), allow_inf_nan)
} else if let Ok(py_byte_array) = data.downcast::<PyByteArray>() {
jiter::python_parse(py, &py_byte_array.to_vec(), allow_inf_nan)
} else {
Err(PyTypeError::new_err("Expected bytes, bytearray or str"))
}
}

pub fn get_pydantic_core_version() -> &'static str {
static PYDANTIC_CORE_VERSION: OnceLock<String> = OnceLock::new();

Expand Down Expand Up @@ -95,6 +109,7 @@ fn _pydantic_core(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<SchemaSerializer>()?;
m.add_class::<TzInfo>()?;
m.add_function(wrap_pyfunction!(to_json, m)?)?;
m.add_function(wrap_pyfunction!(from_json, m)?)?;
m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?;
m.add_function(wrap_pyfunction!(list_all_errors, m)?)?;
m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?;
Expand Down
24 changes: 13 additions & 11 deletions src/lookup_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ use pyo3::exceptions::{PyAttributeError, PyTypeError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyMapping, PyString};

use jiter::{JsonObject, JsonValue};

use crate::build_tools::py_schema_err;
use crate::errors::{py_err_string, ErrorType, ValError, ValLineError, ValResult};
use crate::input::{Input, JsonInput, JsonObject, StringMapping};
use crate::input::{Input, StringMapping};
use crate::tools::{extract_i64, py_err};

/// Used for getting items from python dicts, python objects, or JSON objects, in different ways
Expand Down Expand Up @@ -111,7 +113,7 @@ impl LookupKey {
dict: &'data PyDict,
) -> ValResult<'data, Option<(&'s LookupPath, &'data PyAny)>> {
match self {
Self::Simple { py_key, path, .. } => match dict.get_item(py_key) {
Self::Simple { py_key, path, .. } => match dict.get_item(py_key)? {
Some(value) => Ok(Some((path, value))),
None => Ok(None),
},
Expand All @@ -121,9 +123,9 @@ impl LookupKey {
py_key2,
path2,
..
} => match dict.get_item(py_key1) {
} => match dict.get_item(py_key1)? {
Some(value) => Ok(Some((path1, value))),
None => match dict.get_item(py_key2) {
None => match dict.get_item(py_key2)? {
Some(value) => Ok(Some((path2, value))),
None => Ok(None),
},
Expand Down Expand Up @@ -264,7 +266,7 @@ impl LookupKey {
pub fn json_get<'data, 's>(
&'s self,
dict: &'data JsonObject,
) -> ValResult<'data, Option<(&'s LookupPath, &'data JsonInput)>> {
) -> ValResult<'data, Option<(&'s LookupPath, &'data JsonValue)>> {
match self {
Self::Simple { key, path, .. } => match dict.get(key) {
Some(value) => Ok(Some((path, value))),
Expand All @@ -289,13 +291,13 @@ impl LookupKey {

// first step is different from the rest as we already know dict is JsonObject
// because of above checks, we know that path should have at least one element, hence unwrap
let v: &JsonInput = match path_iter.next().unwrap().json_obj_get(dict) {
let v: &JsonValue = match path_iter.next().unwrap().json_obj_get(dict) {
Some(v) => v,
None => continue,
};

// similar to above
// iterate over the path and plug each value into the JsonInput from the last step, starting with v
// iterate over the path and plug each value into the JsonValue from the last step, starting with v
// from the first step, this could just be a loop but should be somewhat faster with a functional design
if let Some(v) = path_iter.try_fold(v, |d, loc| loc.json_get(d)) {
// Successfully found an item, return it
Expand Down Expand Up @@ -481,10 +483,10 @@ impl PathItem {
}
}

pub fn json_get<'a>(&self, any_json: &'a JsonInput) -> Option<&'a JsonInput> {
pub fn json_get<'a>(&self, any_json: &'a JsonValue) -> Option<&'a JsonValue> {
match any_json {
JsonInput::Object(v_obj) => self.json_obj_get(v_obj),
JsonInput::Array(v_array) => match self {
JsonValue::Object(v_obj) => self.json_obj_get(v_obj),
JsonValue::Array(v_array) => match self {
Self::Pos(index) => v_array.get(*index),
Self::Neg(index) => {
if let Some(index) = v_array.len().checked_sub(*index) {
Expand All @@ -499,7 +501,7 @@ impl PathItem {
}
}

pub fn json_obj_get<'a>(&self, json_obj: &'a JsonObject) -> Option<&'a JsonInput> {
pub fn json_obj_get<'a>(&self, json_obj: &'a JsonObject) -> Option<&'a JsonValue> {
match self {
Self::S(key, _) => json_obj.get(key),
_ => None,
Expand Down
8 changes: 8 additions & 0 deletions src/py_gc.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use ahash::AHashMap;
use enum_dispatch::enum_dispatch;
use pyo3::{AsPyPointer, Py, PyTraverseError, PyVisit};
Expand Down Expand Up @@ -35,6 +37,12 @@ impl<T: PyGcTraverse> PyGcTraverse for AHashMap<String, T> {
}
}

impl<T: PyGcTraverse> PyGcTraverse for Arc<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
T::py_gc_traverse(self, visit)
}
}

impl<T: PyGcTraverse> PyGcTraverse for Box<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
T::py_gc_traverse(self, visit)
Expand Down
42 changes: 41 additions & 1 deletion src/serializers/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ pub(crate) enum BytesMode {
#[default]
Utf8,
Base64,
Hex,
}

impl FromStr for BytesMode {
Expand All @@ -138,7 +139,11 @@ impl FromStr for BytesMode {
match s {
"utf8" => Ok(Self::Utf8),
"base64" => Ok(Self::Base64),
s => py_schema_err!("Invalid bytes serialization mode: `{}`, expected `utf8` or `base64`", s),
"hex" => Ok(Self::Hex),
s => py_schema_err!(
"Invalid bytes serialization mode: `{}`, expected `utf8`, `base64` or `hex`",
s
),
}
}
}
Expand All @@ -158,6 +163,9 @@ impl BytesMode {
.map_err(|err| utf8_py_error(py, err, bytes))
.map(Cow::Borrowed),
Self::Base64 => Ok(Cow::Owned(base64::engine::general_purpose::URL_SAFE.encode(bytes))),
Self::Hex => Ok(Cow::Owned(
bytes.iter().fold(String::new(), |acc, b| acc + &format!("{b:02x}")),
)),
}
}

Expand All @@ -168,6 +176,9 @@ impl BytesMode {
Err(e) => Err(Error::custom(e.to_string())),
},
Self::Base64 => serializer.serialize_str(&base64::engine::general_purpose::URL_SAFE.encode(bytes)),
Self::Hex => {
serializer.serialize_str(&bytes.iter().fold(String::new(), |acc, b| acc + &format!("{b:02x}")))
}
}
}
}
Expand All @@ -178,3 +189,32 @@ pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr {
Err(err) => err,
}
}

#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub(crate) enum InfNanMode {
#[default]
Null,
Constants,
}

impl FromStr for InfNanMode {
type Err = PyErr;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"null" => Ok(Self::Null),
"constants" => Ok(Self::Constants),
s => py_schema_err!(
"Invalid inf_nan serialization mode: `{}`, expected `null` or `constants`",
s
),
}
}
}

impl FromPyObject<'_> for InfNanMode {
fn extract(ob: &'_ PyAny) -> PyResult<Self> {
let s = ob.extract::<&str>()?;
Self::from_str(s)
}
}
27 changes: 26 additions & 1 deletion src/serializers/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,33 @@ pub(super) fn py_err_se_err<T: ser::Error, E: fmt::Display>(py_error: E) -> T {
T::custom(py_error.to_string())
}

#[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")]
#[derive(Debug, Clone)]
pub struct PythonSerializerError {
pub message: String,
}

impl fmt::Display for PythonSerializerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}

impl std::error::Error for PythonSerializerError {}

impl serde::ser::Error for PythonSerializerError {
fn custom<T>(msg: T) -> Self
where
T: fmt::Display,
{
PythonSerializerError {
message: format!("{msg}"),
}
}
}

/// convert a serde serialization error into a `PyErr`
pub(super) fn se_err_py_err(error: serde_json::Error) -> PyErr {
pub(super) fn se_err_py_err(error: PythonSerializerError) -> PyErr {
let s = error.to_string();
if let Some(msg) = s.strip_prefix(UNEXPECTED_TYPE_SER_MARKER) {
if msg.is_empty() {
Expand Down
11 changes: 1 addition & 10 deletions src/serializers/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@ use std::cell::RefCell;
use std::fmt;

use pyo3::exceptions::PyValueError;
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::{intern, AsPyPointer};

use serde::ser::Error;

use super::config::SerializationConfig;
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
use super::ob_type::ObTypeLookup;
use super::shared::CombinedSerializer;
use crate::definitions::Definitions;
use crate::recursion_guard::RecursionGuard;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
Expand Down Expand Up @@ -48,7 +46,6 @@ impl SerializationState {
Extra::new(
py,
mode,
&[],
by_alias,
&self.warnings,
false,
Expand All @@ -72,7 +69,6 @@ impl SerializationState {
#[cfg_attr(debug_assertions, derive(Debug))]
pub(crate) struct Extra<'a> {
pub mode: &'a SerMode,
pub definitions: &'a Definitions<CombinedSerializer>,
pub ob_type_lookup: &'a ObTypeLookup,
pub warnings: &'a CollectWarnings,
pub by_alias: bool,
Expand All @@ -98,7 +94,6 @@ impl<'a> Extra<'a> {
pub fn new(
py: Python<'a>,
mode: &'a SerMode,
definitions: &'a Definitions<CombinedSerializer>,
by_alias: bool,
warnings: &'a CollectWarnings,
exclude_unset: bool,
Expand All @@ -112,7 +107,6 @@ impl<'a> Extra<'a> {
) -> Self {
Self {
mode,
definitions,
ob_type_lookup: ObTypeLookup::cached(py),
warnings,
by_alias,
Expand Down Expand Up @@ -156,7 +150,6 @@ impl SerCheck {
#[cfg_attr(debug_assertions, derive(Debug))]
pub(crate) struct ExtraOwned {
mode: SerMode,
definitions: Vec<CombinedSerializer>,
warnings: CollectWarnings,
by_alias: bool,
exclude_unset: bool,
Expand All @@ -176,7 +169,6 @@ impl ExtraOwned {
pub fn new(extra: &Extra) -> Self {
Self {
mode: extra.mode.clone(),
definitions: extra.definitions.to_vec(),
warnings: extra.warnings.clone(),
by_alias: extra.by_alias,
exclude_unset: extra.exclude_unset,
Expand All @@ -196,7 +188,6 @@ impl ExtraOwned {
pub fn to_extra<'py>(&'py self, py: Python<'py>) -> Extra<'py> {
Extra {
mode: &self.mode,
definitions: &self.definitions,
ob_type_lookup: ObTypeLookup::cached(py),
warnings: &self.warnings,
by_alias: self.by_alias,
Expand Down
10 changes: 5 additions & 5 deletions src/serializers/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ impl SchemaFilter<usize> {
let py = schema.py();
match schema.get_as::<&PyDict>(intern!(py, "serialization"))? {
Some(ser) => {
let include = Self::build_set_ints(ser.get_item(intern!(py, "include")))?;
let exclude = Self::build_set_ints(ser.get_item(intern!(py, "exclude")))?;
let include = Self::build_set_ints(ser.get_item(intern!(py, "include"))?)?;
let exclude = Self::build_set_ints(ser.get_item(intern!(py, "exclude"))?)?;
Ok(Self { include, exclude })
}
None => Ok(SchemaFilter::default()),
Expand Down Expand Up @@ -325,8 +325,8 @@ fn is_ellipsis_like(v: &PyAny) -> bool {

/// lookup the dict, for the key and "__all__" key, and merge them following the same rules as pydantic V1
fn merge_all_value(dict: &PyDict, py_key: impl ToPyObject + Copy) -> PyResult<Option<&PyAny>> {
let op_item_value = dict.get_item(py_key);
let op_all_value = dict.get_item(intern!(dict.py(), "__all__"));
let op_item_value = dict.get_item(py_key)?;
let op_all_value = dict.get_item(intern!(dict.py(), "__all__"))?;

match (op_item_value, op_all_value) {
(Some(item_value), Some(all_value)) => {
Expand Down Expand Up @@ -365,7 +365,7 @@ fn merge_dicts<'py>(item_dict: &'py PyDict, all_value: &'py PyAny) -> PyResult<&
let item_dict = item_dict.copy()?;
if let Ok(all_dict) = all_value.downcast::<PyDict>() {
for (all_key, all_value) in all_dict {
if let Some(item_value) = item_dict.get_item(all_key) {
if let Some(item_value) = item_dict.get_item(all_key)? {
if is_ellipsis_like(item_value) {
continue;
}
Expand Down
36 changes: 27 additions & 9 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict};
use pyo3::{PyTraverseError, PyVisit};

use crate::definitions::DefinitionsBuilder;
use crate::definitions::{Definitions, DefinitionsBuilder};
use crate::py_gc::PyGcTraverse;

use config::SerializationConfig;
Expand All @@ -23,16 +23,21 @@ mod fields;
mod filter;
mod infer;
mod ob_type;
pub mod ser;
mod shared;
mod type_serializers;

#[pyclass(module = "pydantic_core._pydantic_core")]
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
#[derive(Debug)]
pub struct SchemaSerializer {
serializer: CombinedSerializer,
definitions: Vec<CombinedSerializer>,
definitions: Definitions<CombinedSerializer>,
expected_json_size: AtomicUsize,
config: SerializationConfig,
// References to the Python schema and config objects are saved to enable
// reconstructing the object for pickle support (see `__reduce__`).
py_schema: Py<PyDict>,
py_config: Option<Py<PyDict>>,
}

impl SchemaSerializer {
Expand All @@ -54,7 +59,6 @@ impl SchemaSerializer {
Extra::new(
py,
mode,
&self.definitions,
by_alias,
warnings,
exclude_unset,
Expand All @@ -72,15 +76,19 @@ impl SchemaSerializer {
#[pymethods]
impl SchemaSerializer {
#[new]
pub fn py_new(schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
pub fn py_new(py: Python, schema: &PyDict, config: Option<&PyDict>) -> PyResult<Self> {
let mut definitions_builder = DefinitionsBuilder::new();

let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
Ok(Self {
serializer,
definitions: definitions_builder.finish()?,
expected_json_size: AtomicUsize::new(1024),
config: SerializationConfig::from_config(config)?,
py_schema: schema.into_py(py),
py_config: match config {
Some(c) if !c.is_empty() => Some(c.into_py(py)),
_ => None,
},
})
}

Expand Down Expand Up @@ -175,6 +183,14 @@ impl SchemaSerializer {
Ok(py_bytes.into())
}

pub fn __reduce__(slf: &PyCell<Self>) -> PyResult<(PyObject, (PyObject, PyObject))> {
// Enables support for `pickle` serialization.
let py = slf.py();
let cls = slf.get_type().into();
let init_args = (slf.get().py_schema.to_object(py), slf.get().py_config.to_object(py));
Ok((cls, init_args))
}

pub fn __repr__(&self) -> String {
format!(
"SchemaSerializer(serializer={:#?}, definitions={:#?})",
Expand All @@ -183,10 +199,12 @@ impl SchemaSerializer {
}

fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
self.serializer.py_gc_traverse(&visit)?;
for slot in &self.definitions {
slot.py_gc_traverse(&visit)?;
visit.call(&self.py_schema)?;
if let Some(ref py_config) = self.py_config {
visit.call(py_config)?;
}
self.serializer.py_gc_traverse(&visit)?;
self.definitions.py_gc_traverse(&visit)?;
Ok(())
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/serializers/ob_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,9 @@ impl ObTypeLookup {
fn is_enum(&self, op_value: Option<&PyAny>, py_type: &PyType) -> bool {
// only test on the type itself, not base types
if op_value.is_some() {
let enum_meta_type = self.enum_object.as_ref(py_type.py()).get_type();
let meta_type = py_type.get_type();
meta_type.is(&self.enum_object)
meta_type.is(enum_meta_type)
} else {
false
}
Expand Down Expand Up @@ -332,6 +333,7 @@ fn is_dataclass(op_value: Option<&PyAny>) -> bool {
value
.hasattr(intern!(value.py(), "__dataclass_fields__"))
.unwrap_or(false)
&& !value.is_instance_of::<PyType>()
} else {
false
}
Expand All @@ -342,6 +344,7 @@ fn is_pydantic_serializable(op_value: Option<&PyAny>) -> bool {
value
.hasattr(intern!(value.py(), "__pydantic_serializer__"))
.unwrap_or(false)
&& !value.is_instance_of::<PyType>()
} else {
false
}
Expand Down
1,299 changes: 1,299 additions & 0 deletions src/serializers/ser.rs

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ use serde_json::ser::PrettyFormatter;

use crate::build_tools::py_schema_err;
use crate::build_tools::py_schema_error_type;
use crate::definitions::{Definitions, DefinitionsBuilder};
use crate::definitions::DefinitionsBuilder;
use crate::py_gc::PyGcTraverse;
use crate::serializers::ser::PythonSerializer;
use crate::tools::{py_err, SchemaDict};

use super::errors::se_err_py_err;
Expand Down Expand Up @@ -112,7 +113,7 @@ combined_serializer! {
Nullable: super::type_serializers::nullable::NullableSerializer;
Int: super::type_serializers::simple::IntSerializer;
Bool: super::type_serializers::simple::BoolSerializer;
Float: super::type_serializers::simple::FloatSerializer;
Float: super::type_serializers::float::FloatSerializer;
Decimal: super::type_serializers::decimal::DecimalSerializer;
Str: super::type_serializers::string::StrSerializer;
Bytes: super::type_serializers::bytes::BytesSerializer;
Expand Down Expand Up @@ -293,7 +294,7 @@ pub(crate) trait TypeSerializer: Send + Sync + Clone + Debug {
fn get_name(&self) -> &str;

/// Used by union serializers to decide if it's worth trying again while allowing subclasses
fn retry_with_lax_check(&self, _definitions: &Definitions<CombinedSerializer>) -> bool {
fn retry_with_lax_check(&self) -> bool {
false
}

Expand Down Expand Up @@ -352,12 +353,12 @@ pub(crate) fn to_json_bytes(
Some(indent) => {
let indent = vec![b' '; indent];
let formatter = PrettyFormatter::with_indent(&indent);
let mut ser = serde_json::Serializer::with_formatter(writer, formatter);
let mut ser = PythonSerializer::with_formatter(writer, formatter);
serializer.serialize(&mut ser).map_err(se_err_py_err)?;
ser.into_inner()
}
None => {
let mut ser = serde_json::Serializer::new(writer);
let mut ser = PythonSerializer::new(writer);
serializer.serialize(&mut ser).map_err(se_err_py_err)?;
ser.into_inner()
}
Expand Down
4 changes: 2 additions & 2 deletions src/serializers/type_serializers/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::borrow::Cow;
use ahash::AHashMap;

use crate::build_tools::{py_schema_error_type, ExtraBehavior};
use crate::definitions::{Definitions, DefinitionsBuilder};
use crate::definitions::DefinitionsBuilder;
use crate::tools::SchemaDict;

use super::{
Expand Down Expand Up @@ -179,7 +179,7 @@ impl TypeSerializer for DataclassSerializer {
&self.name
}

fn retry_with_lax_check(&self, _definitions: &Definitions<CombinedSerializer>) -> bool {
fn retry_with_lax_check(&self) -> bool {
true
}
}
32 changes: 17 additions & 15 deletions src/serializers/type_serializers/definitions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};

use crate::definitions::Definitions;
use crate::definitions::DefinitionRef;
use crate::definitions::DefinitionsBuilder;

use crate::tools::SchemaDict;
Expand Down Expand Up @@ -41,7 +41,7 @@ impl BuildSerializer for DefinitionsSerializerBuilder {

#[derive(Debug, Clone)]
pub struct DefinitionRefSerializer {
serializer_id: usize,
definition: DefinitionRef<CombinedSerializer>,
}

impl BuildSerializer for DefinitionRefSerializer {
Expand All @@ -52,9 +52,9 @@ impl BuildSerializer for DefinitionRefSerializer {
_config: Option<&PyDict>,
definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let schema_ref: String = schema.get_as_req(intern!(schema.py(), "schema_ref"))?;
let serializer_id = definitions.get_reference_id(&schema_ref);
Ok(Self { serializer_id }.into())
let schema_ref = schema.get_as_req(intern!(schema.py(), "schema_ref"))?;
let definition = definitions.get_definition(schema_ref);
Ok(Self { definition }.into())
}
}

Expand All @@ -68,15 +68,15 @@ impl TypeSerializer for DefinitionRefSerializer {
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let value_id = extra.rec_guard.add(value, self.serializer_id)?;
let comb_serializer = extra.definitions.get(self.serializer_id).unwrap();
let comb_serializer = self.definition.get().unwrap();
let value_id = extra.rec_guard.add(value, self.definition.id())?;
let r = comb_serializer.to_python(value, include, exclude, extra);
extra.rec_guard.pop(value_id, self.serializer_id);
extra.rec_guard.pop(value_id, self.definition.id());
r
}

fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult<Cow<'py, str>> {
self._invalid_as_json_key(key, extra, Self::EXPECTED_TYPE)
self.definition.get().unwrap().json_key(key, extra)
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -87,19 +87,21 @@ impl TypeSerializer for DefinitionRefSerializer {
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let value_id = extra.rec_guard.add(value, self.serializer_id).map_err(py_err_se_err)?;
let comb_serializer = extra.definitions.get(self.serializer_id).unwrap();
let comb_serializer = self.definition.get().unwrap();
let value_id = extra
.rec_guard
.add(value, self.definition.id())
.map_err(py_err_se_err)?;
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, self.serializer_id);
extra.rec_guard.pop(value_id, self.definition.id());
r
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}

fn retry_with_lax_check(&self, definitions: &Definitions<CombinedSerializer>) -> bool {
let comb_serializer = definitions.get(self.serializer_id).unwrap();
comb_serializer.retry_with_lax_check(definitions)
fn retry_with_lax_check(&self) -> bool {
self.definition.get().unwrap().retry_with_lax_check()
}
}
4 changes: 2 additions & 2 deletions src/serializers/type_serializers/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ impl BuildSerializer for DictSerializer {
};
let filter = match schema.get_as::<&PyDict>(intern!(py, "serialization"))? {
Some(ser) => {
let include = ser.get_item(intern!(py, "include"));
let exclude = ser.get_item(intern!(py, "exclude"));
let include = ser.get_item(intern!(py, "include"))?;
let exclude = ser.get_item(intern!(py, "exclude"))?;
SchemaFilter::from_set_hash(include, exclude)?
}
None => SchemaFilter::default(),
Expand Down
102 changes: 102 additions & 0 deletions src/serializers/type_serializers/float.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use pyo3::types::PyDict;
use pyo3::{intern, prelude::*};

use std::borrow::Cow;

use serde::Serializer;

use crate::definitions::DefinitionsBuilder;
use crate::serializers::config::InfNanMode;
use crate::tools::SchemaDict;

use super::simple::to_str_json_key;
use super::{
infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, IsType, ObType,
SerMode, TypeSerializer,
};

#[derive(Debug, Clone)]
pub struct FloatSerializer {
inf_nan_mode: InfNanMode,
}

impl BuildSerializer for FloatSerializer {
const EXPECTED_TYPE: &'static str = "float";

fn build(
schema: &PyDict,
config: Option<&PyDict>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let inf_nan_mode = config
.and_then(|c| c.get_as(intern!(schema.py(), "ser_json_inf_nan")).transpose())
.transpose()?
.unwrap_or_default();
Ok(Self { inf_nan_mode }.into())
}
}

impl_py_gc_traverse!(FloatSerializer {});

impl TypeSerializer for FloatSerializer {
fn to_python(
&self,
value: &PyAny,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let py = value.py();
match extra.ob_type_lookup.is_type(value, ObType::Float) {
IsType::Exact => Ok(value.into_py(py)),
IsType::Subclass => match extra.mode {
SerMode::Json => {
let rust_value = value.extract::<f64>()?;
Ok(rust_value.to_object(py))
}
_ => infer_to_python(value, include, exclude, extra),
},
IsType::False => {
extra.warnings.on_fallback_py(self.get_name(), value, extra)?;
infer_to_python(value, include, exclude, extra)
}
}
}

fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult<Cow<'py, str>> {
match extra.ob_type_lookup.is_type(key, ObType::Float) {
IsType::Exact | IsType::Subclass => to_str_json_key(key),
IsType::False => {
extra.warnings.on_fallback_py(self.get_name(), key, extra)?;
infer_json_key(key, extra)
}
}
}

fn serde_serialize<S: Serializer>(
&self,
value: &PyAny,
serializer: S,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
match value.extract::<f64>() {
Ok(v) => {
if (v.is_nan() || v.is_infinite()) && self.inf_nan_mode == InfNanMode::Null {
serializer.serialize_none()
} else {
serializer.serialize_f64(v)
}
}
Err(_) => {
extra.warnings.on_fallback_ser::<S>(self.get_name(), value, extra)?;
infer_serialize(value, serializer, include, exclude, extra)
}
}
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}
}
4 changes: 4 additions & 0 deletions src/serializers/type_serializers/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,8 @@ impl TypeSerializer for ListSerializer {
fn get_name(&self) -> &str {
&self.name
}

fn retry_with_lax_check(&self) -> bool {
self.item_serializer.retry_with_lax_check()
}
}
1 change: 1 addition & 0 deletions src/serializers/type_serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pub mod datetime_etc;
pub mod decimal;
pub mod definitions;
pub mod dict;
pub mod float;
pub mod format;
pub mod function;
pub mod generator;
Expand Down
6 changes: 3 additions & 3 deletions src/serializers/type_serializers/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use super::{
};
use crate::build_tools::py_schema_err;
use crate::build_tools::{py_schema_error_type, ExtraBehavior};
use crate::definitions::{Definitions, DefinitionsBuilder};
use crate::definitions::DefinitionsBuilder;
use crate::serializers::errors::PydanticSerializationUnexpectedValue;
use crate::tools::SchemaDict;

Expand All @@ -39,7 +39,7 @@ impl BuildSerializer for ModelFieldsBuilder {
let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?;
let mut fields: AHashMap<String, SerField> = AHashMap::with_capacity(fields_dict.len());

let extra_serializer = match (schema.get_item(intern!(py, "extras_schema")), &fields_mode) {
let extra_serializer = match (schema.get_item(intern!(py, "extras_schema"))?, &fields_mode) {
(Some(v), FieldsMode::ModelExtra) => Some(CombinedSerializer::build(v.extract()?, config, definitions)?),
(Some(_), _) => return py_schema_err!("extras_schema can only be used if extra_behavior=allow"),
(_, _) => None,
Expand Down Expand Up @@ -228,7 +228,7 @@ impl TypeSerializer for ModelSerializer {
&self.name
}

fn retry_with_lax_check(&self, _definitions: &Definitions<CombinedSerializer>) -> bool {
fn retry_with_lax_check(&self) -> bool {
true
}
}
6 changes: 3 additions & 3 deletions src/serializers/type_serializers/nullable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::definitions::{Definitions, DefinitionsBuilder};
use crate::definitions::DefinitionsBuilder;
use crate::tools::SchemaDict;

use super::{infer_json_key_known, BuildSerializer, CombinedSerializer, Extra, IsType, ObType, TypeSerializer};
Expand Down Expand Up @@ -75,7 +75,7 @@ impl TypeSerializer for NullableSerializer {
Self::EXPECTED_TYPE
}

fn retry_with_lax_check(&self, definitions: &Definitions<CombinedSerializer>) -> bool {
self.serializer.retry_with_lax_check(definitions)
fn retry_with_lax_check(&self) -> bool {
self.serializer.retry_with_lax_check()
}
}
1 change: 0 additions & 1 deletion src/serializers/type_serializers/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,3 @@ pub(crate) fn bool_json_key(key: &PyAny) -> PyResult<Cow<str>> {
}

build_simple_serializer!(BoolSerializer, "bool", bool, ObType::Bool, bool_json_key);
build_simple_serializer!(FloatSerializer, "float", f64, ObType::Float, to_str_json_key);
2 changes: 1 addition & 1 deletion src/serializers/type_serializers/typed_dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl BuildSerializer for TypedDictBuilder {
let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?;
let mut fields: AHashMap<String, SerField> = AHashMap::with_capacity(fields_dict.len());

let extra_serializer = match (schema.get_item(intern!(py, "extras_schema")), &fields_mode) {
let extra_serializer = match (schema.get_item(intern!(py, "extras_schema"))?, &fields_mode) {
(Some(v), FieldsMode::TypedDictAllow) => {
Some(CombinedSerializer::build(v.extract()?, config, definitions)?)
}
Expand Down
17 changes: 8 additions & 9 deletions src/serializers/type_serializers/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pyo3::types::{PyDict, PyList, PyTuple};
use std::borrow::Cow;

use crate::build_tools::py_schema_err;
use crate::definitions::{Definitions, DefinitionsBuilder};
use crate::definitions::DefinitionsBuilder;
use crate::tools::SchemaDict;
use crate::PydanticSerializationUnexpectedValue;

Expand Down Expand Up @@ -75,9 +75,10 @@ impl TypeSerializer for UnionSerializer {
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
// try the serializers in with error_on fallback=true
// try the serializers in left to right order with error_on fallback=true
let mut new_extra = extra.clone();
new_extra.check = SerCheck::Strict;

for comb_serializer in &self.choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Ok(v) => return Ok(v),
Expand All @@ -87,7 +88,7 @@ impl TypeSerializer for UnionSerializer {
},
}
}
if self.retry_with_lax_check(extra.definitions) {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
for comb_serializer in &self.choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Expand Down Expand Up @@ -116,7 +117,7 @@ impl TypeSerializer for UnionSerializer {
},
}
}
if self.retry_with_lax_check(extra.definitions) {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
for comb_serializer in &self.choices {
match comb_serializer.json_key(key, &new_extra) {
Expand Down Expand Up @@ -153,7 +154,7 @@ impl TypeSerializer for UnionSerializer {
},
}
}
if self.retry_with_lax_check(extra.definitions) {
if self.retry_with_lax_check() {
new_extra.check = SerCheck::Lax;
for comb_serializer in &self.choices {
match comb_serializer.to_python(value, include, exclude, &new_extra) {
Expand All @@ -174,10 +175,8 @@ impl TypeSerializer for UnionSerializer {
&self.name
}

fn retry_with_lax_check(&self, definitions: &Definitions<CombinedSerializer>) -> bool {
self.choices
.iter()
.any(|choice| choice.retry_with_lax_check(definitions))
fn retry_with_lax_check(&self) -> bool {
self.choices.iter().any(CombinedSerializer::retry_with_lax_check)
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/serializers/type_serializers/with_default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use pyo3::intern;
use pyo3::prelude::*;
use pyo3::types::PyDict;

use crate::definitions::{Definitions, DefinitionsBuilder};
use crate::definitions::DefinitionsBuilder;
use crate::tools::SchemaDict;
use crate::validators::DefaultType;

Expand Down Expand Up @@ -67,8 +67,8 @@ impl TypeSerializer for WithDefaultSerializer {
Self::EXPECTED_TYPE
}

fn retry_with_lax_check(&self, definitions: &Definitions<CombinedSerializer>) -> bool {
self.serializer.retry_with_lax_check(definitions)
fn retry_with_lax_check(&self) -> bool {
self.serializer.retry_with_lax_check()
}

fn get_default(&self, py: Python) -> PyResult<Option<PyObject>> {
Expand Down
4 changes: 2 additions & 2 deletions src/tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl<'py> SchemaDict<'py> for PyDict {
where
T: FromPyObject<'py>,
{
match self.get_item(key) {
match self.get_item(key)? {
Some(t) => Ok(Some(<T>::extract(t)?)),
None => Ok(None),
}
Expand All @@ -30,7 +30,7 @@ impl<'py> SchemaDict<'py> for PyDict {
where
T: FromPyObject<'py>,
{
match self.get_item(key) {
match self.get_item(key)? {
Some(t) => <T>::extract(t),
None => py_err!(PyKeyError; "{}", key),
}
Expand Down
20 changes: 6 additions & 14 deletions src/validators/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ use pyo3::types::PyDict;
use crate::errors::ValResult;
use crate::input::Input;

use super::{validation_state::ValidationState, BuildValidator, CombinedValidator, DefinitionsBuilder, Validator};
use super::{
validation_state::Exactness, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator,
};

/// This might seem useless, but it's useful in DictValidator to avoid Option<Validator> a lot
#[derive(Debug, Clone)]
Expand All @@ -29,24 +31,14 @@ impl Validator for AnyValidator {
&self,
py: Python<'data>,
input: &'data impl Input<'data>,
_state: &mut ValidationState,
state: &mut ValidationState,
) -> ValResult<'data, PyObject> {
// in a union, Any should be preferred to doing lax coercions
state.floor_exactness(Exactness::Strict);
Ok(input.to_object(py))
}

fn different_strict_behavior(
&self,
_definitions: Option<&DefinitionsBuilder<CombinedValidator>>,
_ultra_strict: bool,
) -> bool {
false
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}

fn complete(&mut self, _definitions: &DefinitionsBuilder<CombinedValidator>) -> PyResult<()> {
Ok(())
}
}
Loading