Skip to content

Commit

Permalink
Moved LRTD from tokio-rs/tokio#6256
Browse files Browse the repository at this point in the history
  • Loading branch information
zolyfarkas committed Jan 2, 2024
1 parent a9e84ad commit 01bf082
Show file tree
Hide file tree
Showing 4 changed files with 456 additions and 0 deletions.
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ tokio-stream = "0.1.11"
futures-util = "0.3.19"
pin-project-lite = "0.2.7"
tokio = { version = "1.31.0", features = ["rt", "time", "net"], optional = true }
rand = "0.8.5"

[target.'cfg(unix)'.dependencies]
libc = { version = "0.2.149" }

[dev-dependencies]
axum = "0.6"
Expand All @@ -33,6 +37,9 @@ serde = { version = "1.0.136", features = ["derive"] }
serde_json = "1.0.79"
tokio = { version = "1.26.0", features = ["full", "rt", "time", "macros", "test-util"] }

[target.'cfg(unix)'.dev-dependencies]
libc = { version = "0.2.149"}

[[example]]
name = "runtime"
required-features = ["rt"]
Expand Down
8 changes: 8 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,11 @@ cfg_rt! {

mod task;
pub use task::{Instrumented, TaskMetrics, TaskMonitor};

#[cfg(unix)]
pub mod lrtd;
#[cfg(unix)]
pub use lrtd::{
LongRunningTaskDetector,
BlockingActionHandler
};
256 changes: 256 additions & 0 deletions src/lrtd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
//! Utility to help with "really nice to add a warning for tasks that might be blocking"
use libc;
use rand::thread_rng;
use rand::Rng;
use std::collections::HashSet;
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use std::{env, thread};
use tokio::runtime::{Builder, Runtime};

const PANIC_WORKER_BLOCK_DURATION_DEFAULT: Duration = Duration::from_secs(60);

fn get_panic_worker_block_duration() -> Duration {
let duration_str = env::var("MY_DURATION_ENV").unwrap_or_else(|_| "60".to_string());
duration_str
.parse::<u64>()
.map(Duration::from_secs)
.unwrap_or(PANIC_WORKER_BLOCK_DURATION_DEFAULT)
}

fn get_thread_id() -> libc::pthread_t {
unsafe { libc::pthread_self() }
}

/// A trait for handling actions when blocking is detected.
///
/// This trait provides a method for handling the detection of a blocking action.
pub trait BlockingActionHandler: Send + Sync {
/// Called when a blocking action is detected and prior to thread signaling.
///
/// # Arguments
///
/// * `workers` - The list of thread IDs of the tokio runtime worker threads. /// # Returns
///
fn blocking_detected(&self, workers: &[libc::pthread_t]);
}

struct StdErrBlockingActionHandler;

/// BlockingActionHandler implementation that writes blocker details to standard error.
impl BlockingActionHandler for StdErrBlockingActionHandler {
fn blocking_detected(&self, workers: &[libc::pthread_t]) {
eprintln!("Detected blocking in worker threads: {:?}", workers);
}
}

#[derive(Debug)]
struct WorkerSet {
inner: Mutex<HashSet<libc::pthread_t>>,
}

impl WorkerSet {
fn new() -> Self {
WorkerSet {
inner: Mutex::new(HashSet::new()),
}
}

fn add(&self, pid: libc::pthread_t) {
let mut set = self.inner.lock().unwrap();
set.insert(pid);
}

fn remove(&self, pid: libc::pthread_t) {
let mut set = self.inner.lock().unwrap();
set.remove(&pid);
}

fn get_all(&self) -> Vec<libc::pthread_t> {
let set = self.inner.lock().unwrap();
set.iter().cloned().collect()
}
}

/// Utility to help with "really nice to add a warning for tasks that might be blocking"
#[derive(Debug)]
pub struct LongRunningTaskDetector {
interval: Duration,
detection_time: Duration,
stop_flag: Arc<Mutex<bool>>,
workers: Arc<WorkerSet>,
}

async fn do_nothing(tx: mpsc::Sender<()>) {
// signal I am done
tx.send(()).unwrap();
}

fn probe(
tokio_runtime: &Arc<Runtime>,
detection_time: Duration,
workers: &Arc<WorkerSet>,
action: &Arc<dyn BlockingActionHandler>,
) {
let (tx, rx) = mpsc::channel();
let _nothing_handle = tokio_runtime.spawn(do_nothing(tx));
let is_probe_success = match rx.recv_timeout(detection_time) {
Ok(_result) => true,
Err(_) => false,
};
if !is_probe_success {
let targets = workers.get_all();
action.blocking_detected(&targets);
rx.recv_timeout(get_panic_worker_block_duration()).unwrap();
}
}

/// Utility to help with "really nice to add a warning for tasks that might be blocking"
/// Example use:
/// ```
/// use std::sync::Arc;
/// use tokio::runtime::lrtd::LongRunningTaskDetector;
///
/// let mut builder = tokio::runtime::Builder::new_multi_thread();
/// let mutable_builder = builder.worker_threads(2);
/// let lrtd = LongRunningTaskDetector::new(
/// std::time::Duration::from_millis(10),
/// std::time::Duration::from_millis(100),
/// mutable_builder,
/// );
/// let runtime = builder.enable_all().build().unwrap();
/// let arc_runtime = Arc::new(runtime);
/// let arc_runtime2 = arc_runtime.clone();
/// lrtd.start(arc_runtime);
/// arc_runtime2.block_on(async {
/// print!("my async code")
/// });
///
/// ```
///
/// The above will allow you to get details on what is blocking your tokio worker threads for longer that 100ms.
/// The detail will look like:
///
/// ```text
/// Detected blocking in worker threads: [123145318232064, 123145320341504]
/// ```
///
/// To get more details(like stack traces) start LongRunningTaskDetector with start_with_custom_action and provide a custom handler.
///
impl LongRunningTaskDetector {
/// Creates a new `LongRunningTaskDetector` instance.
///
/// # Arguments
///
/// * `interval` - The interval between probes. This interval is randomized.
/// * `detection_time` - The maximum time allowed for a probe to succeed.
/// A probe running for longer indicates something is blocking the worker threads.
/// * `runtime_builder` - A mutable reference to a `tokio::runtime::Builder`.
///
/// # Returns
///
/// Returns a new `LongRunningTaskDetector` instance.
pub fn new(
interval: Duration,
detection_time: Duration,
current_threaded: bool,
) -> (Self, Builder) {
let workers = Arc::new(WorkerSet::new());
if current_threaded {
workers.add(get_thread_id());
let runtime_builder = tokio::runtime::Builder::new_current_thread();
(
LongRunningTaskDetector {
interval,
detection_time,
stop_flag: Arc::new(Mutex::new(true)),
workers,
},
runtime_builder,
)
} else {
let mut runtime_builder = tokio::runtime::Builder::new_multi_thread();
let workers_clone = Arc::clone(&workers);
let workers_clone2 = Arc::clone(&workers);
runtime_builder
.on_thread_start(move || {
let pid = get_thread_id();
workers_clone.add(pid);
})
.on_thread_stop(move || {
let pid = get_thread_id();
workers_clone2.remove(pid);
});
(
LongRunningTaskDetector {
interval,
detection_time,
stop_flag: Arc::new(Mutex::new(true)),
workers,
},
runtime_builder,
)
}
}

pub fn new_single_threaded(interval: Duration, detection_time: Duration) -> (Self, Builder) {
LongRunningTaskDetector::new(interval, detection_time, true)
}

pub fn new_multi_threaded(interval: Duration, detection_time: Duration) -> (Self, Builder) {
LongRunningTaskDetector::new(interval, detection_time, false)
}

/// Starts the monitoring thread with default action handlers (that write details to std err).
///
/// # Arguments
///
/// * `runtime` - An `Arc` reference to a `tokio::runtime::Runtime`.
pub fn start(&self, runtime: Arc<Runtime>) {
self.start_with_custom_action(runtime, Arc::new(StdErrBlockingActionHandler))
}

/// Starts the monitoring process with custom action handlers that
/// allow you to customize what happens when blocking is detected.
///
/// # Arguments
///
/// * `runtime` - An `Arc` reference to a `tokio::runtime::Runtime`.
/// * `action` - An `Arc` reference to a custom `BlockingActionHandler`.
/// * `thread_action` - An `Arc` reference to a custom `ThreadStateHandler`.
pub fn start_with_custom_action(
&self,
runtime: Arc<Runtime>,
action: Arc<dyn BlockingActionHandler>,
) {
*self.stop_flag.lock().unwrap() = false;
let stop_flag = Arc::clone(&self.stop_flag);
let detection_time = self.detection_time;
let interval = self.interval;
let workers = Arc::clone(&self.workers);
thread::spawn(move || {
let mut rng = thread_rng();
while !*stop_flag.lock().unwrap() {
probe(&runtime, detection_time, &workers, &action);
thread::sleep(Duration::from_micros(
rng.gen_range(10..=interval.as_micros().try_into().unwrap()),
));
}
});
}

/// Stops the monitoring thread. Does nothing if LRTD is already stopped.
pub fn stop(&self) {
let mut sf = self.stop_flag.lock().unwrap();
if !(*sf) {
*sf = true;
}
}
}

impl Drop for LongRunningTaskDetector {
fn drop(&mut self) {
self.stop();
}
}
Loading

0 comments on commit 01bf082

Please sign in to comment.