Skip to content

Commit

Permalink
use ValidationState to propagate exactness
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Aug 14, 2023
1 parent 87b4789 commit 8b22d55
Show file tree
Hide file tree
Showing 43 changed files with 783 additions and 792 deletions.
25 changes: 9 additions & 16 deletions src/input/return_enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ use pyo3::PyTypeInfo;
use serde::{ser::Error, Serialize, Serializer};

use crate::errors::{py_err_string, ErrorType, ErrorTypeDefaults, InputValue, ValError, ValLineError, ValResult};
use crate::recursion_guard::RecursionGuard;
use crate::tools::py_err;
use crate::validators::{CombinedValidator, Extra, Validator};
use crate::validators::{CombinedValidator, ValidationState, Validator};

use super::parse_json::{JsonArray, JsonInput, JsonObject};
use super::{py_error_on_minusone, Input};
Expand Down Expand Up @@ -157,15 +156,14 @@ fn validate_iter_to_vec<'a, 's>(
capacity: usize,
mut max_length_check: MaxLengthCheck<'a, impl Input<'a>>,
validator: &'s CombinedValidator,
extra: &Extra,
state: &mut ValidationState,
definitions: &'a [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'a, Vec<PyObject>> {
let mut output: Vec<PyObject> = Vec::with_capacity(capacity);
let mut errors: Vec<ValLineError> = Vec::new();
for (index, item_result) in iter.enumerate() {
let item = item_result.map_err(|e| any_next_error!(py, e, max_length_check.input, index))?;
match validator.validate(py, item, extra, definitions, recursion_guard) {
match validator.validate(py, item, state, definitions) {
Ok(item) => {
max_length_check.incr()?;
output.push(item);
Expand Down Expand Up @@ -226,14 +224,13 @@ fn validate_iter_to_set<'a, 's>(
field_type: &'static str,
max_length: Option<usize>,
validator: &'s CombinedValidator,
extra: &Extra,
state: &mut ValidationState,
definitions: &'a [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'a, ()> {
let mut errors: Vec<ValLineError> = Vec::new();
for (index, item_result) in iter.enumerate() {
let item = item_result.map_err(|e| any_next_error!(py, e, input, index))?;
match validator.validate(py, item, extra, definitions, recursion_guard) {
match validator.validate(py, item, state, definitions) {
Ok(item) => {
set.build_add(item)?;
if let Some(max_length) = max_length {
Expand Down Expand Up @@ -315,9 +312,8 @@ impl<'a> GenericIterable<'a> {
max_length: Option<usize>,
field_type: &'static str,
validator: &'s CombinedValidator,
extra: &Extra,
state: &mut ValidationState,
definitions: &'a [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'a, Vec<PyObject>> {
let capacity = self
.generic_len()
Expand All @@ -332,9 +328,8 @@ impl<'a> GenericIterable<'a> {
capacity,
max_length_check,
validator,
extra,
state,
definitions,
recursion_guard,
)
};
}
Expand All @@ -360,9 +355,8 @@ impl<'a> GenericIterable<'a> {
max_length: Option<usize>,
field_type: &'static str,
validator: &'s CombinedValidator,
extra: &Extra,
state: &mut ValidationState,
definitions: &'a [CombinedValidator],
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'a, ()> {
macro_rules! validate_set {
($iter:expr) => {
Expand All @@ -374,9 +368,8 @@ impl<'a> GenericIterable<'a> {
field_type,
max_length,
validator,
extra,
state,
definitions,
recursion_guard,
)
};
}
Expand Down
12 changes: 6 additions & 6 deletions src/validators/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use pyo3::types::PyDict;
use crate::errors::ValResult;
use crate::input::Input;

use crate::recursion_guard::RecursionGuard;

use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::{
validation_state::{Exactness, ValidationState},
BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Validator,
};

/// This might seem useless, but it's useful in DictValidator to avoid Option<Validator> a lot
#[derive(Debug, Clone)]
Expand All @@ -31,11 +32,10 @@ impl Validator for AnyValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
_extra: &Extra,
state: &mut ValidationState,
_definitions: &'data Definitions<CombinedValidator>,
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
// Ok(input.clone().into_py(py))
state.merge_exactness(Exactness::Lax);
Ok(input.to_object(py))
}

Expand Down
17 changes: 8 additions & 9 deletions src/validators/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use crate::errors::{ErrorTypeDefaults, ValError, ValLineError, ValResult};
use crate::input::{GenericArguments, Input};
use crate::lookup_key::LookupKey;

use crate::recursion_guard::RecursionGuard;
use crate::tools::SchemaDict;

use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
struct Parameter {
Expand Down Expand Up @@ -165,9 +165,8 @@ impl Validator for ArgumentsValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
state: &mut ValidationState,
definitions: &'data Definitions<CombinedValidator>,
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
let args = input.validate_args()?;

Expand Down Expand Up @@ -207,7 +206,7 @@ impl Validator for ArgumentsValidator {
(Some(pos_value), None) => {
match parameter
.validator
.validate(py, pos_value, extra, definitions, recursion_guard)
.validate(py, pos_value, state, definitions)
{
Ok(value) => output_args.push(value),
Err(ValError::LineErrors(line_errors)) => {
Expand All @@ -219,7 +218,7 @@ impl Validator for ArgumentsValidator {
(None, Some((lookup_path, kw_value))) => {
match parameter
.validator
.validate(py, kw_value, extra, definitions, recursion_guard)
.validate(py, kw_value, state, definitions)
{
Ok(value) => output_kwargs.set_item(parameter.kwarg_key.as_ref().unwrap(), value)?,
Err(ValError::LineErrors(line_errors)) => {
Expand All @@ -231,7 +230,7 @@ impl Validator for ArgumentsValidator {
}
}
(None, None) => {
if let Some(value) = parameter.validator.default_value(py, Some(parameter.name.as_str()), extra, definitions, recursion_guard)? {
if let Some(value) = parameter.validator.default_value(py, Some(parameter.name.as_str()), state, definitions)? {
if let Some(ref kwarg_key) = parameter.kwarg_key {
output_kwargs.set_item(kwarg_key, value)?;
} else {
Expand Down Expand Up @@ -261,7 +260,7 @@ impl Validator for ArgumentsValidator {
if len > self.positional_params_count {
if let Some(ref validator) = self.var_args_validator {
for (index, item) in $slice_macro!(args, self.positional_params_count, len).iter().enumerate() {
match validator.validate(py, item, extra, definitions, recursion_guard) {
match validator.validate(py, item, state, definitions) {
Ok(value) => output_args.push(value),
Err(ValError::LineErrors(line_errors)) => {
errors.extend(
Expand Down Expand Up @@ -303,7 +302,7 @@ impl Validator for ArgumentsValidator {
};
if !used_kwargs.contains(either_str.as_cow()?.as_ref()) {
match self.var_kwargs_validator {
Some(ref validator) => match validator.validate(py, value, extra, definitions, recursion_guard) {
Some(ref validator) => match validator.validate(py, value, state, definitions) {
Ok(value) => output_kwargs.set_item(either_str.as_py_string(py), value)?,
Err(ValError::LineErrors(line_errors)) => {
for err in line_errors {
Expand Down
11 changes: 5 additions & 6 deletions src/validators/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ use crate::build_tools::is_strict;
use crate::errors::ValResult;
use crate::input::Input;

use crate::recursion_guard::RecursionGuard;

use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone)]
pub struct BoolValidator {
Expand Down Expand Up @@ -36,13 +34,14 @@ impl Validator for BoolValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
state: &mut ValidationState,
_definitions: &'data Definitions<CombinedValidator>,
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
// TODO in theory this could be quicker if we used PyBool rather than going to a bool
// and back again, might be worth profiling?
Ok(input.validate_bool(extra.strict.unwrap_or(self.strict))?.into_py(py))
let strict = state.strict_or(self.strict);
state.set_exactness_unknown();
Ok(input.validate_bool(strict)?.into_py(py))
}

fn different_strict_behavior(
Expand Down
15 changes: 7 additions & 8 deletions src/validators/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ use crate::build_tools::is_strict;
use crate::errors::{ErrorType, ValError, ValResult};
use crate::input::Input;

use crate::recursion_guard::RecursionGuard;
use crate::tools::SchemaDict;

use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, ValidationState, Validator};

#[derive(Debug, Clone)]
pub struct BytesValidator {
Expand Down Expand Up @@ -45,11 +44,11 @@ impl Validator for BytesValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
state: &mut ValidationState,
_definitions: &'data Definitions<CombinedValidator>,
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
let either_bytes = input.validate_bytes(extra.strict.unwrap_or(self.strict))?;
let either_bytes = input.validate_bytes(state.strict_or(self.strict))?;
state.set_exactness_unknown();
Ok(either_bytes.into_py(py))
}

Expand Down Expand Up @@ -84,11 +83,10 @@ impl Validator for BytesConstrainedValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
state: &mut ValidationState,
_definitions: &'data Definitions<CombinedValidator>,
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
let either_bytes = input.validate_bytes(extra.strict.unwrap_or(self.strict))?;
let either_bytes = input.validate_bytes(state.strict_or(self.strict))?;
let len = either_bytes.len()?;

if let Some(min_length) = self.min_length {
Expand All @@ -114,6 +112,7 @@ impl Validator for BytesConstrainedValidator {
}
}

state.set_exactness_unknown();
Ok(either_bytes.into_py(py))
}

Expand Down
13 changes: 5 additions & 8 deletions src/validators/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ use pyo3::types::{PyDict, PyTuple};
use crate::errors::ValResult;
use crate::input::Input;

use crate::recursion_guard::RecursionGuard;
use crate::tools::SchemaDict;

use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
pub struct CallValidator {
Expand Down Expand Up @@ -76,13 +76,10 @@ impl Validator for CallValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
state: &mut ValidationState,
definitions: &'data Definitions<CombinedValidator>,
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
let args = self
.arguments_validator
.validate(py, input, extra, definitions, recursion_guard)?;
let args = self.arguments_validator.validate(py, input, state, definitions)?;

let return_value = if let Ok((args, kwargs)) = args.extract::<(&PyTuple, &PyDict)>(py) {
self.function.call(py, args, Some(kwargs))?
Expand All @@ -95,7 +92,7 @@ impl Validator for CallValidator {

if let Some(return_validator) = &self.return_validator {
return_validator
.validate(py, return_value.into_ref(py), extra, definitions, recursion_guard)
.validate(py, return_value.into_ref(py), state, definitions)
.map_err(|e| e.with_outer_location("return".into()))
} else {
Ok(return_value.to_object(py))
Expand Down
9 changes: 4 additions & 5 deletions src/validators/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ use pyo3::types::PyDict;
use crate::errors::{ErrorTypeDefaults, ValError, ValResult};
use crate::input::Input;

use crate::recursion_guard::RecursionGuard;

use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::validation_state::{Exactness, ValidationState};
use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
pub struct CallableValidator;
Expand All @@ -30,10 +29,10 @@ impl Validator for CallableValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
_extra: &Extra,
state: &mut ValidationState,
_definitions: &'data Definitions<CombinedValidator>,
_recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
state.merge_exactness(Exactness::Lax);
match input.callable() {
true => Ok(input.to_object(py)),
false => Err(ValError::new(ErrorTypeDefaults::CallableType, input)),
Expand Down
13 changes: 5 additions & 8 deletions src/validators/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use pyo3::types::{PyDict, PyList};
use crate::build_tools::py_schema_err;
use crate::errors::ValResult;
use crate::input::Input;
use crate::recursion_guard::RecursionGuard;
use crate::tools::SchemaDict;

use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
pub struct ChainValidator {
Expand Down Expand Up @@ -74,17 +74,14 @@ impl Validator for ChainValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
state: &mut ValidationState,
definitions: &'data Definitions<CombinedValidator>,
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
let mut steps_iter = self.steps.iter();
let first_step = steps_iter.next().unwrap();
let value = first_step.validate(py, input, extra, definitions, recursion_guard)?;
let value = first_step.validate(py, input, state, definitions)?;

steps_iter.try_fold(value, |v, step| {
step.validate(py, v.into_ref(py), extra, definitions, recursion_guard)
})
steps_iter.try_fold(value, |v, step| step.validate(py, v.into_ref(py), state, definitions))
}

fn different_strict_behavior(
Expand Down
9 changes: 4 additions & 5 deletions src/validators/custom_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ use pyo3::types::PyDict;
use crate::build_tools::py_schema_err;
use crate::errors::{ErrorType, PydanticCustomError, PydanticKnownError, ValError, ValResult};
use crate::input::Input;
use crate::recursion_guard::RecursionGuard;
use crate::tools::SchemaDict;

use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
use super::validation_state::ValidationState;
use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Validator};

#[derive(Debug, Clone)]
pub enum CustomError {
Expand Down Expand Up @@ -92,12 +92,11 @@ impl Validator for CustomErrorValidator {
&'s self,
py: Python<'data>,
input: &'data impl Input<'data>,
extra: &Extra,
state: &mut ValidationState,
definitions: &'data Definitions<CombinedValidator>,
recursion_guard: &'s mut RecursionGuard,
) -> ValResult<'data, PyObject> {
self.validator
.validate(py, input, extra, definitions, recursion_guard)
.validate(py, input, state, definitions)
.map_err(|_| self.custom_error.as_val_error(input))
}

Expand Down
Loading

0 comments on commit 8b22d55

Please sign in to comment.