Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add a way to reserve many permits on bounded mpsc channel #6205

Merged
merged 45 commits into from Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
14d0af6
mpsc: add a way to reserve many permit on bounded mpsc chan
Totodore Dec 9, 2023
92c26b5
mpsc: apply clippy rules + fmt
Totodore Dec 9, 2023
b684f3c
mpsc: export the `ManyPermit` struct
Totodore Dec 9, 2023
c40178b
mpsc: fix fmt
Totodore Dec 9, 2023
53fb399
mpsc: return an `Iterator` impl for the `*_reserve_many` fns
Totodore Dec 10, 2023
242f171
mpsc: fix rustdoc
Totodore Dec 10, 2023
81c91da
mpsc: fix fmt
Totodore Dec 10, 2023
ad21c6c
mpsc: add an implementation for `ExactSizeIterator`
Totodore Dec 10, 2023
71c48ba
mpsc: apply doc suggestions
Totodore Dec 10, 2023
8dd3078
mpsc: fix `Debug` implementation for `PermitIterator`
Totodore Dec 10, 2023
9d9435b
mpsc: move the `u32` n to `usize` and only cast to `u32` for the inte…
Totodore Dec 10, 2023
28d8c51
mpsc: improve doc
Totodore Dec 10, 2023
08a8622
mpsc: change the internal `batch_semaphore` API to take usize permit …
Totodore Dec 10, 2023
5f9cb1a
mpsc: fix usize casting in special features
Totodore Dec 10, 2023
5bcc3f9
mpsc: fix usize casting in `try_reserve` fn for batch_semaphore
Totodore Dec 11, 2023
a3fd29e
mpsc: return a `SendError` if `n` > MAX_PERMIT for `*reserve_many` fns
Totodore Dec 11, 2023
419bfd3
mpsc: return a `SendError` if `n` > max_capacity for `*reserve_many` fns
Totodore Dec 11, 2023
006d7aa
Merge branch 'master' into feat-mpsc-many-permit
Totodore Dec 12, 2023
8f9d4c1
mpsc: add `try_reserve_many_fails` test
Totodore Dec 13, 2023
59944ba
mpsc: fix fmt
Totodore Dec 13, 2023
bffa2fa
mpsc: switch to `assert_ok` expr for testing
Totodore Dec 17, 2023
d9bf134
mpsc: return an empty iterator for `try_reserve_many(0)`
Totodore Dec 17, 2023
2ee4f76
mpsc: test `PermitIterator` and `reserve_many`
Totodore Dec 17, 2023
6a1ffbf
mpsc: fix fmt
Totodore Dec 17, 2023
d585434
Merge branch 'master' into feat-mpsc-many-permit
Totodore Dec 21, 2023
323e6b6
mpsc: fix `reserve_many_and_send` test
Totodore Dec 22, 2023
5496399
mpsc: impl `FusedIterator` for `PermitIterator`
Totodore Dec 22, 2023
ddd44d1
mpsc: test `reserve_many_on_closed_channel`
Totodore Dec 23, 2023
f103641
mpsc: doc mention `reserve_many` for Cancel Safety part
Totodore Dec 23, 2023
b688a05
mpsc: improve doc for `reserve_many` fn
Totodore Dec 23, 2023
3b705be
mpsc: fix formatting
Totodore Dec 23, 2023
430286b
mpsc: fix formatting
Totodore Dec 23, 2023
58f7648
mpsc: switch to `maybe_tokio_test`
Totodore Dec 23, 2023
a25fb7c
mpsc: add tests for `try_reserve_many`
Totodore Dec 23, 2023
2bd5c8b
Merge branch 'master' into feat-mpsc-many-permit
Totodore Dec 23, 2023
bd8c3e6
mpsc: add an early return if `n == 0` to avoid `acquire` mechanism
Totodore Dec 23, 2023
6988079
mpsc: improve doc for `reserve_many`
Totodore Dec 23, 2023
133f2c1
mpsc: remove early return for `reserve_inner`
Totodore Dec 23, 2023
c39878e
mpsc: remove useless empty iterator guard for `try_reserve_many`
Totodore Jan 2, 2024
f7025ce
mpsc: apply doc suggestion
Totodore Jan 2, 2024
0d19901
mpsc: Apply suggestions from code review
Totodore Jan 2, 2024
3fdb5d9
mpsc: fix `sync_mpsc` tests
Totodore Jan 2, 2024
e5f1df8
mpsc: fix `sync_mpsc` tests
Totodore Jan 2, 2024
830188e
Merge branch 'master' into feat-mpsc-many-permit
Totodore Jan 2, 2024
ec8b537
mpsc: early return for empty `PermitIterator` for drop logic
Totodore Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 13 additions & 13 deletions tokio/src/sync/batch_semaphore.rs
Expand Up @@ -71,7 +71,7 @@ pub struct AcquireError(());
pub(crate) struct Acquire<'a> {
node: Waiter,
semaphore: &'a Semaphore,
num_permits: u32,
num_permits: usize,
queued: bool,
}

Expand Down Expand Up @@ -262,13 +262,13 @@ impl Semaphore {
self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED
}

pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> {
pub(crate) fn try_acquire(&self, num_permits: usize) -> Result<(), TryAcquireError> {
assert!(
num_permits as usize <= Self::MAX_PERMITS,
num_permits <= Self::MAX_PERMITS,
"a semaphore may not have more than MAX_PERMITS permits ({})",
Self::MAX_PERMITS
);
let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT;
let num_permits = num_permits << Self::PERMIT_SHIFT;
let mut curr = self.permits.load(Acquire);
loop {
// Has the semaphore closed?
Expand All @@ -293,7 +293,7 @@ impl Semaphore {
}
}

pub(crate) fn acquire(&self, num_permits: u32) -> Acquire<'_> {
pub(crate) fn acquire(&self, num_permits: usize) -> Acquire<'_> {
Acquire::new(self, num_permits)
}

Expand Down Expand Up @@ -371,7 +371,7 @@ impl Semaphore {
fn poll_acquire(
&self,
cx: &mut Context<'_>,
num_permits: u32,
num_permits: usize,
node: Pin<&mut Waiter>,
queued: bool,
) -> Poll<Result<(), AcquireError>> {
Expand All @@ -380,7 +380,7 @@ impl Semaphore {
let needed = if queued {
node.state.load(Acquire) << Self::PERMIT_SHIFT
} else {
(num_permits as usize) << Self::PERMIT_SHIFT
num_permits << Self::PERMIT_SHIFT
};

let mut lock = None;
Expand Down Expand Up @@ -506,12 +506,12 @@ impl fmt::Debug for Semaphore {

impl Waiter {
fn new(
num_permits: u32,
num_permits: usize,
#[cfg(all(tokio_unstable, feature = "tracing"))] ctx: trace::AsyncOpTracingCtx,
) -> Self {
Waiter {
waker: UnsafeCell::new(None),
state: AtomicUsize::new(num_permits as usize),
state: AtomicUsize::new(num_permits),
pointers: linked_list::Pointers::new(),
#[cfg(all(tokio_unstable, feature = "tracing"))]
ctx,
Expand Down Expand Up @@ -591,7 +591,7 @@ impl Future for Acquire<'_> {
}

impl<'a> Acquire<'a> {
fn new(semaphore: &'a Semaphore, num_permits: u32) -> Self {
fn new(semaphore: &'a Semaphore, num_permits: usize) -> Self {
#[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
return Self {
node: Waiter::new(num_permits),
Expand Down Expand Up @@ -635,14 +635,14 @@ impl<'a> Acquire<'a> {
});
}

fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u32, &mut bool) {
fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, usize, &mut bool) {
fn is_unpin<T: Unpin>() {}
unsafe {
// Safety: all fields other than `node` are `Unpin`

is_unpin::<&Semaphore>();
is_unpin::<&mut bool>();
is_unpin::<u32>();
is_unpin::<usize>();

let this = self.get_unchecked_mut();
(
Expand Down Expand Up @@ -673,7 +673,7 @@ impl Drop for Acquire<'_> {
// Safety: we have locked the wait list.
unsafe { waiters.queue.remove(node) };

let acquired_permits = self.num_permits as usize - self.node.state.load(Acquire);
let acquired_permits = self.num_permits - self.node.state.load(Acquire);
if acquired_permits > 0 {
self.semaphore.add_permits_locked(acquired_permits, waiters);
}
Expand Down
220 changes: 216 additions & 4 deletions tokio/src/sync/mpsc/bounded.rs
Expand Up @@ -68,6 +68,18 @@ pub struct Permit<'a, T> {
chan: &'a chan::Tx<T, Semaphore>,
}

/// An [`Iterator`] of [`Permit`] that can be used to hold `n` slots in the channel.
///
/// `PermitIterator` values are returned by [`Sender::reserve_many()`] and [`Sender::try_reserve_many()`]
/// and are used to guarantee channel capacity before generating `n` messages to send.
///
/// [`Sender::reserve_many()`]: Sender::reserve_many
/// [`Sender::try_reserve_many()`]: Sender::try_reserve_many
pub struct PermitIterator<'a, T> {
chan: &'a chan::Tx<T, Semaphore>,
n: usize,
}

/// Owned permit to send one value into the channel.
///
/// This is identical to the [`Permit`] type, except that it moves the sender
Expand Down Expand Up @@ -926,10 +938,74 @@ impl<T> Sender<T> {
/// }
/// ```
pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
self.reserve_inner().await?;
self.reserve_inner(1).await?;
Ok(Permit { chan: &self.chan })
}

/// Waits for channel capacity. Once capacity to send `n` messages is
/// available, it is reserved for the caller.
///
/// If the channel is full or if there are fewer than `n` permits available, the function waits
/// for the number of unreceived messages to become `n` less than the channel capacity.
/// Capacity to send `n` message is then reserved for the caller.
///
/// A [`PermitIterator`] is returned to track the reserved capacity.
/// You can call this [`Iterator`] until it is exhausted to
/// get a [`Permit`] and then call [`Permit::send`]. This function is similar to
/// [`try_reserve_many`] except it awaits for the slots to become available.
///
/// If the channel is closed, the function returns a [`SendError`].
///
/// Dropping [`PermitIterator`] without consuming it entirely releases the remaining
/// permits back to the channel.
///
/// [`PermitIterator`]: PermitIterator
/// [`Permit`]: Permit
/// [`send`]: Permit::send
/// [`try_reserve_many`]: Sender::try_reserve_many
///
/// # Cancel safety
///
/// This channel uses a queue to ensure that calls to `send` and `reserve_many`
/// complete in the order they were requested. Cancelling a call to
/// `reserve_many` makes you lose your place in the queue.
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::channel(2);
///
/// // Reserve capacity
/// let mut permit = tx.reserve_many(2).await.unwrap();
///
/// // Trying to send directly on the `tx` will fail due to no
/// // available capacity.
/// assert!(tx.try_send(123).is_err());
///
/// // Sending with the permit iterator succeeds
/// permit.next().unwrap().send(456);
/// permit.next().unwrap().send(457);
///
/// // The iterator should now be exhausted
/// assert!(permit.next().is_none());
///
/// // The value sent on the permit is received
/// assert_eq!(rx.recv().await.unwrap(), 456);
/// assert_eq!(rx.recv().await.unwrap(), 457);
/// }
/// ```
pub async fn reserve_many(&self, n: usize) -> Result<PermitIterator<'_, T>, SendError<()>> {
Totodore marked this conversation as resolved.
Show resolved Hide resolved
self.reserve_inner(n).await?;
Ok(PermitIterator {
chan: &self.chan,
n,
})
}
Totodore marked this conversation as resolved.
Show resolved Hide resolved

/// Waits for channel capacity, moving the `Sender` and returning an owned
/// permit. Once capacity to send one message is available, it is reserved
/// for the caller.
Expand Down Expand Up @@ -1011,16 +1087,19 @@ impl<T> Sender<T> {
/// [`send`]: OwnedPermit::send
/// [`Arc::clone`]: std::sync::Arc::clone
pub async fn reserve_owned(self) -> Result<OwnedPermit<T>, SendError<()>> {
self.reserve_inner().await?;
self.reserve_inner(1).await?;
Ok(OwnedPermit {
chan: Some(self.chan),
})
}

async fn reserve_inner(&self) -> Result<(), SendError<()>> {
async fn reserve_inner(&self, n: usize) -> Result<(), SendError<()>> {
crate::trace::async_trace_leaf().await;

match self.chan.semaphore().semaphore.acquire(1).await {
if n > self.max_capacity() {
return Err(SendError(()));
}
match self.chan.semaphore().semaphore.acquire(n).await {
Ok(()) => Ok(()),
Err(_) => Err(SendError(())),
}
Expand Down Expand Up @@ -1079,6 +1158,91 @@ impl<T> Sender<T> {
Ok(Permit { chan: &self.chan })
}

/// Tries to acquire `n` slots in the channel without waiting for the slot to become
/// available.
///
/// A [`PermitIterator`] is returned to track the reserved capacity.
/// You can call this [`Iterator`] until it is exhausted to
/// get a [`Permit`] and then call [`Permit::send`]. This function is similar to
/// [`reserve_many`] except it does not await for the slots to become available.
///
/// If there are fewer than `n` permits available on the channel, then
/// this function will return a [`TrySendError::Full`]. If the channel is closed
/// this function will return a [`TrySendError::Closed`].
///
/// Dropping [`PermitIterator`] without consuming it entirely releases the remaining
/// permits back to the channel.
///
/// [`PermitIterator`]: PermitIterator
/// [`send`]: Permit::send
/// [`reserve_many`]: Sender::reserve_many
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::channel(2);
///
/// // Reserve capacity
/// let mut permit = tx.try_reserve_many(2).unwrap();
///
/// // Trying to send directly on the `tx` will fail due to no
/// // available capacity.
/// assert!(tx.try_send(123).is_err());
///
/// // Trying to reserve an additional slot on the `tx` will
/// // fail because there is no capacity.
/// assert!(tx.try_reserve().is_err());
///
/// // Sending with the permit iterator succeeds
/// permit.next().unwrap().send(456);
/// permit.next().unwrap().send(457);
///
/// // The iterator should now be exhausted
/// assert!(permit.next().is_none());
///
/// // The value sent on the permit is received
/// assert_eq!(rx.recv().await.unwrap(), 456);
/// assert_eq!(rx.recv().await.unwrap(), 457);
///
/// // Trying to call try_reserve_many with 0 will return an empty iterator
/// let mut permit = tx.try_reserve_many(0).unwrap();
/// assert!(permit.next().is_none());
///
/// // Trying to call try_reserve_many with a number greater than the channel
/// // capacity will return an error
/// let permit = tx.try_reserve_many(3);
/// assert!(permit.is_err());
///
/// // Trying to call try_reserve_many on a closed channel will return an error
/// drop(rx);
/// let permit = tx.try_reserve_many(1);
/// assert!(permit.is_err());
///
/// let permit = tx.try_reserve_many(0);
/// assert!(permit.is_err());
/// }
/// ```
pub fn try_reserve_many(&self, n: usize) -> Result<PermitIterator<'_, T>, TrySendError<()>> {
if n > self.max_capacity() {
return Err(TrySendError::Full(()));
}

match self.chan.semaphore().semaphore.try_acquire(n) {
Ok(()) => {}
Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(())),
Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(())),
}

Ok(PermitIterator {
chan: &self.chan,
n,
})
}

/// Tries to acquire a slot in the channel without waiting for the slot to become
/// available, returning an owned permit.
///
Expand Down Expand Up @@ -1355,6 +1519,54 @@ impl<T> fmt::Debug for Permit<'_, T> {
}
}

// ===== impl PermitIterator =====

impl<'a, T> Iterator for PermitIterator<'a, T> {
type Item = Permit<'a, T>;

fn next(&mut self) -> Option<Self::Item> {
if self.n == 0 {
return None;
}

self.n -= 1;
Some(Permit { chan: self.chan })
}

fn size_hint(&self) -> (usize, Option<usize>) {
let n = self.n;
(n, Some(n))
}
}
impl<T> ExactSizeIterator for PermitIterator<'_, T> {}
Totodore marked this conversation as resolved.
Show resolved Hide resolved
impl<T> std::iter::FusedIterator for PermitIterator<'_, T> {}

impl<T> Drop for PermitIterator<'_, T> {
fn drop(&mut self) {
use chan::Semaphore;

let semaphore = self.chan.semaphore();

// Add the remaining permits back to the semaphore
semaphore.add_permits(self.n);

// If this is the last sender for this channel, wake the receiver so
// that it can be notified that the channel is closed.
if semaphore.is_closed() && semaphore.is_idle() {
self.chan.wake_rx();
}
}
}
Totodore marked this conversation as resolved.
Show resolved Hide resolved

impl<T> fmt::Debug for PermitIterator<'_, T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("PermitIterator")
.field("chan", &self.chan)
.field("capacity", &self.n)
.finish()
}
}

// ===== impl Permit =====

impl<T> OwnedPermit<T> {
Expand Down
4 changes: 3 additions & 1 deletion tokio/src/sync/mpsc/mod.rs
Expand Up @@ -95,7 +95,9 @@
pub(super) mod block;

mod bounded;
pub use self::bounded::{channel, OwnedPermit, Permit, Receiver, Sender, WeakSender};
pub use self::bounded::{
channel, OwnedPermit, Permit, PermitIterator, Receiver, Sender, WeakSender,
};

mod chan;

Expand Down