Skip to content

Commit

Permalink
Replace definitions Vec with OnceLock slots (#992)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Sep 28, 2023
1 parent 1a966d5 commit a8fb1e3
Show file tree
Hide file tree
Showing 54 changed files with 793 additions and 690 deletions.
268 changes: 204 additions & 64 deletions src/definitions.rs
Expand Up @@ -3,16 +3,20 @@
/// Unlike json schema we let you put definitions inline, not just in a single '#/$defs/' block or similar.
/// We use DefinitionsBuilder to collect the references / definitions into a single vector
/// and then get a definition from a reference using an integer id (just for performance of not using a HashMap)
use std::collections::hash_map::Entry;
use std::{
collections::hash_map::Entry,
fmt::Debug,
sync::{
atomic::{AtomicBool, Ordering},
Arc, OnceLock,
},
};

use pyo3::prelude::*;
use pyo3::{prelude::*, PyTraverseError, PyVisit};

use ahash::AHashMap;

use crate::build_tools::py_schema_err;

// An integer id for the reference
pub type ReferenceId = usize;
use crate::{build_tools::py_schema_err, py_gc::PyGcTraverse};

/// Definitions are validators and serializers that are
/// shared by reference.
Expand All @@ -24,91 +28,227 @@ pub type ReferenceId = usize;
/// They get indexed by a ReferenceId, which are integer identifiers
/// that are handed out and managed by DefinitionsBuilder when the Schema{Validator,Serializer}
/// gets build.
pub type Definitions<T> = [T];
#[derive(Clone)]
pub struct Definitions<T>(AHashMap<Arc<String>, Definition<T>>);

#[derive(Clone, Debug)]
struct Definition<T> {
pub id: ReferenceId,
pub value: Option<T>,
impl<T> Definitions<T> {
pub fn values(&self) -> impl Iterator<Item = &Definition<T>> {
self.0.values()
}
}

/// Internal type which contains a definition to be filled
pub struct Definition<T>(Arc<DefinitionInner<T>>);

impl<T> Definition<T> {
pub fn get(&self) -> Option<&T> {
self.0.value.get()
}
}

struct DefinitionInner<T> {
value: OnceLock<T>,
name: LazyName,
}

/// Reference to a definition.
pub struct DefinitionRef<T> {
name: Arc<String>,
value: Definition<T>,
}

// DefinitionRef can always be cloned (#[derive(Clone)] would require T: Clone)
impl<T> Clone for DefinitionRef<T> {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
value: self.value.clone(),
}
}
}

impl<T> DefinitionRef<T> {
pub fn id(&self) -> usize {
Arc::as_ptr(&self.value.0) as usize
}

pub fn get_or_init_name(&self, init: impl FnOnce(&T) -> String) -> &str {
match self.value.0.value.get() {
Some(value) => self.value.0.name.get_or_init(|| init(value)),
None => "...",
}
}

pub fn get(&self) -> Option<&T> {
self.value.0.value.get()
}
}

impl<T: Debug> Debug for DefinitionRef<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// To avoid possible infinite recursion from recursive definitions,
// a DefinitionRef just displays debug as its name
self.name.fmt(f)
}
}

impl<T: Debug> Debug for Definitions<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Formatted as a list for backwards compatibility; in principle
// this could be formatted as a map. Maybe change in a future
// minor release of pydantic.
write![f, "["]?;
let mut first = true;
for def in self.0.values() {
write![f, "{sep}{def:?}", sep = if first { "" } else { ", " }]?;
first = false;
}
write![f, "]"]?;
Ok(())
}
}

impl<T> Clone for Definition<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl<T: Debug> Debug for Definition<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.0.value.get() {
Some(value) => value.fmt(f),
None => "...".fmt(f),
}
}
}

impl<T: PyGcTraverse> PyGcTraverse for DefinitionRef<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
if let Some(value) = self.value.0.value.get() {
value.py_gc_traverse(visit)?;
}
Ok(())
}
}

impl<T: PyGcTraverse> PyGcTraverse for Definitions<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
for value in self.0.values() {
if let Some(value) = value.0.value.get() {
value.py_gc_traverse(visit)?;
}
}
Ok(())
}
}

#[derive(Clone, Debug)]
pub struct DefinitionsBuilder<T> {
definitions: AHashMap<String, Definition<T>>,
definitions: Definitions<T>,
}

impl<T: Clone + std::fmt::Debug> DefinitionsBuilder<T> {
impl<T: std::fmt::Debug> DefinitionsBuilder<T> {
pub fn new() -> Self {
Self {
definitions: AHashMap::new(),
definitions: Definitions(AHashMap::new()),
}
}

/// Get a ReferenceId for the given reference string.
// This ReferenceId can later be used to retrieve a definition
pub fn get_reference_id(&mut self, reference: &str) -> ReferenceId {
let next_id = self.definitions.len();
pub fn get_definition(&mut self, reference: &str) -> DefinitionRef<T> {
// We either need a String copy or two hashmap lookups
// Neither is better than the other
// We opted for the easier outward facing API
match self.definitions.entry(reference.to_string()) {
Entry::Occupied(entry) => entry.get().id,
Entry::Vacant(entry) => {
entry.insert(Definition {
id: next_id,
value: None,
});
next_id
}
let name = Arc::new(reference.to_string());
let value = match self.definitions.0.entry(name.clone()) {
Entry::Occupied(entry) => entry.into_mut(),
Entry::Vacant(entry) => entry.insert(Definition(Arc::new(DefinitionInner {
value: OnceLock::new(),
name: LazyName::new(),
}))),
};
DefinitionRef {
name,
value: value.clone(),
}
}

/// Add a definition, returning the ReferenceId that maps to it
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<ReferenceId> {
let next_id = self.definitions.len();
match self.definitions.entry(reference.clone()) {
Entry::Occupied(mut entry) => match entry.get_mut().value.replace(value) {
Some(_) => py_schema_err!("Duplicate ref: `{}`", reference),
None => Ok(entry.get().id),
},
Entry::Vacant(entry) => {
entry.insert(Definition {
id: next_id,
value: Some(value),
});
Ok(next_id)
pub fn add_definition(&mut self, reference: String, value: T) -> PyResult<DefinitionRef<T>> {
let name = Arc::new(reference);
let value = match self.definitions.0.entry(name.clone()) {
Entry::Occupied(entry) => {
let definition = entry.into_mut();
match definition.0.value.set(value) {
Ok(()) => definition.clone(),
Err(_) => return py_schema_err!("Duplicate ref: `{}`", name),
}
}
Entry::Vacant(entry) => entry
.insert(Definition(Arc::new(DefinitionInner {
value: OnceLock::from(value),
name: LazyName::new(),
})))
.clone(),
};
Ok(DefinitionRef { name, value })
}

/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
pub fn finish(self) -> PyResult<Definitions<T>> {
for (reference, def) in &self.definitions.0 {
if def.0.value.get().is_none() {
return py_schema_err!("Definitions error: definition `{}` was never filled", reference);
}
}
Ok(self.definitions)
}
}

/// Retrieve an item definition using a ReferenceId
/// If the definition doesn't yet exist (as happens in recursive types) then we create it
/// At the end (in finish()) we check that there are no undefined definitions
pub fn get_definition(&self, reference_id: ReferenceId) -> PyResult<&T> {
let (reference, def) = match self.definitions.iter().find(|(_, def)| def.id == reference_id) {
Some(v) => v,
None => return py_schema_err!("Definitions error: no definition for ReferenceId `{}`", reference_id),
};
match def.value.as_ref() {
Some(v) => Ok(v),
None => py_schema_err!(
"Definitions error: attempted to use `{}` before it was filled",
reference
),
struct LazyName {
initialized: OnceLock<String>,
in_recursion: AtomicBool,
}

impl LazyName {
fn new() -> Self {
Self {
initialized: OnceLock::new(),
in_recursion: AtomicBool::new(false),
}
}

/// Consume this Definitions into a vector of items, indexed by each items ReferenceId
pub fn finish(self) -> PyResult<Vec<T>> {
// We need to create a vec of defs according to the order in their ids
let mut defs: Vec<(usize, T)> = Vec::new();
for (reference, def) in self.definitions {
match def.value {
None => return py_schema_err!("Definitions error: definition {} was never filled", reference),
Some(v) => defs.push((def.id, v)),
}
/// Gets the validator name, returning the default in the case of recursion loops
fn get_or_init(&self, init: impl FnOnce() -> String) -> &str {
if let Some(s) = self.initialized.get() {
return s.as_str();
}

if self
.in_recursion
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
return "...";
}
let result = self.initialized.get_or_init(init).as_str();
self.in_recursion.store(false, Ordering::SeqCst);
result
}
}

impl Debug for LazyName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.initialized.get().map_or("...", String::as_str).fmt(f)
}
}

impl Clone for LazyName {
fn clone(&self) -> Self {
Self {
initialized: OnceLock::new(),
in_recursion: AtomicBool::new(false),
}
defs.sort_by_key(|(id, _)| *id);
Ok(defs.into_iter().map(|(_, v)| v).collect())
}
}
8 changes: 8 additions & 0 deletions src/py_gc.rs
@@ -1,3 +1,5 @@
use std::sync::Arc;

use ahash::AHashMap;
use enum_dispatch::enum_dispatch;
use pyo3::{AsPyPointer, Py, PyTraverseError, PyVisit};
Expand Down Expand Up @@ -35,6 +37,12 @@ impl<T: PyGcTraverse> PyGcTraverse for AHashMap<String, T> {
}
}

impl<T: PyGcTraverse> PyGcTraverse for Arc<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
T::py_gc_traverse(self, visit)
}
}

impl<T: PyGcTraverse> PyGcTraverse for Box<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
T::py_gc_traverse(self, visit)
Expand Down
9 changes: 0 additions & 9 deletions src/serializers/extra.rs
Expand Up @@ -10,8 +10,6 @@ use serde::ser::Error;
use super::config::SerializationConfig;
use super::errors::{PydanticSerializationUnexpectedValue, UNEXPECTED_TYPE_SER_MARKER};
use super::ob_type::ObTypeLookup;
use super::shared::CombinedSerializer;
use crate::definitions::Definitions;
use crate::recursion_guard::RecursionGuard;

/// this is ugly, would be much better if extra could be stored in `SerializationState`
Expand Down Expand Up @@ -48,7 +46,6 @@ impl SerializationState {
Extra::new(
py,
mode,
&[],
by_alias,
&self.warnings,
false,
Expand All @@ -72,7 +69,6 @@ impl SerializationState {
#[cfg_attr(debug_assertions, derive(Debug))]
pub(crate) struct Extra<'a> {
pub mode: &'a SerMode,
pub definitions: &'a Definitions<CombinedSerializer>,
pub ob_type_lookup: &'a ObTypeLookup,
pub warnings: &'a CollectWarnings,
pub by_alias: bool,
Expand All @@ -98,7 +94,6 @@ impl<'a> Extra<'a> {
pub fn new(
py: Python<'a>,
mode: &'a SerMode,
definitions: &'a Definitions<CombinedSerializer>,
by_alias: bool,
warnings: &'a CollectWarnings,
exclude_unset: bool,
Expand All @@ -112,7 +107,6 @@ impl<'a> Extra<'a> {
) -> Self {
Self {
mode,
definitions,
ob_type_lookup: ObTypeLookup::cached(py),
warnings,
by_alias,
Expand Down Expand Up @@ -156,7 +150,6 @@ impl SerCheck {
#[cfg_attr(debug_assertions, derive(Debug))]
pub(crate) struct ExtraOwned {
mode: SerMode,
definitions: Vec<CombinedSerializer>,
warnings: CollectWarnings,
by_alias: bool,
exclude_unset: bool,
Expand All @@ -176,7 +169,6 @@ impl ExtraOwned {
pub fn new(extra: &Extra) -> Self {
Self {
mode: extra.mode.clone(),
definitions: extra.definitions.to_vec(),
warnings: extra.warnings.clone(),
by_alias: extra.by_alias,
exclude_unset: extra.exclude_unset,
Expand All @@ -196,7 +188,6 @@ impl ExtraOwned {
pub fn to_extra<'py>(&'py self, py: Python<'py>) -> Extra<'py> {
Extra {
mode: &self.mode,
definitions: &self.definitions,
ob_type_lookup: ObTypeLookup::cached(py),
warnings: &self.warnings,
by_alias: self.by_alias,
Expand Down

0 comments on commit a8fb1e3

Please sign in to comment.