diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 4b36452cec3..b168be7ee94 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -444,42 +444,13 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2; /// This will panic if `capacity` is equal to `0` or larger /// than `usize::MAX / 2`. #[track_caller] -pub fn channel(mut capacity: usize) -> (Sender, Receiver) { - assert!(capacity > 0, "capacity is empty"); - assert!(capacity <= usize::MAX >> 1, "requested capacity too large"); - - // Round to a power of two - capacity = capacity.next_power_of_two(); - - let mut buffer = Vec::with_capacity(capacity); - - for i in 0..capacity { - buffer.push(RwLock::new(Slot { - rem: AtomicUsize::new(0), - pos: (i as u64).wrapping_sub(capacity as u64), - val: UnsafeCell::new(None), - })); - } - - let shared = Arc::new(Shared { - buffer: buffer.into_boxed_slice(), - mask: capacity - 1, - tail: Mutex::new(Tail { - pos: 0, - rx_cnt: 1, - closed: false, - waiters: LinkedList::new(), - }), - num_tx: AtomicUsize::new(1), - }); - +pub fn channel(capacity: usize) -> (Sender, Receiver) { + // SAFETY: In the line below we are creating one extra receiver, so there will be 1 in total. + let tx = unsafe { Sender::new_with_receiver_count(1, capacity) }; let rx = Receiver { - shared: shared.clone(), + shared: tx.shared.clone(), next: 0, }; - - let tx = Sender { shared }; - (tx, rx) } @@ -490,6 +461,65 @@ unsafe impl Send for Receiver {} unsafe impl Sync for Receiver {} impl Sender { + /// Creates the sending-half of the [`broadcast`] channel. + /// + /// See documentation of [`broadcast::channel`] for errors when calling this function. + /// + /// [`broadcast`]: crate::sync::broadcast + /// [`broadcast::channel`]: crate::sync::broadcast + #[track_caller] + pub fn new(capacity: usize) -> Self { + // SAFETY: We don't create extra receivers, so there are 0. + unsafe { Self::new_with_receiver_count(0, capacity) } + } + + /// Creates the sending-half of the [`broadcast`](self) channel, and provide the receiver + /// count. + /// + /// See the documentation of [`broadcast::channel`](self::channel) for more errors when + /// calling this function. + /// + /// # Safety: + /// + /// The caller must ensure that the amount of receivers for this Sender is correct before + /// the channel functionalities are used, the count is zero by default, as this function + /// does not create any receivers by itself. + #[track_caller] + unsafe fn new_with_receiver_count(receiver_count: usize, mut capacity: usize) -> Self { + assert!(capacity > 0, "broadcast channel capacity cannot be zero"); + assert!( + capacity <= usize::MAX >> 1, + "broadcast channel capacity exceeded `usize::MAX / 2`" + ); + + // Round to a power of two + capacity = capacity.next_power_of_two(); + + let mut buffer = Vec::with_capacity(capacity); + + for i in 0..capacity { + buffer.push(RwLock::new(Slot { + rem: AtomicUsize::new(0), + pos: (i as u64).wrapping_sub(capacity as u64), + val: UnsafeCell::new(None), + })); + } + + let shared = Arc::new(Shared { + buffer: buffer.into_boxed_slice(), + mask: capacity - 1, + tail: Mutex::new(Tail { + pos: 0, + rx_cnt: receiver_count, + closed: false, + waiters: LinkedList::new(), + }), + num_tx: AtomicUsize::new(1), + }); + + Sender { shared } + } + /// Attempts to send a value to all active [`Receiver`] handles, returning /// it back if it could not be sent. /// @@ -1369,3 +1399,41 @@ impl<'a, T> Drop for RecvGuard<'a, T> { } fn is_unpin() {} + +#[cfg(not(loom))] +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn receiver_count_on_sender_constructor() { + let sender = Sender::::new(16); + assert_eq!(sender.receiver_count(), 0); + + let rx_1 = sender.subscribe(); + assert_eq!(sender.receiver_count(), 1); + + let rx_2 = rx_1.resubscribe(); + assert_eq!(sender.receiver_count(), 2); + + let rx_3 = sender.subscribe(); + assert_eq!(sender.receiver_count(), 3); + + drop(rx_3); + drop(rx_1); + assert_eq!(sender.receiver_count(), 1); + + drop(rx_2); + assert_eq!(sender.receiver_count(), 0); + } + + #[cfg(not(loom))] + #[test] + fn receiver_count_on_channel_constructor() { + let (sender, rx) = channel::(16); + assert_eq!(sender.receiver_count(), 1); + + let _rx_2 = rx.resubscribe(); + assert_eq!(sender.receiver_count(), 2); + } +}