diff --git a/applications/tests/test_measure_channel/src/lib.rs b/applications/tests/test_measure_channel/src/lib.rs index 5561d8a20..3efbdd5ae 100644 --- a/applications/tests/test_measure_channel/src/lib.rs +++ b/applications/tests/test_measure_channel/src/lib.rs @@ -4,12 +4,9 @@ extern crate alloc; use alloc::{format, sync::Arc, vec::Vec}; use awkernel_async_lib::{ - channel::bounded, - pubsub::{self, Attribute, Publisher, Subscriber}, - scheduler::SchedulerType, - spawn, uptime_nano, + channel::bounded, scheduler::SchedulerType, spawn, sync::barrier::Barrier, uptime_nano, }; -use core::{sync::atomic::AtomicUsize, time::Duration}; +use core::time::Duration; use serde::Serialize; const NUM_TASKS: [usize; 11] = [1000, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]; @@ -27,41 +24,6 @@ struct MeasureResult { average: f64, } -#[derive(Clone)] -struct Barrier { - count: Arc, - tx: Arc>, - rx: Subscriber<()>, -} - -impl Barrier { - fn new(count: usize) -> Self { - let attr = Attribute { - queue_size: 1, - ..Attribute::default() - }; - let (tx, rx) = pubsub::create_pubsub(attr); - - Self { - count: Arc::new(AtomicUsize::new(count)), - tx: Arc::new(tx), - rx, - } - } - - async fn wait(&mut self) { - if self - .count - .fetch_sub(1, core::sync::atomic::Ordering::Relaxed) - == 1 - { - self.tx.send(()).await; - } else { - self.rx.recv().await; - } - } -} - pub async fn run() { let mut result = alloc::vec::Vec::with_capacity(NUM_TASKS.len()); for num_task in NUM_TASKS.iter() { @@ -76,7 +38,7 @@ pub async fn run() { } async fn measure_task(num_task: usize, num_bytes: usize) -> MeasureResult { - let barrier = Barrier::new(num_task * 2); + let barrier = Arc::new(Barrier::new(num_task * 2)); let mut server_join = alloc::vec::Vec::new(); let mut client_join = alloc::vec::Vec::new(); @@ -84,7 +46,7 @@ async fn measure_task(num_task: usize, num_bytes: usize) -> MeasureResult { let (tx1, rx1) = bounded::new::>(bounded::Attribute::default()); let (tx2, rx2) = bounded::new::>(bounded::Attribute::default()); - let mut barrier2 = barrier.clone(); + let barrier2 = barrier.clone(); let hdl = spawn( format!("{i}-server").into(), async move { @@ -108,7 +70,7 @@ async fn measure_task(num_task: usize, num_bytes: usize) -> MeasureResult { server_join.push(hdl); - let mut barrier2 = barrier.clone(); + let barrier2 = barrier.clone(); let hdl = spawn( format!("{i}-client").into(), async move { diff --git a/awkernel_async_lib/src/sync.rs b/awkernel_async_lib/src/sync.rs index 81dd26c28..002ecac11 100644 --- a/awkernel_async_lib/src/sync.rs +++ b/awkernel_async_lib/src/sync.rs @@ -1,2 +1,3 @@ pub use awkernel_lib::sync::mutex as raw_mutex; +pub mod barrier; pub mod mutex; diff --git a/awkernel_async_lib/src/sync/barrier.rs b/awkernel_async_lib/src/sync/barrier.rs new file mode 100644 index 000000000..09954e26a --- /dev/null +++ b/awkernel_async_lib/src/sync/barrier.rs @@ -0,0 +1,102 @@ +use super::mutex::AsyncLock; +use crate::pubsub::{self, Attribute, Publisher, Subscriber}; +use alloc::{vec, vec::Vec}; + +struct BarrierState { + count: usize, +} + +/// A barrier enables multiple threads to synchronize the beginning of some computation. +pub struct Barrier { + lock: AsyncLock, + num_threads: usize, + tx: Publisher<()>, + rxs: Vec>, +} + +/// `BarrierWaitResult` is returned by `Barrier::wait` when all threads in `Barrier` have rendezvoused. +pub struct BarrierWaitResult(bool); + +impl BarrierWaitResult { + pub fn is_reader(&self) -> bool { + self.0 + } +} + +impl Barrier { + /// Creates a new barrier that can block a given number of threads. + pub fn new(n: usize) -> Self { + let attr = Attribute { + queue_size: 1, + ..Attribute::default() + }; + let (tx, rx) = pubsub::create_pubsub(attr); + + let mut rxs = vec![rx.clone(); n - 2]; + rxs.push(rx); + + Self { + lock: AsyncLock::new(BarrierState { count: 0 }), + num_threads: n, + tx, + rxs, + } + } + + /// Blocks the current thread until all threads have redezvoused here. + /// A single (arbitrary) thread will receive `BarrierWaitResult(true)` when returning from this function, and other threads will receive `BarrierWaitResult(false)`. + pub async fn wait(&self) -> BarrierWaitResult { + let mut lock = self.lock.lock().await; + let count = lock.count; + if count < (self.num_threads - 1) { + lock.count += 1; + drop(lock); + self.rxs[count].recv().await; + BarrierWaitResult(false) + } else { + lock.count = 0; + drop(lock); + self.tx.send(()).await; + BarrierWaitResult(true) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use alloc::sync::Arc; + use core::sync::atomic::{AtomicUsize, Ordering}; + + #[test] + fn test_simple_async_barrier() { + let barrier = Arc::new(Barrier::new(10)); + let num_waits = Arc::new(AtomicUsize::new(0)); + let num_leaders = Arc::new(AtomicUsize::new(0)); + let tasks = crate::mini_task::Tasks::new(); + + for _ in 0..10 { + let barrier = barrier.clone(); + let num_waits = num_waits.clone(); + let num_leaders = num_leaders.clone(); + let task = async move { + num_waits.fetch_add(1, Ordering::Relaxed); + + if barrier.wait().await.is_reader() { + num_leaders.fetch_add(1, Ordering::Relaxed); + } + // Verify that Barrier synchronizes the specified number of threads + assert_eq!(num_waits.load(Ordering::Relaxed), 10); + + // It is safe to call Barrier::wait again + barrier.wait().await; + }; + + tasks.spawn(task); + } + tasks.run(); + + // Verify that only one thread receives BarrierWaitResult(true) + assert_eq!(num_leaders.load(Ordering::Relaxed), 1); + } +} diff --git a/specification/awkernel_async_lib/src/barrier/README.md b/specification/awkernel_async_lib/src/barrier/README.md new file mode 100644 index 000000000..4fe2068d5 --- /dev/null +++ b/specification/awkernel_async_lib/src/barrier/README.md @@ -0,0 +1,14 @@ +# Specification of Barrier implementation +## How to run + +1. Download tla2tools (https://github.com/tlaplus/tlaplus/releases) + +2. Translate PlusCal to TLA+ +```bash +java -cp tla2tools.jar pcal.trans barrier.tla +``` + +3. Run TLC +```bash +java -jar tla2tools.jar -config barrier.cfg barrier.tla +``` \ No newline at end of file diff --git a/specification/awkernel_async_lib/src/barrier/barrier.cfg b/specification/awkernel_async_lib/src/barrier/barrier.cfg new file mode 100644 index 000000000..870858e0d --- /dev/null +++ b/specification/awkernel_async_lib/src/barrier/barrier.cfg @@ -0,0 +1,5 @@ +SPECIFICATION Spec +\* Add statements after this line. +CONSTANT Threads = {1, 2, 3, 4} +CONSTANT N = 2 +INVARIANT BarrierInvariant diff --git a/specification/awkernel_async_lib/src/barrier/barrier.tla b/specification/awkernel_async_lib/src/barrier/barrier.tla new file mode 100644 index 000000000..f13979394 --- /dev/null +++ b/specification/awkernel_async_lib/src/barrier/barrier.tla @@ -0,0 +1,110 @@ +----------------- MODULE barrier ---------------- +EXTENDS TLC, Integers, FiniteSets, Sequences +CONSTANTS Threads, N +ASSUME N \in Nat +ASSUME Threads \in SUBSET Nat + +\* It is obvious that a deadlock wil ocuur if this conditon is not satisfied. +ASSUME Cardinality(Threads) % N = 0 + +(*--algorithm barrier + +\* `count` records how many times `wait` has been called. +\* `num_blocked` records the number of blocked threads. +variables + count = 0, + num_blocked = 0, + blocked = FALSE; + +\* If `count` < N, then the thread is blocked. otherwise, all blocked threads are awakened. +\* This property implies that Barrier does not block more than N threads. +define + BarrierInvariant == num_blocked = count % N +end define; + +\* Note that `wait` is modeled as an atomic operation. +\* Therefore, the implementation needs to use mechanisms such as locks. +procedure wait() begin + Wait: + count := count + 1; + if count % N /= 0 then + num_blocked := num_blocked + 1; + blocked := TRUE; + Await: + await ~blocked; + return; + else + num_blocked := 0; + blocked := FALSE; + return; + end if; +end procedure; + +fair process thread \in Threads begin + Body: + call wait(); +end process; + +end algorithm*) +\* BEGIN TRANSLATION (chksum(pcal) = "78d1002e" /\ chksum(tla) = "8098b806") +VARIABLES pc, count, num_blocked, blocked, stack + +(* define statement *) +BarrierInvariant == num_blocked = count % N + + +vars == << pc, count, num_blocked, blocked, stack >> + +ProcSet == (Threads) + +Init == (* Global variables *) + /\ count = 0 + /\ num_blocked = 0 + /\ blocked = FALSE + /\ stack = [self \in ProcSet |-> << >>] + /\ pc = [self \in ProcSet |-> "Body"] + +Wait(self) == /\ pc[self] = "Wait" + /\ count' = count + 1 + /\ IF count' % N /= 0 + THEN /\ num_blocked' = num_blocked + 1 + /\ blocked' = TRUE + /\ pc' = [pc EXCEPT ![self] = "Await"] + /\ stack' = stack + ELSE /\ num_blocked' = 0 + /\ blocked' = FALSE + /\ pc' = [pc EXCEPT ![self] = Head(stack[self]).pc] + /\ stack' = [stack EXCEPT ![self] = Tail(stack[self])] + +Await(self) == /\ pc[self] = "Await" + /\ ~blocked + /\ pc' = [pc EXCEPT ![self] = Head(stack[self]).pc] + /\ stack' = [stack EXCEPT ![self] = Tail(stack[self])] + /\ UNCHANGED << count, num_blocked, blocked >> + +wait(self) == Wait(self) \/ Await(self) + +Body(self) == /\ pc[self] = "Body" + /\ stack' = [stack EXCEPT ![self] = << [ procedure |-> "wait", + pc |-> "Done" ] >> + \o stack[self]] + /\ pc' = [pc EXCEPT ![self] = "Wait"] + /\ UNCHANGED << count, num_blocked, blocked >> + +thread(self) == Body(self) + +(* Allow infinite stuttering to prevent deadlock on termination. *) +Terminating == /\ \A self \in ProcSet: pc[self] = "Done" + /\ UNCHANGED vars + +Next == (\E self \in ProcSet: wait(self)) + \/ (\E self \in Threads: thread(self)) + \/ Terminating + +Spec == /\ Init /\ [][Next]_vars + /\ \A self \in Threads : WF_vars(thread(self)) /\ WF_vars(wait(self)) + +Termination == <>(\A self \in ProcSet: pc[self] = "Done") + +\* END TRANSLATION +====