diff --git a/base_layer/p2p/src/services/liveness/config.rs b/base_layer/p2p/src/services/liveness/config.rs index 5e90db6528..d70314f30f 100644 --- a/base_layer/p2p/src/services/liveness/config.rs +++ b/base_layer/p2p/src/services/liveness/config.rs @@ -36,6 +36,8 @@ pub struct LivenessConfig { pub num_peers_per_round: usize, /// Peers to include in every auto ping round (Default: ) pub monitored_peers: Vec, + /// Number of ping failures to tolerate before disconnecting the peer. A value of zero disables this feature. + pub max_allowed_ping_failures: usize, } impl Default for LivenessConfig { @@ -46,6 +48,7 @@ impl Default for LivenessConfig { refresh_random_pool_interval: Duration::from_secs(2 * 60 * 60), num_peers_per_round: 8, monitored_peers: Default::default(), + max_allowed_ping_failures: 2, } } } diff --git a/base_layer/p2p/src/services/liveness/error.rs b/base_layer/p2p/src/services/liveness/error.rs index fdea8677c7..f07f8147bf 100644 --- a/base_layer/p2p/src/services/liveness/error.rs +++ b/base_layer/p2p/src/services/liveness/error.rs @@ -20,7 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use tari_comms::{connectivity::ConnectivityError, message::MessageError}; +use tari_comms::{connectivity::ConnectivityError, message::MessageError, PeerConnectionError}; use tari_comms_dht::{outbound::DhtOutboundError, DhtActorError}; use tari_service_framework::reply_channel::TransportChannelError; use thiserror::Error; @@ -31,6 +31,8 @@ pub enum LivenessError { DhtOutboundError(#[from] DhtOutboundError), #[error("Connectivity error: `{0}`")] ConnectivityError(#[from] ConnectivityError), + #[error("Peer connection error: `{0}`")] + PeerConnectionError(#[from] PeerConnectionError), #[error("DHT actor error: `{0}`")] DhtActorError(#[from] DhtActorError), #[error("Failed to send a pong message")] diff --git a/base_layer/p2p/src/services/liveness/service.rs b/base_layer/p2p/src/services/liveness/service.rs index bc1cd3bbd5..167992f128 100644 --- a/base_layer/p2p/src/services/liveness/service.rs +++ b/base_layer/p2p/src/services/liveness/service.rs @@ -122,6 +122,11 @@ where if let Err(err) = self.start_ping_round().await { warn!(target: LOG_TARGET, "Error when pinging peers: {}", err); } + if self.config.max_allowed_ping_failures > 0 { + if let Err(err) = self.disconnect_failed_peers().await { + error!(target: LOG_TARGET, "Error occurred while disconnecting failed peers: {}", err); + } + } }, // Incoming messages from the Comms layer @@ -179,7 +184,7 @@ where return Ok(()); } - let maybe_latency = self.state.record_pong(ping_pong_msg.nonce); + let maybe_latency = self.state.record_pong(ping_pong_msg.nonce, &node_id); debug!( target: LOG_TARGET, "Received pong from peer '{}' with useragent '{}'. {} (Trace: {})", @@ -285,6 +290,26 @@ where Ok(()) } + async fn disconnect_failed_peers(&mut self) -> Result<(), LivenessError> { + let max_allowed_ping_failures = self.config.max_allowed_ping_failures; + for node_id in self + .state + .failed_pings_iter() + .filter(|(_, n)| **n > max_allowed_ping_failures) + .map(|(node_id, _)| node_id) + { + if let Some(mut conn) = self.connectivity.get_connection(node_id.clone()).await? { + debug!( + target: LOG_TARGET, + "Disconnecting peer {} that failed {} rounds of pings", node_id, max_allowed_ping_failures + ); + conn.disconnect().await?; + } + } + self.state.clear_failed_pings(); + Ok(()) + } + fn publish_event(&mut self, event: LivenessEvent) { let _ = self.event_publisher.send(Arc::new(event)).map_err(|_| { trace!( diff --git a/base_layer/p2p/src/services/liveness/state.rs b/base_layer/p2p/src/services/liveness/state.rs index d8b5b29d20..4e89d8e91f 100644 --- a/base_layer/p2p/src/services/liveness/state.rs +++ b/base_layer/p2p/src/services/liveness/state.rs @@ -20,17 +20,18 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use super::LOG_TARGET; use crate::proto::liveness::MetadataKey; -use chrono::{NaiveDateTime, Utc}; +use log::*; use std::{ - collections::{hash_map::RandomState, HashMap}, + collections::HashMap, convert::TryInto, - time::Duration, + time::{Duration, Instant}, }; use tari_comms::peer_manager::NodeId; const LATENCY_SAMPLE_WINDOW_SIZE: usize = 25; -const MAX_INFLIGHT_TTL: Duration = Duration::from_secs(20); +const MAX_INFLIGHT_TTL: Duration = Duration::from_secs(40); /// Represents metadata in a ping/pong message. #[derive(Clone, Debug, Default, PartialEq, Eq)] @@ -62,7 +63,7 @@ impl From>> for Metadata { } } -impl From for HashMap, RandomState> { +impl From for HashMap> { fn from(metadata: Metadata) -> Self { metadata.inner } @@ -71,8 +72,9 @@ impl From for HashMap, RandomState> { /// State for the LivenessService. #[derive(Default, Debug)] pub struct LivenessState { - inflight_pings: HashMap, + inflight_pings: HashMap, peer_latency: HashMap, + failed_pings: HashMap, pings_received: usize, pongs_received: usize, @@ -133,18 +135,27 @@ impl LivenessState { /// Adds a ping to the inflight ping list, while noting the current time that a ping was sent. pub fn add_inflight_ping(&mut self, nonce: u64, node_id: NodeId) { - let now = Utc::now().naive_utc(); - self.inflight_pings.insert(nonce, (node_id, now)); + self.inflight_pings.insert(nonce, (node_id, Instant::now())); self.clear_stale_inflight_pings(); } - /// Clears inflight ping requests which have not responded + /// Clears inflight ping requests which have not responded and adds them to failed_ping counter fn clear_stale_inflight_pings(&mut self) { - self.inflight_pings = self + let (inflight, expired) = self .inflight_pings .drain() - .filter(|(_, (_, time))| convert_to_std_duration(Utc::now().naive_utc() - *time) <= MAX_INFLIGHT_TTL) - .collect(); + .partition(|(_, (_, time))| time.elapsed() <= MAX_INFLIGHT_TTL); + + self.inflight_pings = inflight; + + for (_, (node_id, _)) in expired { + self.failed_pings + .entry(node_id) + .and_modify(|v| { + *v = *v + 1; + }) + .or_insert(1); + } } /// Returns true if the nonce is inflight, otherwise false @@ -153,19 +164,25 @@ impl LivenessState { } /// Records a pong. Specifically, the pong counter is incremented and - /// a latency sample is added and calculated. - pub fn record_pong(&mut self, nonce: u64) -> Option { + /// a latency sample is added and calculated. The given `peer` must match the recorded peer + pub fn record_pong(&mut self, nonce: u64, sent_by: &NodeId) -> Option { self.inc_pongs_received(); - - match self.inflight_pings.remove_entry(&nonce) { - Some((_, (node_id, sent_time))) => { - let now = Utc::now().naive_utc(); - let latency = self - .add_latency_sample(node_id, convert_to_std_duration(now - sent_time)) - .calc_average(); - Some(latency) - }, - None => None, + self.failed_pings.remove_entry(&sent_by); + + let (node_id, _) = self.inflight_pings.get(&nonce)?; + if node_id == sent_by { + self.inflight_pings + .remove(&nonce) + .map(|(node_id, sent_time)| self.add_latency_sample(node_id, sent_time.elapsed()).calc_average()) + } else { + warn!( + target: LOG_TARGET, + "Peer {} sent an nonce for another peer {}. This could indicate malicious behaviour or a bug. \ + Ignoring.", + sent_by, + node_id + ); + None } } @@ -195,11 +212,14 @@ impl LivenessState { // num_peers in map will always be > 0 .map(|latency| latency / num_peers as u32) } -} -/// Convert `chrono::Duration` to `std::time::Duration` -pub(super) fn convert_to_std_duration(old_duration: chrono::Duration) -> Duration { - Duration::from_millis(old_duration.num_milliseconds() as u64) + pub fn failed_pings_iter(&self) -> impl Iterator { + self.failed_pings.iter() + } + + pub fn clear_failed_pings(&mut self) { + self.failed_pings.clear(); + } } /// A very simple implementation for calculating average latency. Samples are added in milliseconds and the mean average @@ -299,9 +319,9 @@ mod test { let mut state = LivenessState::new(); let node_id = NodeId::default(); - state.add_inflight_ping(123, node_id); + state.add_inflight_ping(123, node_id.clone()); - let latency = state.record_pong(123).unwrap(); + let latency = state.record_pong(123, &node_id).unwrap(); assert!(latency < 50); } @@ -311,4 +331,35 @@ mod test { state.set_metadata_entry(MetadataKey::ChainMetadata, b"dummy-data".to_vec()); assert_eq!(state.metadata().get(MetadataKey::ChainMetadata).unwrap(), b"dummy-data"); } + + #[test] + fn clear_stale_inflight_pings() { + let mut state = LivenessState::new(); + + let peer1 = NodeId::default(); + state.add_inflight_ping(1, peer1.clone()); + let peer2 = NodeId::from_public_key(&Default::default()); + state.add_inflight_ping(2, peer2.clone()); + state.add_inflight_ping(3, peer2.clone()); + + assert!(state.failed_pings.get(&peer1).is_none()); + assert!(state.failed_pings.get(&peer2).is_none()); + + // MAX_INFLIGHT_TTL passes + for n in [1, 2, 3] { + let (_, time) = state.inflight_pings.get_mut(&n).unwrap(); + *time = Instant::now() - (MAX_INFLIGHT_TTL + Duration::from_secs(1)); + } + + state.clear_stale_inflight_pings(); + let n = state.failed_pings.get(&peer1).unwrap(); + assert_eq!(*n, 1); + let n = state.failed_pings.get(&peer2).unwrap(); + assert_eq!(*n, 2); + + assert!(state.record_pong(2, &peer2).is_none()); + let n = state.failed_pings.get(&peer1).unwrap(); + assert_eq!(*n, 1); + assert!(state.failed_pings.get(&peer2).is_none()); + } }