diff --git a/Cargo.lock b/Cargo.lock index e07369ccb..4d967814b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -405,6 +405,20 @@ name = "serde" version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.159" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", +] [[package]] name = "serde_json" diff --git a/Cargo.toml b/Cargo.toml index 8969a4b51..cb1bce09e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,9 +30,9 @@ pyo3 = { version = "0.19.1", features = ["generate-import-lib", "num-bigint"] } regex = "1.6.0" strum = { version = "0.25.0", features = ["derive"] } strum_macros = "0.24.3" -serde_json = {version = "1.0.87", features = ["preserve_order"]} +serde_json = {version = "1.0.87", features = ["arbitrary_precision", "preserve_order"]} enum_dispatch = "0.3.8" -serde = "1.0.147" +serde = { version = "1.0.147", features = ["derive"] } # disabled for benchmarks since it makes microbenchmark performance more flakey mimalloc = { version = "0.1.30", optional = true, default-features = false, features = ["local_dynamic_tls"] } speedate = "0.9.1" diff --git a/src/input/input_json.rs b/src/input/input_json.rs index bb7331ccb..409ce6d81 100644 --- a/src/input/input_json.rs +++ b/src/input/input_json.rs @@ -1,5 +1,8 @@ +use std::borrow::Cow; + use pyo3::prelude::*; use pyo3::types::PyDict; +use strum::EnumMessage; use crate::errors::{ErrorType, InputValue, LocItem, ValError, ValResult}; @@ -118,6 +121,7 @@ impl<'a> Input<'a> for JsonInput { match self { JsonInput::Int(i) => Ok(EitherInt::I64(*i)), JsonInput::Uint(u) => Ok(EitherInt::U64(*u)), + JsonInput::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), _ => Err(ValError::new(ErrorType::IntType, self)), } } @@ -129,6 +133,7 @@ impl<'a> Input<'a> for JsonInput { }, JsonInput::Int(i) => Ok(EitherInt::I64(*i)), JsonInput::Uint(u) => Ok(EitherInt::U64(*u)), + JsonInput::BigInt(b) => Ok(EitherInt::BigInt(b.clone())), JsonInput::Float(f) => float_as_int(self, *f), JsonInput::String(str) => str_as_int(self, str), _ => Err(ValError::new(ErrorType::IntType, self)), @@ -270,6 +275,16 @@ impl<'a> Input<'a> for JsonInput { JsonInput::String(v) => bytes_as_time(self, v.as_bytes()), JsonInput::Int(v) => int_as_time(self, *v, 0), JsonInput::Float(v) => float_as_time(self, *v), + JsonInput::BigInt(_) => Err(ValError::new( + ErrorType::TimeParsing { + error: Cow::Borrowed( + speedate::ParseError::TimeTooLarge + .get_documentation() + .unwrap_or_default(), + ), + }, + self, + )), _ => Err(ValError::new(ErrorType::TimeType, self)), } } diff --git a/src/input/parse_json.rs b/src/input/parse_json.rs index 1505c64a7..c386bb1cc 100644 --- a/src/input/parse_json.rs +++ b/src/input/parse_json.rs @@ -1,5 +1,6 @@ use std::fmt; +use num_bigint::BigInt; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; use serde::de::{Deserialize, DeserializeSeed, Error as SerdeError, MapAccess, SeqAccess, Visitor}; @@ -12,6 +13,7 @@ pub enum JsonInput { Null, Bool(bool), Int(i64), + BigInt(BigInt), Uint(u64), Float(f64), String(String), @@ -27,6 +29,7 @@ impl ToPyObject for JsonInput { Self::Null => py.None(), Self::Bool(b) => b.into_py(py), Self::Int(i) => i.into_py(py), + Self::BigInt(b) => b.to_object(py), Self::Uint(i) => i.into_py(py), Self::Float(f) => f.into_py(py), Self::String(s) => s.into_py(py), @@ -83,9 +86,8 @@ impl<'de> Deserialize<'de> for JsonInput { Ok(JsonInput::String(value.to_string())) } - #[cfg_attr(has_no_coverage, no_coverage)] - fn visit_string(self, _: String) -> Result { - unreachable!() + fn visit_string(self, value: String) -> Result { + Ok(JsonInput::String(value)) } #[cfg_attr(has_no_coverage, no_coverage)] @@ -122,11 +124,50 @@ impl<'de> Deserialize<'de> for JsonInput { where V: MapAccess<'de>, { + const SERDE_JSON_NUMBER: &str = "$serde_json::private::Number"; match visitor.next_key_seed(KeyDeserializer)? { Some(first_key) => { let mut values = LazyIndexMap::new(); + let first_value = visitor.next_value()?; + + // serde_json will parse arbitrary precision numbers into a map + // structure with a "number" key and a String value + 'try_number: { + if first_key == SERDE_JSON_NUMBER { + // Just in case someone tries to actually store that key in a real map, + // keep parsing and continue as a map if so + + if let Some((key, value)) = visitor.next_entry::()? { + // Important to preserve order of the keys + values.insert(first_key, first_value); + values.insert(key, value); + break 'try_number; + } + + if let JsonInput::String(s) = &first_value { + // Normalize the string to either an int or float + let normalized = if s.contains('.') { + JsonInput::Float( + s.parse() + .map_err(|e| V::Error::custom(format!("expected a float: {e}")))?, + ) + } else if let Ok(i) = s.parse::() { + JsonInput::Int(i) + } else if let Ok(big) = s.parse::() { + JsonInput::BigInt(big) + } else { + // Failed to normalize, just throw it in the map and continue + values.insert(first_key, first_value); + break 'try_number; + }; + + return Ok(normalized); + }; + } else { + values.insert(first_key, first_value); + } + } - values.insert(first_key, visitor.next_value()?); while let Some((key, value)) = visitor.next_entry()? { values.insert(key, value); } diff --git a/src/input/return_enums.rs b/src/input/return_enums.rs index a7db8e9d1..b47cacc36 100644 --- a/src/input/return_enums.rs +++ b/src/input/return_enums.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use std::cmp::Ordering; use std::ops::Rem; use std::slice::Iter as SliceIter; +use std::str::FromStr; use num_bigint::BigInt; @@ -17,6 +18,7 @@ use pyo3::{ffi, intern, AsPyPointer, PyNativeType}; use pyo3::types::PyFunction; #[cfg(not(PyPy))] use pyo3::PyTypeInfo; +use serde::{ser::Error, Serialize, Serializer}; use crate::errors::{py_err_string, ErrorType, InputValue, ValError, ValLineError, ValResult}; use crate::recursion_guard::RecursionGuard; @@ -926,12 +928,26 @@ impl<'a> IntoPy for EitherFloat<'a> { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] pub enum Int { I64(i64), + #[serde(serialize_with = "serialize_bigint_as_number")] Big(BigInt), } +// The default serialization for BigInt is some internal representation which roundtrips efficiently +// but is not the JSON value which users would expect to see. +fn serialize_bigint_as_number(big_int: &BigInt, serializer: S) -> Result +where + S: Serializer, +{ + serde_json::Number::from_str(&big_int.to_string()) + .map_err(S::Error::custom) + .expect("a valid number") + .serialize(serializer) +} + impl PartialOrd for Int { fn partial_cmp(&self, other: &Self) -> Option { match (self, other) { @@ -973,3 +989,12 @@ impl<'a> FromPyObject<'a> for Int { } } } + +impl ToPyObject for Int { + fn to_object(&self, py: Python<'_>) -> PyObject { + match self { + Self::I64(i) => i.to_object(py), + Self::Big(big_i) => big_i.to_object(py), + } + } +} diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 87cdc2ee8..d668ffba6 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -10,6 +10,7 @@ use pyo3::types::{ use serde::ser::{Error, Serialize, SerializeMap, SerializeSeq, Serializer}; +use crate::input::Int; use crate::serializers::errors::SERIALIZATION_ERR_MARKER; use crate::serializers::filter::SchemaFilter; use crate::serializers::shared::{PydanticSerializer, TypeSerializer}; @@ -406,7 +407,7 @@ pub(crate) fn infer_serialize_known( let ser_result = match ob_type { ObType::None => serializer.serialize_none(), - ObType::Int | ObType::IntSubclass => serialize!(i64), + ObType::Int | ObType::IntSubclass => serialize!(Int), ObType::Bool => serialize!(bool), ObType::Float | ObType::FloatSubclass => serialize!(f64), ObType::Decimal => value.to_string().serialize(serializer), diff --git a/src/serializers/type_serializers/simple.rs b/src/serializers/type_serializers/simple.rs index e430705b9..1007fa285 100644 --- a/src/serializers/type_serializers/simple.rs +++ b/src/serializers/type_serializers/simple.rs @@ -4,7 +4,7 @@ use std::borrow::Cow; use serde::Serialize; -use crate::definitions::DefinitionsBuilder; +use crate::{definitions::DefinitionsBuilder, input::Int}; use super::{ infer_json_key, infer_serialize, infer_to_python, BuildSerializer, CombinedSerializer, Extra, IsType, ObType, @@ -163,7 +163,7 @@ pub(crate) fn to_str_json_key(key: &PyAny) -> PyResult> { Ok(key.str()?.to_string_lossy()) } -build_simple_serializer!(IntSerializer, "int", i64, ObType::Int, to_str_json_key); +build_simple_serializer!(IntSerializer, "int", Int, ObType::Int, to_str_json_key); pub(crate) fn bool_json_key(key: &PyAny) -> PyResult> { let v = if key.is_true().unwrap_or(false) { diff --git a/tests/serializers/test_simple.py b/tests/serializers/test_simple.py index e7c4f9aab..43f88382e 100644 --- a/tests/serializers/test_simple.py +++ b/tests/serializers/test_simple.py @@ -19,11 +19,16 @@ class FloatSubClass(float): pass +# A number well outside of i64 range +_BIG_NUMBER_BYTES = b'1' + (b'0' * 40) + + @pytest.mark.parametrize('custom_type_schema', [None, 'any']) @pytest.mark.parametrize( 'schema_type,value,expected_python,expected_json', [ ('int', 1, 1, b'1'), + ('int', int(_BIG_NUMBER_BYTES), int(_BIG_NUMBER_BYTES), _BIG_NUMBER_BYTES), ('bool', True, True, b'true'), ('bool', False, False, b'false'), ('float', 1.0, 1.0, b'1.0'), diff --git a/tests/test_json.py b/tests/test_json.py index 9a70e454d..d098af6df 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -63,6 +63,10 @@ def test_bytes(): s.validate_json('123') +# A number well outside of i64 range +_BIG_NUMBER_STR = '1' + ('0' * 40) + + @pytest.mark.parametrize( 'input_value,expected', [ @@ -70,7 +74,9 @@ def test_bytes(): ('"123"', 123), ('123.0', 123), ('"123.0"', 123), + (_BIG_NUMBER_STR, int(_BIG_NUMBER_STR)), ('123.4', Err('Input should be a valid integer, got a number with a fractional part [type=int_from_float,')), + ('"123.4"', Err('Input should be a valid integer, unable to parse string as an integer [type=int_parsing,')), ('"string"', Err('Input should be a valid integer, unable to parse string as an integer [type=int_parsing,')), ], ) diff --git a/tests/validators/test_int.py b/tests/validators/test_int.py index e4ed26202..65142698b 100644 --- a/tests/validators/test_int.py +++ b/tests/validators/test_int.py @@ -162,13 +162,7 @@ def test_negative_int(input_value, expected): (i64_max, i64_max), (i64_max + 1, i64_max + 1), (i64_max * 2, i64_max * 2), - ( - int(1e30), - Err( - 'Unable to parse input string as an integer, exceeded maximum size ' - '[type=int_parsing_size, input_value=1e+30, input_type=float]' - ), - ), + (int(1e30), int(1e30)), (0, Err('Input should be greater than 0 [type=greater_than, input_value=0, input_type=int]')), (-1, Err('Input should be greater than 0 [type=greater_than, input_value=-1, input_type=int]')), pytest.param( @@ -197,10 +191,7 @@ def test_positive_json(input_value, expected): (0, Err('Input should be less than 0 [type=less_than, input_value=0, input_type=int]')), (-i64_max, -i64_max), (-i64_max - 1, -i64_max - 1), - ( - -i64_max * 2, - Err(' Unable to parse input string as an integer, exceeded maximum size [type=int_parsing_size'), - ), + (-i64_max * 2, -i64_max * 2), ], ) def test_negative_json(input_value, expected): @@ -391,8 +382,7 @@ def test_long_python_inequality(): def test_long_json(): v = SchemaValidator({'type': 'int'}) - with pytest.raises(ValidationError, match=r'number out of range at line 1 column 401 \[type=json_invalid,'): - v.validate_json('-' + '1' * 400) + assert v.validate_json('-' + '1' * 400) == int('-' + '1' * 400) with pytest.raises(ValidationError, match=r'expected ident at line 1 column 2 \[type=json_invalid,'): v.validate_json('nan')