From f98b0852ba83f9dca1971e2d28897fd10568507b Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 14 Jan 2024 20:28:54 +0000 Subject: [PATCH 1/7] improve performance of recursion guard --- Cargo.toml | 1 + src/recursion_guard.rs | 105 ++++++++++++++---- src/serializers/extra.rs | 20 ++-- src/serializers/infer.rs | 14 +-- .../type_serializers/definitions.rs | 8 +- src/validators/definitions.rs | 20 ++-- 6 files changed, 115 insertions(+), 53 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 66a8667e0..f391daa69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ extension-module = ["pyo3/extension-module"] lto = "fat" codegen-units = 1 strip = true +#debug = true [profile.bench] debug = true diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs index 453f01a1d..f25104be8 100644 --- a/src/recursion_guard.rs +++ b/src/recursion_guard.rs @@ -1,4 +1,5 @@ use ahash::AHashSet; +use std::hash::Hash; type RecursionKey = ( // Identifier for the input object, e.g. the id() of a Python dict @@ -13,7 +14,7 @@ type RecursionKey = ( /// It's used in `validators/definition` to detect when a reference is reused within itself. #[derive(Debug, Clone, Default)] pub struct RecursionGuard { - ids: Option>, + ids: SmallContainer, // depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just // use one number for all validators depth: u16, @@ -31,38 +32,98 @@ pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(any(target_family = "wasm", all(w }; impl RecursionGuard { - // insert a new id into the set, return whether the set already had the id in it - pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> bool { - match self.ids { - // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert - // "If the set did not have this value present, `true` is returned." - Some(ref mut set) => !set.insert((obj_id, node_id)), - None => { - let mut set: AHashSet = AHashSet::with_capacity(10); - set.insert((obj_id, node_id)); - self.ids = Some(set); - false - } - } + // insert a new value + // * return `None` if the array/set already had it in it + // * return `Some(index)` if the array didn't have it in it and it was inserted + pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> Option { + self.ids.contains_or_insert((obj_id, node_id)) } // see #143 this is used as a backup in case the identity check recursion guard fails #[must_use] pub fn incr_depth(&mut self) -> bool { - self.depth += 1; + // use saturating_add as it's faster (since there's no error path) + // and the RECURSION_GUARD_LIMIT check will be hit before it overflows + self.depth = self.depth.saturating_add(1); self.depth >= RECURSION_GUARD_LIMIT } pub fn decr_depth(&mut self) { - self.depth -= 1; + // for the same reason as incr_depth, use saturating_sub + self.depth = self.depth.saturating_sub(1); + } + + pub fn remove(&mut self, obj_id: usize, node_id: usize, index: usize) { + self.ids.remove(&(obj_id, node_id), index); + } +} + +// trial and error suggests this is a good value, going higher causes array lookups to get significantly slower +const ARRAY_SIZE: usize = 16; + +#[derive(Debug, Clone)] +enum SmallContainer { + Array([Option; ARRAY_SIZE]), + Set(AHashSet), +} + +impl Default for SmallContainer { + fn default() -> Self { + Self::Array([None; ARRAY_SIZE]) } +} - pub fn remove(&mut self, obj_id: usize, node_id: usize) { - match self.ids { - Some(ref mut set) => { - set.remove(&(obj_id, node_id)); +impl SmallContainer { + // insert a new value + // * return `None` if the array/set already had it in it + // * return `Some(index)` if the array didn't have it in it and it was inserted + pub fn contains_or_insert(&mut self, v: T) -> Option { + match self { + Self::Array(array) => { + let mut first_slot: Option = None; + for (index, op_value) in array.iter().enumerate() { + if let Some(existing) = op_value { + if existing == &v { + return None; + } + } else { + first_slot = first_slot.or(Some(index)); + } + } + if let Some(first_slot) = first_slot { + array[first_slot] = Some(v); + Some(first_slot) + } else { + let mut set: AHashSet = AHashSet::with_capacity(ARRAY_SIZE + 1); + for existing in array.iter_mut() { + set.insert(existing.take().unwrap()); + } + set.insert(v); + *self = Self::Set(set); + // id doesn't matter here as we'll be removing from a set + Some(0) + } + } + // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert + // "If the set did not have this value present, `true` is returned." + Self::Set(set) => { + if set.insert(v) { + Some(0) + } else { + None + } } - None => unreachable!(), - }; + } + } + + pub fn remove(&mut self, v: &T, index: usize) { + match self { + Self::Array(array) => { + array[index] = None; + } + Self::Set(set) => { + set.remove(v); + } + } } } diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index 37307055e..f9a560b78 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -345,24 +345,24 @@ pub struct SerRecursionGuard { } impl SerRecursionGuard { - pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult { - // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert - // "If the set did not have this value present, `true` is returned." + pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<(usize, usize)> { let id = value.as_ptr() as usize; let mut guard = self.guard.borrow_mut(); - if guard.contains_or_insert(id, def_ref_id) { - Err(PyValueError::new_err("Circular reference detected (id repeated)")) - } else if guard.incr_depth() { - Err(PyValueError::new_err("Circular reference detected (depth exceeded)")) + if let Some(insert_index) = guard.contains_or_insert(id, def_ref_id) { + if guard.incr_depth() { + Err(PyValueError::new_err("Circular reference detected (depth exceeded)")) + } else { + Ok((id, insert_index)) + } } else { - Ok(id) + Err(PyValueError::new_err("Circular reference detected (id repeated)")) } } - pub fn pop(&self, id: usize, def_ref_id: usize) { + pub fn pop(&self, id: usize, def_ref_id: usize, insert_index: usize) { let mut guard = self.guard.borrow_mut(); guard.decr_depth(); - guard.remove(id, def_ref_id); + guard.remove(id, def_ref_id, insert_index); } } diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index 13c20062b..a42950d5b 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -45,7 +45,7 @@ pub(crate) fn infer_to_python_known( extra: &Extra, ) -> PyResult { let py = value.py(); - let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) { + let (value_id, insert_index) = match extra.rec_guard.add(value, INFER_DEF_REF_ID) { Ok(id) => id, Err(e) => { return match extra.mode { @@ -226,7 +226,7 @@ pub(crate) fn infer_to_python_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; let next_result = infer_to_python(next_value, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); return next_result; } else if extra.serialize_unknown { serialize_unknown(value).into_py(py) @@ -284,7 +284,7 @@ pub(crate) fn infer_to_python_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; let next_result = infer_to_python(next_value, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); return next_result; } value.into_py(py) @@ -292,7 +292,7 @@ pub(crate) fn infer_to_python_known( _ => value.into_py(py), }, }; - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); Ok(value) } @@ -351,7 +351,7 @@ pub(crate) fn infer_serialize_known( exclude: Option<&PyAny>, extra: &Extra, ) -> Result { - let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) { + let (value_id, insert_index) = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) { Ok(v) => v, Err(e) => { return if extra.serialize_unknown { @@ -534,7 +534,7 @@ pub(crate) fn infer_serialize_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,)).map_err(py_err_se_err)?; let next_result = infer_serialize(next_value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); return next_result; } else if extra.serialize_unknown { serializer.serialize_str(&serialize_unknown(value)) @@ -548,7 +548,7 @@ pub(crate) fn infer_serialize_known( } } }; - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); ser_result } diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 99dae5bcd..4a3017464 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -70,9 +70,9 @@ impl TypeSerializer for DefinitionRefSerializer { ) -> PyResult { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); - let value_id = extra.rec_guard.add(value, self.definition.id())?; + let (value_id, insert_index) = extra.rec_guard.add(value, self.definition.id())?; let r = comb_serializer.to_python(value, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id()); + extra.rec_guard.pop(value_id, self.definition.id(), insert_index); r }) } @@ -91,12 +91,12 @@ impl TypeSerializer for DefinitionRefSerializer { ) -> Result { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); - let value_id = extra + let (value_id, insert_index) = extra .rec_guard .add(value, self.definition.id()) .map_err(py_err_se_err)?; let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id()); + extra.rec_guard.pop(value_id, self.definition.id(), insert_index); r }) } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 0b5f78c10..7a7cf7b3c 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -76,17 +76,17 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = input.identity() { - if state.recursion_guard.contains_or_insert(id, self.definition.id()) { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) - } else { + if let Some(insert_index) = state.recursion_guard.contains_or_insert(id, self.definition.id()) { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); } let output = validator.validate(py, input, state); - state.recursion_guard.remove(id, self.definition.id()); + state.recursion_guard.remove(id, self.definition.id(), insert_index); state.recursion_guard.decr_depth(); output + } else { + // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` + Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) } } else { validator.validate(py, input, state) @@ -105,17 +105,17 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = obj.identity() { - if state.recursion_guard.contains_or_insert(id, self.definition.id()) { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) - } else { + if let Some(insert_index) = state.recursion_guard.contains_or_insert(id, self.definition.id()) { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); } let output = validator.validate_assignment(py, obj, field_name, field_value, state); - state.recursion_guard.remove(id, self.definition.id()); + state.recursion_guard.remove(id, self.definition.id(), insert_index); state.recursion_guard.decr_depth(); output + } else { + // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` + Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) } } else { validator.validate_assignment(py, obj, field_name, field_value, state) From 5d14e700a587354b7d545c9301b01127d838055f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 14 Jan 2024 20:44:44 +0000 Subject: [PATCH 2/7] tweak depth limit logic --- src/recursion_guard.rs | 32 ++++++++++++++++++++++++-------- tests/serializers/test_any.py | 2 +- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs index f25104be8..16dde4637 100644 --- a/src/recursion_guard.rs +++ b/src/recursion_guard.rs @@ -17,16 +17,16 @@ pub struct RecursionGuard { ids: SmallContainer, // depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just // use one number for all validators - depth: u16, + depth: u8, } // A hard limit to avoid stack overflows when rampant recursion occurs -pub const RECURSION_GUARD_LIMIT: u16 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) { +pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(windows, PyPy))) { // wasm and windows PyPy have very limited stack sizes - 50 + 49 } else if cfg!(any(PyPy, windows)) { // PyPy and Windows in general have more restricted stack space - 100 + 99 } else { 255 }; @@ -41,11 +41,26 @@ impl RecursionGuard { // see #143 this is used as a backup in case the identity check recursion guard fails #[must_use] + #[cfg(any(target_family = "wasm", windows, PyPy))] pub fn incr_depth(&mut self) -> bool { // use saturating_add as it's faster (since there's no error path) // and the RECURSION_GUARD_LIMIT check will be hit before it overflows + debug_assert!(RECURSION_GUARD_LIMIT < 255); self.depth = self.depth.saturating_add(1); - self.depth >= RECURSION_GUARD_LIMIT + self.depth > RECURSION_GUARD_LIMIT + } + + #[must_use] + #[cfg(not(any(target_family = "wasm", windows, PyPy)))] + pub fn incr_depth(&mut self) -> bool { + debug_assert_eq!(RECURSION_GUARD_LIMIT, 255); + // use checked_add to check if we've hit the limit + if let Some(depth) = self.depth.checked_add(1) { + self.depth = depth; + false + } else { + true + } } pub fn decr_depth(&mut self) { @@ -90,9 +105,9 @@ impl SmallContainer { first_slot = first_slot.or(Some(index)); } } - if let Some(first_slot) = first_slot { - array[first_slot] = Some(v); - Some(first_slot) + if let Some(index) = first_slot { + array[index] = Some(v); + first_slot } else { let mut set: AHashSet = AHashSet::with_capacity(ARRAY_SIZE + 1); for existing in array.iter_mut() { @@ -108,6 +123,7 @@ impl SmallContainer { // "If the set did not have this value present, `true` is returned." Self::Set(set) => { if set.insert(v) { + // again id doesn't matter here as we'll be removing from a set Some(0) } else { None diff --git a/tests/serializers/test_any.py b/tests/serializers/test_any.py index 98ec22c1f..fa6e702fe 100644 --- a/tests/serializers/test_any.py +++ b/tests/serializers/test_any.py @@ -371,7 +371,7 @@ def fallback_func(obj): f = FoobarCount(0) v = 0 # when recursion is detected and we're in mode python, we just return the value - expected_visits = pydantic_core._pydantic_core._recursion_limit - 1 + expected_visits = pydantic_core._pydantic_core._recursion_limit assert any_serializer.to_python(f, fallback=fallback_func) == HasRepr(f'') with pytest.raises(ValueError, match=r'Circular reference detected \(depth exceeded\)'): From 3e73e52202d2640249cf545a722ebbe852dc1c28 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 14 Jan 2024 23:41:56 +0000 Subject: [PATCH 3/7] bump From 3a621b741f324201431a764a069fdb369a207434 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 15 Jan 2024 09:41:15 +0000 Subject: [PATCH 4/7] bump From 3f7252202a11670b0ee21bf39e182d0caab9c841 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 15 Jan 2024 11:32:05 +0000 Subject: [PATCH 5/7] tidy up `contains_or_insert` to be more efficient --- src/recursion_guard.rs | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs index 16dde4637..0dd56f74e 100644 --- a/src/recursion_guard.rs +++ b/src/recursion_guard.rs @@ -95,29 +95,26 @@ impl SmallContainer { pub fn contains_or_insert(&mut self, v: T) -> Option { match self { Self::Array(array) => { - let mut first_slot: Option = None; - for (index, op_value) in array.iter().enumerate() { + for (index, op_value) in array.iter_mut().enumerate() { if let Some(existing) = op_value { if existing == &v { return None; } } else { - first_slot = first_slot.or(Some(index)); + *op_value = Some(v); + return Some(index); } } - if let Some(index) = first_slot { - array[index] = Some(v); - first_slot - } else { - let mut set: AHashSet = AHashSet::with_capacity(ARRAY_SIZE + 1); - for existing in array.iter_mut() { - set.insert(existing.take().unwrap()); - } - set.insert(v); - *self = Self::Set(set); - // id doesn't matter here as we'll be removing from a set - Some(0) + + // No array slots exist; convert to set + let mut set: AHashSet = AHashSet::with_capacity(ARRAY_SIZE + 1); + for existing in array.iter_mut() { + set.insert(existing.take().unwrap()); } + set.insert(v); + *self = Self::Set(set); + // id doesn't matter here as we'll be removing from a set + Some(0) } // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert // "If the set did not have this value present, `true` is returned." @@ -135,6 +132,7 @@ impl SmallContainer { pub fn remove(&mut self, v: &T, index: usize) { match self { Self::Array(array) => { + debug_assert!(array[index].as_ref() == Some(v), "remove did not match insert"); array[index] = None; } Self::Set(set) => { From e1190ed50bd888fa36fa4b4af6d65c599d529588 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 15 Jan 2024 12:27:55 +0000 Subject: [PATCH 6/7] remove cargo.toml release profile change Co-authored-by: David Hewitt --- Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index f391daa69..66a8667e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,7 +57,6 @@ extension-module = ["pyo3/extension-module"] lto = "fat" codegen-units = 1 strip = true -#debug = true [profile.bench] debug = true From 4b82950ff8be219c919bc6b445874a32c9fb4bec Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 15 Jan 2024 13:50:07 +0000 Subject: [PATCH 7/7] codify the stack-based nature of the guard --- src/recursion_guard.rs | 109 ++++++++++-------- src/serializers/extra.rs | 10 +- src/serializers/infer.rs | 14 +-- .../type_serializers/definitions.rs | 8 +- src/validators/definitions.rs | 8 +- 5 files changed, 83 insertions(+), 66 deletions(-) diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs index 0dd56f74e..fe5b1bcdd 100644 --- a/src/recursion_guard.rs +++ b/src/recursion_guard.rs @@ -1,5 +1,5 @@ use ahash::AHashSet; -use std::hash::Hash; +use std::mem::MaybeUninit; type RecursionKey = ( // Identifier for the input object, e.g. the id() of a Python dict @@ -14,7 +14,7 @@ type RecursionKey = ( /// It's used in `validators/definition` to detect when a reference is reused within itself. #[derive(Debug, Clone, Default)] pub struct RecursionGuard { - ids: SmallContainer, + ids: RecursionStack, // depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just // use one number for all validators depth: u8, @@ -33,10 +33,10 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi impl RecursionGuard { // insert a new value - // * return `None` if the array/set already had it in it - // * return `Some(index)` if the array didn't have it in it and it was inserted - pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> Option { - self.ids.contains_or_insert((obj_id, node_id)) + // * return `false` if the stack already had it in it + // * return `true` if the stack didn't have it in it and it was inserted + pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool { + self.ids.insert((obj_id, node_id)) } // see #143 this is used as a backup in case the identity check recursion guard fails @@ -68,8 +68,8 @@ impl RecursionGuard { self.depth = self.depth.saturating_sub(1); } - pub fn remove(&mut self, obj_id: usize, node_id: usize, index: usize) { - self.ids.remove(&(obj_id, node_id), index); + pub fn remove(&mut self, obj_id: usize, node_id: usize) { + self.ids.remove(&(obj_id, node_id)); } } @@ -77,63 +77,67 @@ impl RecursionGuard { const ARRAY_SIZE: usize = 16; #[derive(Debug, Clone)] -enum SmallContainer { - Array([Option; ARRAY_SIZE]), - Set(AHashSet), +enum RecursionStack { + Array { + data: [MaybeUninit; ARRAY_SIZE], + len: usize, + }, + Set(AHashSet), } -impl Default for SmallContainer { +impl Default for RecursionStack { fn default() -> Self { - Self::Array([None; ARRAY_SIZE]) + Self::Array { + data: std::array::from_fn(|_| MaybeUninit::uninit()), + len: 0, + } } } -impl SmallContainer { +impl RecursionStack { // insert a new value - // * return `None` if the array/set already had it in it - // * return `Some(index)` if the array didn't have it in it and it was inserted - pub fn contains_or_insert(&mut self, v: T) -> Option { + // * return `false` if the stack already had it in it + // * return `true` if the stack didn't have it in it and it was inserted + pub fn insert(&mut self, v: RecursionKey) -> bool { match self { - Self::Array(array) => { - for (index, op_value) in array.iter_mut().enumerate() { - if let Some(existing) = op_value { - if existing == &v { - return None; + Self::Array { data, len } => { + if *len < ARRAY_SIZE { + for value in data.iter().take(*len) { + // Safety: reading values within bounds + if unsafe { value.assume_init() } == v { + return false; } - } else { - *op_value = Some(v); - return Some(index); } - } - // No array slots exist; convert to set - let mut set: AHashSet = AHashSet::with_capacity(ARRAY_SIZE + 1); - for existing in array.iter_mut() { - set.insert(existing.take().unwrap()); + data[*len].write(v); + *len += 1; + true + } else { + let mut set = AHashSet::with_capacity(ARRAY_SIZE + 1); + for existing in data.iter() { + // Safety: the array is fully initialized + set.insert(unsafe { existing.assume_init() }); + } + let inserted = set.insert(v); + *self = Self::Set(set); + inserted } - set.insert(v); - *self = Self::Set(set); - // id doesn't matter here as we'll be removing from a set - Some(0) } // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert // "If the set did not have this value present, `true` is returned." - Self::Set(set) => { - if set.insert(v) { - // again id doesn't matter here as we'll be removing from a set - Some(0) - } else { - None - } - } + Self::Set(set) => set.insert(v), } } - pub fn remove(&mut self, v: &T, index: usize) { + pub fn remove(&mut self, v: &RecursionKey) { match self { - Self::Array(array) => { - debug_assert!(array[index].as_ref() == Some(v), "remove did not match insert"); - array[index] = None; + Self::Array { data, len } => { + *len = len.checked_sub(1).expect("remove from empty recursion guard"); + // Safety: this is reading what was the back of the initialized array + let removed = unsafe { data.get_unchecked_mut(*len) }; + assert!(unsafe { removed.assume_init_ref() } == v, "remove did not match insert"); + // this should compile away to a noop + unsafe { std::ptr::drop_in_place(removed.as_mut_ptr()) } } Self::Set(set) => { set.remove(v); @@ -141,3 +145,16 @@ impl SmallContainer { } } } + +impl Drop for RecursionStack { + fn drop(&mut self) { + // This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed + // desirable to leave this in for safety in case that should change in the future + if let Self::Array { data, len } = self { + for value in data.iter_mut().take(*len) { + // Safety: reading values within bounds + unsafe { std::ptr::drop_in_place(value.as_mut_ptr()) }; + } + } + } +} diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index f9a560b78..b3978613a 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -345,24 +345,24 @@ pub struct SerRecursionGuard { } impl SerRecursionGuard { - pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<(usize, usize)> { + pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult { let id = value.as_ptr() as usize; let mut guard = self.guard.borrow_mut(); - if let Some(insert_index) = guard.contains_or_insert(id, def_ref_id) { + if guard.insert(id, def_ref_id) { if guard.incr_depth() { Err(PyValueError::new_err("Circular reference detected (depth exceeded)")) } else { - Ok((id, insert_index)) + Ok(id) } } else { Err(PyValueError::new_err("Circular reference detected (id repeated)")) } } - pub fn pop(&self, id: usize, def_ref_id: usize, insert_index: usize) { + pub fn pop(&self, id: usize, def_ref_id: usize) { let mut guard = self.guard.borrow_mut(); guard.decr_depth(); - guard.remove(id, def_ref_id, insert_index); + guard.remove(id, def_ref_id); } } diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index a42950d5b..13c20062b 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -45,7 +45,7 @@ pub(crate) fn infer_to_python_known( extra: &Extra, ) -> PyResult { let py = value.py(); - let (value_id, insert_index) = match extra.rec_guard.add(value, INFER_DEF_REF_ID) { + let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) { Ok(id) => id, Err(e) => { return match extra.mode { @@ -226,7 +226,7 @@ pub(crate) fn infer_to_python_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; let next_result = infer_to_python(next_value, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); return next_result; } else if extra.serialize_unknown { serialize_unknown(value).into_py(py) @@ -284,7 +284,7 @@ pub(crate) fn infer_to_python_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; let next_result = infer_to_python(next_value, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); return next_result; } value.into_py(py) @@ -292,7 +292,7 @@ pub(crate) fn infer_to_python_known( _ => value.into_py(py), }, }; - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); Ok(value) } @@ -351,7 +351,7 @@ pub(crate) fn infer_serialize_known( exclude: Option<&PyAny>, extra: &Extra, ) -> Result { - let (value_id, insert_index) = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) { + let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) { Ok(v) => v, Err(e) => { return if extra.serialize_unknown { @@ -534,7 +534,7 @@ pub(crate) fn infer_serialize_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,)).map_err(py_err_se_err)?; let next_result = infer_serialize(next_value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); return next_result; } else if extra.serialize_unknown { serializer.serialize_str(&serialize_unknown(value)) @@ -548,7 +548,7 @@ pub(crate) fn infer_serialize_known( } } }; - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); ser_result } diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 4a3017464..99dae5bcd 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -70,9 +70,9 @@ impl TypeSerializer for DefinitionRefSerializer { ) -> PyResult { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); - let (value_id, insert_index) = extra.rec_guard.add(value, self.definition.id())?; + let value_id = extra.rec_guard.add(value, self.definition.id())?; let r = comb_serializer.to_python(value, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id(), insert_index); + extra.rec_guard.pop(value_id, self.definition.id()); r }) } @@ -91,12 +91,12 @@ impl TypeSerializer for DefinitionRefSerializer { ) -> Result { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); - let (value_id, insert_index) = extra + let value_id = extra .rec_guard .add(value, self.definition.id()) .map_err(py_err_se_err)?; let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id(), insert_index); + extra.rec_guard.pop(value_id, self.definition.id()); r }) } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 7a7cf7b3c..e8c67a690 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -76,12 +76,12 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = input.identity() { - if let Some(insert_index) = state.recursion_guard.contains_or_insert(id, self.definition.id()) { + if state.recursion_guard.insert(id, self.definition.id()) { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); } let output = validator.validate(py, input, state); - state.recursion_guard.remove(id, self.definition.id(), insert_index); + state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output } else { @@ -105,12 +105,12 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = obj.identity() { - if let Some(insert_index) = state.recursion_guard.contains_or_insert(id, self.definition.id()) { + if state.recursion_guard.insert(id, self.definition.id()) { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); } let output = validator.validate_assignment(py, obj, field_name, field_value, state); - state.recursion_guard.remove(id, self.definition.id(), insert_index); + state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output } else {