diff --git a/src/input/shared.rs b/src/input/shared.rs index e99bfabcf..f4251a11d 100644 --- a/src/input/shared.rs +++ b/src/input/shared.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use pyo3::prelude::*; use pyo3::sync::GILOnceCell; use pyo3::{intern, Py, PyAny, Python}; @@ -61,17 +63,19 @@ fn strip_underscores(s: &str) -> Option { // Double consecutive underscores are also not valid // If there are no underscores at all, no need to replace anything if s.starts_with('_') || s.ends_with('_') || !s.contains('_') || s.contains("__") { - // no underscores to strip - return None; + // no underscores to strip, or underscores in the wrong place + None + } else { + Some(s.replace('_', "")) } - Some(s.replace('_', "")) } /// parse a string as an int +/// max length of the input is 4300 which is checked by jiter, see +/// https://docs.python.org/3/whatsnew/3.11.html#other-cpython-implementation-changes and +/// https://github.com/python/cpython/issues/95778 for more info in that length bound pub fn str_as_int<'py>(input: &(impl Input<'py> + ?Sized), str: &str) -> ValResult> { - let str = str.trim(); - - // we have to call `NumberInt::try_from` directly first so we fail fast if the string is too long + // we can't move `NumberInt::try_from` into its own function we fail fast if the string is too long match NumberInt::try_from(str.as_bytes()) { Ok(NumberInt::Int(i)) => return Ok(EitherInt::I64(i)), Ok(NumberInt::BigInt(i)) => return Ok(EitherInt::BigInt(i)), @@ -82,10 +86,12 @@ pub fn str_as_int<'py>(input: &(impl Input<'py> + ?Sized), str: &str) -> ValResu } } - if let Some(str_stripped) = strip_decimal_zeros(str) { - _parse_str(input, str_stripped) - } else if let Some(str_stripped) = strip_underscores(str) { - _parse_str(input, &str_stripped) + if let Some(cleaned_str) = clean_int_str(str) { + match NumberInt::try_from(cleaned_str.as_ref().as_bytes()) { + Ok(NumberInt::Int(i)) => Ok(EitherInt::I64(i)), + Ok(NumberInt::BigInt(i)) => Ok(EitherInt::BigInt(i)), + Err(_) => Err(ValError::new(ErrorTypeDefaults::IntParsing, input)), + } } else { Err(ValError::new(ErrorTypeDefaults::IntParsing, input)) } @@ -102,30 +108,32 @@ pub fn str_as_float<'py>(input: &(impl Input<'py> + ?Sized), str: &str) -> ValRe } } -/// parse a string as an int, `input` is required here to get lifetimes to match up -/// max length of the input is 4300 which is checked by jiter, see -/// https://docs.python.org/3/whatsnew/3.11.html#other-cpython-implementation-changes and -/// https://github.com/python/cpython/issues/95778 for more info in that length bound -fn _parse_str<'py>(input: &(impl Input<'py> + ?Sized), str: &str) -> ValResult> { - match NumberInt::try_from(str.as_bytes()) { - Ok(jiter::NumberInt::Int(i)) => Ok(EitherInt::I64(i)), - Ok(jiter::NumberInt::BigInt(i)) => Ok(EitherInt::BigInt(i)), - Err(e) => match e.error_type { - JsonErrorType::NumberOutOfRange => Err(ValError::new(ErrorTypeDefaults::IntParsingSize, input)), - _ => Err(ValError::new(ErrorTypeDefaults::IntParsing, input)), - }, - } -} +fn clean_int_str(mut s: &str) -> Option> { + let len_before = s.len(); + + // strip leading and trailing whitespace + s = s.trim(); -/// we don't want to parse as f64 then call `float_as_int` as it can loose precision for large ints, therefore -/// we strip `.0+` manually instead, then parse as i64 -fn strip_decimal_zeros(s: &str) -> Option<&str> { + // strip loading zeros + s = s.trim_start_matches('0'); + + // we don't want to parse as f64 then call `float_as_int` as it can lose precision for large ints, therefore + // we strip `.0+` manually instead if let Some(i) = s.find('.') { if s[i + 1..].chars().all(|c| c == '0') { - return Some(&s[..i]); + s = &s[..i]; + } + } + + // remove underscores + if let Some(str_stripped) = strip_underscores(s) { + Some(str_stripped.into()) + } else { + match len_before == s.len() { + true => None, + false => Some(s.into()), } } - None } pub fn float_as_int<'py>(input: &(impl Input<'py> + ?Sized), float: f64) -> ValResult> { diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index baddb6289..775c65cc3 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -26,6 +26,11 @@ ('42', 42), (42.0, 42), ('42.0', 42), + ('042', 42), + ('4_2', 42), + ('4_2.0', 42), + ('04_2.0', 42), + ('000001', 1), ('123456789.0', 123_456_789), ('123456789123456.00001', Err('Input should be a valid integer, unable to parse string as an integer')), (int(1e10), int(1e10)),