diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index 35de9a57436..aa23dea7d3c 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -71,7 +71,7 @@ pub struct AcquireError(()); pub(crate) struct Acquire<'a> { node: Waiter, semaphore: &'a Semaphore, - num_permits: u32, + num_permits: usize, queued: bool, } @@ -262,13 +262,13 @@ impl Semaphore { self.permits.load(Acquire) & Self::CLOSED == Self::CLOSED } - pub(crate) fn try_acquire(&self, num_permits: u32) -> Result<(), TryAcquireError> { + pub(crate) fn try_acquire(&self, num_permits: usize) -> Result<(), TryAcquireError> { assert!( - num_permits as usize <= Self::MAX_PERMITS, + num_permits <= Self::MAX_PERMITS, "a semaphore may not have more than MAX_PERMITS permits ({})", Self::MAX_PERMITS ); - let num_permits = (num_permits as usize) << Self::PERMIT_SHIFT; + let num_permits = num_permits << Self::PERMIT_SHIFT; let mut curr = self.permits.load(Acquire); loop { // Has the semaphore closed? @@ -293,7 +293,7 @@ impl Semaphore { } } - pub(crate) fn acquire(&self, num_permits: u32) -> Acquire<'_> { + pub(crate) fn acquire(&self, num_permits: usize) -> Acquire<'_> { Acquire::new(self, num_permits) } @@ -371,7 +371,7 @@ impl Semaphore { fn poll_acquire( &self, cx: &mut Context<'_>, - num_permits: u32, + num_permits: usize, node: Pin<&mut Waiter>, queued: bool, ) -> Poll> { @@ -380,7 +380,7 @@ impl Semaphore { let needed = if queued { node.state.load(Acquire) << Self::PERMIT_SHIFT } else { - (num_permits as usize) << Self::PERMIT_SHIFT + num_permits << Self::PERMIT_SHIFT }; let mut lock = None; @@ -506,12 +506,12 @@ impl fmt::Debug for Semaphore { impl Waiter { fn new( - num_permits: u32, + num_permits: usize, #[cfg(all(tokio_unstable, feature = "tracing"))] ctx: trace::AsyncOpTracingCtx, ) -> Self { Waiter { waker: UnsafeCell::new(None), - state: AtomicUsize::new(num_permits as usize), + state: AtomicUsize::new(num_permits), pointers: linked_list::Pointers::new(), #[cfg(all(tokio_unstable, feature = "tracing"))] ctx, @@ -591,7 +591,7 @@ impl Future for Acquire<'_> { } impl<'a> Acquire<'a> { - fn new(semaphore: &'a Semaphore, num_permits: u32) -> Self { + fn new(semaphore: &'a Semaphore, num_permits: usize) -> Self { #[cfg(any(not(tokio_unstable), not(feature = "tracing")))] return Self { node: Waiter::new(num_permits), @@ -635,14 +635,14 @@ impl<'a> Acquire<'a> { }); } - fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, u32, &mut bool) { + fn project(self: Pin<&mut Self>) -> (Pin<&mut Waiter>, &Semaphore, usize, &mut bool) { fn is_unpin() {} unsafe { // Safety: all fields other than `node` are `Unpin` is_unpin::<&Semaphore>(); is_unpin::<&mut bool>(); - is_unpin::(); + is_unpin::(); let this = self.get_unchecked_mut(); ( @@ -673,7 +673,7 @@ impl Drop for Acquire<'_> { // Safety: we have locked the wait list. unsafe { waiters.queue.remove(node) }; - let acquired_permits = self.num_permits as usize - self.node.state.load(Acquire); + let acquired_permits = self.num_permits - self.node.state.load(Acquire); if acquired_permits > 0 { self.semaphore.add_permits_locked(acquired_permits, waiters); } diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index 3a795d55774..a1e0a82d9e2 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -68,6 +68,18 @@ pub struct Permit<'a, T> { chan: &'a chan::Tx, } +/// An [`Iterator`] of [`Permit`] that can be used to hold `n` slots in the channel. +/// +/// `PermitIterator` values are returned by [`Sender::reserve_many()`] and [`Sender::try_reserve_many()`] +/// and are used to guarantee channel capacity before generating `n` messages to send. +/// +/// [`Sender::reserve_many()`]: Sender::reserve_many +/// [`Sender::try_reserve_many()`]: Sender::try_reserve_many +pub struct PermitIterator<'a, T> { + chan: &'a chan::Tx, + n: usize, +} + /// Owned permit to send one value into the channel. /// /// This is identical to the [`Permit`] type, except that it moves the sender @@ -926,10 +938,74 @@ impl Sender { /// } /// ``` pub async fn reserve(&self) -> Result, SendError<()>> { - self.reserve_inner().await?; + self.reserve_inner(1).await?; Ok(Permit { chan: &self.chan }) } + /// Waits for channel capacity. Once capacity to send `n` messages is + /// available, it is reserved for the caller. + /// + /// If the channel is full or if there are fewer than `n` permits available, the function waits + /// for the number of unreceived messages to become `n` less than the channel capacity. + /// Capacity to send `n` message is then reserved for the caller. + /// + /// A [`PermitIterator`] is returned to track the reserved capacity. + /// You can call this [`Iterator`] until it is exhausted to + /// get a [`Permit`] and then call [`Permit::send`]. This function is similar to + /// [`try_reserve_many`] except it awaits for the slots to become available. + /// + /// If the channel is closed, the function returns a [`SendError`]. + /// + /// Dropping [`PermitIterator`] without consuming it entirely releases the remaining + /// permits back to the channel. + /// + /// [`PermitIterator`]: PermitIterator + /// [`Permit`]: Permit + /// [`send`]: Permit::send + /// [`try_reserve_many`]: Sender::try_reserve_many + /// + /// # Cancel safety + /// + /// This channel uses a queue to ensure that calls to `send` and `reserve_many` + /// complete in the order they were requested. Cancelling a call to + /// `reserve_many` makes you lose your place in the queue. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(2); + /// + /// // Reserve capacity + /// let mut permit = tx.reserve_many(2).await.unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Sending with the permit iterator succeeds + /// permit.next().unwrap().send(456); + /// permit.next().unwrap().send(457); + /// + /// // The iterator should now be exhausted + /// assert!(permit.next().is_none()); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// assert_eq!(rx.recv().await.unwrap(), 457); + /// } + /// ``` + pub async fn reserve_many(&self, n: usize) -> Result, SendError<()>> { + self.reserve_inner(n).await?; + Ok(PermitIterator { + chan: &self.chan, + n, + }) + } + /// Waits for channel capacity, moving the `Sender` and returning an owned /// permit. Once capacity to send one message is available, it is reserved /// for the caller. @@ -1011,16 +1087,19 @@ impl Sender { /// [`send`]: OwnedPermit::send /// [`Arc::clone`]: std::sync::Arc::clone pub async fn reserve_owned(self) -> Result, SendError<()>> { - self.reserve_inner().await?; + self.reserve_inner(1).await?; Ok(OwnedPermit { chan: Some(self.chan), }) } - async fn reserve_inner(&self) -> Result<(), SendError<()>> { + async fn reserve_inner(&self, n: usize) -> Result<(), SendError<()>> { crate::trace::async_trace_leaf().await; - match self.chan.semaphore().semaphore.acquire(1).await { + if n > self.max_capacity() { + return Err(SendError(())); + } + match self.chan.semaphore().semaphore.acquire(n).await { Ok(()) => Ok(()), Err(_) => Err(SendError(())), } @@ -1079,6 +1158,91 @@ impl Sender { Ok(Permit { chan: &self.chan }) } + /// Tries to acquire `n` slots in the channel without waiting for the slot to become + /// available. + /// + /// A [`PermitIterator`] is returned to track the reserved capacity. + /// You can call this [`Iterator`] until it is exhausted to + /// get a [`Permit`] and then call [`Permit::send`]. This function is similar to + /// [`reserve_many`] except it does not await for the slots to become available. + /// + /// If there are fewer than `n` permits available on the channel, then + /// this function will return a [`TrySendError::Full`]. If the channel is closed + /// this function will return a [`TrySendError::Closed`]. + /// + /// Dropping [`PermitIterator`] without consuming it entirely releases the remaining + /// permits back to the channel. + /// + /// [`PermitIterator`]: PermitIterator + /// [`send`]: Permit::send + /// [`reserve_many`]: Sender::reserve_many + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx) = mpsc::channel(2); + /// + /// // Reserve capacity + /// let mut permit = tx.try_reserve_many(2).unwrap(); + /// + /// // Trying to send directly on the `tx` will fail due to no + /// // available capacity. + /// assert!(tx.try_send(123).is_err()); + /// + /// // Trying to reserve an additional slot on the `tx` will + /// // fail because there is no capacity. + /// assert!(tx.try_reserve().is_err()); + /// + /// // Sending with the permit iterator succeeds + /// permit.next().unwrap().send(456); + /// permit.next().unwrap().send(457); + /// + /// // The iterator should now be exhausted + /// assert!(permit.next().is_none()); + /// + /// // The value sent on the permit is received + /// assert_eq!(rx.recv().await.unwrap(), 456); + /// assert_eq!(rx.recv().await.unwrap(), 457); + /// + /// // Trying to call try_reserve_many with 0 will return an empty iterator + /// let mut permit = tx.try_reserve_many(0).unwrap(); + /// assert!(permit.next().is_none()); + /// + /// // Trying to call try_reserve_many with a number greater than the channel + /// // capacity will return an error + /// let permit = tx.try_reserve_many(3); + /// assert!(permit.is_err()); + /// + /// // Trying to call try_reserve_many on a closed channel will return an error + /// drop(rx); + /// let permit = tx.try_reserve_many(1); + /// assert!(permit.is_err()); + /// + /// let permit = tx.try_reserve_many(0); + /// assert!(permit.is_err()); + /// } + /// ``` + pub fn try_reserve_many(&self, n: usize) -> Result, TrySendError<()>> { + if n > self.max_capacity() { + return Err(TrySendError::Full(())); + } + + match self.chan.semaphore().semaphore.try_acquire(n) { + Ok(()) => {} + Err(TryAcquireError::Closed) => return Err(TrySendError::Closed(())), + Err(TryAcquireError::NoPermits) => return Err(TrySendError::Full(())), + } + + Ok(PermitIterator { + chan: &self.chan, + n, + }) + } + /// Tries to acquire a slot in the channel without waiting for the slot to become /// available, returning an owned permit. /// @@ -1355,6 +1519,58 @@ impl fmt::Debug for Permit<'_, T> { } } +// ===== impl PermitIterator ===== + +impl<'a, T> Iterator for PermitIterator<'a, T> { + type Item = Permit<'a, T>; + + fn next(&mut self) -> Option { + if self.n == 0 { + return None; + } + + self.n -= 1; + Some(Permit { chan: self.chan }) + } + + fn size_hint(&self) -> (usize, Option) { + let n = self.n; + (n, Some(n)) + } +} +impl ExactSizeIterator for PermitIterator<'_, T> {} +impl std::iter::FusedIterator for PermitIterator<'_, T> {} + +impl Drop for PermitIterator<'_, T> { + fn drop(&mut self) { + use chan::Semaphore; + + if self.n == 0 { + return; + } + + let semaphore = self.chan.semaphore(); + + // Add the remaining permits back to the semaphore + semaphore.add_permits(self.n); + + // If this is the last sender for this channel, wake the receiver so + // that it can be notified that the channel is closed. + if semaphore.is_closed() && semaphore.is_idle() { + self.chan.wake_rx(); + } + } +} + +impl fmt::Debug for PermitIterator<'_, T> { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("PermitIterator") + .field("chan", &self.chan) + .field("capacity", &self.n) + .finish() + } +} + // ===== impl Permit ===== impl OwnedPermit { diff --git a/tokio/src/sync/mpsc/mod.rs b/tokio/src/sync/mpsc/mod.rs index b2af084b2ae..052620be1a9 100644 --- a/tokio/src/sync/mpsc/mod.rs +++ b/tokio/src/sync/mpsc/mod.rs @@ -95,7 +95,9 @@ pub(super) mod block; mod bounded; -pub use self::bounded::{channel, OwnedPermit, Permit, Receiver, Sender, WeakSender}; +pub use self::bounded::{ + channel, OwnedPermit, Permit, PermitIterator, Receiver, Sender, WeakSender, +}; mod chan; diff --git a/tokio/src/sync/rwlock.rs b/tokio/src/sync/rwlock.rs index 877458a57fb..37cf73c5905 100644 --- a/tokio/src/sync/rwlock.rs +++ b/tokio/src/sync/rwlock.rs @@ -772,7 +772,7 @@ impl RwLock { /// ``` pub async fn write(&self) -> RwLockWriteGuard<'_, T> { let acquire_fut = async { - self.s.acquire(self.mr).await.unwrap_or_else(|_| { + self.s.acquire(self.mr as usize).await.unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a // handle to it through the Arc, which means that this can never happen. unreachable!() @@ -907,7 +907,7 @@ impl RwLock { let resource_span = self.resource_span.clone(); let acquire_fut = async { - self.s.acquire(self.mr).await.unwrap_or_else(|_| { + self.s.acquire(self.mr as usize).await.unwrap_or_else(|_| { // The semaphore was closed. but, we never explicitly close it, and we have a // handle to it through the Arc, which means that this can never happen. unreachable!() @@ -971,7 +971,7 @@ impl RwLock { /// } /// ``` pub fn try_write(&self) -> Result, TryLockError> { - match self.s.try_acquire(self.mr) { + match self.s.try_acquire(self.mr as usize) { Ok(permit) => permit, Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), Err(TryAcquireError::Closed) => unreachable!(), @@ -1029,7 +1029,7 @@ impl RwLock { /// } /// ``` pub fn try_write_owned(self: Arc) -> Result, TryLockError> { - match self.s.try_acquire(self.mr) { + match self.s.try_acquire(self.mr as usize) { Ok(permit) => permit, Err(TryAcquireError::NoPermits) => return Err(TryLockError(())), Err(TryAcquireError::Closed) => unreachable!(), diff --git a/tokio/src/sync/semaphore.rs b/tokio/src/sync/semaphore.rs index 8b8fdb23871..25e4134373c 100644 --- a/tokio/src/sync/semaphore.rs +++ b/tokio/src/sync/semaphore.rs @@ -565,7 +565,7 @@ impl Semaphore { pub async fn acquire_many(&self, n: u32) -> Result, AcquireError> { #[cfg(all(tokio_unstable, feature = "tracing"))] trace::async_op( - || self.ll_sem.acquire(n), + || self.ll_sem.acquire(n as usize), self.resource_span.clone(), "Semaphore::acquire_many", "poll", @@ -574,7 +574,7 @@ impl Semaphore { .await?; #[cfg(not(all(tokio_unstable, feature = "tracing")))] - self.ll_sem.acquire(n).await?; + self.ll_sem.acquire(n as usize).await?; Ok(SemaphorePermit { sem: self, @@ -646,7 +646,7 @@ impl Semaphore { /// [`TryAcquireError::NoPermits`]: crate::sync::TryAcquireError::NoPermits /// [`SemaphorePermit`]: crate::sync::SemaphorePermit pub fn try_acquire_many(&self, n: u32) -> Result, TryAcquireError> { - match self.ll_sem.try_acquire(n) { + match self.ll_sem.try_acquire(n as usize) { Ok(()) => Ok(SemaphorePermit { sem: self, permits: n, @@ -764,14 +764,14 @@ impl Semaphore { ) -> Result { #[cfg(all(tokio_unstable, feature = "tracing"))] let inner = trace::async_op( - || self.ll_sem.acquire(n), + || self.ll_sem.acquire(n as usize), self.resource_span.clone(), "Semaphore::acquire_many_owned", "poll", true, ); #[cfg(not(all(tokio_unstable, feature = "tracing")))] - let inner = self.ll_sem.acquire(n); + let inner = self.ll_sem.acquire(n as usize); inner.await?; Ok(OwnedSemaphorePermit { @@ -855,7 +855,7 @@ impl Semaphore { self: Arc, n: u32, ) -> Result { - match self.ll_sem.try_acquire(n) { + match self.ll_sem.try_acquire(n as usize) { Ok(()) => Ok(OwnedSemaphorePermit { sem: self, permits: n, diff --git a/tokio/tests/sync_mpsc.rs b/tokio/tests/sync_mpsc.rs index a5c15a4cfc6..1b581ce98c1 100644 --- a/tokio/tests/sync_mpsc.rs +++ b/tokio/tests/sync_mpsc.rs @@ -522,6 +522,79 @@ async fn try_send_fail_with_try_recv() { assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected)); } +#[maybe_tokio_test] +async fn reserve_many_above_cap() { + const MAX_PERMITS: usize = tokio::sync::Semaphore::MAX_PERMITS; + let (tx, _rx) = mpsc::channel::<()>(1); + + assert_err!(tx.reserve_many(2).await); + assert_err!(tx.reserve_many(MAX_PERMITS + 1).await); + assert_err!(tx.reserve_many(usize::MAX).await); +} + +#[test] +fn try_reserve_many_zero() { + let (tx, rx) = mpsc::channel::<()>(1); + + // Succeeds when not closed. + assert!(assert_ok!(tx.try_reserve_many(0)).next().is_none()); + + // Even when channel is full. + tx.try_send(()).unwrap(); + assert!(assert_ok!(tx.try_reserve_many(0)).next().is_none()); + + drop(rx); + + // Closed error when closed. + assert_eq!( + assert_err!(tx.try_reserve_many(0)), + TrySendError::Closed(()) + ); +} + +#[maybe_tokio_test] +async fn reserve_many_zero() { + let (tx, rx) = mpsc::channel::<()>(1); + + // Succeeds when not closed. + assert!(assert_ok!(tx.reserve_many(0).await).next().is_none()); + + // Even when channel is full. + tx.send(()).await.unwrap(); + assert!(assert_ok!(tx.reserve_many(0).await).next().is_none()); + + drop(rx); + + // Closed error when closed. + assert_err!(tx.reserve_many(0).await); +} + +#[maybe_tokio_test] +async fn try_reserve_many_edge_cases() { + const MAX_PERMITS: usize = tokio::sync::Semaphore::MAX_PERMITS; + + let (tx, rx) = mpsc::channel::<()>(1); + + let mut permit = assert_ok!(tx.try_reserve_many(0)); + assert!(permit.next().is_none()); + + let permit = tx.try_reserve_many(MAX_PERMITS + 1); + match assert_err!(permit) { + TrySendError::Full(..) => {} + _ => panic!(), + } + + let permit = tx.try_reserve_many(usize::MAX); + match assert_err!(permit) { + TrySendError::Full(..) => {} + _ => panic!(), + } + + // Dropping the receiver should close the channel + drop(rx); + assert_err!(tx.reserve_many(0).await); +} + #[maybe_tokio_test] async fn try_reserve_fails() { let (tx, mut rx) = mpsc::channel(1); @@ -545,6 +618,87 @@ async fn try_reserve_fails() { let _permit = tx.try_reserve().unwrap(); } +#[maybe_tokio_test] +async fn reserve_many_and_send() { + let (tx, mut rx) = mpsc::channel(100); + for i in 0..100 { + for permit in assert_ok!(tx.reserve_many(i).await) { + permit.send("foo"); + assert_eq!(rx.recv().await, Some("foo")); + } + assert_eq!(rx.try_recv(), Err(TryRecvError::Empty)); + } +} +#[maybe_tokio_test] +async fn try_reserve_many_and_send() { + let (tx, mut rx) = mpsc::channel(100); + for i in 0..100 { + for permit in assert_ok!(tx.try_reserve_many(i)) { + permit.send("foo"); + assert_eq!(rx.recv().await, Some("foo")); + } + assert_eq!(rx.try_recv(), Err(TryRecvError::Empty)); + } +} + +#[maybe_tokio_test] +async fn reserve_many_on_closed_channel() { + let (tx, rx) = mpsc::channel::<()>(100); + drop(rx); + assert_err!(tx.reserve_many(10).await); +} + +#[maybe_tokio_test] +async fn try_reserve_many_on_closed_channel() { + let (tx, rx) = mpsc::channel::(100); + drop(rx); + match assert_err!(tx.try_reserve_many(10)) { + TrySendError::Closed(()) => {} + _ => panic!(), + }; +} + +#[maybe_tokio_test] +async fn try_reserve_many_full() { + // Reserve n capacity and send k messages + for n in 1..100 { + for k in 0..n { + let (tx, mut rx) = mpsc::channel::(n); + let permits = assert_ok!(tx.try_reserve_many(n)); + + assert_eq!(permits.len(), n); + assert_eq!(tx.capacity(), 0); + + match assert_err!(tx.try_reserve_many(1)) { + TrySendError::Full(..) => {} + _ => panic!(), + }; + + for permit in permits.take(k) { + permit.send(0); + } + // We only used k permits on the n reserved + assert_eq!(tx.capacity(), n - k); + + // We can reserve more permits + assert_ok!(tx.try_reserve_many(1)); + + // But not more than the current capacity + match assert_err!(tx.try_reserve_many(n - k + 1)) { + TrySendError::Full(..) => {} + _ => panic!(), + }; + + for _i in 0..k { + assert_eq!(rx.recv().await, Some(0)); + } + + // Now that we've received everything, capacity should be back to n + assert_eq!(tx.capacity(), n); + } + } +} + #[tokio::test] #[cfg(feature = "full")] async fn drop_permit_releases_permit() { @@ -564,6 +718,30 @@ async fn drop_permit_releases_permit() { assert_ready_ok!(reserve2.poll()); } +#[maybe_tokio_test] +async fn drop_permit_iterator_releases_permits() { + // poll_ready reserves capacity, ensure that the capacity is released if tx + // is dropped w/o sending a value. + for n in 1..100 { + let (tx1, _rx) = mpsc::channel::(n); + let tx2 = tx1.clone(); + + let permits = assert_ok!(tx1.reserve_many(n).await); + + let mut reserve2 = tokio_test::task::spawn(tx2.reserve_many(n)); + assert_pending!(reserve2.poll()); + + drop(permits); + + assert!(reserve2.is_woken()); + + let permits = assert_ready_ok!(reserve2.poll()); + drop(permits); + + assert_eq!(tx1.capacity(), n); + } +} + #[maybe_tokio_test] async fn dropping_rx_closes_channel() { let (tx, rx) = mpsc::channel(100); @@ -573,6 +751,7 @@ async fn dropping_rx_closes_channel() { drop(rx); assert_err!(tx.reserve().await); + assert_err!(tx.reserve_many(10).await); assert_eq!(1, Arc::strong_count(&msg)); }