diff --git a/rust/protocol/src/kdf.rs b/rust/protocol/src/kdf.rs index 5637b6ee41..c20b10ea69 100644 --- a/rust/protocol/src/kdf.rs +++ b/rust/protocol/src/kdf.rs @@ -1,9 +1,11 @@ // -// Copyright 2020 Signal Messenger, LLC. +// Copyright 2020-2021 Signal Messenger, LLC. // SPDX-License-Identifier: AGPL-3.0-only // -use crate::{Result, SignalProtocolError}; +use crate::{MessageVersion, Result}; + +use std::default::Default; use hmac::{Hmac, Mac, NewMac}; use sha2::Sha256; @@ -16,17 +18,18 @@ pub struct HKDF { impl HKDF { const HASH_OUTPUT_SIZE: usize = 32; - pub fn new(message_version: u32) -> Result { + pub fn new() -> Result { + Self::new_for_version(MessageVersion::default()) + } + + pub fn new_for_version(message_version: MessageVersion) -> Result { match message_version { - 2 => Ok(HKDF { + MessageVersion::Version2 => Ok(HKDF { iteration_start_offset: 0, }), - 3 => Ok(HKDF { + MessageVersion::Version3 => Ok(HKDF { iteration_start_offset: 1, }), - _ => Err(SignalProtocolError::UnrecognizedMessageVersion( - message_version, - )), } } @@ -92,6 +95,7 @@ impl HKDF { #[cfg(test)] mod tests { use super::*; + use crate::MessageVersion; #[test] fn test_vector_v3() -> Result<()> { @@ -109,7 +113,7 @@ mod tests { 0xec, 0xc4, 0xc5, 0xbf, 0x34, 0x00, 0x72, 0x08, 0xd5, 0xb8, 0x87, 0x18, 0x58, 0x65, ]; - let output = HKDF::new(3)?.derive_salted_secrets(&ikm, &salt, &info, okm.len())?; + let output = HKDF::new()?.derive_salted_secrets(&ikm, &salt, &info, okm.len())?; assert_eq!(&okm[..], &output[..]); @@ -151,7 +155,7 @@ mod tests { 0x3e, 0x87, 0xc1, 0x4c, 0x01, 0xd5, 0xc1, 0xf3, 0x43, 0x4f, 0x1d, 0x87, ]; - let output = HKDF::new(3)?.derive_salted_secrets(&ikm, &salt, &info, okm.len())?; + let output = HKDF::new()?.derive_salted_secrets(&ikm, &salt, &info, okm.len())?; assert_eq!(&okm[..], &output[..]); @@ -176,7 +180,12 @@ mod tests { 0x4a, 0xa9, 0xfd, 0xa8, 0x99, 0xda, 0xeb, 0xec, ]; - let output = HKDF::new(2)?.derive_salted_secrets(&ikm, &salt, &info, okm.len())?; + let output = HKDF::new_for_version(MessageVersion::Version2)?.derive_salted_secrets( + &ikm, + &salt, + &info, + okm.len(), + )?; assert_eq!(&okm[..], &output[..]); diff --git a/rust/protocol/src/lib.rs b/rust/protocol/src/lib.rs index 7fc2615878..a56ce532f0 100644 --- a/rust/protocol/src/lib.rs +++ b/rust/protocol/src/lib.rs @@ -40,7 +40,7 @@ pub use { identity_key::{IdentityKey, IdentityKeyPair}, kdf::HKDF, protocol::{ - CiphertextMessage, CiphertextMessageType, PreKeySignalMessage, + CiphertextMessage, CiphertextMessageType, MessageVersion, PreKeySignalMessage, SenderKeyDistributionMessage, SenderKeyMessage, SignalMessage, }, ratchet::{ diff --git a/rust/protocol/src/protocol.rs b/rust/protocol/src/protocol.rs index f2ebbef631..43972cf0cf 100644 --- a/rust/protocol/src/protocol.rs +++ b/rust/protocol/src/protocol.rs @@ -4,11 +4,14 @@ // use crate::proto; -use crate::{IdentityKey, PrivateKey, PublicKey, Result, SignalProtocolError}; +use crate::state::{PreKeyId, SignedPreKeyId}; +use crate::{DeviceId, IdentityKey, PrivateKey, PublicKey, Result, SignalProtocolError}; -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; +use std::default::Default; use hmac::{Hmac, Mac, NewMac}; +use num_enum::TryFromPrimitiveError; use prost::Message; use rand::{CryptoRng, Rng}; use sha2::Sha256; @@ -17,6 +20,44 @@ use uuid::Uuid; pub const CIPHERTEXT_MESSAGE_CURRENT_VERSION: u8 = 3; +/// A [u8] describing the version of the message chain format to use when starting a chain. +#[derive(Copy, Clone, Eq, PartialEq, Debug, num_enum::TryFromPrimitive, num_enum::IntoPrimitive)] +#[repr(u8)] +pub enum MessageVersion { + Version2 = 2, + /// **\[CURRENT\]**. + Version3 = CIPHERTEXT_MESSAGE_CURRENT_VERSION, +} + +impl Default for MessageVersion { + fn default() -> Self { + Self::Version3 + } +} + +impl TryFrom for MessageVersion { + type Error = SignalProtocolError; + fn try_from(value: u32) -> Result { + let value_u8: u8 = value + .try_into() + .map_err(|_| SignalProtocolError::UnrecognizedMessageVersion(value))?; + Ok(Self::try_from(value_u8)?) + } +} + +impl From for u32 { + fn from(value: MessageVersion) -> u32 { + let value_u8: u8 = value.into(); + value_u8 as u32 + } +} + +impl From> for SignalProtocolError { + fn from(value: TryFromPrimitiveError) -> SignalProtocolError { + SignalProtocolError::UnrecognizedMessageVersion(value.number.into()) + } +} + pub enum CiphertextMessage { SignalMessage(SignalMessage), PreKeySignalMessage(PreKeySignalMessage), @@ -52,7 +93,7 @@ impl CiphertextMessage { #[derive(Debug, Clone)] pub struct SignalMessage { - message_version: u8, + message_version: MessageVersion, sender_ratchet_key: PublicKey, counter: u32, #[allow(dead_code)] @@ -65,7 +106,7 @@ impl SignalMessage { const MAC_LENGTH: usize = 8; pub fn new( - message_version: u8, + message_version: MessageVersion, mac_key: &[u8], sender_ratchet_key: PublicKey, counter: u32, @@ -81,7 +122,8 @@ impl SignalMessage { ciphertext: Some(Vec::::from(&ciphertext[..])), }; let mut serialized = vec![0u8; 1 + message.encoded_len() + Self::MAC_LENGTH]; - serialized[0] = ((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION; + let message_version_u8: u8 = message_version.into(); + serialized[0] = ((message_version_u8 & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION; message.encode(&mut &mut serialized[1..message.encoded_len() + 1])?; let msg_len_for_mac = serialized.len() - Self::MAC_LENGTH; let mac = Self::compute_mac( @@ -103,7 +145,7 @@ impl SignalMessage { } #[inline] - pub fn message_version(&self) -> u8 { + pub fn message_version(&self) -> MessageVersion { self.message_version } @@ -218,7 +260,7 @@ impl TryFrom<&[u8]> for SignalMessage { .into_boxed_slice(); Ok(SignalMessage { - message_version, + message_version: message_version.try_into()?, sender_ratchet_key, counter, previous_counter, @@ -230,10 +272,10 @@ impl TryFrom<&[u8]> for SignalMessage { #[derive(Debug, Clone)] pub struct PreKeySignalMessage { - message_version: u8, - registration_id: u32, - pre_key_id: Option, - signed_pre_key_id: u32, + message_version: MessageVersion, + registration_id: DeviceId, + pre_key_id: Option, + signed_pre_key_id: SignedPreKeyId, base_key: PublicKey, identity_key: IdentityKey, message: SignalMessage, @@ -242,24 +284,25 @@ pub struct PreKeySignalMessage { impl PreKeySignalMessage { pub fn new( - message_version: u8, - registration_id: u32, - pre_key_id: Option, - signed_pre_key_id: u32, + message_version: MessageVersion, + registration_id: DeviceId, + pre_key_id: Option, + signed_pre_key_id: SignedPreKeyId, base_key: PublicKey, identity_key: IdentityKey, message: SignalMessage, ) -> Result { let proto_message = proto::wire::PreKeySignalMessage { - registration_id: Some(registration_id), - pre_key_id, - signed_pre_key_id: Some(signed_pre_key_id), + registration_id: Some(registration_id.into()), + pre_key_id: pre_key_id.map(|id| id.into()), + signed_pre_key_id: Some(signed_pre_key_id.into()), base_key: Some(base_key.serialize().into_vec()), identity_key: Some(identity_key.serialize().into_vec()), message: Some(Vec::from(message.as_ref())), }; let mut serialized = vec![0u8; 1 + proto_message.encoded_len()]; - serialized[0] = ((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION; + let message_version_u8: u8 = message_version.into(); + serialized[0] = ((message_version_u8 & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION; proto_message.encode(&mut &mut serialized[1..])?; Ok(Self { message_version, @@ -274,22 +317,22 @@ impl PreKeySignalMessage { } #[inline] - pub fn message_version(&self) -> u8 { + pub fn message_version(&self) -> MessageVersion { self.message_version } #[inline] - pub fn registration_id(&self) -> u32 { + pub fn registration_id(&self) -> DeviceId { self.registration_id } #[inline] - pub fn pre_key_id(&self) -> Option { + pub fn pre_key_id(&self) -> Option { self.pre_key_id } #[inline] - pub fn signed_pre_key_id(&self) -> u32 { + pub fn signed_pre_key_id(&self) -> SignedPreKeyId { self.signed_pre_key_id } @@ -358,10 +401,10 @@ impl TryFrom<&[u8]> for PreKeySignalMessage { let base_key = PublicKey::deserialize(base_key.as_ref())?; Ok(PreKeySignalMessage { - message_version, - registration_id: proto_structure.registration_id.unwrap_or(0), - pre_key_id: proto_structure.pre_key_id, - signed_pre_key_id, + message_version: message_version.try_into()?, + registration_id: (proto_structure.registration_id.unwrap_or(0) as u32).into(), + pre_key_id: proto_structure.pre_key_id.map(|id| id.into()), + signed_pre_key_id: signed_pre_key_id.into(), base_key, identity_key: IdentityKey::try_from(identity_key.as_ref())?, message: SignalMessage::try_from(message.as_ref())?, @@ -372,7 +415,7 @@ impl TryFrom<&[u8]> for PreKeySignalMessage { #[derive(Debug, Clone)] pub struct SenderKeyMessage { - message_version: u8, + message_version: MessageVersion, distribution_id: Uuid, chain_id: u32, iteration: u32, @@ -406,7 +449,7 @@ impl SenderKeyMessage { signature_key.calculate_signature(&serialized[..1 + proto_message_len], csprng)?; serialized[1 + proto_message_len..].copy_from_slice(&signature[..]); Ok(Self { - message_version: CIPHERTEXT_MESSAGE_CURRENT_VERSION, + message_version: MessageVersion::default(), distribution_id, chain_id, iteration, @@ -425,7 +468,7 @@ impl SenderKeyMessage { } #[inline] - pub fn message_version(&self) -> u8 { + pub fn message_version(&self) -> MessageVersion { self.message_version } @@ -498,7 +541,7 @@ impl TryFrom<&[u8]> for SenderKeyMessage { .into_boxed_slice(); Ok(SenderKeyMessage { - message_version, + message_version: message_version.try_into()?, distribution_id, chain_id, iteration, @@ -510,7 +553,7 @@ impl TryFrom<&[u8]> for SenderKeyMessage { #[derive(Debug, Clone)] pub struct SenderKeyDistributionMessage { - message_version: u8, + message_version: MessageVersion, distribution_id: Uuid, chain_id: u32, iteration: u32, @@ -540,7 +583,7 @@ impl SenderKeyDistributionMessage { proto_message.encode(&mut &mut serialized[1..])?; Ok(Self { - message_version, + message_version: message_version.try_into()?, distribution_id, chain_id, iteration, @@ -551,7 +594,7 @@ impl SenderKeyDistributionMessage { } #[inline] - pub fn message_version(&self) -> u8 { + pub fn message_version(&self) -> MessageVersion { self.message_version } @@ -640,7 +683,7 @@ impl TryFrom<&[u8]> for SenderKeyDistributionMessage { let signing_key = PublicKey::deserialize(&signing_key)?; Ok(SenderKeyDistributionMessage { - message_version, + message_version: message_version.try_into()?, distribution_id, chain_id, iteration, @@ -676,7 +719,7 @@ mod tests { let receiver_identity_key_pair = KeyPair::generate(csprng); SignalMessage::new( - 3, + MessageVersion::default(), &mac_key, sender_ratchet_key_pair.public_key, 42, @@ -713,10 +756,10 @@ mod tests { let base_key_pair = KeyPair::generate(&mut csprng); let message = create_signal_message(&mut csprng)?; let pre_key_signal_message = PreKeySignalMessage::new( - 3, - 365, + MessageVersion::default(), + 365.into(), None, - 97, + 97.into(), base_key_pair.public_key, identity_key_pair.public_key.into(), message, diff --git a/rust/protocol/src/ratchet.rs b/rust/protocol/src/ratchet.rs index 41a78eda30..d243589958 100644 --- a/rust/protocol/src/ratchet.rs +++ b/rust/protocol/src/ratchet.rs @@ -1,5 +1,5 @@ // -// Copyright 2020 Signal Messenger, LLC. +// Copyright 2020-2021 Signal Messenger, LLC. // SPDX-License-Identifier: AGPL-3.0-only // @@ -15,7 +15,7 @@ use crate::{KeyPair, Result, SessionRecord}; use rand::{CryptoRng, Rng}; fn derive_keys(secret_input: &[u8]) -> Result<(RootKey, ChainKey)> { - let kdf = crate::kdf::HKDF::new(3)?; + let kdf = crate::kdf::HKDF::new()?; let secrets = kdf.derive_secrets(secret_input, b"WhisperText", 64)?; diff --git a/rust/protocol/src/ratchet/keys.rs b/rust/protocol/src/ratchet/keys.rs index 66bf85e55f..49de2d937f 100644 --- a/rust/protocol/src/ratchet/keys.rs +++ b/rust/protocol/src/ratchet/keys.rs @@ -1,5 +1,5 @@ // -// Copyright 2020 Signal Messenger, LLC. +// Copyright 2020-2021 Signal Messenger, LLC. // SPDX-License-Identifier: AGPL-3.0-only // @@ -177,7 +177,7 @@ impl fmt::Display for RootKey { #[cfg(test)] mod tests { use super::*; - use crate::{PrivateKey, PublicKey}; + use crate::{MessageVersion, PrivateKey, PublicKey}; #[test] fn test_chain_key_derivation_v2() -> Result<()> { @@ -202,7 +202,7 @@ mod tests { 0xa2, 0x46, 0xd1, 0x5d, ]; - let chain_key = ChainKey::new(HKDF::new(2)?, &seed, 0)?; + let chain_key = ChainKey::new(HKDF::new_for_version(MessageVersion::Version2)?, &seed, 0)?; assert_eq!(&seed, chain_key.key()); assert_eq!(&message_key, chain_key.message_keys()?.cipher_key()); assert_eq!(&mac_key, chain_key.message_keys()?.mac_key()); @@ -237,7 +237,7 @@ mod tests { 0xa2, 0x46, 0xd1, 0x5d, ]; - let chain_key = ChainKey::new(HKDF::new(3)?, &seed, 0)?; + let chain_key = ChainKey::new(HKDF::new()?, &seed, 0)?; assert_eq!(&seed, chain_key.key()); assert_eq!(&message_key, chain_key.message_keys()?.cipher_key()); assert_eq!(&mac_key, chain_key.message_keys()?.mac_key()); @@ -286,7 +286,10 @@ mod tests { let alice_private_key = PrivateKey::deserialize(&alice_private)?; let bob_public_key = PublicKey::deserialize(&bob_public)?; - let root_key = RootKey::new(HKDF::new(2)?, &root_key_seed)?; + let root_key = RootKey::new( + HKDF::new_for_version(MessageVersion::Version2)?, + &root_key_seed, + )?; let (next_root_key, next_chain_key) = root_key.create_chain(&bob_public_key, &alice_private_key)?; diff --git a/rust/protocol/src/sealed_sender.rs b/rust/protocol/src/sealed_sender.rs index 31e354804d..64c049a4e0 100644 --- a/rust/protocol/src/sealed_sender.rs +++ b/rust/protocol/src/sealed_sender.rs @@ -627,7 +627,7 @@ mod sealed_sender_v1 { } let shared_secret = our_private.calculate_agreement(their_public)?; - let kdf = HKDF::new(3)?; + let kdf = HKDF::new()?; let derived_values = kdf.derive_salted_secrets(&shared_secret, &ephemeral_salt, &[], 96)?; @@ -662,7 +662,7 @@ mod sealed_sender_v1 { salt.extend_from_slice(ctext); let shared_secret = our_private.calculate_agreement(their_public)?; - let kdf = HKDF::new(3)?; + let kdf = HKDF::new()?; // 96 bytes are derived but the first 32 are discarded/unused let derived_values = kdf.derive_salted_secrets(&shared_secret, &salt, &[], 96)?; @@ -798,7 +798,7 @@ mod sealed_sender_v2 { impl DerivedKeys { pub(super) fn calculate(m: &[u8]) -> DerivedKeys { - let kdf = HKDF::new(3).expect("valid KDF version"); + let kdf = HKDF::new().expect("valid KDF version"); let r = kdf .derive_secrets(&m, LABEL_R, 64) .expect("valid use of KDF"); @@ -835,7 +835,7 @@ mod sealed_sender_v2 { } .concat(); - let mut result = HKDF::new(3)?.derive_secrets(&agreement_key_input, LABEL_DH, 32)?; + let mut result = HKDF::new()?.derive_secrets(&agreement_key_input, LABEL_DH, 32)?; result .iter_mut() .zip(input) @@ -865,7 +865,7 @@ mod sealed_sender_v2 { } } - HKDF::new(3)?.derive_secrets(&agreement_key_input, LABEL_DH_S, 16) + HKDF::new()?.derive_secrets(&agreement_key_input, LABEL_DH_S, 16) } } diff --git a/rust/protocol/src/sender_keys.rs b/rust/protocol/src/sender_keys.rs index 73ae2b4923..45cbd59c8c 100644 --- a/rust/protocol/src/sender_keys.rs +++ b/rust/protocol/src/sender_keys.rs @@ -1,5 +1,5 @@ // -// Copyright 2020 Signal Messenger, LLC. +// Copyright 2020-2021 Signal Messenger, LLC. // SPDX-License-Identifier: AGPL-3.0-only // @@ -22,7 +22,7 @@ pub struct SenderMessageKey { impl SenderMessageKey { pub fn new(iteration: u32, seed: Vec) -> Result { - let hkdf = HKDF::new(3)?; + let hkdf = HKDF::new()?; let derived = hkdf.derive_secrets(&seed, b"WhisperGroup", 48)?; Ok(Self { iteration, diff --git a/rust/protocol/src/session_cipher.rs b/rust/protocol/src/session_cipher.rs index 219bb08f89..232398f8c9 100644 --- a/rust/protocol/src/session_cipher.rs +++ b/rust/protocol/src/session_cipher.rs @@ -3,17 +3,18 @@ // SPDX-License-Identifier: AGPL-3.0-only // +use crate::consts::MAX_FORWARD_JUMPS; +use crate::crypto; +use crate::ratchet::{ChainKey, MessageKeys}; +use crate::session; +use crate::state::SessionState; use crate::{ CiphertextMessage, Context, Direction, IdentityKeyStore, KeyPair, PreKeySignalMessage, PreKeyStore, ProtocolAddress, PublicKey, Result, SessionRecord, SessionStore, SignalMessage, SignalProtocolError, SignedPreKeyStore, }; -use crate::consts::MAX_FORWARD_JUMPS; -use crate::crypto; -use crate::ratchet::{ChainKey, MessageKeys}; -use crate::session; -use crate::state::SessionState; +use std::convert::TryInto; use rand::{CryptoRng, Rng}; @@ -57,7 +58,7 @@ pub async fn message_encrypt( ); let message = SignalMessage::new( - session_version, + session_version.try_into()?, message_keys.mac_key(), sender_ephemeral, chain_key.index(), @@ -68,8 +69,8 @@ pub async fn message_encrypt( )?; CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::new( - session_version, - local_registration_id, + session_version.try_into()?, + local_registration_id.into(), items.pre_key_id()?, items.signed_pre_key_id()?, *items.base_key()?, @@ -78,7 +79,7 @@ pub async fn message_encrypt( )?) } else { CiphertextMessage::SignalMessage(SignalMessage::new( - session_version, + session_version.try_into()?, message_keys.mac_key(), sender_ephemeral, chain_key.index(), @@ -475,10 +476,10 @@ fn decrypt_message_with_state( )); } - let ciphertext_version = ciphertext.message_version() as u32; - if ciphertext_version != state.session_version()? { + let ciphertext_version: u8 = ciphertext.message_version().into(); + if ciphertext_version as u32 != state.session_version()? { return Err(SignalProtocolError::UnrecognizedMessageVersion( - ciphertext_version, + ciphertext_version.into(), )); } diff --git a/rust/protocol/src/state/session.rs b/rust/protocol/src/state/session.rs index ad5f4c2122..af18496658 100644 --- a/rust/protocol/src/state/session.rs +++ b/rust/protocol/src/state/session.rs @@ -10,9 +10,11 @@ use crate::consts; use crate::proto::storage::session_structure; use crate::proto::storage::{RecordStructure, SessionStructure}; use crate::state::{PreKeyId, SignedPreKeyId}; + use prost::Message; use std::collections::VecDeque; +use std::convert::TryInto; #[derive(Debug, Clone)] pub(crate) struct UnacknowledgedPreKeyMessageItems { @@ -119,7 +121,7 @@ impl SessionState { if self.session.root_key.len() != 32 { return Err(SignalProtocolError::InvalidProtobufEncoding); } - let hkdf = HKDF::new(self.session_version()?)?; + let hkdf = HKDF::new_for_version(self.session_version()?.try_into()?)?; RootKey::new(hkdf, &self.session.root_key) } @@ -198,7 +200,7 @@ impl SessionState { if c.key.len() != 32 { return Err(SignalProtocolError::InvalidProtobufEncoding); } - let hkdf = HKDF::new(self.session_version()?)?; + let hkdf = HKDF::new_for_version(self.session_version()?.try_into()?)?; Ok(Some(ChainKey::new(hkdf, &c.key, c.index)?)) } }, @@ -268,7 +270,7 @@ impl SessionState { SignalProtocolError::InvalidState("get_sender_chain_key", "No chain key".to_owned()) })?; - let hkdf = HKDF::new(self.session_version()?)?; + let hkdf = HKDF::new_for_version(self.session_version()?.try_into()?)?; ChainKey::new(hkdf, &chain_key.key, chain_key.index) }