diff --git a/Cargo.lock b/Cargo.lock index 20790f93fd945..d11866da375e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7406,6 +7406,7 @@ dependencies = [ "futures-core", "pin-project-lite 0.2.4", "tokio 1.3.0", + "tokio-util 0.6.3", ] [[package]] @@ -8245,6 +8246,7 @@ dependencies = [ "serde", "serde_json", "tokio 1.3.0", + "tokio-stream", "tokio-tungstenite", "url", "uuid 0.8.2", diff --git a/Cargo.toml b/Cargo.toml index 4da2ba7bfa7c4..1a5af9f34ed48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -100,7 +100,7 @@ futures = { version = "0.3", default-features = false, features = ["compat", "io futures01 = { package = "futures", version = "0.1.25" } tokio = { version = "1.3.0", features = ["full"] } tokio-openssl = "0.6.1" -tokio-stream = { version = "0.1.2", features = ["net"] } +tokio-stream = { version = "0.1.3", features = ["net", "sync"] } tokio-util = { version = "0.6.2", features = ["codec", "time"] } # Tracing diff --git a/lib/vector-api-client/Cargo.toml b/lib/vector-api-client/Cargo.toml index b4381f8177b0e..0d2df50afeeaf 100644 --- a/lib/vector-api-client/Cargo.toml +++ b/lib/vector-api-client/Cargo.toml @@ -20,6 +20,7 @@ async-stream = "0.3.0" async-trait = "0.1" futures = { version = "0.3", default-features = false, features = ["compat", "io-compat"] } tokio = { version = "1.3.0", features = ["full"] } +tokio-stream = { version = "0.1.3", features = ["sync"] } # GraphQL graphql_client = "0.9.0" diff --git a/lib/vector-api-client/src/subscription.rs b/lib/vector-api-client/src/subscription.rs index 9fe3810d60f91..70bf285869e1a 100644 --- a/lib/vector-api-client/src/subscription.rs +++ b/lib/vector-api-client/src/subscription.rs @@ -1,8 +1,4 @@ -use async_stream::stream; -use futures::{ - stream::{Stream, StreamExt}, - SinkExt, -}; +use futures::SinkExt; use graphql_client::GraphQLQuery; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -11,6 +7,7 @@ use std::{ sync::{Arc, Mutex, Weak}, }; use tokio::sync::{broadcast, mpsc, oneshot}; +use tokio_stream::{wrappers::BroadcastStream, Stream, StreamExt}; use tokio_tungstenite::{connect_async, tungstenite::Message}; use url::Url; use uuid::Uuid; @@ -133,14 +130,11 @@ where { /// Returns a stream of `Payload` responses, received from the GraphQL server fn stream(&self) -> StreamResponse { - let mut rx = self.tx.subscribe(); - Box::pin(stream! { - loop { - if let Ok(p) = rx.recv().await { - yield p.response::() - } - } - }) + Box::pin( + BroadcastStream::new(self.tx.subscribe()) + .filter(Result::is_ok) + .map(|p| p.unwrap().response::()), + ) } } diff --git a/src/api/schema/components/mod.rs b/src/api/schema/components/mod.rs index 24b998aeab78a..67ef05b60de18 100644 --- a/src/api/schema/components/mod.rs +++ b/src/api/schema/components/mod.rs @@ -13,13 +13,12 @@ use crate::{ filter_check, }; use async_graphql::{Enum, InputObject, Interface, Object, Subscription}; -use async_stream::stream; use lazy_static::lazy_static; use std::{ cmp, collections::{HashMap, HashSet}, }; -use tokio_stream::Stream; +use tokio_stream::{wrappers::BroadcastStream, Stream, StreamExt}; #[derive(Debug, Clone, Interface)] #[graphql( @@ -230,28 +229,18 @@ pub struct ComponentsSubscription; impl ComponentsSubscription { /// Subscribes to all newly added components async fn component_added(&self) -> impl Stream { - let mut rx = COMPONENT_CHANGED.subscribe(); - stream! { - loop { - match rx.recv().await { - Ok(ComponentChanged::Added(c)) => yield c, - _ => {}, - } - } - } + BroadcastStream::new(COMPONENT_CHANGED.subscribe()).filter_map(|c| match c { + Ok(ComponentChanged::Added(c)) => Some(c), + _ => None, + }) } /// Subscribes to all removed components async fn component_removed(&self) -> impl Stream { - let mut rx = COMPONENT_CHANGED.subscribe(); - stream! { - loop { - match rx.recv().await { - Ok(ComponentChanged::Removed(c)) => yield c, - _ => {}, - } - } - } + BroadcastStream::new(COMPONENT_CHANGED.subscribe()).filter_map(|c| match c { + Ok(ComponentChanged::Removed(c)) => Some(c), + _ => None, + }) } } diff --git a/src/sources/util/unix_stream.rs b/src/sources/util/unix_stream.rs index 4708ab0910b42..dc7fac79085ad 100644 --- a/src/sources/util/unix_stream.rs +++ b/src/sources/util/unix_stream.rs @@ -7,12 +7,12 @@ use crate::{ sources::Source, Pipeline, }; -use async_stream::stream; use bytes::Bytes; use futures::{FutureExt, SinkExt, StreamExt}; use std::{future::ready, path::PathBuf}; use tokio::io::AsyncWriteExt; use tokio::net::{UnixListener, UnixStream}; +use tokio_stream::wrappers::UnixListenerStream; use tokio_util::codec::{Decoder, FramedRead}; use tracing::field; use tracing_futures::Instrument; @@ -40,12 +40,7 @@ where info!(message = "Listening.", path = ?listen_path, r#type = "unix"); let connection_open = OpenGauge::new(); - let stream = stream! { - loop { - yield listener.accept().await.map(|(stream, _addr)| stream) - } - } - .take_until(shutdown.clone()); + let stream = UnixListenerStream::new(listener).take_until(shutdown.clone()); tokio::pin!(stream); while let Some(socket) = stream.next().await { let socket = match socket { diff --git a/src/test_util/mod.rs b/src/test_util/mod.rs index e2091bb32135c..8fabff9d57145 100644 --- a/src/test_util/mod.rs +++ b/src/test_util/mod.rs @@ -3,7 +3,6 @@ use crate::{ topology::{self, RunningTopology}, trace, Event, }; -use async_stream::stream; use flate2::read::GzDecoder; use futures::{ ready, stream, task::noop_waker_ref, FutureExt, SinkExt, Stream, StreamExt, TryStreamExt, @@ -36,6 +35,9 @@ use tokio::{ task::JoinHandle, time::{sleep, Duration, Instant}, }; +use tokio_stream::wrappers::TcpListenerStream; +#[cfg(unix)] +use tokio_stream::wrappers::UnixListenerStream; use tokio_util::codec::{Encoder, FramedRead, FramedWrite, LinesCodec}; const WAIT_FOR_SECS: u64 = 5; // The default time to wait in `wait_for` @@ -445,12 +447,13 @@ impl CountReceiver { pub fn receive_lines(addr: SocketAddr) -> CountReceiver { CountReceiver::new(|count, tripwire, connected| async move { let listener = TcpListener::bind(addr).await.unwrap(); - let stream = stream! { - loop { - yield listener.accept().await.map(|(stream, _addr)| stream) - } - }; - CountReceiver::receive_lines_stream(stream, count, tripwire, Some(connected)).await + CountReceiver::receive_lines_stream( + TcpListenerStream::new(listener), + count, + tripwire, + Some(connected), + ) + .await }) } @@ -461,12 +464,13 @@ impl CountReceiver { { CountReceiver::new(|count, tripwire, connected| async move { let listener = tokio::net::UnixListener::bind(path).unwrap(); - let stream = stream! { - loop { - yield listener.accept().await.map(|(stream, _addr)| stream) - } - }; - CountReceiver::receive_lines_stream(stream, count, tripwire, Some(connected)).await + CountReceiver::receive_lines_stream( + UnixListenerStream::new(listener), + count, + tripwire, + Some(connected), + ) + .await }) }