Skip to content

Commit

Permalink
sync: add is_closed, is_empty, and len to mpsc receivers (#6348)
Browse files Browse the repository at this point in the history
Fixes: #4638
  • Loading branch information
balliegojr committed Mar 24, 2024
1 parent 8342e4b commit 3ce4720
Show file tree
Hide file tree
Showing 8 changed files with 661 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tokio/src/sync/mpsc/block.rs
Expand Up @@ -168,6 +168,19 @@ impl<T> Block<T> {
Some(Read::Value(value.assume_init()))
}

/// Returns true if there is a value in the slot to be consumed
///
/// # Safety
///
/// To maintain safety, the caller must ensure:
///
/// * No concurrent access to the slot.
pub(crate) fn has_value(&self, slot_index: usize) -> bool {
let offset = offset(slot_index);
let ready_bits = self.header.ready_slots.load(Acquire);
is_ready(ready_bits, offset)
}

/// Writes a value to the block at the given offset.
///
/// # Safety
Expand Down Expand Up @@ -195,6 +208,11 @@ impl<T> Block<T> {
self.header.ready_slots.fetch_or(TX_CLOSED, Release);
}

pub(crate) unsafe fn is_closed(&self) -> bool {
let ready_bits = self.header.ready_slots.load(Acquire);
is_tx_closed(ready_bits)
}

/// Resets the block to a blank state. This enables reusing blocks in the
/// channel.
///
Expand Down
67 changes: 67 additions & 0 deletions tokio/src/sync/mpsc/bounded.rs
Expand Up @@ -463,6 +463,73 @@ impl<T> Receiver<T> {
self.chan.close();
}

/// Checks if a channel is closed.
///
/// This method returns `true` if the channel has been closed. The channel is closed
/// when all [`Sender`] have been dropped, or when [`Receiver::close`] is called.
///
/// [`Sender`]: crate::sync::mpsc::Sender
/// [`Receiver::close`]: crate::sync::mpsc::Receiver::close
///
/// # Examples
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (_tx, mut rx) = mpsc::channel::<()>(10);
/// assert!(!rx.is_closed());
///
/// rx.close();
///
/// assert!(rx.is_closed());
/// }
/// ```
pub fn is_closed(&self) -> bool {
self.chan.is_closed()
}

/// Checks if a channel is empty.
///
/// This method returns `true` if the channel has no messages.
///
/// # Examples
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = mpsc::channel(10);
/// assert!(rx.is_empty());
///
/// tx.send(0).await.unwrap();
/// assert!(!rx.is_empty());
/// }
///
/// ```
pub fn is_empty(&self) -> bool {
self.chan.is_empty()
}

/// Returns the number of messages in the channel.
///
/// # Examples
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = mpsc::channel(10);
/// assert_eq!(0, rx.len());
///
/// tx.send(0).await.unwrap();
/// assert_eq!(1, rx.len());
/// }
/// ```
pub fn len(&self) -> usize {
self.chan.len()
}

/// Polls to receive the next message on this channel.
///
/// This method returns:
Expand Down
27 changes: 27 additions & 0 deletions tokio/src/sync/mpsc/chan.rs
Expand Up @@ -255,6 +255,33 @@ impl<T, S: Semaphore> Rx<T, S> {
self.inner.notify_rx_closed.notify_waiters();
}

pub(crate) fn is_closed(&self) -> bool {
// There two internal states that can represent a closed channel
//
// 1. When `close` is called.
// In this case, the inner semaphore will be closed.
//
// 2. When all senders are dropped.
// In this case, the semaphore remains unclosed, and the `index` in the list won't
// reach the tail position. It is necessary to check the list if the last block is
// `closed`.
self.inner.semaphore.is_closed() || self.inner.tx_count.load(Acquire) == 0
}

pub(crate) fn is_empty(&self) -> bool {
self.inner.rx_fields.with(|rx_fields_ptr| {
let rx_fields = unsafe { &*rx_fields_ptr };
rx_fields.list.is_empty(&self.inner.tx)
})
}

pub(crate) fn len(&self) -> usize {
self.inner.rx_fields.with(|rx_fields_ptr| {
let rx_fields = unsafe { &*rx_fields_ptr };
rx_fields.list.len(&self.inner.tx)
})
}

/// Receive the next value
pub(crate) fn recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
use super::block::Read;
Expand Down
27 changes: 27 additions & 0 deletions tokio/src/sync/mpsc/list.rs
Expand Up @@ -218,6 +218,15 @@ impl<T> Tx<T> {
let _ = Box::from_raw(block.as_ptr());
}
}

pub(crate) fn is_closed(&self) -> bool {
let tail = self.block_tail.load(Acquire);

unsafe {
let tail_block = &*tail;
tail_block.is_closed()
}
}
}

impl<T> fmt::Debug for Tx<T> {
Expand All @@ -230,6 +239,24 @@ impl<T> fmt::Debug for Tx<T> {
}

impl<T> Rx<T> {
pub(crate) fn is_empty(&self, tx: &Tx<T>) -> bool {
let block = unsafe { self.head.as_ref() };
if block.has_value(self.index) {
return false;
}

// It is possible that a block has no value "now" but the list is still not empty.
// To be sure, it is necessary to check the length of the list.
self.len(tx) == 0
}

pub(crate) fn len(&self, tx: &Tx<T>) -> usize {
// When all the senders are dropped, there will be a last block in the tail position,
// but it will be closed
let tail_position = tx.tail_position.load(Acquire);
tail_position - self.index - (tx.is_closed() as usize)
}

/// Pops the next value off the queue.
pub(crate) fn pop(&mut self, tx: &Tx<T>) -> Option<block::Read<T>> {
// Advance `head`, if needed
Expand Down
67 changes: 67 additions & 0 deletions tokio/src/sync/mpsc/unbounded.rs
Expand Up @@ -330,6 +330,73 @@ impl<T> UnboundedReceiver<T> {
self.chan.close();
}

/// Checks if a channel is closed.
///
/// This method returns `true` if the channel has been closed. The channel is closed
/// when all [`UnboundedSender`] have been dropped, or when [`UnboundedReceiver::close`] is called.
///
/// [`UnboundedSender`]: crate::sync::mpsc::UnboundedSender
/// [`UnboundedReceiver::close`]: crate::sync::mpsc::UnboundedReceiver::close
///
/// # Examples
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (_tx, mut rx) = mpsc::unbounded_channel::<()>();
/// assert!(!rx.is_closed());
///
/// rx.close();
///
/// assert!(rx.is_closed());
/// }
/// ```
pub fn is_closed(&self) -> bool {
self.chan.is_closed()
}

/// Checks if a channel is empty.
///
/// This method returns `true` if the channel has no messages.
///
/// # Examples
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = mpsc::unbounded_channel();
/// assert!(rx.is_empty());
///
/// tx.send(0).unwrap();
/// assert!(!rx.is_empty());
/// }
///
/// ```
pub fn is_empty(&self) -> bool {
self.chan.is_empty()
}

/// Returns the number of messages in the channel.
///
/// # Examples
/// ```
/// use tokio::sync::mpsc;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = mpsc::unbounded_channel();
/// assert_eq!(0, rx.len());
///
/// tx.send(0).unwrap();
/// assert_eq!(1, rx.len());
/// }
/// ```
pub fn len(&self) -> usize {
self.chan.len()
}

/// Polls to receive the next message on this channel.
///
/// This method returns:
Expand Down
34 changes: 34 additions & 0 deletions tokio/src/sync/tests/loom_mpsc.rs
Expand Up @@ -188,3 +188,37 @@ fn try_recv() {
}
});
}

#[test]
fn len_nonzero_after_send() {
loom::model(|| {
let (send, recv) = mpsc::channel(10);
let send2 = send.clone();

let join = thread::spawn(move || {
block_on(send2.send("message2")).unwrap();
});

block_on(send.send("message1")).unwrap();
assert!(recv.len() != 0);

join.join().unwrap();
});
}

#[test]
fn nonempty_after_send() {
loom::model(|| {
let (send, recv) = mpsc::channel(10);
let send2 = send.clone();

let join = thread::spawn(move || {
block_on(send2.send("message2")).unwrap();
});

block_on(send.send("message1")).unwrap();
assert!(!recv.is_empty());

join.join().unwrap();
});
}

0 comments on commit 3ce4720

Please sign in to comment.