From fc7912f20646e60207322b6083d6fb90764607b0 Mon Sep 17 00:00:00 2001 From: Danilo Bargen Date: Mon, 7 May 2018 11:09:03 +0200 Subject: [PATCH] Allow specifying server permanent key in SaltyClientBuilder Refs #12 --- src/lib.rs | 13 +++++++++++++ src/protocol/context.rs | 1 + src/protocol/mod.rs | 14 ++++++++++++-- src/protocol/tests/signaling_messages.rs | 5 +++-- src/protocol/tests/validate_nonce.rs | 14 +++++++------- 5 files changed, 36 insertions(+), 11 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index ba2b2cb..55fd676 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -149,6 +149,7 @@ pub struct SaltyClientBuilder { permanent_key: KeyPair, tasks: Vec, ping_interval: Option, + server_public_permanent_key: Option, } impl SaltyClientBuilder { @@ -158,6 +159,7 @@ impl SaltyClientBuilder { permanent_key, tasks: vec![], ping_interval: None, + server_public_permanent_key: None, } } @@ -170,6 +172,13 @@ impl SaltyClientBuilder { self } + /// Specify the server public permanent key if you want to use server key + /// pinning. + pub fn with_server_key(mut self, server_public_permanent_key: PublicKey) -> Self { + self.server_public_permanent_key = Some(server_public_permanent_key); + self + } + /// Request that the server sends a WebSocket ping message at the specified interval. /// /// Set the `interval` argument to `None` or to a zero duration to disable intervals. @@ -190,6 +199,7 @@ impl SaltyClientBuilder { self.permanent_key, tasks, None, + self.server_public_permanent_key, self.ping_interval, ); Ok(SaltyClient { @@ -204,6 +214,7 @@ impl SaltyClientBuilder { self.permanent_key, tasks, Some(responder_trusted_pubkey), + self.server_public_permanent_key, self.ping_interval, ); Ok(SaltyClient { @@ -218,6 +229,7 @@ impl SaltyClientBuilder { self.permanent_key, initiator_pubkey, Some(auth_token), + self.server_public_permanent_key, tasks, self.ping_interval, ); @@ -233,6 +245,7 @@ impl SaltyClientBuilder { self.permanent_key, initiator_trusted_pubkey, None, + self.server_public_permanent_key, tasks, self.ping_interval, ); diff --git a/src/protocol/context.rs b/src/protocol/context.rs index 3209f9e..79794cf 100644 --- a/src/protocol/context.rs +++ b/src/protocol/context.rs @@ -54,6 +54,7 @@ pub(crate) struct ServerContext { } impl ServerContext { + /// Create a new `ServerContext` instance. pub fn new() -> Self { ServerContext { handshake_state: ServerHandshakeState::New, diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index bd0831f..c194dd0 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1124,6 +1124,7 @@ impl InitiatorSignaling { pub(crate) fn new(permanent_keypair: KeyPair, tasks: Tasks, responder_trusted_pubkey: Option, + server_public_permanent_key: Option, ping_interval: Option) -> Self { InitiatorSignaling { common: Common { @@ -1135,7 +1136,11 @@ impl InitiatorSignaling { Some(key) => AuthProvider::TrustedKey(key), None => AuthProvider::Token(AuthToken::new()), }), - server: ServerContext::new(), + server: { + let mut ctx = ServerContext::new(); + ctx.permanent_key = server_public_permanent_key; + ctx + }, tasks: Some(tasks), task: None, task_supported_types: None, @@ -1660,6 +1665,7 @@ impl ResponderSignaling { pub(crate) fn new(permanent_keypair: KeyPair, initiator_pubkey: PublicKey, auth_token: Option, + server_public_permanent_key: Option, tasks: Tasks, ping_interval: Option) -> Self { ResponderSignaling { @@ -1672,7 +1678,11 @@ impl ResponderSignaling { Some(token) => AuthProvider::Token(token), None => AuthProvider::TrustedKey(initiator_pubkey), }), - server: ServerContext::new(), + server: { + let mut ctx = ServerContext::new(); + ctx.permanent_key = server_public_permanent_key; + ctx + }, tasks: Some(tasks), task: None, task_supported_types: None, diff --git a/src/protocol/tests/signaling_messages.rs b/src/protocol/tests/signaling_messages.rs index a6df7f8..bb40bfd 100644 --- a/src/protocol/tests/signaling_messages.rs +++ b/src/protocol/tests/signaling_messages.rs @@ -31,7 +31,7 @@ impl TestContext { let server_cookie = Cookie::random(); let ks = KeyPair::from_private_key(our_ks.private_key().clone()); let tasks = Tasks::new(Box::new(DummyTask::new(42))); - let mut signaling = InitiatorSignaling::new(ks, tasks, peer_trusted_pubkey, None); + let mut signaling = InitiatorSignaling::new(ks, tasks, peer_trusted_pubkey, None, None); signaling.common_mut().identity = identity; signaling.server_mut().set_handshake_state(server_handshake_state); signaling.server_mut().cookie_pair = CookiePair { @@ -70,7 +70,7 @@ impl TestContext { let ks = KeyPair::from_private_key(our_ks.private_key().clone()); let mut tasks = Tasks::new(Box::new(DummyTask::new(23))); tasks.add_task(Box::new(DummyTask::new(42))).unwrap(); - ResponderSignaling::new(ks, pk, auth_token, tasks, None) + ResponderSignaling::new(ks, pk, auth_token, None, tasks, None) }; signaling.common_mut().identity = identity; signaling.server_mut().set_handshake_state(server_handshake_state); @@ -408,6 +408,7 @@ mod client_auth { kp, Tasks::new(Box::new(DummyTask::new(123))), None, + None, interval, ); diff --git a/src/protocol/tests/validate_nonce.rs b/src/protocol/tests/validate_nonce.rs index 72b0293..b51800b 100644 --- a/src/protocol/tests/validate_nonce.rs +++ b/src/protocol/tests/validate_nonce.rs @@ -10,7 +10,7 @@ use super::*; #[test] fn first_message_wrong_destination() { let ks = KeyPair::new(); - let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None, None); + let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None, None, None); let msg = ServerHello::random().into_message(); let cs = CombinedSequenceSnapshot::random(); @@ -34,7 +34,7 @@ fn first_message_wrong_destination() { #[test] fn wrong_source_initiator() { let ks = KeyPair::new(); - let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None, None); + let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None, None, None); let make_msg = |src: u8, dest: u8| { let msg = ServerHello::random().into_message(); @@ -74,7 +74,7 @@ fn wrong_source_initiator() { fn wrong_source_responder() { let ks = KeyPair::new(); let initiator_pubkey = PublicKey::from_slice(&[0u8; 32]).unwrap(); - let mut s = ResponderSignaling::new(ks, initiator_pubkey, None, Tasks(vec![]), None); + let mut s = ResponderSignaling::new(ks, initiator_pubkey, None, None, Tasks(vec![]), None); let make_msg = |src: u8, dest: u8| { let msg = ServerHello::random().into_message(); @@ -111,7 +111,7 @@ fn wrong_source_responder() { #[test] fn first_message_bad_overflow_number() { let ks = KeyPair::new(); - let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None, None); + let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None, None, None); let msg = ServerHello::random().into_message(); let cs = CombinedSequenceSnapshot::new(1, 1234); @@ -132,7 +132,7 @@ fn _test_sequence_number(first: CombinedSequenceSnapshot, second: CombinedSequenceSnapshot) -> SignalingResult> { let ks = KeyPair::new(); - let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None, None); + let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None, None, None); // Process ServerHello let msg = ServerHello::random().into_message(); @@ -191,7 +191,7 @@ fn sequence_number_reset() { #[test] fn cookie_differs_from_own() { let ks = KeyPair::new(); - let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None, None); + let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None, None, None); let msg = ServerHello::random().into_message(); let cookie = s.server().cookie_pair.ours.clone(); @@ -213,7 +213,7 @@ fn cookie_differs_from_own() { fn cookie_did_not_change() { // Create new signaling instance let ks = KeyPair::new(); - let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None, None); + let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None, None, None); // Prepare 'server-hello' message let msg = ServerHello::random().into_message();