Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pydantic_core/_pydantic_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ else:

__all__ = (
'__version__',
'build_profile',
'SchemaValidator',
'SchemaError',
'ValidationError',
'PydanticCustomError',
'PydanticKindError',
'PydanticOmit',
'list_all_errors',
)
__version__: str
build_profile: str
Expand Down
2 changes: 2 additions & 0 deletions pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,8 @@ def json_schema(schema: CoreSchema | None = None, *, ref: str | None = None, ext
'union_tag_invalid',
'union_tag_not_found',
'arguments_type',
'positional_arguments_type',
'keyword_arguments_type',
'unexpected_keyword_argument',
'missing_keyword_argument',
'unexpected_positional_argument',
Expand Down
6 changes: 5 additions & 1 deletion src/errors/kinds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,12 @@ pub enum ErrorKind {
},
// ---------------------
// argument errors
#[strum(message = "Arguments must be a tuple of (positional arguments, keyword arguments) or a plain dict")]
#[strum(message = "Arguments must be a tuple, list or a dictionary")]
ArgumentsType,
#[strum(message = "Positional arguments must be a list or tuple")]
PositionalArgumentsType,
#[strum(message = "Keyword arguments must be a dictionary")]
KeywordArgumentsType,
#[strum(message = "Unexpected keyword argument")]
UnexpectedKeywordArgument,
#[strum(message = "Missing required keyword argument")]
Expand Down
54 changes: 37 additions & 17 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pyo3::prelude::*;

use crate::errors::{ErrorKind, InputValue, LocItem, ValError, ValResult};
use crate::errors::{ErrorKind, InputValue, LocItem, ValError, ValLineError, ValResult};

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
Expand Down Expand Up @@ -56,24 +56,44 @@ impl<'a> Input<'a> for JsonInput {

fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
match self {
JsonInput::Object(kwargs) => Ok(JsonArgs::new(None, Some(kwargs)).into()),
JsonInput::Array(array) => {
if array.len() != 2 {
Err(ValError::new(ErrorKind::ArgumentsType, self))
} else {
let args = match unsafe { array.get_unchecked(0) } {
JsonInput::Null => None,
JsonInput::Array(args) => Some(args.as_slice()),
_ => return Err(ValError::new(ErrorKind::ArgumentsType, self)),
};
let kwargs = match unsafe { array.get_unchecked(1) } {
JsonInput::Null => None,
JsonInput::Object(kwargs) => Some(kwargs),
_ => return Err(ValError::new(ErrorKind::ArgumentsType, self)),
};
Ok(JsonArgs::new(args, kwargs).into())
JsonInput::Object(object) => {
if let Some(args) = object.get("__args__") {
if let Some(kwargs) = object.get("__kwargs__") {
// we only try this logic if there are only these two items in the dict
if object.len() == 2 {
let args = match args {
JsonInput::Null => Ok(None),
JsonInput::Array(args) => Ok(Some(args.as_slice())),
_ => Err(ValLineError::new_with_loc(
ErrorKind::PositionalArgumentsType,
args,
"__args__",
)),
};
let kwargs = match kwargs {
JsonInput::Null => Ok(None),
JsonInput::Object(kwargs) => Ok(Some(kwargs)),
_ => Err(ValLineError::new_with_loc(
ErrorKind::KeywordArgumentsType,
kwargs,
"__kwargs__",
)),
};

return match (args, kwargs) {
(Ok(args), Ok(kwargs)) => Ok(JsonArgs::new(args, kwargs).into()),
(Err(args_error), Err(kwargs_error)) => {
return Err(ValError::LineErrors(vec![args_error, kwargs_error]))
}
(Err(error), _) => Err(ValError::LineErrors(vec![error])),
(_, Err(error)) => Err(ValError::LineErrors(vec![error])),
};
}
}
}
Ok(JsonArgs::new(None, Some(object)).into())
}
JsonInput::Array(array) => Ok(JsonArgs::new(Some(array), None).into()),
_ => Err(ValError::new(ErrorKind::ArgumentsType, self)),
}
}
Expand Down
70 changes: 49 additions & 21 deletions src/input/input_python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use pyo3::types::{
use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues};
use pyo3::{ffi, intern, AsPyPointer, PyTypeInfo};

use crate::errors::{py_err_string, ErrorKind, InputValue, LocItem, ValError, ValResult};
use crate::errors::{py_err_string, ErrorKind, InputValue, LocItem, ValError, ValLineError, ValResult};

use super::datetime::{
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
Expand Down Expand Up @@ -114,26 +114,54 @@ impl<'a> Input<'a> for PyAny {
}

fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
if let Ok(kwargs) = self.cast_as::<PyDict>() {
Ok(PyArgs::new(None, Some(kwargs)).into())
} else if let Ok((args, kwargs)) = self.extract::<(&PyAny, &PyAny)>() {
let args = if let Ok(tuple) = args.cast_as::<PyTuple>() {
Some(tuple)
} else if args.is_none() {
None
} else if let Ok(list) = args.cast_as::<PyList>() {
Some(PyTuple::new(self.py(), list.iter().collect::<Vec<_>>()))
} else {
return Err(ValError::new(ErrorKind::ArgumentsType, self));
};
let kwargs = if let Ok(dict) = kwargs.cast_as::<PyDict>() {
Some(dict)
} else if kwargs.is_none() {
None
} else {
return Err(ValError::new(ErrorKind::ArgumentsType, self));
};
Ok(PyArgs::new(args, kwargs).into())
if let Ok(dict) = self.cast_as::<PyDict>() {
if let Some(args) = dict.get_item("__args__") {
if let Some(kwargs) = dict.get_item("__kwargs__") {
// we only try this logic if there are only these two items in the dict
if dict.len() == 2 {
let args = if let Ok(tuple) = args.cast_as::<PyTuple>() {
Ok(Some(tuple))
} else if args.is_none() {
Ok(None)
} else if let Ok(list) = args.cast_as::<PyList>() {
Ok(Some(PyTuple::new(self.py(), list.iter().collect::<Vec<_>>())))
} else {
Err(ValLineError::new_with_loc(
ErrorKind::PositionalArgumentsType,
args,
"__args__",
))
};

let kwargs = if let Ok(dict) = kwargs.cast_as::<PyDict>() {
Ok(Some(dict))
} else if kwargs.is_none() {
Ok(None)
} else {
Err(ValLineError::new_with_loc(
ErrorKind::KeywordArgumentsType,
kwargs,
"__kwargs__",
))
};

return match (args, kwargs) {
(Ok(args), Ok(kwargs)) => Ok(PyArgs::new(args, kwargs).into()),
(Err(args_error), Err(kwargs_error)) => {
Err(ValError::LineErrors(vec![args_error, kwargs_error]))
}
(Err(error), _) => Err(ValError::LineErrors(vec![error])),
(_, Err(error)) => Err(ValError::LineErrors(vec![error])),
};
}
}
}
Ok(PyArgs::new(None, Some(dict)).into())
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
Ok(PyArgs::new(Some(tuple), None).into())
} else if let Ok(list) = self.cast_as::<PyList>() {
let tuple = PyTuple::new(self.py(), list.iter().collect::<Vec<_>>());
Ok(PyArgs::new(Some(tuple), None).into())
} else {
Err(ValError::new(ErrorKind::ArgumentsType, self))
}
Expand Down
7 changes: 5 additions & 2 deletions tests/benchmarks/test_micro_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,9 +889,12 @@ def test_arguments(benchmark):
],
}
)
assert v.validate_python(((1, 'a', 'true'), {'b': 'bb', 'c': 3})) == ((1, 'a', True), {'b': 'bb', 'c': 3})
assert v.validate_python({'__args__': (1, 'a', 'true'), '__kwargs__': {'b': 'bb', 'c': 3}}) == (
(1, 'a', True),
{'b': 'bb', 'c': 3},
)

benchmark(v.validate_python, ((1, 'a', 'true'), {'b': 'bb', 'c': 3}))
benchmark(v.validate_python, {'__args__': (1, 'a', 'true'), '__kwargs__': {'b': 'bb', 'c': 3}})


@pytest.mark.benchmark(group='defaults')
Expand Down
Loading