Skip to content

Commit

Permalink
respond to review comments round 1
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmicexplorer committed May 16, 2021
1 parent e252bf2 commit 2648530
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 43 deletions.
9 changes: 8 additions & 1 deletion rust/protocol/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
//
// Copyright 2020 Signal Messenger, LLC.
// Copyright 2020, 2021 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//

use crate::curve::KeyType;

use std::convert::Infallible;
use std::error::Error;
use std::fmt;
use std::panic::UnwindSafe;
Expand Down Expand Up @@ -77,6 +78,12 @@ impl Error for SignalProtocolError {
}
}

impl From<Infallible> for SignalProtocolError {
fn from(_value: Infallible) -> SignalProtocolError {
unreachable!()
}
}

impl From<prost::DecodeError> for SignalProtocolError {
fn from(value: prost::DecodeError) -> SignalProtocolError {
SignalProtocolError::ProtobufDecodingError(value)
Expand Down
4 changes: 2 additions & 2 deletions rust/protocol/src/group_cipher.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// Copyright 2020 Signal Messenger, LLC.
// Copyright 2020-2021 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//

Expand Down Expand Up @@ -211,7 +211,7 @@ pub async fn create_sender_key_distribution_message<R: Rng + CryptoRng>(
let sender_key: [u8; 32] = csprng.gen();
let signing_key = KeyPair::generate(csprng);
sender_key_record.set_sender_key_state(
SENDERKEY_MESSAGE_CURRENT_VERSION,
SENDERKEY_MESSAGE_CURRENT_VERSION.into(),
chain_id,
iteration,
&sender_key,
Expand Down
10 changes: 4 additions & 6 deletions rust/protocol/src/kdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

use crate::{MessageVersion, Result};

use std::default::Default;

use hmac::{Hmac, Mac, NewMac};
use sha2::Sha256;

Expand All @@ -19,15 +17,15 @@ impl HKDF {
const HASH_OUTPUT_SIZE: usize = 32;

pub fn new() -> Result<Self> {
Self::new_for_version(MessageVersion::default())
Self::new_for_version(MessageVersion::V3)
}

pub fn new_for_version(message_version: MessageVersion) -> Result<Self> {
match message_version {
MessageVersion::Version2 => Ok(HKDF {
MessageVersion::V2 => Ok(HKDF {
iteration_start_offset: 0,
}),
MessageVersion::Version3 => Ok(HKDF {
MessageVersion::V3 => Ok(HKDF {
iteration_start_offset: 1,
}),
}
Expand Down Expand Up @@ -180,7 +178,7 @@ mod tests {
0x4a, 0xa9, 0xfd, 0xa8, 0x99, 0xda, 0xeb, 0xec,
];

let output = HKDF::new_for_version(MessageVersion::Version2)?.derive_salted_secrets(
let output = HKDF::new_for_version(MessageVersion::V2)?.derive_salted_secrets(
&ikm,
&salt,
&info,
Expand Down
69 changes: 38 additions & 31 deletions rust/protocol/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use crate::state::{PreKeyId, SignedPreKeyId};
use crate::{DeviceId, IdentityKey, PrivateKey, PublicKey, Result, SignalProtocolError};

use std::convert::{TryFrom, TryInto};
use std::default::Default;

use hmac::{Hmac, Mac, NewMac};
use num_enum::TryFromPrimitiveError;
Expand All @@ -18,22 +17,26 @@ use sha2::Sha256;
use subtle::ConstantTimeEq;
use uuid::Uuid;

pub const CIPHERTEXT_MESSAGE_CURRENT_VERSION: u8 = 3;
pub const SENDERKEY_MESSAGE_CURRENT_VERSION: u8 = 3;
pub const CIPHERTEXT_MESSAGE_CURRENT_VERSION: MessageVersion = MessageVersion::V3;
pub const SENDERKEY_MESSAGE_CURRENT_VERSION: MessageVersion = MessageVersion::V3;

/// A [u8] describing the version of the message chain format to use when starting a chain.
#[derive(Copy, Clone, Eq, PartialEq, Debug, num_enum::TryFromPrimitive, num_enum::IntoPrimitive)]
#[derive(
Copy,
Clone,
Ord,
PartialOrd,
Eq,
PartialEq,
Debug,
num_enum::TryFromPrimitive,
num_enum::IntoPrimitive,
)]
#[repr(u8)]
pub enum MessageVersion {
Version2 = 2,
V2 = 2,
/// **\[CURRENT\]**.
Version3 = CIPHERTEXT_MESSAGE_CURRENT_VERSION,
}

impl Default for MessageVersion {
fn default() -> Self {
Self::Version3
}
V3 = 3,
}

impl TryFrom<u32> for MessageVersion {
Expand Down Expand Up @@ -124,7 +127,8 @@ impl SignalMessage {
};
let mut serialized = vec![0u8; 1 + message.encoded_len() + Self::MAC_LENGTH];
let message_version_u8: u8 = message_version.into();
serialized[0] = ((message_version_u8 & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION;
let current_version_u8: u8 = CIPHERTEXT_MESSAGE_CURRENT_VERSION.into();
serialized[0] = ((message_version_u8 & 0xF) << 4) | current_version_u8;
message.encode(&mut &mut serialized[1..message.encoded_len() + 1])?;
let msg_len_for_mac = serialized.len() - Self::MAC_LENGTH;
let mac = Self::compute_mac(
Expand Down Expand Up @@ -232,15 +236,15 @@ impl TryFrom<&[u8]> for SignalMessage {
if value.len() < SignalMessage::MAC_LENGTH + 1 {
return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
}
let message_version = value[0] >> 4;
let message_version: MessageVersion = (value[0] >> 4).try_into()?;
if message_version < CIPHERTEXT_MESSAGE_CURRENT_VERSION {
return Err(SignalProtocolError::LegacyCiphertextVersion(
message_version,
message_version.into(),
));
}
if message_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
message_version,
message_version.into(),
));
}

Expand Down Expand Up @@ -303,7 +307,8 @@ impl PreKeySignalMessage {
};
let mut serialized = vec![0u8; 1 + proto_message.encoded_len()];
let message_version_u8: u8 = message_version.into();
serialized[0] = ((message_version_u8 & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION;
let current_version_u8: u8 = CIPHERTEXT_MESSAGE_CURRENT_VERSION.into();
serialized[0] = ((message_version_u8 & 0xF) << 4) | current_version_u8;
proto_message.encode(&mut &mut serialized[1..])?;
Ok(Self {
message_version,
Expand Down Expand Up @@ -372,15 +377,15 @@ impl TryFrom<&[u8]> for PreKeySignalMessage {
return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
}

let message_version = value[0] >> 4;
let message_version: MessageVersion = (value[0] >> 4).try_into()?;
if message_version < CIPHERTEXT_MESSAGE_CURRENT_VERSION {
return Err(SignalProtocolError::LegacyCiphertextVersion(
message_version,
message_version.into(),
));
}
if message_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
message_version,
message_version.into(),
));
}

Expand All @@ -402,8 +407,8 @@ impl TryFrom<&[u8]> for PreKeySignalMessage {
let base_key = PublicKey::deserialize(base_key.as_ref())?;

Ok(PreKeySignalMessage {
message_version: message_version.try_into()?,
registration_id: (proto_structure.registration_id.unwrap_or(0) as u32).into(),
message_version: message_version.into(),
registration_id: (proto_structure.registration_id.unwrap_or(0)).into(),
pre_key_id: proto_structure.pre_key_id.map(|id| id.into()),
signed_pre_key_id: signed_pre_key_id.into(),
base_key,
Expand Down Expand Up @@ -444,13 +449,14 @@ impl SenderKeyMessage {
};
let proto_message_len = proto_message.encoded_len();
let mut serialized = vec![0u8; 1 + proto_message_len + Self::SIGNATURE_LEN];
serialized[0] = ((message_version & 0xF) << 4) | SENDERKEY_MESSAGE_CURRENT_VERSION;
let current_senderkey_version_u8: u8 = SENDERKEY_MESSAGE_CURRENT_VERSION.into();
serialized[0] = ((message_version & 0xF) << 4) | current_senderkey_version_u8;
proto_message.encode(&mut &mut serialized[1..1 + proto_message_len])?;
let signature =
signature_key.calculate_signature(&serialized[..1 + proto_message_len], csprng)?;
serialized[1 + proto_message_len..].copy_from_slice(&signature[..]);
Ok(Self {
message_version: MessageVersion::default(),
message_version: MessageVersion::V3,
distribution_id,
chain_id,
iteration,
Expand Down Expand Up @@ -512,15 +518,15 @@ impl TryFrom<&[u8]> for SenderKeyMessage {
if value.len() < 1 + Self::SIGNATURE_LEN {
return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
}
let message_version = value[0] >> 4;
let message_version: MessageVersion = (value[0] >> 4).try_into()?;
if message_version < SENDERKEY_MESSAGE_CURRENT_VERSION {
return Err(SignalProtocolError::LegacyCiphertextVersion(
message_version,
message_version.into(),
));
}
if message_version > SENDERKEY_MESSAGE_CURRENT_VERSION {
return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
message_version,
message_version.into(),
));
}
let proto_structure =
Expand Down Expand Up @@ -580,7 +586,8 @@ impl SenderKeyDistributionMessage {
signing_key: Some(signing_key.serialize().to_vec()),
};
let mut serialized = vec![0u8; 1 + proto_message.encoded_len()];
serialized[0] = ((message_version & 0xF) << 4) | SENDERKEY_MESSAGE_CURRENT_VERSION;
let current_senderkey_version_u8: u8 = SENDERKEY_MESSAGE_CURRENT_VERSION.into();
serialized[0] = ((message_version & 0xF) << 4) | current_senderkey_version_u8;
proto_message.encode(&mut &mut serialized[1..])?;

Ok(Self {
Expand Down Expand Up @@ -645,16 +652,16 @@ impl TryFrom<&[u8]> for SenderKeyDistributionMessage {
return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
}

let message_version = value[0] >> 4;
let message_version: MessageVersion = (value[0] >> 4).try_into()?;

if message_version < SENDERKEY_MESSAGE_CURRENT_VERSION {
return Err(SignalProtocolError::LegacyCiphertextVersion(
message_version,
message_version.into(),
));
}
if message_version > SENDERKEY_MESSAGE_CURRENT_VERSION {
return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
message_version,
message_version.into(),
));
}

Expand Down
4 changes: 2 additions & 2 deletions rust/protocol/src/ratchet/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ mod tests {
0xa2, 0x46, 0xd1, 0x5d,
];

let chain_key = ChainKey::new(HKDF::new_for_version(MessageVersion::Version2)?, &seed, 0)?;
let chain_key = ChainKey::new(HKDF::new_for_version(MessageVersion::V2)?, &seed, 0)?;
assert_eq!(&seed, chain_key.key());
assert_eq!(&message_key, chain_key.message_keys()?.cipher_key());
assert_eq!(&mac_key, chain_key.message_keys()?.mac_key());
Expand Down Expand Up @@ -287,7 +287,7 @@ mod tests {
let alice_private_key = PrivateKey::deserialize(&alice_private)?;
let bob_public_key = PublicKey::deserialize(&bob_public)?;
let root_key = RootKey::new(
HKDF::new_for_version(MessageVersion::Version2)?,
HKDF::new_for_version(MessageVersion::V2)?,
&root_key_seed,
)?;

Expand Down
2 changes: 1 addition & 1 deletion rust/protocol/src/state/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ impl SessionState {
) -> Result<()> {
let signed_pre_key_id: u32 = signed_pre_key_id.into();
let pending = session_structure::PendingPreKey {
pre_key_id: pre_key_id.unwrap_or(0.into()).into(),
pre_key_id: pre_key_id.map(|id| id.into()).unwrap_or(0),
signed_pre_key_id: signed_pre_key_id as i32,
base_key: base_key.serialize().to_vec(),
};
Expand Down

0 comments on commit 2648530

Please sign in to comment.