diff --git a/.gitignore b/.gitignore index a9d37c5..894a336 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ target Cargo.lock + +# Add tempfiles so cargo watch does not trigger before save +*.kate-swp diff --git a/src/lib.rs b/src/lib.rs index 15d5201..52f16ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -288,6 +288,7 @@ impl Builder { job_receiver: Mutex::new(rx), empty_condvar: Condvar::new(), empty_trigger: Mutex::new(()), + join_generation: AtomicUsize::new(0), queued_count: AtomicUsize::new(0), active_count: AtomicUsize::new(0), max_thread_count: AtomicUsize::new(num_threads), @@ -312,6 +313,7 @@ struct ThreadPoolSharedData { job_receiver: Mutex>>, empty_trigger: Mutex<()>, empty_condvar: Condvar, + join_generation: AtomicUsize, queued_count: AtomicUsize, active_count: AtomicUsize, max_thread_count: AtomicUsize, @@ -582,6 +584,9 @@ impl ThreadPool { /// /// Calling `join` on an empty pool will cause an immediate return. /// `join` may be called from multiple threads concurrently. + /// A `join` is an atomic point in time. All threads joining before the join + /// event will exit together even if the pool is processing new jobs by the + /// time they get scheduled. /// /// Calling `join` from a thread within the pool will cause a deadlock. This /// behavior is considered safe. @@ -607,12 +612,20 @@ impl ThreadPool { /// assert_eq!(42, test_count.load(Ordering::Relaxed)); /// ``` pub fn join(&self) { - while self.shared_data.has_work() { - let mut lock = self.shared_data.empty_trigger.lock().unwrap(); - while self.shared_data.has_work() { - lock = self.shared_data.empty_condvar.wait(lock).unwrap(); - } + if self.shared_data.has_work() == false { + return (); + } + + let generation = self.shared_data.join_generation.load(Ordering::SeqCst); + let mut lock = self.shared_data.empty_trigger.lock().unwrap(); + + while generation == self.shared_data.join_generation.load(Ordering::Relaxed) + && self.shared_data.has_work() { + lock = self.shared_data.empty_condvar.wait(lock).unwrap(); } + + // increase generation if we are the first thread to come out of the loop + self.shared_data.join_generation.compare_and_swap(generation, generation.wrapping_add(1), Ordering::SeqCst); } } @@ -763,7 +776,7 @@ fn spawn_in_pool(shared_data: Arc) { #[cfg(test)] mod test { - use super::ThreadPool; + use super::{ThreadPool, Builder}; use std::sync::{Arc, Barrier}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc::{sync_channel, channel}; @@ -1078,7 +1091,6 @@ mod test { let pool0_ = pool0.clone(); let tx = tx.clone(); pool0.execute(move || { - //sleep(Duration::from_millis(13*i)); pool1.execute(move || { error(format!("p1: {} -=- {:?}\n", i, pool0_)); pool0_.join(); @@ -1217,4 +1229,92 @@ mod test { assert_eq!(a, a.clone()); } + + #[test] + /// The scenario is joining threads should not be stuck once their wave + /// of joins has completed. So once one thread joining on a pool has + /// succeded other threads joining on the same pool must get out even if + /// the thread is used for other jobs while the first group is finishing + /// their join + /// + /// In this example this means the waiting threads will exit the join in + /// groups of four because the waiter pool has four workers. + fn test_join_wavesurfer() { + let n_cycles = 4; + let n_workers = 4; + let (tx, rx) = channel(); + let builder = Builder::new().num_threads(n_workers) + .thread_name("join wavesurfer".into()); + let p_waiter = builder.clone().build(); + let p_clock = builder.build(); + + let barrier = Arc::new(Barrier::new(3)); + let wave_clock = Arc::new(AtomicUsize::new(0)); + let clock_thread = { + let barrier = barrier.clone(); + let wave_clock = wave_clock.clone(); + thread::spawn(move || { + barrier.wait(); + for wave_num in 0..n_cycles { + wave_clock.store(wave_num, Ordering::SeqCst); + sleep(Duration::from_secs(1)); + } + }) + }; + + { + let barrier = barrier.clone(); + p_clock.execute(move || { + barrier.wait(); + // this sleep is for stabilisation on weaker platforms + sleep(Duration::from_millis(100)); + }); + } + + // prepare three waves of jobs + for i in 0..3*n_workers { + let p_clock = p_clock.clone(); + let tx = tx.clone(); + let wave_clock = wave_clock.clone(); + p_waiter.execute(move || { + let now = wave_clock.load(Ordering::SeqCst); + p_clock.join(); + // submit jobs for the second wave + p_clock.execute(|| sleep(Duration::from_secs(1))); + let clock = wave_clock.load(Ordering::SeqCst); + tx.send((now, clock, i)).unwrap(); + }); + } + println!("all scheduled at {}", wave_clock.load(Ordering::SeqCst)); + barrier.wait(); + + p_clock.join(); + //p_waiter.join(); + + drop(tx); + let mut hist = vec![0; n_cycles]; + let mut data = vec![]; + for (now, after, i) in rx.iter() { + let mut dur = after - now; + if dur >= n_cycles -1 { + dur = n_cycles -1; + } + hist[dur] += 1; + + data.push((now, after, i)); + } + for (i, n) in hist.iter().enumerate() { + println!("\t{}: {} {}", i, n, &*(0..*n).fold("".to_owned(), |s, _| s + "*")); + } + assert!(data.iter() + .all(|&(cycle, stop, i)| { + if i < n_workers { + cycle == stop + } else { + cycle < stop + } + })); + + clock_thread.join().unwrap(); + } }