Skip to content

Commit

Permalink
use rw lock to make get_root_info cheaper
Browse files Browse the repository at this point in the history
  • Loading branch information
sokra committed Sep 28, 2023
1 parent ded66e9 commit e3ae906
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions crates/turbo-tasks-memory/src/aggregation_tree/bottom_tree.rs
@@ -1,7 +1,7 @@
use std::{hash::Hash, ops::ControlFlow, sync::Arc};

use nohash_hasher::{BuildNoHashHasher, IsEnabled};
use parking_lot::{Mutex, MutexGuard};
use parking_lot::{RwLock, RwLockWriteGuard};
use ref_cast::RefCast;
use smallvec::SmallVec;

Expand All @@ -23,7 +23,7 @@ use crate::count_hash_set::{CountHashSet, RemoveIfEntryResult};
pub struct BottomTree<T, I: IsEnabled> {
height: u8,
item: I,
state: Mutex<BottomTreeState<T, I>>,
state: RwLock<BottomTreeState<T, I>>,
}

pub struct BottomTreeState<T, I: IsEnabled> {
Expand All @@ -39,7 +39,7 @@ impl<T: Default, I: IsEnabled> BottomTree<T, I> {
Self {
height,
item,
state: Mutex::new(BottomTreeState {
state: RwLock::new(BottomTreeState {
data: T::default(),
bottom_upper: BottomConnection::new(),
top_upper: CountHashSet::new(),
Expand Down Expand Up @@ -83,7 +83,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
}

fn add_children_of_child_if_following(&self, children: &mut SmallVec<[&I; 16]>) {
let mut state = self.state.lock();
let mut state = self.state.write();
children.retain(|&mut child| !state.following.add_if_entry(child));
}

Expand All @@ -92,7 +92,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
aggregation_context: &C,
mut children: SmallVec<[&I; 16]>,
) {
let mut state = self.state.lock();
let mut state = self.state.write();
children.retain(|&mut child| state.following.add_clonable(child));
if children.is_empty() {
return;
Expand Down Expand Up @@ -176,7 +176,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
}

fn add_child_of_child_if_following(&self, child_of_child: &I) -> bool {
let mut state = self.state.lock();
let mut state = self.state.write();
state.following.add_if_entry(child_of_child)
}

Expand All @@ -185,7 +185,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
aggregation_context: &C,
child_of_child: &I,
) {
let mut state = self.state.lock();
let mut state = self.state.write();
if !state.following.add_clonable(child_of_child) {
// Already connect, nothing more to do
return;
Expand Down Expand Up @@ -238,7 +238,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
aggregation_context: &C,
child_of_child: &I,
) -> bool {
let mut state = self.state.lock();
let mut state = self.state.write();
match state.following.remove_if_entry(child_of_child) {
RemoveIfEntryResult::PartiallyRemoved => return true,
RemoveIfEntryResult::NotPresent => return false,
Expand All @@ -253,7 +253,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
aggregation_context: &C,
children: &mut Vec<&'a I>,
) {
let mut state = self.state.lock();
let mut state = self.state.write();
let mut removed = SmallVec::<[_; 16]>::default();
children.retain(|&child| match state.following.remove_if_entry(child) {
RemoveIfEntryResult::PartiallyRemoved => false,
Expand All @@ -273,7 +273,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
aggregation_context: &C,
child_of_child: &I,
) -> bool {
let mut state = self.state.lock();
let mut state = self.state.write();
if !state.following.remove_clonable(child_of_child) {
// no present, nothing to do
return false;
Expand All @@ -287,7 +287,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
aggregation_context: &C,
mut children: SmallVec<[&I; 16]>,
) {
let mut state = self.state.lock();
let mut state = self.state.write();
children.retain(|&mut child| state.following.remove_clonable(child));
propagate_lost_followings_to_uppers(state, aggregation_context, children);
}
Expand Down Expand Up @@ -339,7 +339,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
aggregation_context: &C,
upper: &Arc<BottomTree<T, I>>,
) {
let mut state = self.state.lock();
let mut state = self.state.write();
let old_inner = state.bottom_upper.set_left_upper(upper);
let add_change = aggregation_context.info_to_add_change(&state.data);
let children = state
Expand Down Expand Up @@ -402,7 +402,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
remove_change: &Option<C::ItemChange>,
following: &[I],
) {
let mut state = self.state.lock();
let mut state = self.state.write();
if count > 0 {
// add as following
if state.following.add_count(item.clone(), count as usize) {
Expand Down Expand Up @@ -430,7 +430,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
upper: &Arc<BottomTree<T, I>>,
nesting_level: u8,
) -> bool {
let mut state = self.state.lock();
let mut state = self.state.write();
let number_of_following = state.following.len();
let BottomConnection::Inner(inner) = &mut state.bottom_upper else {
return false;
Expand Down Expand Up @@ -466,7 +466,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
aggregation_context: &C,
upper: &Arc<BottomTree<T, I>>,
) {
let mut state = self.state.lock();
let mut state = self.state.write();
state.bottom_upper.unset_left_upper(upper);
if let Some(change) = aggregation_context.info_to_remove_change(&state.data) {
upper.child_change(aggregation_context, &change);
Expand All @@ -491,7 +491,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
aggregation_context: &C,
upper: &Arc<BottomTree<T, I>>,
) -> bool {
let mut state = self.state.lock();
let mut state = self.state.write();
let BottomConnection::Inner(inner) = &mut state.bottom_upper else {
return false;
};
Expand All @@ -517,7 +517,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
aggregation_context: &C,
upper: &Arc<TopTree<T>>,
) {
let mut state = self.state.lock();
let mut state = self.state.write();
let new = state.top_upper.add_clonable(TopRef::ref_cast(upper));
if new {
if let Some(change) = aggregation_context.info_to_add_change(&state.data) {
Expand All @@ -535,7 +535,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
aggregation_context: &C,
upper: &Arc<TopTree<T>>,
) {
let mut state = self.state.lock();
let mut state = self.state.write();
let removed = state.top_upper.remove_clonable(TopRef::ref_cast(upper));
if removed {
if let Some(change) = aggregation_context.info_to_remove_change(&state.data) {
Expand Down Expand Up @@ -570,7 +570,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
aggregation_context: &C,
change: &C::ItemChange,
) {
let mut state = self.state.lock();
let mut state = self.state.write();
let change = aggregation_context.apply_change(&mut state.data, change);
propagate_change_to_upper(&mut state, aggregation_context, change);
}
Expand All @@ -582,7 +582,7 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
) -> C::RootInfo {
let mut result = aggregation_context.new_root_info(root_info_type);
let top_uppers = {
let state = self.state.lock();
let state = self.state.read();
state
.top_upper
.iter()
Expand All @@ -596,15 +596,15 @@ impl<T, I: Clone + Eq + Hash + IsEnabled> BottomTree<T, I> {
}
}
let bottom_uppers = {
let state = self.state.lock();
let state = self.state.read();
state.bottom_upper.as_cloned_uppers()
};
bottom_uppers.get_root_info(aggregation_context, root_info_type, result)
}
}

fn propagate_lost_following_to_uppers<C: AggregationContext>(
state: MutexGuard<'_, BottomTreeState<C::Info, C::ItemRef>>,
state: RwLockWriteGuard<'_, BottomTreeState<C::Info, C::ItemRef>>,
aggregation_context: &C,
child_of_child: &C::ItemRef,
) {
Expand All @@ -622,7 +622,7 @@ fn propagate_lost_following_to_uppers<C: AggregationContext>(
}

fn propagate_lost_followings_to_uppers<'a, C: AggregationContext>(
state: MutexGuard<'_, BottomTreeState<C::Info, C::ItemRef>>,
state: RwLockWriteGuard<'_, BottomTreeState<C::Info, C::ItemRef>>,
aggregation_context: &C,
children: impl IntoIterator<Item = &'a C::ItemRef> + Clone,
) where
Expand All @@ -638,7 +638,7 @@ fn propagate_lost_followings_to_uppers<'a, C: AggregationContext>(
}

fn propagate_new_following_to_uppers<C: AggregationContext>(
state: MutexGuard<'_, BottomTreeState<C::Info, C::ItemRef>>,
state: RwLockWriteGuard<'_, BottomTreeState<C::Info, C::ItemRef>>,
aggregation_context: &C,
child_of_child: &C::ItemRef,
) {
Expand All @@ -652,7 +652,7 @@ fn propagate_new_following_to_uppers<C: AggregationContext>(
}

fn propagate_change_to_upper<C: AggregationContext>(
state: &mut MutexGuard<BottomTreeState<C::Info, C::ItemRef>>,
state: &mut RwLockWriteGuard<BottomTreeState<C::Info, C::ItemRef>>,
aggregation_context: &C,
change: Option<C::ItemChange>,
) {
Expand Down Expand Up @@ -681,7 +681,7 @@ fn visit_graph<C: AggregationContext>(
let mut edges = 0;
while let Some(item) = queue.pop_front() {
let tree = bottom_tree(aggregation_context, &item, height);
let state = tree.state.lock();
let state = tree.state.read();
for next in state.following.iter() {
edges += 1;
if visited.insert(next.clone()) {
Expand Down Expand Up @@ -722,7 +722,7 @@ pub fn print_graph<C: AggregationContext>(
let tree = bottom_tree(aggregation_context, &item, height);
let name = name_fn(&item);
let label = format!("{}", name);
let state = tree.state.lock();
let state = tree.state.read();
if color_upper {
print!(r#""{} {}" [color=red];"#, height - 1, name);
} else {
Expand Down

0 comments on commit e3ae906

Please sign in to comment.