From 67ab20a075dbfe30544d3015387db6ed84913294 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20M=2E=20Bezerra?= Date: Tue, 11 Jul 2023 15:01:15 -0300 Subject: [PATCH] sync::broadcast: don't lock in `channel()` --- tokio/src/sync/broadcast.rs | 55 +++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index ff628c1f2bf..ee3702addba 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -445,8 +445,12 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2; /// than `usize::MAX / 2`. #[track_caller] pub fn channel(capacity: usize) -> (Sender, Receiver) { - let tx = Sender::new(capacity); - let rx = tx.subscribe(); + // SAFETY: In the line below we are creating one extra receiver, so will be 1 in total. + let tx = unsafe { Sender::new_with_receiver_count(1, capacity) }; + let rx = Receiver { + shared: tx.shared.clone(), + next: 0, + }; (tx, rx) } @@ -464,9 +468,28 @@ impl Sender { /// [`broadcast`]: crate::sync::broadcast /// [`broadcast::channel`]: crate::sync::broadcast #[track_caller] - pub fn new(mut capacity: usize) -> Self { - assert!(capacity > 0, "capacity is empty"); - assert!(capacity <= usize::MAX >> 1, "requested capacity too large"); + pub fn new(capacity: usize) -> Self { + // SAFETY: We don't create extra receivers, so there are 0. + unsafe { Self::new_with_receiver_count(0, capacity) } + } + + /// Creates the sending-half of the [`broadcast`] channel, and provide the receiver count. + /// + /// See the documentation of [`broadcast::channel`] for more errors when calling this + /// function. + /// + /// # Safety: + /// + /// The caller must ensure that the amount of receivers for this Sender is correct before + /// the channel functionalities are used, the count is zero by default, as this function + /// does not create any receivers by itself. + #[track_caller] + unsafe fn new_with_receiver_count(receiver_count: usize, mut capacity: usize) -> Self { + assert!(capacity > 0, "broadcast channel capacity cannot be zero"); + assert!( + capacity <= usize::MAX >> 1, + "broadcast channel capacity exceeded `usize::MAX / 2`" + ); // Round to a power of two capacity = capacity.next_power_of_two(); @@ -486,7 +509,7 @@ impl Sender { mask: capacity - 1, tail: Mutex::new(Tail { pos: 0, - rx_cnt: 0, + rx_cnt: receiver_count, closed: false, waiters: LinkedList::new(), }), @@ -1383,37 +1406,33 @@ mod tests { #[test] fn receiver_count_on_sender_constructor() { - let count_of = |sender: &Sender| sender.shared.tail.lock().rx_cnt; - let sender = Sender::::new(16); - assert_eq!(count_of(&sender), 0); + assert_eq!(sender.receiver_count(), 0); let rx_1 = sender.subscribe(); - assert_eq!(count_of(&sender), 1); + assert_eq!(sender.receiver_count(), 1); let rx_2 = rx_1.resubscribe(); - assert_eq!(count_of(&sender), 2); + assert_eq!(sender.receiver_count(), 2); let rx_3 = sender.subscribe(); - assert_eq!(count_of(&sender), 3); + assert_eq!(sender.receiver_count(), 3); drop(rx_3); drop(rx_1); - assert_eq!(count_of(&sender), 1); + assert_eq!(sender.receiver_count(), 1); drop(rx_2); - assert_eq!(count_of(&sender), 0); + assert_eq!(sender.receiver_count(), 0); } #[cfg(not(loom))] #[test] fn receiver_count_on_channel_constructor() { - let count_of = |sender: &Sender| sender.shared.tail.lock().rx_cnt; - let (sender, rx) = channel::(16); - assert_eq!(count_of(&sender), 1); + assert_eq!(sender.receiver_count(), 1); let _rx_2 = rx.resubscribe(); - assert_eq!(count_of(&sender), 2); + assert_eq!(sender.receiver_count(), 2); } }