diff --git a/Cargo.lock b/Cargo.lock index 511b7da62fe..246df9ba837 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6434,6 +6434,7 @@ dependencies = [ "reth-revm", "reth-rpc-types", "reth-stages-api", + "reth-tokio-util", "reth-transaction-pool", "tokio", "tokio-stream", @@ -6769,6 +6770,7 @@ dependencies = [ "reth-provider", "reth-rpc", "reth-rpc-layer", + "reth-tokio-util", "reth-tracing", "secp256k1 0.28.2", "serde_json", @@ -7613,7 +7615,6 @@ dependencies = [ "reth-tracing", "thiserror", "tokio", - "tokio-stream", "tracing", ] @@ -7740,6 +7741,7 @@ dependencies = [ "reth-rpc-types", "reth-rpc-types-compat", "reth-tasks", + "reth-tokio-util", "reth-tracing", "reth-transaction-pool", "serde", @@ -7774,6 +7776,7 @@ dependencies = [ "reth-rpc-types", "reth-rpc-types-compat", "reth-tasks", + "reth-tokio-util", "serde", "thiserror", "tokio", @@ -7906,6 +7909,7 @@ dependencies = [ "reth-stages", "reth-tokio-util", "tempfile", + "tokio", "tokio-stream", "tracing", ] @@ -7941,6 +7945,7 @@ version = "0.2.0-beta.7" dependencies = [ "tokio", "tokio-stream", + "tracing", ] [[package]] diff --git a/bin/reth/src/commands/import.rs b/bin/reth/src/commands/import.rs index 235ada84854..1108f8aa785 100644 --- a/bin/reth/src/commands/import.rs +++ b/bin/reth/src/commands/import.rs @@ -257,7 +257,7 @@ where let max_block = file_client.max_block().unwrap_or(0); - let mut pipeline = Pipeline::builder() + let pipeline = Pipeline::builder() .with_tip_sender(tip_tx) // we want to sync all blocks the file client provides or 0 if empty .with_max_block(max_block) diff --git a/crates/consensus/auto-seal/Cargo.toml b/crates/consensus/auto-seal/Cargo.toml index 435ade53db3..ccbc1e06a32 100644 --- a/crates/consensus/auto-seal/Cargo.toml +++ b/crates/consensus/auto-seal/Cargo.toml @@ -25,6 +25,7 @@ reth-engine-primitives.workspace = true reth-consensus.workspace = true reth-rpc-types.workspace = true reth-network-types.workspace = true +reth-tokio-util.workspace = true # async futures-util.workspace = true diff --git a/crates/consensus/auto-seal/src/task.rs b/crates/consensus/auto-seal/src/task.rs index 42f1268f331..2a5ec4433e4 100644 --- a/crates/consensus/auto-seal/src/task.rs +++ b/crates/consensus/auto-seal/src/task.rs @@ -9,6 +9,7 @@ use reth_primitives::{ use reth_provider::{CanonChainTracker, CanonStateNotificationSender, Chain, StateProviderFactory}; use reth_rpc_types::engine::ForkchoiceState; use reth_stages_api::PipelineEvent; +use reth_tokio_util::EventStream; use reth_transaction_pool::{TransactionPool, ValidPoolTransaction}; use std::{ collections::VecDeque, @@ -18,7 +19,6 @@ use std::{ task::{Context, Poll}, }; use tokio::sync::{mpsc::UnboundedSender, oneshot}; -use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, error, warn}; /// A Future that listens for new ready transactions and puts new blocks into storage @@ -30,7 +30,7 @@ pub struct MiningTask>>>, + insert_task: Option>>>, /// Shared storage to insert new blocks storage: Storage, /// Pool where transactions are stored @@ -42,7 +42,7 @@ pub struct MiningTask>, + pipe_line_events: Option>, /// The type used for block execution block_executor: Executor, } @@ -80,7 +80,7 @@ impl } /// Sets the pipeline events to listen on. - pub fn set_pipeline_events(&mut self, events: UnboundedReceiverStream) { + pub fn set_pipeline_events(&mut self, events: EventStream) { self.pipe_line_events = Some(events); } } diff --git a/crates/consensus/beacon/src/engine/handle.rs b/crates/consensus/beacon/src/engine/handle.rs index 121a8fac070..bec289bf4a7 100644 --- a/crates/consensus/beacon/src/engine/handle.rs +++ b/crates/consensus/beacon/src/engine/handle.rs @@ -10,28 +10,20 @@ use reth_interfaces::RethResult; use reth_rpc_types::engine::{ CancunPayloadFields, ExecutionPayload, ForkchoiceState, ForkchoiceUpdated, PayloadStatus, }; -use tokio::sync::{mpsc, mpsc::UnboundedSender, oneshot}; -use tokio_stream::wrappers::UnboundedReceiverStream; +use reth_tokio_util::{EventSender, EventStream}; +use tokio::sync::{mpsc::UnboundedSender, oneshot}; /// A _shareable_ beacon consensus frontend type. Used to interact with the spawned beacon consensus /// engine task. /// /// See also `BeaconConsensusEngine` -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BeaconConsensusEngineHandle where Engine: EngineTypes, { pub(crate) to_engine: UnboundedSender>, -} - -impl Clone for BeaconConsensusEngineHandle -where - Engine: EngineTypes, -{ - fn clone(&self) -> Self { - Self { to_engine: self.to_engine.clone() } - } + event_sender: EventSender, } // === impl BeaconConsensusEngineHandle === @@ -41,8 +33,11 @@ where Engine: EngineTypes, { /// Creates a new beacon consensus engine handle. - pub fn new(to_engine: UnboundedSender>) -> Self { - Self { to_engine } + pub fn new( + to_engine: UnboundedSender>, + event_sender: EventSender, + ) -> Self { + Self { to_engine, event_sender } } /// Sends a new payload message to the beacon consensus engine and waits for a response. @@ -97,9 +92,7 @@ where } /// Creates a new [`BeaconConsensusEngineEvent`] listener stream. - pub fn event_listener(&self) -> UnboundedReceiverStream { - let (tx, rx) = mpsc::unbounded_channel(); - let _ = self.to_engine.send(BeaconEngineMessage::EventListener(tx)); - UnboundedReceiverStream::new(rx) + pub fn event_listener(&self) -> EventStream { + self.event_sender.new_listener() } } diff --git a/crates/consensus/beacon/src/engine/hooks/static_file.rs b/crates/consensus/beacon/src/engine/hooks/static_file.rs index 2cff68e1d26..01b7056c37f 100644 --- a/crates/consensus/beacon/src/engine/hooks/static_file.rs +++ b/crates/consensus/beacon/src/engine/hooks/static_file.rs @@ -91,8 +91,7 @@ impl StaticFileHook { return Ok(None) }; - let Some(mut locked_static_file_producer) = static_file_producer.try_lock_arc() - else { + let Some(locked_static_file_producer) = static_file_producer.try_lock_arc() else { trace!(target: "consensus::engine::hooks::static_file", "StaticFileProducer lock is already taken"); return Ok(None) }; diff --git a/crates/consensus/beacon/src/engine/message.rs b/crates/consensus/beacon/src/engine/message.rs index f9f1a84d46f..108dab41eb0 100644 --- a/crates/consensus/beacon/src/engine/message.rs +++ b/crates/consensus/beacon/src/engine/message.rs @@ -1,7 +1,4 @@ -use crate::{ - engine::{error::BeaconOnNewPayloadError, forkchoice::ForkchoiceStatus}, - BeaconConsensusEngineEvent, -}; +use crate::engine::{error::BeaconOnNewPayloadError, forkchoice::ForkchoiceStatus}; use futures::{future::Either, FutureExt}; use reth_engine_primitives::EngineTypes; use reth_interfaces::RethResult; @@ -15,7 +12,7 @@ use std::{ pin::Pin, task::{ready, Context, Poll}, }; -use tokio::sync::{mpsc::UnboundedSender, oneshot}; +use tokio::sync::oneshot; /// Represents the outcome of forkchoice update. /// @@ -162,6 +159,4 @@ pub enum BeaconEngineMessage { }, /// Message with exchanged transition configuration. TransitionConfigurationExchanged, - /// Add a new listener for [`BeaconEngineMessage`]. - EventListener(UnboundedSender), } diff --git a/crates/consensus/beacon/src/engine/mod.rs b/crates/consensus/beacon/src/engine/mod.rs index 1057457c779..8139a0c5773 100644 --- a/crates/consensus/beacon/src/engine/mod.rs +++ b/crates/consensus/beacon/src/engine/mod.rs @@ -29,7 +29,7 @@ use reth_rpc_types::engine::{ }; use reth_stages_api::{ControlFlow, Pipeline}; use reth_tasks::TaskSpawner; -use reth_tokio_util::EventListeners; +use reth_tokio_util::EventSender; use std::{ pin::Pin, sync::Arc, @@ -202,8 +202,8 @@ where /// be used to download and execute the missing blocks. pipeline_run_threshold: u64, hooks: EngineHooksController, - /// Listeners for engine events. - listeners: EventListeners, + /// Sender for engine events. + event_sender: EventSender, /// Consensus engine metrics. metrics: EngineMetrics, } @@ -282,8 +282,8 @@ where engine_message_stream: BoxStream<'static, BeaconEngineMessage>, hooks: EngineHooks, ) -> RethResult<(Self, BeaconConsensusEngineHandle)> { - let handle = BeaconConsensusEngineHandle { to_engine }; - let listeners = EventListeners::default(); + let event_sender = EventSender::default(); + let handle = BeaconConsensusEngineHandle::new(to_engine, event_sender.clone()); let sync = EngineSyncController::new( pipeline, client, @@ -291,7 +291,7 @@ where run_pipeline_continuously, max_block, blockchain.chain_spec(), - listeners.clone(), + event_sender.clone(), ); let mut this = Self { sync, @@ -306,7 +306,7 @@ where blockchain_tree_action: None, pipeline_run_threshold, hooks: EngineHooksController::new(hooks), - listeners, + event_sender, metrics: EngineMetrics::default(), }; @@ -406,7 +406,7 @@ where if should_update_head { let head = outcome.header(); let _ = self.update_head(head.clone()); - self.listeners.notify(BeaconConsensusEngineEvent::CanonicalChainCommitted( + self.event_sender.notify(BeaconConsensusEngineEvent::CanonicalChainCommitted( Box::new(head.clone()), elapsed, )); @@ -543,7 +543,7 @@ where } // notify listeners about new processed FCU - self.listeners.notify(BeaconConsensusEngineEvent::ForkchoiceUpdated(state, status)); + self.event_sender.notify(BeaconConsensusEngineEvent::ForkchoiceUpdated(state, status)); } /// Check if the pipeline is consistent (all stages have the checkpoint block numbers no less @@ -597,13 +597,6 @@ where self.handle.clone() } - /// Pushes an [UnboundedSender] to the engine's listeners. Also pushes an [UnboundedSender] to - /// the sync controller's listeners. - pub(crate) fn push_listener(&mut self, listener: UnboundedSender) { - self.listeners.push_listener(listener.clone()); - self.sync.push_listener(listener); - } - /// Returns true if the distance from the local tip to the block is greater than the configured /// threshold. /// @@ -1255,7 +1248,7 @@ where } else { BeaconConsensusEngineEvent::ForkBlockAdded(block) }; - self.listeners.notify(event); + self.event_sender.notify(event); PayloadStatusEnum::Valid } InsertPayloadOk::AlreadySeen(BlockStatus::Valid(_)) => { @@ -1429,7 +1422,7 @@ where match make_canonical_result { Ok(outcome) => { if let CanonicalOutcome::Committed { head } = &outcome { - self.listeners.notify(BeaconConsensusEngineEvent::CanonicalChainCommitted( + self.event_sender.notify(BeaconConsensusEngineEvent::CanonicalChainCommitted( Box::new(head.clone()), elapsed, )); @@ -1878,7 +1871,6 @@ where BeaconEngineMessage::TransitionConfigurationExchanged => { this.blockchain.on_transition_configuration_exchanged(); } - BeaconEngineMessage::EventListener(tx) => this.push_listener(tx), } continue } diff --git a/crates/consensus/beacon/src/engine/sync.rs b/crates/consensus/beacon/src/engine/sync.rs index 09c6d208b6e..441c3ce0362 100644 --- a/crates/consensus/beacon/src/engine/sync.rs +++ b/crates/consensus/beacon/src/engine/sync.rs @@ -14,14 +14,14 @@ use reth_interfaces::p2p::{ use reth_primitives::{stage::PipelineTarget, BlockNumber, ChainSpec, SealedBlock, B256}; use reth_stages_api::{ControlFlow, Pipeline, PipelineError, PipelineWithResult}; use reth_tasks::TaskSpawner; -use reth_tokio_util::EventListeners; +use reth_tokio_util::EventSender; use std::{ cmp::{Ordering, Reverse}, collections::{binary_heap::PeekMut, BinaryHeap}, sync::Arc, task::{ready, Context, Poll}, }; -use tokio::sync::{mpsc::UnboundedSender, oneshot}; +use tokio::sync::oneshot; use tracing::trace; /// Manages syncing under the control of the engine. @@ -49,8 +49,8 @@ where inflight_full_block_requests: Vec>, /// In-flight full block _range_ requests in progress. inflight_block_range_requests: Vec>, - /// Listeners for engine events. - listeners: EventListeners, + /// Sender for engine events. + event_sender: EventSender, /// Buffered blocks from downloads - this is a min-heap of blocks, using the block number for /// ordering. This means the blocks will be popped from the heap with ascending block numbers. range_buffered_blocks: BinaryHeap>, @@ -76,7 +76,7 @@ where run_pipeline_continuously: bool, max_block: Option, chain_spec: Arc, - listeners: EventListeners, + event_sender: EventSender, ) -> Self { Self { full_block_client: FullBlockClient::new( @@ -90,7 +90,7 @@ where inflight_block_range_requests: Vec::new(), range_buffered_blocks: BinaryHeap::new(), run_pipeline_continuously, - listeners, + event_sender, max_block, metrics: EngineSyncMetrics::default(), } @@ -127,11 +127,6 @@ where self.run_pipeline_continuously } - /// Pushes an [UnboundedSender] to the sync controller's listeners. - pub(crate) fn push_listener(&mut self, listener: UnboundedSender) { - self.listeners.push_listener(listener); - } - /// Returns `true` if a pipeline target is queued and will be triggered on the next `poll`. #[allow(dead_code)] pub(crate) fn is_pipeline_sync_pending(&self) -> bool { @@ -169,7 +164,7 @@ where ); // notify listeners that we're downloading a block range - self.listeners.notify(BeaconConsensusEngineEvent::LiveSyncProgress( + self.event_sender.notify(BeaconConsensusEngineEvent::LiveSyncProgress( ConsensusEngineLiveSyncProgress::DownloadingBlocks { remaining_blocks: count, target: hash, @@ -198,7 +193,7 @@ where ); // notify listeners that we're downloading a block - self.listeners.notify(BeaconConsensusEngineEvent::LiveSyncProgress( + self.event_sender.notify(BeaconConsensusEngineEvent::LiveSyncProgress( ConsensusEngineLiveSyncProgress::DownloadingBlocks { remaining_blocks: 1, target: hash, diff --git a/crates/e2e-test-utils/Cargo.toml b/crates/e2e-test-utils/Cargo.toml index 59424cac98f..4165044ae2e 100644 --- a/crates/e2e-test-utils/Cargo.toml +++ b/crates/e2e-test-utils/Cargo.toml @@ -20,6 +20,7 @@ reth-rpc-layer.workspace = true reth-payload-builder = { workspace = true, features = ["test-utils"] } reth-provider.workspace = true reth-node-builder.workspace = true +reth-tokio-util.workspace = true jsonrpsee.workspace = true diff --git a/crates/e2e-test-utils/src/network.rs b/crates/e2e-test-utils/src/network.rs index 92e9b316a9a..5b148b09f55 100644 --- a/crates/e2e-test-utils/src/network.rs +++ b/crates/e2e-test-utils/src/network.rs @@ -1,12 +1,12 @@ use futures_util::StreamExt; use reth::network::{NetworkEvent, NetworkEvents, NetworkHandle, PeersInfo}; use reth_primitives::NodeRecord; +use reth_tokio_util::EventStream; use reth_tracing::tracing::info; -use tokio_stream::wrappers::UnboundedReceiverStream; /// Helper for network operations pub struct NetworkTestContext { - network_events: UnboundedReceiverStream, + network_events: EventStream, network: NetworkHandle, } diff --git a/crates/net/network/src/manager.rs b/crates/net/network/src/manager.rs index d516625c640..b6b1d4d1ecb 100644 --- a/crates/net/network/src/manager.rs +++ b/crates/net/network/src/manager.rs @@ -49,7 +49,7 @@ use reth_primitives::{ForkId, NodeRecord}; use reth_provider::{BlockNumReader, BlockReader}; use reth_rpc_types::{admin::EthProtocolInfo, NetworkStatus}; use reth_tasks::shutdown::GracefulShutdown; -use reth_tokio_util::EventListeners; +use reth_tokio_util::EventSender; use secp256k1::SecretKey; use std::{ net::SocketAddr, @@ -84,8 +84,8 @@ pub struct NetworkManager { from_handle_rx: UnboundedReceiverStream, /// Handles block imports according to the `eth` protocol. block_import: Box, - /// All listeners for high level network events. - event_listeners: EventListeners, + /// Sender for high level network events. + event_sender: EventSender, /// Sender half to send events to the /// [`TransactionsManager`](crate::transactions::TransactionsManager) task, if configured. to_transactions_manager: Option>, @@ -246,6 +246,8 @@ where let (to_manager_tx, from_handle_rx) = mpsc::unbounded_channel(); + let event_sender: EventSender = Default::default(); + let handle = NetworkHandle::new( Arc::clone(&num_active_peers), listener_address, @@ -258,6 +260,7 @@ where Arc::new(AtomicU64::new(chain_spec.chain.id())), tx_gossip_disabled, discv4, + event_sender.clone(), ); Ok(Self { @@ -265,7 +268,7 @@ where handle, from_handle_rx: UnboundedReceiverStream::new(from_handle_rx), block_import, - event_listeners: Default::default(), + event_sender, to_transactions_manager: None, to_eth_request_handler: None, num_active_peers, @@ -528,9 +531,6 @@ where /// Handler for received messages from a handle fn on_handle_message(&mut self, msg: NetworkHandleMessage) { match msg { - NetworkHandleMessage::EventListener(tx) => { - self.event_listeners.push_listener(tx); - } NetworkHandleMessage::DiscoveryListener(tx) => { self.swarm.state_mut().discovery_mut().add_listener(tx); } @@ -690,7 +690,7 @@ where self.update_active_connection_metrics(); - self.event_listeners.notify(NetworkEvent::SessionEstablished { + self.event_sender.notify(NetworkEvent::SessionEstablished { peer_id, remote_addr, client_version, @@ -702,12 +702,12 @@ where } SwarmEvent::PeerAdded(peer_id) => { trace!(target: "net", ?peer_id, "Peer added"); - self.event_listeners.notify(NetworkEvent::PeerAdded(peer_id)); + self.event_sender.notify(NetworkEvent::PeerAdded(peer_id)); self.metrics.tracked_peers.set(self.swarm.state().peers().num_known_peers() as f64); } SwarmEvent::PeerRemoved(peer_id) => { trace!(target: "net", ?peer_id, "Peer dropped"); - self.event_listeners.notify(NetworkEvent::PeerRemoved(peer_id)); + self.event_sender.notify(NetworkEvent::PeerRemoved(peer_id)); self.metrics.tracked_peers.set(self.swarm.state().peers().num_known_peers() as f64); } SwarmEvent::SessionClosed { peer_id, remote_addr, error } => { @@ -750,7 +750,7 @@ where .saturating_sub(1) as f64, ); - self.event_listeners.notify(NetworkEvent::SessionClosed { peer_id, reason }); + self.event_sender.notify(NetworkEvent::SessionClosed { peer_id, reason }); } SwarmEvent::IncomingPendingSessionClosed { remote_addr, error } => { trace!( diff --git a/crates/net/network/src/network.rs b/crates/net/network/src/network.rs index 86669bf19f4..8d9b277f419 100644 --- a/crates/net/network/src/network.rs +++ b/crates/net/network/src/network.rs @@ -16,6 +16,7 @@ use reth_network_api::{ use reth_network_types::PeerId; use reth_primitives::{Head, NodeRecord, TransactionSigned, B256}; use reth_rpc_types::NetworkStatus; +use reth_tokio_util::{EventSender, EventStream}; use secp256k1::SecretKey; use std::{ net::SocketAddr, @@ -24,7 +25,10 @@ use std::{ Arc, }, }; -use tokio::sync::{mpsc, mpsc::UnboundedSender, oneshot}; +use tokio::sync::{ + mpsc::{self, UnboundedSender}, + oneshot, +}; use tokio_stream::wrappers::UnboundedReceiverStream; /// A _shareable_ network frontend. Used to interact with the network. @@ -53,6 +57,7 @@ impl NetworkHandle { chain_id: Arc, tx_gossip_disabled: bool, discv4: Option, + event_sender: EventSender, ) -> Self { let inner = NetworkInner { num_active_peers, @@ -68,6 +73,7 @@ impl NetworkHandle { chain_id, tx_gossip_disabled, discv4, + event_sender, }; Self { inner: Arc::new(inner) } } @@ -196,10 +202,8 @@ impl NetworkHandle { // === API Implementations === impl NetworkEvents for NetworkHandle { - fn event_listener(&self) -> UnboundedReceiverStream { - let (tx, rx) = mpsc::unbounded_channel(); - let _ = self.manager().send(NetworkHandleMessage::EventListener(tx)); - UnboundedReceiverStream::new(rx) + fn event_listener(&self) -> EventStream { + self.inner.event_sender.new_listener() } fn discovery_listener(&self) -> UnboundedReceiverStream { @@ -401,12 +405,14 @@ struct NetworkInner { tx_gossip_disabled: bool, /// The instance of the discv4 service discv4: Option, + /// Sender for high level network events. + event_sender: EventSender, } /// Provides event subscription for the network. pub trait NetworkEvents: Send + Sync { /// Creates a new [`NetworkEvent`] listener channel. - fn event_listener(&self) -> UnboundedReceiverStream; + fn event_listener(&self) -> EventStream; /// Returns a new [`DiscoveryEvent`] stream. /// /// This stream yields [`DiscoveryEvent`]s for each peer that is discovered. @@ -430,8 +436,6 @@ pub(crate) enum NetworkHandleMessage { RemovePeer(PeerId, PeerKind), /// Disconnects a connection to a peer if it exists, optionally providing a disconnect reason. DisconnectPeer(PeerId, Option), - /// Adds a new listener for `NetworkEvent`. - EventListener(UnboundedSender), /// Broadcasts an event to announce a new block to all nodes. AnnounceBlock(NewBlock, B256), /// Sends a list of transactions to the given peer. diff --git a/crates/net/network/src/test_utils/testnet.rs b/crates/net/network/src/test_utils/testnet.rs index a92934c0cbc..99c98db55d5 100644 --- a/crates/net/network/src/test_utils/testnet.rs +++ b/crates/net/network/src/test_utils/testnet.rs @@ -20,6 +20,7 @@ use reth_provider::{ test_utils::NoopProvider, BlockReader, BlockReaderIdExt, HeaderProvider, StateProviderFactory, }; use reth_tasks::TokioTaskExecutor; +use reth_tokio_util::EventStream; use reth_transaction_pool::{ blobstore::InMemoryBlobStore, test_utils::{TestPool, TestPoolBuilder}, @@ -40,7 +41,6 @@ use tokio::{ }, task::JoinHandle, }; -use tokio_stream::wrappers::UnboundedReceiverStream; /// A test network consisting of multiple peers. pub struct Testnet { @@ -503,7 +503,7 @@ impl PeerHandle { } /// Creates a new [`NetworkEvent`] listener channel. - pub fn event_listener(&self) -> UnboundedReceiverStream { + pub fn event_listener(&self) -> EventStream { self.network.event_listener() } @@ -591,14 +591,14 @@ impl Default for PeerConfig { /// This makes it easier to await established connections #[derive(Debug)] pub struct NetworkEventStream { - inner: UnboundedReceiverStream, + inner: EventStream, } // === impl NetworkEventStream === impl NetworkEventStream { /// Create a new [`NetworkEventStream`] from the given network event receiver stream. - pub fn new(inner: UnboundedReceiverStream) -> Self { + pub fn new(inner: EventStream) -> Self { Self { inner } } diff --git a/crates/net/network/src/transactions/mod.rs b/crates/net/network/src/transactions/mod.rs index 070b9c7a147..b6b2328e4f8 100644 --- a/crates/net/network/src/transactions/mod.rs +++ b/crates/net/network/src/transactions/mod.rs @@ -30,6 +30,7 @@ use reth_network_types::PeerId; use reth_primitives::{ FromRecoveredPooledTransaction, PooledTransactionsElement, TransactionSigned, TxHash, B256, }; +use reth_tokio_util::EventStream; use reth_transaction_pool::{ error::{PoolError, PoolResult}, GetPooledTransactionLimit, PoolTransaction, PropagateKind, PropagatedTransactions, @@ -197,7 +198,7 @@ pub struct TransactionsManager { /// Subscriptions to all network related events. /// /// From which we get all new incoming transaction related messages. - network_events: UnboundedReceiverStream, + network_events: EventStream, /// Transaction fetcher to handle inflight and missing transaction requests. transaction_fetcher: TransactionFetcher, /// All currently pending transactions grouped by peers. @@ -880,8 +881,8 @@ where } /// Handles a received event related to common network events. - fn on_network_event(&mut self, event: NetworkEvent) { - match event { + fn on_network_event(&mut self, event_result: NetworkEvent) { + match event_result { NetworkEvent::SessionClosed { peer_id, .. } => { // remove the peer self.peers.remove(&peer_id); @@ -1626,6 +1627,7 @@ mod tests { use secp256k1::SecretKey; use std::{fmt, future::poll_fn, hash}; use tests::fetcher::TxFetchMetadata; + use tracing::error; async fn new_tx_manager() -> TransactionsManager { let secret_key = SecretKey::new(&mut rand::thread_rng()); @@ -1734,7 +1736,7 @@ mod tests { } NetworkEvent::PeerAdded(_peer_id) => continue, ev => { - panic!("unexpected event {ev:?}") + error!("unexpected event {ev:?}") } } } @@ -1820,7 +1822,7 @@ mod tests { } NetworkEvent::PeerAdded(_peer_id) => continue, ev => { - panic!("unexpected event {ev:?}") + error!("unexpected event {ev:?}") } } } @@ -1904,7 +1906,7 @@ mod tests { } NetworkEvent::PeerAdded(_peer_id) => continue, ev => { - panic!("unexpected event {ev:?}") + error!("unexpected event {ev:?}") } } } @@ -1992,7 +1994,7 @@ mod tests { }), NetworkEvent::PeerAdded(_peer_id) => continue, ev => { - panic!("unexpected event {ev:?}") + error!("unexpected event {ev:?}") } } } diff --git a/crates/node-core/src/engine/engine_store.rs b/crates/node-core/src/engine/engine_store.rs index 2a1ffc3b0ed..d59651ce9ca 100644 --- a/crates/node-core/src/engine/engine_store.rs +++ b/crates/node-core/src/engine/engine_store.rs @@ -89,8 +89,7 @@ impl EngineMessageStore { )?; } // noop - BeaconEngineMessage::TransitionConfigurationExchanged | - BeaconEngineMessage::EventListener(_) => (), + BeaconEngineMessage::TransitionConfigurationExchanged => (), }; Ok(()) } diff --git a/crates/node/builder/src/launch/mod.rs b/crates/node/builder/src/launch/mod.rs index ece149e31dc..4987586bc9f 100644 --- a/crates/node/builder/src/launch/mod.rs +++ b/crates/node/builder/src/launch/mod.rs @@ -282,7 +282,7 @@ where // Configure the pipeline let pipeline_exex_handle = exex_manager_handle.clone().unwrap_or_else(ExExManagerHandle::empty); - let (mut pipeline, client) = if ctx.is_dev() { + let (pipeline, client) = if ctx.is_dev() { info!(target: "reth::cli", "Starting Reth in dev mode"); for (idx, (address, alloc)) in ctx.chain_spec().genesis.alloc.iter().enumerate() { @@ -305,7 +305,7 @@ where ) .build(); - let mut pipeline = crate::setup::build_networked_pipeline( + let pipeline = crate::setup::build_networked_pipeline( ctx.node_config(), &ctx.toml_config().stages, client.clone(), @@ -358,7 +358,7 @@ where pruner_builder.finished_exex_height(exex_manager_handle.finished_height()); } - let mut pruner = pruner_builder.build(ctx.provider_factory().clone()); + let pruner = pruner_builder.build(ctx.provider_factory().clone()); let pruner_events = pruner.events(); info!(target: "reth::cli", prune_config=?ctx.prune_config().unwrap_or_default(), "Pruner initialized"); @@ -395,7 +395,7 @@ where Either::Right(stream::empty()) }, pruner_events.map(Into::into), - static_file_producer_events.map(Into::into) + static_file_producer_events.map(Into::into), ); ctx.task_executor().spawn_critical( "events task", diff --git a/crates/node/events/src/node.rs b/crates/node/events/src/node.rs index ba7ae8da460..383da986b4a 100644 --- a/crates/node/events/src/node.rs +++ b/crates/node/events/src/node.rs @@ -392,6 +392,9 @@ pub enum NodeEvent { Pruner(PrunerEvent), /// A static_file_producer event StaticFileProducer(StaticFileProducerEvent), + /// Used to encapsulate various conditions or situations that do not + /// naturally fit into the other more specific variants. + Other(String), } impl From for NodeEvent { @@ -575,6 +578,9 @@ where NodeEvent::StaticFileProducer(event) => { this.state.handle_static_file_producer_event(event); } + NodeEvent::Other(event_description) => { + warn!("{event_description}"); + } } } diff --git a/crates/prune/Cargo.toml b/crates/prune/Cargo.toml index cc24e68b834..65b4ba19c6c 100644 --- a/crates/prune/Cargo.toml +++ b/crates/prune/Cargo.toml @@ -30,7 +30,6 @@ thiserror.workspace = true itertools.workspace = true rayon.workspace = true tokio.workspace = true -tokio-stream.workspace = true [dev-dependencies] # reth diff --git a/crates/prune/src/pruner.rs b/crates/prune/src/pruner.rs index 55a998709d8..f4111f131a5 100644 --- a/crates/prune/src/pruner.rs +++ b/crates/prune/src/pruner.rs @@ -13,13 +13,12 @@ use reth_primitives::{ use reth_provider::{ DatabaseProviderRW, ProviderFactory, PruneCheckpointReader, StaticFileProviderFactory, }; -use reth_tokio_util::EventListeners; +use reth_tokio_util::{EventSender, EventStream}; use std::{ collections::BTreeMap, time::{Duration, Instant}, }; use tokio::sync::watch; -use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::debug; /// Result of [Pruner::run] execution. @@ -53,7 +52,7 @@ pub struct Pruner { finished_exex_height: watch::Receiver, #[doc(hidden)] metrics: Metrics, - listeners: EventListeners, + event_sender: EventSender, } impl Pruner { @@ -77,13 +76,13 @@ impl Pruner { timeout, finished_exex_height, metrics: Metrics::default(), - listeners: Default::default(), + event_sender: Default::default(), } } /// Listen for events on the pruner. - pub fn events(&mut self) -> UnboundedReceiverStream { - self.listeners.new_listener() + pub fn events(&self) -> EventStream { + self.event_sender.new_listener() } /// Run the pruner @@ -100,7 +99,7 @@ impl Pruner { return Ok(PruneProgress::Finished) } - self.listeners.notify(PrunerEvent::Started { tip_block_number }); + self.event_sender.notify(PrunerEvent::Started { tip_block_number }); debug!(target: "pruner", %tip_block_number, "Pruner started"); let start = Instant::now(); @@ -154,7 +153,7 @@ impl Pruner { "{message}", ); - self.listeners.notify(PrunerEvent::Finished { tip_block_number, elapsed, stats }); + self.event_sender.notify(PrunerEvent::Finished { tip_block_number, elapsed, stats }); Ok(progress) } diff --git a/crates/rpc/rpc-builder/Cargo.toml b/crates/rpc/rpc-builder/Cargo.toml index 9087ff7c7ff..e3b5f4766d9 100644 --- a/crates/rpc/rpc-builder/Cargo.toml +++ b/crates/rpc/rpc-builder/Cargo.toml @@ -56,6 +56,7 @@ reth-rpc-types.workspace = true reth-rpc-types-compat.workspace = true reth-tracing.workspace = true reth-transaction-pool = { workspace = true, features = ["test-utils"] } +reth-tokio-util.workspace = true tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } serde_json.workspace = true diff --git a/crates/rpc/rpc-builder/tests/it/utils.rs b/crates/rpc/rpc-builder/tests/it/utils.rs index a3272ac026a..dd58bf2de29 100644 --- a/crates/rpc/rpc-builder/tests/it/utils.rs +++ b/crates/rpc/rpc-builder/tests/it/utils.rs @@ -26,7 +26,8 @@ pub fn test_address() -> SocketAddr { pub async fn launch_auth(secret: JwtSecret) -> AuthServerHandle { let config = AuthServerConfig::builder(secret).socket_addr(test_address()).build(); let (tx, _rx) = unbounded_channel(); - let beacon_engine_handle = BeaconConsensusEngineHandle::::new(tx); + let beacon_engine_handle = + BeaconConsensusEngineHandle::::new(tx, Default::default()); let engine_api = EngineApi::new( NoopProvider::default(), MAINNET.clone(), diff --git a/crates/rpc/rpc-engine-api/Cargo.toml b/crates/rpc/rpc-engine-api/Cargo.toml index 5fe782a6ef5..83a5f85fcfa 100644 --- a/crates/rpc/rpc-engine-api/Cargo.toml +++ b/crates/rpc/rpc-engine-api/Cargo.toml @@ -43,6 +43,7 @@ reth-ethereum-engine-primitives.workspace = true reth-interfaces = { workspace = true, features = ["test-utils"] } reth-provider = { workspace = true, features = ["test-utils"] } reth-payload-builder = { workspace = true, features = ["test-utils"] } +reth-tokio-util.workspace = true alloy-rlp.workspace = true diff --git a/crates/rpc/rpc-engine-api/src/engine_api.rs b/crates/rpc/rpc-engine-api/src/engine_api.rs index 0e4476bb71b..a2275281e63 100644 --- a/crates/rpc/rpc-engine-api/src/engine_api.rs +++ b/crates/rpc/rpc-engine-api/src/engine_api.rs @@ -770,7 +770,7 @@ where mod tests { use super::*; use assert_matches::assert_matches; - use reth_beacon_consensus::BeaconEngineMessage; + use reth_beacon_consensus::{BeaconConsensusEngineEvent, BeaconEngineMessage}; use reth_ethereum_engine_primitives::EthEngineTypes; use reth_interfaces::test_utils::generators::random_block; use reth_payload_builder::test_utils::spawn_test_payload_service; @@ -778,6 +778,7 @@ mod tests { use reth_provider::test_utils::MockEthProvider; use reth_rpc_types_compat::engine::payload::execution_payload_from_sealed_block; use reth_tasks::TokioTaskExecutor; + use reth_tokio_util::EventSender; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; fn setup_engine_api() -> (EngineApiTestHandle, EngineApi, EthEngineTypes>) @@ -786,11 +787,12 @@ mod tests { let provider = Arc::new(MockEthProvider::default()); let payload_store = spawn_test_payload_service(); let (to_engine, engine_rx) = unbounded_channel(); + let event_sender: EventSender = Default::default(); let task_executor = Box::::default(); let api = EngineApi::new( provider.clone(), chain_spec.clone(), - BeaconConsensusEngineHandle::new(to_engine), + BeaconConsensusEngineHandle::new(to_engine, event_sender), payload_store.into(), task_executor, ); diff --git a/crates/stages-api/Cargo.toml b/crates/stages-api/Cargo.toml index 2101961fd2d..32c4258538a 100644 --- a/crates/stages-api/Cargo.toml +++ b/crates/stages-api/Cargo.toml @@ -27,7 +27,6 @@ metrics.workspace = true # async tokio = { workspace = true, features = ["sync"] } -tokio-stream.workspace = true futures-util.workspace = true # misc @@ -40,6 +39,7 @@ auto_impl.workspace = true assert_matches.workspace = true reth-provider = { workspace = true, features = ["test-utils"] } reth-interfaces = { workspace = true, features = ["test-utils"] } +tokio-stream.workspace = true [features] test-utils = [] diff --git a/crates/stages-api/src/error.rs b/crates/stages-api/src/error.rs index 37fe2b3fdbc..f6e528ca754 100644 --- a/crates/stages-api/src/error.rs +++ b/crates/stages-api/src/error.rs @@ -1,3 +1,4 @@ +use crate::PipelineEvent; use reth_consensus::ConsensusError; use reth_interfaces::{ db::DatabaseError as DbError, executor, p2p::error::DownloadError, RethError, @@ -5,9 +6,7 @@ use reth_interfaces::{ use reth_primitives::{BlockNumber, SealedHeader, StaticFileSegment, TxNumber}; use reth_provider::ProviderError; use thiserror::Error; - -use crate::PipelineEvent; -use tokio::sync::mpsc::error::SendError; +use tokio::sync::broadcast::error::SendError; /// Represents the specific error type within a block error. #[derive(Error, Debug)] diff --git a/crates/stages-api/src/pipeline/builder.rs b/crates/stages-api/src/pipeline/builder.rs index e76f76c604c..c059067259f 100644 --- a/crates/stages-api/src/pipeline/builder.rs +++ b/crates/stages-api/src/pipeline/builder.rs @@ -80,7 +80,7 @@ where max_block, static_file_producer, tip_tx, - listeners: Default::default(), + event_sender: Default::default(), progress: Default::default(), metrics_tx, } diff --git a/crates/stages-api/src/pipeline/mod.rs b/crates/stages-api/src/pipeline/mod.rs index 5aceb515b79..66a87a0f8a4 100644 --- a/crates/stages-api/src/pipeline/mod.rs +++ b/crates/stages-api/src/pipeline/mod.rs @@ -17,10 +17,9 @@ use reth_provider::{ }; use reth_prune::PrunerBuilder; use reth_static_file::StaticFileProducer; -use reth_tokio_util::EventListeners; +use reth_tokio_util::{EventSender, EventStream}; use std::pin::Pin; use tokio::sync::watch; -use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::*; mod builder; @@ -75,8 +74,8 @@ pub struct Pipeline { /// The maximum block number to sync to. max_block: Option, static_file_producer: StaticFileProducer, - /// All listeners for events the pipeline emits. - listeners: EventListeners, + /// Sender for events the pipeline emits. + event_sender: EventSender, /// Keeps track of the progress of the pipeline. progress: PipelineProgress, /// A receiver for the current chain tip to sync to. @@ -108,8 +107,8 @@ where } /// Listen for events on the pipeline. - pub fn events(&mut self) -> UnboundedReceiverStream { - self.listeners.new_listener() + pub fn events(&self) -> EventStream { + self.event_sender.new_listener() } /// Registers progress metrics for each registered stage @@ -251,7 +250,7 @@ where /// CAUTION: This method locks the static file producer Mutex, hence can block the thread if the /// lock is occupied. pub fn move_to_static_files(&self) -> RethResult<()> { - let mut static_file_producer = self.static_file_producer.lock(); + let static_file_producer = self.static_file_producer.lock(); // Copies data from database to static files let lowest_static_file_height = { @@ -312,7 +311,8 @@ where %to, "Unwind point too far for stage" ); - self.listeners.notify(PipelineEvent::Skipped { stage_id }); + self.event_sender.notify(PipelineEvent::Skipped { stage_id }); + continue } @@ -325,7 +325,7 @@ where ); while checkpoint.block_number > to { let input = UnwindInput { checkpoint, unwind_to: to, bad_block }; - self.listeners.notify(PipelineEvent::Unwind { stage_id, input }); + self.event_sender.notify(PipelineEvent::Unwind { stage_id, input }); let output = stage.unwind(&provider_rw, input); match output { @@ -350,7 +350,7 @@ where } provider_rw.save_stage_checkpoint(stage_id, checkpoint)?; - self.listeners + self.event_sender .notify(PipelineEvent::Unwound { stage_id, result: unwind_output }); self.provider_factory.static_file_provider().commit()?; @@ -359,7 +359,8 @@ where provider_rw = self.provider_factory.provider_rw()?; } Err(err) => { - self.listeners.notify(PipelineEvent::Error { stage_id }); + self.event_sender.notify(PipelineEvent::Error { stage_id }); + return Err(PipelineError::Stage(StageError::Fatal(Box::new(err)))) } } @@ -395,7 +396,7 @@ where prev_block = prev_checkpoint.map(|progress| progress.block_number), "Stage reached target block, skipping." ); - self.listeners.notify(PipelineEvent::Skipped { stage_id }); + self.event_sender.notify(PipelineEvent::Skipped { stage_id }); // We reached the maximum block, so we skip the stage return Ok(ControlFlow::NoProgress { @@ -405,7 +406,7 @@ where let exec_input = ExecInput { target, checkpoint: prev_checkpoint }; - self.listeners.notify(PipelineEvent::Prepare { + self.event_sender.notify(PipelineEvent::Prepare { pipeline_stages_progress: PipelineStagesProgress { current: stage_index + 1, total: total_stages, @@ -416,14 +417,15 @@ where }); if let Err(err) = stage.execute_ready(exec_input).await { - self.listeners.notify(PipelineEvent::Error { stage_id }); + self.event_sender.notify(PipelineEvent::Error { stage_id }); + match on_stage_error(&self.provider_factory, stage_id, prev_checkpoint, err)? { Some(ctrl) => return Ok(ctrl), None => continue, }; } - self.listeners.notify(PipelineEvent::Run { + self.event_sender.notify(PipelineEvent::Run { pipeline_stages_progress: PipelineStagesProgress { current: stage_index + 1, total: total_stages, @@ -448,7 +450,7 @@ where } provider_rw.save_stage_checkpoint(stage_id, checkpoint)?; - self.listeners.notify(PipelineEvent::Ran { + self.event_sender.notify(PipelineEvent::Ran { pipeline_stages_progress: PipelineStagesProgress { current: stage_index + 1, total: total_stages, @@ -471,7 +473,8 @@ where } Err(err) => { drop(provider_rw); - self.listeners.notify(PipelineEvent::Error { stage_id }); + self.event_sender.notify(PipelineEvent::Error { stage_id }); + if let Some(ctrl) = on_stage_error(&self.provider_factory, stage_id, prev_checkpoint, err)? { @@ -575,7 +578,7 @@ impl std::fmt::Debug for Pipeline { f.debug_struct("Pipeline") .field("stages", &self.stages.iter().map(|stage| stage.id()).collect::>()) .field("max_block", &self.max_block) - .field("listeners", &self.listeners) + .field("event_sender", &self.event_sender) .finish() } } diff --git a/crates/static-file/Cargo.toml b/crates/static-file/Cargo.toml index 1345b2f232f..0f6608c8084 100644 --- a/crates/static-file/Cargo.toml +++ b/crates/static-file/Cargo.toml @@ -21,6 +21,7 @@ reth-nippy-jar.workspace = true reth-tokio-util.workspace = true # async +tokio.workspace = true tokio-stream.workspace = true # misc diff --git a/crates/static-file/src/static_file_producer.rs b/crates/static-file/src/static_file_producer.rs index c7a365c9afa..4eb08256114 100644 --- a/crates/static-file/src/static_file_producer.rs +++ b/crates/static-file/src/static_file_producer.rs @@ -10,13 +10,12 @@ use reth_provider::{ providers::{StaticFileProvider, StaticFileWriter}, ProviderFactory, }; -use reth_tokio_util::EventListeners; +use reth_tokio_util::{EventSender, EventStream}; use std::{ ops::{Deref, RangeInclusive}, sync::Arc, time::Instant, }; -use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::{debug, trace}; /// Result of [StaticFileProducerInner::run] execution. @@ -64,7 +63,7 @@ pub struct StaticFileProducerInner { /// needed in [StaticFileProducerInner] to prevent attempting to move prunable data to static /// files. See [StaticFileProducerInner::get_static_file_targets]. prune_modes: PruneModes, - listeners: EventListeners, + event_sender: EventSender, } /// Static File targets, per data part, measured in [`BlockNumber`]. @@ -107,12 +106,17 @@ impl StaticFileProducerInner { static_file_provider: StaticFileProvider, prune_modes: PruneModes, ) -> Self { - Self { provider_factory, static_file_provider, prune_modes, listeners: Default::default() } + Self { + provider_factory, + static_file_provider, + prune_modes, + event_sender: Default::default(), + } } /// Listen for events on the static_file_producer. - pub fn events(&mut self) -> UnboundedReceiverStream { - self.listeners.new_listener() + pub fn events(&self) -> EventStream { + self.event_sender.new_listener() } /// Run the static_file_producer. @@ -123,7 +127,7 @@ impl StaticFileProducerInner { /// /// NOTE: it doesn't delete the data from database, and the actual deleting (aka pruning) logic /// lives in the `prune` crate. - pub fn run(&mut self, targets: StaticFileTargets) -> StaticFileProducerResult { + pub fn run(&self, targets: StaticFileTargets) -> StaticFileProducerResult { // If there are no targets, do not produce any static files and return early if !targets.any() { return Ok(targets) @@ -133,7 +137,7 @@ impl StaticFileProducerInner { self.static_file_provider.get_highest_static_files() )); - self.listeners.notify(StaticFileProducerEvent::Started { targets: targets.clone() }); + self.event_sender.notify(StaticFileProducerEvent::Started { targets: targets.clone() }); debug!(target: "static_file", ?targets, "StaticFileProducer started"); let start = Instant::now(); @@ -173,7 +177,7 @@ impl StaticFileProducerInner { let elapsed = start.elapsed(); // TODO(alexey): track in metrics debug!(target: "static_file", ?targets, ?elapsed, "StaticFileProducer finished"); - self.listeners + self.event_sender .notify(StaticFileProducerEvent::Finished { targets: targets.clone(), elapsed }); Ok(targets) @@ -304,7 +308,7 @@ mod tests { fn run() { let (provider_factory, static_file_provider, _temp_static_files_dir) = setup(); - let mut static_file_producer = StaticFileProducerInner::new( + let static_file_producer = StaticFileProducerInner::new( provider_factory, static_file_provider.clone(), PruneModes::default(), @@ -392,7 +396,7 @@ mod tests { let tx = tx.clone(); std::thread::spawn(move || { - let mut locked_producer = producer.lock(); + let locked_producer = producer.lock(); if i == 0 { // Let other threads spawn as well. std::thread::sleep(Duration::from_millis(100)); diff --git a/crates/tokio-util/Cargo.toml b/crates/tokio-util/Cargo.toml index e8c21e0fa05..ccace030c0f 100644 --- a/crates/tokio-util/Cargo.toml +++ b/crates/tokio-util/Cargo.toml @@ -12,7 +12,11 @@ description = "Additional utilities for working with Tokio in reth." workspace = true [dependencies] +tracing.workspace = true # async tokio = { workspace = true, features = ["sync"] } tokio-stream = { workspace = true, features = ["sync"] } + +[dev-dependencies] +tokio = { workspace = true, features = ["full", "macros"] } \ No newline at end of file diff --git a/crates/tokio-util/src/event_listeners.rs b/crates/tokio-util/src/event_listeners.rs deleted file mode 100644 index 3c940e28022..00000000000 --- a/crates/tokio-util/src/event_listeners.rs +++ /dev/null @@ -1,46 +0,0 @@ -use tokio::sync::mpsc; -use tokio_stream::wrappers::UnboundedReceiverStream; - -/// A collection of event listeners for a task. -#[derive(Clone, Debug)] -pub struct EventListeners { - /// All listeners for events - listeners: Vec>, -} - -impl Default for EventListeners { - fn default() -> Self { - Self { listeners: Vec::new() } - } -} - -impl EventListeners { - /// Send an event to all listeners. - /// - /// Channels that were closed are removed. - pub fn notify(&mut self, event: T) { - self.listeners.retain(|listener| listener.send(event.clone()).is_ok()) - } - - /// Add a new event listener. - pub fn new_listener(&mut self) -> UnboundedReceiverStream { - let (sender, receiver) = mpsc::unbounded_channel(); - self.listeners.push(sender); - UnboundedReceiverStream::new(receiver) - } - - /// Push new event listener. - pub fn push_listener(&mut self, listener: mpsc::UnboundedSender) { - self.listeners.push(listener); - } - - /// Returns the number of registered listeners. - pub fn len(&self) -> usize { - self.listeners.len() - } - - /// Returns true if there are no registered listeners. - pub fn is_empty(&self) -> bool { - self.listeners.is_empty() - } -} diff --git a/crates/tokio-util/src/event_sender.rs b/crates/tokio-util/src/event_sender.rs new file mode 100644 index 00000000000..3ed6e85910d --- /dev/null +++ b/crates/tokio-util/src/event_sender.rs @@ -0,0 +1,42 @@ +use crate::EventStream; +use tokio::sync::broadcast::{self, Sender}; +use tracing::error; + +const DEFAULT_SIZE_BROADCAST_CHANNEL: usize = 2000; + +/// A bounded broadcast channel for a task. +#[derive(Debug, Clone)] +pub struct EventSender { + /// The sender part of the broadcast channel + sender: Sender, +} + +impl Default for EventSender +where + T: Clone + Send + Sync + 'static, +{ + fn default() -> Self { + Self::new(DEFAULT_SIZE_BROADCAST_CHANNEL) + } +} + +impl EventSender { + /// Creates a new `EventSender`. + pub fn new(events_channel_size: usize) -> Self { + let (sender, _) = broadcast::channel(events_channel_size); + Self { sender } + } + + /// Broadcasts an event to all listeners. + pub fn notify(&self, event: T) { + if self.sender.send(event).is_err() { + error!("channel closed"); + } + } + + /// Creates a new event stream with a subscriber to the sender as the + /// receiver. + pub fn new_listener(&self) -> EventStream { + EventStream::new(self.sender.subscribe()) + } +} diff --git a/crates/tokio-util/src/event_stream.rs b/crates/tokio-util/src/event_stream.rs new file mode 100644 index 00000000000..fc7e56a13bb --- /dev/null +++ b/crates/tokio-util/src/event_stream.rs @@ -0,0 +1,92 @@ +//! Event streams related functionality. + +use std::{ + pin::Pin, + task::{Context, Poll}, +}; +use tokio_stream::Stream; +use tracing::warn; + +/// Thin wrapper around tokio's BroadcastStream to allow skipping broadcast errors. +#[derive(Debug)] +pub struct EventStream { + inner: tokio_stream::wrappers::BroadcastStream, +} + +impl EventStream +where + T: Clone + Send + 'static, +{ + /// Creates a new `EventStream`. + pub fn new(receiver: tokio::sync::broadcast::Receiver) -> Self { + let inner = tokio_stream::wrappers::BroadcastStream::new(receiver); + EventStream { inner } + } +} + +impl Stream for EventStream +where + T: Clone + Send + 'static, +{ + type Item = T; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match Pin::new(&mut self.inner).poll_next(cx) { + Poll::Ready(Some(Ok(item))) => return Poll::Ready(Some(item)), + Poll::Ready(Some(Err(e))) => { + warn!("BroadcastStream lagged: {e:?}"); + continue; + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::sync::broadcast; + use tokio_stream::StreamExt; + + #[tokio::test] + async fn test_event_stream_yields_items() { + let (tx, _) = broadcast::channel(16); + let my_stream = EventStream::new(tx.subscribe()); + + tx.send(1).unwrap(); + tx.send(2).unwrap(); + tx.send(3).unwrap(); + + // drop the sender to terminate the stream and allow collect to work. + drop(tx); + + let items: Vec = my_stream.collect().await; + + assert_eq!(items, vec![1, 2, 3]); + } + + #[tokio::test] + async fn test_event_stream_skips_lag_errors() { + let (tx, _) = broadcast::channel(2); + let my_stream = EventStream::new(tx.subscribe()); + + let mut _rx2 = tx.subscribe(); + let mut _rx3 = tx.subscribe(); + + tx.send(1).unwrap(); + tx.send(2).unwrap(); + tx.send(3).unwrap(); + tx.send(4).unwrap(); // This will cause lag for the first subscriber + + // drop the sender to terminate the stream and allow collect to work. + drop(tx); + + // Ensure lag errors are skipped and only valid items are collected + let items: Vec = my_stream.collect().await; + + assert_eq!(items, vec![3, 4]); + } +} diff --git a/crates/tokio-util/src/lib.rs b/crates/tokio-util/src/lib.rs index 7db8dcfba16..2053bf60bc5 100644 --- a/crates/tokio-util/src/lib.rs +++ b/crates/tokio-util/src/lib.rs @@ -8,5 +8,7 @@ #![cfg_attr(not(test), warn(unused_crate_dependencies))] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] -mod event_listeners; -pub use event_listeners::EventListeners; +mod event_sender; +mod event_stream; +pub use event_sender::EventSender; +pub use event_stream::EventStream;