From 2e5773a6fef7444e8959a992c7ed7d32f97af284 Mon Sep 17 00:00:00 2001 From: Alice Ryhl Date: Tue, 14 Nov 2023 14:01:10 +0100 Subject: [PATCH] runtime: handle missing context on wake (#6148) --- tokio/src/runtime/context.rs | 4 ++- tokio/tests/rt_common.rs | 62 ++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index c49c08932b2..07875a0723f 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -178,7 +178,9 @@ cfg_rt! { #[track_caller] pub(super) fn with_scheduler(f: impl FnOnce(Option<&scheduler::Context>) -> R) -> R { - CONTEXT.with(|c| c.scheduler.with(f)) + let mut f = Some(f); + CONTEXT.try_with(|c| c.scheduler.with(f.take().unwrap())) + .unwrap_or_else(|_| (f.take().unwrap())(None)) } cfg_taskdump! { diff --git a/tokio/tests/rt_common.rs b/tokio/tests/rt_common.rs index abca8dd667a..11c44a8d1c2 100644 --- a/tokio/tests/rt_common.rs +++ b/tokio/tests/rt_common.rs @@ -1366,4 +1366,66 @@ rt_test! { th.join().unwrap(); } } + + #[test] + #[cfg_attr(target_family = "wasm", ignore)] + fn wake_by_ref_from_thread_local() { + wake_from_thread_local(true); + } + + #[test] + #[cfg_attr(target_family = "wasm", ignore)] + fn wake_by_val_from_thread_local() { + wake_from_thread_local(false); + } + + fn wake_from_thread_local(by_ref: bool) { + use std::cell::RefCell; + use std::sync::mpsc::{channel, Sender}; + use std::task::Waker; + + struct TLData { + by_ref: bool, + waker: Option, + done: Sender, + } + + impl Drop for TLData { + fn drop(&mut self) { + if self.by_ref { + self.waker.take().unwrap().wake_by_ref(); + } else { + self.waker.take().unwrap().wake(); + } + let _ = self.done.send(true); + } + } + + std::thread_local! { + static TL_DATA: RefCell> = RefCell::new(None); + }; + + let (send, recv) = channel(); + + std::thread::spawn(move || { + let rt = rt(); + rt.block_on(rt.spawn(poll_fn(move |cx| { + let waker = cx.waker().clone(); + let send = send.clone(); + TL_DATA.with(|tl| { + tl.replace(Some(TLData { + by_ref, + waker: Some(waker), + done: send, + })); + }); + Poll::Ready(()) + }))) + .unwrap(); + }) + .join() + .unwrap(); + + assert!(recv.recv().unwrap()); + } }