Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync: implement Weak version of mpsc::UnboundedSender #5189

Merged
merged 3 commits into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tokio/src/sync/mpsc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ mod chan;
pub(super) mod list;

mod unbounded;
pub use self::unbounded::{unbounded_channel, UnboundedReceiver, UnboundedSender};
pub use self::unbounded::{
unbounded_channel, UnboundedReceiver, UnboundedSender, WeakUnboundedSender,
};

pub mod error;

Expand Down
69 changes: 68 additions & 1 deletion tokio/src/sync/mpsc/unbounded.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::{atomic::AtomicUsize, Arc};
use crate::sync::mpsc::chan;
use crate::sync::mpsc::error::{SendError, TryRecvError};

Expand All @@ -13,6 +13,40 @@ pub struct UnboundedSender<T> {
chan: chan::Tx<T, Semaphore>,
}

/// An unbounded sender that does not prevent the channel from being closed.
///
/// If all [`UnboundedSender`] instances of a channel were dropped and only
/// `WeakUnboundedSender` instances remain, the channel is closed.
///
/// In order to send messages, the `WeakUnboundedSender` needs to be upgraded using
/// [`WeakUnboundedSender::upgrade`], which returns `Option<UnboundedSender>`. It returns `None`
/// if all `UnboundedSender`s have been dropped, and otherwise it returns an `UnboundedSender`.
///
/// [`UnboundedSender`]: UnboundedSender
/// [`WeakUnboundedSender::upgrade`]: WeakUnboundedSender::upgrade
///
/// #Examples
///
/// ```
/// use tokio::sync::mpsc::unbounded_channel;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, _rx) = unbounded_channel::<i32>();
/// let tx_weak = tx.downgrade();
///
/// // Upgrading will succeed because `tx` still exists.
/// assert!(tx_weak.upgrade().is_some());
///
/// // If we drop `tx`, then it will fail.
/// drop(tx);
/// assert!(tx_weak.clone().upgrade().is_none());
/// }
/// ```
pub struct WeakUnboundedSender<T> {
chan: Arc<chan::Chan<T, Semaphore>>,
}

impl<T> Clone for UnboundedSender<T> {
fn clone(&self) -> Self {
UnboundedSender {
Expand Down Expand Up @@ -384,4 +418,37 @@ impl<T> UnboundedSender<T> {
pub fn same_channel(&self, other: &Self) -> bool {
self.chan.same_channel(&other.chan)
}

/// Converts the `UnboundedSender` to a [`WeakUnboundedSender`] that does not count
/// towards RAII semantics, i.e. if all `UnboundedSender` instances of the
/// channel were dropped and only `WeakUnboundedSender` instances remain,
/// the channel is closed.
pub fn downgrade(&self) -> WeakUnboundedSender<T> {
WeakUnboundedSender {
chan: self.chan.downgrade(),
}
}
}

impl<T> Clone for WeakUnboundedSender<T> {
fn clone(&self) -> Self {
WeakUnboundedSender {
chan: self.chan.clone(),
}
}
}

impl<T> WeakUnboundedSender<T> {
/// Tries to convert a WeakUnboundedSender into an [`UnboundedSender`].
/// This will return `Some` if there are other `Sender` instances alive and
/// the channel wasn't previously dropped, otherwise `None` is returned.
pub fn upgrade(&self) -> Option<UnboundedSender<T>> {
chan::Tx::upgrade(self.chan.clone()).map(UnboundedSender::new)
}
}

impl<T> fmt::Debug for WeakUnboundedSender<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("WeakUnboundedSender").finish()
}
}
274 changes: 4 additions & 270 deletions tokio/tests/sync_mpsc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,14 @@ use wasm_bindgen_test::wasm_bindgen_test as test;
#[cfg(tokio_wasm_not_wasi)]
use wasm_bindgen_test::wasm_bindgen_test as maybe_tokio_test;

use std::fmt;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::{TryRecvError, TrySendError};
#[cfg(not(tokio_wasm_not_wasi))]
use tokio::test as maybe_tokio_test;

use tokio::sync::mpsc::error::{TryRecvError, TrySendError};
use tokio::sync::mpsc::{self, channel};
use tokio::sync::oneshot;
use tokio_test::*;

use std::fmt;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{Acquire, Release};
use std::sync::Arc;

#[cfg(not(tokio_wasm))]
mod support {
pub(crate) mod mpsc_stream;
Expand Down Expand Up @@ -662,267 +657,6 @@ fn recv_timeout_panic() {
tx.send_timeout(10, Duration::from_secs(1)).now_or_never();
}

#[tokio::test]
async fn weak_sender() {
let (tx, mut rx) = channel(11);

let tx_weak = tokio::spawn(async move {
let tx_weak = tx.clone().downgrade();

for i in 0..10 {
if tx.send(i).await.is_err() {
return None;
}
}

let tx2 = tx_weak
.upgrade()
.expect("expected to be able to upgrade tx_weak");
let _ = tx2.send(20).await;
let tx_weak = tx2.downgrade();

Some(tx_weak)
})
.await
.unwrap();

for i in 0..12 {
let recvd = rx.recv().await;

match recvd {
Some(msg) => {
if i == 10 {
assert_eq!(msg, 20);
}
}
None => {
assert_eq!(i, 11);
break;
}
}
}

let tx_weak = tx_weak.unwrap();
let upgraded = tx_weak.upgrade();
assert!(upgraded.is_none());
}

#[tokio::test]
async fn actor_weak_sender() {
pub struct MyActor {
receiver: mpsc::Receiver<ActorMessage>,
sender: mpsc::WeakSender<ActorMessage>,
next_id: u32,
pub received_self_msg: bool,
}

enum ActorMessage {
GetUniqueId { respond_to: oneshot::Sender<u32> },
SelfMessage {},
}

impl MyActor {
fn new(
receiver: mpsc::Receiver<ActorMessage>,
sender: mpsc::WeakSender<ActorMessage>,
) -> Self {
MyActor {
receiver,
sender,
next_id: 0,
received_self_msg: false,
}
}

fn handle_message(&mut self, msg: ActorMessage) {
match msg {
ActorMessage::GetUniqueId { respond_to } => {
self.next_id += 1;

// The `let _ =` ignores any errors when sending.
//
// This can happen if the `select!` macro is used
// to cancel waiting for the response.
let _ = respond_to.send(self.next_id);
}
ActorMessage::SelfMessage { .. } => {
self.received_self_msg = true;
}
}
}

async fn send_message_to_self(&mut self) {
let msg = ActorMessage::SelfMessage {};

let sender = self.sender.clone();

// cannot move self.sender here
if let Some(sender) = sender.upgrade() {
let _ = sender.send(msg).await;
self.sender = sender.downgrade();
}
}

async fn run(&mut self) {
let mut i = 0;
while let Some(msg) = self.receiver.recv().await {
self.handle_message(msg);

if i == 0 {
self.send_message_to_self().await;
}

i += 1
}

assert!(self.received_self_msg);
}
}

#[derive(Clone)]
pub struct MyActorHandle {
sender: mpsc::Sender<ActorMessage>,
}

impl MyActorHandle {
pub fn new() -> (Self, MyActor) {
let (sender, receiver) = mpsc::channel(8);
let actor = MyActor::new(receiver, sender.clone().downgrade());

(Self { sender }, actor)
}

pub async fn get_unique_id(&self) -> u32 {
let (send, recv) = oneshot::channel();
let msg = ActorMessage::GetUniqueId { respond_to: send };

// Ignore send errors. If this send fails, so does the
// recv.await below. There's no reason to check the
// failure twice.
let _ = self.sender.send(msg).await;
recv.await.expect("Actor task has been killed")
}
}

let (handle, mut actor) = MyActorHandle::new();

let actor_handle = tokio::spawn(async move { actor.run().await });

let _ = tokio::spawn(async move {
let _ = handle.get_unique_id().await;
drop(handle);
})
.await;

let _ = actor_handle.await;
}

static NUM_DROPPED: AtomicUsize = AtomicUsize::new(0);

#[derive(Debug)]
struct Msg;

impl Drop for Msg {
fn drop(&mut self) {
NUM_DROPPED.fetch_add(1, Release);
}
}

// Tests that no pending messages are put onto the channel after `Rx` was
// dropped.
//
// Note: After the introduction of `WeakSender`, which internally
// used `Arc` and doesn't call a drop of the channel after the last strong
// `Sender` was dropped while more than one `WeakSender` remains, we want to
// ensure that no messages are kept in the channel, which were sent after
// the receiver was dropped.
#[tokio::test]
async fn test_msgs_dropped_on_rx_drop() {
let (tx, mut rx) = mpsc::channel(3);

tx.send(Msg {}).await.unwrap();
tx.send(Msg {}).await.unwrap();

// This msg will be pending and should be dropped when `rx` is dropped
let sent_fut = tx.send(Msg {});

let _ = rx.recv().await.unwrap();
let _ = rx.recv().await.unwrap();

sent_fut.await.unwrap();

drop(rx);

assert_eq!(NUM_DROPPED.load(Acquire), 3);

// This msg will not be put onto `Tx` list anymore, since `Rx` is closed.
assert!(tx.send(Msg {}).await.is_err());

assert_eq!(NUM_DROPPED.load(Acquire), 4);
}

// Tests that a `WeakSender` is upgradeable when other `Sender`s exist.
#[tokio::test]
async fn downgrade_upgrade_sender_success() {
let (tx, _rx) = mpsc::channel::<i32>(1);
let weak_tx = tx.downgrade();
assert!(weak_tx.upgrade().is_some());
}

// Tests that a `WeakSender` fails to upgrade when no other `Sender` exists.
#[tokio::test]
async fn downgrade_upgrade_sender_failure() {
let (tx, _rx) = mpsc::channel::<i32>(1);
let weak_tx = tx.downgrade();
drop(tx);
assert!(weak_tx.upgrade().is_none());
}

// Tests that a `WeakSender` cannot be upgraded after a `Sender` was dropped,
// which existed at the time of the `downgrade` call.
#[tokio::test]
async fn downgrade_drop_upgrade() {
let (tx, _rx) = mpsc::channel::<i32>(1);

// the cloned `Tx` is dropped right away
let weak_tx = tx.clone().downgrade();
drop(tx);
assert!(weak_tx.upgrade().is_none());
}

// Tests that we can upgrade a weak sender with an outstanding permit
// but no other strong senders.
#[tokio::test]
async fn downgrade_get_permit_upgrade_no_senders() {
let (tx, _rx) = mpsc::channel::<i32>(1);
let weak_tx = tx.downgrade();
let _permit = tx.reserve_owned().await.unwrap();
assert!(weak_tx.upgrade().is_some());
}

// Tests that you can downgrade and upgrade a sender with an outstanding permit
// but no other senders left.
#[tokio::test]
async fn downgrade_upgrade_get_permit_no_senders() {
let (tx, _rx) = mpsc::channel::<i32>(1);
let tx2 = tx.clone();
let _permit = tx.reserve_owned().await.unwrap();
let weak_tx = tx2.downgrade();
drop(tx2);
assert!(weak_tx.upgrade().is_some());
}

// Tests that `downgrade` does not change the `tx_count` of the channel.
#[tokio::test]
async fn test_tx_count_weak_sender() {
let (tx, _rx) = mpsc::channel::<i32>(1);
let tx_weak = tx.downgrade();
let tx_weak2 = tx.downgrade();
drop(tx);

assert!(tx_weak.upgrade().is_none() && tx_weak2.upgrade().is_none());
}

// Tests that channel `capacity` changes and `max_capacity` stays the same
#[tokio::test]
async fn test_tx_capacity() {
Expand Down
Loading