Skip to content

Commit

Permalink
Support setting ping interval (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbrgn committed Dec 19, 2017
1 parent 873b329 commit b38cf16
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 15 deletions.
2 changes: 2 additions & 0 deletions examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::io::{Read, Write};
use std::path::Path;
use std::process;
use std::rc::Rc;
use std::time::Duration;

use chrono::Local;
use clap::{Arg, App, SubCommand};
Expand Down Expand Up @@ -136,6 +137,7 @@ fn main() {
let task = ChatTask::new("initiat0r");
let salty = SaltyClientBuilder::new(keypair)
.add_task(Box::new(task))
.with_ping_interval(Some(Duration::from_secs(30)))
.initiator()
.expect("Could not create SaltyClient instance");
let auth_token_hex = HEXLOWER.encode(salty.auth_token().unwrap().secret_key_bytes());
Expand Down
23 changes: 22 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ mod test_helpers;
use std::cell::RefCell;
use std::ops::Deref;
use std::rc::Rc;
use std::time::Duration;

// Third party imports
use native_tls::TlsConnector;
Expand Down Expand Up @@ -94,6 +95,7 @@ macro_rules! boxed {
pub struct SaltyClientBuilder {
permanent_key: KeyPair,
tasks: Vec<Box<Task>>,
ping_interval: Option<Duration>,
}

impl SaltyClientBuilder {
Expand All @@ -102,6 +104,7 @@ impl SaltyClientBuilder {
SaltyClientBuilder {
permanent_key,
tasks: vec![],
ping_interval: None,
}
}

Expand All @@ -114,10 +117,27 @@ impl SaltyClientBuilder {
self
}

/// Request that the server sends a WebSocket ping message at the specified interval.
///
/// Set the `interval` argument to `None` or to a zero duration to disable intervals.
///
/// Note: Fractions of seconds are ignored, so if you set the duration to 13.37s,
/// then the ping interval 13s will be requested.
///
/// By default, ping messages are disabled.
pub fn with_ping_interval(mut self, interval: Option<Duration>) -> Self {
self.ping_interval = interval;
self
}

/// Create a new SaltyRTC initiator.
pub fn initiator(self) -> Result<SaltyClient, BuilderError> {
let tasks = Tasks::from_vec(self.tasks).map_err(|_| BuilderError::MissingTask)?;
let signaling = InitiatorSignaling::new(self.permanent_key, tasks);
let signaling = InitiatorSignaling::new(
self.permanent_key,
tasks,
self.ping_interval,
);
Ok(SaltyClient {
signaling: Box::new(signaling),
})
Expand All @@ -131,6 +151,7 @@ impl SaltyClientBuilder {
initiator_pubkey,
auth_token,
tasks,
self.ping_interval,
);
Ok(SaltyClient {
signaling: Box::new(signaling),
Expand Down
27 changes: 22 additions & 5 deletions src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

use std::collections::{HashMap, HashSet};
use std::mem;
use std::time::Duration;

use boxes::{ByteBox, OpenBox};
use crypto::{KeyPair, AuthToken, PublicKey};
Expand Down Expand Up @@ -409,10 +410,19 @@ pub(crate) trait Signaling {
}

// Send client-auth message
let ping_interval = self.common()
.ping_interval
.map(|duration| duration.as_secs())
.map(|secs| if secs > ::std::u32::MAX as u64 {
::std::u32::MAX
} else {
secs as u32
})
.unwrap_or(0u32);
let client_auth = ClientAuth {
your_cookie: self.server().cookie_pair().theirs.clone().unwrap(),
subprotocols: vec![::SUBPROTOCOL.into()],
ping_interval: 0, // TODO (#11)
ping_interval,
your_key: None, // TODO (#12)
}.into_message();
let client_auth_nonce = Nonce::new(
Expand Down Expand Up @@ -512,11 +522,14 @@ pub(crate) struct Common {
/// The server context.
pub(crate) server: ServerContext,

/// The list of possible task instances
/// The list of possible task instances.
pub(crate) tasks: Option<Tasks>,

/// The chosen task
/// The chosen task.
pub(crate) task: Option<Box<Task>>,

/// The interval at which the server should send WebSocket ping messages.
pub(crate) ping_interval: Option<Duration>,
}

impl Common {
Expand Down Expand Up @@ -805,7 +818,8 @@ impl Signaling for InitiatorSignaling {

impl InitiatorSignaling {
pub(crate) fn new(permanent_keypair: KeyPair,
tasks: Tasks) -> Self {
tasks: Tasks,
ping_interval: Option<Duration>) -> Self {
InitiatorSignaling {
common: Common {
signaling_state: SignalingState::ServerHandshake,
Expand All @@ -816,6 +830,7 @@ impl InitiatorSignaling {
server: ServerContext::new(),
tasks: Some(tasks),
task: None,
ping_interval: ping_interval,
},
responders: HashMap::new(),
responder: None,
Expand Down Expand Up @@ -1226,7 +1241,8 @@ impl ResponderSignaling {
pub(crate) fn new(permanent_keypair: KeyPair,
initiator_pubkey: PublicKey,
auth_token: Option<AuthToken>,
tasks: Tasks) -> Self {
tasks: Tasks,
ping_interval: Option<Duration>) -> Self {
ResponderSignaling {
common: Common {
signaling_state: SignalingState::ServerHandshake,
Expand All @@ -1237,6 +1253,7 @@ impl ResponderSignaling {
server: ServerContext::new(),
tasks: Some(tasks),
task: None,
ping_interval: ping_interval,
},
initiator: InitiatorContext::new(initiator_pubkey),
}
Expand Down
72 changes: 69 additions & 3 deletions src/protocol/tests/signaling_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl TestContext<InitiatorSignaling> {
let server_cookie = Cookie::random();
let ks = KeyPair::from_private_key(our_ks.private_key().clone());
let tasks = Tasks::new(Box::new(DummyTask::new(42)));
let mut signaling = InitiatorSignaling::new(ks, tasks);
let mut signaling = InitiatorSignaling::new(ks, tasks, None);
signaling.common_mut().identity = identity;
signaling.server_mut().set_handshake_state(server_handshake_state);
signaling.server_mut().cookie_pair = CookiePair {
Expand Down Expand Up @@ -68,7 +68,7 @@ impl TestContext<ResponderSignaling> {
let ks = KeyPair::from_private_key(our_ks.private_key().clone());
let mut tasks = Tasks::new(Box::new(DummyTask::new(23)));
tasks.add_task(Box::new(DummyTask::new(42)));
ResponderSignaling::new(ks, pk, auth_token, tasks)
ResponderSignaling::new(ks, pk, auth_token, tasks, None)
};
signaling.common_mut().identity = identity;
signaling.server_mut().set_handshake_state(server_handshake_state);
Expand Down Expand Up @@ -393,11 +393,77 @@ mod server_auth {
}
}

mod client_auth {
use super::*;

fn _test_ping_interval(interval: Option<Duration>) -> ClientAuth {
let kp = KeyPair::new();
let mut s = InitiatorSignaling::new(kp, Tasks::new(Box::new(DummyTask::new(123))), interval);

// Create and encode ServerHello message
let server_pubkey = PublicKey::random();
let server_hello = ServerHello::new(server_pubkey.clone()).into_message();
let cs = CombinedSequenceSnapshot::random();
let nonce = Nonce::new(Cookie::random(), Address(0), Address(0), cs);
let obox = OpenBox::<Message>::new(server_hello, nonce);
let bbox = obox.encode();

// Handle message
assert_eq!(s.server().handshake_state(), ServerHandshakeState::New);
let mut actions = s.handle_message(bbox).unwrap();
assert_eq!(s.server().handshake_state(), ServerHandshakeState::ClientInfoSent);
assert_eq!(actions.len(), 1); // Reply with client-auth

// Action contains ClientAuth message, encrypted with our permanent key
// and the server session key. Decrypt it to take a look at its contents.
let action = actions.remove(0);
let bytes: ByteBox = match action {
HandleAction::Reply(bbox) => bbox,
};

let decrypted = OpenBox::<Message>::decrypt(
bytes, &s.common().permanent_keypair, &server_pubkey
).unwrap();
match decrypted.message {
Message::ClientAuth(client_auth) => client_auth,
other => panic!("Expected ClientAuth, got {:?}", other)
}
}

/// If ping interval is None, send zero.
#[test]
fn ping_interval_none() {
let client_auth = _test_ping_interval(None);
assert_eq!(client_auth.ping_interval, 0);
}

/// If ping interval is 0s, send zero.
#[test]
fn ping_interval_zero() {
let client_auth = _test_ping_interval(Some(Duration::from_secs(0)));
assert_eq!(client_auth.ping_interval, 0);
}

/// If ping interval is a larger number, send that (as seconds).
#[test]
fn ping_interval_12345() {
let client_auth = _test_ping_interval(Some(Duration::from_secs(12345)));
assert_eq!(client_auth.ping_interval, 12345);
}

/// Ignore sub-second values.
#[test]
fn ping_interval_nanos() {
let client_auth = _test_ping_interval(Some(Duration::new(123, 45)));
assert_eq!(client_auth.ping_interval, 123);
}
}

mod token {
use super::*;

/// A receiving initiator MUST check that the message contains a valid NaCl
/// public key (32 bytes) in the key field.
/// public key (32 bytes) in the key field.
#[test]
fn token_initiator_validate_public_key() {
let mut ctx = TestContext::initiator(
Expand Down
12 changes: 6 additions & 6 deletions src/protocol/tests/validate_nonce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::*;
#[test]
fn first_message_wrong_destination() {
let ks = KeyPair::new();
let mut s = InitiatorSignaling::new(ks, Tasks(vec![]));
let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None);

let msg = ServerHello::random().into_message();
let cs = CombinedSequenceSnapshot::random();
Expand All @@ -34,7 +34,7 @@ fn first_message_wrong_destination() {
#[test]
fn wrong_source_initiator() {
let ks = KeyPair::new();
let mut s = InitiatorSignaling::new(ks, Tasks(vec![]));
let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None);

let make_msg = |src: u8, dest: u8| {
let msg = ServerHello::random().into_message();
Expand Down Expand Up @@ -74,7 +74,7 @@ fn wrong_source_initiator() {
fn wrong_source_responder() {
let ks = KeyPair::new();
let initiator_pubkey = PublicKey::from_slice(&[0u8; 32]).unwrap();
let mut s = ResponderSignaling::new(ks, initiator_pubkey, None, Tasks(vec![]));
let mut s = ResponderSignaling::new(ks, initiator_pubkey, None, Tasks(vec![]), None);

let make_msg = |src: u8, dest: u8| {
let msg = ServerHello::random().into_message();
Expand Down Expand Up @@ -111,7 +111,7 @@ fn wrong_source_responder() {
#[test]
fn first_message_bad_overflow_number() {
let ks = KeyPair::new();
let mut s = InitiatorSignaling::new(ks, Tasks(vec![]));
let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None);

let msg = ServerHello::random().into_message();
let cs = CombinedSequenceSnapshot::new(1, 1234);
Expand All @@ -132,7 +132,7 @@ fn _test_sequence_number(first: CombinedSequenceSnapshot,
second: CombinedSequenceSnapshot)
-> SignalingResult<Vec<HandleAction>> {
let ks = KeyPair::new();
let mut s = InitiatorSignaling::new(ks, Tasks(vec![]));
let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None);

// Process ServerHello
let msg = ServerHello::random().into_message();
Expand Down Expand Up @@ -191,7 +191,7 @@ fn sequence_number_reset() {
#[test]
fn cookie_differs_from_own() {
let ks = KeyPair::new();
let mut s = InitiatorSignaling::new(ks, Tasks(vec![]));
let mut s = InitiatorSignaling::new(ks, Tasks(vec![]), None);

let msg = ServerHello::random().into_message();
let cookie = s.server().cookie_pair.ours.clone();
Expand Down

0 comments on commit b38cf16

Please sign in to comment.