Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

macros: join! start by polling a different future each time poll_fn is polled #4624

Merged
merged 1 commit into from May 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 46 additions & 11 deletions tokio/src/macros/join.rs
Expand Up @@ -60,6 +60,9 @@ macro_rules! join {
// normalization is complete.
( $($count:tt)* )

// The expression `0+1+1+ ... +1` equal to the number of branches.
( $($total:tt)* )

// Normalized join! branches
$( ( $($skip:tt)* ) $e:expr, )*

Expand All @@ -71,22 +74,54 @@ macro_rules! join {
// the requirement of `Pin::new_unchecked` called below.
let mut futures = ( $( maybe_done($e), )* );

// Each time the future created by poll_fn is polled, a different future will be polled first
// to ensure every future passed to join! gets a chance to make progress even if
// one of the futures consumes the whole budget.
//
// This is number of futures that will be skipped in the first loop
// iteration the next time.
let mut skip_next_time: u32 = 0;

poll_fn(move |cx| {
const COUNT: u32 = $($total)*;

let mut is_pending = false;

let mut to_run = COUNT;

// The number of futures that will be skipped in the first loop iteration.
let mut skip = skip_next_time;

skip_next_time = if skip + 1 == COUNT { 0 } else { skip + 1 };

// This loop runs twice and the first `skip` futures
// are not polled in the first iteration.
loop {
$(
// Extract the future for this branch from the tuple.
let ( $($skip,)* fut, .. ) = &mut futures;
if skip == 0 {
if to_run == 0 {
// Every future has been polled
break;
}
to_run -= 1;

// Safety: future is stored on the stack above
// and never moved.
let mut fut = unsafe { Pin::new_unchecked(fut) };
// Extract the future for this branch from the tuple.
let ( $($skip,)* fut, .. ) = &mut futures;

// Safety: future is stored on the stack above
// and never moved.
let mut fut = unsafe { Pin::new_unchecked(fut) };

// Try polling
if fut.poll(cx).is_pending() {
is_pending = true;
// Try polling
if fut.poll(cx).is_pending() {
is_pending = true;
}
} else {
// Future skipped, one less future to skip in the next iteration
skip -= 1;
}
)*
}

if is_pending {
Pending
Expand All @@ -107,13 +142,13 @@ macro_rules! join {

// ===== Normalize =====

(@ { ( $($s:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
$crate::join!(@{ ($($s)* _) $($t)* ($($s)*) $e, } $($r)*)
(@ { ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
$crate::join!(@{ ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
};

// ===== Entry point =====

( $($e:expr),* $(,)?) => {
$crate::join!(@{ () } $($e,)*)
$crate::join!(@{ () (0) } $($e,)*)
};
}
65 changes: 50 additions & 15 deletions tokio/src/macros/try_join.rs
Expand Up @@ -106,6 +106,9 @@ macro_rules! try_join {
// normalization is complete.
( $($count:tt)* )

// The expression `0+1+1+ ... +1` equal to the number of branches.
( $($total:tt)* )

// Normalized try_join! branches
$( ( $($skip:tt)* ) $e:expr, )*

Expand All @@ -117,24 +120,56 @@ macro_rules! try_join {
// the requirement of `Pin::new_unchecked` called below.
let mut futures = ( $( maybe_done($e), )* );

// Each time the future created by poll_fn is polled, a different future will be polled first
// to ensure every future passed to join! gets a chance to make progress even if
// one of the futures consumes the whole budget.
//
// This is number of futures that will be skipped in the first loop
// iteration the next time.
let mut skip_next_time: u32 = 0;

poll_fn(move |cx| {
const COUNT: u32 = $($total)*;

let mut is_pending = false;

let mut to_run = COUNT;

// The number of futures that will be skipped in the first loop iteration
let mut skip = skip_next_time;

skip_next_time = if skip + 1 == COUNT { 0 } else { skip + 1 };

// This loop runs twice and the first `skip` futures
// are not polled in the first iteration.
loop {
$(
// Extract the future for this branch from the tuple.
let ( $($skip,)* fut, .. ) = &mut futures;

// Safety: future is stored on the stack above
// and never moved.
let mut fut = unsafe { Pin::new_unchecked(fut) };

// Try polling
if fut.as_mut().poll(cx).is_pending() {
is_pending = true;
} else if fut.as_mut().output_mut().expect("expected completed future").is_err() {
return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap()))
if skip == 0 {
if to_run == 0 {
// Every future has been polled
break;
}
to_run -= 1;

// Extract the future for this branch from the tuple.
let ( $($skip,)* fut, .. ) = &mut futures;

// Safety: future is stored on the stack above
// and never moved.
let mut fut = unsafe { Pin::new_unchecked(fut) };

// Try polling
if fut.as_mut().poll(cx).is_pending() {
is_pending = true;
} else if fut.as_mut().output_mut().expect("expected completed future").is_err() {
return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap()))
}
} else {
// Future skipped, one less future to skip in the next iteration
skip -= 1;
}
)*
}

if is_pending {
Pending
Expand All @@ -159,13 +194,13 @@ macro_rules! try_join {

// ===== Normalize =====

(@ { ( $($s:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
$crate::try_join!(@{ ($($s)* _) $($t)* ($($s)*) $e, } $($r)*)
(@ { ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
$crate::try_join!(@{ ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
};

// ===== Entry point =====

( $($e:expr),* $(,)?) => {
$crate::try_join!(@{ () } $($e,)*)
$crate::try_join!(@{ () (0) } $($e,)*)
};
}
77 changes: 74 additions & 3 deletions tokio/tests/macros_join.rs
@@ -1,5 +1,6 @@
#![cfg(feature = "macros")]
#![allow(clippy::blacklisted_name)]
use std::sync::Arc;

#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test as test;
Expand All @@ -9,7 +10,7 @@ use wasm_bindgen_test::wasm_bindgen_test as maybe_tokio_test;
#[cfg(not(target_arch = "wasm32"))]
use tokio::test as maybe_tokio_test;

use tokio::sync::oneshot;
use tokio::sync::{oneshot, Semaphore};
use tokio_test::{assert_pending, assert_ready, task};

#[maybe_tokio_test]
Expand Down Expand Up @@ -71,12 +72,82 @@ fn join_size() {
let ready = future::ready(0i32);
tokio::join!(ready)
};
assert_eq!(mem::size_of_val(&fut), 16);
assert_eq!(mem::size_of_val(&fut), 20);

let fut = async {
let ready1 = future::ready(0i32);
let ready2 = future::ready(0i32);
tokio::join!(ready1, ready2)
};
assert_eq!(mem::size_of_val(&fut), 28);
assert_eq!(mem::size_of_val(&fut), 32);
}

async fn non_cooperative_task(permits: Arc<Semaphore>) -> usize {
let mut exceeded_budget = 0;

for _ in 0..5 {
// Another task should run after after this task uses its whole budget
for _ in 0..128 {
let _permit = permits.clone().acquire_owned().await.unwrap();
}

exceeded_budget += 1;
}

exceeded_budget
}

async fn poor_little_task(permits: Arc<Semaphore>) -> usize {
let mut how_many_times_i_got_to_run = 0;

for _ in 0..5 {
let _permit = permits.clone().acquire_owned().await.unwrap();
how_many_times_i_got_to_run += 1;
}

how_many_times_i_got_to_run
}

#[tokio::test]
async fn join_does_not_allow_tasks_to_starve() {
let permits = Arc::new(Semaphore::new(1));

// non_cooperative_task should yield after its budget is exceeded and then poor_little_task should run.
let (non_cooperative_result, little_task_result) = tokio::join!(
non_cooperative_task(Arc::clone(&permits)),
poor_little_task(permits)
);

assert_eq!(5, non_cooperative_result);
assert_eq!(5, little_task_result);
}

#[tokio::test]
async fn a_different_future_is_polled_first_every_time_poll_fn_is_polled() {
let poll_order = Arc::new(std::sync::Mutex::new(vec![]));

let fut = |x, poll_order: Arc<std::sync::Mutex<Vec<i32>>>| async move {
for _ in 0..4 {
{
let mut guard = poll_order.lock().unwrap();

guard.push(x);
}

tokio::task::yield_now().await;
}
};

tokio::join!(
fut(1, Arc::clone(&poll_order)),
fut(2, Arc::clone(&poll_order)),
fut(3, Arc::clone(&poll_order)),
);

// Each time the future created by join! is polled, it should start
// by polling a different future first.
assert_eq!(
vec![1, 2, 3, 2, 3, 1, 3, 1, 2, 1, 2, 3],
*poll_order.lock().unwrap()
);
}
81 changes: 78 additions & 3 deletions tokio/tests/macros_try_join.rs
@@ -1,7 +1,9 @@
#![cfg(feature = "macros")]
#![allow(clippy::blacklisted_name)]

use tokio::sync::oneshot;
use std::sync::Arc;

use tokio::sync::{oneshot, Semaphore};
use tokio_test::{assert_pending, assert_ready, task};

#[cfg(target_arch = "wasm32")]
Expand Down Expand Up @@ -94,16 +96,89 @@ fn join_size() {
let ready = future::ready(ok(0i32));
tokio::try_join!(ready)
};
assert_eq!(mem::size_of_val(&fut), 16);
assert_eq!(mem::size_of_val(&fut), 20);

let fut = async {
let ready1 = future::ready(ok(0i32));
let ready2 = future::ready(ok(0i32));
tokio::try_join!(ready1, ready2)
};
assert_eq!(mem::size_of_val(&fut), 28);
assert_eq!(mem::size_of_val(&fut), 32);
}

fn ok<T>(val: T) -> Result<T, ()> {
Ok(val)
}

async fn non_cooperative_task(permits: Arc<Semaphore>) -> Result<usize, String> {
let mut exceeded_budget = 0;

for _ in 0..5 {
// Another task should run after after this task uses its whole budget
for _ in 0..128 {
let _permit = permits.clone().acquire_owned().await.unwrap();
}

exceeded_budget += 1;
}

Ok(exceeded_budget)
}

async fn poor_little_task(permits: Arc<Semaphore>) -> Result<usize, String> {
let mut how_many_times_i_got_to_run = 0;

for _ in 0..5 {
let _permit = permits.clone().acquire_owned().await.unwrap();

how_many_times_i_got_to_run += 1;
}

Ok(how_many_times_i_got_to_run)
}

#[tokio::test]
async fn try_join_does_not_allow_tasks_to_starve() {
let permits = Arc::new(Semaphore::new(10));

// non_cooperative_task should yield after its budget is exceeded and then poor_little_task should run.
let result = tokio::try_join!(
non_cooperative_task(Arc::clone(&permits)),
poor_little_task(permits)
);

let (non_cooperative_result, little_task_result) = result.unwrap();

assert_eq!(5, non_cooperative_result);
assert_eq!(5, little_task_result);
}

#[tokio::test]
async fn a_different_future_is_polled_first_every_time_poll_fn_is_polled() {
let poll_order = Arc::new(std::sync::Mutex::new(vec![]));

let fut = |x, poll_order: Arc<std::sync::Mutex<Vec<i32>>>| async move {
for _ in 0..4 {
{
let mut guard = poll_order.lock().unwrap();

guard.push(x);
}

tokio::task::yield_now().await;
}
};

tokio::join!(
fut(1, Arc::clone(&poll_order)),
fut(2, Arc::clone(&poll_order)),
fut(3, Arc::clone(&poll_order)),
);

// Each time the future created by join! is polled, it should start
// by polling a different future first.
assert_eq!(
vec![1, 2, 3, 2, 3, 1, 3, 1, 2, 1, 2, 3],
*poll_order.lock().unwrap()
);
}