Skip to content

Commit

Permalink
Parallel comemo & optimizations (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dherse committed Dec 15, 2023
1 parent 2b3b8ee commit ddb3773
Show file tree
Hide file tree
Showing 12 changed files with 648 additions and 372 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
- run: cargo build
- run: cargo test
- run: cargo build --all-features
- run: cargo test --all-features
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.vscode
.DS_Store
/target
macros/target
Cargo.lock
14 changes: 14 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@ license = "MIT OR Apache-2.0"
categories = ["caching"]
keywords = ["incremental", "memoization", "tracked", "constraints"]

[features]
default = []
testing = []

[dependencies]
comemo-macros = { version = "0.3.1", path = "macros" }
once_cell = "1.18"
parking_lot = "0.12"
siphasher = "1"

[dev-dependencies]
serial_test = "2.0.0"

[[test]]
name = "tests"
path = "tests/tests.rs"
required-features = ["testing"]
30 changes: 21 additions & 9 deletions macros/src/memoize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub fn expand(item: &syn::Item) -> Result<proc_macro2::TokenStream> {
};

// Preprocess and validate the function.
let function = prepare(&item)?;
let function = prepare(item)?;

// Rewrite the function's body to memoize it.
process(&function)
Expand All @@ -23,7 +23,7 @@ struct Function {
/// An argument to a memoized function.
enum Argument {
Receiver(syn::Token![self]),
Ident(Option<syn::Token![mut]>, syn::Ident),
Ident(Box<syn::Type>, Option<syn::Token![mut]>, syn::Ident),
}

/// Preprocess and validate a function.
Expand Down Expand Up @@ -71,7 +71,7 @@ fn prepare_arg(input: &syn::FnArg) -> Result<Argument> {
bail!(typed.ty, "memoized functions cannot have mutable parameters")
}

Argument::Ident(mutability.clone(), ident.clone())
Argument::Ident(typed.ty.clone(), *mutability, ident.clone())
}
})
}
Expand All @@ -82,7 +82,7 @@ fn process(function: &Function) -> Result<TokenStream> {
let bounds = function.args.iter().map(|arg| {
let val = match arg {
Argument::Receiver(token) => quote! { #token },
Argument::Ident(_, ident) => quote! { #ident },
Argument::Ident(_, _, ident) => quote! { #ident },
};
quote_spanned! { function.item.span() =>
::comemo::internal::assert_hashable_or_trackable(&#val);
Expand All @@ -94,14 +94,20 @@ fn process(function: &Function) -> Result<TokenStream> {
Argument::Receiver(token) => quote! {
::comemo::internal::hash(&#token)
},
Argument::Ident(_, ident) => quote! { #ident },
Argument::Ident(_, _, ident) => quote! { #ident },
});
let arg_tuple = quote! { (#(#args,)*) };

let arg_tys = function.args.iter().map(|arg| match arg {
Argument::Receiver(_) => quote! { () },
Argument::Ident(ty, _, _) => quote! { #ty },
});
let arg_ty_tuple = quote! { (#(#arg_tys,)*) };

// Construct a tuple for all parameters.
let params = function.args.iter().map(|arg| match arg {
Argument::Receiver(_) => quote! { _ },
Argument::Ident(mutability, ident) => quote! { #mutability #ident },
Argument::Ident(_, mutability, ident) => quote! { #mutability #ident },
});
let param_tuple = quote! { (#(#params,)*) };

Expand All @@ -118,14 +124,20 @@ fn process(function: &Function) -> Result<TokenStream> {
ident.mutability = None;
}

let unique = quote! { __ComemoUnique };
wrapped.block = parse_quote! { {
struct #unique;
static __CACHE: ::comemo::internal::Cache<
<::comemo::internal::Args<#arg_ty_tuple> as ::comemo::internal::Input>::Constraint,
#output,
> = ::comemo::internal::Cache::new(|| {
::comemo::internal::register_evictor(|max_age| __CACHE.evict(max_age));
::core::default::Default::default()
});

#(#bounds;)*
::comemo::internal::memoized(
::core::any::TypeId::of::<#unique>(),
::comemo::internal::Args(#arg_tuple),
&::core::default::Default::default(),
&__CACHE,
#closure,
)
} };
Expand Down
83 changes: 59 additions & 24 deletions macros/src/track.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,38 @@ pub fn expand(item: &syn::Item) -> Result<TokenStream> {
}

for item in &item.items {
methods.push(prepare_impl_method(&item)?);
methods.push(prepare_impl_method(item)?);
}

let ty = item.self_ty.as_ref().clone();
(ty, &item.generics, None)
}
syn::Item::Trait(item) => {
for param in item.generics.params.iter() {
bail!(param, "tracked traits cannot be generic")
if let Some(first) = item.generics.params.first() {
bail!(first, "tracked traits cannot be generic")
}

for item in &item.items {
methods.push(prepare_trait_method(&item)?);
methods.push(prepare_trait_method(item)?);
}

let name = &item.ident;
let ty = parse_quote! { dyn #name + '__comemo_dynamic };
(ty, &item.generics, Some(name.clone()))
(ty, &item.generics, Some(item.ident.clone()))
}
_ => bail!(item, "`track` can only be applied to impl blocks and traits"),
};

// Produce the necessary items for the type to become trackable.
let variants = create_variants(&methods);
let scope = create(&ty, generics, trait_, &methods)?;

Ok(quote! {
#item
const _: () = { #scope };
const _: () = {
#variants
#scope
};
})
}

Expand Down Expand Up @@ -175,6 +179,43 @@ fn prepare_method(vis: syn::Visibility, sig: &syn::Signature) -> Result<Method>
})
}

/// Produces the variants for the constraint.
fn create_variants(methods: &[Method]) -> TokenStream {
let variants = methods.iter().map(create_variant);
let is_mutable_variants = methods.iter().map(|m| {
let name = &m.sig.ident;
let mutable = m.mutable;
quote! { __ComemoVariant::#name(..) => #mutable }
});

let is_mutable = (!methods.is_empty())
.then(|| {
quote! {
match &self.0 {
#(#is_mutable_variants),*
}
}
})
.unwrap_or_else(|| quote! { false });

quote! {
#[derive(Clone, PartialEq, Hash)]
pub struct __ComemoCall(__ComemoVariant);

impl ::comemo::internal::Call for __ComemoCall {
fn is_mutable(&self) -> bool {
#is_mutable
}
}

#[derive(Clone, PartialEq, Hash)]
#[allow(non_camel_case_types)]
enum __ComemoVariant {
#(#variants,)*
}
}
}

/// Produce the necessary items for a type to become trackable.
fn create(
ty: &syn::Type,
Expand Down Expand Up @@ -229,26 +270,32 @@ fn create(
};

// Prepare replying.
let immutable = methods.iter().all(|m| !m.mutable);
let replays = methods.iter().map(create_replay);
let replay = methods.iter().any(|m| m.mutable).then(|| {
let replay = (!immutable).then(|| {
quote! {
constraint.replay(|call| match &call.0 { #(#replays,)* });
}
});

// Prepare variants and wrapper methods.
let variants = methods.iter().map(create_variant);
let wrapper_methods = methods
.iter()
.filter(|m| !m.mutable)
.map(|m| create_wrapper(m, false));
let wrapper_methods_mut = methods.iter().map(|m| create_wrapper(m, true));

let constraint = if immutable {
quote! { ImmutableConstraint }
} else {
quote! { MutableConstraint }
};

Ok(quote! {
impl #impl_params ::comemo::Track for #ty #where_clause {}
impl #impl_params ::comemo::Track for #ty #where_clause {}

impl #impl_params ::comemo::Validate for #ty #where_clause {
type Constraint = ::comemo::internal::Constraint<__ComemoCall>;
impl #impl_params ::comemo::Validate for #ty #where_clause {
type Constraint = ::comemo::internal::#constraint<__ComemoCall>;

#[inline]
fn validate(&self, constraint: &Self::Constraint) -> bool {
Expand All @@ -267,15 +314,6 @@ fn create(
}
}

#[derive(Clone, PartialEq, Hash)]
pub struct __ComemoCall(__ComemoVariant);

#[derive(Clone, PartialEq, Hash)]
#[allow(non_camel_case_types)]
enum __ComemoVariant {
#(#variants,)*
}

#[doc(hidden)]
impl #impl_params ::comemo::internal::Surfaces for #ty #where_clause {
type Surface<#t> = __ComemoSurface #type_params_t where Self: #t;
Expand Down Expand Up @@ -323,7 +361,6 @@ fn create(
impl #impl_params_t #prefix __ComemoSurfaceMut #type_params_t {
#(#wrapper_methods_mut)*
}

})
}

Expand Down Expand Up @@ -370,10 +407,9 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream {
let vis = &method.vis;
let sig = &method.sig;
let args = &method.args;
let mutable = method.mutable;
let to_parts = if !tracked_mut {
quote! { to_parts_ref(self.0) }
} else if !mutable {
} else if !method.mutable {
quote! { to_parts_mut_ref(&self.0) }
} else {
quote! { to_parts_mut_mut(&mut self.0) }
Expand All @@ -389,7 +425,6 @@ fn create_wrapper(method: &Method, tracked_mut: bool) -> TokenStream {
constraint.push(
__ComemoCall(__comemo_variant),
::comemo::internal::hash(&output),
#mutable,
);
}
output
Expand Down
63 changes: 63 additions & 0 deletions src/accelerate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};

use parking_lot::{MappedRwLockReadGuard, Mutex, RwLock, RwLockReadGuard};

/// The global list of currently alive accelerators.
static ACCELERATORS: RwLock<(usize, Vec<Accelerator>)> = RwLock::new((0, Vec::new()));

/// The current ID of the accelerator.
static ID: AtomicUsize = AtomicUsize::new(0);

/// The type of each individual accelerator.
///
/// Maps from call hashes to return hashes.
type Accelerator = Mutex<HashMap<u128, u128>>;

/// Generate a new accelerator.
pub fn id() -> usize {
// Get the next ID.
ID.fetch_add(1, Ordering::SeqCst)
}

/// Evict the accelerators.
pub fn evict() {
let mut accelerators = ACCELERATORS.write();
let (offset, vec) = &mut *accelerators;

// Update the offset.
*offset = ID.load(Ordering::SeqCst);

// Clear all accelerators while keeping the memory allocated.
vec.iter_mut().for_each(|accelerator| accelerator.lock().clear())
}

/// Get an accelerator by ID.
pub fn get(id: usize) -> Option<MappedRwLockReadGuard<'static, Accelerator>> {
// We always lock the accelerators, as we need to make sure that the
// accelerator is not removed while we are reading it.
let mut accelerators = ACCELERATORS.read();

let mut i = id.checked_sub(accelerators.0)?;
if i >= accelerators.1.len() {
drop(accelerators);
resize(i + 1);
accelerators = ACCELERATORS.read();

// Because we release the lock before resizing the accelerator, we need
// to check again whether the ID is still valid because another thread
// might evicted the cache.
i = id.checked_sub(accelerators.0)?;
}

Some(RwLockReadGuard::map(accelerators, move |(_, vec)| &vec[i]))
}

/// Adjusts the amount of accelerators.
#[cold]
fn resize(len: usize) {
let mut pair = ACCELERATORS.write();
if len > pair.1.len() {
pair.1.resize_with(len, || Mutex::new(HashMap::new()));
}
}
Loading

0 comments on commit ddb3773

Please sign in to comment.