From ddb3773d062369e120d09c9a0a7909faf29d8fe1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20d=27Herbais=20de=20Thun?= Date: Fri, 15 Dec 2023 11:47:53 +0100 Subject: [PATCH] Parallel comemo & optimizations (#5) --- .github/workflows/ci.yml | 4 +- .gitignore | 1 + Cargo.toml | 14 ++ macros/src/memoize.rs | 30 ++- macros/src/track.rs | 83 +++++--- src/accelerate.rs | 63 ++++++ src/cache.rs | 404 ++++++++++++--------------------------- src/constraint.rs | 301 +++++++++++++++++++++++++++++ src/input.rs | 2 +- src/lib.rs | 12 +- src/track.rs | 15 +- tests/tests.rs | 91 +++++---- 12 files changed, 648 insertions(+), 372 deletions(-) create mode 100644 src/accelerate.rs create mode 100644 src/constraint.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9e6f0a2..98b71ca 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,5 +7,5 @@ jobs: steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - - run: cargo build - - run: cargo test + - run: cargo build --all-features + - run: cargo test --all-features diff --git a/.gitignore b/.gitignore index 360ab70..dbb723c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .vscode .DS_Store /target +macros/target Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index 2bad481..4cd252f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,20 @@ license = "MIT OR Apache-2.0" categories = ["caching"] keywords = ["incremental", "memoization", "tracked", "constraints"] +[features] +default = [] +testing = [] + [dependencies] comemo-macros = { version = "0.3.1", path = "macros" } +once_cell = "1.18" +parking_lot = "0.12" siphasher = "1" + +[dev-dependencies] +serial_test = "2.0.0" + +[[test]] +name = "tests" +path = "tests/tests.rs" +required-features = ["testing"] diff --git a/macros/src/memoize.rs b/macros/src/memoize.rs index 0e24f82..09afbce 100644 --- a/macros/src/memoize.rs +++ b/macros/src/memoize.rs @@ -7,7 +7,7 @@ pub fn expand(item: &syn::Item) -> Result { }; // Preprocess and validate the function. - let function = prepare(&item)?; + let function = prepare(item)?; // Rewrite the function's body to memoize it. process(&function) @@ -23,7 +23,7 @@ struct Function { /// An argument to a memoized function. enum Argument { Receiver(syn::Token![self]), - Ident(Option, syn::Ident), + Ident(Box, Option, syn::Ident), } /// Preprocess and validate a function. @@ -71,7 +71,7 @@ fn prepare_arg(input: &syn::FnArg) -> Result { bail!(typed.ty, "memoized functions cannot have mutable parameters") } - Argument::Ident(mutability.clone(), ident.clone()) + Argument::Ident(typed.ty.clone(), *mutability, ident.clone()) } }) } @@ -82,7 +82,7 @@ fn process(function: &Function) -> Result { let bounds = function.args.iter().map(|arg| { let val = match arg { Argument::Receiver(token) => quote! { #token }, - Argument::Ident(_, ident) => quote! { #ident }, + Argument::Ident(_, _, ident) => quote! { #ident }, }; quote_spanned! { function.item.span() => ::comemo::internal::assert_hashable_or_trackable(&#val); @@ -94,14 +94,20 @@ fn process(function: &Function) -> Result { Argument::Receiver(token) => quote! { ::comemo::internal::hash(&#token) }, - Argument::Ident(_, ident) => quote! { #ident }, + Argument::Ident(_, _, ident) => quote! { #ident }, }); let arg_tuple = quote! { (#(#args,)*) }; + let arg_tys = function.args.iter().map(|arg| match arg { + Argument::Receiver(_) => quote! { () }, + Argument::Ident(ty, _, _) => quote! { #ty }, + }); + let arg_ty_tuple = quote! { (#(#arg_tys,)*) }; + // Construct a tuple for all parameters. let params = function.args.iter().map(|arg| match arg { Argument::Receiver(_) => quote! { _ }, - Argument::Ident(mutability, ident) => quote! { #mutability #ident }, + Argument::Ident(_, mutability, ident) => quote! { #mutability #ident }, }); let param_tuple = quote! { (#(#params,)*) }; @@ -118,14 +124,20 @@ fn process(function: &Function) -> Result { ident.mutability = None; } - let unique = quote! { __ComemoUnique }; wrapped.block = parse_quote! { { - struct #unique; + static __CACHE: ::comemo::internal::Cache< + <::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint, + #output, + > = ::comemo::internal::Cache::new(|| { + ::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age)); + ::core::default::Default::default() + }); + #(#bounds;)* ::comemo::internal::memoized( - ::core::any::TypeId::of::<#unique>(), ::comemo::internal::Args(#arg_tuple), &::core::default::Default::default(), + &__CACHE, #closure, ) } }; diff --git a/macros/src/track.rs b/macros/src/track.rs index 60e80ec..bfc90be 100644 --- a/macros/src/track.rs +++ b/macros/src/track.rs @@ -20,34 +20,38 @@ pub fn expand(item: &syn::Item) -> Result { } for item in &item.items { - methods.push(prepare_impl_method(&item)?); + methods.push(prepare_impl_method(item)?); } let ty = item.self_ty.as_ref().clone(); (ty, &item.generics, None) } syn::Item::Trait(item) => { - for param in item.generics.params.iter() { - bail!(param, "tracked traits cannot be generic") + if let Some(first) = item.generics.params.first() { + bail!(first, "tracked traits cannot be generic") } for item in &item.items { - methods.push(prepare_trait_method(&item)?); + methods.push(prepare_trait_method(item)?); } let name = &item.ident; let ty = parse_quote! { dyn #name + '__comemo_dynamic }; - (ty, &item.generics, Some(name.clone())) + (ty, &item.generics, Some(item.ident.clone())) } _ => bail!(item, "`track` can only be applied to impl blocks and traits"), }; // Produce the necessary items for the type to become trackable. + let variants = create_variants(&methods); let scope = create(&ty, generics, trait_, &methods)?; Ok(quote! { #item - const _: () = { #scope }; + const _: () = { + #variants + #scope + }; }) } @@ -175,6 +179,43 @@ fn prepare_method(vis: syn::Visibility, sig: &syn::Signature) -> Result }) } +/// Produces the variants for the constraint. +fn create_variants(methods: &[Method]) -> TokenStream { + let variants = methods.iter().map(create_variant); + let is_mutable_variants = methods.iter().map(|m| { + let name = &m.sig.ident; + let mutable = m.mutable; + quote! { __ComemoVariant::#name(..) => #mutable } + }); + + let is_mutable = (!methods.is_empty()) + .then(|| { + quote! { + match &self.0 { + #(#is_mutable_variants),* + } + } + }) + .unwrap_or_else(|| quote! { false }); + + quote! { + #[derive(Clone, PartialEq, Hash)] + pub struct __ComemoCall(__ComemoVariant); + + impl ::comemo::internal::Call for __ComemoCall { + fn is_mutable(&self) -> bool { + #is_mutable + } + } + + #[derive(Clone, PartialEq, Hash)] + #[allow(non_camel_case_types)] + enum __ComemoVariant { + #(#variants,)* + } + } +} + /// Produce the necessary items for a type to become trackable. fn create( ty: &syn::Type, @@ -229,26 +270,32 @@ fn create( }; // Prepare replying. + let immutable = methods.iter().all(|m| !m.mutable); let replays = methods.iter().map(create_replay); - let replay = methods.iter().any(|m| m.mutable).then(|| { + let replay = (!immutable).then(|| { quote! { constraint.replay(|call| match &call.0 { #(#replays,)* }); } }); // Prepare variants and wrapper methods. - let variants = methods.iter().map(create_variant); let wrapper_methods = methods .iter() .filter(|m| !m.mutable) .map(|m| create_wrapper(m, false)); let wrapper_methods_mut = methods.iter().map(|m| create_wrapper(m, true)); + let constraint = if immutable { + quote! { ImmutableConstraint } + } else { + quote! { MutableConstraint } + }; + Ok(quote! { - impl #impl_params ::comemo::Track for #ty #where_clause {} + impl #impl_params ::comemo::Track for #ty #where_clause {} - impl #impl_params ::comemo::Validate for #ty #where_clause { - type Constraint = ::comemo::internal::Constraint<__ComemoCall>; + impl #impl_params ::comemo::Validate for #ty #where_clause { + type Constraint = ::comemo::internal::#constraint<__ComemoCall>; #[inline] fn validate(&self, constraint: &Self::Constraint) -> bool { @@ -267,15 +314,6 @@ fn create( } } - #[derive(Clone, PartialEq, Hash)] - pub struct __ComemoCall(__ComemoVariant); - - #[derive(Clone, PartialEq, Hash)] - #[allow(non_camel_case_types)] - enum __ComemoVariant { - #(#variants,)* - } - #[doc(hidden)] impl #impl_params ::comemo::internal::Surfaces for #ty #where_clause { type Surface<#t> = __ComemoSurface #type_params_t where Self: #t; @@ -323,7 +361,6 @@ fn create( impl #impl_params_t #prefix __ComemoSurfaceMut #type_params_t { #(#wrapper_methods_mut)* } - }) } @@ -370,10 +407,9 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream { let vis = &method.vis; let sig = &method.sig; let args = &method.args; - let mutable = method.mutable; let to_parts = if !tracked_mut { quote! { to_parts_ref(self.0) } - } else if !mutable { + } else if !method.mutable { quote! { to_parts_mut_ref(&self.0) } } else { quote! { to_parts_mut_mut(&mut self.0) } @@ -389,7 +425,6 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream { constraint.push( __ComemoCall(__comemo_variant), ::comemo::internal::hash(&output), - #mutable, ); } output diff --git a/src/accelerate.rs b/src/accelerate.rs new file mode 100644 index 0000000..611d3a0 --- /dev/null +++ b/src/accelerate.rs @@ -0,0 +1,63 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use parking_lot::{MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard}; + +/// The global list of currently alive accelerators. +static ACCELERATORS: RwLock<(usize, Vec)> = RwLock::new((0, Vec::new())); + +/// The current ID of the accelerator. +static ID: AtomicUsize = AtomicUsize::new(0); + +/// The type of each individual accelerator. +/// +/// Maps from call hashes to return hashes. +type Accelerator = Mutex>; + +/// Generate a new accelerator. +pub fn id() -> usize { + // Get the next ID. + ID.fetch_add(1, Ordering::SeqCst) +} + +/// Evict the accelerators. +pub fn evict() { + let mut accelerators = ACCELERATORS.write(); + let (offset, vec) = &mut *accelerators; + + // Update the offset. + *offset = ID.load(Ordering::SeqCst); + + // Clear all accelerators while keeping the memory allocated. + vec.iter_mut().for_each(|accelerator| accelerator.lock().clear()) +} + +/// Get an accelerator by ID. +pub fn get(id: usize) -> Option> { + // We always lock the accelerators, as we need to make sure that the + // accelerator is not removed while we are reading it. + let mut accelerators = ACCELERATORS.read(); + + let mut i = id.checked_sub(accelerators.0)?; + if i >= accelerators.1.len() { + drop(accelerators); + resize(i + 1); + accelerators = ACCELERATORS.read(); + + // Because we release the lock before resizing the accelerator, we need + // to check again whether the ID is still valid because another thread + // might evicted the cache. + i = id.checked_sub(accelerators.0)?; + } + + Some(RwLockReadGuard::map(accelerators, move |(_, vec)| &vec[i])) +} + +/// Adjusts the amount of accelerators. +#[cold] +fn resize(len: usize) { + let mut pair = ACCELERATORS.write(); + if len > pair.1.len() { + pair.1.resize_with(len, || Mutex::new(HashMap::new())); + } +} diff --git a/src/cache.rs b/src/cache.rs index df7e702..b8fec40 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,31 +1,28 @@ -use std::any::{Any, TypeId}; -use std::cell::{Cell, RefCell}; use std::collections::HashMap; -use std::hash::Hash; +use std::sync::atomic::{AtomicUsize, Ordering}; +use once_cell::sync::Lazy; +use parking_lot::RwLock; use siphasher::sip128::{Hasher128, SipHasher13}; +use crate::accelerate; +use crate::constraint::Join; use crate::input::Input; -thread_local! { - /// The global, dynamic cache shared by all memoized functions. - static CACHE: RefCell = RefCell::new(Cache::default()); - - /// The global ID counter for tracked values. Each tracked value gets a - /// unqiue ID based on which its validations are cached in the accelerator. - /// IDs may only be reused upon eviction of the accelerator. - static ID: Cell = const { Cell::new(0) }; +/// The global list of eviction functions. +static EVICTORS: RwLock> = RwLock::new(Vec::new()); - /// The global, dynamic accelerator shared by all cached values. - static ACCELERATOR: RefCell> - = RefCell::new(HashMap::default()); +#[cfg(feature = "testing")] +thread_local! { + /// Whether the last call was a hit. + static LAST_WAS_HIT: std::cell::Cell = const { std::cell::Cell::new(false) }; } /// Execute a function or use a cached result for it. pub fn memoized<'c, In, Out, F>( - id: TypeId, mut input: In, constraint: &'c In::Constraint, + cache: &Cache, func: F, ) -> Out where @@ -33,61 +30,47 @@ where Out: Clone + 'static, F: FnOnce(In::Tracked<'c>) -> Out, { - CACHE.with(|cache| { - // Compute the hash of the input's key part. - let key = { - let mut state = SipHasher13::new(); - input.key(&mut state); - let hash = state.finish128().as_u128(); - (id, hash) - }; + // Compute the hash of the input's key part. + let key = { + let mut state = SipHasher13::new(); + input.key(&mut state); + state.finish128().as_u128() + }; - // Check if there is a cached output. - let mut borrow = cache.borrow_mut(); - if let Some(constrained) = borrow.lookup::(key, &input) { - // Replay the mutations. - input.replay(&constrained.constraint); + // Check if there is a cached output. + let borrow = cache.0.read(); + if let Some((constrained, value)) = borrow.lookup::(key, &input) { + // Replay the mutations. + input.replay(constrained); - // Add the cached constraints to the outer ones. - input.retrack(constraint).1.join(&constrained.constraint); + // Add the cached constraints to the outer ones. + input.retrack(constraint).1.join(constrained); - let value = constrained.output.clone(); - borrow.last_was_hit = true; - return value; - } + #[cfg(feature = "testing")] + LAST_WAS_HIT.with(|cell| cell.set(true)); - // Release the borrow so that nested memoized calls can access the - // cache without panicking. - drop(borrow); + return value.clone(); + } - // Execute the function with the new constraints hooked in. - let (input, outer) = input.retrack(constraint); - let output = func(input); + // Release the borrow so that nested memoized calls can access the + // cache without dead locking. + drop(borrow); - // Add the new constraints to the outer ones. - outer.join(constraint); + // Execute the function with the new constraints hooked in. + let (input, outer) = input.retrack(constraint); + let output = func(input); - // Insert the result into the cache. - borrow = cache.borrow_mut(); - borrow.insert::(key, constraint.take(), output.clone()); - borrow.last_was_hit = false; + // Add the new constraints to the outer ones. + outer.join(constraint); - output - }) -} + // Insert the result into the cache. + let mut borrow = cache.0.write(); + borrow.insert::(key, constraint.take(), output.clone()); -/// Whether the last call was a hit. -pub fn last_was_hit() -> bool { - CACHE.with(|cache| cache.borrow().last_was_hit) -} + #[cfg(feature = "testing")] + LAST_WAS_HIT.with(|cell| cell.set(false)); -/// Get the next ID. -pub fn id() -> usize { - ID.with(|cell| { - let current = cell.get(); - cell.set(current.wrapping_add(1)); - current - }) + output } /// Evict the cache. @@ -100,260 +83,119 @@ pub fn id() -> usize { /// Comemo's cache is thread-local, meaning that this only evicts this thread's /// cache. pub fn evict(max_age: usize) { - CACHE.with(|cache| { - let mut cache = cache.borrow_mut(); - cache.map.retain(|_, entries| { - entries.retain_mut(|entry| { - entry.age += 1; - entry.age <= max_age - }); - !entries.is_empty() - }); - }); - ACCELERATOR.with(|accelerator| accelerator.borrow_mut().clear()); -} - -/// The global cache. -#[derive(Default)] -struct Cache { - /// Maps from function IDs + hashes to memoized results. - map: HashMap<(TypeId, u128), Vec>, - /// Whether the last call was a hit. - last_was_hit: bool, -} - -impl Cache { - /// Look for a matching entry in the cache. - fn lookup( - &mut self, - key: (TypeId, u128), - input: &In, - ) -> Option<&Constrained> - where - In: Input, - Out: Clone + 'static, - { - self.map - .get_mut(&key)? - .iter_mut() - .rev() - .find_map(|entry| entry.lookup::(input)) + for subevict in EVICTORS.read().iter() { + subevict(max_age); } - /// Insert an entry into the cache. - fn insert( - &mut self, - key: (TypeId, u128), - constraint: In::Constraint, - output: Out, - ) where - In: Input, - Out: 'static, - { - self.map - .entry(key) - .or_default() - .push(CacheEntry::new::(constraint, output)); - } + accelerate::evict(); } -/// A memoized result. -struct CacheEntry { - /// The memoized function's constrained output. - /// - /// This is of type `Constrained`. - constrained: Box, - /// How many evictions have passed since the entry has been last used. - age: usize, +/// Register an eviction function in the global list. +pub fn register_evictor(evict: fn(usize)) { + EVICTORS.write().push(evict); } -/// A value with a constraint. -struct Constrained { - /// The constraint which must be fulfilled for the output to be used. - constraint: C, - /// The memoized function's output. - output: T, +/// Whether the last call was a hit. +#[cfg(feature = "testing")] +pub fn last_was_hit() -> bool { + LAST_WAS_HIT.with(|cell| cell.get()) } -impl CacheEntry { - /// Create a new entry. - fn new(constraint: In::Constraint, output: Out) -> Self - where - In: Input, - Out: 'static, - { - Self { - constrained: Box::new(Constrained { constraint, output }), - age: 0, - } - } +/// A cache for a single memoized function. +pub struct Cache(Lazy>>); - /// Return the entry's output if it is valid for the given input. - fn lookup(&mut self, input: &In) -> Option<&Constrained> - where - In: Input, - Out: Clone + 'static, - { - let constrained: &Constrained = - self.constrained.downcast_ref().expect("wrong entry type"); +impl Cache { + /// Create an empty cache. + /// + /// It must take an initialization function because the `evict` fn + /// pointer cannot be passed as an argument otherwise the function + /// passed to `Lazy::new` is a closure and not a function pointer. + pub const fn new(init: fn() -> RwLock>) -> Self { + Self(Lazy::new(init)) + } - input.validate(&constrained.constraint).then(|| { - self.age = 0; - constrained - }) + /// Evict all entries whose age is larger than or equal to `max_age`. + pub fn evict(&self, max_age: usize) { + self.0.write().evict(max_age) } } -/// Defines a constraint for a tracked type. -#[derive(Clone)] -pub struct Constraint(RefCell>>); - -/// A call entry. -#[derive(Clone)] -struct Call { - args: T, - ret: u128, - both: u128, - mutable: bool, +/// The internal data for a cache. +pub struct CacheData { + /// Maps from hashes to memoized results. + entries: HashMap>>, } -impl Constraint { - /// Create empty constraints. - pub fn new() -> Self { - Self::default() - } - - /// Enter a constraint for a call to an immutable function. - #[inline] - pub fn push(&self, args: T, ret: u128, mutable: bool) { - let both = hash(&(&args, ret)); - self.push_inner(Call { args, ret, both, mutable }); - } - - /// Enter a constraint for a call to an immutable function. - #[inline] - fn push_inner(&self, call: Call) { - let mut calls = self.0.borrow_mut(); - - if !call.mutable { - for prev in calls.iter().rev() { - if prev.mutable { - break; - } - - #[cfg(debug_assertions)] - if prev.args == call.args { - check(prev.ret, call.ret); - } - - if prev.both == call.both { - return; - } - } - } - - calls.push(call); - } - - /// Whether the method satisfies as all input-output pairs. - #[inline] - pub fn validate(&self, mut f: F) -> bool - where - F: FnMut(&T) -> u128, - { - self.0.borrow().iter().all(|entry| f(&entry.args) == entry.ret) +impl CacheData { + /// Evict all entries whose age is larger than or equal to `max_age`. + fn evict(&mut self, max_age: usize) { + self.entries.retain(|_, entries| { + entries.retain_mut(|entry| { + let age = entry.age.get_mut(); + *age += 1; + *age <= max_age + }); + !entries.is_empty() + }); } - /// Whether the method satisfies as all input-output pairs. - #[inline] - pub fn validate_with_id(&self, mut f: F, id: usize) -> bool + /// Look for a matching entry in the cache. + fn lookup(&self, key: u128, input: &In) -> Option<(&In::Constraint, &Out)> where - F: FnMut(&T) -> u128, + In: Input, { - let calls = self.0.borrow(); - ACCELERATOR.with(|accelerator| { - let mut map = accelerator.borrow_mut(); - calls.iter().all(|entry| { - *map.entry((id, entry.both)).or_insert_with(|| f(&entry.args)) - == entry.ret - }) - }) + self.entries + .get(&key)? + .iter() + .rev() + .find_map(|entry| entry.lookup::(input)) } - /// Replay all input-output pairs. - #[inline] - pub fn replay(&self, mut f: F) + /// Insert an entry into the cache. + fn insert(&mut self, key: u128, constraint: In::Constraint, output: Out) where - F: FnMut(&T), + In: Input, { - for entry in self.0.borrow().iter() { - if entry.mutable { - f(&entry.args); - } - } + self.entries + .entry(key) + .or_default() + .push(CacheEntry::new::(constraint, output)); } } -impl Default for Constraint { +impl Default for CacheData { fn default() -> Self { - Self(RefCell::new(vec![])) + Self { entries: HashMap::new() } } } -/// Extend an outer constraint by an inner one. -pub trait Join { - /// Join this constraint with the `inner` one. - fn join(&self, inner: &T); - - /// Take out the constraint. - fn take(&self) -> Self; -} - -impl Join for Option<&T> { - #[inline] - fn join(&self, inner: &T) { - if let Some(outer) = self { - outer.join(inner); - } - } - - #[inline] - fn take(&self) -> Self { - unimplemented!("cannot call `Join::take` on optional constraint") - } +/// A memoized result. +struct CacheEntry { + /// The memoized function's constraint. + constraint: C, + /// The memoized function's output. + output: Out, + /// How many evictions have passed since the entry has been last used. + age: AtomicUsize, } -impl Join for Constraint { - #[inline] - fn join(&self, inner: &Self) { - for call in inner.0.borrow().iter() { - self.push_inner(call.clone()); - } - } - - #[inline] - fn take(&self) -> Self { - Self(RefCell::new(std::mem::take(&mut *self.0.borrow_mut()))) +impl CacheEntry { + /// Create a new entry. + fn new(constraint: In::Constraint, output: Out) -> Self + where + In: Input, + { + Self { constraint, output, age: AtomicUsize::new(0) } } -} -/// Produce a 128-bit hash of a value. -#[inline] -pub fn hash(value: &T) -> u128 { - let mut state = SipHasher13::new(); - value.hash(&mut state); - state.finish128().as_u128() -} - -/// Check for a constraint violation. -#[inline] -#[track_caller] -#[allow(dead_code)] -fn check(left_hash: u128, right_hash: u128) { - if left_hash != right_hash { - panic!( - "comemo: found conflicting constraints. \ - is this tracked function pure?" - ) + /// Return the entry's output if it is valid for the given input. + fn lookup(&self, input: &In) -> Option<(&In::Constraint, &Out)> + where + In: Input, + { + input.validate(&self.constraint).then(|| { + self.age.store(0, Ordering::SeqCst); + (&self.constraint, &self.output) + }) } } diff --git a/src/constraint.rs b/src/constraint.rs new file mode 100644 index 0000000..8e34878 --- /dev/null +++ b/src/constraint.rs @@ -0,0 +1,301 @@ +use std::borrow::Cow; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::hash::Hash; + +use parking_lot::RwLock; +use siphasher::sip128::{Hasher128, SipHasher13}; + +use crate::accelerate; + +/// A call to a tracked function. +pub trait Call: Hash + PartialEq + Clone { + /// Whether the call is mutable. + fn is_mutable(&self) -> bool; +} + +/// A constraint entry for a single call. +#[derive(Clone)] +struct ConstraintEntry { + call: T, + call_hash: u128, + ret_hash: u128, +} + +/// Defines a constraint for an immutably tracked type. +pub struct ImmutableConstraint(RwLock>); + +impl ImmutableConstraint { + /// Create an empty constraint. + pub fn new() -> Self { + Self::default() + } + + /// Enter a constraint for a call to an immutable function. + #[inline] + pub fn push(&self, call: T, ret_hash: u128) { + let call_hash = hash(&call); + let entry = ConstraintEntry { call, call_hash, ret_hash }; + self.0.write().push_inner(Cow::Owned(entry)); + } + + /// Whether the method satisfies as all input-output pairs. + #[inline] + pub fn validate(&self, mut f: F) -> bool + where + F: FnMut(&T) -> u128, + { + self.0.read().0.values().all(|entry| f(&entry.call) == entry.ret_hash) + } + + /// Whether the method satisfies as all input-output pairs. + #[inline] + pub fn validate_with_id(&self, mut f: F, id: usize) -> bool + where + F: FnMut(&T) -> u128, + { + let guard = self.0.read(); + if let Some(accelerator) = accelerate::get(id) { + let mut map = accelerator.lock(); + guard.0.values().all(|entry| { + *map.entry(entry.call_hash).or_insert_with(|| f(&entry.call)) + == entry.ret_hash + }) + } else { + guard.0.values().all(|entry| f(&entry.call) == entry.ret_hash) + } + } + + /// Replay all input-output pairs. + #[inline] + pub fn replay(&self, _: F) + where + F: FnMut(&T), + { + #[cfg(debug_assertions)] + for entry in self.0.read().0.values() { + assert!(!entry.call.is_mutable()); + } + } +} + +impl Clone for ImmutableConstraint { + fn clone(&self) -> Self { + Self(RwLock::new(self.0.read().clone())) + } +} + +impl Default for ImmutableConstraint { + fn default() -> Self { + Self(RwLock::new(EntryMap::default())) + } +} + +/// Defines a constraint for a mutably tracked type. +pub struct MutableConstraint(RwLock>); + +impl MutableConstraint { + /// Create an empty constraint. + pub fn new() -> Self { + Self::default() + } + + /// Enter a constraint for a call to a mutable function. + #[inline] + pub fn push(&self, call: T, ret_hash: u128) { + let call_hash = hash(&call); + let entry = ConstraintEntry { call, call_hash, ret_hash }; + self.0.write().push_inner(Cow::Owned(entry)); + } + + /// Whether the method satisfies as all input-output pairs. + #[inline] + pub fn validate(&self, mut f: F) -> bool + where + F: FnMut(&T) -> u128, + { + self.0.read().0.iter().all(|entry| f(&entry.call) == entry.ret_hash) + } + + /// Whether the method satisfies as all input-output pairs. + /// + /// On mutable tracked types, this does not use an accelerator as it is + /// rarely, if ever used. Therefore, it is not worth the overhead. + #[inline] + pub fn validate_with_id(&self, f: F, _: usize) -> bool + where + F: FnMut(&T) -> u128, + { + self.validate(f) + } + + /// Replay all input-output pairs. + #[inline] + pub fn replay(&self, mut f: F) + where + F: FnMut(&T), + { + for entry in &self.0.read().0 { + if entry.call.is_mutable() { + f(&entry.call); + } + } + } +} + +impl Clone for MutableConstraint { + fn clone(&self) -> Self { + Self(RwLock::new(self.0.read().clone())) + } +} + +impl Default for MutableConstraint { + fn default() -> Self { + Self(RwLock::new(EntryVec::default())) + } +} + +/// A map of calls. +#[derive(Clone)] +struct EntryMap(HashMap>); + +impl EntryMap { + /// Enter a constraint for a call to a function. + #[inline] + fn push_inner(&mut self, entry: Cow>) { + match self.0.entry(entry.call_hash) { + Entry::Occupied(_occupied) => { + #[cfg(debug_assertions)] + check(_occupied.get(), &entry); + } + Entry::Vacant(vacant) => { + vacant.insert(entry.into_owned()); + } + } + } +} + +impl Default for EntryMap { + fn default() -> Self { + Self(HashMap::new()) + } +} + +/// A list of calls. +/// +/// Order matters here, as those are mutable & immutable calls. +#[derive(Clone)] +struct EntryVec(Vec>); + +impl EntryVec { + /// Enter a constraint for a call to a function. + #[inline] + fn push_inner(&mut self, entry: Cow>) { + // If the call is immutable check whether we already have a call + // with the same arguments and return value. + if !entry.call.is_mutable() { + for prev in self.0.iter().rev() { + if entry.call.is_mutable() { + break; + } + + if entry.call_hash == prev.call_hash && entry.ret_hash == prev.ret_hash { + #[cfg(debug_assertions)] + check(&entry, prev); + return; + } + } + } + + // Insert the call into the call list. + self.0.push(entry.into_owned()); + } +} + +impl Default for EntryVec { + fn default() -> Self { + Self(Vec::new()) + } +} + +/// Extend an outer constraint by an inner one. +pub trait Join { + /// Join this constraint with the `inner` one. + fn join(&self, inner: &T); + + /// Take out the constraint. + fn take(&self) -> Self; +} + +impl Join for Option<&T> { + #[inline] + fn join(&self, inner: &T) { + if let Some(outer) = self { + outer.join(inner); + } + } + + #[inline] + fn take(&self) -> Self { + unimplemented!("cannot call `Join::take` on optional constraint") + } +} + +impl Join for ImmutableConstraint { + #[inline] + fn join(&self, inner: &Self) { + let mut this = self.0.write(); + for entry in inner.0.read().0.values() { + this.push_inner(Cow::Borrowed(entry)); + } + } + + #[inline] + fn take(&self) -> Self { + Self(RwLock::new(std::mem::take(&mut *self.0.write()))) + } +} + +impl Join for MutableConstraint { + #[inline] + fn join(&self, inner: &Self) { + let mut this = self.0.write(); + for entry in inner.0.read().0.iter() { + this.push_inner(Cow::Borrowed(entry)); + } + } + + #[inline] + fn take(&self) -> Self { + Self(RwLock::new(std::mem::take(&mut *self.0.write()))) + } +} + +/// Produce a 128-bit hash of a value. +#[inline] +pub fn hash(value: &T) -> u128 { + let mut state = SipHasher13::new(); + value.hash(&mut state); + state.finish128().as_u128() +} + +/// Check for a constraint violation. +#[inline] +#[track_caller] +#[allow(dead_code)] +fn check(lhs: &ConstraintEntry, rhs: &ConstraintEntry) { + if lhs.ret_hash != rhs.ret_hash { + panic!( + "comemo: found conflicting constraints. \ + is this tracked function pure?" + ) + } + + // Additional checks for debugging. + if lhs.call_hash != rhs.call_hash || lhs.call != rhs.call { + panic!( + "comemo: found conflicting `check` arguments. \ + this is a bug in comemo" + ) + } +} diff --git a/src/input.rs b/src/input.rs index d4185ea..a535ef2 100644 --- a/src/input.rs +++ b/src/input.rs @@ -1,6 +1,6 @@ use std::hash::{Hash, Hasher}; -use crate::cache::Join; +use crate::constraint::Join; use crate::track::{Track, Tracked, TrackedMut, Validate}; /// Ensure a type is suitable as input. diff --git a/src/lib.rs b/src/lib.rs index 9c5b77b..5921b7a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -82,7 +82,9 @@ For the full example see [`examples/calc.rs`][calc]. [calc]: https://github.com/typst/comemo/blob/main/examples/calc.rs */ +mod accelerate; mod cache; +mod constraint; mod input; mod prehashed; mod track; @@ -95,7 +97,13 @@ pub use comemo_macros::{memoize, track}; /// These are implementation details. Do not rely on them! #[doc(hidden)] pub mod internal { - pub use crate::cache::{hash, last_was_hit, memoized, Constraint}; - pub use crate::input::{assert_hashable_or_trackable, Args}; + pub use parking_lot::RwLock; + + pub use crate::cache::{memoized, register_evictor, Cache, CacheData}; + pub use crate::constraint::{hash, Call, ImmutableConstraint, MutableConstraint}; + pub use crate::input::{assert_hashable_or_trackable, Args, Input}; pub use crate::track::{to_parts_mut_mut, to_parts_mut_ref, to_parts_ref, Surfaces}; + + #[cfg(feature = "testing")] + pub use crate::cache::last_was_hit; } diff --git a/src/track.rs b/src/track.rs index e707057..8c7e340 100644 --- a/src/track.rs +++ b/src/track.rs @@ -1,7 +1,8 @@ use std::fmt::{self, Debug, Formatter}; use std::ops::{Deref, DerefMut}; -use crate::cache::{id, Join}; +use crate::accelerate; +use crate::constraint::Join; /// A trackable type. /// @@ -12,7 +13,11 @@ pub trait Track: Validate + Surfaces { /// Start tracking all accesses to a value. #[inline] fn track(&self) -> Tracked { - Tracked { value: self, constraint: None, id: id() } + Tracked { + value: self, + constraint: None, + id: accelerate::id(), + } } /// Start tracking all accesses and mutations to a value. @@ -27,7 +32,7 @@ pub trait Track: Validate + Surfaces { Tracked { value: self, constraint: Some(constraint), - id: id(), + id: accelerate::id(), } } @@ -227,7 +232,7 @@ where Tracked { value: this.value, constraint: this.constraint, - id: id(), + id: accelerate::id(), } } @@ -240,7 +245,7 @@ where Tracked { value: this.value, constraint: this.constraint, - id: id(), + id: accelerate::id(), } } diff --git a/tests/tests.rs b/tests/tests.rs index 9674272..db6bf31 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,8 +1,11 @@ +//! Run with `cargo test --all-features`. + use std::collections::HashMap; use std::hash::Hash; use std::path::{Path, PathBuf}; use comemo::{evict, memoize, track, Track, Tracked, TrackedMut, Validate}; +use serial_test::serial; macro_rules! test { (miss: $call:expr, $result:expr) => {{ @@ -17,6 +20,7 @@ macro_rules! test { /// Test basic memoization. #[test] +#[serial] fn test_basic() { #[memoize] fn empty() -> String { @@ -71,6 +75,7 @@ fn test_basic() { /// Test the calc language. #[test] +#[serial] fn test_calc() { #[memoize] fn evaluate(script: &str, files: Tracked) -> i32 { @@ -116,6 +121,7 @@ impl Files { /// Test cache eviction. #[test] +#[serial] fn test_evict() { #[memoize] fn null() -> u8 { @@ -141,17 +147,18 @@ fn test_evict() { /// Test tracking a trait object. #[test] +#[serial] fn test_tracked_trait() { #[memoize] fn traity(loader: Tracked, path: &Path) -> Vec { loader.load(path).unwrap() } - fn wrapper(loader: &dyn Loader, path: &Path) -> Vec { + fn wrapper(loader: &(dyn Loader), path: &Path) -> Vec { traity(loader.track(), path) } - let loader: &dyn Loader = &StaticLoader; + let loader: &(dyn Loader) = &StaticLoader; test!(miss: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); test!(hit: traity(loader.track(), Path::new("hi.rs")), [1, 2, 3]); test!(miss: traity(loader.track(), Path::new("bye.rs")), [1, 2, 3]); @@ -159,7 +166,7 @@ fn test_tracked_trait() { } #[track] -trait Loader { +trait Loader: Send + Sync { fn load(&self, path: &Path) -> Result, String>; } @@ -172,6 +179,7 @@ impl Loader for StaticLoader { /// Test memoized methods. #[test] +#[serial] fn test_memoized_methods() { #[derive(Hash)] struct Taker(String); @@ -197,6 +205,7 @@ fn test_memoized_methods() { /// Test different kinds of arguments. #[test] +#[serial] fn test_kinds() { #[memoize] fn selfie(tester: Tracky) -> String { @@ -212,27 +221,12 @@ fn test_kinds() { } } - #[memoize] - fn generic(tester: Tracky, name: T) -> String - where - T: AsRef + Hash, - { - tester.double_ref(name.as_ref()).to_string() - } - - #[memoize] - fn ignorant(tester: Tracky, name: impl AsRef + Hash) -> String { - tester.arg_ref(name.as_ref()).to_string() - } - let mut tester = Tester { data: "Hi".to_string() }; let tracky = tester.track(); test!(miss: selfie(tracky), "Hi"); test!(miss: unconditional(tracky), "Short"); test!(hit: unconditional(tracky), "Short"); - test!(miss: generic(tracky, "World"), "World"); - test!(miss: ignorant(tracky, "Ignorant"), "Ignorant"); test!(hit: selfie(tracky), "Hi"); tester.data.push('!'); @@ -240,15 +234,11 @@ fn test_kinds() { let tracky = tester.track(); test!(miss: selfie(tracky), "Hi!"); test!(miss: unconditional(tracky), "Short"); - test!(hit: generic(tracky, "World"), "World"); - test!(hit: ignorant(tracky, "Ignorant"), "Ignorant"); tester.data.push_str(" Let's go."); let tracky = tester.track(); test!(miss: unconditional(tracky), "Long"); - test!(miss: generic(tracky, "World"), "Hi! Let's go."); - test!(hit: ignorant(tracky, "Ignorant"), "Ignorant"); } /// Test with type alias. @@ -296,6 +286,7 @@ impl Empty {} /// Test tracking a type with a lifetime. #[test] +#[serial] fn test_lifetime() { #[comemo::memoize] fn contains_hello(lifeful: Tracked) -> bool { @@ -323,6 +314,7 @@ impl<'a> Lifeful<'a> { /// Test tracking a type with a chain of tracked values. #[test] +#[serial] fn test_chain() { #[comemo::memoize] fn process(chain: Tracked, value: u32) -> bool { @@ -348,6 +340,7 @@ fn test_chain() { /// Test that `Tracked` is covariant over `T`. #[test] +#[serial] #[allow(unused, clippy::needless_lifetimes)] fn test_variance() { fn foo<'a>(_: Tracked<'a, Chain<'a>>) {} @@ -384,33 +377,34 @@ impl<'a> Chain<'a> { } /// Test mutable tracking. -#[test] -#[rustfmt::skip] -fn test_mutable() { - #[comemo::memoize] - fn dump(mut sink: TrackedMut) { - sink.emit("a"); - sink.emit("b"); - let c = sink.len_or_ten().to_string(); - sink.emit(&c); - } + #[test] + #[serial] + #[rustfmt::skip] + fn test_mutable() { + #[comemo::memoize] + fn dump(mut sink: TrackedMut) { + sink.emit("a"); + sink.emit("b"); + let c = sink.len_or_ten().to_string(); + sink.emit(&c); + } - let mut emitter = Emitter(vec![]); - test!(miss: dump(emitter.track_mut()), ()); - test!(miss: dump(emitter.track_mut()), ()); - test!(miss: dump(emitter.track_mut()), ()); - test!(miss: dump(emitter.track_mut()), ()); - test!(hit: dump(emitter.track_mut()), ()); - test!(hit: dump(emitter.track_mut()), ()); - assert_eq!(emitter.0, [ - "a", "b", "2", - "a", "b", "5", - "a", "b", "8", - "a", "b", "10", - "a", "b", "10", - "a", "b", "10", - ]) -} + let mut emitter = Emitter(vec![]); + test!(miss: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut()), ()); + test!(miss: dump(emitter.track_mut()), ()); + test!(hit: dump(emitter.track_mut()), ()); + test!(hit: dump(emitter.track_mut()), ()); + assert_eq!(emitter.0, [ + "a", "b", "2", + "a", "b", "5", + "a", "b", "8", + "a", "b", "10", + "a", "b", "10", + "a", "b", "10", + ]) + } /// A tracked type with a mutable and an immutable method. #[derive(Clone)] @@ -433,6 +427,7 @@ struct Heavy(String); /// Test a tracked method that is impure. #[test] +#[serial] #[cfg(debug_assertions)] #[should_panic( expected = "comemo: found conflicting constraints. is this tracked function pure?"