Skip to content

Commit

Permalink
sync: add Sender::{try_,}reserve_many (#6205)
Browse files Browse the repository at this point in the history
  • Loading branch information
Totodore authored Jan 2, 2024
1 parent 2d2faf6 commit 7c606ab
Show file tree
Hide file tree
Showing 6 changed files with 425 additions and 28 deletions.
26 changes: 13 additions & 13 deletions tokio/src/sync/batch_semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub struct AcquireError(());
pub(crate) struct Acquire<'a> {
node: Waiter,
semaphore: &'a Semaphore,
num_permits: u32,
num_permits: usize,
queued: bool,
}

Expand Down Expand Up @@ -262,13 +262,13 @@ impl Semaphore {
self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED
}

pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> {
pub(crate) fn try_acquire(&self, num_permits: usize) -> Result<(), TryAcquireError> {
assert!(
num_permits as usize <= Self::MAX_PERMITS,
num_permits <= Self::MAX_PERMITS,
"a semaphore may not have more than MAX_PERMITS permits ({})",
Self::MAX_PERMITS
);
let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT;
let num_permits = num_permits << Self::PERMIT_SHIFT;
let mut curr = self.permits.load(Acquire);
loop {
// Has the semaphore closed?
Expand All @@ -293,7 +293,7 @@ impl Semaphore {
}
}

pub(crate) fn acquire(&self, num_permits: u32) -> Acquire<'_> {
pub(crate) fn acquire(&self, num_permits: usize) -> Acquire<'_> {
Acquire::new(self, num_permits)
}

Expand Down Expand Up @@ -371,7 +371,7 @@ impl Semaphore {
fn poll_acquire(
&self,
cx: &mut Context<'_>,
num_permits: u32,
num_permits: usize,
node: Pin<&mut Waiter>,
queued: bool,
) -> Poll<Result<(), AcquireError>> {
Expand All @@ -380,7 +380,7 @@ impl Semaphore {
let needed = if queued {
node.state.load(Acquire) << Self::PERMIT_SHIFT
} else {
(num_permits as usize) << Self::PERMIT_SHIFT
num_permits << Self::PERMIT_SHIFT
};

let mut lock = None;
Expand Down Expand Up @@ -506,12 +506,12 @@ impl fmt::Debug for Semaphore {

impl Waiter {
fn new(
num_permits: u32,
num_permits: usize,
#[cfg(all(tokio_unstable, feature = "tracing"))] ctx: trace::AsyncOpTracingCtx,
) -> Self {
Waiter {
waker: UnsafeCell::new(None),
state: AtomicUsize::new(num_permits as usize),
state: AtomicUsize::new(num_permits),
pointers: linked_list::Pointers::new(),
#[cfg(all(tokio_unstable, feature = "tracing"))]
ctx,
Expand Down Expand Up @@ -591,7 +591,7 @@ impl Future for Acquire<'_> {
}

impl<'a> Acquire<'a> {
fn new(semaphore: &'a Semaphore, num_permits: u32) -> Self {
fn new(semaphore: &'a Semaphore, num_permits: usize) -> Self {
#[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
return Self {
node: Waiter::new(num_permits),
Expand Down Expand Up @@ -635,14 +635,14 @@ impl<'a> Acquire<'a> {
});
}

fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u32, &mut bool) {
fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, usize, &mut bool) {
fn is_unpin<T: Unpin>() {}
unsafe {
// Safety: all fields other than `node` are `Unpin`

is_unpin::<&Semaphore>();
is_unpin::<&mut bool>();
is_unpin::<u32>();
is_unpin::<usize>();

let this = self.get_unchecked_mut();
(
Expand Down Expand Up @@ -673,7 +673,7 @@ impl Drop for Acquire<'_> {
// Safety: we have locked the wait list.
unsafe { waiters.queue.remove(node) };

let acquired_permits = self.num_permits as usize - self.node.state.load(Acquire);
let acquired_permits = self.num_permits - self.node.state.load(Acquire);
if acquired_permits > 0 {
self.semaphore.add_permits_locked(acquired_permits, waiters);
}
Expand Down
224 changes: 220 additions & 4 deletions tokio/src/sync/mpsc/bounded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ pub struct Permit<'a, T> {
chan: &'a chan::Tx<T, Semaphore>,
}

/// An [`Iterator`] of [`Permit`] that can be used to hold `n` slots in the channel.
///
/// `PermitIterator` values are returned by [`Sender::reserve_many()`] and [`Sender::try_reserve_many()`]
/// and are used to guarantee channel capacity before generating `n` messages to send.
///
/// [`Sender::reserve_many()`]: Sender::reserve_many
/// [`Sender::try_reserve_many()`]: Sender::try_reserve_many
pub struct PermitIterator<'a, T> {
chan: &'a chan::Tx<T, Semaphore>,
n: usize,
}

/// Owned permit to send one value into the channel.
///
/// This is identical to the [`Permit`] type, except that it moves the sender
Expand Down Expand Up @@ -926,10 +938,74 @@ impl<T> Sender<T> {
/// }
/// ```
pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
self.reserve_inner().await?;
self.reserve_inner(1).await?;
Ok(Permit { chan: &self.chan })
}

/// Waits for channel capacity. Once capacity to send `n` messages is
/// available, it is reserved for the caller.
///
/// If the channel is full or if there are fewer than `n` permits available, the function waits
/// for the number of unreceived messages to become `n` less than the channel capacity.
/// Capacity to send `n` message is then reserved for the caller.
///
/// A [`PermitIterator`] is returned to track the reserved capacity.
/// You can call this [`Iterator`] until it is exhausted to
/// get a [`Permit`] and then call [`Permit::send`]. This function is similar to
/// [`try_reserve_many`] except it awaits for the slots to become available.
///
/// If the channel is closed, the function returns a [`SendError`].
///
/// Dropping [`PermitIterator`] without consuming it entirely releases the remaining
/// permits back to the channel.
///
/// [`PermitIterator`]: PermitIterator
/// [`Permit`]: Permit
/// [`send`]: Permit::send
/// [`try_reserve_many`]: Sender::try_reserve_many
///
/// # Cancel safety
///
/// This channel uses a queue to ensure that calls to `send` and `reserve_many`
/// complete in the order they were requested. Cancelling a call to
/// `reserve_many` makes you lose your place in the queue.
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::channel(2);
///
/// // Reserve capacity
/// let mut permit = tx.reserve_many(2).await.unwrap();
///
/// // Trying to send directly on the `tx` will fail due to no
/// // available capacity.
/// assert!(tx.try_send(123).is_err());
///
/// // Sending with the permit iterator succeeds
/// permit.next().unwrap().send(456);
/// permit.next().unwrap().send(457);
///
/// // The iterator should now be exhausted
/// assert!(permit.next().is_none());
///
/// // The value sent on the permit is received
/// assert_eq!(rx.recv().await.unwrap(), 456);
/// assert_eq!(rx.recv().await.unwrap(), 457);
/// }
/// ```
pub async fn reserve_many(&self, n: usize) -> Result<PermitIterator<'_, T>, SendError<()>> {
self.reserve_inner(n).await?;
Ok(PermitIterator {
chan: &self.chan,
n,
})
}

/// Waits for channel capacity, moving the `Sender` and returning an owned
/// permit. Once capacity to send one message is available, it is reserved
/// for the caller.
Expand Down Expand Up @@ -1011,16 +1087,19 @@ impl<T> Sender<T> {
/// [`send`]: OwnedPermit::send
/// [`Arc::clone`]: std::sync::Arc::clone
pub async fn reserve_owned(self) -> Result<OwnedPermit<T>, SendError<()>> {
self.reserve_inner().await?;
self.reserve_inner(1).await?;
Ok(OwnedPermit {
chan: Some(self.chan),
})
}

async fn reserve_inner(&self) -> Result<(), SendError<()>> {
async fn reserve_inner(&self, n: usize) -> Result<(), SendError<()>> {
crate::trace::async_trace_leaf().await;

match self.chan.semaphore().semaphore.acquire(1).await {
if n > self.max_capacity() {
return Err(SendError(()));
}
match self.chan.semaphore().semaphore.acquire(n).await {
Ok(()) => Ok(()),
Err(_) => Err(SendError(())),
}
Expand Down Expand Up @@ -1079,6 +1158,91 @@ impl<T> Sender<T> {
Ok(Permit { chan: &self.chan })
}

/// Tries to acquire `n` slots in the channel without waiting for the slot to become
/// available.
///
/// A [`PermitIterator`] is returned to track the reserved capacity.
/// You can call this [`Iterator`] until it is exhausted to
/// get a [`Permit`] and then call [`Permit::send`]. This function is similar to
/// [`reserve_many`] except it does not await for the slots to become available.
///
/// If there are fewer than `n` permits available on the channel, then
/// this function will return a [`TrySendError::Full`]. If the channel is closed
/// this function will return a [`TrySendError::Closed`].
///
/// Dropping [`PermitIterator`] without consuming it entirely releases the remaining
/// permits back to the channel.
///
/// [`PermitIterator`]: PermitIterator
/// [`send`]: Permit::send
/// [`reserve_many`]: Sender::reserve_many
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::channel(2);
///
/// // Reserve capacity
/// let mut permit = tx.try_reserve_many(2).unwrap();
///
/// // Trying to send directly on the `tx` will fail due to no
/// // available capacity.
/// assert!(tx.try_send(123).is_err());
///
/// // Trying to reserve an additional slot on the `tx` will
/// // fail because there is no capacity.
/// assert!(tx.try_reserve().is_err());
///
/// // Sending with the permit iterator succeeds
/// permit.next().unwrap().send(456);
/// permit.next().unwrap().send(457);
///
/// // The iterator should now be exhausted
/// assert!(permit.next().is_none());
///
/// // The value sent on the permit is received
/// assert_eq!(rx.recv().await.unwrap(), 456);
/// assert_eq!(rx.recv().await.unwrap(), 457);
///
/// // Trying to call try_reserve_many with 0 will return an empty iterator
/// let mut permit = tx.try_reserve_many(0).unwrap();
/// assert!(permit.next().is_none());
///
/// // Trying to call try_reserve_many with a number greater than the channel
/// // capacity will return an error
/// let permit = tx.try_reserve_many(3);
/// assert!(permit.is_err());
///
/// // Trying to call try_reserve_many on a closed channel will return an error
/// drop(rx);
/// let permit = tx.try_reserve_many(1);
/// assert!(permit.is_err());
///
/// let permit = tx.try_reserve_many(0);
/// assert!(permit.is_err());
/// }
/// ```
pub fn try_reserve_many(&self, n: usize) -> Result<PermitIterator<'_, T>, TrySendError<()>> {
if n > self.max_capacity() {
return Err(TrySendError::Full(()));
}

match self.chan.semaphore().semaphore.try_acquire(n) {
Ok(()) => {}
Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(())),
Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(())),
}

Ok(PermitIterator {
chan: &self.chan,
n,
})
}

/// Tries to acquire a slot in the channel without waiting for the slot to become
/// available, returning an owned permit.
///
Expand Down Expand Up @@ -1355,6 +1519,58 @@ impl<T> fmt::Debug for Permit<'_, T> {
}
}

// ===== impl PermitIterator =====

impl<'a, T> Iterator for PermitIterator<'a, T> {
type Item = Permit<'a, T>;

fn next(&mut self) -> Option<Self::Item> {
if self.n == 0 {
return None;
}

self.n -= 1;
Some(Permit { chan: self.chan })
}

fn size_hint(&self) -> (usize, Option<usize>) {
let n = self.n;
(n, Some(n))
}
}
impl<T> ExactSizeIterator for PermitIterator<'_, T> {}
impl<T> std::iter::FusedIterator for PermitIterator<'_, T> {}

impl<T> Drop for PermitIterator<'_, T> {
fn drop(&mut self) {
use chan::Semaphore;

if self.n == 0 {
return;
}

let semaphore = self.chan.semaphore();

// Add the remaining permits back to the semaphore
semaphore.add_permits(self.n);

// If this is the last sender for this channel, wake the receiver so
// that it can be notified that the channel is closed.
if semaphore.is_closed() && semaphore.is_idle() {
self.chan.wake_rx();
}
}
}

impl<T> fmt::Debug for PermitIterator<'_, T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("PermitIterator")
.field("chan", &self.chan)
.field("capacity", &self.n)
.finish()
}
}

// ===== impl Permit =====

impl<T> OwnedPermit<T> {
Expand Down
4 changes: 3 additions & 1 deletion tokio/src/sync/mpsc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@
pub(super) mod block;

mod bounded;
pub use self::bounded::{channel, OwnedPermit, Permit, Receiver, Sender, WeakSender};
pub use self::bounded::{
channel, OwnedPermit, Permit, PermitIterator, Receiver, Sender, WeakSender,
};

mod chan;

Expand Down
Loading

0 comments on commit 7c606ab

Please sign in to comment.