Skip to content

Commit

Permalink
introduce MessageVersion enum and HKDF::new() with no args
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmicexplorer committed May 5, 2021
1 parent 6f1badf commit 8a01414
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 80 deletions.
31 changes: 20 additions & 11 deletions rust/protocol/src/kdf.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
//
// Copyright 2020 Signal Messenger, LLC.
// Copyright 2020-2021 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//

use crate::{Result, SignalProtocolError};
use crate::{MessageVersion, Result};

use std::default::Default;

use hmac::{Hmac, Mac, NewMac};
use sha2::Sha256;
Expand All @@ -16,17 +18,18 @@ pub struct HKDF {
impl HKDF {
const HASH_OUTPUT_SIZE: usize = 32;

pub fn new(message_version: u32) -> Result<Self> {
pub fn new() -> Result<Self> {
Self::new_for_version(MessageVersion::default())
}

pub fn new_for_version(message_version: MessageVersion) -> Result<Self> {
match message_version {
2 => Ok(HKDF {
MessageVersion::Version2 => Ok(HKDF {
iteration_start_offset: 0,
}),
3 => Ok(HKDF {
MessageVersion::Version3 => Ok(HKDF {
iteration_start_offset: 1,
}),
_ => Err(SignalProtocolError::UnrecognizedMessageVersion(
message_version,
)),
}
}

Expand Down Expand Up @@ -92,6 +95,7 @@ impl HKDF {
#[cfg(test)]
mod tests {
use super::*;
use crate::MessageVersion;

#[test]
fn test_vector_v3() -> Result<()> {
Expand All @@ -109,7 +113,7 @@ mod tests {
0xec, 0xc4, 0xc5, 0xbf, 0x34, 0x00, 0x72, 0x08, 0xd5, 0xb8, 0x87, 0x18, 0x58, 0x65,
];

let output = HKDF::new(3)?.derive_salted_secrets(&ikm, &salt, &info, okm.len())?;
let output = HKDF::new()?.derive_salted_secrets(&ikm, &salt, &info, okm.len())?;

assert_eq!(&okm[..], &output[..]);

Expand Down Expand Up @@ -151,7 +155,7 @@ mod tests {
0x3e, 0x87, 0xc1, 0x4c, 0x01, 0xd5, 0xc1, 0xf3, 0x43, 0x4f, 0x1d, 0x87,
];

let output = HKDF::new(3)?.derive_salted_secrets(&ikm, &salt, &info, okm.len())?;
let output = HKDF::new()?.derive_salted_secrets(&ikm, &salt, &info, okm.len())?;

assert_eq!(&okm[..], &output[..]);

Expand All @@ -176,7 +180,12 @@ mod tests {
0x4a, 0xa9, 0xfd, 0xa8, 0x99, 0xda, 0xeb, 0xec,
];

let output = HKDF::new(2)?.derive_salted_secrets(&ikm, &salt, &info, okm.len())?;
let output = HKDF::new_for_version(MessageVersion::Version2)?.derive_salted_secrets(
&ikm,
&salt,
&info,
okm.len(),
)?;

assert_eq!(&okm[..], &output[..]);

Expand Down
2 changes: 1 addition & 1 deletion rust/protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub use {
identity_key::{IdentityKey, IdentityKeyPair},
kdf::HKDF,
protocol::{
CiphertextMessage, CiphertextMessageType, PreKeySignalMessage,
CiphertextMessage, CiphertextMessageType, MessageVersion, PreKeySignalMessage,
SenderKeyDistributionMessage, SenderKeyMessage, SignalMessage,
},
ratchet::{
Expand Down
121 changes: 82 additions & 39 deletions rust/protocol/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
//

use crate::proto;
use crate::{IdentityKey, PrivateKey, PublicKey, Result, SignalProtocolError};
use crate::state::{PreKeyId, SignedPreKeyId};
use crate::{DeviceId, IdentityKey, PrivateKey, PublicKey, Result, SignalProtocolError};

use std::convert::TryFrom;
use std::convert::{TryFrom, TryInto};
use std::default::Default;

use hmac::{Hmac, Mac, NewMac};
use num_enum::TryFromPrimitiveError;
use prost::Message;
use rand::{CryptoRng, Rng};
use sha2::Sha256;
Expand All @@ -17,6 +20,44 @@ use uuid::Uuid;

pub const CIPHERTEXT_MESSAGE_CURRENT_VERSION: u8 = 3;

/// A [u8] describing the version of the message chain format to use when starting a chain.
#[derive(Copy, Clone, Eq, PartialEq, Debug, num_enum::TryFromPrimitive, num_enum::IntoPrimitive)]
#[repr(u8)]
pub enum MessageVersion {
Version2 = 2,
/// **\[CURRENT\]**.
Version3 = CIPHERTEXT_MESSAGE_CURRENT_VERSION,
}

impl Default for MessageVersion {
fn default() -> Self {
Self::Version3
}
}

impl TryFrom<u32> for MessageVersion {
type Error = SignalProtocolError;
fn try_from(value: u32) -> Result<Self> {
let value_u8: u8 = value
.try_into()
.map_err(|_| SignalProtocolError::UnrecognizedMessageVersion(value))?;
Ok(Self::try_from(value_u8)?)
}
}

impl From<MessageVersion> for u32 {
fn from(value: MessageVersion) -> u32 {
let value_u8: u8 = value.into();
value_u8 as u32
}
}

impl From<TryFromPrimitiveError<MessageVersion>> for SignalProtocolError {
fn from(value: TryFromPrimitiveError<MessageVersion>) -> SignalProtocolError {
SignalProtocolError::UnrecognizedMessageVersion(value.number.into())
}
}

pub enum CiphertextMessage {
SignalMessage(SignalMessage),
PreKeySignalMessage(PreKeySignalMessage),
Expand Down Expand Up @@ -52,7 +93,7 @@ impl CiphertextMessage {

#[derive(Debug, Clone)]
pub struct SignalMessage {
message_version: u8,
message_version: MessageVersion,
sender_ratchet_key: PublicKey,
counter: u32,
#[allow(dead_code)]
Expand All @@ -65,7 +106,7 @@ impl SignalMessage {
const MAC_LENGTH: usize = 8;

pub fn new(
message_version: u8,
message_version: MessageVersion,
mac_key: &[u8],
sender_ratchet_key: PublicKey,
counter: u32,
Expand All @@ -81,7 +122,8 @@ impl SignalMessage {
ciphertext: Some(Vec::<u8>::from(&ciphertext[..])),
};
let mut serialized = vec![0u8; 1 + message.encoded_len() + Self::MAC_LENGTH];
serialized[0] = ((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION;
let message_version_u8: u8 = message_version.into();
serialized[0] = ((message_version_u8 & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION;
message.encode(&mut &mut serialized[1..message.encoded_len() + 1])?;
let msg_len_for_mac = serialized.len() - Self::MAC_LENGTH;
let mac = Self::compute_mac(
Expand All @@ -103,7 +145,7 @@ impl SignalMessage {
}

#[inline]
pub fn message_version(&self) -> u8 {
pub fn message_version(&self) -> MessageVersion {
self.message_version
}

Expand Down Expand Up @@ -218,7 +260,7 @@ impl TryFrom<&[u8]> for SignalMessage {
.into_boxed_slice();

Ok(SignalMessage {
message_version,
message_version: message_version.try_into()?,
sender_ratchet_key,
counter,
previous_counter,
Expand All @@ -230,10 +272,10 @@ impl TryFrom<&[u8]> for SignalMessage {

#[derive(Debug, Clone)]
pub struct PreKeySignalMessage {
message_version: u8,
registration_id: u32,
pre_key_id: Option<u32>,
signed_pre_key_id: u32,
message_version: MessageVersion,
registration_id: DeviceId,
pre_key_id: Option<PreKeyId>,
signed_pre_key_id: SignedPreKeyId,
base_key: PublicKey,
identity_key: IdentityKey,
message: SignalMessage,
Expand All @@ -242,24 +284,25 @@ pub struct PreKeySignalMessage {

impl PreKeySignalMessage {
pub fn new(
message_version: u8,
registration_id: u32,
pre_key_id: Option<u32>,
signed_pre_key_id: u32,
message_version: MessageVersion,
registration_id: DeviceId,
pre_key_id: Option<PreKeyId>,
signed_pre_key_id: SignedPreKeyId,
base_key: PublicKey,
identity_key: IdentityKey,
message: SignalMessage,
) -> Result<Self> {
let proto_message = proto::wire::PreKeySignalMessage {
registration_id: Some(registration_id),
pre_key_id,
signed_pre_key_id: Some(signed_pre_key_id),
registration_id: Some(registration_id.into()),
pre_key_id: pre_key_id.map(|id| id.into()),
signed_pre_key_id: Some(signed_pre_key_id.into()),
base_key: Some(base_key.serialize().into_vec()),
identity_key: Some(identity_key.serialize().into_vec()),
message: Some(Vec::from(message.as_ref())),
};
let mut serialized = vec![0u8; 1 + proto_message.encoded_len()];
serialized[0] = ((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION;
let message_version_u8: u8 = message_version.into();
serialized[0] = ((message_version_u8 & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION;
proto_message.encode(&mut &mut serialized[1..])?;
Ok(Self {
message_version,
Expand All @@ -274,22 +317,22 @@ impl PreKeySignalMessage {
}

#[inline]
pub fn message_version(&self) -> u8 {
pub fn message_version(&self) -> MessageVersion {
self.message_version
}

#[inline]
pub fn registration_id(&self) -> u32 {
pub fn registration_id(&self) -> DeviceId {
self.registration_id
}

#[inline]
pub fn pre_key_id(&self) -> Option<u32> {
pub fn pre_key_id(&self) -> Option<PreKeyId> {
self.pre_key_id
}

#[inline]
pub fn signed_pre_key_id(&self) -> u32 {
pub fn signed_pre_key_id(&self) -> SignedPreKeyId {
self.signed_pre_key_id
}

Expand Down Expand Up @@ -358,10 +401,10 @@ impl TryFrom<&[u8]> for PreKeySignalMessage {
let base_key = PublicKey::deserialize(base_key.as_ref())?;

Ok(PreKeySignalMessage {
message_version,
registration_id: proto_structure.registration_id.unwrap_or(0),
pre_key_id: proto_structure.pre_key_id,
signed_pre_key_id,
message_version: message_version.try_into()?,
registration_id: (proto_structure.registration_id.unwrap_or(0) as u32).into(),
pre_key_id: proto_structure.pre_key_id.map(|id| id.into()),
signed_pre_key_id: signed_pre_key_id.into(),
base_key,
identity_key: IdentityKey::try_from(identity_key.as_ref())?,
message: SignalMessage::try_from(message.as_ref())?,
Expand All @@ -372,7 +415,7 @@ impl TryFrom<&[u8]> for PreKeySignalMessage {

#[derive(Debug, Clone)]
pub struct SenderKeyMessage {
message_version: u8,
message_version: MessageVersion,
distribution_id: Uuid,
chain_id: u32,
iteration: u32,
Expand Down Expand Up @@ -406,7 +449,7 @@ impl SenderKeyMessage {
signature_key.calculate_signature(&serialized[..1 + proto_message_len], csprng)?;
serialized[1 + proto_message_len..].copy_from_slice(&signature[..]);
Ok(Self {
message_version: CIPHERTEXT_MESSAGE_CURRENT_VERSION,
message_version: MessageVersion::default(),
distribution_id,
chain_id,
iteration,
Expand All @@ -425,7 +468,7 @@ impl SenderKeyMessage {
}

#[inline]
pub fn message_version(&self) -> u8 {
pub fn message_version(&self) -> MessageVersion {
self.message_version
}

Expand Down Expand Up @@ -498,7 +541,7 @@ impl TryFrom<&[u8]> for SenderKeyMessage {
.into_boxed_slice();

Ok(SenderKeyMessage {
message_version,
message_version: message_version.try_into()?,
distribution_id,
chain_id,
iteration,
Expand All @@ -510,7 +553,7 @@ impl TryFrom<&[u8]> for SenderKeyMessage {

#[derive(Debug, Clone)]
pub struct SenderKeyDistributionMessage {
message_version: u8,
message_version: MessageVersion,
distribution_id: Uuid,
chain_id: u32,
iteration: u32,
Expand Down Expand Up @@ -540,7 +583,7 @@ impl SenderKeyDistributionMessage {
proto_message.encode(&mut &mut serialized[1..])?;

Ok(Self {
message_version,
message_version: message_version.try_into()?,
distribution_id,
chain_id,
iteration,
Expand All @@ -551,7 +594,7 @@ impl SenderKeyDistributionMessage {
}

#[inline]
pub fn message_version(&self) -> u8 {
pub fn message_version(&self) -> MessageVersion {
self.message_version
}

Expand Down Expand Up @@ -640,7 +683,7 @@ impl TryFrom<&[u8]> for SenderKeyDistributionMessage {
let signing_key = PublicKey::deserialize(&signing_key)?;

Ok(SenderKeyDistributionMessage {
message_version,
message_version: message_version.try_into()?,
distribution_id,
chain_id,
iteration,
Expand Down Expand Up @@ -676,7 +719,7 @@ mod tests {
let receiver_identity_key_pair = KeyPair::generate(csprng);

SignalMessage::new(
3,
MessageVersion::default(),
&mac_key,
sender_ratchet_key_pair.public_key,
42,
Expand Down Expand Up @@ -713,10 +756,10 @@ mod tests {
let base_key_pair = KeyPair::generate(&mut csprng);
let message = create_signal_message(&mut csprng)?;
let pre_key_signal_message = PreKeySignalMessage::new(
3,
365,
MessageVersion::default(),
365.into(),
None,
97,
97.into(),
base_key_pair.public_key,
identity_key_pair.public_key.into(),
message,
Expand Down
4 changes: 2 additions & 2 deletions rust/protocol/src/ratchet.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//
// Copyright 2020 Signal Messenger, LLC.
// Copyright 2020-2021 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//

Expand All @@ -15,7 +15,7 @@ use crate::{KeyPair, Result, SessionRecord};
use rand::{CryptoRng, Rng};

fn derive_keys(secret_input: &[u8]) -> Result<(RootKey, ChainKey)> {
let kdf = crate::kdf::HKDF::new(3)?;
let kdf = crate::kdf::HKDF::new()?;

let secrets = kdf.derive_secrets(secret_input, b"WhisperText", 64)?;

Expand Down
Loading

0 comments on commit 8a01414

Please sign in to comment.