Skip to content

Commit

Permalink
fix memory leak with recursive definitions creating reference cycles
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Dec 19, 2023
1 parent bec63db commit f9d1e1a
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 92 deletions.
94 changes: 47 additions & 47 deletions src/definitions.rs
Expand Up @@ -8,7 +8,7 @@ use std::{
fmt::Debug,
sync::{
atomic::{AtomicBool, Ordering},
Arc, OnceLock,
Arc, OnceLock, Weak,
},
};

Expand All @@ -28,47 +28,50 @@ use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse};
/// They get indexed by a ReferenceId, which are integer identifiers
/// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer}
/// gets build.
#[derive(Clone)]
pub struct Definitions<T>(AHashMap<Arc<String>, Definition<T>>);

/// Internal type which contains a definition to be filled
pub struct Definition<T>(Arc<DefinitionInner<T>>);

struct DefinitionInner<T> {
value: OnceLock<T>,
name: LazyName,
struct Definition<T> {
value: Arc<OnceLock<T>>,
name: Arc<LazyName>,
}

/// Reference to a definition.
pub struct DefinitionRef<T> {
name: Arc<String>,
value: Definition<T>,
reference: Arc<String>,
// We use a weak reference to the definition to avoid a reference cycle
// when recursive definitions are used.
value: Weak<OnceLock<T>>,
name: Arc<LazyName>,
}

// DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone)
impl<T> Clone for DefinitionRef<T> {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
reference: self.reference.clone(),
value: self.value.clone(),
name: self.name.clone(),
}
}
}

impl<T> DefinitionRef<T> {
pub fn id(&self) -> usize {
Arc::as_ptr(&self.value.0) as usize
Weak::as_ptr(&self.value) as usize
}

pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str {
match self.value.0.value.get() {
Some(value) => self.value.0.name.get_or_init(|| init(value)),
let Some(definition) = self.value.upgrade() else {
return "...";
};
match definition.get() {
Some(value) => self.name.get_or_init(|| init(value)),
None => "...",
}
}

pub fn get(&self) -> Option<&T> {
self.value.0.value.get()
pub fn read<R>(&self, f: impl FnOnce(Option<&T>) -> R) -> R {
f(self.value.upgrade().as_ref().and_then(|value| value.get()))
}
}

Expand Down Expand Up @@ -96,15 +99,9 @@ impl<T: Debug> Debug for Definitions<T> {
}
}

impl<T> Clone for Definition<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl<T: Debug> Debug for Definition<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0.value.get() {
match self.value.get() {
Some(value) => value.fmt(f),
None => "...".fmt(f),
}
Expand All @@ -113,7 +110,7 @@ impl<T: Debug> Debug for Definition<T> {

impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
if let Some(value) = self.value.0.value.get() {
if let Some(value) = self.value.upgrade().as_ref().and_then(|v| v.get()) {
value.py_gc_traverse(visit)?;
}
Ok(())
Expand All @@ -123,15 +120,15 @@ impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
impl<T: PyGcTraverse> PyGcTraverse for Definitions<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
for value in self.0.values() {
if let Some(value) = value.0.value.get() {
if let Some(value) = value.value.get() {
value.py_gc_traverse(visit)?;
}
}
Ok(())
}
}

#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct DefinitionsBuilder<T> {
definitions: Definitions<T>,
}
Expand All @@ -148,45 +145,48 @@ impl<T: std::fmt::Debug> DefinitionsBuilder<T> {
// We either need a String copy or two hashmap lookups
// Neither is better than the other
// We opted for the easier outward facing API
let name = Arc::new(reference.to_string());
let value = match self.definitions.0.entry(name.clone()) {
let reference = Arc::new(reference.to_string());
let value = match self.definitions.0.entry(reference.clone()) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner {
value: OnceLock::new(),
name: LazyName::new(),
}))),
Entry::Vacant(entry) => entry.insert(Definition {
value: Arc::new(OnceLock::new()),
name: Arc::new(LazyName::new()),
}),
};
DefinitionRef {
name,
value: value.clone(),
reference,
value: Arc::downgrade(&value.value),
name: value.name.clone(),
}
}

/// Add a definition, returning the ReferenceId that maps to it
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<DefinitionRef<T>> {
let name = Arc::new(reference);
let value = match self.definitions.0.entry(name.clone()) {
let reference = Arc::new(reference);
let value = match self.definitions.0.entry(reference.clone()) {
Entry::Occupied(entry) => {
let definition = entry.into_mut();
match definition.0.value.set(value) {
Ok(()) => definition.clone(),
Err(_) => return py_schema_err!("Duplicate ref: `{}`", name),
match definition.value.set(value) {
Ok(()) => definition,
Err(_) => return py_schema_err!("Duplicate ref: `{}`", reference),
}
}
Entry::Vacant(entry) => entry
.insert(Definition(Arc::new(DefinitionInner {
value: OnceLock::from(value),
name: LazyName::new(),
})))
.clone(),
Entry::Vacant(entry) => entry.insert(Definition {
value: Arc::new(OnceLock::from(value)),
name: Arc::new(LazyName::new()),
}),
};
Ok(DefinitionRef { name, value })
Ok(DefinitionRef {
reference,
value: Arc::downgrade(&value.value),
name: value.name.clone(),
})
}

/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
pub fn finish(self) -> PyResult<Definitions<T>> {
for (reference, def) in &self.definitions.0 {
if def.0.value.get().is_none() {
if def.value.get().is_none() {
return py_schema_err!("Definitions error: definition `{}` was never filled", reference);
}
}
Expand Down
34 changes: 19 additions & 15 deletions src/serializers/type_serializers/definitions.rs
Expand Up @@ -68,15 +68,17 @@ impl TypeSerializer for DefinitionRefSerializer {
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let comb_serializer = self.definition.get().unwrap();
let value_id = extra.rec_guard.add(value, self.definition.id())?;
let r = comb_serializer.to_python(value, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id());
r
self.definition.read(|comb_serializer| {
let comb_serializer = comb_serializer.unwrap();
let value_id = extra.rec_guard.add(value, self.definition.id())?;
let r = comb_serializer.to_python(value, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id());
r
})
}

fn json_key<'py>(&self, key: &'py PyAny, extra: &Extra) -> PyResult<Cow<'py, str>> {
self.definition.get().unwrap().json_key(key, extra)
self.definition.read(|s| s.unwrap().json_key(key, extra))
}

fn serde_serialize<S: serde::ser::Serializer>(
Expand All @@ -87,21 +89,23 @@ impl TypeSerializer for DefinitionRefSerializer {
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let comb_serializer = self.definition.get().unwrap();
let value_id = extra
.rec_guard
.add(value, self.definition.id())
.map_err(py_err_se_err)?;
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id());
r
self.definition.read(|comb_serializer| {
let comb_serializer = comb_serializer.unwrap();
let value_id = extra
.rec_guard
.add(value, self.definition.id())
.map_err(py_err_se_err)?;
let r = comb_serializer.serde_serialize(value, serializer, include, exclude, extra);
extra.rec_guard.pop(value_id, self.definition.id());
r
})
}

fn get_name(&self) -> &str {
Self::EXPECTED_TYPE
}

fn retry_with_lax_check(&self) -> bool {
self.definition.get().unwrap().retry_with_lax_check()
self.definition.read(|s| s.unwrap().retry_with_lax_check())
}
}
64 changes: 34 additions & 30 deletions src/validators/definitions.rs
Expand Up @@ -73,23 +73,25 @@ impl Validator for DefinitionRefValidator {
input: &'data impl Input<'data>,
state: &mut ValidationState,
) -> ValResult<PyObject> {
let validator = self.definition.get().unwrap();
if let Some(id) = input.identity() {
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
} else {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = input.identity() {
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input))
} else {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
}
let output = validator.validate(py, input, state);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.decr_depth();
output
}
let output = validator.validate(py, input, state);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.decr_depth();
output
} else {
validator.validate(py, input, state)
}
} else {
validator.validate(py, input, state)
}
})
}

fn validate_assignment<'data>(
Expand All @@ -100,23 +102,25 @@ impl Validator for DefinitionRefValidator {
field_value: &'data PyAny,
state: &mut ValidationState,
) -> ValResult<PyObject> {
let validator = self.definition.get().unwrap();
if let Some(id) = obj.identity() {
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
} else {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
self.definition.read(|validator| {
let validator = validator.unwrap();
if let Some(id) = obj.identity() {
if state.recursion_guard.contains_or_insert(id, self.definition.id()) {
// we don't remove id here, we leave that to the validator which originally added id to `recursion_guard`
Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj))
} else {
if state.recursion_guard.incr_depth() {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, obj));
}
let output = validator.validate_assignment(py, obj, field_name, field_value, state);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.decr_depth();
output
}
let output = validator.validate_assignment(py, obj, field_name, field_value, state);
state.recursion_guard.remove(id, self.definition.id());
state.recursion_guard.decr_depth();
output
} else {
validator.validate_assignment(py, obj, field_name, field_value, state)
}
} else {
validator.validate_assignment(py, obj, field_name, field_value, state)
}
})
}

fn get_name(&self) -> &str {
Expand Down

0 comments on commit f9d1e1a

Please sign in to comment.