Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: make DhtHeader field non-malleable #3243

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
764 changes: 396 additions & 368 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion comms/dht/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ serde_derive = "1.0.90"
serde_repr = "0.1.5"
thiserror = "1.0.20"
tokio = {version="0.2.10", features=["rt-threaded", "blocking"]}
tower= "0.3.1"
tower = {version= "0.4.8", features=["full"]}
ttl_cache = "0.5.1"

# tower-filter dependencies
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

/// Version for DHT envelope
pub const DHT_MAJOR_VERSION: u32 = 0;
pub const DHT_MAJOR_VERSION: u32 = 1;
pub const DHT_MINOR_VERSION: u32 = 0;
18 changes: 15 additions & 3 deletions comms/dht/src/dedup/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ mod test {

assert!(dedup.poll_ready(&mut cx).is_ready());
let node_identity = make_node_identity();
let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false);
let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false, false);

rt.block_on(dedup.call(msg.clone())).unwrap();
assert_eq!(spy.call_count(), 1);
Expand All @@ -169,11 +169,23 @@ mod test {
const TEST_MSG: &[u8] = b"test123";
const EXPECTED_HASH: &str = "90cccd774db0ac8c6ea2deff0e26fc52768a827c91c737a2e050668d8c39c224";
let node_identity = make_node_identity();
let msg = make_dht_inbound_message(&node_identity, TEST_MSG.to_vec(), DhtMessageFlags::empty(), false);
let msg = make_dht_inbound_message(
&node_identity,
TEST_MSG.to_vec(),
DhtMessageFlags::empty(),
false,
false,
);
let hash1 = hash_inbound_message(&msg);

let node_identity = make_node_identity();
let msg = make_dht_inbound_message(&node_identity, TEST_MSG.to_vec(), DhtMessageFlags::empty(), false);
let msg = make_dht_inbound_message(
&node_identity,
TEST_MSG.to_vec(),
DhtMessageFlags::empty(),
false,
false,
);
let hash2 = hash_inbound_message(&msg);

assert_eq!(hash1, hash2);
Expand Down
5 changes: 5 additions & 0 deletions comms/dht/src/dht.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ mod test {
DhtMessageFlags::empty(),
false,
MessageTag::new(),
false,
);
let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into());

Expand Down Expand Up @@ -504,7 +505,9 @@ mod test {
DhtMessageFlags::ENCRYPTED,
true,
MessageTag::new(),
false,
);

let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into());

let msg = {
Expand Down Expand Up @@ -559,6 +562,7 @@ mod test {
DhtMessageFlags::ENCRYPTED,
true,
MessageTag::new(),
false,
);

let origin_mac = dht_envelope.header.as_ref().unwrap().origin_mac.clone();
Expand Down Expand Up @@ -611,6 +615,7 @@ mod test {
DhtMessageFlags::empty(),
false,
MessageTag::new(),
false,
);
dht_envelope.header.as_mut().map(|header| {
header.message_type = DhtMessageType::SafStoredMessages as i32;
Expand Down
49 changes: 39 additions & 10 deletions comms/dht/src/inbound/decryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,18 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
) -> Result<DecryptedDhtMessage, DecryptionError> {
let dht_header = &message.dht_header;

let mut header_mac_bytes = Vec::with_capacity(256);
header_mac_bytes.extend_from_slice(&dht_header.major.to_le_bytes());
header_mac_bytes.extend_from_slice(&dht_header.minor.to_le_bytes());
header_mac_bytes.extend_from_slice(dht_header.destination.to_inner_bytes().as_slice());
header_mac_bytes.extend_from_slice(&(dht_header.message_type as i32).to_le_bytes());
header_mac_bytes.extend_from_slice(&dht_header.flags.bits().to_le_bytes());
if let Some(t) = dht_header.expires {
header_mac_bytes.extend_from_slice(&t.as_u64().to_le_bytes());
}

if !dht_header.flags.contains(DhtMessageFlags::ENCRYPTED) {
return Self::success_not_encrypted(message).await;
return Self::success_not_encrypted(message, header_mac_bytes).await;
}
trace!(
target: LOG_TARGET,
Expand All @@ -214,9 +224,11 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
// Decrypt and verify the origin
let authenticated_origin = match Self::attempt_decrypt_origin_mac(&shared_secret, dht_header) {
Ok((public_key, signature)) => {
header_mac_bytes.extend_from_slice(e_pk.as_bytes());

// If this fails, discard the message because we decrypted and deserialized the message with our shared
// ECDH secret but the message could not be authenticated
Self::authenticate_origin_mac(&public_key, &signature, &message.body)?;
Self::authenticate_origin_mac(&public_key, &signature, header_mac_bytes.as_slice(), &message.body)?;
public_key
},
Err(err) => {
Expand Down Expand Up @@ -307,9 +319,11 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
fn authenticate_origin_mac(
public_key: &CommsPublicKey,
signature: &[u8],
mac_header: &[u8],
body: &[u8],
) -> Result<(), DecryptionError> {
if signature::verify(public_key, signature, body) {
let mac_body = [mac_header, body].concat();
if signature::verify(public_key, signature, mac_body) {
Ok(())
} else {
Err(DecryptionError::OriginMacInvalidSignature)
Expand Down Expand Up @@ -350,15 +364,24 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
.map_err(|_| DecryptionError::MessageBodyDecryptionFailed)
}

async fn success_not_encrypted(message: DhtInboundMessage) -> Result<DecryptedDhtMessage, DecryptionError> {
async fn success_not_encrypted(
message: DhtInboundMessage,
header_mac_bytes: Vec<u8>,
) -> Result<DecryptedDhtMessage, DecryptionError> {
let authenticated_pk = if message.dht_header.origin_mac.is_empty() {
None
} else {
let origin_mac = OriginMac::decode(message.dht_header.origin_mac.as_slice())
.map_err(|_| DecryptionError::OriginMacClearTextDecodeFailed)?;
let public_key = CommsPublicKey::from_bytes(&origin_mac.public_key)
.map_err(|_| DecryptionError::OriginMacInvalidPublicKey)?;
Self::authenticate_origin_mac(&public_key, &origin_mac.signature, &message.body)?;

Self::authenticate_origin_mac(
&public_key,
&origin_mac.signature,
header_mac_bytes.as_slice(),
&message.body,
)?;
Some(public_key)
};

Expand Down Expand Up @@ -435,6 +458,7 @@ mod test {
plain_text_msg.to_encoded_bytes(),
DhtMessageFlags::ENCRYPTED,
true,
false,
);

block_on(service.call(inbound_msg)).unwrap();
Expand All @@ -459,8 +483,13 @@ mod test {

let some_secret = b"Super secret message".to_vec();
let some_other_node_identity = make_node_identity();
let inbound_msg =
make_dht_inbound_message(&some_other_node_identity, some_secret, DhtMessageFlags::ENCRYPTED, true);
let inbound_msg = make_dht_inbound_message(
&some_other_node_identity,
some_secret,
DhtMessageFlags::ENCRYPTED,
true,
false,
);

block_on(service.call(inbound_msg.clone())).unwrap();
let decrypted = result.lock().unwrap().take().unwrap();
Expand All @@ -471,6 +500,7 @@ mod test {

#[tokio_macros::test_basic]
async fn decrypt_inbound_fail_destination() {
let _ = env_logger::try_init();
let (connectivity, mock) = create_connectivity_mock();
mock.spawn();
let result = Arc::new(Mutex::new(None));
Expand All @@ -485,9 +515,8 @@ mod test {
let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service);

let nonsense = b"Cannot Decrypt this".to_vec();
let mut inbound_msg =
make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED, true);
inbound_msg.dht_header.destination = node_identity.public_key().clone().into();
let inbound_msg =
make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED, true, true);

let err = service.call(inbound_msg).await.unwrap_err();
let err = err.downcast::<DecryptionError>().unwrap();
Expand Down
3 changes: 2 additions & 1 deletion comms/dht/src/inbound/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,11 @@ mod test {
DhtMessageFlags::empty(),
false,
MessageTag::new(),
false,
);

deserialize
.ready_and()
.ready()
.await
.unwrap()
.call(make_comms_inbound_message(
Expand Down
54 changes: 47 additions & 7 deletions comms/dht/src/outbound/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ use super::{error::DhtOutboundError, message::DhtOutboundRequest};
use crate::{
actor::DhtRequester,
broadcast_strategy::BroadcastStrategy,
consts::{DHT_MAJOR_VERSION, DHT_MINOR_VERSION},
crypt,
discovery::DhtDiscoveryRequester,
envelope::{datetime_to_timestamp, DhtMessageFlags, DhtMessageHeader, NodeDestination},
envelope::{datetime_to_epochtime, datetime_to_timestamp, DhtMessageFlags, DhtMessageHeader, NodeDestination},
outbound::{
message::{DhtOutboundMessage, OutboundEncryption, SendFailure},
message_params::FinalSendMessageParams,
Expand Down Expand Up @@ -57,7 +58,7 @@ use tari_comms::{
};
use tari_crypto::{
keys::PublicKey,
tari_utilities::{message_format::MessageFormat, ByteArray},
tari_utilities::{epoch_time::EpochTime, message_format::MessageFormat, ByteArray},
};
use tari_utilities::hex::Hex;
use tower::{layer::Layer, Service, ServiceExt};
Expand Down Expand Up @@ -413,8 +414,17 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
expires: Option<DateTime<Utc>>,
) -> Result<(Vec<DhtOutboundMessage>, Vec<MessageSendState>), DhtOutboundError> {
let dht_flags = encryption.flags() | extra_flags;
let expires_epochtime = expires.map(datetime_to_epochtime);

let (ephemeral_public_key, origin_mac, body) = self.process_encryption(&encryption, force_origin, body)?;
let (ephemeral_public_key, origin_mac, body) = self.process_encryption(
&encryption,
force_origin,
&destination,
&dht_message_type,
&dht_flags,
expires_epochtime.as_ref(),
body,
)?;

if is_broadcast {
self.add_to_dedup_cache(&body, self.node_identity.public_key().clone())
Expand Down Expand Up @@ -462,10 +472,15 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
.map_err(|_| DhtOutboundError::FailedToInsertMessageHash)
}

#[allow(clippy::too_many_arguments)]
fn process_encryption(
&self,
encryption: &OutboundEncryption,
include_origin: bool,
destination: &NodeDestination,
message_type: &DhtMessageType,
flags: &DhtMessageFlags,
expires: Option<&EpochTime>,
body: Bytes,
) -> Result<FinalMessageParts, DhtOutboundError> {
match encryption {
Expand All @@ -477,8 +492,18 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
// Encrypt the message with the body
let encrypted_body = crypt::encrypt(&shared_ephemeral_secret, &body)?;

let mut header_mac_bytes = Vec::with_capacity(256);
header_mac_bytes.extend_from_slice(&DHT_MAJOR_VERSION.to_le_bytes());
header_mac_bytes.extend_from_slice(&DHT_MINOR_VERSION.to_le_bytes());
header_mac_bytes.extend_from_slice(destination.to_inner_bytes().as_slice());
header_mac_bytes.extend_from_slice(&(*message_type as i32).to_le_bytes());
header_mac_bytes.extend_from_slice(&flags.bits().to_le_bytes());
if let Some(t) = expires {
header_mac_bytes.extend_from_slice(&t.as_u64().to_le_bytes());
}
header_mac_bytes.extend_from_slice(e_pk.as_bytes());
// Sign the encrypted message
let origin_mac = create_origin_mac(&self.node_identity, &encrypted_body)?;
let origin_mac = create_origin_mac(&self.node_identity, header_mac_bytes.as_slice(), &encrypted_body)?;
// Encrypt and set the origin field
let encrypted_origin_mac = crypt::encrypt(&shared_ephemeral_secret, &origin_mac)?;
Ok((
Expand All @@ -491,7 +516,16 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
trace!(target: LOG_TARGET, "Encryption not requested for message");

if include_origin {
let origin_mac = create_origin_mac(&self.node_identity, &body)?;
let mut header_mac_bytes = Vec::with_capacity(256);
header_mac_bytes.extend_from_slice(&DHT_MAJOR_VERSION.to_le_bytes());
header_mac_bytes.extend_from_slice(&DHT_MINOR_VERSION.to_le_bytes());
header_mac_bytes.extend_from_slice(destination.to_inner_bytes().as_slice());
header_mac_bytes.extend_from_slice(&(*message_type as i32).to_le_bytes());
header_mac_bytes.extend_from_slice(&flags.bits().to_le_bytes());
if let Some(t) = expires {
header_mac_bytes.extend_from_slice(&t.as_u64().to_le_bytes());
}
let origin_mac = create_origin_mac(&self.node_identity, &header_mac_bytes, &body)?;
Ok((None, Some(origin_mac.into()), body))
} else {
Ok((None, None, body))
Expand All @@ -501,8 +535,14 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
}
}

fn create_origin_mac(node_identity: &NodeIdentity, body: &[u8]) -> Result<Vec<u8>, DhtOutboundError> {
let signature = signature::sign(&mut OsRng, node_identity.secret_key().clone(), body)?;
fn create_origin_mac(
node_identity: &NodeIdentity,
mac_header: &[u8],
body: &[u8],
) -> Result<Vec<u8>, DhtOutboundError> {
let mac_body = [mac_header, body].concat();

let signature = signature::sign(&mut OsRng, node_identity.secret_key().clone(), mac_body)?;

let mac = OriginMac {
public_key: node_identity.public_key().to_vec(),
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/outbound/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ mod test {
let msg = create_outbound_message(body);
assert_send_static_service(&serialize);

let service = serialize.ready_and().await.unwrap();
let service = serialize.ready().await.unwrap();
service.call(msg).await.unwrap();
let mut msg = spy.pop_request().unwrap();
let dht_envelope = DhtEnvelope::decode(&mut msg.body).unwrap();
Expand Down
4 changes: 3 additions & 1 deletion comms/dht/src/store_forward/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ mod test {
let mut service = ForwardLayer::new(oms, true).layer(spy.to_service::<PipelineError>());

let node_identity = make_node_identity();
let inbound_msg = make_dht_inbound_message(&node_identity, b"".to_vec(), DhtMessageFlags::empty(), false);
let inbound_msg =
make_dht_inbound_message(&node_identity, b"".to_vec(), DhtMessageFlags::empty(), false, false);
let msg = DecryptedDhtMessage::succeeded(
wrap_in_envelope_body!(Vec::new()),
Some(node_identity.public_key().clone()),
Expand All @@ -301,6 +302,7 @@ mod test {
sample_body.to_vec(),
DhtMessageFlags::empty(),
false,
false,
);
let header = inbound_msg.dht_header.clone();
let msg = DecryptedDhtMessage::failed(inbound_msg);
Expand Down
Loading