Skip to content

Commit

Permalink
address review
Browse files Browse the repository at this point in the history
  • Loading branch information
b-naber committed May 12, 2022
1 parent c1c024b commit 2c23c9c
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 32 deletions.
2 changes: 1 addition & 1 deletion tokio-util/src/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#[cfg(tokio_unstable)]
mod join_map;
mod spawn_pinned;
pub use spawn_pinned::LocalPoolHandle;
pub use spawn_pinned::{LocalPoolHandle, WorkerIdxError};

#[cfg(tokio_unstable)]
#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))]
Expand Down
64 changes: 42 additions & 22 deletions tokio-util/src/task/spawn_pinned.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use futures_util::future::{AbortHandle, Abortable};
use std::error::Error;
use std::fmt;
use std::fmt::{Debug, Formatter};
use std::future::Future;
Expand All @@ -9,19 +10,27 @@ use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot;
use tokio::task::{spawn_local, JoinHandle, LocalSet};

/// Error Type for out-of-bounds indexing error in [`LocalPoolHandle::spawn_pinned_by_idx`].
///
/// [`LocalPoolHandle::spawn_pinned_by_idx`]: LocalPoolHandle::spawn_pinned_by_idx
#[derive(Debug)]
pub struct WorkerIdxError(usize, usize);
pub struct WorkerIdxError {
idx: usize,
num_workers: usize,
}

impl fmt::Display for WorkerIdxError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"Index {} out of bounds, only {} workers in pool",
self.0, self.1
self.idx, self.num_workers
)
}
}

impl Error for WorkerIdxError {}

/// A cloneable handle to a local pool, used for spawning `!Send` tasks.
///
/// Internally the local pool uses a [`tokio::task::LocalSet`] for each worker thread
Expand Down Expand Up @@ -159,7 +168,10 @@ impl LocalPoolHandle {
Fut::Output: Send + 'static,
{
if idx >= self.pool.workers.len() {
return Err(WorkerIdxError(idx, self.pool.workers.len()));
return Err(WorkerIdxError {
idx,
num_workers: self.pool.workers.len(),
});
}

Ok(self
Expand All @@ -169,12 +181,32 @@ impl LocalPoolHandle {

/// Spawn a task on every worker thread in the pool and pin it so that it
/// can't be moved off of the thread.
///
/// # Examples
///
/// ```
/// use std::rc::Rc;
/// use tokio_util::task::LocalPoolHandle;
///
/// #[tokio::main]
/// async fn main() {
/// let pool = LocalPoolHandle::new(3);
///
/// let _ = pool.spawn_pinned_on_all_workers(|| {
/// // Rc is !Send + !Sync
/// let local_data = Rc::new("test");
///
/// // This future holds an Rc, so it is !Send
/// async move { local_data.to_string() }
/// });
/// }
/// ```
pub fn spawn_pinned_on_all_workers<F, Fut>(
&self,
create_task: F,
) -> Vec<JoinHandle<Fut::Output>>
where
F: Fn() -> Fut,
F: FnOnce() -> Fut,
F: Send + Clone + 'static,
Fut: Future + 'static,
Fut::Output: Send + 'static,
Expand Down Expand Up @@ -331,25 +363,13 @@ impl LocalPool {
}

fn find_worker_by_idx(&self, idx: usize) -> (&LocalWorkerHandle, JobCountGuard) {
loop {
let worker = &self.workers[idx];
let task_count = worker.task_count.load(Ordering::SeqCst);
let worker = &self.workers[idx];
let task_count = worker.task_count.load(Ordering::SeqCst);
worker
.task_count
.fetch_add(task_count + 1, Ordering::SeqCst);

// Make sure the task count hasn't changed since when we choose this
// worker. Otherwise, restart the search.
if worker
.task_count
.compare_exchange(
task_count,
task_count + 1,
Ordering::SeqCst,
Ordering::Relaxed,
)
.is_ok()
{
return (worker, JobCountGuard(Arc::clone(&worker.task_count)));
}
}
(worker, JobCountGuard(Arc::clone(&worker.task_count)))
}
}

Expand Down
37 changes: 28 additions & 9 deletions tokio-util/tests/spawn_pinned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use std::rc::Rc;
use std::sync::Arc;
use tokio::sync::Barrier;
use tokio_util::task;

/// Simple test of running a !Send future via spawn_pinned
Expand Down Expand Up @@ -195,11 +196,27 @@ async fn tasks_are_balanced() {
#[tokio::test]
async fn spawn_by_idx() {
let pool = task::LocalPoolHandle::new(3);

let handle1 = pool.spawn_pinned_by_idx(|| async { std::thread::current().id() }, 0);
let handle2 = pool.spawn_pinned_by_idx(|| async { std::thread::current().id() }, 1);
let barrier = Arc::new(Barrier::new(3));
let barrier1 = barrier.clone();
let barrier2 = barrier.clone();

let handle1 = pool.spawn_pinned_by_idx(
|| async move {
barrier1.wait().await;
std::thread::current().id()
},
0,
);
let handle2 = pool.spawn_pinned_by_idx(
|| async move {
barrier2.wait().await;
std::thread::current().id()
},
1,
);

let loads = pool.get_task_loads_for_each_worker();
barrier.wait().await;
assert_eq!(loads[0], 1);
assert_eq!(loads[1], 1);
assert_eq!(loads[2], 0);
Expand All @@ -212,17 +229,19 @@ async fn spawn_by_idx() {

#[tokio::test]
async fn spawn_on_all_workers() {
let pool = task::LocalPoolHandle::new(3);
const NUM_WORKERS: usize = 3;
let pool = task::LocalPoolHandle::new(NUM_WORKERS);
let barrier = Arc::new(Barrier::new(2));
let barrier_clone = barrier.clone();

let _ = pool.spawn_pinned_on_all_workers(|| {
// Rc is !Send + !Sync
let local_data = Rc::new("test");
let _ = pool.spawn_pinned_on_all_workers(|| async move {
barrier_clone.wait().await;

// This future holds an Rc, so it is !Send
async move { local_data.to_string() }
"test"
});

let loads = pool.get_task_loads_for_each_worker();
barrier.wait().await;
assert_eq!(loads[0], 1);
assert_eq!(loads[1], 1);
assert_eq!(loads[2], 1);
Expand Down

0 comments on commit 2c23c9c

Please sign in to comment.