From cf8dddddd77bd76e470776db4eb8505c0947144b Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Mon, 15 Jan 2024 13:50:07 +0000 Subject: [PATCH] codify the stack-based nature of the guard --- src/recursion_guard.rs | 109 ++++++++++-------- src/serializers/extra.rs | 10 +- src/serializers/infer.rs | 14 +-- .../type_serializers/definitions.rs | 8 +- src/validators/definitions.rs | 8 +- 5 files changed, 83 insertions(+), 66 deletions(-) diff --git a/src/recursion_guard.rs b/src/recursion_guard.rs index 0dd56f74e..3b6ab7924 100644 --- a/src/recursion_guard.rs +++ b/src/recursion_guard.rs @@ -1,5 +1,5 @@ use ahash::AHashSet; -use std::hash::Hash; +use std::mem::MaybeUninit; type RecursionKey = ( // Identifier for the input object, e.g. the id() of a Python dict @@ -14,7 +14,7 @@ type RecursionKey = ( /// It's used in `validators/definition` to detect when a reference is reused within itself. #[derive(Debug, Clone, Default)] pub struct RecursionGuard { - ids: SmallContainer, + ids: RecursionStack, // depth could be a hashmap {validator_id => depth} but for simplicity and performance it's easier to just // use one number for all validators depth: u8, @@ -33,10 +33,10 @@ pub const RECURSION_GUARD_LIMIT: u8 = if cfg!(any(target_family = "wasm", all(wi impl RecursionGuard { // insert a new value - // * return `None` if the array/set already had it in it - // * return `Some(index)` if the array didn't have it in it and it was inserted - pub fn contains_or_insert(&mut self, obj_id: usize, node_id: usize) -> Option { - self.ids.contains_or_insert((obj_id, node_id)) + // * return `false` if the stack already had it in it + // * return `true` if the stack didn't have it in it and it was inserted + pub fn insert(&mut self, obj_id: usize, node_id: usize) -> bool { + self.ids.insert((obj_id, node_id)) } // see #143 this is used as a backup in case the identity check recursion guard fails @@ -68,8 +68,8 @@ impl RecursionGuard { self.depth = self.depth.saturating_sub(1); } - pub fn remove(&mut self, obj_id: usize, node_id: usize, index: usize) { - self.ids.remove(&(obj_id, node_id), index); + pub fn remove(&mut self, obj_id: usize, node_id: usize) { + self.ids.remove(&(obj_id, node_id)); } } @@ -77,63 +77,67 @@ impl RecursionGuard { const ARRAY_SIZE: usize = 16; #[derive(Debug, Clone)] -enum SmallContainer { - Array([Option; ARRAY_SIZE]), - Set(AHashSet), +enum RecursionStack { + Array { + data: [MaybeUninit; ARRAY_SIZE], + len: usize, + }, + Set(AHashSet), } -impl Default for SmallContainer { +impl Default for RecursionStack { fn default() -> Self { - Self::Array([None; ARRAY_SIZE]) + Self::Array { + data: std::array::from_fn(|_| MaybeUninit::uninit()), + len: 0, + } } } -impl SmallContainer { +impl RecursionStack { // insert a new value - // * return `None` if the array/set already had it in it - // * return `Some(index)` if the array didn't have it in it and it was inserted - pub fn contains_or_insert(&mut self, v: T) -> Option { + // * return `false` if the stack already had it in it + // * return `true` if the stack didn't have it in it and it was inserted + pub fn insert(&mut self, v: RecursionKey) -> bool { match self { - Self::Array(array) => { - for (index, op_value) in array.iter_mut().enumerate() { - if let Some(existing) = op_value { - if existing == &v { - return None; + Self::Array { data, len } => { + if *len < ARRAY_SIZE { + for value in data.iter().take(*len) { + // Safety: reading values within bounds + if unsafe { value.assume_init() } == v { + return false; } - } else { - *op_value = Some(v); - return Some(index); } - } - // No array slots exist; convert to set - let mut set: AHashSet = AHashSet::with_capacity(ARRAY_SIZE + 1); - for existing in array.iter_mut() { - set.insert(existing.take().unwrap()); + data[*len].write(v); + *len += 1; + true + } else { + let mut set = AHashSet::with_capacity(ARRAY_SIZE + 1); + for existing in data.iter() { + // Safety: the array is fully initialized + set.insert(unsafe { existing.assume_init() }); + } + let inserted = set.insert(v); + *self = Self::Set(set); + inserted } - set.insert(v); - *self = Self::Set(set); - // id doesn't matter here as we'll be removing from a set - Some(0) } // https://doc.rust-lang.org/std/collections/struct.HashSet.html#method.insert // "If the set did not have this value present, `true` is returned." - Self::Set(set) => { - if set.insert(v) { - // again id doesn't matter here as we'll be removing from a set - Some(0) - } else { - None - } - } + Self::Set(set) => set.insert(v), } } - pub fn remove(&mut self, v: &T, index: usize) { + pub fn remove(&mut self, v: &RecursionKey) { match self { - Self::Array(array) => { - debug_assert!(array[index].as_ref() == Some(v), "remove did not match insert"); - array[index] = None; + Self::Array { data, len } => { + *len = len.checked_sub(1).expect("remove from empty recursion guard"); + assert!( + // Safety: this is reading the back of the initialized array + unsafe { data[*len].assume_init() } == *v, + "remove did not match insert" + ); } Self::Set(set) => { set.remove(v); @@ -141,3 +145,16 @@ impl SmallContainer { } } } + +impl Drop for RecursionStack { + fn drop(&mut self) { + // This should compile away to a noop as Recursion>Key doesn't implement Drop, but it seemed + // desirable to leave this in for safety in case that should change in the future + if let Self::Array { data, len } = self { + for value in data.iter_mut().take(*len) { + // Safety: reading values within bounds + unsafe { std::ptr::drop_in_place(value.as_mut_ptr()) }; + } + } + } +} diff --git a/src/serializers/extra.rs b/src/serializers/extra.rs index f9a560b78..b3978613a 100644 --- a/src/serializers/extra.rs +++ b/src/serializers/extra.rs @@ -345,24 +345,24 @@ pub struct SerRecursionGuard { } impl SerRecursionGuard { - pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult<(usize, usize)> { + pub fn add(&self, value: &PyAny, def_ref_id: usize) -> PyResult { let id = value.as_ptr() as usize; let mut guard = self.guard.borrow_mut(); - if let Some(insert_index) = guard.contains_or_insert(id, def_ref_id) { + if guard.insert(id, def_ref_id) { if guard.incr_depth() { Err(PyValueError::new_err("Circular reference detected (depth exceeded)")) } else { - Ok((id, insert_index)) + Ok(id) } } else { Err(PyValueError::new_err("Circular reference detected (id repeated)")) } } - pub fn pop(&self, id: usize, def_ref_id: usize, insert_index: usize) { + pub fn pop(&self, id: usize, def_ref_id: usize) { let mut guard = self.guard.borrow_mut(); guard.decr_depth(); - guard.remove(id, def_ref_id, insert_index); + guard.remove(id, def_ref_id); } } diff --git a/src/serializers/infer.rs b/src/serializers/infer.rs index a42950d5b..13c20062b 100644 --- a/src/serializers/infer.rs +++ b/src/serializers/infer.rs @@ -45,7 +45,7 @@ pub(crate) fn infer_to_python_known( extra: &Extra, ) -> PyResult { let py = value.py(); - let (value_id, insert_index) = match extra.rec_guard.add(value, INFER_DEF_REF_ID) { + let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID) { Ok(id) => id, Err(e) => { return match extra.mode { @@ -226,7 +226,7 @@ pub(crate) fn infer_to_python_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; let next_result = infer_to_python(next_value, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); return next_result; } else if extra.serialize_unknown { serialize_unknown(value).into_py(py) @@ -284,7 +284,7 @@ pub(crate) fn infer_to_python_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,))?; let next_result = infer_to_python(next_value, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); return next_result; } value.into_py(py) @@ -292,7 +292,7 @@ pub(crate) fn infer_to_python_known( _ => value.into_py(py), }, }; - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); Ok(value) } @@ -351,7 +351,7 @@ pub(crate) fn infer_serialize_known( exclude: Option<&PyAny>, extra: &Extra, ) -> Result { - let (value_id, insert_index) = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) { + let value_id = match extra.rec_guard.add(value, INFER_DEF_REF_ID).map_err(py_err_se_err) { Ok(v) => v, Err(e) => { return if extra.serialize_unknown { @@ -534,7 +534,7 @@ pub(crate) fn infer_serialize_known( if let Some(fallback) = extra.fallback { let next_value = fallback.call1((value,)).map_err(py_err_se_err)?; let next_result = infer_serialize(next_value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); return next_result; } else if extra.serialize_unknown { serializer.serialize_str(&serialize_unknown(value)) @@ -548,7 +548,7 @@ pub(crate) fn infer_serialize_known( } } }; - extra.rec_guard.pop(value_id, INFER_DEF_REF_ID, insert_index); + extra.rec_guard.pop(value_id, INFER_DEF_REF_ID); ser_result } diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index 4a3017464..99dae5bcd 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -70,9 +70,9 @@ impl TypeSerializer for DefinitionRefSerializer { ) -> PyResult { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); - let (value_id, insert_index) = extra.rec_guard.add(value, self.definition.id())?; + let value_id = extra.rec_guard.add(value, self.definition.id())?; let r = comb_serializer.to_python(value, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id(), insert_index); + extra.rec_guard.pop(value_id, self.definition.id()); r }) } @@ -91,12 +91,12 @@ impl TypeSerializer for DefinitionRefSerializer { ) -> Result { self.definition.read(|comb_serializer| { let comb_serializer = comb_serializer.unwrap(); - let (value_id, insert_index) = extra + let value_id = extra .rec_guard .add(value, self.definition.id()) .map_err(py_err_se_err)?; let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id(), insert_index); + extra.rec_guard.pop(value_id, self.definition.id()); r }) } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 7a7cf7b3c..e8c67a690 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -76,12 +76,12 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = input.identity() { - if let Some(insert_index) = state.recursion_guard.contains_or_insert(id, self.definition.id()) { + if state.recursion_guard.insert(id, self.definition.id()) { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); } let output = validator.validate(py, input, state); - state.recursion_guard.remove(id, self.definition.id(), insert_index); + state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output } else { @@ -105,12 +105,12 @@ impl Validator for DefinitionRefValidator { self.definition.read(|validator| { let validator = validator.unwrap(); if let Some(id) = obj.identity() { - if let Some(insert_index) = state.recursion_guard.contains_or_insert(id, self.definition.id()) { + if state.recursion_guard.insert(id, self.definition.id()) { if state.recursion_guard.incr_depth() { return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); } let output = validator.validate_assignment(py, obj, field_name, field_value, state); - state.recursion_guard.remove(id, self.definition.id(), insert_index); + state.recursion_guard.remove(id, self.definition.id()); state.recursion_guard.decr_depth(); output } else {