Skip to content

Commit

Permalink
introduce RegistrationId wrapper struct and remove a sealed sender test
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmicexplorer committed Oct 9, 2021
1 parent d6da511 commit df2a1e0
Show file tree
Hide file tree
Showing 17 changed files with 182 additions and 182 deletions.
7 changes: 5 additions & 2 deletions rust/bridge/shared/src/ffi/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ impl IdentityKeyStore for &FfiIdentityKeyStoreStruct {
Ok(IdentityKeyPair::new(IdentityKey::new(pub_key), *priv_key))
}

async fn get_local_registration_id(&self, ctx: Context) -> Result<u32, SignalProtocolError> {
async fn get_local_registration_id(
&self,
ctx: Context,
) -> Result<RegistrationId, SignalProtocolError> {
let ctx = ctx.unwrap_or(std::ptr::null_mut());
let mut id = 0;
let result = (self.get_local_registration_id)(self.ctx, &mut id, ctx);
Expand All @@ -89,7 +92,7 @@ impl IdentityKeyStore for &FfiIdentityKeyStoreStruct {
));
}

Ok(id)
Ok(RegistrationId::unsafe_from_value(id))
}

async fn save_identity(
Expand Down
9 changes: 7 additions & 2 deletions rust/bridge/shared/src/jni/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,13 @@ impl<'a> IdentityKeyStore for JniIdentityKeyStore<'a> {
Ok(self.do_get_identity_key_pair()?)
}

async fn get_local_registration_id(&self, _ctx: Context) -> Result<u32, SignalProtocolError> {
Ok(self.do_get_local_registration_id()?)
async fn get_local_registration_id(
&self,
_ctx: Context,
) -> Result<RegistrationId, SignalProtocolError> {
Ok(RegistrationId::unsafe_from_value(
self.do_get_local_registration_id()?,
))
}

async fn save_identity(
Expand Down
10 changes: 6 additions & 4 deletions rust/bridge/shared/src/node/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,12 @@ impl IdentityKeyStore for NodeIdentityKeyStore {
async fn get_local_registration_id(
&self,
_ctx: libsignal_protocol::Context,
) -> Result<u32, SignalProtocolError> {
self.do_get_local_registration_id()
.await
.map_err(|s| js_error_to_rust("getLocalRegistrationId", s))
) -> Result<RegistrationId, SignalProtocolError> {
Ok(RegistrationId::unsafe_from_value(
self.do_get_local_registration_id()
.await
.map_err(|s| js_error_to_rust("getLocalRegistrationId", s))?,
))
}

async fn get_identity(
Expand Down
8 changes: 4 additions & 4 deletions rust/bridge/shared/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ fn PreKeySignalMessage_New(
) -> Result<PreKeySignalMessage> {
PreKeySignalMessage::new(
message_version,
registration_id,
RegistrationId::unsafe_from_value(registration_id),
pre_key_id.map(|id| id.into()),
signed_pre_key_id.into(),
*base_key,
Expand Down Expand Up @@ -531,7 +531,7 @@ fn PreKeyBundle_New(
};

PreKeyBundle::new(
registration_id,
RegistrationId::unsafe_from_value(registration_id),
device_id.into(),
prekey,
signed_prekey_id.into(),
Expand Down Expand Up @@ -864,8 +864,8 @@ bridge_get_optional_bytearray!(
ffi = false,
node = false
);
bridge_get!(SessionRecord::local_registration_id -> u32);
bridge_get!(SessionRecord::remote_registration_id -> u32);
bridge_get!(SessionRecord::unsafe_local_registration_id -> u32);
bridge_get!(SessionRecord::unsafe_remote_registration_id -> u32);
bridge_get!(SessionRecord::has_sender_chain as HasSenderChain -> bool, ffi = false, node = false);

bridge_get!(SealedSenderDecryptionResult::sender_uuid -> String, ffi = false, jni = false);
Expand Down
3 changes: 2 additions & 1 deletion rust/protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ pub use {
storage::{
Context, Direction, IdentityKeyStore, InMemIdentityKeyStore, InMemPreKeyStore,
InMemSenderKeyStore, InMemSessionStore, InMemSignalProtocolStore, InMemSignedPreKeyStore,
PreKeyStore, ProtocolStore, SenderKeyStore, SessionStore, SignedPreKeyStore,
PreKeyStore, ProtocolStore, RegistrationId, SenderKeyStore, SessionStore,
SignedPreKeyStore,
},
};
9 changes: 5 additions & 4 deletions rust/protocol/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use crate::proto;
use crate::state::{PreKeyId, SignedPreKeyId};
use crate::{IdentityKey, PrivateKey, PublicKey, Result, SignalProtocolError};
use crate::{IdentityKey, PrivateKey, PublicKey, RegistrationId, Result, SignalProtocolError};

use std::convert::TryFrom;

Expand Down Expand Up @@ -249,13 +249,14 @@ pub struct PreKeySignalMessage {
impl PreKeySignalMessage {
pub fn new(
message_version: u8,
registration_id: u32,
registration_id: RegistrationId,
pre_key_id: Option<PreKeyId>,
signed_pre_key_id: SignedPreKeyId,
base_key: PublicKey,
identity_key: IdentityKey,
message: SignalMessage,
) -> Result<Self> {
let registration_id: u32 = registration_id.into();
let proto_message = proto::wire::PreKeySignalMessage {
registration_id: Some(registration_id),
pre_key_id: pre_key_id.map(|id| id.into()),
Expand Down Expand Up @@ -894,7 +895,7 @@ mod tests {
let message = create_signal_message(&mut csprng)?;
let pre_key_signal_message = PreKeySignalMessage::new(
3,
365,
RegistrationId::unsafe_from_value(365),
None,
97.into(),
base_key_pair.public_key,
Expand Down Expand Up @@ -1004,7 +1005,7 @@ mod tests {

let pre_key_signal_message = PreKeySignalMessage::new(
3,
365,
RegistrationId::unsafe_from_value(365),
None,
97.into(),
base_key_pair.public_key,
Expand Down
35 changes: 14 additions & 21 deletions rust/protocol/src/sealed_sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1289,28 +1289,21 @@ pub async fn sealed_sender_multi_recipient_encrypt<R: Rng + CryptoRng>(
.await?
.ok_or_else(|| SignalProtocolError::SessionNotFound(format!("{}", destination)))?;

let their_registration_id = session.remote_registration_id().map_err(|_| {
SignalProtocolError::InvalidState(
"sealed_sender_multi_recipient_encrypt",
format!(
concat!(
"cannot get registration ID from session with {} ",
"(maybe it was recently archived)"
let their_registration_id: u16 = session
.remote_registration_id(destination)
.map_err(|_| {
SignalProtocolError::InvalidState(
"sealed_sender_multi_recipient_encrypt",
format!(
concat!(
"cannot get registration ID from session with {} ",
"(maybe it was recently archived)"
),
destination
),
destination
),
)
})?;
// Valid registration IDs fit in 14 bits.
// TODO: move this into a RegistrationId strong type.
if their_registration_id & 0x3FFF != their_registration_id {
return Err(SignalProtocolError::InvalidRegistrationId(
destination.clone(),
their_registration_id,
));
}
let their_registration_id =
u16::try_from(their_registration_id).expect("just checked range");
)
})?
.into();

let end_of_previous_recipient_data = serialized.len();

Expand Down
16 changes: 10 additions & 6 deletions rust/protocol/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

use crate::{
Context, Direction, IdentityKeyStore, KeyPair, PreKeyBundle, PreKeyId, PreKeySignalMessage,
PreKeyStore, ProtocolAddress, Result, SessionRecord, SessionStore, SignalProtocolError,
SignedPreKeyStore,
PreKeyStore, ProtocolAddress, RegistrationId, Result, SessionRecord, SessionStore,
SignalProtocolError, SignedPreKeyStore,
};

use crate::ratchet;
Expand Down Expand Up @@ -116,8 +116,11 @@ async fn process_prekey_v3(

let mut new_session = ratchet::initialize_bob_session(&parameters)?;

new_session.set_local_registration_id(identity_store.get_local_registration_id(ctx).await?)?;
new_session.set_remote_registration_id(message.registration_id())?;
new_session
.set_local_registration_id(identity_store.get_local_registration_id(ctx).await?.into())?;
let remote_registration_id =
RegistrationId::deserialize(message.registration_id().into(), remote_address)?;
new_session.set_remote_registration_id(remote_registration_id)?;
new_session.set_alice_base_key(&message.base_key().serialize())?;

session_record.promote_state(new_session)?;
Expand Down Expand Up @@ -187,8 +190,9 @@ pub async fn process_prekey_bundle<R: Rng + CryptoRng>(
&our_base_key_pair.public_key,
)?;

session.set_local_registration_id(identity_store.get_local_registration_id(ctx).await?)?;
session.set_remote_registration_id(bundle.registration_id()?)?;
session
.set_local_registration_id(identity_store.get_local_registration_id(ctx).await?.into())?;
session.set_remote_registration_id(bundle.registration_id()?.into())?;
session.set_alice_base_key(&our_base_key_pair.public_key.serialize())?;

identity_store
Expand Down
4 changes: 2 additions & 2 deletions rust/protocol/src/session_cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub async fn message_encrypt(
let ctext = crypto::aes_256_cbc_encrypt(ptext, message_keys.cipher_key(), message_keys.iv())?;

let message = if let Some(items) = session_state.unacknowledged_pre_key_message_items()? {
let local_registration_id = session_state.local_registration_id()?;
let local_registration_id = session_state.local_registration_id(remote_address)?;

log::info!(
"Building PreKeyWhisperMessage for: {} with preKeyId: {}",
Expand All @@ -69,7 +69,7 @@ pub async fn message_encrypt(

CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::new(
session_version,
local_registration_id,
local_registration_id.into(),
items.pre_key_id()?,
items.signed_pre_key_id()?,
*items.base_key()?,
Expand Down
8 changes: 4 additions & 4 deletions rust/protocol/src/state/bundle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
//

use crate::state::{PreKeyId, SignedPreKeyId};
use crate::{DeviceId, IdentityKey, PublicKey, Result};
use crate::{DeviceId, IdentityKey, PublicKey, RegistrationId, Result};

#[derive(Debug, Clone)]
pub struct PreKeyBundle {
registration_id: u32,
registration_id: RegistrationId,
device_id: DeviceId,
pre_key_id: Option<PreKeyId>,
pre_key_public: Option<PublicKey>,
Expand All @@ -20,7 +20,7 @@ pub struct PreKeyBundle {

impl PreKeyBundle {
pub fn new(
registration_id: u32,
registration_id: RegistrationId,
device_id: DeviceId,
pre_key: Option<(PreKeyId, PublicKey)>,
signed_pre_key_id: SignedPreKeyId,
Expand All @@ -45,7 +45,7 @@ impl PreKeyBundle {
})
}

pub fn registration_id(&self) -> Result<u32> {
pub fn registration_id(&self) -> Result<RegistrationId> {
Ok(self.registration_id)
}

Expand Down
50 changes: 39 additions & 11 deletions rust/protocol/src/state/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
use prost::Message;

use crate::ratchet::{ChainKey, MessageKeys, RootKey};
use crate::{IdentityKey, KeyPair, PrivateKey, PublicKey, Result, SignalProtocolError, HKDF};
use crate::{
IdentityKey, KeyPair, PrivateKey, ProtocolAddress, PublicKey, RegistrationId, Result,
SignalProtocolError, HKDF,
};

use crate::consts;
use crate::proto::storage::session_structure;
Expand Down Expand Up @@ -418,21 +421,38 @@ impl SessionState {
Ok(())
}

pub(crate) fn set_remote_registration_id(&mut self, registration_id: u32) -> Result<()> {
self.session.remote_registration_id = registration_id;
pub(crate) fn set_remote_registration_id(
&mut self,
registration_id: RegistrationId,
) -> Result<()> {
self.session.remote_registration_id = registration_id.into();
Ok(())
}

pub(crate) fn remote_registration_id(&self) -> Result<u32> {
pub(crate) fn remote_registration_id(
&self,
remote: &ProtocolAddress,
) -> Result<RegistrationId> {
RegistrationId::deserialize(self.session.remote_registration_id, remote)
}

pub(crate) fn unsafe_remote_registration_id(&self) -> Result<u32> {
Ok(self.session.remote_registration_id)
}

pub(crate) fn set_local_registration_id(&mut self, registration_id: u32) -> Result<()> {
self.session.local_registration_id = registration_id;
pub(crate) fn set_local_registration_id(
&mut self,
registration_id: RegistrationId,
) -> Result<()> {
self.session.local_registration_id = registration_id.into();
Ok(())
}

pub(crate) fn local_registration_id(&self) -> Result<u32> {
pub(crate) fn local_registration_id(&self, remote: &ProtocolAddress) -> Result<RegistrationId> {
RegistrationId::deserialize(self.session.local_registration_id, remote)
}

pub(crate) fn unsafe_local_registration_id(&self) -> Result<u32> {
Ok(self.session.local_registration_id)
}
}
Expand Down Expand Up @@ -596,12 +616,20 @@ impl SessionRecord {
Ok(record.encode_to_vec())
}

pub fn remote_registration_id(&self) -> Result<u32> {
self.session_state()?.remote_registration_id()
pub fn remote_registration_id(&self, remote: &ProtocolAddress) -> Result<RegistrationId> {
self.session_state()?.remote_registration_id(remote)
}

pub fn unsafe_remote_registration_id(&self) -> Result<u32> {
self.session_state()?.unsafe_remote_registration_id()
}

pub fn local_registration_id(&self, remote: &ProtocolAddress) -> Result<RegistrationId> {
self.session_state()?.local_registration_id(remote)
}

pub fn local_registration_id(&self) -> Result<u32> {
self.session_state()?.local_registration_id()
pub fn unsafe_local_registration_id(&self) -> Result<u32> {
self.session_state()?.unsafe_local_registration_id()
}

pub fn session_version(&self) -> Result<u32> {
Expand Down
4 changes: 2 additions & 2 deletions rust/protocol/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub use {
InMemSignalProtocolStore, InMemSignedPreKeyStore,
},
traits::{
Context, Direction, IdentityKeyStore, PreKeyStore, ProtocolStore, SenderKeyStore,
SessionStore, SignedPreKeyStore,
Context, Direction, IdentityKeyStore, PreKeyStore, ProtocolStore, RegistrationId,
SenderKeyStore, SessionStore, SignedPreKeyStore,
},
};
Loading

0 comments on commit df2a1e0

Please sign in to comment.