Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add a way to reserve many permits on bounded mpsc channel #6205

Merged
merged 45 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
14d0af6
mpsc: add a way to reserve many permit on bounded mpsc chan
Totodore Dec 9, 2023
92c26b5
mpsc: apply clippy rules + fmt
Totodore Dec 9, 2023
b684f3c
mpsc: export the `ManyPermit` struct
Totodore Dec 9, 2023
c40178b
mpsc: fix fmt
Totodore Dec 9, 2023
53fb399
mpsc: return an `Iterator` impl for the `*_reserve_many` fns
Totodore Dec 10, 2023
242f171
mpsc: fix rustdoc
Totodore Dec 10, 2023
81c91da
mpsc: fix fmt
Totodore Dec 10, 2023
ad21c6c
mpsc: add an implementation for `ExactSizeIterator`
Totodore Dec 10, 2023
71c48ba
mpsc: apply doc suggestions
Totodore Dec 10, 2023
8dd3078
mpsc: fix `Debug` implementation for `PermitIterator`
Totodore Dec 10, 2023
9d9435b
mpsc: move the `u32` n to `usize` and only cast to `u32` for the inte…
Totodore Dec 10, 2023
28d8c51
mpsc: improve doc
Totodore Dec 10, 2023
08a8622
mpsc: change the internal `batch_semaphore` API to take usize permit …
Totodore Dec 10, 2023
5f9cb1a
mpsc: fix usize casting in special features
Totodore Dec 10, 2023
5bcc3f9
mpsc: fix usize casting in `try_reserve` fn for batch_semaphore
Totodore Dec 11, 2023
a3fd29e
mpsc: return a `SendError` if `n` > MAX_PERMIT for `*reserve_many` fns
Totodore Dec 11, 2023
419bfd3
mpsc: return a `SendError` if `n` > max_capacity for `*reserve_many` fns
Totodore Dec 11, 2023
006d7aa
Merge branch 'master' into feat-mpsc-many-permit
Totodore Dec 12, 2023
8f9d4c1
mpsc: add `try_reserve_many_fails` test
Totodore Dec 13, 2023
59944ba
mpsc: fix fmt
Totodore Dec 13, 2023
bffa2fa
mpsc: switch to `assert_ok` expr for testing
Totodore Dec 17, 2023
d9bf134
mpsc: return an empty iterator for `try_reserve_many(0)`
Totodore Dec 17, 2023
2ee4f76
mpsc: test `PermitIterator` and `reserve_many`
Totodore Dec 17, 2023
6a1ffbf
mpsc: fix fmt
Totodore Dec 17, 2023
d585434
Merge branch 'master' into feat-mpsc-many-permit
Totodore Dec 21, 2023
323e6b6
mpsc: fix `reserve_many_and_send` test
Totodore Dec 22, 2023
5496399
mpsc: impl `FusedIterator` for `PermitIterator`
Totodore Dec 22, 2023
ddd44d1
mpsc: test `reserve_many_on_closed_channel`
Totodore Dec 23, 2023
f103641
mpsc: doc mention `reserve_many` for Cancel Safety part
Totodore Dec 23, 2023
b688a05
mpsc: improve doc for `reserve_many` fn
Totodore Dec 23, 2023
3b705be
mpsc: fix formatting
Totodore Dec 23, 2023
430286b
mpsc: fix formatting
Totodore Dec 23, 2023
58f7648
mpsc: switch to `maybe_tokio_test`
Totodore Dec 23, 2023
a25fb7c
mpsc: add tests for `try_reserve_many`
Totodore Dec 23, 2023
2bd5c8b
Merge branch 'master' into feat-mpsc-many-permit
Totodore Dec 23, 2023
bd8c3e6
mpsc: add an early return if `n == 0` to avoid `acquire` mechanism
Totodore Dec 23, 2023
6988079
mpsc: improve doc for `reserve_many`
Totodore Dec 23, 2023
133f2c1
mpsc: remove early return for `reserve_inner`
Totodore Dec 23, 2023
c39878e
mpsc: remove useless empty iterator guard for `try_reserve_many`
Totodore Jan 2, 2024
f7025ce
mpsc: apply doc suggestion
Totodore Jan 2, 2024
0d19901
mpsc: Apply suggestions from code review
Totodore Jan 2, 2024
3fdb5d9
mpsc: fix `sync_mpsc` tests
Totodore Jan 2, 2024
e5f1df8
mpsc: fix `sync_mpsc` tests
Totodore Jan 2, 2024
830188e
Merge branch 'master' into feat-mpsc-many-permit
Totodore Jan 2, 2024
ec8b537
mpsc: early return for empty `PermitIterator` for drop logic
Totodore Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 9 additions & 9 deletions tokio/src/sync/batch_semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub struct AcquireError(());
pub(crate) struct Acquire<'a> {
node: Waiter,
semaphore: &'a Semaphore,
num_permits: u32,
num_permits: usize,
queued: bool,
}

Expand Down Expand Up @@ -293,7 +293,7 @@ impl Semaphore {
}
}

pub(crate) fn acquire(&self, num_permits: u32) -> Acquire<'_> {
pub(crate) fn acquire(&self, num_permits: usize) -> Acquire<'_> {
Acquire::new(self, num_permits)
}

Expand Down Expand Up @@ -371,7 +371,7 @@ impl Semaphore {
fn poll_acquire(
&self,
cx: &mut Context<'_>,
num_permits: u32,
num_permits: usize,
node: Pin<&mut Waiter>,
queued: bool,
) -> Poll<Result<(), AcquireError>> {
Expand All @@ -380,7 +380,7 @@ impl Semaphore {
let needed = if queued {
node.state.load(Acquire) << Self::PERMIT_SHIFT
} else {
(num_permits as usize) << Self::PERMIT_SHIFT
num_permits << Self::PERMIT_SHIFT
};

let mut lock = None;
Expand Down Expand Up @@ -506,12 +506,12 @@ impl fmt::Debug for Semaphore {

impl Waiter {
fn new(
num_permits: u32,
num_permits: usize,
#[cfg(all(tokio_unstable, feature = "tracing"))] ctx: trace::AsyncOpTracingCtx,
) -> Self {
Waiter {
waker: UnsafeCell::new(None),
state: AtomicUsize::new(num_permits as usize),
state: AtomicUsize::new(num_permits),
pointers: linked_list::Pointers::new(),
#[cfg(all(tokio_unstable, feature = "tracing"))]
ctx,
Expand Down Expand Up @@ -591,7 +591,7 @@ impl Future for Acquire<'_> {
}

impl<'a> Acquire<'a> {
fn new(semaphore: &'a Semaphore, num_permits: u32) -> Self {
fn new(semaphore: &'a Semaphore, num_permits: usize) -> Self {
#[cfg(any(not(tokio_unstable), not(feature = "tracing")))]
return Self {
node: Waiter::new(num_permits),
Expand Down Expand Up @@ -635,7 +635,7 @@ impl<'a> Acquire<'a> {
});
}

fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u32, &mut bool) {
fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, usize, &mut bool) {
fn is_unpin<T: Unpin>() {}
unsafe {
// Safety: all fields other than `node` are `Unpin`
Expand Down Expand Up @@ -673,7 +673,7 @@ impl Drop for Acquire<'_> {
// Safety: we have locked the wait list.
unsafe { waiters.queue.remove(node) };

let acquired_permits = self.num_permits as usize - self.node.state.load(Acquire);
let acquired_permits = self.num_permits - self.node.state.load(Acquire);
if acquired_permits > 0 {
self.semaphore.add_permits_locked(acquired_permits, waiters);
}
Expand Down
187 changes: 183 additions & 4 deletions tokio/src/sync/mpsc/bounded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ pub struct Permit<'a, T> {
chan: &'a chan::Tx<T, Semaphore>,
}

/// An [`Iterator`] of [`Permit`] that can be used to reserve `n` slots in the channel.
Totodore marked this conversation as resolved.
Show resolved Hide resolved
///
/// `PermitIterator` values are returned by [`Sender::reserve_many()`] and [`Sender::try_reserve_many()`]
/// and are used to guarantee channel capacity before generating `n` message to send.
Totodore marked this conversation as resolved.
Show resolved Hide resolved
///
/// [`Sender::reserve_many()`]: Sender::reserve_many
/// [`Sender::try_reserve_many()`]: Sender::try_reserve_many
pub struct PermitIterator<'a, T> {
chan: &'a chan::Tx<T, Semaphore>,
n: usize,
}

/// Owned permit to send one value into the channel.
///
/// This is identical to the [`Permit`] type, except that it moves the sender
Expand Down Expand Up @@ -849,10 +861,68 @@ impl<T> Sender<T> {
/// }
/// ```
pub async fn reserve(&self) -> Result<Permit<'_, T>, SendError<()>> {
self.reserve_inner().await?;
self.reserve_inner(1).await?;
Ok(Permit { chan: &self.chan })
}

/// Waits for channel capacity. Once capacity to send `n` messages is
/// available, it is reserved for the caller.
///
/// If the channel is full, the function waits for the number of unreceived
/// messages to become `n` less than the channel capacity. Capacity to send `n`
/// message is reserved for the caller. A [`PermitIterator`] is returned to track
/// the reserved capacity. You can call this [`Iterator`] until it is exhausted to
/// get a [`Permit`] and then call [`Permit::send`].
///
/// Dropping [`PermitIterator`] without sending all messages releases the capacity back
/// to the channel.
Totodore marked this conversation as resolved.
Show resolved Hide resolved
///
/// [`PermitIterator`]: PermitIterator
/// [`Permit`]: Permit
/// [`send`]: Permit::send
///
/// # Cancel safety
///
/// This channel uses a queue to ensure that calls to `send` and `reserve`
/// complete in the order they were requested. Cancelling a call to
/// `reserve` makes you lose your place in the queue.
Totodore marked this conversation as resolved.
Show resolved Hide resolved
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::channel(2);
///
/// // Reserve capacity
/// let mut permit = tx.reserve_many(2).await.unwrap();
///
/// // Trying to send directly on the `tx` will fail due to no
/// // available capacity.
/// assert!(tx.try_send(123).is_err());
///
/// // Sending with the permit iterator succeeds
/// permit.next().unwrap().send(456);
/// permit.next().unwrap().send(457);
///
/// // The iterator should now be exhausted
/// assert!(permit.next().is_none());
///
/// // The value sent on the permit is received
/// assert_eq!(rx.recv().await.unwrap(), 456);
/// assert_eq!(rx.recv().await.unwrap(), 457);
/// }
/// ```
pub async fn reserve_many(&self, n: usize) -> Result<PermitIterator<'_, T>, SendError<()>> {
Totodore marked this conversation as resolved.
Show resolved Hide resolved
self.reserve_inner(n).await?;
Ok(PermitIterator {
chan: &self.chan,
n,
})
}
Totodore marked this conversation as resolved.
Show resolved Hide resolved

/// Waits for channel capacity, moving the `Sender` and returning an owned
/// permit. Once capacity to send one message is available, it is reserved
/// for the caller.
Expand Down Expand Up @@ -934,16 +1004,16 @@ impl<T> Sender<T> {
/// [`send`]: OwnedPermit::send
/// [`Arc::clone`]: std::sync::Arc::clone
pub async fn reserve_owned(self) -> Result<OwnedPermit<T>, SendError<()>> {
self.reserve_inner().await?;
self.reserve_inner(1).await?;
Ok(OwnedPermit {
chan: Some(self.chan),
})
}

async fn reserve_inner(&self) -> Result<(), SendError<()>> {
async fn reserve_inner(&self, n: usize) -> Result<(), SendError<()>> {
crate::trace::async_trace_leaf().await;

match self.chan.semaphore().semaphore.acquire(1).await {
match self.chan.semaphore().semaphore.acquire(n).await {
Ok(()) => Ok(()),
Err(_) => Err(SendError(())),
}
Expand Down Expand Up @@ -1002,6 +1072,68 @@ impl<T> Sender<T> {
Ok(Permit { chan: &self.chan })
}

/// Tries to acquire `n` slot in the channel without waiting for the slot to become
/// available.
///
/// If the channel is full this function will return a [`TrySendError`], otherwise
/// A [`PermitIterator`] is returned to track the reserved capacity.
/// You can call this [`Iterator`] until it is exhausted to
/// get a [`Permit`] and then call [`Permit::send`]. This function is similar to
/// [`reserve_many`] except it does not await for the slot to become available.
Totodore marked this conversation as resolved.
Show resolved Hide resolved
///
/// Dropping [`PermitIterator`] without sending a message releases the capacity back
/// to the channel.
///
/// [`PermitIterator`]: PermitIterator
/// [`send`]: Permit::send
/// [`reserve_many`]: Sender::reserve_many
///
/// # Examples
///
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, mut rx) = mpsc::channel(2);
///
/// // Reserve capacity
/// let mut permit = tx.try_reserve_many(2).unwrap();
///
/// // Trying to send directly on the `tx` will fail due to no
/// // available capacity.
/// assert!(tx.try_send(123).is_err());
///
/// // Trying to reserve an additional slot on the `tx` will
/// // fail because there is no capacity.
/// assert!(tx.try_reserve().is_err());
///
/// // Sending with the permit iterator succeeds
/// permit.next().unwrap().send(456);
/// permit.next().unwrap().send(457);
///
/// // The iterator should now be exhausted
/// assert!(permit.next().is_none());
///
/// // The value sent on the permit is received
/// assert_eq!(rx.recv().await.unwrap(), 456);
/// assert_eq!(rx.recv().await.unwrap(), 457);
///
/// }
/// ```
pub fn try_reserve_many(&self, n: usize) -> Result<PermitIterator<'_, T>, TrySendError<()>> {
match self.chan.semaphore().semaphore.try_acquire(n as u32) {
Ok(()) => {}
Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(())),
Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(())),
}

Ok(PermitIterator {
chan: &self.chan,
n,
})
}

/// Tries to acquire a slot in the channel without waiting for the slot to become
/// available, returning an owned permit.
///
Expand Down Expand Up @@ -1278,6 +1410,53 @@ impl<T> fmt::Debug for Permit<'_, T> {
}
}

// ===== impl PermitIterator =====

impl<'a, T> Iterator for PermitIterator<'a, T> {
type Item = Permit<'a, T>;

fn next(&mut self) -> Option<Self::Item> {
if self.n == 0 {
return None;
}

self.n -= 1;
Some(Permit { chan: self.chan })
}

fn size_hint(&self) -> (usize, Option<usize>) {
let n = self.n;
(n, Some(n))
}
}
impl<T> ExactSizeIterator for PermitIterator<'_, T> {}
Totodore marked this conversation as resolved.
Show resolved Hide resolved

impl<T> Drop for PermitIterator<'_, T> {
fn drop(&mut self) {
use chan::Semaphore;

let semaphore = self.chan.semaphore();

// Add the remaining permits back to the semaphore
semaphore.add_permits(self.n);

// If this is the last sender for this channel, wake the receiver so
// that it can be notified that the channel is closed.
if semaphore.is_closed() && semaphore.is_idle() {
self.chan.wake_rx();
}
}
}
Totodore marked this conversation as resolved.
Show resolved Hide resolved

impl<T> fmt::Debug for PermitIterator<'_, T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("PermitIterator")
.field("chan", &self.chan)
.field("capacity", &self.n)
.finish()
}
}

// ===== impl Permit =====

impl<T> OwnedPermit<T> {
Expand Down
4 changes: 3 additions & 1 deletion tokio/src/sync/mpsc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@
pub(super) mod block;

mod bounded;
pub use self::bounded::{channel, OwnedPermit, Permit, Receiver, Sender, WeakSender};
pub use self::bounded::{
channel, OwnedPermit, Permit, PermitIterator, Receiver, Sender, WeakSender,
};

mod chan;

Expand Down
4 changes: 2 additions & 2 deletions tokio/src/sync/rwlock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ impl<T: ?Sized> RwLock<T> {
/// ```
pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
let acquire_fut = async {
self.s.acquire(self.mr).await.unwrap_or_else(|_| {
self.s.acquire(self.mr as usize).await.unwrap_or_else(|_| {
// The semaphore was closed. but, we never explicitly close it, and we have a
// handle to it through the Arc, which means that this can never happen.
unreachable!()
Expand Down Expand Up @@ -907,7 +907,7 @@ impl<T: ?Sized> RwLock<T> {
let resource_span = self.resource_span.clone();

let acquire_fut = async {
self.s.acquire(self.mr).await.unwrap_or_else(|_| {
self.s.acquire(self.mr as usize).await.unwrap_or_else(|_| {
// The semaphore was closed. but, we never explicitly close it, and we have a
// handle to it through the Arc, which means that this can never happen.
unreachable!()
Expand Down
8 changes: 4 additions & 4 deletions tokio/src/sync/semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ impl Semaphore {
pub async fn acquire_many(&self, n: u32) -> Result<SemaphorePermit<'_>, AcquireError> {
#[cfg(all(tokio_unstable, feature = "tracing"))]
trace::async_op(
|| self.ll_sem.acquire(n),
|| self.ll_sem.acquire(n as usize),
self.resource_span.clone(),
"Semaphore::acquire_many",
"poll",
Expand All @@ -574,7 +574,7 @@ impl Semaphore {
.await?;

#[cfg(not(all(tokio_unstable, feature = "tracing")))]
self.ll_sem.acquire(n).await?;
self.ll_sem.acquire(n as usize).await?;

Ok(SemaphorePermit {
sem: self,
Expand Down Expand Up @@ -764,14 +764,14 @@ impl Semaphore {
) -> Result<OwnedSemaphorePermit, AcquireError> {
#[cfg(all(tokio_unstable, feature = "tracing"))]
let inner = trace::async_op(
|| self.ll_sem.acquire(n),
|| self.ll_sem.acquire(n as usize),
self.resource_span.clone(),
"Semaphore::acquire_many_owned",
"poll",
true,
);
#[cfg(not(all(tokio_unstable, feature = "tracing")))]
let inner = self.ll_sem.acquire(n);
let inner = self.ll_sem.acquire(n as usize);

inner.await?;
Ok(OwnedSemaphorePermit {
Expand Down