From bf13531e600ae03f9cbc66dc923f09a89b2b6453 Mon Sep 17 00:00:00 2001 From: r1ru Date: Thu, 12 Dec 2024 16:36:51 +0900 Subject: [PATCH 01/11] Fix #230: Modify Barrier::wait to return BarrierWaitResult --- applications/measure_channel/src/lib.rs | 28 +++++++++++++++++-------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/applications/measure_channel/src/lib.rs b/applications/measure_channel/src/lib.rs index 5561d8a20..8a4af7c9e 100644 --- a/applications/measure_channel/src/lib.rs +++ b/applications/measure_channel/src/lib.rs @@ -27,13 +27,21 @@ struct MeasureResult { average: f64, } -#[derive(Clone)] struct Barrier { - count: Arc, - tx: Arc>, + count: AtomicUsize, + tx: Publisher<()>, rx: Subscriber<()>, } +/// `BarrierWaitResult` is returned by `Barrier::wait` when all threads in `Barrier` have rendezvoused. +pub struct BarrierWaitResult(bool); + +impl BarrierWaitResult { + pub fn is_reader(&self) -> bool { + self.0 + } +} + impl Barrier { fn new(count: usize) -> Self { let attr = Attribute { @@ -43,21 +51,23 @@ impl Barrier { let (tx, rx) = pubsub::create_pubsub(attr); Self { - count: Arc::new(AtomicUsize::new(count)), - tx: Arc::new(tx), + count: AtomicUsize::new(count), + tx, rx, } } - async fn wait(&mut self) { + async fn wait(&self) -> BarrierWaitResult { if self .count .fetch_sub(1, core::sync::atomic::Ordering::Relaxed) == 1 { self.tx.send(()).await; + BarrierWaitResult(true) } else { self.rx.recv().await; + BarrierWaitResult(false) } } } @@ -76,7 +86,7 @@ pub async fn run() { } async fn measure_task(num_task: usize, num_bytes: usize) -> MeasureResult { - let barrier = Barrier::new(num_task * 2); + let barrier = Arc::new(Barrier::new(num_task * 2)); let mut server_join = alloc::vec::Vec::new(); let mut client_join = alloc::vec::Vec::new(); @@ -84,7 +94,7 @@ async fn measure_task(num_task: usize, num_bytes: usize) -> MeasureResult { let (tx1, rx1) = bounded::new::>(bounded::Attribute::default()); let (tx2, rx2) = bounded::new::>(bounded::Attribute::default()); - let mut barrier2 = barrier.clone(); + let barrier2 = barrier.clone(); let hdl = spawn( format!("{i}-server").into(), async move { @@ -108,7 +118,7 @@ async fn measure_task(num_task: usize, num_bytes: usize) -> MeasureResult { server_join.push(hdl); - let mut barrier2 = barrier.clone(); + let barrier2 = barrier.clone(); let hdl = spawn( format!("{i}-client").into(), async move { From fecabaf7d1ea7b6275203de51a3a6c399b4ce540 Mon Sep 17 00:00:00 2001 From: r1ru Date: Thu, 12 Dec 2024 19:05:17 +0900 Subject: [PATCH 02/11] Move Barrier implementation to barrier.rs --- .../tests/test_measure_channel/src/lib.rs | 52 +------------------ awkernel_async_lib/src/sync.rs | 1 + awkernel_async_lib/src/sync/barrier.rs | 51 ++++++++++++++++++ 3 files changed, 54 insertions(+), 50 deletions(-) create mode 100644 awkernel_async_lib/src/sync/barrier.rs diff --git a/applications/tests/test_measure_channel/src/lib.rs b/applications/tests/test_measure_channel/src/lib.rs index 8a4af7c9e..3efbdd5ae 100644 --- a/applications/tests/test_measure_channel/src/lib.rs +++ b/applications/tests/test_measure_channel/src/lib.rs @@ -4,12 +4,9 @@ extern crate alloc; use alloc::{format, sync::Arc, vec::Vec}; use awkernel_async_lib::{ - channel::bounded, - pubsub::{self, Attribute, Publisher, Subscriber}, - scheduler::SchedulerType, - spawn, uptime_nano, + channel::bounded, scheduler::SchedulerType, spawn, sync::barrier::Barrier, uptime_nano, }; -use core::{sync::atomic::AtomicUsize, time::Duration}; +use core::time::Duration; use serde::Serialize; const NUM_TASKS: [usize; 11] = [1000, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]; @@ -27,51 +24,6 @@ struct MeasureResult { average: f64, } -struct Barrier { - count: AtomicUsize, - tx: Publisher<()>, - rx: Subscriber<()>, -} - -/// `BarrierWaitResult` is returned by `Barrier::wait` when all threads in `Barrier` have rendezvoused. -pub struct BarrierWaitResult(bool); - -impl BarrierWaitResult { - pub fn is_reader(&self) -> bool { - self.0 - } -} - -impl Barrier { - fn new(count: usize) -> Self { - let attr = Attribute { - queue_size: 1, - ..Attribute::default() - }; - let (tx, rx) = pubsub::create_pubsub(attr); - - Self { - count: AtomicUsize::new(count), - tx, - rx, - } - } - - async fn wait(&self) -> BarrierWaitResult { - if self - .count - .fetch_sub(1, core::sync::atomic::Ordering::Relaxed) - == 1 - { - self.tx.send(()).await; - BarrierWaitResult(true) - } else { - self.rx.recv().await; - BarrierWaitResult(false) - } - } -} - pub async fn run() { let mut result = alloc::vec::Vec::with_capacity(NUM_TASKS.len()); for num_task in NUM_TASKS.iter() { diff --git a/awkernel_async_lib/src/sync.rs b/awkernel_async_lib/src/sync.rs index 81dd26c28..002ecac11 100644 --- a/awkernel_async_lib/src/sync.rs +++ b/awkernel_async_lib/src/sync.rs @@ -1,2 +1,3 @@ pub use awkernel_lib::sync::mutex as raw_mutex; +pub mod barrier; pub mod mutex; diff --git a/awkernel_async_lib/src/sync/barrier.rs b/awkernel_async_lib/src/sync/barrier.rs new file mode 100644 index 000000000..532e6b1c2 --- /dev/null +++ b/awkernel_async_lib/src/sync/barrier.rs @@ -0,0 +1,51 @@ +use crate::pubsub::{self, Attribute, Publisher, Subscriber}; +use core::sync::atomic::AtomicUsize; + +/// A barrier enables multiple threads to synchronize the beginning of some computation. +pub struct Barrier { + count: AtomicUsize, + tx: Publisher<()>, + rx: Subscriber<()>, +} + +/// `BarrierWaitResult` is returned by `Barrier::wait` when all threads in `Barrier` have rendezvoused. +pub struct BarrierWaitResult(bool); + +impl BarrierWaitResult { + pub fn is_reader(&self) -> bool { + self.0 + } +} + +impl Barrier { + /// Creates a new barrier that can block a given number of threads. + pub fn new(count: usize) -> Self { + let attr = Attribute { + queue_size: 1, + ..Attribute::default() + }; + let (tx, rx) = pubsub::create_pubsub(attr); + + Self { + count: AtomicUsize::new(count), + tx, + rx, + } + } + + /// Blocks the current thread until all threads have redezvoused here. + /// A single (arbitrary) thread will receive `BarrierWaitResult(true)` when returning fron this function, and other threads will receive `BarrierWaitResult(false)`. + pub async fn wait(&self) -> BarrierWaitResult { + if self + .count + .fetch_sub(1, core::sync::atomic::Ordering::Relaxed) + == 1 + { + self.tx.send(()).await; + BarrierWaitResult(true) + } else { + self.rx.recv().await; + BarrierWaitResult(false) + } + } +} From 32fc9b4e49e12056f9d0047ee89f4196077e3d21 Mon Sep 17 00:00:00 2001 From: r1ru Date: Thu, 12 Dec 2024 20:06:36 +0900 Subject: [PATCH 03/11] Create simple test --- awkernel_async_lib/src/sync/barrier.rs | 32 ++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/awkernel_async_lib/src/sync/barrier.rs b/awkernel_async_lib/src/sync/barrier.rs index 532e6b1c2..33327583d 100644 --- a/awkernel_async_lib/src/sync/barrier.rs +++ b/awkernel_async_lib/src/sync/barrier.rs @@ -49,3 +49,35 @@ impl Barrier { } } } + +#[cfg(test)] +mod tests { + use super::*; + use alloc::sync::Arc; + use core::sync::atomic::{AtomicUsize, Ordering}; + + #[test] + fn test_simple_async_barrier() { + let barrier = Arc::new(Barrier::new(10)); + let num_waits = Arc::new(AtomicUsize::new(0)); + let num_leaders = Arc::new(AtomicUsize::new(0)); + let tasks = crate::mini_task::Tasks::new(); + + for _ in 0..10 { + let barrier = barrier.clone(); + let num_waits = num_waits.clone(); + let num_leaders = num_leaders.clone(); + let task = async move { + num_waits.fetch_add(1, Ordering::Relaxed); + if barrier.wait().await.is_reader() { + num_leaders.fetch_add(1, Ordering::Relaxed); + } + assert_eq!(num_waits.load(Ordering::Relaxed), 10); + }; + tasks.spawn(task); + } + tasks.run(); + + assert_eq!(num_leaders.load(Ordering::Relaxed), 1); + } +} From 220bf374f2f57fdf5c2230bff49cf8a272a14932 Mon Sep 17 00:00:00 2001 From: r1ru Date: Thu, 19 Dec 2024 10:10:59 +0900 Subject: [PATCH 04/11] Bug fix: resolve infinite loop --- awkernel_async_lib/src/sync/barrier.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awkernel_async_lib/src/sync/barrier.rs b/awkernel_async_lib/src/sync/barrier.rs index 33327583d..7b6742dce 100644 --- a/awkernel_async_lib/src/sync/barrier.rs +++ b/awkernel_async_lib/src/sync/barrier.rs @@ -44,7 +44,7 @@ impl Barrier { self.tx.send(()).await; BarrierWaitResult(true) } else { - self.rx.recv().await; + self.rx.clone().recv().await; BarrierWaitResult(false) } } From 5f360a1a80f2b0ea573696310ccf6883451846e6 Mon Sep 17 00:00:00 2001 From: r1ru Date: Thu, 19 Dec 2024 13:16:19 +0900 Subject: [PATCH 05/11] Handles the case where Barrier::wait is called more than the specified number of times --- awkernel_async_lib/src/sync/barrier.rs | 36 ++++++++++++++++++-------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/awkernel_async_lib/src/sync/barrier.rs b/awkernel_async_lib/src/sync/barrier.rs index 7b6742dce..f4f94e980 100644 --- a/awkernel_async_lib/src/sync/barrier.rs +++ b/awkernel_async_lib/src/sync/barrier.rs @@ -1,11 +1,13 @@ use crate::pubsub::{self, Attribute, Publisher, Subscriber}; +use alloc::{vec, vec::Vec}; use core::sync::atomic::AtomicUsize; /// A barrier enables multiple threads to synchronize the beginning of some computation. pub struct Barrier { count: AtomicUsize, + num_threads: usize, tx: Publisher<()>, - rx: Subscriber<()>, + rxs: Vec>, } /// `BarrierWaitResult` is returned by `Barrier::wait` when all threads in `Barrier` have rendezvoused. @@ -19,33 +21,38 @@ impl BarrierWaitResult { impl Barrier { /// Creates a new barrier that can block a given number of threads. - pub fn new(count: usize) -> Self { + pub fn new(n: usize) -> Self { let attr = Attribute { queue_size: 1, ..Attribute::default() }; let (tx, rx) = pubsub::create_pubsub(attr); + let mut rxs = vec![rx.clone(); n - 2]; + rxs.push(rx); + Self { - count: AtomicUsize::new(count), + num_threads: n, + count: AtomicUsize::new(0), tx, - rx, + rxs, } } /// Blocks the current thread until all threads have redezvoused here. /// A single (arbitrary) thread will receive `BarrierWaitResult(true)` when returning fron this function, and other threads will receive `BarrierWaitResult(false)`. pub async fn wait(&self) -> BarrierWaitResult { - if self + let count = self .count - .fetch_sub(1, core::sync::atomic::Ordering::Relaxed) - == 1 - { + .fetch_add(1, core::sync::atomic::Ordering::Relaxed); + if count < self.num_threads - 1 { + self.rxs[count].recv().await; + BarrierWaitResult(false) + } else { + // Safety: count mut be set to 0 before calling Sender::poll, as it switches to a task waiting to receive. + self.count.store(0, core::sync::atomic::Ordering::Relaxed); self.tx.send(()).await; BarrierWaitResult(true) - } else { - self.rx.clone().recv().await; - BarrierWaitResult(false) } } } @@ -69,15 +76,22 @@ mod tests { let num_leaders = num_leaders.clone(); let task = async move { num_waits.fetch_add(1, Ordering::Relaxed); + if barrier.wait().await.is_reader() { num_leaders.fetch_add(1, Ordering::Relaxed); } + // Verify that Barrier synchronizes the specified number of threads assert_eq!(num_waits.load(Ordering::Relaxed), 10); + + // it is safe to call Barrier::wait more than count times + barrier.wait().await; }; + tasks.spawn(task); } tasks.run(); + // Vefify that only one thread receives BarrierWaitResult(true) assert_eq!(num_leaders.load(Ordering::Relaxed), 1); } } From 86edfc0234a05ff132eb5ee0ae072b5d0c877c41 Mon Sep 17 00:00:00 2001 From: r1ru Date: Thu, 19 Dec 2024 13:20:24 +0900 Subject: [PATCH 06/11] Fix typo --- awkernel_async_lib/src/sync/barrier.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/awkernel_async_lib/src/sync/barrier.rs b/awkernel_async_lib/src/sync/barrier.rs index f4f94e980..fb0048ccd 100644 --- a/awkernel_async_lib/src/sync/barrier.rs +++ b/awkernel_async_lib/src/sync/barrier.rs @@ -83,7 +83,7 @@ mod tests { // Verify that Barrier synchronizes the specified number of threads assert_eq!(num_waits.load(Ordering::Relaxed), 10); - // it is safe to call Barrier::wait more than count times + // it is safe to call Barrier::wait again barrier.wait().await; }; @@ -91,7 +91,7 @@ mod tests { } tasks.run(); - // Vefify that only one thread receives BarrierWaitResult(true) + // Verify that only one thread receives BarrierWaitResult(true) assert_eq!(num_leaders.load(Ordering::Relaxed), 1); } } From 64ffd480ff8b34c7bb49a893b8e1ebbaff5800b5 Mon Sep 17 00:00:00 2001 From: r1ru Date: Tue, 24 Dec 2024 13:16:23 +0900 Subject: [PATCH 07/11] Fix typo --- awkernel_async_lib/src/sync/barrier.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/awkernel_async_lib/src/sync/barrier.rs b/awkernel_async_lib/src/sync/barrier.rs index fb0048ccd..26634a80d 100644 --- a/awkernel_async_lib/src/sync/barrier.rs +++ b/awkernel_async_lib/src/sync/barrier.rs @@ -40,7 +40,7 @@ impl Barrier { } /// Blocks the current thread until all threads have redezvoused here. - /// A single (arbitrary) thread will receive `BarrierWaitResult(true)` when returning fron this function, and other threads will receive `BarrierWaitResult(false)`. + /// A single (arbitrary) thread will receive `BarrierWaitResult(true)` when returning from this function, and other threads will receive `BarrierWaitResult(false)`. pub async fn wait(&self) -> BarrierWaitResult { let count = self .count @@ -49,7 +49,7 @@ impl Barrier { self.rxs[count].recv().await; BarrierWaitResult(false) } else { - // Safety: count mut be set to 0 before calling Sender::poll, as it switches to a task waiting to receive. + // Safety: count must be set to 0 before calling Sender::poll, as it switches to a task waiting to receive. self.count.store(0, core::sync::atomic::Ordering::Relaxed); self.tx.send(()).await; BarrierWaitResult(true) From f2e57c1bbb8dce4775365f2c74bfba49e4258518 Mon Sep 17 00:00:00 2001 From: r1ru Date: Tue, 24 Dec 2024 17:04:37 +0900 Subject: [PATCH 08/11] Verify the algorithm of Barrier::wait with TLA+ --- .../awkernel_async_lib/src/barrier/README.md | 14 +++ .../src/barrier/barrier.cfg | 5 + .../src/barrier/barrier.tla | 110 ++++++++++++++++++ 3 files changed, 129 insertions(+) create mode 100644 specification/awkernel_async_lib/src/barrier/README.md create mode 100644 specification/awkernel_async_lib/src/barrier/barrier.cfg create mode 100644 specification/awkernel_async_lib/src/barrier/barrier.tla diff --git a/specification/awkernel_async_lib/src/barrier/README.md b/specification/awkernel_async_lib/src/barrier/README.md new file mode 100644 index 000000000..dbe022557 --- /dev/null +++ b/specification/awkernel_async_lib/src/barrier/README.md @@ -0,0 +1,14 @@ +# Specification of Barrier implementation +## How to run + +1. download tla2tools (https://github.com/tlaplus/tlaplus/releases) + +2. Translate PlusCal to TLA+ +```bash +java -cp tla2tools.jar pcal.trans barrier.tla +``` + +3. Run TLC +```bash +java -jar tla2tools.jar -config barrier.cfg barrier.tla +``` \ No newline at end of file diff --git a/specification/awkernel_async_lib/src/barrier/barrier.cfg b/specification/awkernel_async_lib/src/barrier/barrier.cfg new file mode 100644 index 000000000..870858e0d --- /dev/null +++ b/specification/awkernel_async_lib/src/barrier/barrier.cfg @@ -0,0 +1,5 @@ +SPECIFICATION Spec +\* Add statements after this line. +CONSTANT Threads = {1, 2, 3, 4} +CONSTANT N = 2 +INVARIANT BarrierInvariant diff --git a/specification/awkernel_async_lib/src/barrier/barrier.tla b/specification/awkernel_async_lib/src/barrier/barrier.tla new file mode 100644 index 000000000..acf1ea6e9 --- /dev/null +++ b/specification/awkernel_async_lib/src/barrier/barrier.tla @@ -0,0 +1,110 @@ +----------------- MODULE barrier ---------------- +EXTENDS TLC, Integers, FiniteSets, Sequences +CONSTANTS Threads, N +ASSUME N \in Nat +ASSUME Threads \in SUBSET Nat + +\* It is obvious that a deadlock wil ocuur if this conditon is not satisfied. +ASSUME Cardinality(Threads) % N = 0 + +(*--algorithm barrier + +\* `count` records how many times `wait` has been called. +\* `num_blocked` records the number of blocked threads. +variables + count = 0, + num_blocked = 0, + blocked = FALSE; + +\* If `count` < N, then the thread is blocked. otherwise, all blocked threads are awakened. +\* This property implies that Barrier does not block more than N threads. +define + BarrierCorrectness == num_blocked = count % N +end define; + +\* Note that `wait` is modeled as an atomic operation. +\* Therefore, the implementation needs to use mechanisms such as locks. +procedure wait() begin + Wait: + count := count + 1; + if count % N /= 0 then + num_blocked := num_blocked + 1; + blocked := TRUE; + Await: + await ~blocked; + return; + else + num_blocked := 0; + blocked := FALSE; + return; + end if; +end procedure; + +fair process thread \in Threads begin + Body: + call wait(); +end process; + +end algorithm*) +\* BEGIN TRANSLATION (chksum(pcal) = "26f45583" /\ chksum(tla) = "34eb2117") +VARIABLES pc, count, num_blocked, blocked, stack + +(* define statement *) +BarrierCorrectness == num_blocked = count % N + + +vars == << pc, count, num_blocked, blocked, stack >> + +ProcSet == (Threads) + +Init == (* Global variables *) + /\ count = 0 + /\ num_blocked = 0 + /\ blocked = FALSE + /\ stack = [self \in ProcSet |-> << >>] + /\ pc = [self \in ProcSet |-> "Body"] + +Wait(self) == /\ pc[self] = "Wait" + /\ count' = count + 1 + /\ IF count' % N /= 0 + THEN /\ num_blocked' = num_blocked + 1 + /\ blocked' = TRUE + /\ pc' = [pc EXCEPT ![self] = "Await"] + /\ stack' = stack + ELSE /\ num_blocked' = 0 + /\ blocked' = FALSE + /\ pc' = [pc EXCEPT ![self] = Head(stack[self]).pc] + /\ stack' = [stack EXCEPT ![self] = Tail(stack[self])] + +Await(self) == /\ pc[self] = "Await" + /\ ~blocked + /\ pc' = [pc EXCEPT ![self] = Head(stack[self]).pc] + /\ stack' = [stack EXCEPT ![self] = Tail(stack[self])] + /\ UNCHANGED << count, num_blocked, blocked >> + +wait(self) == Wait(self) \/ Await(self) + +Body(self) == /\ pc[self] = "Body" + /\ stack' = [stack EXCEPT ![self] = << [ procedure |-> "wait", + pc |-> "Done" ] >> + \o stack[self]] + /\ pc' = [pc EXCEPT ![self] = "Wait"] + /\ UNCHANGED << count, num_blocked, blocked >> + +thread(self) == Body(self) + +(* Allow infinite stuttering to prevent deadlock on termination. *) +Terminating == /\ \A self \in ProcSet: pc[self] = "Done" + /\ UNCHANGED vars + +Next == (\E self \in ProcSet: wait(self)) + \/ (\E self \in Threads: thread(self)) + \/ Terminating + +Spec == /\ Init /\ [][Next]_vars + /\ \A self \in Threads : WF_vars(thread(self)) /\ WF_vars(wait(self)) + +Termination == <>(\A self \in ProcSet: pc[self] = "Done") + +\* END TRANSLATION +==== From f0e5a23bb4b6f1a41879ef9f0e079d76be8c0cfc Mon Sep 17 00:00:00 2001 From: r1ru Date: Thu, 26 Dec 2024 14:12:13 +0900 Subject: [PATCH 09/11] Fix typo --- specification/awkernel_async_lib/src/barrier/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specification/awkernel_async_lib/src/barrier/README.md b/specification/awkernel_async_lib/src/barrier/README.md index dbe022557..4fe2068d5 100644 --- a/specification/awkernel_async_lib/src/barrier/README.md +++ b/specification/awkernel_async_lib/src/barrier/README.md @@ -1,7 +1,7 @@ # Specification of Barrier implementation ## How to run -1. download tla2tools (https://github.com/tlaplus/tlaplus/releases) +1. Download tla2tools (https://github.com/tlaplus/tlaplus/releases) 2. Translate PlusCal to TLA+ ```bash From 5a4d8d7847400cdddcca4c55291683fbd79f331c Mon Sep 17 00:00:00 2001 From: r1ru Date: Thu, 9 Jan 2025 09:12:42 +0900 Subject: [PATCH 10/11] Fix typo --- specification/awkernel_async_lib/src/barrier/barrier.tla | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/specification/awkernel_async_lib/src/barrier/barrier.tla b/specification/awkernel_async_lib/src/barrier/barrier.tla index acf1ea6e9..f13979394 100644 --- a/specification/awkernel_async_lib/src/barrier/barrier.tla +++ b/specification/awkernel_async_lib/src/barrier/barrier.tla @@ -19,7 +19,7 @@ variables \* If `count` < N, then the thread is blocked. otherwise, all blocked threads are awakened. \* This property implies that Barrier does not block more than N threads. define - BarrierCorrectness == num_blocked = count % N + BarrierInvariant == num_blocked = count % N end define; \* Note that `wait` is modeled as an atomic operation. @@ -46,11 +46,11 @@ fair process thread \in Threads begin end process; end algorithm*) -\* BEGIN TRANSLATION (chksum(pcal) = "26f45583" /\ chksum(tla) = "34eb2117") +\* BEGIN TRANSLATION (chksum(pcal) = "78d1002e" /\ chksum(tla) = "8098b806") VARIABLES pc, count, num_blocked, blocked, stack (* define statement *) -BarrierCorrectness == num_blocked = count % N +BarrierInvariant == num_blocked = count % N vars == << pc, count, num_blocked, blocked, stack >> From 6570b17bf512d93f9838aff13d5a588608b75924 Mon Sep 17 00:00:00 2001 From: r1ru Date: Thu, 9 Jan 2025 12:10:11 +0900 Subject: [PATCH 11/11] Fix the implementation of Barrier::wait --- awkernel_async_lib/src/sync/barrier.rs | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/awkernel_async_lib/src/sync/barrier.rs b/awkernel_async_lib/src/sync/barrier.rs index 26634a80d..09954e26a 100644 --- a/awkernel_async_lib/src/sync/barrier.rs +++ b/awkernel_async_lib/src/sync/barrier.rs @@ -1,10 +1,14 @@ +use super::mutex::AsyncLock; use crate::pubsub::{self, Attribute, Publisher, Subscriber}; use alloc::{vec, vec::Vec}; -use core::sync::atomic::AtomicUsize; + +struct BarrierState { + count: usize, +} /// A barrier enables multiple threads to synchronize the beginning of some computation. pub struct Barrier { - count: AtomicUsize, + lock: AsyncLock, num_threads: usize, tx: Publisher<()>, rxs: Vec>, @@ -32,8 +36,8 @@ impl Barrier { rxs.push(rx); Self { + lock: AsyncLock::new(BarrierState { count: 0 }), num_threads: n, - count: AtomicUsize::new(0), tx, rxs, } @@ -42,15 +46,16 @@ impl Barrier { /// Blocks the current thread until all threads have redezvoused here. /// A single (arbitrary) thread will receive `BarrierWaitResult(true)` when returning from this function, and other threads will receive `BarrierWaitResult(false)`. pub async fn wait(&self) -> BarrierWaitResult { - let count = self - .count - .fetch_add(1, core::sync::atomic::Ordering::Relaxed); - if count < self.num_threads - 1 { + let mut lock = self.lock.lock().await; + let count = lock.count; + if count < (self.num_threads - 1) { + lock.count += 1; + drop(lock); self.rxs[count].recv().await; BarrierWaitResult(false) } else { - // Safety: count must be set to 0 before calling Sender::poll, as it switches to a task waiting to receive. - self.count.store(0, core::sync::atomic::Ordering::Relaxed); + lock.count = 0; + drop(lock); self.tx.send(()).await; BarrierWaitResult(true) } @@ -83,7 +88,7 @@ mod tests { // Verify that Barrier synchronizes the specified number of threads assert_eq!(num_waits.load(Ordering::Relaxed), 10); - // it is safe to call Barrier::wait again + // It is safe to call Barrier::wait again barrier.wait().await; };