Skip to content

Commit

Permalink
one-pass union validation
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Aug 14, 2023
1 parent 8b22d55 commit ee74e77
Show file tree
Hide file tree
Showing 16 changed files with 286 additions and 329 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ serde = { version = "1.0.183", features = ["derive"] }
# disabled for benchmarks since it makes microbenchmark performance more flakey
mimalloc = { version = "0.1.30", optional = true, default-features = false, features = ["local_dynamic_tls"] }
speedate = "0.11.0"
smallvec = "1.11.0"
ahash = "0.8.0"
url = "2.3.1"
# idna is already required by url, added here to be explicit
Expand Down
56 changes: 11 additions & 45 deletions src/input/input_abstract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use std::fmt;
use pyo3::types::{PyDict, PyType};
use pyo3::{intern, prelude::*};

use crate::errors::{InputValue, LocItem, ValResult};
use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
use crate::{PyMultiHostUrl, PyUrl};

use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
use super::return_enums::{EitherBytes, EitherInt, EitherString};
use super::return_enums::{EitherBytes, EitherInt, EitherString, ValidationMatch};
use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, JsonInput};

#[derive(Debug, Clone, Copy)]
Expand All @@ -29,7 +29,7 @@ impl IntoPy<PyObject> for InputType {
/// the convention is to either implement:
/// * `strict_*` & `lax_*` if they have different behavior
/// * or, `validate_*` and `strict_*` to just call `validate_*` if the behavior for strict and lax is the same
pub trait Input<'a>: fmt::Debug + ToPyObject {
pub trait Input<'a>: fmt::Debug + ToPyObject + Sized {
fn as_loc_item(&self) -> LocItem;

fn as_error_value(&'a self) -> InputValue<'a>;
Expand Down Expand Up @@ -98,36 +98,16 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
self.strict_bytes()
}

fn validate_bool(&self, strict: bool) -> ValResult<bool> {
if strict {
self.strict_bool()
} else {
self.lax_bool()
}
}
fn strict_bool(&self) -> ValResult<bool>;
#[cfg_attr(has_no_coverage, no_coverage)]
fn lax_bool(&self) -> ValResult<bool> {
self.strict_bool()
}
fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch<bool>>;

fn validate_int(&'a self, strict: bool) -> ValResult<EitherInt<'a>> {
if strict {
self.strict_int()
} else {
self.lax_int()
}
}
fn strict_int(&'a self) -> ValResult<EitherInt<'a>>;
#[cfg_attr(has_no_coverage, no_coverage)]
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
self.strict_int()
}
fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>>;

/// Extract an EitherInt from the input, only allowing exact
/// matches for an Int (no subclasses)
fn exact_int(&'a self) -> ValResult<EitherInt<'a>> {
self.strict_int()
self.validate_int(true).and_then(|val_match| {
val_match
.require_exact()
.ok_or_else(|| ValError::new(ErrorTypeDefaults::IntType, self))
})
}

/// Extract a String from the input, only allowing exact
Expand All @@ -136,21 +116,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
self.strict_str()
}

fn validate_float(&'a self, strict: bool, ultra_strict: bool) -> ValResult<EitherFloat<'a>> {
if ultra_strict {
self.ultra_strict_float()
} else if strict {
self.strict_float()
} else {
self.lax_float()
}
}
fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>>;
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>>;
#[cfg_attr(has_no_coverage, no_coverage)]
fn lax_float(&'a self) -> ValResult<EitherFloat<'a>> {
self.strict_float()
}
fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>>;

fn validate_decimal(&'a self, strict: bool, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
if strict {
Expand Down
116 changes: 41 additions & 75 deletions src/input/input_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use super::datetime::{
float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate, EitherDateTime, EitherTime,
};
use super::parse_json::JsonArray;
use super::return_enums::ValidationMatch;
use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int};
use super::{
EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable,
Expand Down Expand Up @@ -106,74 +107,41 @@ impl<'a> Input<'a> for JsonInput {
self.validate_bytes(false)
}

fn strict_bool(&self) -> ValResult<bool> {
fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch<bool>> {
match self {
JsonInput::Bool(b) => Ok(*b),
_ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)),
}
}
fn lax_bool(&self) -> ValResult<bool> {
match self {
JsonInput::Bool(b) => Ok(*b),
JsonInput::String(s) => str_as_bool(self, s),
JsonInput::Int(int) => int_as_bool(self, *int),
JsonInput::Float(float) => match float_as_int(self, *float) {
JsonInput::Bool(b) => Ok(ValidationMatch::exact(*b)),
JsonInput::String(s) if !strict => str_as_bool(self, s).map(ValidationMatch::lax),
JsonInput::Int(int) if !strict => int_as_bool(self, *int).map(ValidationMatch::lax),
JsonInput::Float(float) if !strict => match float_as_int(self, *float) {
Ok(int) => int
.as_bool()
.ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self)),
.ok_or_else(|| ValError::new(ErrorTypeDefaults::BoolParsing, self))
.map(ValidationMatch::lax),
_ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)),
},
_ => Err(ValError::new(ErrorTypeDefaults::BoolType, self)),
}
}

fn strict_int(&'a self) -> ValResult<EitherInt<'a>> {
match self {
JsonInput::Int(i) => Ok(EitherInt::I64(*i)),
JsonInput::Uint(u) => Ok(EitherInt::U64(*u)),
JsonInput::BigInt(b) => Ok(EitherInt::BigInt(b.clone())),
_ => Err(ValError::new(ErrorTypeDefaults::IntType, self)),
}
}
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>> {
match self {
JsonInput::Bool(b) => match *b {
true => Ok(EitherInt::I64(1)),
false => Ok(EitherInt::I64(0)),
},
JsonInput::Int(i) => Ok(EitherInt::I64(*i)),
JsonInput::Uint(u) => Ok(EitherInt::U64(*u)),
JsonInput::BigInt(b) => Ok(EitherInt::BigInt(b.clone())),
JsonInput::Float(f) => float_as_int(self, *f),
JsonInput::String(str) => str_as_int(self, str),
JsonInput::Int(i) => Ok(ValidationMatch::exact(EitherInt::I64(*i))),
JsonInput::Uint(u) => Ok(ValidationMatch::exact(EitherInt::U64(*u))),
JsonInput::BigInt(b) => Ok(ValidationMatch::exact(EitherInt::BigInt(b.clone()))),
JsonInput::Bool(b) if !strict => Ok(ValidationMatch::lax(EitherInt::I64((*b).into()))),
JsonInput::Float(f) if !strict => float_as_int(self, *f).map(ValidationMatch::lax),
JsonInput::String(str) if !strict => str_as_int(self, str).map(ValidationMatch::lax),
_ => Err(ValError::new(ErrorTypeDefaults::IntType, self)),
}
}

fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>> {
match self {
JsonInput::Float(f) => Ok(EitherFloat::F64(*f)),
_ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)),
}
}
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
match self {
JsonInput::Float(f) => Ok(EitherFloat::F64(*f)),
JsonInput::Int(i) => Ok(EitherFloat::F64(*i as f64)),
JsonInput::Uint(u) => Ok(EitherFloat::F64(*u as f64)),
_ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)),
}
}
fn lax_float(&'a self) -> ValResult<EitherFloat<'a>> {
match self {
JsonInput::Bool(b) => match *b {
true => Ok(EitherFloat::F64(1.0)),
false => Ok(EitherFloat::F64(0.0)),
},
JsonInput::Float(f) => Ok(EitherFloat::F64(*f)),
JsonInput::Int(i) => Ok(EitherFloat::F64(*i as f64)),
JsonInput::Uint(u) => Ok(EitherFloat::F64(*u as f64)),
JsonInput::String(str) => str_as_float(self, str),
JsonInput::Float(f) => Ok(ValidationMatch::exact(EitherFloat::F64(*f))),
JsonInput::Int(i) => Ok(ValidationMatch::strict(EitherFloat::F64(*i as f64))),
JsonInput::Uint(u) => Ok(ValidationMatch::strict(EitherFloat::F64(*u as f64))),
JsonInput::Bool(b) if !strict => Ok(ValidationMatch::lax(EitherFloat::F64(if *b { 1.0 } else { 0.0 }))),
JsonInput::String(str) if !strict => str_as_float(self, str).map(ValidationMatch::lax),
_ => Err(ValError::new(ErrorTypeDefaults::FloatType, self)),
}
}
Expand Down Expand Up @@ -410,33 +378,31 @@ impl<'a> Input<'a> for String {
self.validate_bytes(false)
}

fn strict_bool(&self) -> ValResult<bool> {
Err(ValError::new(ErrorTypeDefaults::BoolType, self))
}
fn lax_bool(&self) -> ValResult<bool> {
str_as_bool(self, self)
fn validate_bool(&self, strict: bool) -> ValResult<'_, ValidationMatch<bool>> {
if strict {
Err(ValError::new(ErrorTypeDefaults::BoolType, self))
} else {
str_as_bool(self, self).map(ValidationMatch::lax)
}
}

fn strict_int(&'a self) -> ValResult<EitherInt<'a>> {
Err(ValError::new(ErrorTypeDefaults::IntType, self))
}
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
match self.parse() {
Ok(i) => Ok(EitherInt::I64(i)),
Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)),
fn validate_int(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherInt<'a>>> {
if strict {
Err(ValError::new(ErrorTypeDefaults::IntType, self))
} else {
match self.parse() {
Ok(i) => Ok(ValidationMatch::lax(EitherInt::I64(i))),
Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, self)),
}
}
}

#[cfg_attr(has_no_coverage, no_coverage)]
fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
self.strict_float()
}
#[cfg_attr(has_no_coverage, no_coverage)]
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
Err(ValError::new(ErrorTypeDefaults::FloatType, self))
}
fn lax_float(&'a self) -> ValResult<EitherFloat<'a>> {
str_as_float(self, self)
fn validate_float(&'a self, strict: bool) -> ValResult<'a, ValidationMatch<EitherFloat<'a>>> {
if strict {
Err(ValError::new(ErrorTypeDefaults::FloatType, self))
} else {
str_as_float(self, self).map(ValidationMatch::lax)
}
}

fn strict_decimal(&'a self, decimal_type: &'a PyType) -> ValResult<&'a PyAny> {
Expand Down
Loading

0 comments on commit ee74e77

Please sign in to comment.