diff --git a/pydantic_core/_pydantic_core.pyi b/pydantic_core/_pydantic_core.pyi index cdf6bc1b1..c3f4d6489 100644 --- a/pydantic_core/_pydantic_core.pyi +++ b/pydantic_core/_pydantic_core.pyi @@ -154,7 +154,8 @@ class MultiHostUrl: def __repr__(self) -> str: ... class SchemaError(Exception): - pass + def error_count(self) -> int: ... + def errors(self) -> 'list[ErrorDetails]': ... class ValidationError(ValueError): @property diff --git a/src/build_tools.rs b/src/build_tools.rs index d60abc3f5..27e76ba41 100644 --- a/src/build_tools.rs +++ b/src/build_tools.rs @@ -4,10 +4,11 @@ use std::fmt; use pyo3::exceptions::{PyException, PyKeyError}; use pyo3::prelude::*; -use pyo3::types::{PyDict, PyString}; +use pyo3::types::{PyDict, PyList, PyString}; use pyo3::{intern, FromPyObject, PyErrArguments}; -use crate::errors::{pretty_line_errors, ValError}; +use crate::errors::ValError; +use crate::ValidationError; pub trait SchemaDict<'py> { fn get_as(&'py self, key: &PyString) -> PyResult> @@ -98,22 +99,25 @@ pub fn is_strict(schema: &PyDict, config: Option<&PyDict>) -> PyResult { Ok(schema_or_config_same(schema, config, intern!(py, "strict"))?.unwrap_or(false)) } +enum SchemaErrorEnum { + Message(String), + ValidationError(ValidationError), +} + // we could perhaps do clever things here to store each schema error, or have different types for the top // level error group, and other errors, we could perhaps also support error groups!? #[pyclass(extends=PyException, module="pydantic_core._pydantic_core")] -pub struct SchemaError { - message: String, -} +pub struct SchemaError(SchemaErrorEnum); impl fmt::Debug for SchemaError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "SchemaError({:?})", self.message) + write!(f, "SchemaError({:?})", self.message()) } } impl fmt::Display for SchemaError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(&self.message) + f.write_str(self.message()) } } @@ -134,12 +138,24 @@ impl SchemaError { pub fn from_val_error(py: Python, error: ValError) -> PyErr { match error { - ValError::LineErrors(line_errors) => { - let details = pretty_line_errors(py, line_errors); - SchemaError::new_err(format!("Invalid Schema:\n{details}")) + ValError::LineErrors(raw_errors) => { + let line_errors = raw_errors.into_iter().map(|e| e.into_py(py)).collect(); + let validation_error = ValidationError::new(line_errors, "Schema".to_object(py)); + let schema_error = SchemaError(SchemaErrorEnum::ValidationError(validation_error)); + match Py::new(py, schema_error) { + Ok(err) => PyErr::from_value(err.into_ref(py)), + Err(err) => err, + } } - ValError::InternalErr(py_err) => py_err, - ValError::Omit => unreachable!(), + ValError::InternalErr(err) => err, + ValError::Omit => Self::new_err("Unexpected Omit error."), + } + } + + fn message(&self) -> &str { + match &self.0 { + SchemaErrorEnum::Message(message) => message.as_str(), + SchemaErrorEnum::ValidationError(_) => "", } } } @@ -148,15 +164,35 @@ impl SchemaError { impl SchemaError { #[new] fn py_new(message: String) -> Self { - Self { message } + Self(SchemaErrorEnum::Message(message)) + } + + fn error_count(&self) -> usize { + match &self.0 { + SchemaErrorEnum::Message(_) => 0, + SchemaErrorEnum::ValidationError(error) => error.error_count(), + } } - fn __repr__(&self) -> String { - format!("{self:?}") + fn errors(&self, py: Python) -> PyResult> { + match &self.0 { + SchemaErrorEnum::Message(_) => Ok(PyList::empty(py).into_py(py)), + SchemaErrorEnum::ValidationError(error) => error.errors(py, None), + } } - fn __str__(&self) -> String { - self.to_string() + fn __str__(&self, py: Python) -> String { + match &self.0 { + SchemaErrorEnum::Message(message) => message.to_owned(), + SchemaErrorEnum::ValidationError(error) => error.display(py), + } + } + + fn __repr__(&self, py: Python) -> String { + match &self.0 { + SchemaErrorEnum::Message(message) => format!("SchemaError({message:?})"), + SchemaErrorEnum::ValidationError(error) => error.display(py), + } } } diff --git a/src/errors/line_error.rs b/src/errors/line_error.rs index 8e4c9cd4b..e39c116c6 100644 --- a/src/errors/line_error.rs +++ b/src/errors/line_error.rs @@ -6,7 +6,6 @@ use crate::input::{Input, JsonInput}; use super::location::{LocItem, Location}; use super::types::ErrorType; -use super::validation_exception::{pretty_py_line_errors, PyLineError}; pub type ValResult<'a, T> = Result>; @@ -71,11 +70,6 @@ impl<'a> ValError<'a> { } } -pub fn pretty_line_errors(py: Python, line_errors: Vec) -> String { - let py_line_errors: Vec = line_errors.into_iter().map(|e| e.into_py(py)).collect(); - pretty_py_line_errors(py, py_line_errors.iter()) -} - /// A `ValLineError` is a single error that occurred during validation which is converted to a `PyLineError` /// to eventually form a `ValidationError`. /// I don't like the name `ValLineError`, but it's the best I could come up with (for now). diff --git a/src/errors/mod.rs b/src/errors/mod.rs index 21f029f27..914629272 100644 --- a/src/errors/mod.rs +++ b/src/errors/mod.rs @@ -6,7 +6,7 @@ mod types; mod validation_exception; mod value_exception; -pub use self::line_error::{pretty_line_errors, InputValue, ValError, ValLineError, ValResult}; +pub use self::line_error::{InputValue, ValError, ValLineError, ValResult}; pub use self::location::LocItem; pub use self::types::{list_all_errors, ErrorType}; pub use self::validation_exception::ValidationError; diff --git a/src/errors/validation_exception.rs b/src/errors/validation_exception.rs index 0b1218a81..f13147183 100644 --- a/src/errors/validation_exception.rs +++ b/src/errors/validation_exception.rs @@ -24,6 +24,10 @@ pub struct ValidationError { } impl ValidationError { + pub fn new(line_errors: Vec, title: PyObject) -> Self { + Self { line_errors, title } + } + pub fn from_val_error(py: Python, title: PyObject, error: ValError, outer_location: Option) -> PyErr { match error { ValError::LineErrors(raw_errors) => { @@ -41,7 +45,7 @@ impl ValidationError { } } - fn display(&self, py: Python) -> String { + pub fn display(&self, py: Python) -> String { let count = self.line_errors.len(); let plural = if count == 1 { "" } else { "s" }; let title: &str = self.title.extract(py).unwrap(); @@ -77,11 +81,11 @@ impl ValidationError { self.title.clone_ref(py) } - fn error_count(&self) -> usize { + pub fn error_count(&self) -> usize { self.line_errors.len() } - fn errors(&self, py: Python, include_context: Option) -> PyResult> { + pub fn errors(&self, py: Python, include_context: Option) -> PyResult> { // taken approximately from the pyo3, but modified to return the error during iteration // https://github.com/PyO3/pyo3/blob/a3edbf4fcd595f0e234c87d4705eb600a9779130/src/types/list.rs#L27-L55 unsafe { diff --git a/tests/serializers/test_definitions.py b/tests/serializers/test_definitions.py index 7d18e4aba..fadf0bed9 100644 --- a/tests/serializers/test_definitions.py +++ b/tests/serializers/test_definitions.py @@ -36,8 +36,8 @@ def test_def_error(): ) ) - assert exc_info.value.args[0].startswith( - "Invalid Schema:\ndefinitions -> definitions -> 1\n Input tag 'wrong' found using 'type'" + assert str(exc_info.value).startswith( + "1 validation error for Schema\ndefinitions -> definitions -> 1\n Input tag 'wrong' found using 'type'" ) diff --git a/tests/serializers/test_list_tuple.py b/tests/serializers/test_list_tuple.py index c04263587..dfb82f413 100644 --- a/tests/serializers/test_list_tuple.py +++ b/tests/serializers/test_list_tuple.py @@ -160,8 +160,8 @@ class RemovedContains(ImplicitContains): ({4.2}, 'Input should be a valid integer, got a number with a fractional part'), ({'a'}, 'Input should be a valid integer, unable to parse string as an integer'), (ImplicitContains(), 'Input should be a valid set'), - (ExplicitContains(), re.compile('.*Invalid Schema:.*Input should be a valid set.*', re.DOTALL)), - (RemovedContains(), re.compile('.*Invalid Schema:.*Input should be a valid set.*', re.DOTALL)), + (ExplicitContains(), re.compile('.*1 validation error for Schema.*Input should be a valid set.*', re.DOTALL)), + (RemovedContains(), re.compile('.*1 validation error for Schema.*Input should be a valid set.*', re.DOTALL)), ], ) @pytest.mark.parametrize('schema_func', [core_schema.list_schema, core_schema.tuple_variable_schema]) diff --git a/tests/test_build.py b/tests/test_build.py index 51195baba..3448c5539 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -34,10 +34,19 @@ def test_schema_as_string(): def test_schema_wrong_type(): with pytest.raises(SchemaError) as exc_info: SchemaValidator(1) - assert exc_info.value.args[0] == ( - 'Invalid Schema:\n Input should be a valid dictionary or instance to' + assert str(exc_info.value) == ( + '1 validation error for Schema\n Input should be a valid dictionary or instance to' ' extract fields from [type=dict_attributes_type, input_value=1, input_type=int]' ) + assert exc_info.value.errors() == [ + { + 'input': 1, + 'loc': (), + 'msg': 'Input should be a valid dictionary or instance to extract fields ' 'from', + 'type': 'dict_attributes_type', + } + ] + assert exc_info.value.error_count() == 1 @pytest.mark.parametrize('pickle_protocol', range(1, pickle.HIGHEST_PROTOCOL + 1)) diff --git a/tests/validators/test_definitions.py b/tests/validators/test_definitions.py index 358af3eac..9a074dbb6 100644 --- a/tests/validators/test_definitions.py +++ b/tests/validators/test_definitions.py @@ -41,9 +41,10 @@ def test_def_error(): [core_schema.int_schema(ref='foobar'), {'type': 'wrong'}], ) ) - assert exc_info.value.args[0].startswith( - "Invalid Schema:\ndefinitions -> definitions -> 1\n Input tag 'wrong' found using 'type'" + assert str(exc_info.value).startswith( + "1 validation error for Schema\ndefinitions -> definitions -> 1\n Input tag 'wrong' found using 'type'" ) + assert exc_info.value.error_count() == 1 def test_dict_repeat(): diff --git a/tests/validators/test_union.py b/tests/validators/test_union.py index f81e23939..9dc276312 100644 --- a/tests/validators/test_union.py +++ b/tests/validators/test_union.py @@ -227,11 +227,15 @@ def test_no_choices(): with pytest.raises(SchemaError) as exc_info: SchemaValidator({'type': 'union'}) - assert exc_info.value.args[0] == ( - 'Invalid Schema:\n' + assert str(exc_info.value) == ( + '1 validation error for Schema\n' 'union -> choices\n' " Field required [type=missing, input_value={'type': 'union'}, input_type=dict]" ) + assert exc_info.value.error_count() == 1 + assert exc_info.value.errors() == [ + {'input': {'type': 'union'}, 'loc': ('union', 'choices'), 'msg': 'Field required', 'type': 'missing'} + ] def test_empty_choices():