diff --git a/src/definitions.rs b/src/definitions.rs index 4627fd2d1..46a77196d 100644 --- a/src/definitions.rs +++ b/src/definitions.rs @@ -8,7 +8,7 @@ use std::{ fmt::Debug, sync::{ atomic::{AtomicBool, Ordering}, - Arc, OnceLock, + Arc, OnceLock, Weak, }, }; @@ -28,47 +28,50 @@ use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse}; /// They get indexed by a ReferenceId, which are integer identifiers /// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer} /// gets build. -#[derive(Clone)] pub struct Definitions(AHashMap, Definition>); -/// Internal type which contains a definition to be filled -pub struct Definition(Arc>); - -struct DefinitionInner { - value: OnceLock, - name: LazyName, +struct Definition { + value: Arc>, + name: Arc, } /// Reference to a definition. pub struct DefinitionRef { - name: Arc, - value: Definition, + reference: Arc, + // We use a weak reference to the definition to avoid a reference cycle + // when recursive definitions are used. + value: Weak>, + name: Arc, } // DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone) impl Clone for DefinitionRef { fn clone(&self) -> Self { Self { - name: self.name.clone(), + reference: self.reference.clone(), value: self.value.clone(), + name: self.name.clone(), } } } impl DefinitionRef { pub fn id(&self) -> usize { - Arc::as_ptr(&self.value.0) as usize + Weak::as_ptr(&self.value) as usize } pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str { - match self.value.0.value.get() { - Some(value) => self.value.0.name.get_or_init(|| init(value)), + let Some(definition) = self.value.upgrade() else { + return "..."; + }; + match definition.get() { + Some(value) => self.name.get_or_init(|| init(value)), None => "...", } } - pub fn get(&self) -> Option<&T> { - self.value.0.value.get() + pub fn read(&self, f: impl FnOnce(Option<&T>) -> R) -> R { + f(self.value.upgrade().as_ref().and_then(|value| value.get())) } } @@ -96,15 +99,9 @@ impl Debug for Definitions { } } -impl Clone for Definition { - fn clone(&self) -> Self { - Self(self.0.clone()) - } -} - impl Debug for Definition { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.0.value.get() { + match self.value.get() { Some(value) => value.fmt(f), None => "...".fmt(f), } @@ -113,7 +110,7 @@ impl Debug for Definition { impl PyGcTraverse for DefinitionRef { fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { - if let Some(value) = self.value.0.value.get() { + if let Some(value) = self.value.upgrade().as_ref().and_then(|v| v.get()) { value.py_gc_traverse(visit)?; } Ok(()) @@ -123,7 +120,7 @@ impl PyGcTraverse for DefinitionRef { impl PyGcTraverse for Definitions { fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> { for value in self.0.values() { - if let Some(value) = value.0.value.get() { + if let Some(value) = value.value.get() { value.py_gc_traverse(visit)?; } } @@ -131,7 +128,7 @@ impl PyGcTraverse for Definitions { } } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct DefinitionsBuilder { definitions: Definitions, } @@ -148,45 +145,48 @@ impl DefinitionsBuilder { // We either need a String copy or two hashmap lookups // Neither is better than the other // We opted for the easier outward facing API - let name = Arc::new(reference.to_string()); - let value = match self.definitions.0.entry(name.clone()) { + let reference = Arc::new(reference.to_string()); + let value = match self.definitions.0.entry(reference.clone()) { Entry::Occupied(entry) => entry.into_mut(), - Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner { - value: OnceLock::new(), - name: LazyName::new(), - }))), + Entry::Vacant(entry) => entry.insert(Definition { + value: Arc::new(OnceLock::new()), + name: Arc::new(LazyName::new()), + }), }; DefinitionRef { - name, - value: value.clone(), + reference, + value: Arc::downgrade(&value.value), + name: value.name.clone(), } } /// Add a definition, returning the ReferenceId that maps to it pub fn add_definition(&mut self, reference: String, value: T) -> PyResult> { - let name = Arc::new(reference); - let value = match self.definitions.0.entry(name.clone()) { + let reference = Arc::new(reference); + let value = match self.definitions.0.entry(reference.clone()) { Entry::Occupied(entry) => { let definition = entry.into_mut(); - match definition.0.value.set(value) { - Ok(()) => definition.clone(), - Err(_) => return py_schema_err!("Duplicate ref: `{}`", name), + match definition.value.set(value) { + Ok(()) => definition, + Err(_) => return py_schema_err!("Duplicate ref: `{}`", reference), } } - Entry::Vacant(entry) => entry - .insert(Definition(Arc::new(DefinitionInner { - value: OnceLock::from(value), - name: LazyName::new(), - }))) - .clone(), + Entry::Vacant(entry) => entry.insert(Definition { + value: Arc::new(OnceLock::from(value)), + name: Arc::new(LazyName::new()), + }), }; - Ok(DefinitionRef { name, value }) + Ok(DefinitionRef { + reference, + value: Arc::downgrade(&value.value), + name: value.name.clone(), + }) } /// Consume this Definitions into a vector of items, indexed by each items ReferenceId pub fn finish(self) -> PyResult> { for (reference, def) in &self.definitions.0 { - if def.0.value.get().is_none() { + if def.value.get().is_none() { return py_schema_err!("Definitions error: definition `{}` was never filled", reference); } } diff --git a/src/serializers/type_serializers/definitions.rs b/src/serializers/type_serializers/definitions.rs index b7bf63365..99dae5bcd 100644 --- a/src/serializers/type_serializers/definitions.rs +++ b/src/serializers/type_serializers/definitions.rs @@ -68,15 +68,17 @@ impl TypeSerializer for DefinitionRefSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> PyResult { - let comb_serializer = self.definition.get().unwrap(); - let value_id = extra.rec_guard.add(value, self.definition.id())?; - let r = comb_serializer.to_python(value, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id()); - r + self.definition.read(|comb_serializer| { + let comb_serializer = comb_serializer.unwrap(); + let value_id = extra.rec_guard.add(value, self.definition.id())?; + let r = comb_serializer.to_python(value, include, exclude, extra); + extra.rec_guard.pop(value_id, self.definition.id()); + r + }) } fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult> { - self.definition.get().unwrap().json_key(key, extra) + self.definition.read(|s| s.unwrap().json_key(key, extra)) } fn serde_serialize( @@ -87,14 +89,16 @@ impl TypeSerializer for DefinitionRefSerializer { exclude: Option<&PyAny>, extra: &Extra, ) -> Result { - let comb_serializer = self.definition.get().unwrap(); - let value_id = extra - .rec_guard - .add(value, self.definition.id()) - .map_err(py_err_se_err)?; - let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); - extra.rec_guard.pop(value_id, self.definition.id()); - r + self.definition.read(|comb_serializer| { + let comb_serializer = comb_serializer.unwrap(); + let value_id = extra + .rec_guard + .add(value, self.definition.id()) + .map_err(py_err_se_err)?; + let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra); + extra.rec_guard.pop(value_id, self.definition.id()); + r + }) } fn get_name(&self) -> &str { @@ -102,6 +106,6 @@ impl TypeSerializer for DefinitionRefSerializer { } fn retry_with_lax_check(&self) -> bool { - self.definition.get().unwrap().retry_with_lax_check() + self.definition.read(|s| s.unwrap().retry_with_lax_check()) } } diff --git a/src/validators/definitions.rs b/src/validators/definitions.rs index 7297bd27a..0b5f78c10 100644 --- a/src/validators/definitions.rs +++ b/src/validators/definitions.rs @@ -73,23 +73,25 @@ impl Validator for DefinitionRefValidator { input: &'data impl Input<'data>, state: &mut ValidationState, ) -> ValResult { - let validator = self.definition.get().unwrap(); - if let Some(id) = input.identity() { - if state.recursion_guard.contains_or_insert(id, self.definition.id()) { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) - } else { - if state.recursion_guard.incr_depth() { - return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); + self.definition.read(|validator| { + let validator = validator.unwrap(); + if let Some(id) = input.identity() { + if state.recursion_guard.contains_or_insert(id, self.definition.id()) { + // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` + Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)) + } else { + if state.recursion_guard.incr_depth() { + return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input)); + } + let output = validator.validate(py, input, state); + state.recursion_guard.remove(id, self.definition.id()); + state.recursion_guard.decr_depth(); + output } - let output = validator.validate(py, input, state); - state.recursion_guard.remove(id, self.definition.id()); - state.recursion_guard.decr_depth(); - output + } else { + validator.validate(py, input, state) } - } else { - validator.validate(py, input, state) - } + }) } fn validate_assignment<'data>( @@ -100,23 +102,25 @@ impl Validator for DefinitionRefValidator { field_value: &'data PyAny, state: &mut ValidationState, ) -> ValResult { - let validator = self.definition.get().unwrap(); - if let Some(id) = obj.identity() { - if state.recursion_guard.contains_or_insert(id, self.definition.id()) { - // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` - Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) - } else { - if state.recursion_guard.incr_depth() { - return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); + self.definition.read(|validator| { + let validator = validator.unwrap(); + if let Some(id) = obj.identity() { + if state.recursion_guard.contains_or_insert(id, self.definition.id()) { + // we don't remove id here, we leave that to the validator which originally added id to `recursion_guard` + Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)) + } else { + if state.recursion_guard.incr_depth() { + return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj)); + } + let output = validator.validate_assignment(py, obj, field_name, field_value, state); + state.recursion_guard.remove(id, self.definition.id()); + state.recursion_guard.decr_depth(); + output } - let output = validator.validate_assignment(py, obj, field_name, field_value, state); - state.recursion_guard.remove(id, self.definition.id()); - state.recursion_guard.decr_depth(); - output + } else { + validator.validate_assignment(py, obj, field_name, field_value, state) } - } else { - validator.validate_assignment(py, obj, field_name, field_value, state) - } + }) } fn get_name(&self) -> &str {