From 1251247fe29a3f62df5e078784efc39412a718a8 Mon Sep 17 00:00:00 2001 From: "Roman S. Borschel" Date: Tue, 21 Jul 2020 10:27:04 +0200 Subject: [PATCH] Graceful shutdown for connections, networks and swarms. Building on the ability to wait for connection shutdown to complete introduced in https://github.com/libp2p/rust-libp2p/pull/1619, this commit extends the ability for performing graceful shutdowns in the following ways: 1. The `ConnectionHandler` (and thus also `ProtocolsHandler`) can participate in the shutdown, via new `poll_close` methods. The muxer and underlying transport connection only starts closing once the connection handler signals readiness to do so. 2. A `Network` can be gracefully shut down, which involves a graceful shutdown of the underlying connection `Pool`. The `Pool` in turn proceeds with a shutdown by rejecting new connections while draining established connections. 3. A `Swarm` can be gracefully shut down, which involves a graceful shutdown of the underlying `Network` followed by polling the `NetworkBehaviour` until it returns `Poll::Pending`, i.e. it has no more output. In particular, the following are important details: * Analogous to new inbound and outbound connections during shutdown, while a single connection is shutting down, it rejects new inbound substreams and, by the return type of `ConnectionHandler::poll_close`, no new outbound substreams can be requested. * The `NodeHandlerWrapper` managing the `ProtocolsHandler` always waits for already ongoing inbound and outbound substream upgrades to complete. Since the `NodeHandlerWrapper` is a `ConnectionHandler`, the previous point applies w.r.t. new inbound and outbound substreams. * When the `connection_keep_alive` expires, a graceful shutdown is initiated. --- core/src/connection.rs | 122 ++++- core/src/connection/handler.rs | 38 +- core/src/connection/manager/task.rs | 64 +-- core/src/connection/pool.rs | 126 ++++- core/src/connection/substream.rs | 62 +-- core/src/network.rs | 207 ++++--- core/src/network/peer.rs | 3 +- core/tests/network_dial_error.rs | 29 +- examples/ipfs-kad.rs | 4 +- misc/core-derive/src/lib.rs | 6 + misc/core-derive/tests/test.rs | 5 +- protocols/gossipsub/src/behaviour.rs | 6 + protocols/gossipsub/tests/smoke.rs | 16 +- protocols/identify/src/identify.rs | 6 +- protocols/kad/src/behaviour/test.rs | 577 +++++++++++--------- protocols/ping/tests/ping.rs | 4 +- protocols/request-response/src/handler.rs | 83 ++- protocols/request-response/src/lib.rs | 28 +- protocols/request-response/tests/ping.rs | 8 +- swarm/src/behaviour.rs | 11 + swarm/src/lib.rs | 197 +++++-- swarm/src/protocols_handler.rs | 48 +- swarm/src/protocols_handler/node_handler.rs | 165 +++--- 23 files changed, 1201 insertions(+), 614 deletions(-) diff --git a/core/src/connection.rs b/core/src/connection.rs index d160aeb23b39..20f78007566a 100644 --- a/core/src/connection.rs +++ b/core/src/connection.rs @@ -30,7 +30,7 @@ pub use error::{ConnectionError, PendingConnectionError}; pub use handler::{ConnectionHandler, ConnectionHandlerEvent, IntoConnectionHandler}; pub use listeners::{ListenerId, ListenersStream, ListenersEvent}; pub use manager::ConnectionId; -pub use substream::{Substream, SubstreamEndpoint, Close}; +pub use substream::{Substream, SubstreamEndpoint}; pub use pool::{EstablishedConnection, EstablishedConnectionIter, PendingConnection}; use crate::muxing::StreamMuxer; @@ -194,10 +194,12 @@ where TMuxer: StreamMuxer, THandler: ConnectionHandler>, { - /// Node that handles the muxing. + /// The substream multiplexer over the connection I/O stream. muxing: substream::Muxing, - /// Handler that processes substreams. + /// The connection handler for the substreams. handler: THandler, + /// The operating state of the connection. + state: ConnectionState, } impl fmt::Debug for Connection @@ -231,44 +233,76 @@ where Connection { muxing: Muxing::new(muxer), handler, + state: ConnectionState::Open, } } - /// Returns a reference to the `ConnectionHandler` - pub fn handler(&self) -> &THandler { - &self.handler - } - - /// Returns a mutable reference to the `ConnectionHandler` - pub fn handler_mut(&mut self) -> &mut THandler { - &mut self.handler - } - /// Notifies the connection handler of an event. + /// + /// Has no effect if the connection handler is already closed. pub fn inject_event(&mut self, event: THandler::InEvent) { - self.handler.inject_event(event); + match self.state { + ConnectionState::Open | ConnectionState::CloseHandler + => self.handler.inject_event(event), + _ => { + log::trace!("Ignoring handler event. Handler is closed.") + } + } } - /// Begins an orderly shutdown of the connection, returning a - /// `Future` that resolves when connection shutdown is complete. - pub fn close(self) -> Close { - self.muxing.close().0 + /// Begins a graceful shutdown of the connection. + /// + /// The connection must continue to be `poll()`ed to drive the + /// shutdown process to completion. Once connection shutdown is + /// complete, `poll()` returns `Ok(None)`. + pub fn start_close(&mut self) { + if self.state == ConnectionState::Open { + self.state = ConnectionState::CloseHandler; + } } /// Polls the connection for events produced by the associated handler /// as a result of I/O activity on the substream multiplexer. + /// + /// > **Note**: A return value of `Ok(None)` signals successful + /// > connection shutdown, whereas an `Err` signals termination + /// > of the connection due to an error. In either case, the + /// > connection must be dropped; any further method calls + /// > result in unspecified behaviour. pub fn poll(mut self: Pin<&mut Self>, cx: &mut Context) - -> Poll, ConnectionError>> + -> Poll>, ConnectionError>> { loop { + if let ConnectionState::Closed = self.state { // (1) + return Poll::Ready(Ok(None)) + } + + if let ConnectionState::CloseMuxer = self.state { // (2) + match futures::ready!(self.muxing.poll_close(cx)) { + Ok(()) => { + self.state = ConnectionState::Closed; + return Poll::Ready(Ok(None)) + } + Err(e) => return Poll::Ready(Err(ConnectionError::IO(e))) + } + } + + // At this point the connection is either open or in the process + // of a graceful shutdown by the connection handler. let mut io_pending = false; // Perform I/O on the connection through the muxer, informing the handler - // of new substreams. + // of new substreams or other muxer events. match self.muxing.poll(cx) { Poll::Pending => io_pending = true, Poll::Ready(Ok(SubstreamEvent::InboundSubstream { substream })) => { - self.handler.inject_substream(substream, SubstreamEndpoint::Listener) + // Drop new inbound substreams when closing. This is analogous + // to rejecting new connections. + if self.state == ConnectionState::Open { + self.handler.inject_substream(substream, SubstreamEndpoint::Listener) + } else { + log::trace!("Inbound substream dropped. Connection is closing.") + } } Poll::Ready(Ok(SubstreamEvent::OutboundSubstream { user_data, substream })) => { let endpoint = SubstreamEndpoint::Dialer(user_data); @@ -276,23 +310,37 @@ where } Poll::Ready(Ok(SubstreamEvent::AddressChange(address))) => { self.handler.inject_address_change(&address); - return Poll::Ready(Ok(Event::AddressChange(address))); + return Poll::Ready(Ok(Some(Event::AddressChange(address)))); } Poll::Ready(Err(err)) => return Poll::Ready(Err(ConnectionError::IO(err))), } // Poll the handler for new events. - match self.handler.poll(cx) { + let poll = match &self.state { + ConnectionState::Open => self.handler.poll(cx).map_ok(Some), + ConnectionState::CloseHandler => self.handler.poll_close(cx).map_ok( + |event| event.map(ConnectionHandlerEvent::Custom)), + s => panic!("Unexpected closing state: {:?}", s) // s.a. (1),(2) + }; + + match poll { Poll::Pending => { if io_pending { return Poll::Pending // Nothing to do } } - Poll::Ready(Ok(ConnectionHandlerEvent::OutboundSubstreamRequest(user_data))) => { + Poll::Ready(Ok(Some(ConnectionHandlerEvent::OutboundSubstreamRequest(user_data)))) => { self.muxing.open_substream(user_data); } - Poll::Ready(Ok(ConnectionHandlerEvent::Custom(event))) => { - return Poll::Ready(Ok(Event::Handler(event))); + Poll::Ready(Ok(Some(ConnectionHandlerEvent::Custom(event)))) => { + return Poll::Ready(Ok(Some(Event::Handler(event)))); + } + Poll::Ready(Ok(Some(ConnectionHandlerEvent::Close))) => { + self.start_close() + } + Poll::Ready(Ok(None)) => { + // The handler is done, we can now close the muxer (i.e. connection). + self.state = ConnectionState::CloseMuxer; } Poll::Ready(Err(err)) => return Poll::Ready(Err(ConnectionError::Handler(err))), } @@ -352,3 +400,25 @@ impl fmt::Display for ConnectionLimit { /// A `ConnectionLimit` can represent an error if it has been exceeded. impl Error for ConnectionLimit {} + +/// The state of a [`Connection`] w.r.t. an active graceful close. +#[derive(Debug, PartialEq, Eq)] +enum ConnectionState { + /// The connection is open, accepting new inbound and outbound + /// substreams. + Open, + /// The connection is closing, rejecting new inbound substreams + /// and not permitting new outbound substreams while the + /// connection handler closes. [`ConnectionHandler::poll_close`] + /// is called until completion which results in transitioning to + /// `CloseMuxer`. + CloseHandler, + /// The connection is closing, rejecting new inbound substreams + /// and not permitting new outbound substreams while the + /// muxer is closing the transport connection. [`Muxer::poll_close`] + /// is called until completion, which results in transitioning + /// to `Closed`. + CloseMuxer, + /// The connection is closed. + Closed +} diff --git a/core/src/connection/handler.rs b/core/src/connection/handler.rs index 07006f8c3d2e..12705d45b058 100644 --- a/core/src/connection/handler.rs +++ b/core/src/connection/handler.rs @@ -19,7 +19,7 @@ // DEALINGS IN THE SOFTWARE. use crate::{Multiaddr, PeerId}; -use std::{task::Context, task::Poll}; +use std::task::{Context, Poll}; use super::{Connected, SubstreamEndpoint}; /// The interface of a connection handler. @@ -66,6 +66,37 @@ pub trait ConnectionHandler { /// Returning an error will close the connection to the remote. fn poll(&mut self, cx: &mut Context) -> Poll, Self::Error>>; + + /// Polls the handler to make progress towards closing the connection. + /// + /// When a connection is actively closed, the handler can perform + /// a graceful shutdown of the connection by draining the I/O + /// activity, e.g. allowing in-flight requests to complete without + /// accepting new ones, possibly signaling the remote that it + /// should direct further requests elsewhere. + /// + /// The handler can also use the opportunity to flush any buffers + /// or clean up any other (asynchronous) resources before the + /// connection is ultimately dropped and closed on the transport + /// layer. + /// + /// While closing, new inbound substreams are rejected and the + /// handler is unable to request new outbound substreams as + /// per the return type of `poll_close`. + /// + /// The handler signals its readiness for the connection + /// to be closed by returning `Ready(Ok(None))`, which is the + /// default implementation. Hence, by default, connection + /// shutdown is not delayed and may result in ungraceful + /// interruption of ongoing I/O. + /// + /// > **Note**: Once `poll_close()` is invoked, the handler is no + /// > longer `poll()`ed. + fn poll_close(&mut self, _: &mut Context) + -> Poll, Self::Error>> + { + Poll::Ready(Ok(None)) + } } /// Prototype for a `ConnectionHandler`. @@ -99,6 +130,9 @@ pub enum ConnectionHandlerEvent { /// Other event. Custom(TCustom), + + /// Initiate connection shutdown. + Close, } /// Event produced by a handler. @@ -112,6 +146,7 @@ impl ConnectionHandlerEvent ConnectionHandlerEvent::Custom(val), + ConnectionHandlerEvent::Close => ConnectionHandlerEvent::Close, } } @@ -124,6 +159,7 @@ impl ConnectionHandlerEvent ConnectionHandlerEvent::Custom(map(val)), + ConnectionHandlerEvent::Close => ConnectionHandlerEvent::Close, } } } diff --git a/core/src/connection/manager/task.rs b/core/src/connection/manager/task.rs index bea84513dbe3..20c9ae4efb53 100644 --- a/core/src/connection/manager/task.rs +++ b/core/src/connection/manager/task.rs @@ -23,7 +23,6 @@ use crate::{ muxing::StreamMuxer, connection::{ self, - Close, Connected, Connection, ConnectionError, @@ -168,9 +167,6 @@ where event: Option::Error, C>> }, - /// The connection is closing (active close). - Closing(Close), - /// The task is terminating with a final event for the `Manager`. Terminating(Event::Error, C>), @@ -250,11 +246,8 @@ where Poll::Ready(Some(Command::NotifyHandler(event))) => connection.inject_event(event), Poll::Ready(Some(Command::Close)) => { - // Don't accept any further commands. - this.commands.get_mut().close(); - // Discard the event, if any, and start a graceful close. - this.state = State::Closing(connection.close()); - continue 'poll + // Start closing the connection, if not already. + connection.start_close(); } Poll::Ready(None) => { // The manager has dropped the task or disappeared; abort. @@ -267,13 +260,19 @@ where // Send the event to the manager. match this.events.poll_ready(cx) { Poll::Pending => { - this.state = State::Established { connection, event: Some(event) }; + this.state = State::Established { + connection, + event: Some(event), + }; return Poll::Pending } Poll::Ready(result) => { if result.is_ok() { if let Ok(()) = this.events.start_send(event) { - this.state = State::Established { connection, event: None }; + this.state = State::Established { + connection, + event: None, + }; continue 'poll } } @@ -282,24 +281,34 @@ where } } } else { - // Poll the connection for new events. match Connection::poll(Pin::new(&mut connection), cx) { Poll::Pending => { - this.state = State::Established { connection, event: None }; + this.state = State::Established { + connection, + event: None, + }; return Poll::Pending } - Poll::Ready(Ok(connection::Event::Handler(event))) => { + Poll::Ready(Ok(Some(connection::Event::Handler(event)))) => { this.state = State::Established { connection, - event: Some(Event::Notify { id, event }) + event: Some(Event::Notify { id, event }), }; } - Poll::Ready(Ok(connection::Event::AddressChange(new_address))) => { + Poll::Ready(Ok(Some(connection::Event::AddressChange(new_address)))) => { this.state = State::Established { connection, - event: Some(Event::AddressChange { id, new_address }) + event: Some(Event::AddressChange { id, new_address }), }; } + Poll::Ready(Ok(None)) => { + // The connection is closed, don't accept any further commands + // and terminate the task with a final event. + this.commands.get_mut().close(); + let event = Event::Closed { id: this.id, error: None }; + this.state = State::Terminating(event); + continue 'poll + } Poll::Ready(Err(error)) => { // Don't accept any further commands. this.commands.get_mut().close(); @@ -311,27 +320,6 @@ where } } - State::Closing(mut closing) => { - // Try to gracefully close the connection. - match closing.poll_unpin(cx) { - Poll::Ready(Ok(())) => { - let event = Event::Closed { id: this.id, error: None }; - this.state = State::Terminating(event); - } - Poll::Ready(Err(e)) => { - let event = Event::Closed { - id: this.id, - error: Some(ConnectionError::IO(e)) - }; - this.state = State::Terminating(event); - } - Poll::Pending => { - this.state = State::Closing(closing); - return Poll::Pending - } - } - } - State::Terminating(event) => { // Try to deliver the final event. match this.events.poll_ready(cx) { diff --git a/core/src/connection/pool.rs b/core/src/connection/pool.rs index 003fd50ae123..ca509afa038c 100644 --- a/core/src/connection/pool.rs +++ b/core/src/connection/pool.rs @@ -70,15 +70,20 @@ pub struct Pool>, + + /// The current operating state of the pool. + state: PoolState, } impl fmt::Debug for Pool { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - // TODO: More useful debug impl? f.debug_struct("Pool") + .field("state", &self.state) .field("limits", &self.limits) + .field("peers", &self.established.len()) + .field("pending", &self.pending.len()) .finish() } } @@ -218,6 +223,7 @@ where established: Default::default(), pending: Default::default(), disconnected: Vec::new(), + state: PoolState::Open, } } @@ -226,6 +232,22 @@ where &self.limits } + /// Whether the `Pool` is open, i.e. accepting new connections. + pub fn is_open(&self) -> bool { + self.state == PoolState::Open + } + + /// Whether the `Pool` is closing, i.e. rejecting new connections. + pub fn is_closing(&self) -> bool { + self.state == PoolState::Draining + } + + /// Whether the `Pool` is closed, i.e. all connections are closed + /// and no new connections are accepted. + pub fn is_closed(&self) -> bool { + self.state == PoolState::Closed + } + /// Adds a pending incoming connection to the pool in the form of a /// `Future` that establishes and negotiates the connection. /// @@ -599,16 +621,57 @@ where self.established.keys() } + /// Initiates a graceful shutdown of the `Pool`. + /// + /// All pending connections are immediately aborted and the `Pool` + /// refuses to accept any new connections. Established connections + /// will be drained by continued `poll()`ing of the `Pool`. + pub fn start_close(&mut self) { + if self.state != PoolState::Open { + return // Already closing + } + + // Set the state and limits for the shutdown. + self.state = PoolState::Draining; + self.limits = SHUTDOWN_LIMITS; + + // Immediately abort all pending connections. + for (id, _) in self.pending.drain() { + match self.manager.entry(id) { + Some(manager::Entry::Pending(e)) => { e.abort(); }, + _ => {} + } + } + + // Start a clean shutdown for all established connections. + for id in self.established.values().flat_map(|conns| conns.keys()) { + match self.manager.entry(*id) { + Some(manager::Entry::Established(mut e)) => { + e.start_close(); + }, + _ => { + panic!("Entry for established connection not found: {:?}", id) + } + } + } + + // Shutdown progress for established connections is driven by `Pool::poll`. + } + /// Polls the connection pool for events. /// /// > **Note**: We use a regular `poll` method instead of implementing `Stream`, /// > because we want the `Pool` to stay borrowed if necessary. pub fn poll<'a>(&'a mut self, cx: &mut Context) -> Poll< - PoolEvent<'a, TInEvent, TOutEvent, THandler, TTransErr, THandlerErr, TConnInfo, TPeerId> + Option> > where TConnInfo: ConnectionInfo + Clone, TPeerId: Clone { + if self.state == PoolState::Closed { + return Poll::Ready(None) + } + // Drain events resulting from forced disconnections. // // Note: The `Disconnected` entries in `self.disconnected` @@ -619,33 +682,39 @@ where while let Some(Disconnected { id, connected, num_established }) = self.disconnected.pop() { - return Poll::Ready(PoolEvent::ConnectionClosed { + return Poll::Ready(Some(PoolEvent::ConnectionClosed { id, connected, num_established, error: None, pool: self, - }) + })) } // Poll the connection `Manager`. loop { + // If there are no more established connections, shutdown is complete. + if self.state == PoolState::Draining && self.established.is_empty() { + self.state = PoolState::Closed; + return Poll::Ready(None) + } + let item = match self.manager.poll(cx) { Poll::Ready(item) => item, - Poll::Pending => return Poll::Pending, + Poll::Pending => return Poll::Pending }; match item { manager::Event::PendingConnectionError { id, error, handler } => { if let Some((endpoint, peer)) = self.pending.remove(&id) { - return Poll::Ready(PoolEvent::PendingConnectionError { + return Poll::Ready(Some(PoolEvent::PendingConnectionError { id, endpoint, error, handler: Some(handler), peer, pool: self - }) + })) } }, manager::Event::ConnectionClosed { id, connected, error } => { @@ -659,9 +728,9 @@ where if num_established == 0 { self.established.remove(connected.peer_id()); } - return Poll::Ready(PoolEvent::ConnectionClosed { + return Poll::Ready(Some(PoolEvent::ConnectionClosed { id, connected, error, num_established, pool: self - }) + })) } manager::Event::ConnectionEstablished { entry } => { let id = entry.id(); @@ -672,14 +741,14 @@ where .map_or(0, |conns| conns.len()); if let Err(e) = self.limits.check_established(current) { let connected = entry.remove(); - return Poll::Ready(PoolEvent::PendingConnectionError { + return Poll::Ready(Some(PoolEvent::PendingConnectionError { id, endpoint: connected.endpoint, error: PendingConnectionError::ConnectionLimit(e), handler: None, peer, pool: self - }) + })) } // Peer ID checks must already have happened. See `add_pending`. if cfg!(debug_assertions) { @@ -700,9 +769,9 @@ where conns.insert(id, endpoint); match self.get(id) { Some(PoolConnection::Established(connection)) => - return Poll::Ready(PoolEvent::ConnectionEstablished { + return Poll::Ready(Some(PoolEvent::ConnectionEstablished { connection, num_established - }), + })), _ => unreachable!("since `entry` is an `EstablishedEntry`.") } } @@ -711,10 +780,10 @@ where let id = entry.id(); match self.get(id) { Some(PoolConnection::Established(connection)) => - return Poll::Ready(PoolEvent::ConnectionEvent { + return Poll::Ready(Some(PoolEvent::ConnectionEvent { connection, event, - }), + })), _ => unreachable!("since `entry` is an `EstablishedEntry`.") } }, @@ -730,11 +799,11 @@ where match self.get(id) { Some(PoolConnection::Established(connection)) => - return Poll::Ready(PoolEvent::AddressChange { + return Poll::Ready(Some(PoolEvent::AddressChange { connection, new_endpoint, old_endpoint, - }), + })), _ => unreachable!("since `entry` is an `EstablishedEntry`.") } }, @@ -914,7 +983,7 @@ where } /// The configurable limits of a connection [`Pool`]. -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct PoolLimits { pub max_outgoing: Option, pub max_incoming: Option, @@ -922,6 +991,14 @@ pub struct PoolLimits { pub max_outgoing_per_peer: Option, } +const SHUTDOWN_LIMITS: PoolLimits = + PoolLimits { + max_outgoing: Some(0), + max_incoming: Some(0), + max_established_per_peer: Some(0), + max_outgoing_per_peer: Some(0), + }; + impl PoolLimits { fn check_established(&self, current: F) -> Result<(), ConnectionLimit> where @@ -974,3 +1051,16 @@ struct Disconnected { /// to the same peer. num_established: u32, } + +/// The operating state of a [`Pool`]. +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +enum PoolState { + /// The pool is open for new connections, subject to the + /// configured limits. + Open, + /// The pool is waiting for established connections to close + /// while rejecting new connections. + Draining, + /// The pool has shut down and is closed. + Closed +} diff --git a/core/src/connection/substream.rs b/core/src/connection/substream.rs index cbba375cdc5c..d770e881799c 100644 --- a/core/src/connection/substream.rs +++ b/core/src/connection/substream.rs @@ -19,11 +19,10 @@ // DEALINGS IN THE SOFTWARE. use crate::muxing::{StreamMuxer, StreamMuxerEvent, SubstreamRef, substream_from_ref}; -use futures::prelude::*; use multiaddr::Multiaddr; use smallvec::SmallVec; use std::sync::Arc; -use std::{fmt, io::Error as IoError, pin::Pin, task::Context, task::Poll}; +use std::{fmt, io, task::Context, task::Poll}; /// Endpoint for a received substream. #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -67,12 +66,6 @@ where outbound_substreams: SmallVec<[(TUserData, TMuxer::OutboundSubstream); 8]>, } -/// Future that signals the remote that we have closed the connection. -pub struct Close { - /// Muxer to close. - muxer: Arc, -} - /// A successfully opened substream. pub type Substream = SubstreamRef>; @@ -130,27 +123,27 @@ where self.outbound_substreams.push((user_data, raw)); } - /// Destroys the node stream and returns all the pending outbound substreams, plus an object - /// that signals the remote that we shut down the connection. - #[must_use] - pub fn close(mut self) -> (Close, Vec) { - let substreams = self.cancel_outgoing(); - let close = Close { muxer: self.inner.clone() }; - (close, substreams) + /// Closes the underlying connection, canceling any pending outbound substreams. + pub fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll> { + self.cancel_outgoing(); + match self.inner.close(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(Err(err)) => Poll::Ready(Err(err.into())), + } } - /// Destroys all outbound streams and returns the corresponding user data. - pub fn cancel_outgoing(&mut self) -> Vec { - let mut out = Vec::with_capacity(self.outbound_substreams.len()); - for (user_data, outbound) in self.outbound_substreams.drain(..) { - out.push(user_data); + /// Destroys all outbound streams. + fn cancel_outgoing(&mut self) { + for (_, outbound) in self.outbound_substreams.drain(..) { self.inner.destroy_outbound(outbound); } - out } /// Provides an API similar to `Future`. - pub fn poll(&mut self, cx: &mut Context) -> Poll, IoError>> { + pub fn poll(&mut self, cx: &mut Context) + -> Poll, io::Error>> + { // Polling inbound substream. match self.inner.poll_event(cx) { Poll::Ready(Ok(StreamMuxerEvent::InboundSubstream(substream))) => { @@ -218,31 +211,6 @@ where } } -impl Future for Close -where - TMuxer: StreamMuxer, -{ - type Output = Result<(), IoError>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { - match self.muxer.close(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Ok(())) => Poll::Ready(Ok(())), - Poll::Ready(Err(err)) => Poll::Ready(Err(err.into())), - } - } -} - -impl fmt::Debug for Close -where - TMuxer: StreamMuxer, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - f.debug_struct("Close") - .finish() - } -} - impl fmt::Debug for SubstreamEvent where TMuxer: StreamMuxer, diff --git a/core/src/network.rs b/core/src/network.rs index 17790b6d846e..3c6447b62fb6 100644 --- a/core/src/network.rs +++ b/core/src/network.rs @@ -137,14 +137,20 @@ where impl Network where - TTrans: Transport + Clone, - TMuxer: StreamMuxer, + TTrans: Transport + Clone, + TTrans::Error: Send + 'static, + TTrans::Dial: Send + 'static, + TTrans::ListenerUpgrade: Send + 'static, + TMuxer: StreamMuxer + Send + Sync + 'static, + TMuxer::OutboundSubstream: Send, THandler: IntoConnectionHandler + Send + 'static, THandler::Handler: ConnectionHandler, InEvent = TInEvent, OutEvent = TOutEvent> + Send + 'static, ::OutboundOpenInfo: Send + 'static, // TODO: shouldn't be necessary ::Error: error::Error + Send + 'static, - TConnInfo: fmt::Debug + ConnectionInfo + Send + 'static, - TPeerId: Eq + Hash + Clone, + TConnInfo: ConnectionInfo + fmt::Debug + Clone + Send + 'static, + TPeerId: Eq + Hash + Clone + Send + 'static, + TInEvent: Send + 'static, + TOutEvent: Send + 'static, { /// Creates a new node events stream. pub fn new( @@ -220,16 +226,6 @@ where /// connection ID is returned. pub fn dial(&mut self, address: &Multiaddr, handler: THandler) -> Result - where - TTrans: Transport, - TTrans::Error: Send + 'static, - TTrans::Dial: Send + 'static, - TMuxer: Send + Sync + 'static, - TMuxer::OutboundSubstream: Send, - TInEvent: Send + 'static, - TOutEvent: Send + 'static, - TConnInfo: Send + 'static, - TPeerId: Send + 'static, { let info = OutgoingInfo { address, peer_id: None }; match self.transport().clone().dial(address.clone()) { @@ -326,59 +322,125 @@ where Peer::new(self, peer_id) } + /// Initiates a graceful shutdown of the `Network`. + /// + /// A graceful shutdown proceeds by not accepting any new + /// connections while waiting for all currently established + /// connections to close. + /// + /// After calling this method, [`Network::poll`] makes progress + /// towards shutdown, eventually returning `Poll::Ready(None)` when + /// shutdown is complete. + /// + /// A graceful shutdown involves gracefully closing all established + /// connections, as defined by [`ConnectionHandler::poll_close`]. + pub fn start_close(&mut self) { + self.pool.start_close(); + } + + /// Performs a graceful shutdown of the `Network`, ignoring + /// any further events. + /// + /// See [`Network::start_close`] for further details. + pub async fn close(&mut self) { + self.start_close(); + future::poll_fn(move |cx| { + loop { + match self.poll(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(()), + Poll::Ready(Some(_)) => {}, + } + } + }).await + } + + /// Whether the `Network` is closed. + /// + /// The `Network` is closed after successful completion + /// of a graceful shutdown. See [`Network::start_close`] + /// and [`Network::closed`]. + /// + /// A closed `Network` no longer performs any I/O and + /// shoudl eventually be discarded. + pub fn is_closed(&self) -> bool { + self.pool.is_closed() + } + + /// Whether the `Network` is closing. + /// + /// When the `Network` is closing, no new inbound connections + /// are accepted and no new outbound connections can be requested + /// while already established connections are drained. + /// + /// Returns `false` if the network is already closed + /// or is not currently closing. + pub fn is_closing(&self) -> bool { + self.pool.is_closing() + } + /// Provides an API similar to `Stream`, except that it cannot error. - pub fn poll<'a>(&'a mut self, cx: &mut Context) -> Poll> - where - TTrans: Transport, - TTrans::Error: Send + 'static, - TTrans::Dial: Send + 'static, - TTrans::ListenerUpgrade: Send + 'static, - TMuxer: Send + Sync + 'static, - TMuxer::OutboundSubstream: Send, - TInEvent: Send + 'static, - TOutEvent: Send + 'static, - THandler: IntoConnectionHandler + Send + 'static, - THandler::Handler: ConnectionHandler, InEvent = TInEvent, OutEvent = TOutEvent> + Send + 'static, - ::Error: error::Error + Send + 'static, - TConnInfo: Clone, - TPeerId: Send + 'static, + pub fn poll<'a>(&'a mut self, cx: &mut Context) + -> Poll>> { - // Poll the listener(s) for new connections. - match ListenersStream::poll(Pin::new(&mut self.listeners), cx) { - Poll::Pending => (), - Poll::Ready(ListenersEvent::Incoming { - listener_id, - upgrade, - local_addr, - send_back_addr - }) => { - return Poll::Ready(NetworkEvent::IncomingConnection( - IncomingConnectionEvent { - listener_id, - upgrade, - local_addr, - send_back_addr, - pool: &mut self.pool, - })) - } - Poll::Ready(ListenersEvent::NewAddress { listener_id, listen_addr }) => { - return Poll::Ready(NetworkEvent::NewListenerAddress { listener_id, listen_addr }) - } - Poll::Ready(ListenersEvent::AddressExpired { listener_id, listen_addr }) => { - return Poll::Ready(NetworkEvent::ExpiredListenerAddress { listener_id, listen_addr }) - } - Poll::Ready(ListenersEvent::Closed { listener_id, addresses, reason }) => { - return Poll::Ready(NetworkEvent::ListenerClosed { listener_id, addresses, reason }) - } - Poll::Ready(ListenersEvent::Error { listener_id, error }) => { - return Poll::Ready(NetworkEvent::ListenerError { listener_id, error }) + // If the connection pool is closed, the network is considered closed as well. + if self.pool.is_closed() { + return Poll::Ready(None) + } + + // If the pool accepts new connections, poll the listeners stream. + if self.pool.is_open() { + match ListenersStream::poll(Pin::new(&mut self.listeners), cx) { + Poll::Pending => (), + Poll::Ready(ListenersEvent::Incoming { + listener_id, + upgrade, + local_addr, + send_back_addr + }) => { + return Poll::Ready(Some( + NetworkEvent::IncomingConnection( + IncomingConnectionEvent { + listener_id, + upgrade, + local_addr, + send_back_addr, + pool: &mut self.pool, + }))) + } + Poll::Ready(ListenersEvent::NewAddress { listener_id, listen_addr }) => { + return Poll::Ready(Some( + NetworkEvent::NewListenerAddress { + listener_id, listen_addr + })) + } + Poll::Ready(ListenersEvent::AddressExpired { listener_id, listen_addr }) => { + return Poll::Ready(Some( + NetworkEvent::ExpiredListenerAddress { + listener_id, listen_addr + })) + } + Poll::Ready(ListenersEvent::Closed { listener_id, addresses, reason }) => { + return Poll::Ready(Some( + NetworkEvent::ListenerClosed { + listener_id, addresses, reason + })) + } + Poll::Ready(ListenersEvent::Error { listener_id, error }) => { + return Poll::Ready(Some( + NetworkEvent::ListenerError { + listener_id, error + })) + } } } - // Poll the known peers. + // Poll the connection pool. let event = match self.pool.poll(cx) { Poll::Pending => return Poll::Pending, - Poll::Ready(PoolEvent::ConnectionEstablished { connection, num_established }) => { + Poll::Ready(Some( + PoolEvent::ConnectionEstablished { connection, num_established } + )) => { match self.dialing.entry(connection.peer_id().clone()) { hash_map::Entry::Occupied(mut e) => { e.get_mut().retain(|s| s.current.0 != connection.id()); @@ -394,7 +456,9 @@ where num_established, } } - Poll::Ready(PoolEvent::PendingConnectionError { id, endpoint, error, handler, pool, .. }) => { + Poll::Ready(Some( + PoolEvent::PendingConnectionError { id, endpoint, error, handler, pool, .. } + )) => { let dialing = &mut self.dialing; let (next, event) = on_connection_failed(dialing, id, endpoint, error, handler); if let Some(dial) = next { @@ -405,7 +469,9 @@ where } event } - Poll::Ready(PoolEvent::ConnectionClosed { id, connected, error, num_established, .. }) => { + Poll::Ready(Some( + PoolEvent::ConnectionClosed { id, connected, error, num_established, .. } + )) => { NetworkEvent::ConnectionClosed { id, connected, @@ -413,31 +479,36 @@ where error, } } - Poll::Ready(PoolEvent::ConnectionEvent { connection, event }) => { + Poll::Ready(Some( + PoolEvent::ConnectionEvent { connection, event } + )) => { NetworkEvent::ConnectionEvent { connection, event, } } - Poll::Ready(PoolEvent::AddressChange { connection, new_endpoint, old_endpoint }) => { + Poll::Ready(Some( + PoolEvent::AddressChange { connection, new_endpoint, old_endpoint } + )) => { NetworkEvent::AddressChange { connection, new_endpoint, old_endpoint, } } + Poll::Ready(None) => { + // If the connection pool closed, so does the network. + return Poll::Ready(None) + } }; - Poll::Ready(event) + Poll::Ready(Some(event)) } /// Initiates a connection attempt to a known peer. fn dial_peer(&mut self, opts: DialingOpts) -> Result where - TTrans: Transport, - TTrans::Dial: Send + 'static, - TTrans::Error: Send + 'static, TMuxer: Send + Sync + 'static, TMuxer::OutboundSubstream: Send, TInEvent: Send + 'static, diff --git a/core/src/network/peer.rs b/core/src/network/peer.rs index 2966404759a7..b07d6feda137 100644 --- a/core/src/network/peer.rs +++ b/core/src/network/peer.rs @@ -165,6 +165,7 @@ where TTrans: Transport + Clone, TTrans::Error: Send + 'static, TTrans::Dial: Send + 'static, + TTrans::ListenerUpgrade: Send + 'static, TMuxer: StreamMuxer + Send + Sync + 'static, TMuxer::OutboundSubstream: Send, TInEvent: Send + 'static, @@ -173,7 +174,7 @@ where THandler::Handler: ConnectionHandler, InEvent = TInEvent, OutEvent = TOutEvent> + Send, ::OutboundOpenInfo: Send, ::Error: error::Error + Send + 'static, - TConnInfo: fmt::Debug + ConnectionInfo + Send + 'static, + TConnInfo: fmt::Debug + ConnectionInfo + Clone + Send + 'static, TPeerId: Eq + Hash + Clone + Send + 'static, { /// Checks whether the peer is currently connected. diff --git a/core/tests/network_dial_error.rs b/core/tests/network_dial_error.rs index 630eccc01e18..a83ea4cbc23a 100644 --- a/core/tests/network_dial_error.rs +++ b/core/tests/network_dial_error.rs @@ -80,7 +80,7 @@ fn deny_incoming_connec() { swarm1.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()).unwrap(); let address = async_std::task::block_on(future::poll_fn(|cx| { - if let Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) = swarm1.poll(cx) { + if let Poll::Ready(Some(NetworkEvent::NewListenerAddress { listen_addr, .. })) = swarm1.poll(cx) { Poll::Ready(listen_addr) } else { panic!("Was expecting the listen address to be reported") @@ -94,18 +94,18 @@ fn deny_incoming_connec() { async_std::task::block_on(future::poll_fn(|cx| -> Poll> { match swarm1.poll(cx) { - Poll::Ready(NetworkEvent::IncomingConnection(inc)) => drop(inc), + Poll::Ready(Some(NetworkEvent::IncomingConnection(inc))) => drop(inc), Poll::Ready(_) => unreachable!(), Poll::Pending => (), } match swarm2.poll(cx) { - Poll::Ready(NetworkEvent::DialError { + Poll::Ready(Some(NetworkEvent::DialError { attempts_remaining: 0, peer_id, multiaddr, error: PendingConnectionError::Transport(_) - }) => { + })) => { assert_eq!(peer_id, *swarm1.local_peer_id()); assert_eq!(multiaddr, address); return Poll::Ready(Ok(())); @@ -136,7 +136,7 @@ fn dial_self() { let (local_address, mut swarm) = async_std::task::block_on( future::lazy(move |cx| { - if let Poll::Ready(NetworkEvent::NewListenerAddress { listen_addr, .. }) = swarm.poll(cx) { + if let Poll::Ready(Some(NetworkEvent::NewListenerAddress { listen_addr, .. })) = swarm.poll(cx) { Ok::<_, void::Void>((listen_addr, swarm)) } else { panic!("Was expecting the listen address to be reported") @@ -151,11 +151,11 @@ fn dial_self() { async_std::task::block_on(future::poll_fn(|cx| -> Poll> { loop { match swarm.poll(cx) { - Poll::Ready(NetworkEvent::UnknownPeerDialError { + Poll::Ready(Some(NetworkEvent::UnknownPeerDialError { multiaddr, error: PendingConnectionError::InvalidPeerId { .. }, .. - }) => { + })) => { assert!(!got_dial_err); assert_eq!(multiaddr, local_address); got_dial_err = true; @@ -163,9 +163,9 @@ fn dial_self() { return Poll::Ready(Ok(())) } }, - Poll::Ready(NetworkEvent::IncomingConnectionError { + Poll::Ready(Some(NetworkEvent::IncomingConnectionError { local_addr, .. - }) => { + })) => { assert!(!got_inc_err); assert_eq!(local_addr, local_address); got_inc_err = true; @@ -173,13 +173,16 @@ fn dial_self() { return Poll::Ready(Ok(())) } }, - Poll::Ready(NetworkEvent::IncomingConnection(inc)) => { + Poll::Ready(Some(NetworkEvent::IncomingConnection(inc))) => { assert_eq!(*inc.local_addr(), local_address); inc.accept(TestHandler()).unwrap(); }, - Poll::Ready(ev) => { + Poll::Ready(Some(ev)) => { panic!("Unexpected event: {:?}", ev) } + Poll::Ready(None) => { + panic!("Unexpected shutdown") + } Poll::Pending => break Poll::Pending, } } @@ -221,12 +224,12 @@ fn multiple_addresses_err() { async_std::task::block_on(future::poll_fn(|cx| -> Poll> { loop { match swarm.poll(cx) { - Poll::Ready(NetworkEvent::DialError { + Poll::Ready(Some(NetworkEvent::DialError { attempts_remaining, peer_id, multiaddr, error: PendingConnectionError::Transport(_) - }) => { + })) => { assert_eq!(peer_id, target); let expected = addresses.remove(0); assert_eq!(multiaddr, expected); diff --git a/examples/ipfs-kad.rs b/examples/ipfs-kad.rs index ec48435db002..3d29ad3afd62 100644 --- a/examples/ipfs-kad.rs +++ b/examples/ipfs-kad.rs @@ -97,10 +97,10 @@ fn main() -> Result<(), Box> { task::block_on(async move { loop { let event = swarm.next().await; - if let KademliaEvent::QueryResult { + if let Some(KademliaEvent::QueryResult { result: QueryResult::GetClosestPeers(result), .. - } = event { + }) = event { match result { Ok(ok) => if !ok.peers.is_empty() { diff --git a/misc/core-derive/src/lib.rs b/misc/core-derive/src/lib.rs index c100b516366d..6726251033dd 100644 --- a/misc/core-derive/src/lib.rs +++ b/misc/core-derive/src/lib.rs @@ -453,6 +453,12 @@ fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> TokenStream { std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address }) => { return std::task::Poll::Ready(#network_behaviour_action::ReportObservedAddr { address }); } + std::task::Poll::Ready(#network_behaviour_action::CloseConnection { peer_id, connection_id }) => { + return std::task::Poll::Ready(#network_behaviour_action::CloseConnection { peer_id, connection_id }); + } + std::task::Poll::Ready(#network_behaviour_action::DisconnectPeer { peer_id }) => { + return std::task::Poll::Ready(#network_behaviour_action::DisconnectPeer { peer_id }); + } std::task::Poll::Pending => break, } } diff --git a/misc/core-derive/tests/test.rs b/misc/core-derive/tests/test.rs index 1dfcd84723e3..2aadd2d1c32d 100644 --- a/misc/core-derive/tests/test.rs +++ b/misc/core-derive/tests/test.rs @@ -300,8 +300,9 @@ fn event_process_false() { // check that the event is bubbled up all the way to swarm let _ = async { match swarm.next().await { - BehaviourOutEvent::Ping(_) => {}, - BehaviourOutEvent::Identify(_) => {}, + Some(BehaviourOutEvent::Ping(_)) => {}, + Some(BehaviourOutEvent::Identify(_)) => {}, + None => panic!("swarm terminated unexpectedly") } }; } diff --git a/protocols/gossipsub/src/behaviour.rs b/protocols/gossipsub/src/behaviour.rs index 2a17efaf8d93..0f83fa429487 100644 --- a/protocols/gossipsub/src/behaviour.rs +++ b/protocols/gossipsub/src/behaviour.rs @@ -1170,6 +1170,12 @@ impl NetworkBehaviour for Gossipsub { NetworkBehaviourAction::ReportObservedAddr { address } => { return Poll::Ready(NetworkBehaviourAction::ReportObservedAddr { address }); } + NetworkBehaviourAction::CloseConnection { peer_id, connection_id } => { + return Poll::Ready(NetworkBehaviourAction::CloseConnection { peer_id, connection_id }) + } + NetworkBehaviourAction::DisconnectPeer { peer_id } => { + return Poll::Ready(NetworkBehaviourAction::DisconnectPeer { peer_id }) + } } } diff --git a/protocols/gossipsub/tests/smoke.rs b/protocols/gossipsub/tests/smoke.rs index f16486e66cce..7734e978d995 100644 --- a/protocols/gossipsub/tests/smoke.rs +++ b/protocols/gossipsub/tests/smoke.rs @@ -48,17 +48,20 @@ struct Graph { } impl Future for Graph { - type Output = (Multiaddr, GossipsubEvent); + type Output = Option<(Multiaddr, GossipsubEvent)>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { for (addr, node) in &mut self.nodes { match node.poll_next_unpin(cx) { - Poll::Ready(Some(event)) => return Poll::Ready((addr.clone(), event)), - Poll::Ready(None) => panic!("unexpected None when polling nodes"), - Poll::Pending => {} + Poll::Ready(Some(event)) => return Poll::Ready(Some((addr.clone(), event))), + Poll::Ready(None) | Poll::Pending => {}, } } + if self.nodes.iter().all(|(_, s)| Swarm::is_closed(s)) { + return Poll::Ready(None) + } + Poll::Pending } } @@ -118,11 +121,12 @@ impl Graph { let fut = futures::future::poll_fn(move |cx| match &mut this { Some(graph) => loop { match graph.poll_unpin(cx) { - Poll::Ready((_addr, ev)) => { + Poll::Ready(Some((_addr, ev))) => { if f(ev) { - return Poll::Ready(this.take().unwrap()); + graph.nodes.iter_mut().for_each(|(_, s)| Swarm::start_close(s)); } } + Poll::Ready(None) => return Poll::Ready(this.take().unwrap()), Poll::Pending => return Poll::Pending, } }, diff --git a/protocols/identify/src/identify.rs b/protocols/identify/src/identify.rs index 312e273ded15..738bae8ee0e8 100644 --- a/protocols/identify/src/identify.rs +++ b/protocols/identify/src/identify.rs @@ -323,7 +323,7 @@ mod tests { let swarm1_fut = swarm1.next_event(); pin_mut!(swarm1_fut); match swarm1_fut.await { - SwarmEvent::NewListenAddr(addr) => return addr, + Some(SwarmEvent::NewListenAddr(addr)) => return addr, _ => {} } } @@ -342,7 +342,7 @@ mod tests { pin_mut!(swarm2_fut); match future::select(swarm1_fut, swarm2_fut).await.factor_second().0 { - future::Either::Left(IdentifyEvent::Received { info, .. }) => { + future::Either::Left(Some(IdentifyEvent::Received { info, .. })) => { assert_eq!(info.public_key, pubkey2); assert_eq!(info.protocol_version, "c"); assert_eq!(info.agent_version, "d"); @@ -350,7 +350,7 @@ mod tests { assert!(info.listen_addrs.is_empty()); return; } - future::Either::Right(IdentifyEvent::Received { info, .. }) => { + future::Either::Right(Some(IdentifyEvent::Received { info, .. })) => { assert_eq!(info.public_key, pubkey1); assert_eq!(info.protocol_version, "a"); assert_eq!(info.agent_version, "b"); diff --git a/protocols/kad/src/behaviour/test.rs b/protocols/kad/src/behaviour/test.rs index f7c5f97818a6..6ca005afd1e0 100644 --- a/protocols/kad/src/behaviour/test.rs +++ b/protocols/kad/src/behaviour/test.rs @@ -62,7 +62,7 @@ fn build_node_with_config(cfg: KademliaConfig) -> (Multiaddr, TestSwarm) { .authenticate(SecioConfig::new(local_key)) .multiplex(yamux::Config::default()) .map(|(p, m), _| (p, StreamMuxerBox::new(m))) - .map_err(|e| -> io::Error { panic!("Failed to create transport: {:?}", e); }) + .map_err(|e| -> io::Error { io::Error::new(io::ErrorKind::Other, e.to_string()) }) .boxed(); let local_id = local_public_key.clone().into_peer_id(); @@ -179,43 +179,55 @@ fn bootstrap() { let expected_known = swarm_ids.iter().skip(1).cloned().collect::>(); let mut first = true; + let mut success = false; + // Run test - block_on( - poll_fn(move |ctx| { - for (i, swarm) in swarms.iter_mut().enumerate() { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(KademliaEvent::QueryResult { - id, result: QueryResult::Bootstrap(Ok(ok)), .. - })) => { - assert_eq!(id, qid); - assert_eq!(i, 0); - if first { - // Bootstrapping must start with a self-lookup. - assert_eq!(ok.peer, swarm_ids[0]); - } - first = false; - if ok.num_remaining == 0 { - let mut known = HashSet::new(); - for b in swarm.kbuckets.iter() { - for e in b.iter() { - known.insert(e.node.key.preimage().clone()); - } + block_on(poll_fn(move |ctx| { + for (i, swarm) in swarms.iter_mut().enumerate() { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(KademliaEvent::QueryResult { + id, result: QueryResult::Bootstrap(Ok(ok)), .. + })) => { + assert_eq!(id, qid); + assert_eq!(i, 0); + if first { + // Bootstrapping must start with a self-lookup. + assert_eq!(ok.peer, swarm_ids[0]); + } + first = false; + if ok.num_remaining == 0 { + let mut known = HashSet::new(); + for b in swarm.kbuckets.iter() { + for e in b.iter() { + known.insert(e.node.key.preimage().clone()); } - assert_eq!(expected_known, known); - return Poll::Ready(()) } + assert_eq!(expected_known, known); + success = true; + Swarm::start_close(swarm); } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + Poll::Ready(None) => break, + Poll::Pending => + if success && !Swarm::is_closing(swarm) { + Swarm::start_close(swarm) + } else { + break + } } } - Poll::Pending - }) - ) + } + + if swarms.iter().all(Swarm::is_closed) { + assert!(success); + return Poll::Ready(()) + } + + Poll::Pending + })) } QuickCheck::new().tests(10).quickcheck(prop as fn(_) -> _) @@ -259,34 +271,46 @@ fn query_iter() { let mut expected_distances = distances(&search_target_key, expected_peer_ids.clone()); expected_distances.sort(); + let mut success = false; + // Run test - block_on( - poll_fn(move |ctx| { - for (i, swarm) in swarms.iter_mut().enumerate() { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(KademliaEvent::QueryResult { - id, result: QueryResult::GetClosestPeers(Ok(ok)), .. - })) => { - assert_eq!(id, qid); - assert_eq!(&ok.key[..], search_target.as_bytes()); - assert_eq!(swarm_ids[i], expected_swarm_id); - assert_eq!(swarm.queries.size(), 0); - assert!(expected_peer_ids.iter().all(|p| ok.peers.contains(p))); - let key = kbucket::Key::new(ok.key); - assert_eq!(expected_distances, distances(&key, ok.peers)); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + block_on(poll_fn(move |ctx| { + for (i, swarm) in swarms.iter_mut().enumerate() { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(KademliaEvent::QueryResult { + id, result: QueryResult::GetClosestPeers(Ok(ok)), .. + })) => { + assert_eq!(id, qid); + assert_eq!(&ok.key[..], search_target.as_bytes()); + assert_eq!(swarm_ids[i], expected_swarm_id); + assert_eq!(swarm.queries.size(), 0); + assert!(expected_peer_ids.iter().all(|p| ok.peers.contains(p))); + let key = kbucket::Key::new(ok.key); + assert_eq!(expected_distances, distances(&key, ok.peers)); + success = true; + Swarm::start_close(swarm); } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + Poll::Ready(None) => break, + Poll::Pending => + if success && !Swarm::is_closing(swarm) { + Swarm::start_close(swarm); + } else { + break + } } } - Poll::Pending - }) - ) + } + + if swarms.iter().all(Swarm::is_closed) { + assert!(success); + return Poll::Ready(()) + } + + Poll::Pending + })) } let mut rng = thread_rng(); @@ -323,11 +347,11 @@ fn unresponsive_not_returned_direct() { })) => { assert_eq!(&ok.key[..], search_target.as_bytes()); assert_eq!(ok.peers.len(), 0); - return Poll::Ready(()); + Swarm::start_close(swarm); } // Ignore any other event. Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Ready(None) => return Poll::Ready(()), Poll::Pending => break, } } @@ -374,11 +398,11 @@ fn unresponsive_not_returned_indirect() { assert_eq!(&ok.key[..], search_target.as_bytes()); assert_eq!(ok.peers.len(), 1); assert_eq!(ok.peers[0], first_peer_id); - return Poll::Ready(()); + Swarm::start_close(swarm); } // Ignore any other event. Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), + Poll::Ready(None) => return Poll::Ready(()), Poll::Pending => break, } } @@ -408,36 +432,40 @@ fn get_record_not_found() { let target_key = record::Key::from(random_multihash()); let qid = swarms[0].get_record(&target_key, Quorum::One); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(KademliaEvent::QueryResult { - id, result: QueryResult::GetRecord(Err(e)), .. - })) => { - assert_eq!(id, qid); - if let GetRecordError::NotFound { key, closest_peers, } = e { - assert_eq!(key, target_key); - assert_eq!(closest_peers.len(), 2); - assert!(closest_peers.contains(&swarm_ids[1])); - assert!(closest_peers.contains(&swarm_ids[2])); - return Poll::Ready(()); - } else { - panic!("Unexpected error result: {:?}", e); - } + let mut success = false; + + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(KademliaEvent::QueryResult { + id, result: QueryResult::GetRecord(Err(e)), .. + })) => { + assert_eq!(id, qid); + if let GetRecordError::NotFound { key, closest_peers, } = e { + assert_eq!(key, target_key); + assert_eq!(closest_peers.len(), 2); + assert!(closest_peers.contains(&swarm_ids[1])); + assert!(closest_peers.contains(&swarm_ids[2])); + success = true; + Swarm::start_close(swarm); + } else { + panic!("Unexpected error result: {:?}", e); } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + } + Poll::Pending => break, + // Ignore any other event. + Poll::Ready(Some(_)) => (), + Poll::Ready(None) => { + assert!(success); + return Poll::Ready(()) } } } + } - Poll::Pending - }) - ) + Poll::Pending + })) } /// A node joining a fully connected network via three (ALPHA_VALUE) bootnodes @@ -514,6 +542,8 @@ fn put_record() { // The accumulated results for one round of publishing. let mut results = Vec::new(); + let mut success = false; + block_on( poll_fn(move |ctx| loop { // Poll all swarms until they are "Pending". @@ -542,15 +572,14 @@ fn put_record() { } // Ignore any other event. Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + Poll::Ready(None) | Poll::Pending => break, } } } // All swarms are Pending and not enough results have been collected // so far, thus wait to be polled again for further progress. - if results.len() != records.len() { + if !success && results.len() != records.len() { return Poll::Pending } @@ -608,14 +637,21 @@ fn put_record() { } if republished { - assert_eq!(swarms[0].store.records().count(), records.len()); - assert_eq!(swarms[0].queries.size(), 0); - for k in records.keys() { - swarms[0].store.remove(&k); + if !success { + assert_eq!(swarms[0].store.records().count(), records.len()); + assert_eq!(swarms[0].queries.size(), 0); + for k in records.keys() { + swarms[0].store.remove(&k); + } + assert_eq!(swarms[0].store.records().count(), 0); + // All records have been republished, thus the test is complete. + swarms.iter_mut().for_each(Swarm::start_close); + success = true; + } + if swarms.iter().all(Swarm::is_closed) { + assert!(success); + return Poll::Ready(()) } - assert_eq!(swarms[0].store.records().count(), 0); - // All records have been republished, thus the test is complete. - return Poll::Ready(()); } // Tell the replication job to republish asap. @@ -646,32 +682,44 @@ fn get_record() { swarms[1].store.put(record.clone()).unwrap(); let qid = swarms[0].get_record(&record.key, Quorum::One); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(KademliaEvent::QueryResult { - id, - result: QueryResult::GetRecord(Ok(GetRecordOk { records })), - .. - })) => { - assert_eq!(id, qid); - assert_eq!(records.len(), 1); - assert_eq!(records.first().unwrap().record, record); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + let mut success = false; + + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(KademliaEvent::QueryResult { + id, + result: QueryResult::GetRecord(Ok(GetRecordOk { records })), + .. + })) => { + assert!(!success); + assert_eq!(id, qid); + assert_eq!(records.len(), 1); + assert_eq!(records.first().unwrap().record, record); + success = true; + Swarm::start_close(swarm); } + // Ignore any other event. + Poll::Ready(Some(..)) => {}, + Poll::Pending => + if success && !Swarm::is_closing(swarm) { + Swarm::start_close(swarm); + } else { + break + }, + Poll::Ready(None) => break } } + } - Poll::Pending - }) - ) + if swarms.iter().all(Swarm::is_closed) { + assert!(success); + return Poll::Ready(()) + } + + Poll::Pending + })); } #[test] @@ -692,31 +740,43 @@ fn get_record_many() { let quorum = Quorum::N(NonZeroUsize::new(num_results).unwrap()); let qid = swarms[0].get_record(&record.key, quorum); - block_on( - poll_fn(move |ctx| { - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(KademliaEvent::QueryResult { - id, - result: QueryResult::GetRecord(Ok(GetRecordOk { records })), - .. - })) => { - assert_eq!(id, qid); - assert_eq!(records.len(), num_results); - assert_eq!(records.first().unwrap().record, record); - return Poll::Ready(()); - } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, + let mut success = false; + + block_on(poll_fn(move |ctx| { + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(KademliaEvent::QueryResult { + id, + result: QueryResult::GetRecord(Ok(GetRecordOk { records })), + .. + })) => { + assert_eq!(id, qid); + assert_eq!(records.len(), num_results); + assert_eq!(records.first().unwrap().record, record); + success = true; + Swarm::start_close(swarm); } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + Poll::Ready(None) => break, + Poll::Pending => + if success && !Swarm::is_closing(swarm) { + Swarm::start_close(swarm); + } else { + break + } } } - Poll::Pending - }) - ) + } + + if swarms.iter().all(Swarm::is_closed) { + assert!(success); + return Poll::Ready(()) + } + + Poll::Pending + })); } /// A node joining a fully connected network via three (ALPHA_VALUE) bootnodes @@ -773,106 +833,112 @@ fn add_provider() { qids.insert(qid); } - block_on( - poll_fn(move |ctx| loop { - // Poll all swarms until they are "Pending". - for swarm in &mut swarms { - loop { - match swarm.poll_next_unpin(ctx) { - Poll::Ready(Some(KademliaEvent::QueryResult { - id, result: QueryResult::StartProviding(res), .. - })) | - Poll::Ready(Some(KademliaEvent::QueryResult { - id, result: QueryResult::RepublishProvider(res), .. - })) => { - assert!(qids.is_empty() || qids.remove(&id)); - match res { - Err(e) => panic!(e), - Ok(ok) => { - assert!(keys.contains(&ok.key)); - results.push(ok.key); - } + let mut success = false; + + block_on(poll_fn(move |ctx| loop { + // Poll all swarms until they are "Pending". + for swarm in &mut swarms { + loop { + match swarm.poll_next_unpin(ctx) { + Poll::Ready(Some(KademliaEvent::QueryResult { + id, result: QueryResult::StartProviding(res), .. + })) | + Poll::Ready(Some(KademliaEvent::QueryResult { + id, result: QueryResult::RepublishProvider(res), .. + })) => { + assert!(qids.is_empty() || qids.remove(&id)); + match res { + Err(e) => panic!(e), + Ok(ok) => { + assert!(keys.contains(&ok.key)); + results.push(ok.key); } } - // Ignore any other event. - Poll::Ready(Some(_)) => (), - e @ Poll::Ready(_) => panic!("Unexpected return value: {:?}", e), - Poll::Pending => break, } + // Ignore any other event. + Poll::Ready(Some(_)) => (), + Poll::Ready(None) | Poll::Pending => break, } } + } - if results.len() == keys.len() { - // All requests have been sent for one round of publishing. - published = true - } + if results.len() == keys.len() { + // All requests have been sent for one round of publishing. + published = true + } + + if !success && !published { + // Still waiting for all requests to be sent for one round + // of publishing. + return Poll::Pending + } - if !published { - // Still waiting for all requests to be sent for one round - // of publishing. + // A round of publishing is complete. Consume the results, checking that + // each key was published to the `replication_factor` closest peers. + while let Some(key) = results.pop() { + // Collect the nodes that have a provider record for `key`. + let actual = swarms.iter().skip(1) + .filter_map(|swarm| + if swarm.store.providers(&key).len() == 1 { + Some(Swarm::local_peer_id(&swarm).clone()) + } else { + None + }) + .collect::>(); + + if actual.len() != replication_factor.get() { + // Still waiting for some nodes to process the request. + results.push(key); return Poll::Pending } - // A round of publishing is complete. Consume the results, checking that - // each key was published to the `replication_factor` closest peers. - while let Some(key) = results.pop() { - // Collect the nodes that have a provider record for `key`. - let actual = swarms.iter().skip(1) - .filter_map(|swarm| - if swarm.store.providers(&key).len() == 1 { - Some(Swarm::local_peer_id(&swarm).clone()) - } else { - None - }) - .collect::>(); - - if actual.len() != replication_factor.get() { - // Still waiting for some nodes to process the request. - results.push(key); - return Poll::Pending - } - - let mut expected = swarms.iter() - .skip(1) - .map(Swarm::local_peer_id) - .cloned() - .collect::>(); - let kbucket_key = kbucket::Key::new(key); - expected.sort_by(|id1, id2| - kbucket::Key::new(id1.clone()).distance(&kbucket_key).cmp( - &kbucket::Key::new(id2.clone()).distance(&kbucket_key))); - - let expected = expected - .into_iter() - .take(replication_factor.get()) - .collect::>(); - - assert_eq!(actual, expected); - } + let mut expected = swarms.iter() + .skip(1) + .map(Swarm::local_peer_id) + .cloned() + .collect::>(); + let kbucket_key = kbucket::Key::new(key); + expected.sort_by(|id1, id2| + kbucket::Key::new(id1.clone()).distance(&kbucket_key).cmp( + &kbucket::Key::new(id2.clone()).distance(&kbucket_key))); + + let expected = expected + .into_iter() + .take(replication_factor.get()) + .collect::>(); + + assert_eq!(actual, expected); + } - // One round of publishing is complete. - assert!(results.is_empty()); - for swarm in &swarms { - assert_eq!(swarm.queries.size(), 0); - } + // One round of publishing is complete. + assert!(results.is_empty()); + for swarm in &swarms { + assert_eq!(swarm.queries.size(), 0); + } - if republished { + if republished { + // All records have been republished, thus the test is complete. + if !success { assert_eq!(swarms[0].store.provided().count(), keys.len()); for k in &keys { swarms[0].stop_providing(&k); } assert_eq!(swarms[0].store.provided().count(), 0); - // All records have been republished, thus the test is complete. + success = true; + swarms.iter_mut().for_each(Swarm::start_close); + } + if swarms.iter().all(Swarm::is_closed) { + assert!(success); return Poll::Ready(()); } + } - // Initiate the second round of publishing by telling the - // periodic provider job to run asap. - swarms[0].add_provider_job.as_mut().unwrap().asap(); - published = false; - republished = true; - }) - ) + // Initiate the second round of publishing by telling the + // periodic provider job to run asap. + swarms[0].add_provider_job.as_mut().unwrap().asap(); + published = false; + republished = true; + })) } QuickCheck::new().tests(3).quickcheck(prop as fn(_,_)) @@ -891,25 +957,23 @@ fn exceed_jobs_max_queries() { assert_eq!(swarm.queries.size(), num); - block_on( - poll_fn(move |ctx| { - for _ in 0 .. num { - // There are no other nodes, so the queries finish instantly. - if let Poll::Ready(Some(e)) = swarm.poll_next_unpin(ctx) { - if let KademliaEvent::QueryResult { - result: QueryResult::GetClosestPeers(Ok(r)), .. - } = e { - assert!(r.peers.is_empty()) - } else { - panic!("Unexpected event: {:?}", e) - } + block_on(poll_fn(move |ctx| { + for _ in 0 .. num { + // There are no other nodes, so the queries finish instantly. + if let Poll::Ready(Some(e)) = swarm.poll_next_unpin(ctx) { + if let KademliaEvent::QueryResult { + result: QueryResult::GetClosestPeers(Ok(r)), .. + } = e { + assert!(r.peers.is_empty()) } else { - panic!("Expected event") + panic!("Unexpected event: {:?}", e) } + } else { + panic!("Expected event") } - Poll::Ready(()) - }) - ) + } + Poll::Ready(()) + })) } #[test] @@ -1011,9 +1075,11 @@ fn disjoint_query_does_not_finish_before_all_paths_did() { } }); + let mut records = Vec::new(); + // Poll `alice` and `bob` expecting `alice` to return a successful query // result as it is now able to explore the second disjoint path. - let records = block_on( + block_on( poll_fn(|ctx| { for (i, swarm) in [&mut alice, &mut bob].iter_mut().enumerate() { loop { @@ -1027,20 +1093,31 @@ fn disjoint_query_does_not_finish_before_all_paths_did() { } match result { - Ok(ok) => return Poll::Ready(ok.records), + Ok(ok) => { + records = ok.records; + Swarm::start_close(swarm); + }, Err(e) => unreachable!("{:?}", e), } } // Ignore any other event. Poll::Ready(Some(_)) => (), - Poll::Ready(None) => panic!( - "Expected Kademlia behaviour not to finish.", - ), - Poll::Pending => break, + Poll::Ready(None) => break, + Poll::Pending => { + if !records.is_empty() && !Swarm::is_closing(swarm) { + Swarm::start_close(swarm) + } else { + break + } + } } } } + if [&alice, &bob].iter().all(|s| Swarm::is_closed(*s)) { + return Poll::Ready(()) + } + Poll::Pending }) ); @@ -1074,6 +1151,7 @@ fn manual_bucket_inserts() { let mut routable = Vec::new(); // Start an iterative query from the first peer. swarms[0].1.get_closest_peers(PeerId::random()); + let mut success = false; block_on(poll_fn(move |ctx| { for (_, swarm) in swarms.iter_mut() { loop { @@ -1088,14 +1166,27 @@ fn manual_bucket_inserts() { let bucket = swarm.kbucket(peer.clone()).unwrap(); assert!(bucket.iter().all(|e| e.node.key.preimage() != peer)); } - return Poll::Ready(()) + success = true; + Swarm::start_close(swarm); } } - Poll::Ready(..) => {}, - Poll::Pending => break + Poll::Ready(Some(_)) => {}, + Poll::Ready(None) => break, + Poll::Pending => + if success && !Swarm::is_closing(swarm) { + Swarm::start_close(swarm) + } else { + break + } } } } + + if swarms.iter().all(|(_, s)| Swarm::is_closed(s)) { + assert!(success); + return Poll::Ready(()) + } + Poll::Pending })); } diff --git a/protocols/ping/tests/ping.rs b/protocols/ping/tests/ping.rs index 30e8de601ee0..4f36673e3e3f 100644 --- a/protocols/ping/tests/ping.rs +++ b/protocols/ping/tests/ping.rs @@ -60,7 +60,7 @@ fn ping() { loop { match swarm1.next().await { - PingEvent { peer, result: Ok(PingSuccess::Ping { rtt }) } => { + Some(PingEvent { peer, result: Ok(PingSuccess::Ping { rtt }) }) => { return (pid1.clone(), peer, rtt) }, _ => {} @@ -74,7 +74,7 @@ fn ping() { loop { match swarm2.next().await { - PingEvent { peer, result: Ok(PingSuccess::Ping { rtt }) } => { + Some(PingEvent { peer, result: Ok(PingSuccess::Ping { rtt }) }) => { return (pid2.clone(), peer, rtt) }, _ => {} diff --git a/protocols/request-response/src/handler.rs b/protocols/request-response/src/handler.rs index 3a491a69b6a0..715971010630 100644 --- a/protocols/request-response/src/handler.rs +++ b/protocols/request-response/src/handler.rs @@ -86,7 +86,7 @@ where impl RequestResponseHandler where - TCodec: RequestResponseCodec, + TCodec: RequestResponseCodec + Clone + Send + 'static, { pub(super) fn new( inbound_protocols: SmallVec<[TCodec::Protocol; 2]>, @@ -106,6 +106,30 @@ where pending_error: None, } } + + fn poll_inbound(&mut self, cx: &mut Context<'_>) + -> Poll> + { + while let Poll::Ready(Some(result)) = self.inbound.poll_next_unpin(cx) { + match result { + Ok((rq, rs_sender)) => { + // We received an inbound request. + self.keep_alive = KeepAlive::Yes; + return Poll::Ready( + RequestResponseHandlerEvent::Request { + request: rq, sender: rs_sender + }) + } + Err(oneshot::Canceled) => { + // The inbound upgrade has errored or timed out reading + // or waiting for the request. The handler is informed + // via `inject_listen_upgrade_error`. + } + } + } + + Poll::Pending + } } /// The events emitted by the [`RequestResponseHandler`]. @@ -132,6 +156,8 @@ where InboundTimeout, /// An inbound request failed to negotiate a mutually supported protocol. InboundUnsupportedProtocols, + /// An outbound request could not be sent because the connection is closing. + Closing(RequestProtocol), } impl ProtocolsHandler for RequestResponseHandler @@ -256,10 +282,7 @@ where self.keep_alive } - fn poll( - &mut self, - cx: &mut Context, - ) -> Poll< + fn poll(&mut self, cx: &mut Context) -> Poll< ProtocolsHandlerEvent, RequestId, Self::OutEvent, Self::Error>, > { // Check for a pending (fatal) error. @@ -276,22 +299,8 @@ where } // Check for inbound requests. - while let Poll::Ready(Some(result)) = self.inbound.poll_next_unpin(cx) { - match result { - Ok((rq, rs_sender)) => { - // We received an inbound request. - self.keep_alive = KeepAlive::Yes; - return Poll::Ready(ProtocolsHandlerEvent::Custom( - RequestResponseHandlerEvent::Request { - request: rq, sender: rs_sender - })) - } - Err(oneshot::Canceled) => { - // The inbound upgrade has errored or timed out reading - // or waiting for the request. The handler is informed - // via `inject_listen_upgrade_error`. - } - } + if let Poll::Ready(event) = self.poll_inbound(cx) { + return Poll::Ready(ProtocolsHandlerEvent::Custom(event)) } // Emit outbound requests. @@ -322,5 +331,37 @@ where Poll::Pending } + + fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll< + Result, Self::Error> + > { + // Check for a pending (fatal) error. + if let Some(err) = self.pending_error.take() { + // The handler will not be polled again by the `Swarm`. + return Poll::Ready(Err(err)) + } + + // Drain pending events. + if let Some(event) = self.pending_events.pop_front() { + return Poll::Ready(Ok(Some(event))) + } + + // Drain remaining inbound requests. New inbound requests + // will not be received. + if let Poll::Ready(event) = self.poll_inbound(cx) { + return Poll::Ready(Ok(Some(event))) + } + + // Deny new outbound requests so they can be rescheduled on + // a different connection. + if let Some(request) = self.outbound.pop_front() { + return Poll::Ready(Ok(Some(RequestResponseHandlerEvent::Closing(request)))) + } + + // Ongoing inbound and outbound upgrades are always allowed to + // complete, thus there is nothing else that needs to delay + // connection shutdown. + return Poll::Ready(Ok(None)) + } } diff --git a/protocols/request-response/src/lib.rs b/protocols/request-response/src/lib.rs index c21929343a0c..80666322f635 100644 --- a/protocols/request-response/src/lib.rs +++ b/protocols/request-response/src/lib.rs @@ -398,7 +398,7 @@ where /// Tries to send a request by queueing an appropriate event to be /// emitted to the `Swarm`. If the peer is not currently connected, - /// the given request is return unchanged. + /// the given request is returned unchanged. fn try_send_request(&mut self, peer: &PeerId, request: RequestProtocol) -> Option> { @@ -520,7 +520,7 @@ where fn inject_event( &mut self, peer: PeerId, - _: ConnectionId, + conn: ConnectionId, event: RequestResponseHandlerEvent, ) { match event { @@ -574,6 +574,30 @@ where error: InboundFailure::UnsupportedProtocols, })); } + RequestResponseHandlerEvent::Closing(request) => { + if let Some((req_peer, req_conn)) = self.pending_responses.remove(&request.request_id) { + debug_assert_eq!(req_peer, peer); + debug_assert_eq!(req_conn, conn); + // Try to send the request on a different connection. + if let Some(conn) = self.connected.get(&peer).and_then(|conns| + conns.iter().find(|c| c.id != conn) + ) { + self.pending_responses.insert(request.request_id, (peer, conn.id)); + self.pending_events.push_back(NetworkBehaviourAction::NotifyHandler { + peer_id: req_peer, + handler: NotifyHandler::One(conn.id), + event: request + }); + } else { + // There is no other existing connection to use, so request a new one. + self.pending_events.push_back(NetworkBehaviourAction::DialPeer { + peer_id: peer.clone(), + condition: DialPeerCondition::Disconnected, + }); + self.pending_requests.entry(peer).or_default().push(request); + } + } + } } } diff --git a/protocols/request-response/tests/ping.rs b/protocols/request-response/tests/ping.rs index 107a37edf044..6e8cfde22b16 100644 --- a/protocols/request-response/tests/ping.rs +++ b/protocols/request-response/tests/ping.rs @@ -72,10 +72,10 @@ fn ping_protocol() { loop { match swarm1.next().await { - RequestResponseEvent::Message { + Some(RequestResponseEvent::Message { peer, message: RequestResponseMessage::Request { request, channel } - } => { + }) => { assert_eq!(&request, &expected_ping); assert_eq!(&peer, &peer2_id); swarm1.send_response(channel, pong.clone()); @@ -93,10 +93,10 @@ fn ping_protocol() { loop { match swarm2.next().await { - RequestResponseEvent::Message { + Some(RequestResponseEvent::Message { peer, message: RequestResponseMessage::Response { request_id, response } - } => { + }) => { count += 1; assert_eq!(&response, &expected_pong); assert_eq!(&peer, &peer1_id); diff --git a/swarm/src/behaviour.rs b/swarm/src/behaviour.rs index f75a31ab0443..98f7f0cd7e13 100644 --- a/swarm/src/behaviour.rs +++ b/swarm/src/behaviour.rs @@ -279,6 +279,17 @@ pub enum NetworkBehaviourAction { /// The observed address of the local node. address: Multiaddr, }, + + /// Instructs the `Swarm` to initiate a shutdown of a connection. + CloseConnection { + peer_id: PeerId, + connection_id: ConnectionId + }, + + /// Instructs the `Swarm to immediately disconnect a peer. + DisconnectPeer { + peer_id: PeerId, + } } /// The options w.r.t. which connection handlers to notify of an event. diff --git a/swarm/src/lib.rs b/swarm/src/lib.rs index 5f2c3c628a74..8c3087be475a 100644 --- a/swarm/src/lib.rs +++ b/swarm/src/lib.rs @@ -289,7 +289,10 @@ where /// Pending event to be delivered to connection handlers /// (or dropped if the peer disconnected) before the `behaviour` /// can be polled again. - pending_event: Option<(PeerId, PendingNotifyHandler, TInEvent)> + pending_event: Option<(PeerId, PendingNotifyHandler, TInEvent)>, + + /// Whether the `Swarm` has been closed, i.e. completed a clean shutdown. + closed: bool, } impl Deref for @@ -461,20 +464,66 @@ where TBehaviour: NetworkBehaviour, me.banned_peers.remove(&peer_id); } + /// Starts closing down the `Swarm` and underlying `Network`. + /// + /// Closing down a `Swarm` proceeds as follows: + /// + /// 1. The underlying `Network` drains all established connections + /// while rejecting new connections. Established connections are + /// drained by asking the associated [`ProtocolsHandler`] to close, + /// after which the underlying connection is closed. + /// + /// 2. The [`NetworkBehaviour`] associated with the `Swarm` continues + /// to be polled once the `Network` closed until it returns + /// `Poll::Pending`, indicating that further network I/O would be + /// needed for the behaviour to make progress. In this way, the + /// behaviour can finish any work and emit resulting events. + /// + /// After calling this method, the `Swarm` must be `poll()`ed to + /// drive the shutdown to completion. + pub fn start_close(me: &mut Self) { + me.network.start_close() + } + + /// Closes the `Swarm`, ignoring any further events. + /// + /// See [`Self::start_close`] for details. + pub async fn close(&mut self) { + self.network.start_close(); + while let Some(_) = self.next_event().await {} + } + + /// Whether the `Swarm` is closed. + /// + /// When the `Swarm` is closed, neither the underlying `Network` + /// nor the associated `NetworkBehaviour` are polled again. + pub fn is_closed(me: &Self) -> bool { + me.closed + } + + /// Whether the `Swarm` is closing. + /// + /// While the `Swarm` is closing, new connections are rejected + /// while established connections and the `NetworkBehaviour` + /// are drained until no more work is to be done. + pub fn is_closing(me: &Self) -> bool { + me.network.is_closing() || (me.network.is_closed() && !me.closed) + } + /// Returns the next event that happens in the `Swarm`. /// /// Includes events from the `NetworkBehaviour` but also events about the connections status. - pub async fn next_event(&mut self) -> SwarmEvent { + pub async fn next_event(&mut self) -> Option> { future::poll_fn(move |cx| ExpandedSwarm::poll_next_event(Pin::new(self), cx)).await } /// Returns the next event produced by the [`NetworkBehaviour`]. - pub async fn next(&mut self) -> TBehaviour::OutEvent { + pub async fn next(&mut self) -> Option { future::poll_fn(move |cx| { loop { let event = futures::ready!(ExpandedSwarm::poll_next_event(Pin::new(self), cx)); - if let SwarmEvent::Behaviour(event) = event { - return Poll::Ready(event); + if let Some(SwarmEvent::Behaviour(event)) = event { + return Poll::Ready(Some(event)); } } }).await @@ -484,29 +533,44 @@ where TBehaviour: NetworkBehaviour, /// /// Polls the `Swarm` for the next event. fn poll_next_event(mut self: Pin<&mut Self>, cx: &mut Context) - -> Poll> + -> Poll>> { // We use a `this` variable because the compiler can't mutably borrow multiple times // across a `Deref`. let this = &mut *self; loop { + if this.closed { + return Poll::Ready(None) + } + let mut network_not_ready = false; // First let the network make progress. match this.network.poll(cx) { Poll::Pending => network_not_ready = true, - Poll::Ready(NetworkEvent::ConnectionEvent { connection, event }) => { + Poll::Ready(None) => { + // The network closed, but the behaviour may still be + // doing work and produce events. Thus we wait until it + // returns `Pending`. + } + Poll::Ready(Some( + NetworkEvent::ConnectionEvent { connection, event } + )) => { let peer = connection.peer_id().clone(); let connection = connection.id(); this.behaviour.inject_event(peer, connection, event); }, - Poll::Ready(NetworkEvent::AddressChange { connection, new_endpoint, old_endpoint }) => { + Poll::Ready(Some( + NetworkEvent::AddressChange { connection, new_endpoint, old_endpoint } + )) => { let peer = connection.peer_id(); let connection = connection.id(); this.behaviour.inject_address_change(&peer, &connection, &old_endpoint, &new_endpoint); }, - Poll::Ready(NetworkEvent::ConnectionEstablished { connection, num_established }) => { + Poll::Ready(Some( + NetworkEvent::ConnectionEstablished { connection, num_established } + )) => { let peer_id = connection.peer_id().clone(); let endpoint = connection.endpoint().clone(); if this.banned_peers.contains(&peer_id) { @@ -514,10 +578,10 @@ where TBehaviour: NetworkBehaviour, .into_connected() .expect("the Network just notified us that we were connected; QED") .disconnect(); - return Poll::Ready(SwarmEvent::BannedPeer { + return Poll::Ready(Some(SwarmEvent::BannedPeer { peer_id, endpoint, - }); + })); } else { log::debug!("Connection established: {:?}; Total (peer): {}.", connection.connected(), num_established); @@ -526,12 +590,14 @@ where TBehaviour: NetworkBehaviour, if num_established.get() == 1 { this.behaviour.inject_connected(&peer_id); } - return Poll::Ready(SwarmEvent::ConnectionEstablished { + return Poll::Ready(Some(SwarmEvent::ConnectionEstablished { peer_id, num_established, endpoint - }); + })); } }, - Poll::Ready(NetworkEvent::ConnectionClosed { id, connected, error, num_established }) => { + Poll::Ready(Some( + NetworkEvent::ConnectionClosed { id, connected, error, num_established } + )) => { if let Some(error) = error.as_ref() { log::debug!("Connection {:?} closed: {:?}", connected, error); } else { @@ -543,40 +609,46 @@ where TBehaviour: NetworkBehaviour, if num_established == 0 { this.behaviour.inject_disconnected(info.peer_id()); } - return Poll::Ready(SwarmEvent::ConnectionClosed { + return Poll::Ready(Some(SwarmEvent::ConnectionClosed { peer_id: info.peer_id().clone(), endpoint, cause: error, num_established, - }); + })); }, - Poll::Ready(NetworkEvent::IncomingConnection(incoming)) => { + Poll::Ready(Some(NetworkEvent::IncomingConnection(incoming))) => { let handler = this.behaviour.new_handler(); let local_addr = incoming.local_addr().clone(); let send_back_addr = incoming.send_back_addr().clone(); if let Err(e) = incoming.accept(handler.into_node_handler_builder()) { log::warn!("Incoming connection rejected: {:?}", e); } - return Poll::Ready(SwarmEvent::IncomingConnection { + return Poll::Ready(Some(SwarmEvent::IncomingConnection { local_addr, send_back_addr, - }); + })); }, - Poll::Ready(NetworkEvent::NewListenerAddress { listener_id, listen_addr }) => { + Poll::Ready(Some( + NetworkEvent::NewListenerAddress { listener_id, listen_addr } + )) => { log::debug!("Listener {:?}; New address: {:?}", listener_id, listen_addr); if !this.listened_addrs.contains(&listen_addr) { this.listened_addrs.push(listen_addr.clone()) } this.behaviour.inject_new_listen_addr(&listen_addr); - return Poll::Ready(SwarmEvent::NewListenAddr(listen_addr)); + return Poll::Ready(Some(SwarmEvent::NewListenAddr(listen_addr))); } - Poll::Ready(NetworkEvent::ExpiredListenerAddress { listener_id, listen_addr }) => { + Poll::Ready(Some( + NetworkEvent::ExpiredListenerAddress { listener_id, listen_addr } + )) => { log::debug!("Listener {:?}; Expired address {:?}.", listener_id, listen_addr); this.listened_addrs.retain(|a| a != &listen_addr); this.behaviour.inject_expired_listen_addr(&listen_addr); - return Poll::Ready(SwarmEvent::ExpiredListenAddr(listen_addr)); + return Poll::Ready(Some(SwarmEvent::ExpiredListenAddr(listen_addr))); } - Poll::Ready(NetworkEvent::ListenerClosed { listener_id, addresses, reason }) => { + Poll::Ready(Some( + NetworkEvent::ListenerClosed { listener_id, addresses, reason } + )) => { log::debug!("Listener {:?}; Closed by {:?}.", listener_id, reason); for addr in addresses.iter() { this.behaviour.inject_expired_listen_addr(addr); @@ -585,26 +657,32 @@ where TBehaviour: NetworkBehaviour, Ok(()) => Ok(()), Err(err) => Err(err), }); - return Poll::Ready(SwarmEvent::ListenerClosed { + return Poll::Ready(Some(SwarmEvent::ListenerClosed { addresses, reason, - }); + })); } - Poll::Ready(NetworkEvent::ListenerError { listener_id, error }) => { + Poll::Ready(Some( + NetworkEvent::ListenerError { listener_id, error } + )) => { this.behaviour.inject_listener_error(listener_id, &error); - return Poll::Ready(SwarmEvent::ListenerError { + return Poll::Ready(Some(SwarmEvent::ListenerError { error, - }); + })); }, - Poll::Ready(NetworkEvent::IncomingConnectionError { local_addr, send_back_addr, error }) => { + Poll::Ready(Some( + NetworkEvent::IncomingConnectionError { local_addr, send_back_addr, error } + )) => { log::debug!("Incoming connection failed: {:?}", error); - return Poll::Ready(SwarmEvent::IncomingConnectionError { + return Poll::Ready(Some(SwarmEvent::IncomingConnectionError { local_addr, send_back_addr, error, - }); + })); }, - Poll::Ready(NetworkEvent::DialError { peer_id, multiaddr, error, attempts_remaining }) => { + Poll::Ready(Some( + NetworkEvent::DialError { peer_id, multiaddr, error, attempts_remaining } + )) => { log::debug!( "Connection attempt to {:?} via {:?} failed with {:?}. Attempts remaining: {}.", peer_id, multiaddr, error, attempts_remaining); @@ -612,21 +690,23 @@ where TBehaviour: NetworkBehaviour, if attempts_remaining == 0 { this.behaviour.inject_dial_failure(&peer_id); } - return Poll::Ready(SwarmEvent::UnreachableAddr { + return Poll::Ready(Some(SwarmEvent::UnreachableAddr { peer_id, address: multiaddr, error, attempts_remaining, - }); + })); }, - Poll::Ready(NetworkEvent::UnknownPeerDialError { multiaddr, error, .. }) => { + Poll::Ready(Some( + NetworkEvent::UnknownPeerDialError { multiaddr, error, .. } + )) => { log::debug!("Connection attempt to address {:?} of unknown peer failed with {:?}", multiaddr, error); this.behaviour.inject_addr_reach_failure(None, &multiaddr, &error); - return Poll::Ready(SwarmEvent::UnknownPeerUnreachableAddr { + return Poll::Ready(Some(SwarmEvent::UnknownPeerUnreachableAddr { address: multiaddr, error, - }); + })); }, } @@ -676,10 +756,19 @@ where TBehaviour: NetworkBehaviour, }; match behaviour_poll { - Poll::Pending if network_not_ready => return Poll::Pending, - Poll::Pending => (), + Poll::Pending if network_not_ready => { + return Poll::Pending + } + Poll::Pending => { + if this.network.is_closed() { + // The network is closed and the behaviour is + // waiting for network I/O. Hence the swarm terminates. + this.closed = true; + return Poll::Ready(None) + } + }, Poll::Ready(NetworkBehaviourAction::GenerateEvent(event)) => { - return Poll::Ready(SwarmEvent::Behaviour(event)) + return Poll::Ready(Some(SwarmEvent::Behaviour(event))) }, Poll::Ready(NetworkBehaviourAction::DialAddress { address }) => { let _ = ExpandedSwarm::dial_addr(&mut *this, address); @@ -697,7 +786,7 @@ where TBehaviour: NetworkBehaviour, }; if condition_matched { if ExpandedSwarm::dial(this, &peer_id).is_ok() { - return Poll::Ready(SwarmEvent::Dialing(peer_id)) + return Poll::Ready(Some(SwarmEvent::Dialing(peer_id))) } } else { // Even if the condition for a _new_ dialing attempt is not met, @@ -718,6 +807,18 @@ where TBehaviour: NetworkBehaviour, } } }, + Poll::Ready(NetworkBehaviourAction::DisconnectPeer { peer_id }) => { + if let Some(peer) = this.network.peer(peer_id).into_connected() { + peer.disconnect(); + } + } + Poll::Ready(NetworkBehaviourAction::CloseConnection { peer_id, connection_id }) => { + if let Some(mut peer) = this.network.peer(peer_id).into_connected() { + if let Some(conn) = peer.connection(connection_id) { + conn.start_close(); + } + } + } Poll::Ready(NetworkBehaviourAction::NotifyHandler { peer_id, handler, event }) => { if let Some(mut peer) = this.network.peer(peer_id.clone()).into_connected() { match handler { @@ -914,9 +1015,10 @@ where TBehaviour: NetworkBehaviour, fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { loop { - let event = futures::ready!(ExpandedSwarm::poll_next_event(self.as_mut(), cx)); - if let SwarmEvent::Behaviour(event) = event { - return Poll::Ready(Some(event)); + match futures::ready!(ExpandedSwarm::poll_next_event(self.as_mut(), cx)) { + Some(SwarmEvent::Behaviour(event)) => return Poll::Ready(Some(event)), + Some(_) => {} + None => return Poll::Ready(None), } } } @@ -933,7 +1035,7 @@ where TBehaviour: NetworkBehaviour, TConnInfo: ConnectionInfo + fmt::Debug + Clone + Send + 'static, { fn is_terminated(&self) -> bool { - false + ExpandedSwarm::is_closed(self) } } @@ -1121,7 +1223,8 @@ where TBehaviour: NetworkBehaviour, listened_addrs: SmallVec::new(), external_addrs: Addresses::default(), banned_peers: HashSet::new(), - pending_event: None + pending_event: None, + closed: false, } } } diff --git a/swarm/src/protocols_handler.rs b/swarm/src/protocols_handler.rs index 9721e9db7451..188a4f688343 100644 --- a/swarm/src/protocols_handler.rs +++ b/swarm/src/protocols_handler.rs @@ -96,9 +96,13 @@ pub use select::{IntoProtocolsHandlerSelect, ProtocolsHandlerSelect}; /// implemented by the handler can include conditions for terminating the connection. /// The lifetime of successfully negotiated substreams is fully controlled by the handler. /// -/// Implementors of this trait should keep in mind that the connection can be closed at any time. -/// When a connection is closed gracefully, the substreams used by the handler may still -/// continue reading data until the remote closes its side of the connection. +/// When the keep-alive expires, the connection is gracefully closed. However, +/// implementors of this trait should keep in mind that the connection can close +/// as a result of encountering an error (including reading EOF as a result of +/// the connection being closed by the remote) at any time. +/// +/// When a connection is gracefully closed, the substreams used by the handler may be +/// shut down in an orderly fashion by implementing [`ProtocolsHandler::poll_close`]. pub trait ProtocolsHandler: Send + 'static { /// Custom event that can be received from the outside. type InEvent: Send + 'static; @@ -164,15 +168,14 @@ pub trait ProtocolsHandler: Send + 'static { /// Returns until when the connection should be kept alive. /// /// This method is called by the `Swarm` after each invocation of - /// [`ProtocolsHandler::poll`] to determine if the connection and the associated + /// [`Self::poll`] to determine if the connection and the associated /// `ProtocolsHandler`s should be kept alive as far as this handler is concerned /// and if so, for how long. /// - /// Returning [`KeepAlive::No`] indicates that the connection should be - /// closed and this handler destroyed immediately. + /// Returning [`KeepAlive::No`] indicates that the connection should be closed. /// /// Returning [`KeepAlive::Until`] indicates that the connection may be closed - /// and this handler destroyed after the specified `Instant`. + /// after the specified `Instant`. /// /// Returning [`KeepAlive::Yes`] indicates that the connection should /// be kept alive until the next call to this method. @@ -188,6 +191,37 @@ pub trait ProtocolsHandler: Send + 'static { ProtocolsHandlerEvent >; + /// Polls the handler to make progress towards closing the connection. + /// + /// When a connection is actively closed, the handler can perform + /// a graceful shutdown of the connection by draining the I/O + /// activity, e.g. allowing in-flight requests to complete without + /// accepting new ones, possibly signaling the remote that it + /// should direct further requests elsewhere. + /// + /// The handler can also use the opportunity to flush any buffers + /// or clean up any other (asynchronous) resources before the + /// connection is ultimately dropped and closed on the transport + /// layer. + /// + /// While closing, new inbound substreams are rejected and the + /// handler is unable to request new outbound substreams as + /// per the return type of `poll_close`. + /// + /// The handler signals its readiness for the connection + /// to be closed by returning `Ready(Ok(None))`, which is the + /// default implementation. Hence, by default, connection + /// shutdown is not delayed and may result in sudden + /// interruption of ongoing I/O. + /// + /// > **Note**: Once `poll_close()` is invoked, the handler is no + /// > longer `poll()`ed. + fn poll_close(&mut self, _: &mut Context) + -> Poll, Self::Error>> + { + Poll::Ready(Ok(None)) + } + /// Adds a closure that turns the input event into something else. #[inline] fn map_in_event(self, map: TMap) -> MapInEvent diff --git a/swarm/src/protocols_handler/node_handler.rs b/swarm/src/protocols_handler/node_handler.rs index a24ea2cc6652..890a04d4ec24 100644 --- a/swarm/src/protocols_handler/node_handler.rs +++ b/swarm/src/protocols_handler/node_handler.rs @@ -113,6 +113,75 @@ where shutdown: Shutdown, } +impl NodeHandlerWrapper +where + TProtoHandler: ProtocolsHandler, +{ + fn poll_inbound(&mut self, cx: &mut Context<'_>) { + // Continue negotiation of newly-opened substreams on the listening side. + // We remove each element from `negotiating_in` one by one and add them back if not ready. + for n in (0..self.negotiating_in.len()).rev() { + let (mut in_progress, mut timeout) = self.negotiating_in.swap_remove(n); + match Future::poll(Pin::new(&mut timeout), cx) { + Poll::Ready(Ok(_)) => { + let err = ProtocolsHandlerUpgrErr::Timeout; + self.handler.inject_listen_upgrade_error(err); + continue + } + Poll::Ready(Err(_)) => { + let err = ProtocolsHandlerUpgrErr::Timer; + self.handler.inject_listen_upgrade_error(err); + continue; + } + Poll::Pending => {}, + } + match Future::poll(Pin::new(&mut in_progress), cx) { + Poll::Ready(Ok(upgrade)) => + self.handler.inject_fully_negotiated_inbound(upgrade), + Poll::Pending => self.negotiating_in.push((in_progress, timeout)), + Poll::Ready(Err(err)) => { + let err = ProtocolsHandlerUpgrErr::Upgrade(err); + self.handler.inject_listen_upgrade_error(err); + } + } + } + + } + + fn poll_outbound(&mut self, cx: &mut Context<'_>) { + // Continue negotiation of newly-opened substreams. + // We remove each element from `negotiating_out` one by one and add them back if not ready. + for n in (0..self.negotiating_out.len()).rev() { + let (upgr_info, mut in_progress, mut timeout) = self.negotiating_out.swap_remove(n); + match Future::poll(Pin::new(&mut timeout), cx) { + Poll::Ready(Ok(_)) => { + let err = ProtocolsHandlerUpgrErr::Timeout; + self.handler.inject_dial_upgrade_error(upgr_info, err); + continue; + }, + Poll::Ready(Err(_)) => { + let err = ProtocolsHandlerUpgrErr::Timer; + self.handler.inject_dial_upgrade_error(upgr_info, err); + continue; + }, + Poll::Pending => {}, + } + match Future::poll(Pin::new(&mut in_progress), cx) { + Poll::Ready(Ok(upgrade)) => { + self.handler.inject_fully_negotiated_outbound(upgrade, upgr_info); + } + Poll::Pending => { + self.negotiating_out.push((upgr_info, in_progress, timeout)); + } + Poll::Ready(Err(err)) => { + let err = ProtocolsHandlerUpgrErr::Upgrade(err); + self.handler.inject_dial_upgrade_error(upgr_info, err); + } + } + } + } +} + /// The options for a planned connection & handler shutdown. /// /// A shutdown is planned anew based on the the return value of @@ -228,64 +297,8 @@ where fn poll(&mut self, cx: &mut Context) -> Poll< Result, Self::Error> > { - // Continue negotiation of newly-opened substreams on the listening side. - // We remove each element from `negotiating_in` one by one and add them back if not ready. - for n in (0..self.negotiating_in.len()).rev() { - let (mut in_progress, mut timeout) = self.negotiating_in.swap_remove(n); - match Future::poll(Pin::new(&mut timeout), cx) { - Poll::Ready(Ok(_)) => { - let err = ProtocolsHandlerUpgrErr::Timeout; - self.handler.inject_listen_upgrade_error(err); - continue - } - Poll::Ready(Err(_)) => { - let err = ProtocolsHandlerUpgrErr::Timer; - self.handler.inject_listen_upgrade_error(err); - continue; - } - Poll::Pending => {}, - } - match Future::poll(Pin::new(&mut in_progress), cx) { - Poll::Ready(Ok(upgrade)) => - self.handler.inject_fully_negotiated_inbound(upgrade), - Poll::Pending => self.negotiating_in.push((in_progress, timeout)), - Poll::Ready(Err(err)) => { - let err = ProtocolsHandlerUpgrErr::Upgrade(err); - self.handler.inject_listen_upgrade_error(err); - } - } - } - - // Continue negotiation of newly-opened substreams. - // We remove each element from `negotiating_out` one by one and add them back if not ready. - for n in (0..self.negotiating_out.len()).rev() { - let (upgr_info, mut in_progress, mut timeout) = self.negotiating_out.swap_remove(n); - match Future::poll(Pin::new(&mut timeout), cx) { - Poll::Ready(Ok(_)) => { - let err = ProtocolsHandlerUpgrErr::Timeout; - self.handler.inject_dial_upgrade_error(upgr_info, err); - continue; - }, - Poll::Ready(Err(_)) => { - let err = ProtocolsHandlerUpgrErr::Timer; - self.handler.inject_dial_upgrade_error(upgr_info, err); - continue; - }, - Poll::Pending => {}, - } - match Future::poll(Pin::new(&mut in_progress), cx) { - Poll::Ready(Ok(upgrade)) => { - self.handler.inject_fully_negotiated_outbound(upgrade, upgr_info); - } - Poll::Pending => { - self.negotiating_out.push((upgr_info, in_progress, timeout)); - } - Poll::Ready(Err(err)) => { - let err = ProtocolsHandlerUpgrErr::Upgrade(err); - self.handler.inject_dial_upgrade_error(upgr_info, err); - } - } - } + self.poll_inbound(cx); + self.poll_outbound(cx); // Poll the handler at the end so that we see the consequences of the method // calls on `self.handler`. @@ -328,16 +341,42 @@ where // Check if the connection (and handler) should be shut down. // As long as we're still negotiating substreams, shutdown is always postponed. if self.negotiating_in.is_empty() && self.negotiating_out.is_empty() { - match self.shutdown { - Shutdown::None => {}, - Shutdown::Asap => return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)), + let close = match self.shutdown { + Shutdown::None => false, + Shutdown::Asap => true, Shutdown::Later(ref mut delay, _) => match Future::poll(Pin::new(delay), cx) { - Poll::Ready(_) => return Poll::Ready(Err(NodeHandlerWrapperError::KeepAliveTimeout)), - Poll::Pending => {} + Poll::Ready(_) => true, + Poll::Pending => false } + }; + if close { + log::debug!("Closing connection due to keep-alive timeout."); + return Poll::Ready(Ok(ConnectionHandlerEvent::Close)) } } Poll::Pending } + + fn poll_close(&mut self, cx: &mut Context) + -> Poll, Self::Error>> + { + // Allow ongoing inbound and outbound substream upgrades to complete. + // New inbound substreams are dropped / rejected and new outbound + // substreams are not permitted as per the return type of `poll_close`. + self.poll_inbound(cx); + self.poll_outbound(cx); + + // If the handler is ready to close and there are no more + // substreams being negotiated, the connection can close. + match self.handler.poll_close(cx).map_err(NodeHandlerWrapperError::Handler) { + Poll::Ready(Ok(None)) => + if self.negotiating_in.is_empty() && self.negotiating_out.is_empty() { + return Poll::Ready(Ok(None)) + } else { + return Poll::Pending + } + poll => poll + } + } }