diff --git a/tokio/src/macros/join.rs b/tokio/src/macros/join.rs index f91b5f19145..ab85cdec26f 100644 --- a/tokio/src/macros/join.rs +++ b/tokio/src/macros/join.rs @@ -60,6 +60,9 @@ macro_rules! join { // normalization is complete. ( $($count:tt)* ) + // The expression `0+1+1+ ... +1` equal to the number of branches. + ( $($total:tt)* ) + // Normalized join! branches $( ( $($skip:tt)* ) $e:expr, )* @@ -71,22 +74,54 @@ macro_rules! join { // the requirement of `Pin::new_unchecked` called below. let mut futures = ( $( maybe_done($e), )* ); + // Each time the future created by poll_fn is polled, a different future will be polled first + // to ensure every future passed to join! gets a chance to make progress even if + // one of the futures consumes the whole budget. + // + // This is number of futures that will be skipped in the first loop + // iteration the next time. + let mut skip_next_time: u32 = 0; + poll_fn(move |cx| { + const COUNT: u32 = $($total)*; + let mut is_pending = false; + let mut to_run = COUNT; + + // The number of futures that will be skipped in the first loop iteration. + let mut skip = skip_next_time; + + skip_next_time = if skip + 1 == COUNT { 0 } else { skip + 1 }; + + // This loop runs twice and the first `skip` futures + // are not polled in the first iteration. + loop { $( - // Extract the future for this branch from the tuple. - let ( $($skip,)* fut, .. ) = &mut futures; + if skip == 0 { + if to_run == 0 { + // Every future has been polled + break; + } + to_run -= 1; - // Safety: future is stored on the stack above - // and never moved. - let mut fut = unsafe { Pin::new_unchecked(fut) }; + // Extract the future for this branch from the tuple. + let ( $($skip,)* fut, .. ) = &mut futures; + + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; - // Try polling - if fut.poll(cx).is_pending() { - is_pending = true; + // Try polling + if fut.poll(cx).is_pending() { + is_pending = true; + } + } else { + // Future skipped, one less future to skip in the next iteration + skip -= 1; } )* + } if is_pending { Pending @@ -107,13 +142,13 @@ macro_rules! join { // ===== Normalize ===== - (@ { ( $($s:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { - $crate::join!(@{ ($($s)* _) $($t)* ($($s)*) $e, } $($r)*) + (@ { ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { + $crate::join!(@{ ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*) }; // ===== Entry point ===== ( $($e:expr),* $(,)?) => { - $crate::join!(@{ () } $($e,)*) + $crate::join!(@{ () (0) } $($e,)*) }; } diff --git a/tokio/src/macros/try_join.rs b/tokio/src/macros/try_join.rs index 6d3a893b7e1..c80395c9097 100644 --- a/tokio/src/macros/try_join.rs +++ b/tokio/src/macros/try_join.rs @@ -106,6 +106,9 @@ macro_rules! try_join { // normalization is complete. ( $($count:tt)* ) + // The expression `0+1+1+ ... +1` equal to the number of branches. + ( $($total:tt)* ) + // Normalized try_join! branches $( ( $($skip:tt)* ) $e:expr, )* @@ -117,24 +120,56 @@ macro_rules! try_join { // the requirement of `Pin::new_unchecked` called below. let mut futures = ( $( maybe_done($e), )* ); + // Each time the future created by poll_fn is polled, a different future will be polled first + // to ensure every future passed to join! gets a chance to make progress even if + // one of the futures consumes the whole budget. + // + // This is number of futures that will be skipped in the first loop + // iteration the next time. + let mut skip_next_time: u32 = 0; + poll_fn(move |cx| { + const COUNT: u32 = $($total)*; + let mut is_pending = false; + let mut to_run = COUNT; + + // The number of futures that will be skipped in the first loop iteration + let mut skip = skip_next_time; + + skip_next_time = if skip + 1 == COUNT { 0 } else { skip + 1 }; + + // This loop runs twice and the first `skip` futures + // are not polled in the first iteration. + loop { $( - // Extract the future for this branch from the tuple. - let ( $($skip,)* fut, .. ) = &mut futures; - - // Safety: future is stored on the stack above - // and never moved. - let mut fut = unsafe { Pin::new_unchecked(fut) }; - - // Try polling - if fut.as_mut().poll(cx).is_pending() { - is_pending = true; - } else if fut.as_mut().output_mut().expect("expected completed future").is_err() { - return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap())) + if skip == 0 { + if to_run == 0 { + // Every future has been polled + break; + } + to_run -= 1; + + // Extract the future for this branch from the tuple. + let ( $($skip,)* fut, .. ) = &mut futures; + + // Safety: future is stored on the stack above + // and never moved. + let mut fut = unsafe { Pin::new_unchecked(fut) }; + + // Try polling + if fut.as_mut().poll(cx).is_pending() { + is_pending = true; + } else if fut.as_mut().output_mut().expect("expected completed future").is_err() { + return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap())) + } + } else { + // Future skipped, one less future to skip in the next iteration + skip -= 1; } )* + } if is_pending { Pending @@ -159,13 +194,13 @@ macro_rules! try_join { // ===== Normalize ===== - (@ { ( $($s:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { - $crate::try_join!(@{ ($($s)* _) $($t)* ($($s)*) $e, } $($r)*) + (@ { ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => { + $crate::try_join!(@{ ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*) }; // ===== Entry point ===== ( $($e:expr),* $(,)?) => { - $crate::try_join!(@{ () } $($e,)*) + $crate::try_join!(@{ () (0) } $($e,)*) }; } diff --git a/tokio/tests/macros_join.rs b/tokio/tests/macros_join.rs index d4f20b3862c..56bd9c58f3c 100644 --- a/tokio/tests/macros_join.rs +++ b/tokio/tests/macros_join.rs @@ -1,5 +1,6 @@ #![cfg(feature = "macros")] #![allow(clippy::blacklisted_name)] +use std::sync::Arc; #[cfg(target_arch = "wasm32")] use wasm_bindgen_test::wasm_bindgen_test as test; @@ -9,7 +10,7 @@ use wasm_bindgen_test::wasm_bindgen_test as maybe_tokio_test; #[cfg(not(target_arch = "wasm32"))] use tokio::test as maybe_tokio_test; -use tokio::sync::oneshot; +use tokio::sync::{oneshot, Semaphore}; use tokio_test::{assert_pending, assert_ready, task}; #[maybe_tokio_test] @@ -71,12 +72,82 @@ fn join_size() { let ready = future::ready(0i32); tokio::join!(ready) }; - assert_eq!(mem::size_of_val(&fut), 16); + assert_eq!(mem::size_of_val(&fut), 20); let fut = async { let ready1 = future::ready(0i32); let ready2 = future::ready(0i32); tokio::join!(ready1, ready2) }; - assert_eq!(mem::size_of_val(&fut), 28); + assert_eq!(mem::size_of_val(&fut), 32); +} + +async fn non_cooperative_task(permits: Arc) -> usize { + let mut exceeded_budget = 0; + + for _ in 0..5 { + // Another task should run after after this task uses its whole budget + for _ in 0..128 { + let _permit = permits.clone().acquire_owned().await.unwrap(); + } + + exceeded_budget += 1; + } + + exceeded_budget +} + +async fn poor_little_task(permits: Arc) -> usize { + let mut how_many_times_i_got_to_run = 0; + + for _ in 0..5 { + let _permit = permits.clone().acquire_owned().await.unwrap(); + how_many_times_i_got_to_run += 1; + } + + how_many_times_i_got_to_run +} + +#[tokio::test] +async fn join_does_not_allow_tasks_to_starve() { + let permits = Arc::new(Semaphore::new(1)); + + // non_cooperative_task should yield after its budget is exceeded and then poor_little_task should run. + let (non_cooperative_result, little_task_result) = tokio::join!( + non_cooperative_task(Arc::clone(&permits)), + poor_little_task(permits) + ); + + assert_eq!(5, non_cooperative_result); + assert_eq!(5, little_task_result); +} + +#[tokio::test] +async fn a_different_future_is_polled_first_every_time_poll_fn_is_polled() { + let poll_order = Arc::new(std::sync::Mutex::new(vec![])); + + let fut = |x, poll_order: Arc>>| async move { + for _ in 0..4 { + { + let mut guard = poll_order.lock().unwrap(); + + guard.push(x); + } + + tokio::task::yield_now().await; + } + }; + + tokio::join!( + fut(1, Arc::clone(&poll_order)), + fut(2, Arc::clone(&poll_order)), + fut(3, Arc::clone(&poll_order)), + ); + + // Each time the future created by join! is polled, it should start + // by polling a different future first. + assert_eq!( + vec![1, 2, 3, 2, 3, 1, 3, 1, 2, 1, 2, 3], + *poll_order.lock().unwrap() + ); } diff --git a/tokio/tests/macros_try_join.rs b/tokio/tests/macros_try_join.rs index 60a726b659a..556436d2207 100644 --- a/tokio/tests/macros_try_join.rs +++ b/tokio/tests/macros_try_join.rs @@ -1,7 +1,9 @@ #![cfg(feature = "macros")] #![allow(clippy::blacklisted_name)] -use tokio::sync::oneshot; +use std::sync::Arc; + +use tokio::sync::{oneshot, Semaphore}; use tokio_test::{assert_pending, assert_ready, task}; #[cfg(target_arch = "wasm32")] @@ -94,16 +96,89 @@ fn join_size() { let ready = future::ready(ok(0i32)); tokio::try_join!(ready) }; - assert_eq!(mem::size_of_val(&fut), 16); + assert_eq!(mem::size_of_val(&fut), 20); let fut = async { let ready1 = future::ready(ok(0i32)); let ready2 = future::ready(ok(0i32)); tokio::try_join!(ready1, ready2) }; - assert_eq!(mem::size_of_val(&fut), 28); + assert_eq!(mem::size_of_val(&fut), 32); } fn ok(val: T) -> Result { Ok(val) } + +async fn non_cooperative_task(permits: Arc) -> Result { + let mut exceeded_budget = 0; + + for _ in 0..5 { + // Another task should run after after this task uses its whole budget + for _ in 0..128 { + let _permit = permits.clone().acquire_owned().await.unwrap(); + } + + exceeded_budget += 1; + } + + Ok(exceeded_budget) +} + +async fn poor_little_task(permits: Arc) -> Result { + let mut how_many_times_i_got_to_run = 0; + + for _ in 0..5 { + let _permit = permits.clone().acquire_owned().await.unwrap(); + + how_many_times_i_got_to_run += 1; + } + + Ok(how_many_times_i_got_to_run) +} + +#[tokio::test] +async fn try_join_does_not_allow_tasks_to_starve() { + let permits = Arc::new(Semaphore::new(10)); + + // non_cooperative_task should yield after its budget is exceeded and then poor_little_task should run. + let result = tokio::try_join!( + non_cooperative_task(Arc::clone(&permits)), + poor_little_task(permits) + ); + + let (non_cooperative_result, little_task_result) = result.unwrap(); + + assert_eq!(5, non_cooperative_result); + assert_eq!(5, little_task_result); +} + +#[tokio::test] +async fn a_different_future_is_polled_first_every_time_poll_fn_is_polled() { + let poll_order = Arc::new(std::sync::Mutex::new(vec![])); + + let fut = |x, poll_order: Arc>>| async move { + for _ in 0..4 { + { + let mut guard = poll_order.lock().unwrap(); + + guard.push(x); + } + + tokio::task::yield_now().await; + } + }; + + tokio::join!( + fut(1, Arc::clone(&poll_order)), + fut(2, Arc::clone(&poll_order)), + fut(3, Arc::clone(&poll_order)), + ); + + // Each time the future created by join! is polled, it should start + // by polling a different future first. + assert_eq!( + vec![1, 2, 3, 2, 3, 1, 3, 1, 2, 1, 2, 3], + *poll_order.lock().unwrap() + ); +}