Skip to content

Commit

Permalink
feature: Add Semaphore::add_permits()
Browse files Browse the repository at this point in the history
This function allows increasing the permit count after creation.
  • Loading branch information
Jules-Bertholet committed May 27, 2023
1 parent 6bac58a commit 99a44a5
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 4 deletions.
26 changes: 24 additions & 2 deletions src/semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ impl Semaphore {
listener: None,
}
}
}

impl Semaphore {
/// Attempts to get an owned permit for a concurrent operation.
///
/// If the permit could not be acquired at this time, then [`None`] is returned. Otherwise, an
Expand Down Expand Up @@ -152,6 +150,30 @@ impl Semaphore {
listener: None,
}
}

/// Adds `n` additional permits to the semaphore.
///
/// # Examples
///
/// ```
/// use async_lock::Semaphore;
///
/// # futures_lite::future::block_on(async {
/// let s = Semaphore::new(1);
///
/// let _guard = s.acquire().await;
/// assert!(s.try_acquire().is_none());
///
/// s.add_permits(2);
///
/// let _guard = s.acquire().await;
/// let _guard = s.acquire().await;
/// # });
/// ```
pub fn add_permits(&self, n: usize) {
self.count.fetch_add(n, Ordering::AcqRel);
self.event.notify(n);
}
}

/// The future returned by [`Semaphore::acquire`].
Expand Down
69 changes: 67 additions & 2 deletions tests/semaphore.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
mod common;

use std::sync::{mpsc, Arc};
use std::future::Future;
use std::mem::forget;
use std::pin::Pin;
use std::sync::{
atomic::{AtomicUsize, Ordering},
mpsc, Arc,
};
use std::task::Context;
use std::task::Poll;
use std::thread;

use common::check_yields_when_contended;

use async_lock::Semaphore;
use futures_lite::future;
use futures_lite::{future, pin};

#[test]
fn try_acquire() {
Expand Down Expand Up @@ -105,3 +113,60 @@ fn yields_when_contended() {
let s = Arc::new(s);
check_yields_when_contended(s.try_acquire_arc().unwrap(), s.acquire_arc());
}

#[test]
fn add_permits() {
static COUNTER: AtomicUsize = AtomicUsize::new(0);

let s = Arc::new(Semaphore::new(0));
let (tx, rx) = mpsc::channel::<()>();

for _ in 0..50 {
let s = s.clone();
let tx = tx.clone();

thread::spawn(move || {
future::block_on(async {
let perm = s.acquire().await;
forget(perm);
COUNTER.fetch_add(1, Ordering::Relaxed);
drop(tx);
})
});
}

assert_eq!(COUNTER.load(Ordering::Relaxed), 0);

s.add_permits(50);

drop(tx);
let _ = rx.recv();

assert_eq!(COUNTER.load(Ordering::Relaxed), 50);
}

#[test]
fn add_permits_2() {
future::block_on(AddPermitsTest);
}

struct AddPermitsTest;

impl Future for AddPermitsTest {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let s = Semaphore::new(0);
let acq = s.acquire();
pin!(acq);
let acq_2 = s.acquire();
pin!(acq_2);
assert!(acq.as_mut().poll(cx).is_pending());
assert!(acq_2.as_mut().poll(cx).is_pending());
s.add_permits(1);
let g = acq.poll(cx);
assert!(g.is_ready());
assert!(acq_2.poll(cx).is_pending());

Poll::Ready(())
}
}

0 comments on commit 99a44a5

Please sign in to comment.