From b60fd69c16296ff1f0d8952757a91e8c1f58617e Mon Sep 17 00:00:00 2001 From: Hansie Odendaal Date: Mon, 22 Apr 2024 12:54:02 +0200 Subject: [PATCH] Limit wallet peer connections Added functionality to limit the number of base node peer connections that a wallet can have, based on a config setting. The furtherest nodes will be disconnected. --- .../src/grpc/wallet_grpc_server.rs | 2 +- base_layer/chat_ffi/src/byte_vector.rs | 2 +- .../src/chain_storage/blockchain_database.rs | 2 +- .../core/src/chain_storage/lmdb_db/lmdb_db.rs | 4 +- .../priority/prioritized_transaction.rs | 2 +- .../core/src/transactions/crypto_factories.rs | 2 +- .../transaction_components/encrypted_data.rs | 2 +- .../transaction_protocol/sender.rs | 2 +- base_layer/mmr/src/backend.rs | 2 +- .../mmr/src/sparse_merkle_tree/proofs.rs | 2 +- base_layer/p2p/src/initialization.rs | 5 + base_layer/p2p/src/transport.rs | 2 +- .../src/output_manager_service/service.rs | 4 +- base_layer/wallet_ffi/src/lib.rs | 29 ++-- common/config/presets/c_base_node_c.toml | 6 +- common/config/presets/d_console_wallet.toml | 6 +- comms/core/Cargo.toml | 1 + comms/core/src/builder/mod.rs | 14 ++ comms/core/src/connection_manager/dialer.rs | 2 +- comms/core/src/connectivity/config.rs | 4 + comms/core/src/connectivity/manager.rs | 56 ++++++++ comms/core/src/peer_manager/peer_query.rs | 20 +-- comms/dht/src/config.rs | 4 + comms/dht/src/connectivity/mod.rs | 136 ++++++++++++------ comms/dht/src/connectivity/test.rs | 8 +- 25 files changed, 234 insertions(+), 85 deletions(-) diff --git a/applications/minotari_console_wallet/src/grpc/wallet_grpc_server.rs b/applications/minotari_console_wallet/src/grpc/wallet_grpc_server.rs index f74f793664..563bb004e7 100644 --- a/applications/minotari_console_wallet/src/grpc/wallet_grpc_server.rs +++ b/applications/minotari_console_wallet/src/grpc/wallet_grpc_server.rs @@ -160,7 +160,7 @@ impl WalletGrpcServer { fn get_consensus_constants(&self) -> Result<&ConsensusConstants, WalletStorageError> { // If we don't have the chain metadata, we hope that VNReg consensus constants did not change - worst case, we - // spend more than we need to or the the transaction is rejected. + // spend more than we need to or the transaction is rejected. let height = self .wallet .db diff --git a/base_layer/chat_ffi/src/byte_vector.rs b/base_layer/chat_ffi/src/byte_vector.rs index 233840c66d..cc666adbb5 100644 --- a/base_layer/chat_ffi/src/byte_vector.rs +++ b/base_layer/chat_ffi/src/byte_vector.rs @@ -100,7 +100,7 @@ pub unsafe extern "C" fn chat_byte_vector_destroy(bytes: *mut ChatByteVector) { /// /// # Safety /// None -// converting between here is fine as its used to clamp the the array to length +// converting between here is fine as its used to clamp the array to length #[allow(clippy::cast_possible_wrap)] #[no_mangle] pub unsafe extern "C" fn chat_byte_vector_get_at( diff --git a/base_layer/core/src/chain_storage/blockchain_database.rs b/base_layer/core/src/chain_storage/blockchain_database.rs index e81abdb26f..0c88b14930 100644 --- a/base_layer/core/src/chain_storage/blockchain_database.rs +++ b/base_layer/core/src/chain_storage/blockchain_database.rs @@ -2405,7 +2405,7 @@ fn get_previous_timestamps( Ok(timestamps) } -/// Gets all blocks ordered from the the block that connects (via prev_hash) to the main chain, to the orphan tip. +/// Gets all blocks ordered from the block that connects (via prev_hash) to the main chain, to the orphan tip. #[allow(clippy::ptr_arg)] fn get_orphan_link_main_chain( db: &T, diff --git a/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs b/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs index ed5a941c7e..05a2cdcafb 100644 --- a/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs +++ b/base_layer/core/src/chain_storage/lmdb_db/lmdb_db.rs @@ -2554,9 +2554,9 @@ impl BlockchainBackend for LMDBDatabase { } trace!( target: LOG_TARGET, - "Finished calculating new smt (size: {}), took: #{}s", + "Finished calculating new smt (size: {}), took: {:.2?}", smt.size(), - start.elapsed().as_millis() + start.elapsed() ); Ok(smt) } diff --git a/base_layer/core/src/mempool/priority/prioritized_transaction.rs b/base_layer/core/src/mempool/priority/prioritized_transaction.rs index 0c78db88b7..e3fb1664eb 100644 --- a/base_layer/core/src/mempool/priority/prioritized_transaction.rs +++ b/base_layer/core/src/mempool/priority/prioritized_transaction.rs @@ -35,7 +35,7 @@ use crate::transactions::{ }; /// Create a unique unspent transaction priority based on the transaction fee, maturity of the oldest input UTXO and the -/// excess_sig. The excess_sig is included to ensure the the priority key unique so it can be used with a BTreeMap. +/// excess_sig. The excess_sig is included to ensure the priority key unique so it can be used with a BTreeMap. /// Normally, duplicate keys will be overwritten in a BTreeMap. #[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone)] pub struct FeePriority(Vec); diff --git a/base_layer/core/src/transactions/crypto_factories.rs b/base_layer/core/src/transactions/crypto_factories.rs index c2cc0f9841..7b1e560f50 100644 --- a/base_layer/core/src/transactions/crypto_factories.rs +++ b/base_layer/core/src/transactions/crypto_factories.rs @@ -31,7 +31,7 @@ impl CryptoFactories { /// /// ## Parameters /// - /// * `max_proof_range`: Sets the the maximum value in range proofs, where `max = 2^max_proof_range` + /// * `max_proof_range`: Sets the maximum value in range proofs, where `max = 2^max_proof_range` pub fn new(max_proof_range: usize) -> Self { Self { commitment: Arc::new(CommitmentFactory::default()), diff --git a/base_layer/core/src/transactions/transaction_components/encrypted_data.rs b/base_layer/core/src/transactions/transaction_components/encrypted_data.rs index 273fd14740..f3df4a9ef2 100644 --- a/base_layer/core/src/transactions/transaction_components/encrypted_data.rs +++ b/base_layer/core/src/transactions/transaction_components/encrypted_data.rs @@ -23,7 +23,7 @@ // Portions of this file were originally copyrighted (c) 2018 The Grin Developers, issued under the Apache License, // Version 2.0, available at http://www.apache.org/licenses/LICENSE-2.0. -//! Encrypted data using the the extended-nonce variant XChaCha20-Poly1305 encryption with secure random nonce. +//! Encrypted data using the extended-nonce variant XChaCha20-Poly1305 encryption with secure random nonce. use std::mem::size_of; diff --git a/base_layer/core/src/transactions/transaction_protocol/sender.rs b/base_layer/core/src/transactions/transaction_protocol/sender.rs index 2e6ae88294..01a15de043 100644 --- a/base_layer/core/src/transactions/transaction_protocol/sender.rs +++ b/base_layer/core/src/transactions/transaction_protocol/sender.rs @@ -472,7 +472,7 @@ impl SenderTransactionProtocol { Ok((public_nonce, public_excess)) } - /// Add partial signatures, add the the recipient info to sender state and move to the Finalizing state + /// Add partial signatures, add the recipient info to sender state and move to the Finalizing state pub async fn add_single_recipient_info( &mut self, mut rec: RecipientSignedMessage, diff --git a/base_layer/mmr/src/backend.rs b/base_layer/mmr/src/backend.rs index 69235daf01..fe5943420b 100644 --- a/base_layer/mmr/src/backend.rs +++ b/base_layer/mmr/src/backend.rs @@ -41,7 +41,7 @@ pub trait ArrayLike { /// Return the item at the given index fn get(&self, index: usize) -> Result, Self::Error>; - /// Remove all stored items from the the backend. + /// Remove all stored items from the backend. fn clear(&mut self) -> Result<(), Self::Error>; /// Finds the index of the specified stored item, it will return None if the object could not be found. diff --git a/base_layer/mmr/src/sparse_merkle_tree/proofs.rs b/base_layer/mmr/src/sparse_merkle_tree/proofs.rs index cf7b3405fe..0cd10a9c04 100644 --- a/base_layer/mmr/src/sparse_merkle_tree/proofs.rs +++ b/base_layer/mmr/src/sparse_merkle_tree/proofs.rs @@ -98,7 +98,7 @@ pub struct InclusionProof { /// ``` pub struct ExclusionProof { siblings: Vec, - // The terminal node of the tree proof, or `None` if the the node is `Empty`. + // The terminal node of the tree proof, or `None` if the node is `Empty`. leaf: Option>, phantom: std::marker::PhantomData, } diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 4d3ba863d1..1aac56609a 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -559,6 +559,11 @@ impl ServiceInitializer for P2pInitializer { network_byte: self.network.as_byte(), user_agent: config.user_agent.clone(), }) + .with_minimize_connections(if self.config.dht.minimize_connections { + Some(self.config.dht.num_neighbouring_nodes + self.config.dht.num_random_nodes) + } else { + None + }) .set_self_liveness_check(config.listener_self_liveness_check_interval); if config.allow_test_addresses || config.dht.peer_validator_config.allow_test_addresses { diff --git a/base_layer/p2p/src/transport.rs b/base_layer/p2p/src/transport.rs index a220fa9d0e..939a96329e 100644 --- a/base_layer/p2p/src/transport.rs +++ b/base_layer/p2p/src/transport.rs @@ -147,7 +147,7 @@ pub struct TorTransportConfig { /// When set to true, outbound TCP connections bypass the tor proxy. Defaults to false for better privacy, setting /// to true may improve network performance for TCP nodes. pub proxy_bypass_for_outbound_tcp: bool, - /// If set, instructs tor to forward traffic the the provided address. Otherwise, an OS-assigned port on 127.0.0.1 + /// If set, instructs tor to forward traffic the provided address. Otherwise, an OS-assigned port on 127.0.0.1 /// is used. pub forward_address: Option, /// If set, the listener will bind to this address instead of the forward_address. diff --git a/base_layer/wallet/src/output_manager_service/service.rs b/base_layer/wallet/src/output_manager_service/service.rs index 98ec0f2e20..a24deb7df6 100644 --- a/base_layer/wallet/src/output_manager_service/service.rs +++ b/base_layer/wallet/src/output_manager_service/service.rs @@ -1293,7 +1293,7 @@ where let uo_len = uo.len(); trace!( target: LOG_TARGET, - "select_utxos profile - fetch_unspent_outputs_for_spending: {} outputs, {} ms (at {})", + "select_utxos profile - fetch_unspent_outputs_for_spending: {} outputs, {} ms (at {} ms)", uo_len, start_new.elapsed().as_millis(), start.elapsed().as_millis(), @@ -1362,7 +1362,7 @@ where let enough_spendable = utxos_total_value > amount + fee_with_change; trace!( target: LOG_TARGET, - "select_utxos profile - final_selection: {} outputs from {}, {} ms (at {})", + "select_utxos profile - final_selection: {} outputs from {}, {} ms (at {} ms)", utxos.len(), uo_len, start_new.elapsed().as_millis(), diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index ec5050a0ca..f3bed75ba2 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -27,7 +27,7 @@ //! becoming a `CompletedTransaction` with the `Completed` status. This means that the transaction has been //! negotiated between the parties and is now ready to be broadcast to the Base Layer. The funds are still encumbered //! as pending because the transaction has not been mined yet. -//! 3. The finalized `CompletedTransaction` will be sent back to the the receiver so that they have a copy. +//! 3. The finalized `CompletedTransaction` will be sent back to the receiver so that they have a copy. //! 4. The wallet will broadcast the `CompletedTransaction` to a Base Node to be added to the mempool. Its status will //! move from `Completed` to `Broadcast`. //! 5. Wait until the transaction is mined. The `CompleteTransaction` status will then move from `Broadcast` to `Mined` @@ -131,7 +131,13 @@ use tari_comms::{ transports::MemoryTransport, types::CommsPublicKey, }; -use tari_comms_dht::{store_forward::SafConfig, DbConnectionUrl, DhtConfig, NetworkDiscoveryConfig}; +use tari_comms_dht::{ + store_forward::SafConfig, + DbConnectionUrl, + DhtConfig, + DhtConnectivityConfig, + NetworkDiscoveryConfig, +}; use tari_contacts::contacts_service::{handle::ContactsServiceHandle, types::Contact}; use tari_core::{ borsh::FromBytes, @@ -818,7 +824,7 @@ pub unsafe extern "C" fn byte_vector_destroy(bytes: *mut ByteVector) { /// /// # Safety /// None -// converting between here is fine as its used to clamp the the array to length +// converting between here is fine as its used to clamp the array to length #[allow(clippy::cast_possible_wrap)] #[no_mangle] pub unsafe extern "C" fn byte_vector_get_at(ptr: *mut ByteVector, position: c_uint, error_out: *mut c_int) -> c_uchar { @@ -1778,7 +1784,7 @@ pub unsafe extern "C" fn unblinded_outputs_get_length( /// /// # Safety /// The ```contact_destroy``` method must be called when finished with a TariContact to prevent a memory leak -// converting between here is fine as its used to clamp the the array to length +// converting between here is fine as its used to clamp the array to length #[allow(clippy::cast_possible_wrap)] #[no_mangle] pub unsafe extern "C" fn unblinded_outputs_get_at( @@ -2884,7 +2890,7 @@ pub unsafe extern "C" fn contacts_get_length(contacts: *mut TariContacts, error_ /// /// # Safety /// The ```contact_destroy``` method must be called when finished with a TariContact to prevent a memory leak -// converting between here is fine as its used to clamp the the array to length +// converting between here is fine as its used to clamp the array to length #[allow(clippy::cast_possible_wrap)] #[no_mangle] pub unsafe extern "C" fn contacts_get_at( @@ -3185,7 +3191,7 @@ pub unsafe extern "C" fn completed_transactions_get_length( /// # Safety /// The ```completed_transaction_destroy``` method must be called when finished with a TariCompletedTransaction to /// prevent a memory leak -// converting between here is fine as its used to clamp the the array to length +// converting between here is fine as its used to clamp the array to length #[allow(clippy::cast_possible_wrap)] #[no_mangle] pub unsafe extern "C" fn completed_transactions_get_at( @@ -3278,7 +3284,7 @@ pub unsafe extern "C" fn pending_outbound_transactions_get_length( /// # Safety /// The ```pending_outbound_transaction_destroy``` method must be called when finished with a /// TariPendingOutboundTransaction to prevent a memory leak -// converting between here is fine as its used to clamp the the array to length +// converting between here is fine as its used to clamp the array to length #[allow(clippy::cast_possible_wrap)] #[no_mangle] pub unsafe extern "C" fn pending_outbound_transactions_get_at( @@ -3370,7 +3376,7 @@ pub unsafe extern "C" fn pending_inbound_transactions_get_length( /// # Safety /// The ```pending_inbound_transaction_destroy``` method must be called when finished with a /// TariPendingOutboundTransaction to prevent a memory leak -// converting between here is fine as its used to clamp the the array to length +// converting between here is fine as its used to clamp the array to length #[allow(clippy::cast_possible_wrap)] #[no_mangle] pub unsafe extern "C" fn pending_inbound_transactions_get_at( @@ -4851,6 +4857,9 @@ pub unsafe extern "C" fn comms_config_create( max_concurrent_inbound_tasks: 25, max_concurrent_outbound_tasks: 50, dht: DhtConfig { + num_neighbouring_nodes: 6, + num_random_nodes: 2, + minimize_connections: true, discovery_request_timeout: Duration::from_secs(discovery_timeout_in_secs), database_url: DbConnectionUrl::File(dht_database_path), auto_join: true, @@ -4864,6 +4873,10 @@ pub unsafe extern "C" fn comms_config_create( initial_peer_sync_delay: Some(Duration::from_secs(25)), ..Default::default() }, + connectivity: DhtConnectivityConfig { + update_interval: Duration::from_secs(180), + ..Default::default() + }, ..Default::default() }, allow_test_addresses: true, diff --git a/common/config/presets/c_base_node_c.toml b/common/config/presets/c_base_node_c.toml index f0754403f0..66a7f89343 100644 --- a/common/config/presets/c_base_node_c.toml +++ b/common/config/presets/c_base_node_c.toml @@ -192,7 +192,7 @@ listener_self_liveness_check_interval = 15 # When using the tor transport and set to true, outbound TCP connections bypass the tor proxy. Defaults to false for # better privacy #tor.proxy_bypass_for_outbound_tcp = false -# If set, instructs tor to forward traffic the the provided address. (e.g. "/dns4/my-base-node/tcp/32123") (default = OS-assigned port) +# If set, instructs tor to forward traffic the provided address. (e.g. "/dns4/my-base-node/tcp/32123") (default = OS-assigned port) #tor.forward_address = # If set, the listener will bind to this address instead of the forward_address. You need to make sure that this listener is connectable from the forward_address. #tor.listener_address_override = @@ -216,7 +216,9 @@ database_url = "data/base_node/dht.db" # The maximum number of peer nodes that a message has to be closer to, to be considered a neighbour. Default: 8 #num_neighbouring_nodes = 8 # Number of random peers to include. Default: 4 -#num_random_nodes= 4 +#num_random_nodes = 4 +# Connections above the configured number of neighbouring and random nodes will be removed (default: false) +#minimize_connections = false # Send to this many peers when using the broadcast strategy. Default: 8 #broadcast_factor = 8 # Send to this many peers when using the propagate strategy. Default: 4 diff --git a/common/config/presets/d_console_wallet.toml b/common/config/presets/d_console_wallet.toml index 7d076434bf..7f91bee5b1 100644 --- a/common/config/presets/d_console_wallet.toml +++ b/common/config/presets/d_console_wallet.toml @@ -242,7 +242,7 @@ event_channel_size = 3500 # When using the tor transport and set to true, outbound TCP connections bypass the tor proxy. Defaults to false for # better privacy #tor.proxy_bypass_for_outbound_tcp = false -# If set, instructs tor to forward traffic the the provided address. (e.g. "/ip4/127.0.0.1/tcp/0") (default = ) +# If set, instructs tor to forward traffic the provided address. (e.g. "/ip4/127.0.0.1/tcp/0") (default = ) #tor.forward_address = # Use a SOCKS5 proxy transport. This transport recognises any addresses supported by the proxy. @@ -264,7 +264,9 @@ database_url = "data/wallet/dht.db" # The maximum number of peer nodes that a message has to be closer to, to be considered a neighbour. Default: 8 #num_neighbouring_nodes = 8 # Number of random peers to include. Default: 4 -#num_random_nodes= 4 +#num_random_nodes = 4 +# Connections above the configured number of neighbouring and random nodes will be removed (default: false) +minimize_connections = true # Send to this many peers when using the broadcast strategy. Default: 8 #broadcast_factor = 8 # Send to this many peers when using the propagate strategy. Default: 4 diff --git a/comms/core/Cargo.toml b/comms/core/Cargo.toml index 14e6a0aa47..95f1b035bc 100644 --- a/comms/core/Cargo.toml +++ b/comms/core/Cargo.toml @@ -15,6 +15,7 @@ tari_metrics = { path = "../../infrastructure/metrics", optional = true, version tari_storage = { path = "../../infrastructure/storage", version = "1.0.0-pre.12" } tari_shutdown = { path = "../../infrastructure/shutdown" , version = "1.0.0-pre.12"} tari_utilities = { version = "0.7" } +tari_common = { path = "../../common", version = "1.0.0-pre.11a" } anyhow = "1.0.53" async-trait = "0.1.36" diff --git a/comms/core/src/builder/mod.rs b/comms/core/src/builder/mod.rs index 26f4cc503a..43b78874b0 100644 --- a/comms/core/src/builder/mod.rs +++ b/comms/core/src/builder/mod.rs @@ -70,6 +70,7 @@ use crate::{ /// # #[tokio::main] /// # async fn main() { /// use std::env::temp_dir; +/// use tari_comms::connectivity::ConnectivityConfig; /// /// use tari_storage::{ /// lmdb_store::{LMDBBuilder, LMDBConfig}, @@ -126,6 +127,7 @@ pub struct CommsBuilder { connection_manager_config: ConnectionManagerConfig, connectivity_config: ConnectivityConfig, shutdown_signal: Option, + maintain_n_closest_connections_only: Option, } impl Default for CommsBuilder { @@ -139,6 +141,7 @@ impl Default for CommsBuilder { connection_manager_config: ConnectionManagerConfig::default(), connectivity_config: ConnectivityConfig::default(), shutdown_signal: None, + maintain_n_closest_connections_only: None, } } } @@ -292,6 +295,17 @@ impl CommsBuilder { self } + /// The closest number of peer connections to maintain; connections above the threshold will be removed + pub fn with_minimize_connections(mut self, connections: Option) -> Self { + self.maintain_n_closest_connections_only = connections; + self.connectivity_config.maintain_n_closest_connections_only = connections; + if let Some(val) = connections { + self.connectivity_config.reaper_min_connection_threshold = val; + } + self.connectivity_config.connection_pool_refresh_interval = Duration::from_secs(180); + self + } + fn make_peer_manager(&mut self) -> Result, CommsBuilderError> { let file_lock = self.peer_storage_file_lock.take(); diff --git a/comms/core/src/connection_manager/dialer.rs b/comms/core/src/connection_manager/dialer.rs index 8226eab9d5..e502b9bd4e 100644 --- a/comms/core/src/connection_manager/dialer.rs +++ b/comms/core/src/connection_manager/dialer.rs @@ -598,7 +598,7 @@ where let noise_upgrade_time = timer.elapsed(); debug!( - "Dial - upgraded noise: {} on address: {} on tcp after: {}", + "Dial - upgraded noise: {} on address: {} on tcp after: {} ms", node_id.short_str(), moved_address, timer.elapsed().as_millis() diff --git a/comms/core/src/connectivity/config.rs b/comms/core/src/connectivity/config.rs index 2ebc47fe91..02a65b3c7d 100644 --- a/comms/core/src/connectivity/config.rs +++ b/comms/core/src/connectivity/config.rs @@ -49,6 +49,9 @@ pub struct ConnectivityConfig { /// next connection attempt. /// Default: 24 hours pub expire_peer_last_seen_duration: Duration, + /// The closest number of peer connections to maintain; connections above the threshold will be removed + /// (default: disabled) + pub maintain_n_closest_connections_only: Option, } impl Default for ConnectivityConfig { @@ -62,6 +65,7 @@ impl Default for ConnectivityConfig { max_failures_mark_offline: 1, connection_tie_break_linger: Duration::from_secs(2), expire_peer_last_seen_duration: Duration::from_secs(24 * 60 * 60), + maintain_n_closest_connections_only: None, } } } diff --git a/comms/core/src/connectivity/manager.rs b/comms/core/src/connectivity/manager.rs index 1e9b7d18e3..ed61e8ad8f 100644 --- a/comms/core/src/connectivity/manager.rs +++ b/comms/core/src/connectivity/manager.rs @@ -389,11 +389,67 @@ impl ConnectivityManagerActor { if self.config.is_connection_reaping_enabled { self.reap_inactive_connections().await; } + if let Some(threshold) = self.config.maintain_n_closest_connections_only { + self.maintain_n_closest_peer_connections_only(threshold).await; + } self.update_connectivity_status(); self.update_connectivity_metrics(); Ok(()) } + async fn maintain_n_closest_peer_connections_only(&mut self, threshold: usize) { + // Select all active peer connections + let mut connections = match self + .select_connections(ConnectivitySelection::closest_to( + self.node_identity.node_id().clone(), + self.pool.count_connected_nodes(), + vec![], + )) + .await + { + Ok(peers) => peers, + Err(e) => { + warn!( + target: LOG_TARGET, + "Connectivity error trying to maintain {} closest peers ({:?})", + threshold, + e + ); + return; + }, + }; + + // Remove peers that on are the allow list + let mut nodes_to_ignore = vec![]; + for conn in &connections { + if self.allow_list.contains(conn.peer_node_id()) { + nodes_to_ignore.push(conn.peer_node_id().clone()); + } + } + connections.retain(|conn| !nodes_to_ignore.contains(conn.peer_node_id())); + // Remove peers that are not communication nodes + connections.retain(|conn| conn.peer_features().is_node()); + + // Disconnect all remaining peers above the threshold + for conn in connections.iter_mut().skip(threshold) { + debug!( + target: LOG_TARGET, + "Disconnecting '{}' because the node is not among the {} closest peers", + conn.peer_node_id(), + threshold + ); + if let Err(err) = conn.disconnect().await { + // Already disconnected + debug!( + target: LOG_TARGET, + "Peer '{}' already disconnected. Error: {:?}", + conn.peer_node_id().short_str(), + err + ); + } + } + } + async fn reap_inactive_connections(&mut self) { let excess_connections = self .pool diff --git a/comms/core/src/peer_manager/peer_query.rs b/comms/core/src/peer_manager/peer_query.rs index 71f56b3fd7..fe049e7e37 100644 --- a/comms/core/src/peer_manager/peer_query.rs +++ b/comms/core/src/peer_manager/peer_query.rs @@ -246,7 +246,7 @@ mod test { #[test] fn limit_query() { - // Create 20 peers were the 1st and last one is bad + // Create some good peers let db = HashmapDatabase::new(); let mut id_counter = 0; @@ -262,11 +262,7 @@ mod test { #[test] fn select_where_query() { - // Create peer manager with random peers - let mut sample_peers = Vec::new(); - // Create 20 peers were the 1st and last one is bad - let _rng = rand::rngs::OsRng; - sample_peers.push(create_test_peer(true)); + // Create some good and bad peers let db = HashmapDatabase::new(); let mut id_counter = 0; @@ -292,11 +288,7 @@ mod test { #[test] fn select_where_limit_query() { - // Create peer manager with random peers - let mut sample_peers = Vec::new(); - // Create 20 peers were the 1st and last one is bad - let _rng = rand::rngs::OsRng; - sample_peers.push(create_test_peer(true)); + // Create some good and bad peers let db = HashmapDatabase::new(); let mut id_counter = 0; @@ -333,11 +325,7 @@ mod test { #[test] fn sort_by_query() { - // Create peer manager with random peers - let mut sample_peers = Vec::new(); - // Create 20 peers were the 1st and last one is bad - let _rng = rand::rngs::OsRng; - sample_peers.push(create_test_peer(true)); + // Create some good and bad peers let db = HashmapDatabase::new(); let mut id_counter = 0; diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 4937eaacb4..ec502499bc 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -50,6 +50,9 @@ pub struct DhtConfig { /// Number of random peers to include /// Default: 4 pub num_random_nodes: usize, + /// Connections above the configured number of neighbouring and random nodes will be removed + /// (default: false) + pub minimize_connections: bool, /// Send to this many peers when using the broadcast strategy /// Default: 8 pub broadcast_factor: usize, @@ -169,6 +172,7 @@ impl Default for DhtConfig { protocol_version: DhtProtocolVersion::latest(), num_neighbouring_nodes: 8, num_random_nodes: 4, + minimize_connections: false, propagation_factor: 4, broadcast_factor: 8, outbound_buffer_size: 20, diff --git a/comms/dht/src/connectivity/mod.rs b/comms/dht/src/connectivity/mod.rs index 9094396392..6aceaf04f4 100644 --- a/comms/dht/src/connectivity/mod.rs +++ b/comms/dht/src/connectivity/mod.rs @@ -385,14 +385,9 @@ impl DhtConnectivity { debug!( target: LOG_TARGET, - "Adding {} neighbouring peer(s), removing {} peers", - new_neighbours.len(), - difference.len() - ); - debug!( - target: LOG_TARGET, - "Adding {} peer(s) to DHT connectivity manager: {}", + "Adding {} neighbouring peer(s), removing {} peers: {}", new_neighbours.len(), + difference.len(), new_neighbours .iter() .map(ToString::to_string) @@ -401,7 +396,7 @@ impl DhtConnectivity { ); new_neighbours.iter().cloned().for_each(|peer| { - self.insert_neighbour(peer); + self.insert_neighbour_ordered_by_distance(peer); }); if !new_neighbours.is_empty() { @@ -482,7 +477,9 @@ impl DhtConnectivity { random_peers, difference ); - self.random_pool.extend(random_peers.clone()); + for peer in &random_peers { + self.insert_random_peer_ordered_by_distance(peer.clone()); + } // Drop any connection handles that removed from the random pool difference.iter().for_each(|peer| { self.remove_connection_handle(peer); @@ -524,24 +521,55 @@ impl DhtConnectivity { let peer_to_insert = conn.peer_node_id().clone(); self.insert_connection_handle(conn); - if let Some(node_id) = self.insert_neighbour(peer_to_insert.clone()) { - // If we kicked a neighbour out of our neighbour pool but the random pool is not full. - // Add the neighbour to the random pool, otherwise remove the handle from the connection pool - if self.random_pool.len() < self.config.num_random_nodes { - debug!( - target: LOG_TARGET, - "Moving peer '{}' from neighbouring pool to random pool", peer_to_insert - ); - self.random_pool.push(node_id); - } else { - self.remove_connection_handle(&node_id) - } + if let Some(node_id) = self.insert_neighbour_ordered_by_distance(peer_to_insert.clone()) { + // If we kicked a neighbour out of our neighbour pool, add it to the random pool if + // it is not full or if it is closer than the furthest random peer. + debug!( + target: LOG_TARGET, + "Moving peer '{}' from neighbouring pool to random pool if not full or closer", peer_to_insert + ); + self.insert_random_peer_ordered_by_distance(node_id) } } Ok(()) } + async fn minimize_connections(&mut self) -> Result<(), DhtConnectivityError> { + // Retrieve all communication node peers with an active connection status + let query = PeerQuery::new() + .select_where(|peer| { + self.connection_handles + .iter() + .any(|conn| conn.peer_node_id() == &peer.node_id && peer.features.is_node()) + }) + .sort_by(PeerQuerySortBy::DistanceFrom(self.node_identity.node_id())); + let mut peers_by_distance = self.peer_manager.perform_query(query).await?; + debug!( + target: LOG_TARGET, + "minimize_connections: Filtered peers: {}, Handles: {}", + peers_by_distance.len(), + self.connection_handles.len(), + ); + + // Remove all above threshold connections + let threshold = self.config.num_neighbouring_nodes + self.config.num_random_nodes; + for peer in peers_by_distance.iter_mut().skip(threshold) { + debug!( + target: LOG_TARGET, + "minimize_connections: Disconnecting '{}' because the node is not among the {} closest peers", + peer.node_id, + threshold + ); + // Remove from managed pool if applicable + self.replace_pool_peer(&peer.node_id).await?; + // In case the connections was not managed, remove the connection handle + self.remove_connection_handle(&peer.node_id); + } + + Ok(()) + } + fn insert_connection_handle(&mut self, conn: PeerConnection) { // Remove any existing connection for this peer self.remove_connection_handle(conn.peer_node_id()); @@ -563,6 +591,9 @@ impl DhtConnectivity { match event { PeerConnected(conn) => { self.handle_new_peer_connected(*conn).await?; + if self.config.minimize_connections { + self.minimize_connections().await?; + } }, PeerConnectFailed(node_id) => { self.connection_handles.retain(|c| *c.peer_node_id() != node_id); @@ -587,6 +618,9 @@ impl DhtConnectivity { "Failed to clear metrics for peer `{}`. Metric collector is shut down.", node_id ); }; + if self.config.minimize_connections { + self.minimize_connections().await?; + } if !self.is_pool_peer(&node_id) { debug!(target: LOG_TARGET, "{} is not managed by the DHT. Ignoring", node_id); return Ok(()); @@ -629,7 +663,8 @@ impl DhtConnectivity { .iter() .position(|n| n == current_peer) .expect("unreachable panic"); - self.random_pool.swap_remove(pos); + self.random_pool.remove(pos); + self.remove_connection_handle(current_peer); debug!( target: LOG_TARGET, @@ -637,11 +672,7 @@ impl DhtConnectivity { ); match self.fetch_random_peers(1, &exclude).await?.pop() { Some(new_peer) => { - self.remove_connection_handle(current_peer); - if let Some(pos) = self.random_pool.iter().position(|n| n == current_peer) { - self.random_pool.swap_remove(pos); - } - self.random_pool.push(new_peer.clone()); + self.insert_random_peer_ordered_by_distance(new_peer.clone()); self.connectivity.request_many_dials([new_peer]).await?; }, None => { @@ -664,19 +695,16 @@ impl DhtConnectivity { .position(|n| n == current_peer) .expect("unreachable panic"); self.neighbours.remove(pos); + self.remove_connection_handle(current_peer); debug!( target: LOG_TARGET, "Peer '{}' in neighbour pool is offline. Adding a new peer if possible", current_peer ); match self.fetch_neighbouring_peers(1, &exclude).await?.pop() { - Some(node_id) => { - self.remove_connection_handle(current_peer); - if let Some(pos) = self.neighbours.iter().position(|n| n == current_peer) { - self.neighbours.remove(pos); - } - self.insert_neighbour(node_id.clone()); - self.connectivity.request_many_dials([node_id]).await?; + Some(new_peer) => { + self.insert_neighbour_ordered_by_distance(new_peer.clone()); + self.connectivity.request_many_dials([new_peer]).await?; }, None => { info!( @@ -693,29 +721,48 @@ impl DhtConnectivity { Ok(()) } - fn insert_neighbour(&mut self, node_id: NodeId) -> Option { + fn insert_neighbour_ordered_by_distance(&mut self, node_id: NodeId) -> Option { let dist = node_id.distance(self.node_identity.node_id()); let pos = self .neighbours .iter() .position(|node_id| node_id.distance(self.node_identity.node_id()) > dist); - let removed_peer = if self.neighbours.len() + 1 > self.config.num_neighbouring_nodes { + match pos { + Some(idx) => { + self.neighbours.insert(idx, node_id); + }, + None => { + self.neighbours.push(node_id); + }, + } + + if self.neighbours.len() > self.config.num_neighbouring_nodes { self.neighbours.pop() } else { None - }; + } + } + + fn insert_random_peer_ordered_by_distance(&mut self, node_id: NodeId) { + let dist = node_id.distance(self.node_identity.node_id()); + let pos = self + .random_pool + .iter() + .position(|node_id| node_id.distance(self.node_identity.node_id()) > dist); match pos { Some(idx) => { - self.neighbours.insert(idx, node_id); + self.random_pool.insert(idx, node_id); }, None => { - self.neighbours.push(node_id); + self.random_pool.push(node_id); }, } - removed_peer + if self.random_pool.len() > self.config.num_random_nodes { + self.random_pool.pop(); + } } fn is_pool_peer(&self, node_id: &NodeId) -> bool { @@ -790,6 +837,15 @@ impl DhtConnectivity { return false; } + if self.config.minimize_connections { + // If the peer is not closer, return false + let dist = self.node_identity.node_id().distance(&peer.node_id); + let neighbour_distance = self.get_neighbour_max_distance(); + if dist >= neighbour_distance { + return false; + } + } + true }) .sort_by(PeerQuerySortBy::DistanceFrom(node_id)) diff --git a/comms/dht/src/connectivity/test.rs b/comms/dht/src/connectivity/test.rs index 3120aa075b..53a07090d1 100644 --- a/comms/dht/src/connectivity/test.rs +++ b/comms/dht/src/connectivity/test.rs @@ -242,12 +242,16 @@ async fn insert_neighbour() { // First 8 inserts should not remove a peer (because num_neighbouring_nodes == 8) for ni in shuffled.iter().take(8) { - assert!(dht_connectivity.insert_neighbour(ni.node_id().clone()).is_none()); + assert!(dht_connectivity + .insert_neighbour_ordered_by_distance(ni.node_id().clone()) + .is_none()); } // Next 2 inserts will always remove a node id for ni in shuffled.iter().skip(8) { - assert!(dht_connectivity.insert_neighbour(ni.node_id().clone()).is_some()) + assert!(dht_connectivity + .insert_neighbour_ordered_by_distance(ni.node_id().clone()) + .is_some()) } // Check the first 7 node ids match our neighbours, the last element depends on distance and ordering of inserts