Skip to content

Commit

Permalink
feat!: Make DhtHeader field non-malleable
Browse files Browse the repository at this point in the history
### Breaking Change to comms

This PR includes the relevant DhtHeader fields in the commitment of a Dht Message MAC signature so that they cannot be changed while in route.

The header fields included in the commitment are:
- Major version
- Minor version
- Destination
- Message type
- Message flags
- Expiry time (if exists)
- Ephemeral public key (if exists)
  • Loading branch information
philipr-za committed Aug 26, 2021
1 parent 098f25d commit b8cc0a3
Show file tree
Hide file tree
Showing 14 changed files with 690 additions and 409 deletions.
764 changes: 396 additions & 368 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion comms/dht/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ serde_derive = "1.0.90"
serde_repr = "0.1.5"
thiserror = "1.0.20"
tokio = {version="0.2.10", features=["rt-threaded", "blocking"]}
tower= "0.3.1"
tower = {version= "0.4.8", features=["full"]}
ttl_cache = "0.5.1"

# tower-filter dependencies
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

/// Version for DHT envelope
pub const DHT_MAJOR_VERSION: u32 = 0;
pub const DHT_MAJOR_VERSION: u32 = 1;
pub const DHT_MINOR_VERSION: u32 = 0;
18 changes: 15 additions & 3 deletions comms/dht/src/dedup/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ mod test {

assert!(dedup.poll_ready(&mut cx).is_ready());
let node_identity = make_node_identity();
let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false);
let msg = make_dht_inbound_message(&node_identity, Vec::new(), DhtMessageFlags::empty(), false, false);

rt.block_on(dedup.call(msg.clone())).unwrap();
assert_eq!(spy.call_count(), 1);
Expand All @@ -169,11 +169,23 @@ mod test {
const TEST_MSG: &[u8] = b"test123";
const EXPECTED_HASH: &str = "90cccd774db0ac8c6ea2deff0e26fc52768a827c91c737a2e050668d8c39c224";
let node_identity = make_node_identity();
let msg = make_dht_inbound_message(&node_identity, TEST_MSG.to_vec(), DhtMessageFlags::empty(), false);
let msg = make_dht_inbound_message(
&node_identity,
TEST_MSG.to_vec(),
DhtMessageFlags::empty(),
false,
false,
);
let hash1 = hash_inbound_message(&msg);

let node_identity = make_node_identity();
let msg = make_dht_inbound_message(&node_identity, TEST_MSG.to_vec(), DhtMessageFlags::empty(), false);
let msg = make_dht_inbound_message(
&node_identity,
TEST_MSG.to_vec(),
DhtMessageFlags::empty(),
false,
false,
);
let hash2 = hash_inbound_message(&msg);

assert_eq!(hash1, hash2);
Expand Down
5 changes: 5 additions & 0 deletions comms/dht/src/dht.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ mod test {
DhtMessageFlags::empty(),
false,
MessageTag::new(),
false,
);
let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into());

Expand Down Expand Up @@ -504,7 +505,9 @@ mod test {
DhtMessageFlags::ENCRYPTED,
true,
MessageTag::new(),
false,
);

let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into());

let msg = {
Expand Down Expand Up @@ -559,6 +562,7 @@ mod test {
DhtMessageFlags::ENCRYPTED,
true,
MessageTag::new(),
false,
);

let origin_mac = dht_envelope.header.as_ref().unwrap().origin_mac.clone();
Expand Down Expand Up @@ -611,6 +615,7 @@ mod test {
DhtMessageFlags::empty(),
false,
MessageTag::new(),
false,
);
dht_envelope.header.as_mut().map(|header| {
header.message_type = DhtMessageType::SafStoredMessages as i32;
Expand Down
49 changes: 39 additions & 10 deletions comms/dht/src/inbound/decryption.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,18 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
) -> Result<DecryptedDhtMessage, DecryptionError> {
let dht_header = &message.dht_header;

let mut header_mac_bytes = Vec::with_capacity(256);
header_mac_bytes.extend_from_slice(&dht_header.major.to_le_bytes());
header_mac_bytes.extend_from_slice(&dht_header.minor.to_le_bytes());
header_mac_bytes.extend_from_slice(dht_header.destination.to_inner_bytes().as_slice());
header_mac_bytes.extend_from_slice(&(dht_header.message_type as i32).to_le_bytes());
header_mac_bytes.extend_from_slice(&dht_header.flags.bits().to_le_bytes());
if let Some(t) = dht_header.expires {
header_mac_bytes.extend_from_slice(&t.as_u64().to_le_bytes());
}

if !dht_header.flags.contains(DhtMessageFlags::ENCRYPTED) {
return Self::success_not_encrypted(message).await;
return Self::success_not_encrypted(message, header_mac_bytes).await;
}
trace!(
target: LOG_TARGET,
Expand All @@ -214,9 +224,11 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
// Decrypt and verify the origin
let authenticated_origin = match Self::attempt_decrypt_origin_mac(&shared_secret, dht_header) {
Ok((public_key, signature)) => {
header_mac_bytes.extend_from_slice(e_pk.as_bytes());

// If this fails, discard the message because we decrypted and deserialized the message with our shared
// ECDH secret but the message could not be authenticated
Self::authenticate_origin_mac(&public_key, &signature, &message.body)?;
Self::authenticate_origin_mac(&public_key, &signature, header_mac_bytes.as_slice(), &message.body)?;
public_key
},
Err(err) => {
Expand Down Expand Up @@ -307,9 +319,11 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
fn authenticate_origin_mac(
public_key: &CommsPublicKey,
signature: &[u8],
mac_header: &[u8],
body: &[u8],
) -> Result<(), DecryptionError> {
if signature::verify(public_key, signature, body) {
let mac_body = [mac_header, body].concat();
if signature::verify(public_key, signature, mac_body) {
Ok(())
} else {
Err(DecryptionError::OriginMacInvalidSignature)
Expand Down Expand Up @@ -350,15 +364,24 @@ where S: Service<DecryptedDhtMessage, Response = (), Error = PipelineError>
.map_err(|_| DecryptionError::MessageBodyDecryptionFailed)
}

async fn success_not_encrypted(message: DhtInboundMessage) -> Result<DecryptedDhtMessage, DecryptionError> {
async fn success_not_encrypted(
message: DhtInboundMessage,
header_mac_bytes: Vec<u8>,
) -> Result<DecryptedDhtMessage, DecryptionError> {
let authenticated_pk = if message.dht_header.origin_mac.is_empty() {
None
} else {
let origin_mac = OriginMac::decode(message.dht_header.origin_mac.as_slice())
.map_err(|_| DecryptionError::OriginMacClearTextDecodeFailed)?;
let public_key = CommsPublicKey::from_bytes(&origin_mac.public_key)
.map_err(|_| DecryptionError::OriginMacInvalidPublicKey)?;
Self::authenticate_origin_mac(&public_key, &origin_mac.signature, &message.body)?;

Self::authenticate_origin_mac(
&public_key,
&origin_mac.signature,
header_mac_bytes.as_slice(),
&message.body,
)?;
Some(public_key)
};

Expand Down Expand Up @@ -435,6 +458,7 @@ mod test {
plain_text_msg.to_encoded_bytes(),
DhtMessageFlags::ENCRYPTED,
true,
false,
);

block_on(service.call(inbound_msg)).unwrap();
Expand All @@ -459,8 +483,13 @@ mod test {

let some_secret = b"Super secret message".to_vec();
let some_other_node_identity = make_node_identity();
let inbound_msg =
make_dht_inbound_message(&some_other_node_identity, some_secret, DhtMessageFlags::ENCRYPTED, true);
let inbound_msg = make_dht_inbound_message(
&some_other_node_identity,
some_secret,
DhtMessageFlags::ENCRYPTED,
true,
false,
);

block_on(service.call(inbound_msg.clone())).unwrap();
let decrypted = result.lock().unwrap().take().unwrap();
Expand All @@ -471,6 +500,7 @@ mod test {

#[tokio_macros::test_basic]
async fn decrypt_inbound_fail_destination() {
let _ = env_logger::try_init();
let (connectivity, mock) = create_connectivity_mock();
mock.spawn();
let result = Arc::new(Mutex::new(None));
Expand All @@ -485,9 +515,8 @@ mod test {
let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service);

let nonsense = b"Cannot Decrypt this".to_vec();
let mut inbound_msg =
make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED, true);
inbound_msg.dht_header.destination = node_identity.public_key().clone().into();
let inbound_msg =
make_dht_inbound_message(&node_identity, nonsense.clone(), DhtMessageFlags::ENCRYPTED, true, true);

let err = service.call(inbound_msg).await.unwrap_err();
let err = err.downcast::<DecryptionError>().unwrap();
Expand Down
3 changes: 2 additions & 1 deletion comms/dht/src/inbound/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,11 @@ mod test {
DhtMessageFlags::empty(),
false,
MessageTag::new(),
false,
);

deserialize
.ready_and()
.ready()
.await
.unwrap()
.call(make_comms_inbound_message(
Expand Down
54 changes: 47 additions & 7 deletions comms/dht/src/outbound/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ use super::{error::DhtOutboundError, message::DhtOutboundRequest};
use crate::{
actor::DhtRequester,
broadcast_strategy::BroadcastStrategy,
consts::{DHT_MAJOR_VERSION, DHT_MINOR_VERSION},
crypt,
discovery::DhtDiscoveryRequester,
envelope::{datetime_to_timestamp, DhtMessageFlags, DhtMessageHeader, NodeDestination},
envelope::{datetime_to_epochtime, datetime_to_timestamp, DhtMessageFlags, DhtMessageHeader, NodeDestination},
outbound::{
message::{DhtOutboundMessage, OutboundEncryption, SendFailure},
message_params::FinalSendMessageParams,
Expand Down Expand Up @@ -57,7 +58,7 @@ use tari_comms::{
};
use tari_crypto::{
keys::PublicKey,
tari_utilities::{message_format::MessageFormat, ByteArray},
tari_utilities::{epoch_time::EpochTime, message_format::MessageFormat, ByteArray},
};
use tari_utilities::hex::Hex;
use tower::{layer::Layer, Service, ServiceExt};
Expand Down Expand Up @@ -413,8 +414,17 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
expires: Option<DateTime<Utc>>,
) -> Result<(Vec<DhtOutboundMessage>, Vec<MessageSendState>), DhtOutboundError> {
let dht_flags = encryption.flags() | extra_flags;
let expires_epochtime = expires.map(datetime_to_epochtime);

let (ephemeral_public_key, origin_mac, body) = self.process_encryption(&encryption, force_origin, body)?;
let (ephemeral_public_key, origin_mac, body) = self.process_encryption(
&encryption,
force_origin,
&destination,
&dht_message_type,
&dht_flags,
expires_epochtime.as_ref(),
body,
)?;

if is_broadcast {
self.add_to_dedup_cache(&body, self.node_identity.public_key().clone())
Expand Down Expand Up @@ -462,10 +472,15 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
.map_err(|_| DhtOutboundError::FailedToInsertMessageHash)
}

#[allow(clippy::too_many_arguments)]
fn process_encryption(
&self,
encryption: &OutboundEncryption,
include_origin: bool,
destination: &NodeDestination,
message_type: &DhtMessageType,
flags: &DhtMessageFlags,
expires: Option<&EpochTime>,
body: Bytes,
) -> Result<FinalMessageParts, DhtOutboundError> {
match encryption {
Expand All @@ -477,8 +492,18 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
// Encrypt the message with the body
let encrypted_body = crypt::encrypt(&shared_ephemeral_secret, &body)?;

let mut header_mac_bytes = Vec::with_capacity(256);
header_mac_bytes.extend_from_slice(&DHT_MAJOR_VERSION.to_le_bytes());
header_mac_bytes.extend_from_slice(&DHT_MINOR_VERSION.to_le_bytes());
header_mac_bytes.extend_from_slice(destination.to_inner_bytes().as_slice());
header_mac_bytes.extend_from_slice(&(*message_type as i32).to_le_bytes());
header_mac_bytes.extend_from_slice(&flags.bits().to_le_bytes());
if let Some(t) = expires {
header_mac_bytes.extend_from_slice(&t.as_u64().to_le_bytes());
}
header_mac_bytes.extend_from_slice(e_pk.as_bytes());
// Sign the encrypted message
let origin_mac = create_origin_mac(&self.node_identity, &encrypted_body)?;
let origin_mac = create_origin_mac(&self.node_identity, header_mac_bytes.as_slice(), &encrypted_body)?;
// Encrypt and set the origin field
let encrypted_origin_mac = crypt::encrypt(&shared_ephemeral_secret, &origin_mac)?;
Ok((
Expand All @@ -491,7 +516,16 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
trace!(target: LOG_TARGET, "Encryption not requested for message");

if include_origin {
let origin_mac = create_origin_mac(&self.node_identity, &body)?;
let mut header_mac_bytes = Vec::with_capacity(256);
header_mac_bytes.extend_from_slice(&DHT_MAJOR_VERSION.to_le_bytes());
header_mac_bytes.extend_from_slice(&DHT_MINOR_VERSION.to_le_bytes());
header_mac_bytes.extend_from_slice(destination.to_inner_bytes().as_slice());
header_mac_bytes.extend_from_slice(&(*message_type as i32).to_le_bytes());
header_mac_bytes.extend_from_slice(&flags.bits().to_le_bytes());
if let Some(t) = expires {
header_mac_bytes.extend_from_slice(&t.as_u64().to_le_bytes());
}
let origin_mac = create_origin_mac(&self.node_identity, &header_mac_bytes, &body)?;
Ok((None, Some(origin_mac.into()), body))
} else {
Ok((None, None, body))
Expand All @@ -501,8 +535,14 @@ where S: Service<DhtOutboundMessage, Response = (), Error = PipelineError>
}
}

fn create_origin_mac(node_identity: &NodeIdentity, body: &[u8]) -> Result<Vec<u8>, DhtOutboundError> {
let signature = signature::sign(&mut OsRng, node_identity.secret_key().clone(), body)?;
fn create_origin_mac(
node_identity: &NodeIdentity,
mac_header: &[u8],
body: &[u8],
) -> Result<Vec<u8>, DhtOutboundError> {
let mac_body = [mac_header, body].concat();

let signature = signature::sign(&mut OsRng, node_identity.secret_key().clone(), mac_body)?;

let mac = OriginMac {
public_key: node_identity.public_key().to_vec(),
Expand Down
2 changes: 1 addition & 1 deletion comms/dht/src/outbound/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ mod test {
let msg = create_outbound_message(body);
assert_send_static_service(&serialize);

let service = serialize.ready_and().await.unwrap();
let service = serialize.ready().await.unwrap();
service.call(msg).await.unwrap();
let mut msg = spy.pop_request().unwrap();
let dht_envelope = DhtEnvelope::decode(&mut msg.body).unwrap();
Expand Down
4 changes: 3 additions & 1 deletion comms/dht/src/store_forward/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ mod test {
let mut service = ForwardLayer::new(oms, true).layer(spy.to_service::<PipelineError>());

let node_identity = make_node_identity();
let inbound_msg = make_dht_inbound_message(&node_identity, b"".to_vec(), DhtMessageFlags::empty(), false);
let inbound_msg =
make_dht_inbound_message(&node_identity, b"".to_vec(), DhtMessageFlags::empty(), false, false);
let msg = DecryptedDhtMessage::succeeded(
wrap_in_envelope_body!(Vec::new()),
Some(node_identity.public_key().clone()),
Expand All @@ -301,6 +302,7 @@ mod test {
sample_body.to_vec(),
DhtMessageFlags::empty(),
false,
false,
);
let header = inbound_msg.dht_header.clone();
let msg = DecryptedDhtMessage::failed(inbound_msg);
Expand Down
Loading

0 comments on commit b8cc0a3

Please sign in to comment.