Skip to content
48 changes: 5 additions & 43 deletions applications/tests/test_measure_channel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@ extern crate alloc;

use alloc::{format, sync::Arc, vec::Vec};
use awkernel_async_lib::{
channel::bounded,
pubsub::{self, Attribute, Publisher, Subscriber},
scheduler::SchedulerType,
spawn, uptime_nano,
channel::bounded, scheduler::SchedulerType, spawn, sync::barrier::Barrier, uptime_nano,
};
use core::{sync::atomic::AtomicUsize, time::Duration};
use core::time::Duration;
use serde::Serialize;

const NUM_TASKS: [usize; 11] = [1000, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000];
Expand All @@ -27,41 +24,6 @@ struct MeasureResult {
average: f64,
}

#[derive(Clone)]
struct Barrier {
count: Arc<AtomicUsize>,
tx: Arc<Publisher<()>>,
rx: Subscriber<()>,
}

impl Barrier {
fn new(count: usize) -> Self {
let attr = Attribute {
queue_size: 1,
..Attribute::default()
};
let (tx, rx) = pubsub::create_pubsub(attr);

Self {
count: Arc::new(AtomicUsize::new(count)),
tx: Arc::new(tx),
rx,
}
}

async fn wait(&mut self) {
if self
.count
.fetch_sub(1, core::sync::atomic::Ordering::Relaxed)
== 1
{
self.tx.send(()).await;
} else {
self.rx.recv().await;
}
}
}

pub async fn run() {
let mut result = alloc::vec::Vec::with_capacity(NUM_TASKS.len());
for num_task in NUM_TASKS.iter() {
Expand All @@ -76,15 +38,15 @@ pub async fn run() {
}

async fn measure_task(num_task: usize, num_bytes: usize) -> MeasureResult {
let barrier = Barrier::new(num_task * 2);
let barrier = Arc::new(Barrier::new(num_task * 2));
let mut server_join = alloc::vec::Vec::new();
let mut client_join = alloc::vec::Vec::new();

for i in 0..num_task {
let (tx1, rx1) = bounded::new::<Vec<u8>>(bounded::Attribute::default());
let (tx2, rx2) = bounded::new::<Vec<u8>>(bounded::Attribute::default());

let mut barrier2 = barrier.clone();
let barrier2 = barrier.clone();
let hdl = spawn(
format!("{i}-server").into(),
async move {
Expand All @@ -108,7 +70,7 @@ async fn measure_task(num_task: usize, num_bytes: usize) -> MeasureResult {

server_join.push(hdl);

let mut barrier2 = barrier.clone();
let barrier2 = barrier.clone();
let hdl = spawn(
format!("{i}-client").into(),
async move {
Expand Down
1 change: 1 addition & 0 deletions awkernel_async_lib/src/sync.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub use awkernel_lib::sync::mutex as raw_mutex;
pub mod barrier;
pub mod mutex;
102 changes: 102 additions & 0 deletions awkernel_async_lib/src/sync/barrier.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use super::mutex::AsyncLock;
use crate::pubsub::{self, Attribute, Publisher, Subscriber};
use alloc::{vec, vec::Vec};

struct BarrierState {
count: usize,
}

/// A barrier enables multiple threads to synchronize the beginning of some computation.
pub struct Barrier {
lock: AsyncLock<BarrierState>,
num_threads: usize,
tx: Publisher<()>,
rxs: Vec<Subscriber<()>>,
}

/// `BarrierWaitResult` is returned by `Barrier::wait` when all threads in `Barrier` have rendezvoused.
pub struct BarrierWaitResult(bool);

impl BarrierWaitResult {
pub fn is_reader(&self) -> bool {
self.0
}
}

impl Barrier {
/// Creates a new barrier that can block a given number of threads.
pub fn new(n: usize) -> Self {
let attr = Attribute {
queue_size: 1,
..Attribute::default()
};
let (tx, rx) = pubsub::create_pubsub(attr);

let mut rxs = vec![rx.clone(); n - 2];
rxs.push(rx);

Self {
lock: AsyncLock::new(BarrierState { count: 0 }),
num_threads: n,
tx,
rxs,
}
}

/// Blocks the current thread until all threads have redezvoused here.
/// A single (arbitrary) thread will receive `BarrierWaitResult(true)` when returning from this function, and other threads will receive `BarrierWaitResult(false)`.
pub async fn wait(&self) -> BarrierWaitResult {
let mut lock = self.lock.lock().await;
let count = lock.count;
if count < (self.num_threads - 1) {
lock.count += 1;
drop(lock);
self.rxs[count].recv().await;
BarrierWaitResult(false)
} else {
lock.count = 0;
drop(lock);
self.tx.send(()).await;
BarrierWaitResult(true)
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use alloc::sync::Arc;
use core::sync::atomic::{AtomicUsize, Ordering};

#[test]
fn test_simple_async_barrier() {
let barrier = Arc::new(Barrier::new(10));
let num_waits = Arc::new(AtomicUsize::new(0));
let num_leaders = Arc::new(AtomicUsize::new(0));
let tasks = crate::mini_task::Tasks::new();

for _ in 0..10 {
let barrier = barrier.clone();
let num_waits = num_waits.clone();
let num_leaders = num_leaders.clone();
let task = async move {
num_waits.fetch_add(1, Ordering::Relaxed);

if barrier.wait().await.is_reader() {
num_leaders.fetch_add(1, Ordering::Relaxed);
}
// Verify that Barrier synchronizes the specified number of threads
assert_eq!(num_waits.load(Ordering::Relaxed), 10);

// It is safe to call Barrier::wait again
barrier.wait().await;
};

tasks.spawn(task);
}
tasks.run();

// Verify that only one thread receives BarrierWaitResult(true)
assert_eq!(num_leaders.load(Ordering::Relaxed), 1);
}
}
14 changes: 14 additions & 0 deletions specification/awkernel_async_lib/src/barrier/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Specification of Barrier implementation
## How to run

1. Download tla2tools (https://github.com/tlaplus/tlaplus/releases)

2. Translate PlusCal to TLA+
```bash
java -cp tla2tools.jar pcal.trans barrier.tla
```

3. Run TLC
```bash
java -jar tla2tools.jar -config barrier.cfg barrier.tla
```
5 changes: 5 additions & 0 deletions specification/awkernel_async_lib/src/barrier/barrier.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SPECIFICATION Spec
\* Add statements after this line.
CONSTANT Threads = {1, 2, 3, 4}
CONSTANT N = 2
INVARIANT BarrierInvariant
110 changes: 110 additions & 0 deletions specification/awkernel_async_lib/src/barrier/barrier.tla
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
----------------- MODULE barrier ----------------
EXTENDS TLC, Integers, FiniteSets, Sequences
CONSTANTS Threads, N
ASSUME N \in Nat
ASSUME Threads \in SUBSET Nat

\* It is obvious that a deadlock wil ocuur if this conditon is not satisfied.
ASSUME Cardinality(Threads) % N = 0

(*--algorithm barrier

\* `count` records how many times `wait` has been called.
\* `num_blocked` records the number of blocked threads.
variables
count = 0,
num_blocked = 0,
blocked = FALSE;

\* If `count` < N, then the thread is blocked. otherwise, all blocked threads are awakened.
\* This property implies that Barrier does not block more than N threads.
define
BarrierInvariant == num_blocked = count % N
end define;

\* Note that `wait` is modeled as an atomic operation.
\* Therefore, the implementation needs to use mechanisms such as locks.
procedure wait() begin
Wait:
count := count + 1;
if count % N /= 0 then
num_blocked := num_blocked + 1;
blocked := TRUE;
Await:
await ~blocked;
return;
else
num_blocked := 0;
blocked := FALSE;
return;
end if;
end procedure;

fair process thread \in Threads begin
Body:
call wait();
end process;

end algorithm*)
\* BEGIN TRANSLATION (chksum(pcal) = "78d1002e" /\ chksum(tla) = "8098b806")
VARIABLES pc, count, num_blocked, blocked, stack

(* define statement *)
BarrierInvariant == num_blocked = count % N


vars == << pc, count, num_blocked, blocked, stack >>

ProcSet == (Threads)

Init == (* Global variables *)
/\ count = 0
/\ num_blocked = 0
/\ blocked = FALSE
/\ stack = [self \in ProcSet |-> << >>]
/\ pc = [self \in ProcSet |-> "Body"]

Wait(self) == /\ pc[self] = "Wait"
/\ count' = count + 1
/\ IF count' % N /= 0
THEN /\ num_blocked' = num_blocked + 1
/\ blocked' = TRUE
/\ pc' = [pc EXCEPT ![self] = "Await"]
/\ stack' = stack
ELSE /\ num_blocked' = 0
/\ blocked' = FALSE
/\ pc' = [pc EXCEPT ![self] = Head(stack[self]).pc]
/\ stack' = [stack EXCEPT ![self] = Tail(stack[self])]

Await(self) == /\ pc[self] = "Await"
/\ ~blocked
/\ pc' = [pc EXCEPT ![self] = Head(stack[self]).pc]
/\ stack' = [stack EXCEPT ![self] = Tail(stack[self])]
/\ UNCHANGED << count, num_blocked, blocked >>

wait(self) == Wait(self) \/ Await(self)

Body(self) == /\ pc[self] = "Body"
/\ stack' = [stack EXCEPT ![self] = << [ procedure |-> "wait",
pc |-> "Done" ] >>
\o stack[self]]
/\ pc' = [pc EXCEPT ![self] = "Wait"]
/\ UNCHANGED << count, num_blocked, blocked >>

thread(self) == Body(self)

(* Allow infinite stuttering to prevent deadlock on termination. *)
Terminating == /\ \A self \in ProcSet: pc[self] = "Done"
/\ UNCHANGED vars

Next == (\E self \in ProcSet: wait(self))
\/ (\E self \in Threads: thread(self))
\/ Terminating

Spec == /\ Init /\ [][Next]_vars
/\ \A self \in Threads : WF_vars(thread(self)) /\ WF_vars(wait(self))

Termination == <>(\A self \in ProcSet: pc[self] = "Done")

\* END TRANSLATION
====
Loading