From b39fe195cca587869c0686b366e36e077cc789c4 Mon Sep 17 00:00:00 2001 From: Matheus Consoli Date: Fri, 18 Nov 2022 03:25:47 -0300 Subject: [PATCH] Impl "perfect" waking for `tuple::merge` --- src/stream/merge/tuple.rs | 208 ++++++++++++++++++++++++++++---------- src/utils/tuple.rs | 1 - 2 files changed, 153 insertions(+), 56 deletions(-) diff --git a/src/stream/merge/tuple.rs b/src/stream/merge/tuple.rs index 814396e..a071ab8 100644 --- a/src/stream/merge/tuple.rs +++ b/src/stream/merge/tuple.rs @@ -1,15 +1,41 @@ use super::Merge as MergeTrait; use crate::stream::IntoStream; -use crate::utils; +use crate::utils::{self, PollArray, WakerArray}; use core::fmt; use futures_core::Stream; use std::pin::Pin; use std::task::{Context, Poll}; -// TODO: handle none case +macro_rules! poll_stream { + ($stream_idx:tt, $iteration:ident, $this:ident, $streams:ident, $cx:ident, $len_streams:ident) => { + if $stream_idx == $iteration { + match unsafe { Pin::new_unchecked(&mut $streams.$stream_idx) }.poll_next(&mut $cx) { + Poll::Ready(Some(item)) => { + // Mark ourselves as ready again because we need to poll for the next item. + $this + .wakers + .readiness() + .lock() + .unwrap() + .set_ready($stream_idx); + return Poll::Ready(Some(item)); + } + Poll::Ready(None) => { + *$this.completed += 1; + $this.state[$stream_idx].set_consumed(); + if *$this.completed == $len_streams { + return Poll::Ready(None); + } + } + Poll::Pending => {} + } + } + }; +} + macro_rules! impl_merge_tuple { - ($StructName:ident) => { + ($ignore:ident $StructName:ident) => { /// A stream that merges multiple streams into a single stream. /// /// This `struct` is created by the [`merge`] method on the [`Merge`] trait. See its @@ -17,10 +43,7 @@ macro_rules! impl_merge_tuple { /// /// [`merge`]: trait.Merge.html#method.merge /// [`Merge`]: trait.Merge.html - #[pin_project::pin_project] - pub struct $StructName { - rng: utils::RandomGenerator, - } + pub struct $StructName {} impl fmt::Debug for $StructName { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -29,7 +52,7 @@ macro_rules! impl_merge_tuple { } impl Stream for $StructName { - type Item = std::convert::Infallible; // TODO: convert to `never` type in the stdlib + type Item = core::convert::Infallible; // TODO: convert to `never` type in the stdlib fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(None) @@ -37,17 +60,21 @@ macro_rules! impl_merge_tuple { } impl MergeTrait for () { - type Item = std::convert::Infallible; // TODO: convert to `never` type in the stdlib + type Item = core::convert::Infallible; // TODO: convert to `never` type in the stdlib type Stream = $StructName; fn merge(self) -> Self::Stream { - $StructName { - rng: utils::RandomGenerator::new(), - } + $StructName { } } } }; - ($StructName:ident $($F:ident)+) => { + ($mod_name:ident $StructName:ident $($F:ident)+) => { + mod $mod_name { + #[derive(Debug)] + #[pin_project::pin_project] + pub(super) struct Streams<$($F,)+>($(#[pin] pub(super) $F,)+); + } + /// A stream that merges multiple streams into a single stream. /// /// This `struct` is created by the [`merge`] method on the [`Merge`] trait. See its @@ -60,9 +87,12 @@ macro_rules! impl_merge_tuple { where $( $F: Stream, )* { - done: bool, - $(#[pin] $F: $F,)* + #[pin] streams: $mod_name::Streams<$($F,)+>, rng: utils::RandomGenerator, + wakers: WakerArray<{utils::tuple_len!($($F,)+)}>, + state: PollArray<{utils::tuple_len!($($F,)+)}>, + completed: u8, + done: bool, } impl fmt::Debug for $StructName @@ -72,7 +102,7 @@ macro_rules! impl_merge_tuple { )* { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_tuple("Merge") - $(.field(&self.$F))* + .field(&self.streams) .finish() } } @@ -84,33 +114,41 @@ macro_rules! impl_merge_tuple { type Item = T; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); + let this = self.project(); - // Return early in case we're polled again after completion. - if *this.done { - return Poll::Ready(None); - } + let mut readiness = this.wakers.readiness().lock().unwrap(); + readiness.set_waker(cx.waker()); - const LEN: u32 = utils::tuple_len!($($F,)*); - const PERMUTATIONS: u32 = utils::permutations(LEN); - let r = this.rng.generate(PERMUTATIONS); - let mut pending = false; - for i in 0..LEN { - utils::gen_conditions!(LEN, i, r, this, cx, poll_next, { - Poll::Ready(Some(value)) => return Poll::Ready(Some(value)), - Poll::Ready(None) => continue, - Poll::Pending => { - pending = true; - continue - }, - }, $($F,)*); - } - if pending { - Poll::Pending - } else { - *this.done = true; - Poll::Ready(None) + const LEN: u8 = utils::tuple_len!($($F,)*); + let r = this.rng.generate(LEN as u32) as u8; + + let mut streams = this.streams.project(); + + // Iterate over our streams one-by-one. If a stream yields a value, + // we exit early. By default we'll return `Poll::Ready(None)`, but + // this changes if we encounter a `Poll::Pending`. + for index in (0..LEN).map(|n| (r + n).wrapping_rem(LEN) as usize) { + if !readiness.any_ready() { + // Nothing is ready yet + return Poll::Pending; + } else if !readiness.clear_ready(index) || this.state[index].is_consumed() { + continue; + } + + // unlock readiness so we don't deadlock when polling + drop(readiness); + + // Obtain the intermediate waker. + let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); + + // poll the `streams.{index}` stream + utils::tuple_for_each!(poll_stream (index, this, streams, cx, LEN) $($F)*); + + // Lock readiness so we can use it again + readiness = this.wakers.readiness().lock().unwrap(); } + + Poll::Pending } } @@ -124,32 +162,36 @@ macro_rules! impl_merge_tuple { fn merge(self) -> Self::Stream { let ($($F,)*): ($($F,)*) = self; $StructName { - done: false, + streams: $mod_name::Streams($($F.into_stream(),)+), rng: utils::RandomGenerator::new(), - $($F: $F.into_stream()),* + wakers: WakerArray::new(), + state: PollArray::new(), + completed: 0, + done: false, } } } }; } -impl_merge_tuple! { Merge0 } -impl_merge_tuple! { Merge1 A } -impl_merge_tuple! { Merge2 A B } -impl_merge_tuple! { Merge3 A B C } -impl_merge_tuple! { Merge4 A B C D } -impl_merge_tuple! { Merge5 A B C D E } -impl_merge_tuple! { Merge6 A B C D E F } -impl_merge_tuple! { Merge7 A B C D E F G } -impl_merge_tuple! { Merge8 A B C D E F G H } -impl_merge_tuple! { Merge9 A B C D E F G H I } -impl_merge_tuple! { Merge10 A B C D E F G H I J } -impl_merge_tuple! { Merge11 A B C D E F G H I J K } -impl_merge_tuple! { Merge12 A B C D E F G H I J K L } +impl_merge_tuple! { merge0 Merge0 } +impl_merge_tuple! { merge1 Merge1 A } +impl_merge_tuple! { merge2 Merge2 A B } +impl_merge_tuple! { merge3 Merge3 A B C } +impl_merge_tuple! { merge4 Merge4 A B C D } +impl_merge_tuple! { merge5 Merge5 A B C D E } +impl_merge_tuple! { merge6 Merge6 A B C D E F } +impl_merge_tuple! { merge7 Merge7 A B C D E F G } +impl_merge_tuple! { merge8 Merge8 A B C D E F G H } +impl_merge_tuple! { merge9 Merge9 A B C D E F G H I } +impl_merge_tuple! { merge10 Merge10 A B C D E F G H I J } +impl_merge_tuple! { merge11 Merge11 A B C D E F G H I J K } +impl_merge_tuple! { merge12 Merge12 A B C D E F G H I J K L } #[cfg(test)] mod tests { use super::*; + use futures::task::LocalSpawnExt; use futures_lite::future::block_on; use futures_lite::prelude::*; use futures_lite::stream; @@ -228,4 +270,60 @@ mod tests { assert_eq!(counter, 10); }) } + + /// This test case uses channels so we'll have streams that return Pending from time to time. + /// + /// The purpose of this test is to make sure we have the waking logic working. + #[test] + fn merge_channels() { + use std::cell::RefCell; + use std::rc::Rc; + + use futures::executor::LocalPool; + + use crate::future::Join; + use crate::utils::channel::local_channel; + + let mut pool = LocalPool::new(); + + let done = Rc::new(RefCell::new(false)); + let done2 = done.clone(); + + pool.spawner() + .spawn_local(async move { + let (send1, receive1) = local_channel(); + let (send2, receive2) = local_channel(); + let (send3, receive3) = local_channel(); + + let (count, ()) = ( + async { + (receive1, receive2, receive3) + .merge() + .fold(0, |a, b| a + b) + .await + }, + async { + for i in 1..=4 { + send1.send(i); + send2.send(i); + send3.send(i); + } + drop(send1); + drop(send2); + drop(send3); + }, + ) + .join() + .await; + + assert_eq!(count, 30); + + *done2.borrow_mut() = true; + }) + .unwrap(); + + while !*done.borrow() { + pool.run_until_stalled() + } + } } diff --git a/src/utils/tuple.rs b/src/utils/tuple.rs index 6fef9aa..9fc4b76 100644 --- a/src/utils/tuple.rs +++ b/src/utils/tuple.rs @@ -56,7 +56,6 @@ macro_rules! gen_conditions { } pub(crate) use gen_conditions; - /// Repeats a given macro for each element a tuple has, passing the iteration /// number (aka the element index) to the macro called macro. ///