Skip to content

Commit

Permalink
fix: validate dht header before dedup cache (#3468)
Browse files Browse the repository at this point in the history
Description
---
- reorders the DHT messaging layers to validate the message before entering the dedup store.
- adds the origin_mac to the dedup hash

This PR was written by @Impala123, I finished off a rust integration test 

Motivation and Context
---

From original PR: #3450 

> With the current order of layers, a malicious node could tamper with a message which would then be discarded by the 
> validation layer. However the dedup cache currently stores this before it is discarded by validate. Thus any un-tampered 
> version of the same message would no longer be processed.

A valid origin mac means the message comes from the possessor of the private key and has not been altered.
The valid origin mac bytes are included in the dedup hash preimage so that the origin of the message (if any) is tied to
the dedup entry. Previously, an attacker could craft a message `A'` that had no/different valid origin MAC but the same body and cause a subsequent message `A` to not to be discarded as a duplicate.

How Has This Been Tested?
---
Rust integration test
memorynet
  • Loading branch information
sdbondi committed Oct 18, 2021
1 parent 6f61582 commit 81f01d2
Show file tree
Hide file tree
Showing 11 changed files with 224 additions and 44 deletions.
Expand Up @@ -688,7 +688,7 @@ where
.outbound_message_service
.closest_broadcast(
NodeId::from_public_key(&self.dest_pubkey),
OutboundEncryption::EncryptFor(Box::new(self.dest_pubkey.clone())),
OutboundEncryption::encrypt_for(self.dest_pubkey.clone()),
vec![],
OutboundDomainMessage::new(TariMessageType::SenderPartialTransaction, proto_message),
)
Expand Down
Expand Up @@ -215,7 +215,7 @@ async fn send_transaction_finalized_message_store_and_forward(
match outbound_message_service
.closest_broadcast(
NodeId::from_public_key(&destination_pubkey),
OutboundEncryption::EncryptFor(Box::new(destination_pubkey.clone())),
OutboundEncryption::encrypt_for(destination_pubkey.clone()),
vec![],
OutboundDomainMessage::new(TariMessageType::TransactionFinalized, msg.clone()),
)
Expand Down
Expand Up @@ -48,7 +48,7 @@ pub async fn send_transaction_cancelled_message(
let _ = outbound_message_service
.closest_broadcast(
NodeId::from_public_key(&destination_public_key),
OutboundEncryption::EncryptFor(Box::new(destination_public_key)),
OutboundEncryption::encrypt_for(destination_public_key),
vec![],
OutboundDomainMessage::new(TariMessageType::SenderPartialTransaction, proto_message),
)
Expand Down
Expand Up @@ -196,7 +196,7 @@ async fn send_transaction_reply_store_and_forward(
match outbound_message_service
.closest_broadcast(
NodeId::from_public_key(&destination_pubkey),
OutboundEncryption::EncryptFor(Box::new(destination_pubkey.clone())),
OutboundEncryption::encrypt_for(destination_pubkey.clone()),
vec![],
OutboundDomainMessage::new(TariMessageType::ReceiverPartialTransactionReply, msg),
)
Expand Down
4 changes: 2 additions & 2 deletions comms/dht/examples/memory_net/utilities.rs
Expand Up @@ -436,7 +436,7 @@ pub async fn do_store_and_forward_message_propagation(
.outbound_requester()
.closest_broadcast(
node_identity.node_id().clone(),
OutboundEncryption::EncryptFor(Box::new(node_identity.public_key().clone())),
OutboundEncryption::encrypt_for(node_identity.public_key().clone()),
vec![],
OutboundDomainMessage::new(123i32, secret_message.clone()),
)
Expand Down Expand Up @@ -716,7 +716,7 @@ impl TestNode {
loop {
match conn_man_event_sub.recv().await {
Ok(event) => {
events_tx.send(logger(event)).await.unwrap();
let _ = events_tx.send(logger(event)).await;
},
Err(broadcast::error::RecvError::Closed) => break,
Err(err) => log::error!("{}", err),
Expand Down
43 changes: 20 additions & 23 deletions comms/dht/src/dedup/mod.rs
Expand Up @@ -24,24 +24,19 @@ mod dedup_cache;

pub use dedup_cache::DedupCacheDatabase;

use crate::{actor::DhtRequester, inbound::DhtInboundMessage};
use digest::Digest;
use crate::{actor::DhtRequester, inbound::DecryptedDhtMessage};
use futures::{future::BoxFuture, task::Context};
use log::*;
use std::task::Poll;
use tari_comms::{pipeline::PipelineError, types::Challenge};
use tari_comms::pipeline::PipelineError;
use tari_utilities::hex::Hex;
use tower::{layer::Layer, Service, ServiceExt};

const LOG_TARGET: &str = "comms::dht::dedup";

fn hash_inbound_message(message: &DhtInboundMessage) -> Vec<u8> {
Challenge::new().chain(&message.body).finalize().to_vec()
}

/// # DHT Deduplication middleware
///
/// Takes in a `DhtInboundMessage` and checks the message signature cache for duplicates.
/// Takes in a `DecryptedDhtMessage` and checks the message signature cache for duplicates.
/// If a duplicate message is detected, it is discarded.
#[derive(Clone)]
pub struct DedupMiddleware<S> {
Expand All @@ -60,9 +55,9 @@ impl<S> DedupMiddleware<S> {
}
}

impl<S> Service<DhtInboundMessage> for DedupMiddleware<S>
impl<S> Service<DecryptedDhtMessage> for DedupMiddleware<S>
where
S: Service<DhtInboundMessage, Response = (), Error = PipelineError> + Clone + Send + 'static,
S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError> + Clone + Send + 'static,
S::Future: Send,
{
type Error = PipelineError;
Expand All @@ -73,22 +68,21 @@ where
Poll::Ready(Ok(()))
}

fn call(&mut self, mut message: DhtInboundMessage) -> Self::Future {
fn call(&mut self, mut message: DecryptedDhtMessage) -> Self::Future {
let next_service = self.next_service.clone();
let mut dht_requester = self.dht_requester.clone();
let allowed_message_occurrences = self.allowed_message_occurrences;
Box::pin(async move {
let hash = hash_inbound_message(&message);
trace!(
target: LOG_TARGET,
"Inserting message hash {} for message {} (Trace: {})",
hash.to_hex(),
message.hash.to_hex(),
message.tag,
message.dht_header.message_tag
);

message.dedup_hit_count = dht_requester
.add_message_to_dedup_cache(hash, message.source_peer.public_key.clone())
.add_message_to_dedup_cache(message.hash.clone(), message.source_peer.public_key.clone())
.await?;

if message.dedup_hit_count as usize > allowed_message_occurrences {
Expand Down Expand Up @@ -144,6 +138,7 @@ mod test {
envelope::DhtMessageFlags,
test_utils::{create_dht_actor_mock, make_dht_inbound_message, make_node_identity, service_spy},
};
use tari_comms::wrap_in_envelope_body;
use tari_test_utils::panic_context;
use tokio::runtime::Runtime;

Expand All @@ -163,13 +158,14 @@ mod test {

assert!(dedup.poll_ready(&mut cx).is_ready());
let node_identity = make_node_identity();
let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false, false);
let inbound_message = make_dht_inbound_message(&node_identity, vec![], DhtMessageFlags::empty(), false, false);
let decrypted_msg = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, inbound_message);

rt.block_on(dedup.call(msg.clone())).unwrap();
rt.block_on(dedup.call(decrypted_msg.clone())).unwrap();
assert_eq!(spy.call_count(), 1);

mock_state.set_number_of_message_hits(4);
rt.block_on(dedup.call(msg)).unwrap();
rt.block_on(dedup.call(decrypted_msg)).unwrap();
assert_eq!(spy.call_count(), 1);
// Drop dedup so that the DhtMock will stop running
drop(dedup);
Expand All @@ -179,28 +175,29 @@ mod test {
fn deterministic_hash() {
const TEST_MSG: &[u8] = b"test123";
const EXPECTED_HASH: &str = "90cccd774db0ac8c6ea2deff0e26fc52768a827c91c737a2e050668d8c39c224";

let node_identity = make_node_identity();
let msg = make_dht_inbound_message(
let dht_message = make_dht_inbound_message(
&node_identity,
TEST_MSG.to_vec(),
DhtMessageFlags::empty(),
false,
false,
);
let hash1 = hash_inbound_message(&msg);
let decrypted1 = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, dht_message);

let node_identity = make_node_identity();
let msg = make_dht_inbound_message(
let dht_message = make_dht_inbound_message(
&node_identity,
TEST_MSG.to_vec(),
DhtMessageFlags::empty(),
false,
false,
);
let hash2 = hash_inbound_message(&msg);
let decrypted2 = DecryptedDhtMessage::succeeded(wrap_in_envelope_body!(vec![]), None, dht_message);

assert_eq!(hash1, hash2);
let subjects = &[hash1, hash2];
assert_eq!(decrypted1.hash, decrypted2.hash);
let subjects = &[decrypted1.hash, decrypted2.hash];
assert!(subjects.iter().all(|h| h.to_hex() == EXPECTED_HASH));
}
}
14 changes: 7 additions & 7 deletions comms/dht/src/dht.rs
Expand Up @@ -295,21 +295,21 @@ impl Dht {
ServiceBuilder::new()
.layer(MetricsLayer::new(self.metrics_collector.clone()))
.layer(inbound::DeserializeLayer::new(self.peer_manager.clone()))
.layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter()))
.layer(inbound::DecryptionLayer::new(
self.config.clone(),
self.node_identity.clone(),
self.connectivity.clone(),
))
.layer(DedupLayer::new(
self.dht_requester(),
self.config.dedup_allowed_message_occurrences,
))
.layer(filter::FilterLayer::new(self.unsupported_saf_messages_filter()))
.layer(filter::FilterLayer::new(filter_messages_to_rebroadcast))
.layer(MessageLoggingLayer::new(format!(
"Inbound [{}]",
self.node_identity.node_id().short_str()
)))
.layer(inbound::DecryptionLayer::new(
self.config.clone(),
self.node_identity.clone(),
self.connectivity.clone(),
))
.layer(filter::FilterLayer::new(filter_messages_to_rebroadcast))
.layer(store_forward::StoreLayer::new(
self.config.clone(),
Arc::clone(&self.peer_manager),
Expand Down
14 changes: 13 additions & 1 deletion comms/dht/src/inbound/message.rs
Expand Up @@ -21,6 +21,7 @@
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use crate::envelope::{DhtMessageFlags, DhtMessageHeader};
use digest::Digest;
use std::{
fmt,
fmt::{Display, Formatter},
Expand All @@ -29,9 +30,17 @@ use std::{
use tari_comms::{
message::{EnvelopeBody, MessageTag},
peer_manager::Peer,
types::CommsPublicKey,
types::{Challenge, CommsPublicKey},
};

fn hash_inbound_message(message: &DhtInboundMessage) -> Vec<u8> {
Challenge::new()
.chain(&message.dht_header.origin_mac)
.chain(&message.body)
.finalize()
.to_vec()
}

#[derive(Debug, Clone)]
pub struct DhtInboundMessage {
pub tag: MessageTag,
Expand Down Expand Up @@ -84,6 +93,7 @@ pub struct DecryptedDhtMessage {
pub is_already_forwarded: bool,
pub decryption_result: Result<EnvelopeBody, Vec<u8>>,
pub dedup_hit_count: u32,
pub hash: Vec<u8>,
}

impl DecryptedDhtMessage {
Expand All @@ -104,6 +114,7 @@ impl DecryptedDhtMessage {
message: DhtInboundMessage,
) -> Self {
Self {
hash: hash_inbound_message(&message),
tag: message.tag,
source_peer: message.source_peer,
authenticated_origin,
Expand All @@ -118,6 +129,7 @@ impl DecryptedDhtMessage {

pub fn failed(message: DhtInboundMessage) -> Self {
Self {
hash: hash_inbound_message(&message),
tag: message.tag,
source_peer: message.source_peer,
authenticated_origin: None,
Expand Down
4 changes: 4 additions & 0 deletions comms/dht/src/outbound/message.rs
Expand Up @@ -46,6 +46,10 @@ pub enum OutboundEncryption {
}

impl OutboundEncryption {
pub fn encrypt_for(public_key: CommsPublicKey) -> Self {
OutboundEncryption::EncryptFor(Box::new(public_key))
}

/// Return the correct DHT flags for the encryption setting
pub fn flags(&self) -> DhtMessageFlags {
match self {
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/outbound/message_params.rs
Expand Up @@ -40,7 +40,7 @@ use tari_comms::{message::MessageTag, peer_manager::NodeId, types::CommsPublicKe
/// let dest_public_key = CommsPublicKey::default();
/// let params = SendMessageParams::new()
/// .random(5)
/// .with_encryption(OutboundEncryption::EncryptFor(Box::new(dest_public_key)))
/// .with_encryption(OutboundEncryption::encrypt_for(dest_public_key))
/// .finish();
/// ```
#[derive(Debug, Clone)]
Expand Down

0 comments on commit 81f01d2

Please sign in to comment.