diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs index f25104be8..16dde4637 100644 --- a/src/recursion_guard.rs +++ b/src/recursion_guard.rs @@ -17,16 +17,16 @@ pub struct RecursionGuard { ids: SmallContainer, // depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just // use one number for all validators - depth: u16, + depth: u8, } // A hard limit to avoid stack overflows when rampant recursion occurs -pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) { +pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) { // wasm and windows PyPy have very limited stack sizes - 50 + 49 } else if cfg!(any(PyPy, windows)) { // PyPy and Windows in general have more restricted stack space - 100 + 99 } else { 255 }; @@ -41,11 +41,26 @@ impl RecursionGuard { // see #143 this is used as a backup in case the identity check recursion guard fails #[must_use] + #[cfg(any(target_family = "wasm", windows, PyPy))] pub fn incr_depth(&mut self) -> bool { // use saturating_add as it's faster (since there's no error path) // and the RECURSION_GUARD_LIMIT check will be hit before it overflows + debug_assert!(RECURSION_GUARD_LIMIT < 255); self.depth = self.depth.saturating_add(1); - self.depth >= RECURSION_GUARD_LIMIT + self.depth > RECURSION_GUARD_LIMIT + } + + #[must_use] + #[cfg(not(any(target_family = "wasm", windows, PyPy)))] + pub fn incr_depth(&mut self) -> bool { + debug_assert_eq!(RECURSION_GUARD_LIMIT, 255); + // use checked_add to check if we've hit the limit + if let Some(depth) = self.depth.checked_add(1) { + self.depth = depth; + false + } else { + true + } } pub fn decr_depth(&mut self) { @@ -90,9 +105,9 @@ impl SmallContainer { first_slot = first_slot.or(Some(index)); } } - if let Some(first_slot) = first_slot { - array[first_slot] = Some(v); - Some(first_slot) + if let Some(index) = first_slot { + array[index] = Some(v); + first_slot } else { let mut set: AHashSet = AHashSet::with_capacity(ARRAY_SIZE + 1); for existing in array.iter_mut() { @@ -108,6 +123,7 @@ impl SmallContainer { // "If the set did not have this value present, `true` is returned." Self::Set(set) => { if set.insert(v) { + // again id doesn't matter here as we'll be removing from a set Some(0) } else { None diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index 98ec22c1f..fa6e702fe 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -371,7 +371,7 @@ def fallback_func(obj): f = FoobarCount(0) v = 0 # when recursion is detected and we're in mode python, we just return the value - expected_visits = pydantic_core._pydantic_core._recursion_limit - 1 + expected_visits = pydantic_core._pydantic_core._recursion_limit assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr(f'') with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'):