diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000000..5deae4867d --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,8 @@ +hard_tabs = true +array_layout = "Block" +fn_args_layout = "Block" +chain_indent = "Visual" +chain_one_line_max = 100 +take_source_hints = true +write_mode = "Overwrite" + diff --git a/.travis.yml b/.travis.yml index 65fe2cf2c7..f47b1f4114 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,12 @@ language: rust rust: nightly +cache: cargo +before_script: + - export PATH="$PATH:$HOME/.cargo/bin" + - which rustfmt || cargo install rustfmt script: + - cargo fmt -- --write-mode=diff - cargo build --features nightly - cargo test --features nightly - cargo bench --features nightly @@ -12,36 +17,36 @@ after_success: - > [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && sudo pip install ghp-import - > - [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && { - echo "Running Autobahn TestSuite for client" ; - wstest -m fuzzingserver -s ./autobahn/fuzzingserver.json & FUZZINGSERVER_PID=$! ; - sleep 10 ; - ./target/debug/examples/autobahn-client ; + [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && { + echo "Running Autobahn TestSuite for client" ; + wstest -m fuzzingserver -s ./autobahn/fuzzingserver.json & FUZZINGSERVER_PID=$! ; + sleep 10 ; + ./target/debug/examples/autobahn-client ; kill -9 ${FUZZINGSERVER_PID} ; } - > - [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && { - echo "Running Autobahn TestSuite for server" ; - ./target/debug/examples/autobahn-server & WSSERVER_PID=$! ; - sleep 10 ; - wstest -m fuzzingclient -s ./autobahn/fuzzingclient.json ; + [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && { + echo "Running Autobahn TestSuite for server" ; + ./target/debug/examples/autobahn-server & WSSERVER_PID=$! ; + sleep 10 ; + wstest -m fuzzingclient -s ./autobahn/fuzzingclient.json ; kill -9 ${WSSERVER_PID} ; } - > [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && { - echo "Building docs and gh-pages" ; - PROJECT_VERSION=$(cargo doc --features nightly | grep "Documenting websocket v" | sed 's/.*Documenting websocket v\(.*\) .*/\1/') ; - curl -sL https://github.com/${TRAVIS_REPO_SLUG}/archive/html.tar.gz | tar xz ; - cd ./rust-websocket-html && - find . -type f | xargs sed -i 's//'"${PROJECT_VERSION}"'/g' ; - mv ../target/doc ./doc ; - mv ../autobahn/server ./autobahn/server ; - mv ../autobahn/client ./autobahn/client ; - mv ./autobahn/server/index.json ./autobahn/server/index.temp && rm ./autobahn/server/*.json && mv ./autobahn/server/index.temp ./autobahn/server/index.json ; + echo "Building docs and gh-pages" ; + PROJECT_VERSION=$(cargo doc --features nightly | grep "Documenting websocket v" | sed 's/.*Documenting websocket v\(.*\) .*/\1/') ; + curl -sL https://github.com/${TRAVIS_REPO_SLUG}/archive/html.tar.gz | tar xz ; + cd ./rust-websocket-html && + find . -type f | xargs sed -i 's//'"${PROJECT_VERSION}"'/g' ; + mv ../target/doc ./doc ; + mv ../autobahn/server ./autobahn/server ; + mv ../autobahn/client ./autobahn/client ; + mv ./autobahn/server/index.json ./autobahn/server/index.temp && rm ./autobahn/server/*.json && mv ./autobahn/server/index.temp ./autobahn/server/index.json ; mv ./autobahn/client/index.json ./autobahn/client/index.temp && rm ./autobahn/client/*.json && mv ./autobahn/client/index.temp ./autobahn/client/index.json ; cd ../ ; } - > [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && { - echo "Pushing gh-pages" ; - ghp-import -n ./rust-websocket-html -m "Generated by Travis CI build ${TRAVIS_BUILD_NUMBER} for commit ${TRAVIS_COMMIT}" && + echo "Pushing gh-pages" ; + ghp-import -n ./rust-websocket-html -m "Generated by Travis CI build ${TRAVIS_BUILD_NUMBER} for commit ${TRAVIS_COMMIT}" && git push -fq https://${TOKEN}@github.com/${TRAVIS_REPO_SLUG}.git gh-pages ; } env: diff --git a/Cargo.toml b/Cargo.toml index 884b3d0b8c..8d61955c62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "websocket" -version = "0.17.2" -authors = ["cyderize "] +version = "0.18.0" +authors = ["cyderize ", "Michael Eden "] description = "A WebSocket (RFC6455) library for Rust." @@ -17,15 +17,17 @@ keywords = ["websocket", "websockets", "rfc6455"] license = "MIT" [dependencies] -hyper = ">=0.7, <0.11" -unicase = "1.0.1" -openssl = "0.7.6" -url = "1.0" -rustc-serialize = "0.3.16" -bitflags = "0.7" -rand = "0.3.12" -byteorder = "1.0" -net2 = "0.2.17" +hyper = { git = "https://github.com/hyperium/hyper.git", branch = "0.10.x" } +unicase = "^1.0" +url = "^1.0" +rustc-serialize = "^0.3" +bitflags = "^0.8" +rand = "^0.3" +byteorder = "^1.0" +sha1 = "^0.2" +openssl = { version = "^0.9.10", optional = true } [features] +default = ["ssl"] +ssl = ["openssl"] nightly = ["hyper/nightly"] diff --git a/README.md b/README.md index 1d965bc859..1d6702f0a1 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Rust-WebSocket provides a framework for dealing with WebSocket connections (both To add a library release version from [crates.io](https://crates.io/crates/websocket) to a Cargo project, add this to the 'dependencies' section of your Cargo.toml: ```INI -websocket = "0.17.1" +websocket = "0.18.0" ``` To add the library's Git repository to a Cargo project, add this to your Cargo.toml: diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 0000000000..afd557c319 --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,41 @@ +# The Roadmap + +## More Docs, Examples and Tests + +Easy as that, every method should be tested and documented. +Every use-case should have an example. + +## Adding Features + +### `net2` Feature + +This is a feature to add the `net2` crate which will let us do cool things +like set the option `SO_REUSEADDR` and similar when making TCP connections. + +This is discussed in [vi/rust-websocket#2](https://github.com/vi/rust-websocket/pull/2). + +### Add Mio & Tokio (Evented Websocket) + +There are a lot of issues that would be solved if this was evented, such as: + + - [#88 tokio support](https://github.com/cyderize/rust-websocket/issues/88) + - [#66 Timeout on recv_message](https://github.com/cyderize/rust-websocket/issues/66) + - [#6 one client, one thread?](https://github.com/cyderize/rust-websocket/issues/6) + +So maybe we should _just_ add `tokio` support, or maybe `mio` is still used and popular. + +### Support Permessage-Deflate + +We need this to pass more autobahn tests! + +### Buffer Reads and Writes + +In the old crate the stream was split up into a reader and writer stream so you could +have both a `BufReader` and a `BufWriter` to buffer your operations to gain some speed. +However is doesn't make sense to split the stream up anymore +(see [#83](https://github.com/cyderize/rust-websocket/issues/83)) +meaning that we should buffer reads and writes in some other way. + +Some work has begun on this, like [#91](https://github.com/cyderize/rust-websocket/pull/91), +but is this enough? And what about writing? + diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 77423c7b47..2519e79f1e 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -2,8 +2,8 @@ extern crate websocket; extern crate rustc_serialize as serialize; use std::str::from_utf8; -use websocket::client::request::Url; -use websocket::{Client, Message, Sender, Receiver}; +use websocket::ClientBuilder; +use websocket::Message; use websocket::message::Type; use serialize::json; @@ -20,22 +20,18 @@ fn main() { let case_count = get_case_count(addr.clone()); while current_case_id <= case_count { - let url = addr.clone() + "/runCase?case=" + ¤t_case_id.to_string()[..] + "&agent=" + agent; + let case_id = current_case_id; + current_case_id += 1; + let url = addr.clone() + "/runCase?case=" + &case_id.to_string()[..] + "&agent=" + agent; - let ws_uri = Url::parse(&url[..]).unwrap(); - let request = Client::connect(ws_uri).unwrap(); - let response = request.send().unwrap(); - match response.validate() { - Ok(()) => (), - Err(e) => { - println!("{:?}", e); - current_case_id += 1; - continue; - } - } - let (mut sender, mut receiver) = response.begin().split(); + let client = ClientBuilder::new(&url) + .unwrap() + .connect_insecure() + .unwrap(); - println!("Executing test case: {}/{}", current_case_id, case_count); + let (mut receiver, mut sender) = client.split().unwrap(); + + println!("Executing test case: {}/{}", case_id, case_count); for message in receiver.incoming_messages() { let message: Message = match message { @@ -49,7 +45,7 @@ fn main() { match message.opcode { Type::Text => { - let response = Message::text(from_utf8(&*message.payload).unwrap()); + let response = Message::text(from_utf8(&*message.payload).unwrap()); sender.send_message(&response).unwrap(); } Type::Binary => { @@ -65,8 +61,6 @@ fn main() { _ => (), } } - - current_case_id += 1; } update_reports(addr.clone(), agent); @@ -74,17 +68,16 @@ fn main() { fn get_case_count(addr: String) -> usize { let url = addr + "/getCaseCount"; - let ws_uri = Url::parse(&url[..]).unwrap(); - let request = Client::connect(ws_uri).unwrap(); - let response = request.send().unwrap(); - match response.validate() { - Ok(()) => (), + + let client = match ClientBuilder::new(&url).unwrap().connect_insecure() { + Ok(c) => c, Err(e) => { println!("{:?}", e); return 0; } - } - let (mut sender, mut receiver) = response.begin().split(); + }; + + let (mut receiver, mut sender) = client.split().unwrap(); let mut count = 0; @@ -93,7 +86,8 @@ fn get_case_count(addr: String) -> usize { Ok(message) => message, Err(e) => { println!("Error: {:?}", e); - let _ = sender.send_message(&Message::close_because(1002, "".to_string())); + let _ = + sender.send_message(&Message::close_because(1002, "".to_string())); break; } }; @@ -118,17 +112,16 @@ fn get_case_count(addr: String) -> usize { fn update_reports(addr: String, agent: &str) { let url = addr + "/updateReports?agent=" + agent; - let ws_uri = Url::parse(&url[..]).unwrap(); - let request = Client::connect(ws_uri).unwrap(); - let response = request.send().unwrap(); - match response.validate() { - Ok(()) => (), + + let client = match ClientBuilder::new(&url).unwrap().connect_insecure() { + Ok(c) => c, Err(e) => { println!("{:?}", e); return; } - } - let (mut sender, mut receiver) = response.begin().split(); + }; + + let (mut receiver, mut sender) = client.split().unwrap(); println!("Updating reports..."); diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 6e4885a0fa..f9a8d959e9 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -2,20 +2,17 @@ extern crate websocket; use std::thread; use std::str::from_utf8; -use websocket::{Server, Message, Sender, Receiver}; +use websocket::{Server, Message}; use websocket::message::Type; fn main() { - let addr = "127.0.0.1:9002".to_string(); - - let server = Server::bind(&addr[..]).unwrap(); + let server = Server::bind("127.0.0.1:9002").unwrap(); for connection in server { thread::spawn(move || { - let request = connection.unwrap().read_request().unwrap(); - request.validate().unwrap(); - let response = request.accept(); - let (mut sender, mut receiver) = response.send().unwrap().split(); + let client = connection.accept().unwrap(); + + let (mut receiver, mut sender) = client.split().unwrap(); for message in receiver.incoming_messages() { let message: Message = match message { @@ -29,10 +26,12 @@ fn main() { match message.opcode { Type::Text => { - let response = Message::text(from_utf8(&*message.payload).unwrap()); - sender.send_message(&response).unwrap() - }, - Type::Binary => sender.send_message(&Message::binary(message.payload)).unwrap(), + let response = Message::text(from_utf8(&*message.payload).unwrap()); + sender.send_message(&response).unwrap() + } + Type::Binary => { + sender.send_message(&Message::binary(message.payload)).unwrap() + } Type::Close => { let _ = sender.send_message(&Message::close()); return; diff --git a/examples/client.rs b/examples/client.rs index 95e63552d3..62bc504708 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,30 +1,26 @@ extern crate websocket; +const CONNECTION: &'static str = "ws://127.0.0.1:2794"; + fn main() { use std::thread; use std::sync::mpsc::channel; use std::io::stdin; - use websocket::{Message, Sender, Receiver}; - use websocket::message::Type; - use websocket::client::request::Url; - use websocket::Client; - - let url = Url::parse("ws://127.0.0.1:2794").unwrap(); - - println!("Connecting to {}", url); - - let request = Client::connect(url).unwrap(); + use websocket::Message; + use websocket::message::Type; + use websocket::client::ClientBuilder; - let response = request.send().unwrap(); // Send the request and retrieve a response + println!("Connecting to {}", CONNECTION); - println!("Validating response..."); - - response.validate().unwrap(); // Validate the response + let client = ClientBuilder::new(CONNECTION) + .unwrap() + .connect_insecure() + .unwrap(); println!("Successfully connected"); - let (mut sender, mut receiver) = response.begin().split(); + let (mut receiver, mut sender) = client.split().unwrap(); let (tx, rx) = channel(); @@ -45,7 +41,7 @@ fn main() { let _ = sender.send_message(&message); // If it's a close message, just send it and then return. return; - }, + } _ => (), } // Send the message @@ -77,14 +73,16 @@ fn main() { let _ = tx_1.send(Message::close()); return; } - Type::Ping => match tx_1.send(Message::pong(message.payload)) { - // Send a pong in response - Ok(()) => (), - Err(e) => { - println!("Receive Loop: {:?}", e); - return; + Type::Ping => { + match tx_1.send(Message::pong(message.payload)) { + // Send a pong in response + Ok(()) => (), + Err(e) => { + println!("Receive Loop: {:?}", e); + return; + } } - }, + } // Say what we received _ => println!("Receive Loop: {:?}", message), } diff --git a/examples/hyper.rs b/examples/hyper.rs index 640bc1c6a4..f5b8b21248 100644 --- a/examples/hyper.rs +++ b/examples/hyper.rs @@ -3,29 +3,29 @@ extern crate hyper; use std::thread; use std::io::Write; -use websocket::{Server, Message, Sender, Receiver}; -use websocket::header::WebSocketProtocol; +use websocket::{Server, Message}; use websocket::message::Type; use hyper::Server as HttpServer; -use hyper::server::Handler; use hyper::net::Fresh; use hyper::server::request::Request; use hyper::server::response::Response; +const HTML: &'static str = include_str!("websockets.html"); + // The HTTP server handler fn http_handler(_: Request, response: Response) { let mut response = response.start().unwrap(); // Send a client webpage - response.write_all(b"WebSocket Test

Received Messages:

").unwrap(); + response.write_all(HTML.as_bytes()).unwrap(); response.end().unwrap(); } fn main() { // Start listening for http connections thread::spawn(move || { - let http_server = HttpServer::http("127.0.0.1:8080").unwrap(); - http_server.handle(http_handler).unwrap(); - }); + let http_server = HttpServer::http("127.0.0.1:8080").unwrap(); + http_server.handle(http_handler).unwrap(); + }); // Start listening for WebSocket connections let ws_server = Server::bind("127.0.0.1:2794").unwrap(); @@ -33,33 +33,21 @@ fn main() { for connection in ws_server { // Spawn a new thread for each connection. thread::spawn(move || { - let request = connection.unwrap().read_request().unwrap(); // Get the request - let headers = request.headers.clone(); // Keep the headers so we can check them - - request.validate().unwrap(); // Validate the request - - let mut response = request.accept(); // Form a response - - if let Some(&WebSocketProtocol(ref protocols)) = headers.get() { - if protocols.contains(&("rust-websocket".to_string())) { - // We have a protocol we want to use - response.headers.set(WebSocketProtocol(vec!["rust-websocket".to_string()])); - } + if !connection.protocols().contains(&"rust-websocket".to_string()) { + connection.reject().unwrap(); + return; } - let mut client = response.send().unwrap(); // Send the response + let mut client = connection.use_protocol("rust-websocket").accept().unwrap(); - let ip = client.get_mut_sender() - .get_mut() - .peer_addr() - .unwrap(); + let ip = client.peer_addr().unwrap(); println!("Connection from {}", ip); let message = Message::text("Hello".to_string()); client.send_message(&message).unwrap(); - let (mut sender, mut receiver) = client.split(); + let (mut receiver, mut sender) = client.split().unwrap(); for message in receiver.incoming_messages() { let message: Message = message.unwrap(); @@ -70,11 +58,11 @@ fn main() { sender.send_message(&message).unwrap(); println!("Client {} disconnected", ip); return; - }, + } Type::Ping => { let message = Message::pong(message.payload); sender.send_message(&message).unwrap(); - }, + } _ => sender.send_message(&message).unwrap(), } } diff --git a/examples/server.rs b/examples/server.rs index 1bfef5c090..dd11f5a098 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,43 +1,30 @@ extern crate websocket; use std::thread; -use websocket::{Server, Message, Sender, Receiver}; +use websocket::{Server, Message}; use websocket::message::Type; -use websocket::header::WebSocketProtocol; fn main() { let server = Server::bind("127.0.0.1:2794").unwrap(); - for connection in server { + for request in server { // Spawn a new thread for each connection. thread::spawn(move || { - let request = connection.unwrap().read_request().unwrap(); // Get the request - let headers = request.headers.clone(); // Keep the headers so we can check them - - request.validate().unwrap(); // Validate the request - - let mut response = request.accept(); // Form a response - - if let Some(&WebSocketProtocol(ref protocols)) = headers.get() { - if protocols.contains(&("rust-websocket".to_string())) { - // We have a protocol we want to use - response.headers.set(WebSocketProtocol(vec!["rust-websocket".to_string()])); - } + if !request.protocols().contains(&"rust-websocket".to_string()) { + request.reject().unwrap(); + return; } - let mut client = response.send().unwrap(); // Send the response + let mut client = request.use_protocol("rust-websocket").accept().unwrap(); - let ip = client.get_mut_sender() - .get_mut() - .peer_addr() - .unwrap(); + let ip = client.peer_addr().unwrap(); println!("Connection from {}", ip); let message: Message = Message::text("Hello".to_string()); client.send_message(&message).unwrap(); - let (mut sender, mut receiver) = client.split(); + let (mut receiver, mut sender) = client.split().unwrap(); for message in receiver.incoming_messages() { let message: Message = message.unwrap(); @@ -48,7 +35,7 @@ fn main() { sender.send_message(&message).unwrap(); println!("Client {} disconnected", ip); return; - }, + } Type::Ping => { let message = Message::pong(message.payload); sender.send_message(&message).unwrap(); diff --git a/src/client/builder.rs b/src/client/builder.rs new file mode 100644 index 0000000000..b16e41a36a --- /dev/null +++ b/src/client/builder.rs @@ -0,0 +1,578 @@ +//! Everything you need to create a client connection to a websocket. + +use std::borrow::Cow; +use std::net::TcpStream; +pub use url::{Url, ParseError}; +use url::Position; +use hyper::version::HttpVersion; +use hyper::status::StatusCode; +use hyper::buffer::BufReader; +use hyper::http::h1::parse_response; +use hyper::header::{Headers, Header, HeaderFormat, Host, Connection, ConnectionOption, Upgrade, + Protocol, ProtocolName}; +use unicase::UniCase; +#[cfg(feature="ssl")] +use openssl::ssl::{SslMethod, SslStream, SslConnector, SslConnectorBuilder}; +use header::extensions::Extension; +use header::{WebSocketAccept, WebSocketKey, WebSocketVersion, WebSocketProtocol, + WebSocketExtensions, Origin}; +use result::{WSUrlErrorKind, WebSocketResult, WebSocketError}; +#[cfg(feature="ssl")] +use stream::NetworkStream; +use stream::Stream; +use super::Client; + +/// Build clients with a builder-style API +/// This makes it easy to create and configure a websocket +/// connection: +/// +/// The easiest way to connect is like this: +/// +/// ```rust,no_run +/// use websocket::ClientBuilder; +/// +/// let client = ClientBuilder::new("ws://myapp.com") +/// .unwrap() +/// .connect_insecure() +/// .unwrap(); +/// ``` +/// +/// But there are so many more possibilities: +/// +/// ```rust,no_run +/// use websocket::ClientBuilder; +/// use websocket::header::{Headers, Cookie}; +/// +/// let default_protos = vec!["ping", "chat"]; +/// let mut my_headers = Headers::new(); +/// my_headers.set(Cookie(vec!["userid=1".to_owned()])); +/// +/// let mut builder = ClientBuilder::new("ws://myapp.com/room/discussion") +/// .unwrap() +/// .add_protocols(default_protos) // any IntoIterator +/// .add_protocol("video-chat") +/// .custom_headers(&my_headers); +/// +/// // connect to a chat server with a user +/// let client = builder.connect_insecure().unwrap(); +/// +/// // clone the builder and take it with you +/// let not_logged_in = builder +/// .clone() +/// .clear_header::() +/// .connect_insecure().unwrap(); +/// ``` +/// +/// You may have noticed we're not using SSL, have no fear, SSL is included! +/// This crate's openssl dependency is optional (and included by default). +/// One can use `connect_secure` to connect to an SSL service, or simply `connect` +/// to choose either SSL or not based on the protocol (`ws://` or `wss://`). +#[derive(Clone, Debug)] +pub struct ClientBuilder<'u> { + url: Cow<'u, Url>, + version: HttpVersion, + headers: Headers, + version_set: bool, + key_set: bool, +} + +impl<'u> ClientBuilder<'u> { + /// Create a client builder from an already parsed Url, + /// because there is no need to parse this will never error. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// use websocket::url::Url; + /// + /// // the parsing error will be handled outside the constructor + /// let url = Url::parse("ws://bitcoins.pizza").unwrap(); + /// + /// let builder = ClientBuilder::from_url(&url); + /// ``` + /// The path of a URL is optional if no port is given then port + /// 80 will be used in the case of `ws://` and port `443` will be + /// used in the case of `wss://`. + pub fn from_url(address: &'u Url) -> Self { + ClientBuilder::init(Cow::Borrowed(address)) + } + + /// Create a client builder from a URL string, this will + /// attempt to parse the URL immediately and return a `ParseError` + /// if the URL is invalid. URLs must be of the form: + /// `[ws or wss]://[domain]:[port]/[path]` + /// The path of a URL is optional if no port is given then port + /// 80 will be used in the case of `ws://` and port `443` will be + /// used in the case of `wss://`. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// let builder = ClientBuilder::new("wss://mycluster.club"); + /// ``` + pub fn new(address: &str) -> Result { + let url = try!(Url::parse(address)); + Ok(ClientBuilder::init(Cow::Owned(url))) + } + + fn init(url: Cow<'u, Url>) -> Self { + ClientBuilder { + url: url, + version: HttpVersion::Http11, + version_set: false, + key_set: false, + headers: Headers::new(), + } + } + + /// Adds a user-defined protocol to the handshake, the server will be + /// given a list of these protocols and will send back the ones it accepts. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// # use websocket::header::WebSocketProtocol; + /// let builder = ClientBuilder::new("wss://my-twitch-clone.rs").unwrap() + /// .add_protocol("my-chat-proto"); + /// + /// let protos = &builder.get_header::().unwrap().0; + /// assert!(protos.contains(&"my-chat-proto".to_string())); + /// ``` + pub fn add_protocol

(mut self, protocol: P) -> Self + where P: Into + { + upsert_header!(self.headers; WebSocketProtocol; { + Some(protos) => protos.0.push(protocol.into()), + None => WebSocketProtocol(vec![protocol.into()]) + }); + self + } + + /// Adds a user-defined protocols to the handshake. + /// This can take many kinds of iterators. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// # use websocket::header::WebSocketProtocol; + /// let builder = ClientBuilder::new("wss://my-twitch-clone.rs").unwrap() + /// .add_protocols(vec!["pubsub", "sub.events"]); + /// + /// let protos = &builder.get_header::().unwrap().0; + /// assert!(protos.contains(&"pubsub".to_string())); + /// assert!(protos.contains(&"sub.events".to_string())); + /// ``` + pub fn add_protocols(mut self, protocols: I) -> Self + where I: IntoIterator, + S: Into + { + let mut protocols: Vec = + protocols.into_iter() + .map(Into::into) + .collect(); + + upsert_header!(self.headers; WebSocketProtocol; { + Some(protos) => protos.0.append(&mut protocols), + None => WebSocketProtocol(protocols) + }); + self + } + + /// Removes all the currently set protocols. + pub fn clear_protocols(mut self) -> Self { + self.headers.remove::(); + self + } + + /// Adds an extension to the connection. + /// Unlike protocols, extensions can be below the application level + /// (like compression). Currently no extensions are supported + /// out-of-the-box but one can still use them by using their own + /// implementation. Support is coming soon though. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// # use websocket::header::{WebSocketExtensions}; + /// # use websocket::header::extensions::Extension; + /// let builder = ClientBuilder::new("wss://skype-for-linux-lol.com").unwrap() + /// .add_extension(Extension { + /// name: "permessage-deflate".to_string(), + /// params: vec![], + /// }); + /// + /// let exts = &builder.get_header::().unwrap().0; + /// assert!(exts.first().unwrap().name == "permessage-deflate"); + /// ``` + pub fn add_extension(mut self, extension: Extension) -> Self { + upsert_header!(self.headers; WebSocketExtensions; { + Some(protos) => protos.0.push(extension), + None => WebSocketExtensions(vec![extension]) + }); + self + } + + /// Adds some extensions to the connection. + /// Currently no extensions are supported out-of-the-box but one can + /// still use them by using their own implementation. Support is coming soon though. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// # use websocket::header::{WebSocketExtensions}; + /// # use websocket::header::extensions::Extension; + /// let builder = ClientBuilder::new("wss://moxie-chat.org").unwrap() + /// .add_extensions(vec![ + /// Extension { + /// name: "permessage-deflate".to_string(), + /// params: vec![], + /// }, + /// Extension { + /// name: "crypt-omemo".to_string(), + /// params: vec![], + /// }, + /// ]); + /// + /// # let exts = &builder.get_header::().unwrap().0; + /// # assert!(exts.first().unwrap().name == "permessage-deflate"); + /// # assert!(exts.last().unwrap().name == "crypt-omemo"); + /// ``` + pub fn add_extensions(mut self, extensions: I) -> Self + where I: IntoIterator + { + let mut extensions: Vec = + extensions.into_iter().collect(); + upsert_header!(self.headers; WebSocketExtensions; { + Some(protos) => protos.0.append(&mut extensions), + None => WebSocketExtensions(extensions) + }); + self + } + + /// Remove all the extensions added to the builder. + pub fn clear_extensions(mut self) -> Self { + self.headers.remove::(); + self + } + + /// Add a custom `Sec-WebSocket-Key` header. + /// Use this only if you know what you're doing, and this almost + /// never has to be used. + pub fn key(mut self, key: [u8; 16]) -> Self { + self.headers.set(WebSocketKey(key)); + self.key_set = true; + self + } + + /// Remove the currently set `Sec-WebSocket-Key` header if any. + pub fn clear_key(mut self) -> Self { + self.headers.remove::(); + self.key_set = false; + self + } + + /// Set the version of the Websocket connection. + /// Currently this library only supports version 13 (from RFC6455), + /// but one could use this library to create the handshake then use an + /// implementation of another websocket version. + pub fn version(mut self, version: WebSocketVersion) -> Self { + self.headers.set(version); + self.version_set = true; + self + } + + /// Unset the websocket version to be the default (WebSocket 13). + pub fn clear_version(mut self) -> Self { + self.headers.remove::(); + self.version_set = false; + self + } + + /// Sets the Origin header of the handshake. + /// Normally in browsers this is used to protect against + /// unauthorized cross-origin use of a WebSocket server, but it is rarely + /// send by non-browser clients. Still, it can be useful. + pub fn origin(mut self, origin: String) -> Self { + self.headers.set(Origin(origin)); + self + } + + /// Remove the Origin header from the handshake. + pub fn clear_origin(mut self) -> Self { + self.headers.remove::(); + self + } + + /// This is a catch all to add random headers to your handshake, + /// the process here is more manual. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// # use websocket::header::{Headers, Authorization}; + /// let mut headers = Headers::new(); + /// headers.set(Authorization("let me in".to_owned())); + /// + /// let builder = ClientBuilder::new("ws://moz.illest").unwrap() + /// .custom_headers(&headers); + /// + /// # let hds = &builder.get_header::>().unwrap().0; + /// # assert!(hds == &"let me in".to_string()); + /// ``` + pub fn custom_headers(mut self, custom_headers: &Headers) -> Self { + self.headers.extend(custom_headers.iter()); + self + } + + /// Remove a type of header from the handshake, this is to be used + /// with the catch all `custom_headers`. + pub fn clear_header(mut self) -> Self + where H: Header + HeaderFormat + { + self.headers.remove::(); + self + } + + /// Get a header to inspect it. + pub fn get_header(&self) -> Option<&H> + where H: Header + HeaderFormat + { + self.headers.get::() + } + + fn establish_tcp(&mut self, secure: Option) -> WebSocketResult { + let port = match (self.url.port(), secure) { + (Some(port), _) => port, + (None, None) if self.url.scheme() == "wss" => 443, + (None, None) => 80, + (None, Some(true)) => 443, + (None, Some(false)) => 80, + }; + let host = match self.url.host_str() { + Some(h) => h, + None => return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName)), + }; + + let tcp_stream = try!(TcpStream::connect((host, port))); + Ok(tcp_stream) + } + + #[cfg(feature="ssl")] + fn wrap_ssl( + &self, + tcp_stream: TcpStream, + connector: Option, + ) -> WebSocketResult> { + let host = match self.url.host_str() { + Some(h) => h, + None => return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName)), + }; + let connector = match connector { + Some(c) => c, + None => try!(SslConnectorBuilder::new(SslMethod::tls())).build(), + }; + + let ssl_stream = try!(connector.connect(host, tcp_stream)); + Ok(ssl_stream) + } + + /// Connect to a server (finally)! + /// This will use a `Box` to represent either an SSL + /// connection or a normal TCP connection, what to use will be decided + /// using the protocol of the URL passed in (e.g. `ws://` or `wss://`) + /// + /// If you have non-default SSL circumstances, you can use the `ssl_config` + /// parameter to configure those. + /// + /// ```rust,no_run + /// # use websocket::ClientBuilder; + /// # use websocket::Message; + /// let mut client = ClientBuilder::new("wss://supersecret.l33t").unwrap() + /// .connect(None) + /// .unwrap(); + /// + /// // send messages! + /// let message = Message::text("m337 47 7pm"); + /// client.send_message(&message).unwrap(); + /// ``` + #[cfg(feature="ssl")] + pub fn connect( + &mut self, + ssl_config: Option, + ) -> WebSocketResult>> { + let tcp_stream = try!(self.establish_tcp(None)); + + let boxed_stream: Box = if + self.url.scheme() == "wss" { + Box::new(try!(self.wrap_ssl(tcp_stream, ssl_config))) + } else { + Box::new(tcp_stream) + }; + + self.connect_on(boxed_stream) + } + + /// Create an insecure (plain TCP) connection to the client. + /// In this case no `Box` will be used you will just get a TcpStream, + /// giving you the ability to split the stream into a reader and writer + /// (since SSL streams cannot be cloned). + /// + /// ```rust,no_run + /// # use websocket::ClientBuilder; + /// let mut client = ClientBuilder::new("wss://supersecret.l33t").unwrap() + /// .connect_insecure() + /// .unwrap(); + /// + /// // split into two (for some reason)! + /// let (receiver, sender) = client.split().unwrap(); + /// ``` + pub fn connect_insecure(&mut self) -> WebSocketResult> { + let tcp_stream = try!(self.establish_tcp(Some(false))); + + self.connect_on(tcp_stream) + } + + /// Create an SSL connection to the sever. + /// This will only use an `SslStream`, this is useful + /// when you want to be sure to connect over SSL or when you want access + /// to the `SslStream` functions (without having to go through a `Box`). + #[cfg(feature="ssl")] + pub fn connect_secure( + &mut self, + ssl_config: Option, + ) -> WebSocketResult>> { + let tcp_stream = try!(self.establish_tcp(Some(true))); + + let ssl_stream = try!(self.wrap_ssl(tcp_stream, ssl_config)); + + self.connect_on(ssl_stream) + } + + // TODO: similar ability for server? + /// Connects to a websocket server on any stream you would like. + /// Possible streams: + /// - Unix Sockets + /// - Logging Middle-ware + /// - SSH + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// use websocket::stream::ReadWritePair; + /// use std::io::Cursor; + /// + /// let accept = b"HTTP/1.1 101 Switching Protocols\r + /// Upgrade: websocket\r + /// Connection: Upgrade\r + /// Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r + /// \r\n"; + /// + /// let input = Cursor::new(&accept[..]); + /// let output = Cursor::new(Vec::new()); + /// + /// let client = ClientBuilder::new("wss://test.ws").unwrap() + /// .key(b"the sample nonce".clone()) + /// .connect_on(ReadWritePair(input, output)) + /// .unwrap(); + /// + /// let text = (client.into_stream().0).1.into_inner(); + /// let text = String::from_utf8(text).unwrap(); + /// assert!(text.contains("dGhlIHNhbXBsZSBub25jZQ=="), "{}", text); + /// ``` + pub fn connect_on(&mut self, mut stream: S) -> WebSocketResult> + where S: Stream + { + let resource = self.url[Position::BeforePath..Position::AfterQuery].to_owned(); + + // enter host if available (unix sockets don't have hosts) + if let Some(host) = self.url.host_str() { + self.headers + .set(Host { + hostname: host.to_string(), + port: self.url.port(), + }); + } + + self.headers + .set(Connection(vec![ + ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) + ])); + + self.headers + .set(Upgrade(vec![ + Protocol { + name: ProtocolName::WebSocket, + version: None, + }, + ])); + + if !self.version_set { + self.headers.set(WebSocketVersion::WebSocket13); + } + + if !self.key_set { + self.headers.set(WebSocketKey::new()); + } + + // send request + try!(write!(stream, "GET {} {}\r\n", resource, self.version)); + try!(write!(stream, "{}\r\n", self.headers)); + + // wait for a response + let mut reader = BufReader::new(stream); + let response = try!(parse_response(&mut reader)); + let status = StatusCode::from_u16(response.subject.0); + + // validate + if status != StatusCode::SwitchingProtocols { + return Err(WebSocketError::ResponseError("Status code must be Switching Protocols")); + } + + let key = try!(self.headers + .get::() + .ok_or(WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid"))); + + if response.headers.get() != Some(&(WebSocketAccept::new(key))) { + return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid")); + } + + if response.headers.get() != + Some(&(Upgrade(vec![ + Protocol { + name: ProtocolName::WebSocket, + version: None, + }, + ]))) { + return Err(WebSocketError::ResponseError("Upgrade field must be WebSocket")); + } + + if self.headers.get() != + Some(&(Connection(vec![ + ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())), + ]))) { + return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); + } + + Ok(Client::unchecked(reader, response.headers)) + } +} + +mod tests { + #[test] + fn build_client_with_protocols() { + use super::*; + let builder = ClientBuilder::new("ws://127.0.0.1:8080/hello/world") + .unwrap() + .add_protocol("protobeard"); + + let protos = &builder.headers.get::().unwrap().0; + assert!(protos.contains(&"protobeard".to_string())); + assert!(protos.len() == 1); + + let builder = ClientBuilder::new("ws://example.org/hello") + .unwrap() + .add_protocol("rust-websocket") + .clear_protocols() + .add_protocols(vec!["electric", "boogaloo"]); + + let protos = &builder.headers.get::().unwrap().0; + + assert!(protos.contains(&"boogaloo".to_string())); + assert!(protos.contains(&"electric".to_string())); + assert!(!protos.contains(&"rust-websocket".to_string())); + } + + // TODO: a few more +} diff --git a/src/client/mod.rs b/src/client/mod.rs index b62302d0bf..2e87651c11 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,170 +1,302 @@ //! Contains the WebSocket client. +extern crate url; use std::net::TcpStream; -use std::marker::PhantomData; +use std::net::SocketAddr; use std::io::Result as IoResult; +use std::io::{Read, Write}; +use hyper::header::Headers; +use hyper::buffer::BufReader; use ws; -use ws::util::url::ToWebSocketUrlComponents; +use ws::sender::Sender as SenderTrait; use ws::receiver::{DataFrameIterator, MessageIterator}; +use ws::receiver::Receiver as ReceiverTrait; use result::WebSocketResult; -use stream::WebSocketStream; +use stream::{AsTcpStream, Stream, Splittable, Shutdown}; use dataframe::DataFrame; -use ws::dataframe::DataFrame as DataFrameable; - -use openssl::ssl::{SslContext, SslMethod, SslStream}; - -pub use self::request::Request; -pub use self::response::Response; +use header::{WebSocketProtocol, WebSocketExtensions}; +use header::extensions::Extension; -pub use sender::Sender; -pub use receiver::Receiver; +use ws::dataframe::DataFrame as DataFrameable; +use sender::Sender; +use receiver::Receiver; +pub use sender::Writer; +pub use receiver::Reader; -pub mod request; -pub mod response; +pub mod builder; +pub use self::builder::{ClientBuilder, Url, ParseError}; /// Represents a WebSocket client, which can send and receive messages/data frames. /// -/// `D` is the data frame type, `S` is the type implementing `Sender` and `R` -/// is the type implementing `Receiver`. -/// -/// For most cases, the data frame type will be `dataframe::DataFrame`, the Sender -/// type will be `client::Sender` and the receiver type -/// will be `client::Receiver`. +/// The client just wraps around a `Stream` (which is something that can be read from +/// and written to) and handles the websocket protocol. TCP or SSL over TCP is common, +/// but any stream can be used. /// -/// A `Client` can be split into a `Sender` and a `Receiver` which can then be moved +/// A `Client` can also be split into a `Reader` and a `Writer` which can then be moved /// to different threads, often using a send loop and receiver loop concurrently, /// as shown in the client example in `examples/client.rs`. +/// This is only possible for streams that implement the `Splittable` trait, which +/// currently is only TCP streams. (it is unsafe to duplicate an SSL stream) /// -///#Connecting to a Server +///# Connecting to a Server /// ///```no_run ///extern crate websocket; ///# fn main() { /// -///use websocket::{Client, Message}; -///use websocket::client::request::Url; +///use websocket::{ClientBuilder, Message}; /// -///let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL -///let request = Client::connect(url).unwrap(); // Connect to the server -///let response = request.send().unwrap(); // Send the request -///response.validate().unwrap(); // Ensure the response is valid -/// -///let mut client = response.begin(); // Get a Client +///let mut client = ClientBuilder::new("ws://127.0.0.1:1234") +/// .unwrap() +/// .connect_insecure() +/// .unwrap(); /// ///let message = Message::text("Hello, World!"); ///client.send_message(&message).unwrap(); // Send message ///# } ///``` -pub struct Client { - sender: S, - receiver: R, - _dataframe: PhantomData +pub struct Client + where S: Stream +{ + stream: BufReader, + headers: Headers, + sender: Sender, + receiver: Receiver, } -impl Client, Receiver> { - /// Connects to the given ws:// or wss:// URL and return a Request to be sent. - /// - /// A connection is established, however the request is not sent to - /// the server until a call to ```send()```. - pub fn connect(components: T) -> WebSocketResult> { - let context = try!(SslContext::new(SslMethod::Tlsv1)); - Client::connect_ssl_context(components, &context) +impl Client { + /// Shuts down the sending half of the client connection, will cause all pending + /// and future IO to return immediately with an appropriate value. + pub fn shutdown_sender(&self) -> IoResult<()> { + self.stream.get_ref().as_tcp().shutdown(Shutdown::Write) } - /// Connects to the specified wss:// URL using the given SSL context. - /// - /// If a ws:// URL is supplied, a normal, non-secure connection is established - /// and the context parameter is ignored. - /// - /// A connection is established, however the request is not sent to - /// the server until a call to ```send()```. - pub fn connect_ssl_context(components: T, context: &SslContext) -> WebSocketResult> { - let (host, resource_name, secure) = try!(components.to_components()); - let connection = try!(TcpStream::connect( - (&host.hostname[..], host.port.unwrap_or(if secure { 443 } else { 80 })) - )); + /// Shuts down the receiving half of the client connection, will cause all pending + /// and future IO to return immediately with an appropriate value. + pub fn shutdown_receiver(&self) -> IoResult<()> { + self.stream.get_ref().as_tcp().shutdown(Shutdown::Read) + } +} - let stream = if secure { - let sslstream = try!(SslStream::connect(context, connection)); - WebSocketStream::Ssl(sslstream) - } - else { - WebSocketStream::Tcp(connection) - }; - - Request::new((host, resource_name, secure), try!(stream.try_clone()), stream) - } - - /// Shuts down the sending half of the client connection, will cause all pending - /// and future IO to return immediately with an appropriate value. - pub fn shutdown_sender(&mut self) -> IoResult<()> { - self.sender.shutdown() - } - - /// Shuts down the receiving half of the client connection, will cause all pending - /// and future IO to return immediately with an appropriate value. - pub fn shutdown_receiver(&mut self) -> IoResult<()> { - self.receiver.shutdown() - } - - /// Shuts down the client connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&mut self) -> IoResult<()> { - self.receiver.shutdown_all() - } +impl Client + where S: AsTcpStream + Stream +{ + /// Shuts down the client connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.stream.get_ref().as_tcp().shutdown(Shutdown::Both) + } + + /// See [`TcpStream::peer_addr`] + /// (https://doc.rust-lang.org/std/net/struct.TcpStream.html#method.peer_addr). + pub fn peer_addr(&self) -> IoResult { + self.stream.get_ref().as_tcp().peer_addr() + } + + /// See [`TcpStream::local_addr`] + /// (https://doc.rust-lang.org/std/net/struct.TcpStream.html#method.local_addr). + pub fn local_addr(&self) -> IoResult { + self.stream.get_ref().as_tcp().local_addr() + } + + /// See [`TcpStream::set_nodelay`] + /// (https://doc.rust-lang.org/std/net/struct.TcpStream.html#method.set_nodelay). + pub fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { + self.stream.get_ref().as_tcp().set_nodelay(nodelay) + } + + /// Changes whether the stream is in nonblocking mode. + pub fn set_nonblocking(&self, nonblocking: bool) -> IoResult<()> { + self.stream.get_ref().as_tcp().set_nonblocking(nonblocking) + } } -impl> Client { - /// Creates a Client from the given Sender and Receiver. - /// - /// Essentially the opposite of `Client.split()`. - pub fn new(sender: S, receiver: R) -> Client { +impl Client + where S: Stream +{ + /// Creates a Client from a given stream + /// **without sending any handshake** this is meant to only be used with + /// a stream that has a websocket connection already set up. + /// If in doubt, don't use this! + #[doc(hidden)] + pub fn unchecked(stream: BufReader, headers: Headers) -> Self { Client { - sender: sender, - receiver: receiver, - _dataframe: PhantomData + headers: headers, + stream: stream, + // NOTE: these are always true & false, see + // https://tools.ietf.org/html/rfc6455#section-5 + sender: Sender::new(true), + receiver: Receiver::new(false), } } + /// Sends a single data frame to the remote endpoint. pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> - where D: DataFrameable { - self.sender.send_dataframe(dataframe) + where D: DataFrameable + { + self.sender.send_dataframe(self.stream.get_mut(), dataframe) } + /// Sends a single message to the remote endpoint. pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> - where M: ws::Message<'m, D>, D: DataFrameable { - self.sender.send_message(message) + where M: ws::Message<'m, D>, + D: DataFrameable + { + self.sender.send_message(self.stream.get_mut(), message) } + /// Reads a single data frame from the remote endpoint. - pub fn recv_dataframe(&mut self) -> WebSocketResult { - self.receiver.recv_dataframe() + pub fn recv_dataframe(&mut self) -> WebSocketResult { + self.receiver.recv_dataframe(&mut self.stream) } + /// Returns an iterator over incoming data frames. - pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, R, F> { - self.receiver.incoming_dataframes() + pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, BufReader> { + self.receiver.incoming_dataframes(&mut self.stream) } + /// Reads a single message from this receiver. pub fn recv_message<'m, M, I>(&mut self) -> WebSocketResult - where M: ws::Message<'m, F, DataFrameIterator = I>, I: Iterator { - self.receiver.recv_message() + where M: ws::Message<'m, DataFrame, DataFrameIterator = I>, + I: Iterator + { + self.receiver.recv_message(&mut self.stream) + } + + /// Access the headers that were sent in the server's handshake response. + /// This is a catch all for headers other than protocols and extensions. + pub fn headers(&self) -> &Headers { + &self.headers + } + + /// **If you supplied a protocol, you must check that it was accepted by + /// the server** using this function. + /// This is not done automatically because the terms of accepting a protocol + /// can get complicated, especially if some protocols depend on others, etc. + /// + /// ```rust,no_run + /// # use websocket::ClientBuilder; + /// let mut client = ClientBuilder::new("wss://test.fysh.in").unwrap() + /// .add_protocol("xmpp") + /// .connect_insecure() + /// .unwrap(); + /// + /// // be sure to check the protocol is there! + /// assert!(client.protocols().iter().any(|p| p as &str == "xmpp")); + /// ``` + pub fn protocols(&self) -> &[String] { + self.headers + .get::() + .map(|p| p.0.as_slice()) + .unwrap_or(&[]) + } + + /// If you supplied a protocol, be sure to check if it was accepted by the + /// server here. Since no extensions are implemented out of the box yet, using + /// one will require its own implementation. + pub fn extensions(&self) -> &[Extension] { + self.headers + .get::() + .map(|e| e.0.as_slice()) + .unwrap_or(&[]) + } + + /// Get a reference to the stream. + /// Useful to be able to set options on the stream. + /// + /// ```rust,no_run + /// # use websocket::ClientBuilder; + /// let mut client = ClientBuilder::new("ws://double.down").unwrap() + /// .connect_insecure() + /// .unwrap(); + /// + /// client.stream_ref().set_ttl(60).unwrap(); + /// ``` + pub fn stream_ref(&self) -> &S { + self.stream.get_ref() + } + + /// Get a handle to the writable portion of this stream. + /// This can be used to write custom extensions. + /// + /// ```rust,no_run + /// # use websocket::ClientBuilder; + /// use websocket::Message; + /// use websocket::ws::sender::Sender as SenderTrait; + /// use websocket::sender::Sender; + /// + /// let mut client = ClientBuilder::new("ws://the.room").unwrap() + /// .connect_insecure() + /// .unwrap(); + /// + /// let message = Message::text("Oh hi, Mark."); + /// let mut sender = Sender::new(true); + /// let mut buf = Vec::new(); + /// + /// sender.send_message(&mut buf, &message); + /// + /// /* transform buf somehow */ + /// + /// client.writer_mut().write_all(&buf); + /// ``` + pub fn writer_mut(&mut self) -> &mut Write { + self.stream.get_mut() + } + + /// Get a handle to the readable portion of this stream. + /// This can be used to transform raw bytes before they + /// are read in. + /// + /// ```rust,no_run + /// # use websocket::ClientBuilder; + /// use std::io::Cursor; + /// use websocket::Message; + /// use websocket::ws::receiver::Receiver as ReceiverTrait; + /// use websocket::receiver::Receiver; + /// + /// let mut client = ClientBuilder::new("ws://the.room").unwrap() + /// .connect_insecure() + /// .unwrap(); + /// + /// let mut receiver = Receiver::new(false); + /// let mut buf = Vec::new(); + /// + /// client.reader_mut().read_to_end(&mut buf); + /// + /// /* transform buf somehow */ + /// + /// let mut buf_reader = Cursor::new(&mut buf); + /// let message: Message = receiver.recv_message(&mut buf_reader).unwrap(); + /// ``` + pub fn reader_mut(&mut self) -> &mut Read { + &mut self.stream } + + /// Deconstruct the client into its underlying stream and + /// maybe some of the buffer that was already read from the stream. + /// The client uses a buffered reader to read in messages, so some + /// bytes might already be read from the stream when this is called, + /// these buffered bytes are returned in the form + /// + /// `(byte_buffer: Vec, buffer_capacity: usize, buffer_position: usize)` + pub fn into_stream(self) -> (S, Option<(Vec, usize, usize)>) { + let (stream, buf, pos, cap) = self.stream.into_parts(); + (stream, Some((buf, pos, cap))) + } + /// Returns an iterator over incoming messages. /// ///```no_run ///# extern crate websocket; ///# fn main() { - ///use websocket::{Client, Message}; - ///# use websocket::client::request::Url; - ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL - ///# let request = Client::connect(url).unwrap(); // Connect to the server - ///# let response = request.send().unwrap(); // Send the request - ///# response.validate().unwrap(); // Ensure the response is valid + ///use websocket::{ClientBuilder, Message}; /// - ///let mut client = response.begin(); // Get a Client + ///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() + /// .connect(None).unwrap(); /// ///for message in client.incoming_messages() { - /// let message: Message = message.unwrap(); + /// let message: Message = message.unwrap(); /// println!("Recv: {:?}", message); ///} ///# } @@ -177,44 +309,32 @@ impl> Client { ///```no_run ///# extern crate websocket; ///# fn main() { - ///use websocket::{Client, Message, Sender, Receiver}; - ///# use websocket::client::request::Url; - ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL - ///# let request = Client::connect(url).unwrap(); // Connect to the server - ///# let response = request.send().unwrap(); // Send the request - ///# response.validate().unwrap(); // Ensure the response is valid - /// - ///let client = response.begin(); // Get a Client - ///let (mut sender, mut receiver) = client.split(); // Split the Client + ///use websocket::{ClientBuilder, Message}; + /// + ///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() + /// .connect_insecure().unwrap(); + /// + ///let (mut receiver, mut sender) = client.split().unwrap(); + /// ///for message in receiver.incoming_messages() { - /// let message: Message = message.unwrap(); + /// let message: Message = message.unwrap(); /// // Echo the message back /// sender.send_message(&message).unwrap(); ///} ///# } ///``` - pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, R, D, F, M> - where M: ws::Message<'a, D>, - D: DataFrameable - { - self.receiver.incoming_messages() - } - /// Returns a reference to the underlying Sender. - pub fn get_sender(&self) -> &S { - &self.sender - } - /// Returns a reference to the underlying Receiver. - pub fn get_receiver(&self) -> &R { - &self.receiver - } - /// Returns a mutable reference to the underlying Sender. - pub fn get_mut_sender(&mut self) -> &mut S { - &mut self.sender - } - /// Returns a mutable reference to the underlying Receiver. - pub fn get_mut_receiver(&mut self) -> &mut R { - &mut self.receiver + pub fn incoming_messages<'a, M, D>(&'a mut self,) + -> MessageIterator<'a, Receiver, D, M, BufReader> + where M: ws::Message<'a, D>, + D: DataFrameable + { + self.receiver.incoming_messages(&mut self.stream) } +} + +impl Client + where S: Splittable + Stream +{ /// Split this client into its constituent Sender and Receiver pair. /// /// This allows the Sender and Receiver to be sent to different threads. @@ -222,21 +342,17 @@ impl> Client { ///```no_run ///# extern crate websocket; ///# fn main() { - ///use websocket::{Client, Message, Sender, Receiver}; ///use std::thread; - ///# use websocket::client::request::Url; - ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL - ///# let request = Client::connect(url).unwrap(); // Connect to the server - ///# let response = request.send().unwrap(); // Send the request - ///# response.validate().unwrap(); // Ensure the response is valid + ///use websocket::{ClientBuilder, Message}; /// - ///let client = response.begin(); // Get a Client + ///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() + /// .connect_insecure().unwrap(); /// - ///let (mut sender, mut receiver) = client.split(); + ///let (mut receiver, mut sender) = client.split().unwrap(); /// ///thread::spawn(move || { /// for message in receiver.incoming_messages() { - /// let message: Message = message.unwrap(); + /// let message: Message = message.unwrap(); /// println!("Recv: {:?}", message); /// } ///}); @@ -245,7 +361,18 @@ impl> Client { ///sender.send_message(&message).unwrap(); ///# } ///``` - pub fn split(self) -> (S, R) { - (self.sender, self.receiver) + pub fn split + (self,) + -> IoResult<(Reader<::Reader>, Writer<::Writer>)> { + let (stream, buf, pos, cap) = self.stream.into_parts(); + let (read, write) = try!(stream.split()); + Ok((Reader { + stream: BufReader::from_parts(read, buf, pos, cap), + receiver: self.receiver, + }, + Writer { + stream: write, + sender: self.sender, + })) } } diff --git a/src/client/request.rs b/src/client/request.rs deleted file mode 100644 index 0e470b761a..0000000000 --- a/src/client/request.rs +++ /dev/null @@ -1,144 +0,0 @@ -//! Structs for client-side (outbound) WebSocket requests -use std::io::{Read, Write}; - -pub use url::Url; - -use hyper::version::HttpVersion; -use hyper::buffer::BufReader; -use hyper::header::Headers; -use hyper::header::{Connection, ConnectionOption}; -use hyper::header::{Upgrade, Protocol, ProtocolName}; - -use unicase::UniCase; - -use header::{WebSocketKey, WebSocketVersion, WebSocketProtocol, WebSocketExtensions, Origin}; -use result::WebSocketResult; -use client::response::Response; -use ws::util::url::ToWebSocketUrlComponents; - -/// Represents a WebSocket request. -/// -/// Note that nothing is written to the internal Writer until the `send()` method is called. -pub struct Request { - /// The HTTP version of this request. - pub version: HttpVersion, - /// The headers of this request. - pub headers: Headers, - - resource_name: String, - reader: BufReader, - writer: W, -} - -unsafe impl Send for Request where R: Read + Send, W: Write + Send { } - -impl Request { - /// Creates a new client-side request. - /// - /// In general `Client::connect()` should be used for connecting to servers. - /// However, if the request is to be written to a different Writer, this function - /// may be used. - pub fn new(components: T, reader: R, writer: W) -> WebSocketResult> { - let mut headers = Headers::new(); - let (host, resource_name, _) = try!(components.to_components()); - headers.set(host); - headers.set(Connection(vec![ - ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) - ])); - headers.set(Upgrade(vec![Protocol{ - name: ProtocolName::WebSocket, - version: None - }])); - headers.set(WebSocketVersion::WebSocket13); - headers.set(WebSocketKey::new()); - - Ok(Request { - version: HttpVersion::Http11, - headers: headers, - resource_name: resource_name, - reader: BufReader::new(reader), - writer: writer - }) - } - /// Short-cut to obtain the WebSocketKey value. - pub fn key(&self) -> Option<&WebSocketKey> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketVersion value. - pub fn version(&self) -> Option<&WebSocketVersion> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketProtocol value. - pub fn protocol(&self) -> Option<&WebSocketProtocol> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketExtensions value. - pub fn extensions(&self) -> Option<&WebSocketExtensions> { - self.headers.get() - } - /// Short-cut to obtain the Origin value. - pub fn origin(&self) -> Option<&Origin> { - self.headers.get() - } - /// Short-cut to obtain a mutable reference to the WebSocketKey value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn key_mut(&mut self) -> Option<&mut WebSocketKey> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the WebSocketVersion value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn version_mut(&mut self) -> Option<&mut WebSocketVersion> { - self.headers.get_mut() - } - /// Short-cut to obtaina mutable reference to the WebSocketProtocol value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn protocol_mut(&mut self) -> Option<&mut WebSocketProtocol> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the WebSocketExtensions value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn extensions_mut(&mut self) -> Option<&mut WebSocketExtensions> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the Origin value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn origin_mut(&mut self) -> Option<&mut Origin> { - self.headers.get_mut() - } - /// Returns a reference to the inner Reader. - pub fn get_reader(&self) -> &BufReader { - &self.reader - } - /// Returns a reference to the inner Writer. - pub fn get_writer(&self) -> &W { - &self.writer - } - /// Returns a mutable reference to the inner Reader. - pub fn get_mut_reader(&mut self) -> &mut BufReader { - &mut self.reader - } - /// Returns a mutable reference to the inner Writer. - pub fn get_mut_writer(&mut self) -> &mut W { - &mut self.writer - } - /// Return the inner Reader and Writer. - pub fn into_inner(self) -> (BufReader, W) { - (self.reader, self.writer) - } - /// Sends the request to the server and returns a response. - pub fn send(mut self) -> WebSocketResult> { - try!(write!(&mut self.writer, "GET {} {}\r\n", self.resource_name, self.version)); - try!(write!(&mut self.writer, "{}\r\n", self.headers)); - Response::read(self) - } -} diff --git a/src/client/response.rs b/src/client/response.rs deleted file mode 100644 index e60b3fa31c..0000000000 --- a/src/client/response.rs +++ /dev/null @@ -1,136 +0,0 @@ -//! Structs for WebSocket responses -use std::option::Option; -use std::io::{Read, Write}; - -use hyper::status::StatusCode; -use hyper::buffer::BufReader; -use hyper::version::HttpVersion; -use hyper::header::Headers; -use hyper::header::{Connection, ConnectionOption}; -use hyper::header::{Upgrade, Protocol, ProtocolName}; -use hyper::http::h1::parse_response; - -use unicase::UniCase; - -use header::{WebSocketAccept, WebSocketProtocol, WebSocketExtensions}; - -use client::{Client, Request, Sender, Receiver}; -use result::{WebSocketResult, WebSocketError}; -use dataframe::DataFrame; -use ws::dataframe::DataFrame as DataFrameable; -use ws; - -/// Represents a WebSocket response. -pub struct Response { - /// The status of the response - pub status: StatusCode, - /// The headers contained in this response - pub headers: Headers, - /// The HTTP version of this response - pub version: HttpVersion, - - request: Request -} - -unsafe impl Send for Response where R: Read + Send, W: Write + Send { } - -impl Response { - /// Reads a Response off the stream associated with a Request. - /// - /// This is called by Request.send(), and does not need to be called by the user. - pub fn read(mut request: Request) -> WebSocketResult> { - let (status, version, headers) = { - let reader = request.get_mut_reader(); - - let response = try!(parse_response(reader)); - - let status = StatusCode::from_u16(response.subject.0); - (status, response.version, response.headers) - }; - - Ok(Response { - status: status, - headers: headers, - version: version, - request: request - }) - } - - /// Short-cut to obtain the WebSocketAccept value. - pub fn accept(&self) -> Option<&WebSocketAccept> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketProtocol value. - pub fn protocol(&self) -> Option<&WebSocketProtocol> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketExtensions value. - pub fn extensions(&self) -> Option<&WebSocketExtensions> { - self.headers.get() - } - /// Returns a reference to the inner Reader. - pub fn get_reader(&self) -> &BufReader { - self.request.get_reader() - } - /// Returns a reference to the inner Writer. - pub fn get_writer(&self) -> &W { - self.request.get_writer() - } - /// Returns a mutable reference to the inner Reader. - pub fn get_mut_reader(&mut self) -> &mut BufReader { - self.request.get_mut_reader() - } - /// Returns a mutable reference to the inner Writer. - pub fn get_mut_writer(&mut self) -> &mut W { - self.request.get_mut_writer() - } - /// Returns a reference to the request associated with this response. - pub fn get_request(&self) -> &Request { - &self.request - } - /// Return the inner Reader and Writer. - pub fn into_inner(self) -> (BufReader, W) { - self.request.into_inner() - } - - /// Check if this response constitutes a successful handshake. - pub fn validate(&self) -> WebSocketResult<()> { - if self.status != StatusCode::SwitchingProtocols { - return Err(WebSocketError::ResponseError("Status code must be Switching Protocols")); - } - let key = try!(self.request.key().ok_or( - WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid") - )); - if self.accept() != Some(&(WebSocketAccept::new(key))) { - return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid")); - } - if self.headers.get() != Some(&(Upgrade(vec![Protocol{ - name: ProtocolName::WebSocket, - version: None - }]))) { - return Err(WebSocketError::ResponseError("Upgrade field must be WebSocket")); - } - if self.headers.get() != Some(&(Connection(vec![ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string()))]))) { - return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); - } - Ok(()) - } - - /// Consume this response and return a Client ready to transmit/receive data frames - /// using the data frame type D, Sender B and Receiver C. - /// - /// Does not check if the response was valid. Use `validate()` to ensure that the response constitutes a successful handshake. - pub fn begin_with(self, sender: B, receiver: C) -> Client - where B: ws::Sender, C: ws::Receiver, D: DataFrameable { - Client::new(sender, receiver) - } - /// Consume this response and return a Client ready to transmit/receive data frames. - /// - /// Does not check if the response was valid. Use `validate()` to ensure that the response constitutes a successful handshake. - pub fn begin(self) -> Client, Receiver> { - let (reader, writer) = self.into_inner(); - let sender = Sender::new(writer, true); - let receiver = Receiver::new(reader, false); - Client::new(sender, receiver) - } -} diff --git a/src/dataframe.rs b/src/dataframe.rs index cb861692a9..9f08fe2747 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -37,58 +37,55 @@ impl DataFrame { } } - /// Reads a DataFrame from a Reader. - pub fn read_dataframe(reader: &mut R, should_be_masked: bool) -> WebSocketResult - where R: Read { - let header = try!(dfh::read_header(reader)); + /// Reads a DataFrame from a Reader. + pub fn read_dataframe(reader: &mut R, should_be_masked: bool) -> WebSocketResult + where R: Read + { + let header = try!(dfh::read_header(reader)); - Ok(DataFrame { - finished: header.flags.contains(dfh::FIN), - reserved: [ - header.flags.contains(dfh::RSV1), - header.flags.contains(dfh::RSV2), - header.flags.contains(dfh::RSV3) - ], - opcode: Opcode::new(header.opcode).expect("Invalid header opcode!"), - data: match header.mask { - Some(mask) => { - if !should_be_masked { - return Err(WebSocketError::DataFrameError( - "Expected unmasked data frame" - )); - } - let mut data: Vec = Vec::with_capacity(header.len as usize); - try!(reader.take(header.len).read_to_end(&mut data)); - mask::mask_data(mask, &data) - } - None => { - if should_be_masked { - return Err(WebSocketError::DataFrameError( - "Expected masked data frame" - )); - } - let mut data: Vec = Vec::with_capacity(header.len as usize); - try!(reader.take(header.len).read_to_end(&mut data)); - data - } - } - }) - } + Ok(DataFrame { + finished: header.flags.contains(dfh::FIN), + reserved: [ + header.flags.contains(dfh::RSV1), + header.flags.contains(dfh::RSV2), + header.flags.contains(dfh::RSV3), + ], + opcode: Opcode::new(header.opcode).expect("Invalid header opcode!"), + data: match header.mask { + Some(mask) => { + if !should_be_masked { + return Err(WebSocketError::DataFrameError("Expected unmasked data frame")); + } + let mut data: Vec = Vec::with_capacity(header.len as usize); + try!(reader.take(header.len).read_to_end(&mut data)); + mask::mask_data(mask, &data) + } + None => { + if should_be_masked { + return Err(WebSocketError::DataFrameError("Expected masked data frame")); + } + let mut data: Vec = Vec::with_capacity(header.len as usize); + try!(reader.take(header.len).read_to_end(&mut data)); + data + } + }, + }) + } } impl DataFrameable for DataFrame { #[inline(always)] - fn is_last(&self) -> bool { + fn is_last(&self) -> bool { self.finished } #[inline(always)] - fn opcode(&self) -> u8 { + fn opcode(&self) -> u8 { self.opcode as u8 } #[inline(always)] - fn reserved<'a>(&'a self) -> &'a [bool; 3] { + fn reserved<'a>(&'a self) -> &'a [bool; 3] { &self.reserved } @@ -141,62 +138,60 @@ impl Opcode { /// Returns the Opcode, or None if the opcode is out of range. pub fn new(op: u8) -> Option { Some(match op { - 0 => Opcode::Continuation, - 1 => Opcode::Text, - 2 => Opcode::Binary, - 3 => Opcode::NonControl1, - 4 => Opcode::NonControl2, - 5 => Opcode::NonControl3, - 6 => Opcode::NonControl4, - 7 => Opcode::NonControl5, - 8 => Opcode::Close, - 9 => Opcode::Ping, - 10 => Opcode::Pong, - 11 => Opcode::Control1, - 12 => Opcode::Control2, - 13 => Opcode::Control3, - 14 => Opcode::Control4, - 15 => Opcode::Control5, - _ => return None, - }) + 0 => Opcode::Continuation, + 1 => Opcode::Text, + 2 => Opcode::Binary, + 3 => Opcode::NonControl1, + 4 => Opcode::NonControl2, + 5 => Opcode::NonControl3, + 6 => Opcode::NonControl4, + 7 => Opcode::NonControl5, + 8 => Opcode::Close, + 9 => Opcode::Ping, + 10 => Opcode::Pong, + 11 => Opcode::Control1, + 12 => Opcode::Control2, + 13 => Opcode::Control3, + 14 => Opcode::Control4, + 15 => Opcode::Control5, + _ => return None, + }) } } #[cfg(all(feature = "nightly", test))] mod tests { - use super::*; + use super::*; use ws::dataframe::DataFrame as DataFrameable; - use test::Bencher; + use test::Bencher; - #[test] - fn test_read_dataframe() { - let data = b"The quick brown fox jumps over the lazy dog"; - let mut dataframe = vec![0x81, 0x2B]; - for i in data.iter() { - dataframe.push(*i); - } - let obtained = DataFrame::read_dataframe(&mut &dataframe[..], false).unwrap(); - let expected = DataFrame { - finished: true, - reserved: [false; 3], - opcode: Opcode::Text, - data: data.to_vec() - }; - assert_eq!(obtained, expected); - } - #[bench] + #[test] + fn test_read_dataframe() { + let data = b"The quick brown fox jumps over the lazy dog"; + let mut dataframe = vec![0x81, 0x2B]; + for i in data.iter() { + dataframe.push(*i); + } + let obtained = DataFrame::read_dataframe(&mut &dataframe[..], false).unwrap(); + let expected = DataFrame { + finished: true, + reserved: [false; 3], + opcode: Opcode::Text, + data: data.to_vec(), + }; + assert_eq!(obtained, expected); + } + #[bench] fn bench_read_dataframe(b: &mut Bencher) { let data = b"The quick brown fox jumps over the lazy dog"; let mut dataframe = vec![0x81, 0x2B]; for i in data.iter() { dataframe.push(*i); } - b.iter(|| { - DataFrame::read_dataframe(&mut &dataframe[..], false).unwrap(); - }); + b.iter(|| { DataFrame::read_dataframe(&mut &dataframe[..], false).unwrap(); }); } - #[test] + #[test] fn test_write_dataframe() { let data = b"The quick brown fox jumps over the lazy dog"; let mut expected = vec![0x81, 0x2B]; @@ -207,10 +202,10 @@ mod tests { finished: true, reserved: [false; 3], opcode: Opcode::Text, - data: data.to_vec() + data: data.to_vec(), }; let mut obtained = Vec::new(); - dataframe.write_to(&mut obtained, false).unwrap(); + dataframe.write_to(&mut obtained, false).unwrap(); assert_eq!(&obtained[..], &expected[..]); } @@ -222,11 +217,9 @@ mod tests { finished: true, reserved: [false; 3], opcode: Opcode::Text, - data: data.to_vec() + data: data.to_vec(), }; let mut writer = Vec::with_capacity(45); - b.iter(|| { - dataframe.write_to(&mut writer, false).unwrap(); - }); + b.iter(|| { dataframe.write_to(&mut writer, false).unwrap(); }); } } diff --git a/src/header/accept.rs b/src/header/accept.rs index 8bc0ddab15..aabb718ece 100644 --- a/src/header/accept.rs +++ b/src/header/accept.rs @@ -5,8 +5,8 @@ use std::fmt::{self, Debug}; use std::str::FromStr; use serialize::base64::{ToBase64, FromBase64, STANDARD}; use header::WebSocketKey; -use openssl::crypto::hash::{self, hash}; use result::{WebSocketResult, WebSocketError}; +use sha1::Sha1; static MAGIC_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -27,9 +27,7 @@ impl FromStr for WebSocketAccept { match accept.from_base64() { Ok(vec) => { if vec.len() != 20 { - return Err(WebSocketError::ProtocolError( - "Sec-WebSocket-Accept must be 20 bytes" - )); + return Err(WebSocketError::ProtocolError("Sec-WebSocket-Accept must be 20 bytes")); } let mut array = [0u8; 20]; let mut iter = vec.into_iter(); @@ -39,9 +37,7 @@ impl FromStr for WebSocketAccept { Ok(WebSocketAccept(array)) } Err(_) => { - return Err(WebSocketError::ProtocolError( - "Invalid Sec-WebSocket-Accept " - )); + return Err(WebSocketError::ProtocolError("Invalid Sec-WebSocket-Accept ")); } } } @@ -54,12 +50,9 @@ impl WebSocketAccept { let mut concat_key = String::with_capacity(serialized.len() + 36); concat_key.push_str(&serialized[..]); concat_key.push_str(MAGIC_GUID); - let output = hash(hash::Type::SHA1, concat_key.as_bytes()); - let mut iter = output.into_iter(); - let mut bytes = [0u8; 20]; - for i in bytes.iter_mut() { - *i = iter.next().unwrap(); - } + let mut sha1 = Sha1::new(); + sha1.update(concat_key.as_bytes()); + let bytes = sha1.digest().bytes(); WebSocketAccept(bytes) } /// Return the Base64 encoding of this WebSocketAccept @@ -91,39 +84,40 @@ mod tests { use test; use std::str::FromStr; use header::{Headers, WebSocketKey}; - use hyper::header::{Header, HeaderFormatter}; + use hyper::header::Header; + #[test] fn test_header_accept() { let key = FromStr::from_str("dGhlIHNhbXBsZSBub25jZQ==").unwrap(); let accept = WebSocketAccept::new(&key); let mut headers = Headers::new(); headers.set(accept); - - assert_eq!(&headers.to_string()[..], "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"); + + assert_eq!(&headers.to_string()[..], + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"); } #[bench] fn bench_header_accept_new(b: &mut test::Bencher) { let key = WebSocketKey::new(); b.iter(|| { - let mut accept = WebSocketAccept::new(&key); - test::black_box(&mut accept); - }); + let mut accept = WebSocketAccept::new(&key); + test::black_box(&mut accept); + }); } #[bench] fn bench_header_accept_parse(b: &mut test::Bencher) { let value = vec![b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".to_vec()]; b.iter(|| { - let mut accept: WebSocketAccept = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut accept); - }); + let mut accept: WebSocketAccept = Header::parse_header(&value[..]).unwrap(); + test::black_box(&mut accept); + }); } #[bench] fn bench_header_accept_format(b: &mut test::Bencher) { let value = vec![b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".to_vec()]; let val: WebSocketAccept = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); b.iter(|| { - format!("{}", fmt); - }); + format!("{}", val.serialize()); + }); } } diff --git a/src/header/extensions.rs b/src/header/extensions.rs index 07854ad109..efbb7537f8 100644 --- a/src/header/extensions.rs +++ b/src/header/extensions.rs @@ -8,6 +8,10 @@ use std::str::FromStr; use std::ops::Deref; use result::{WebSocketResult, WebSocketError}; +const INVALID_EXTENSION: &'static str = "Invalid Sec-WebSocket-Extensions extension name"; + +// TODO: check if extension name is valid according to spec + /// Represents a Sec-WebSocket-Extensions header #[derive(PartialEq, Clone, Debug)] pub struct WebSocketExtensions(pub Vec); @@ -15,9 +19,9 @@ pub struct WebSocketExtensions(pub Vec); impl Deref for WebSocketExtensions { type Target = Vec; - fn deref<'a>(&'a self) -> &'a Vec { - &self.0 - } + fn deref<'a>(&'a self) -> &'a Vec { + &self.0 + } } #[derive(PartialEq, Clone, Debug)] @@ -26,7 +30,7 @@ pub struct Extension { /// The name of this extension pub name: String, /// The parameters for this extension - pub params: Vec + pub params: Vec, } impl Extension { @@ -34,43 +38,42 @@ impl Extension { pub fn new(name: String) -> Extension { Extension { name: name, - params: Vec::new() + params: Vec::new(), } } } impl FromStr for Extension { type Err = WebSocketError; - + fn from_str(s: &str) -> WebSocketResult { let mut ext = s.split(';').map(|x| x.trim()); Ok(Extension { - name: match ext.next() { - Some(x) => x.to_string(), - None => return Err(WebSocketError::ProtocolError( - "Invalid Sec-WebSocket-Extensions extension name" - )), - }, - params: ext.map(|x| { - let mut pair = x.splitn(1, '=').map(|x| x.trim().to_string()); - - Parameter { - name: pair.next().unwrap(), - value: pair.next() - } - }).collect() - }) + name: match ext.next() { + Some(x) => x.to_string(), + None => return Err(WebSocketError::ProtocolError(INVALID_EXTENSION)), + }, + params: ext.map(|x| { + let mut pair = x.splitn(1, '=').map(|x| x.trim().to_string()); + + Parameter { + name: pair.next().unwrap(), + value: pair.next(), + } + }) + .collect(), + }) } } impl fmt::Display for Extension { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - try!(write!(f, "{}", self.name)); + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + try!(write!(f, "{}", self.name)); for param in self.params.iter() { try!(write!(f, "; {}", param)); } Ok(()) - } + } } #[derive(PartialEq, Clone, Debug)] @@ -79,7 +82,7 @@ pub struct Parameter { /// The name of this parameter pub name: String, /// The value of this parameter, if any - pub value: Option + pub value: Option, } impl Parameter { @@ -87,20 +90,20 @@ impl Parameter { pub fn new(name: String, value: Option) -> Parameter { Parameter { name: name, - value: value + value: value, } } } impl fmt::Display for Parameter { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - try!(write!(f, "{}", self.name)); + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + try!(write!(f, "{}", self.name)); match self.value { Some(ref x) => try!(write!(f, "={}", x)), None => (), } Ok(()) - } + } } impl Header for WebSocketExtensions { @@ -120,37 +123,44 @@ impl HeaderFormat for WebSocketExtensions { } } +impl fmt::Display for WebSocketExtensions { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(fmt) + } +} + #[cfg(all(feature = "nightly", test))] mod tests { use super::*; - use hyper::header::{Header, HeaderFormatter}; + use hyper::header::Header; use test; #[test] fn test_header_extensions() { use header::Headers; let value = vec![b"foo, bar; baz; qux=quux".to_vec()]; let extensions: WebSocketExtensions = Header::parse_header(&value[..]).unwrap(); - + let mut headers = Headers::new(); headers.set(extensions); - - assert_eq!(&headers.to_string()[..], "Sec-WebSocket-Extensions: foo, bar; baz; qux=quux\r\n"); + + assert_eq!(&headers.to_string()[..], + "Sec-WebSocket-Extensions: foo, bar; baz; qux=quux\r\n"); } #[bench] fn bench_header_extensions_parse(b: &mut test::Bencher) { let value = vec![b"foo, bar; baz; qux=quux".to_vec()]; b.iter(|| { - let mut extensions: WebSocketExtensions = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut extensions); - }); + let mut extensions: WebSocketExtensions = Header::parse_header(&value[..]) + .unwrap(); + test::black_box(&mut extensions); + }); } #[bench] fn bench_header_extensions_format(b: &mut test::Bencher) { let value = vec![b"foo, bar; baz; qux=quux".to_vec()]; let val: WebSocketExtensions = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); b.iter(|| { - format!("{}", fmt); - }); + format!("{}", val); + }); } -} \ No newline at end of file +} diff --git a/src/header/key.rs b/src/header/key.rs index 31706aabf1..3e3331b953 100644 --- a/src/header/key.rs +++ b/src/header/key.rs @@ -25,22 +25,18 @@ impl FromStr for WebSocketKey { match key.from_base64() { Ok(vec) => { if vec.len() != 16 { - return Err(WebSocketError::ProtocolError( - "Sec-WebSocket-Key must be 16 bytes" - )); + return Err(WebSocketError::ProtocolError("Sec-WebSocket-Key must be 16 bytes")); } let mut array = [0u8; 16]; let mut iter = vec.into_iter(); for i in array.iter_mut() { *i = iter.next().unwrap(); } - + Ok(WebSocketKey(array)) } Err(_) => { - return Err(WebSocketError::ProtocolError( - "Invalid Sec-WebSocket-Accept" - )); + return Err(WebSocketError::ProtocolError("Invalid Sec-WebSocket-Accept")); } } } @@ -51,9 +47,7 @@ impl WebSocketKey { pub fn new() -> WebSocketKey { let key: [u8; 16] = unsafe { // Much faster than calling random() several times - mem::transmute( - rand::random::<(u64, u64)>() - ) + mem::transmute(rand::random::<(u64, u64)>()) }; WebSocketKey(key) } @@ -83,40 +77,40 @@ impl HeaderFormat for WebSocketKey { #[cfg(all(feature = "nightly", test))] mod tests { use super::*; - use hyper::header::{Header, HeaderFormatter}; + use hyper::header::Header; use test; #[test] fn test_header_key() { use header::Headers; - + let extensions = WebSocketKey([65; 16]); let mut headers = Headers::new(); headers.set(extensions); - - assert_eq!(&headers.to_string()[..], "Sec-WebSocket-Key: QUFBQUFBQUFBQUFBQUFBQQ==\r\n"); + + assert_eq!(&headers.to_string()[..], + "Sec-WebSocket-Key: QUFBQUFBQUFBQUFBQUFBQQ==\r\n"); } #[bench] fn bench_header_key_new(b: &mut test::Bencher) { b.iter(|| { - let mut key = WebSocketKey::new(); - test::black_box(&mut key); - }); + let mut key = WebSocketKey::new(); + test::black_box(&mut key); + }); } #[bench] fn bench_header_key_parse(b: &mut test::Bencher) { let value = vec![b"QUFBQUFBQUFBQUFBQUFBQQ==".to_vec()]; b.iter(|| { - let mut key: WebSocketKey = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut key); - }); + let mut key: WebSocketKey = Header::parse_header(&value[..]).unwrap(); + test::black_box(&mut key); + }); } #[bench] fn bench_header_key_format(b: &mut test::Bencher) { let value = vec![b"QUFBQUFBQUFBQUFBQUFBQQ==".to_vec()]; let val: WebSocketKey = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); b.iter(|| { - format!("{}", fmt); - }); + format!("{}", val.serialize()); + }); } -} \ No newline at end of file +} diff --git a/src/header/mod.rs b/src/header/mod.rs index 9fc0dda0ae..48a73bfbcc 100644 --- a/src/header/mod.rs +++ b/src/header/mod.rs @@ -9,11 +9,11 @@ pub use self::protocol::WebSocketProtocol; pub use self::version::WebSocketVersion; pub use self::extensions::WebSocketExtensions; pub use self::origin::Origin; -pub use hyper::header::Headers; +pub use hyper::header::*; mod accept; mod key; mod protocol; mod version; pub mod extensions; -mod origin; \ No newline at end of file +mod origin; diff --git a/src/header/origin.rs b/src/header/origin.rs index 8209fe1d29..3c3d807106 100644 --- a/src/header/origin.rs +++ b/src/header/origin.rs @@ -10,9 +10,9 @@ pub struct Origin(pub String); impl Deref for Origin { type Target = String; - fn deref<'a>(&'a self) -> &'a String { - &self.0 - } + fn deref<'a>(&'a self) -> &'a String { + &self.0 + } } impl Header for Origin { @@ -28,40 +28,45 @@ impl Header for Origin { impl HeaderFormat for Origin { fn fmt_header(&self, fmt: &mut fmt::Formatter) -> fmt::Result { let Origin(ref value) = *self; - write!(fmt, "{}", value) + write!(fmt, "{}", value) + } +} + +impl fmt::Display for Origin { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(fmt) } } #[cfg(all(feature = "nightly", test))] mod tests { use super::*; - use hyper::header::{Header, HeaderFormatter}; + use hyper::header::Header; use test; #[test] fn test_header_origin() { use header::Headers; - + let origin = Origin("foo bar".to_string()); let mut headers = Headers::new(); headers.set(origin); - + assert_eq!(&headers.to_string()[..], "Origin: foo bar\r\n"); } #[bench] fn bench_header_origin_parse(b: &mut test::Bencher) { let value = vec![b"foobar".to_vec()]; b.iter(|| { - let mut origin: Origin = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut origin); - }); + let mut origin: Origin = Header::parse_header(&value[..]).unwrap(); + test::black_box(&mut origin); + }); } #[bench] fn bench_header_origin_format(b: &mut test::Bencher) { let value = vec![b"foobar".to_vec()]; let val: Origin = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); b.iter(|| { - format!("{}", fmt); - }); + format!("{}", val); + }); } -} \ No newline at end of file +} diff --git a/src/header/protocol.rs b/src/header/protocol.rs index 899c970e48..9a6e8f1cf0 100644 --- a/src/header/protocol.rs +++ b/src/header/protocol.rs @@ -4,15 +4,17 @@ use hyper; use std::fmt; use std::ops::Deref; +// TODO: only allow valid protocol names to be added + /// Represents a Sec-WebSocket-Protocol header #[derive(PartialEq, Clone, Debug)] pub struct WebSocketProtocol(pub Vec); impl Deref for WebSocketProtocol { type Target = Vec; - fn deref<'a>(&'a self) -> &'a Vec { - &self.0 - } + fn deref<'a>(&'a self) -> &'a Vec { + &self.0 + } } impl Header for WebSocketProtocol { @@ -32,36 +34,42 @@ impl HeaderFormat for WebSocketProtocol { } } +impl fmt::Display for WebSocketProtocol { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(fmt) + } +} + #[cfg(all(feature = "nightly", test))] mod tests { use super::*; - use hyper::header::{Header, HeaderFormatter}; + use hyper::header::Header; use test; #[test] fn test_header_protocol() { use header::Headers; - + let protocol = WebSocketProtocol(vec!["foo".to_string(), "bar".to_string()]); let mut headers = Headers::new(); headers.set(protocol); - - assert_eq!(&headers.to_string()[..], "Sec-WebSocket-Protocol: foo, bar\r\n"); + + assert_eq!(&headers.to_string()[..], + "Sec-WebSocket-Protocol: foo, bar\r\n"); } #[bench] fn bench_header_protocol_parse(b: &mut test::Bencher) { let value = vec![b"foo, bar".to_vec()]; b.iter(|| { - let mut protocol: WebSocketProtocol = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut protocol); - }); + let mut protocol: WebSocketProtocol = Header::parse_header(&value[..]).unwrap(); + test::black_box(&mut protocol); + }); } #[bench] fn bench_header_protocol_format(b: &mut test::Bencher) { let value = vec![b"foo, bar".to_vec()]; let val: WebSocketProtocol = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); b.iter(|| { - format!("{}", fmt); - }); + format!("{}", val); + }); } -} \ No newline at end of file +} diff --git a/src/header/version.rs b/src/header/version.rs index 664a4faeee..7fadfc4ff5 100644 --- a/src/header/version.rs +++ b/src/header/version.rs @@ -9,18 +9,14 @@ pub enum WebSocketVersion { /// The version of WebSocket defined in RFC6455 WebSocket13, /// An unknown version of WebSocket - Unknown(String) + Unknown(String), } impl fmt::Debug for WebSocketVersion { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - WebSocketVersion::WebSocket13 => { - write!(f, "13") - } - WebSocketVersion::Unknown(ref value) => { - write!(f, "{}", value) - } + WebSocketVersion::WebSocket13 => write!(f, "13"), + WebSocketVersion::Unknown(ref value) => write!(f, "{}", value), } } } @@ -31,12 +27,10 @@ impl Header for WebSocketVersion { } fn parse_header(raw: &[Vec]) -> hyper::Result { - from_one_raw_str(raw).map(|s : String| - match &s[..] { - "13" => { WebSocketVersion::WebSocket13 } - _ => { WebSocketVersion::Unknown(s) } - } - ) + from_one_raw_str(raw).map(|s: String| match &s[..] { + "13" => WebSocketVersion::WebSocket13, + _ => WebSocketVersion::Unknown(s), + }) } } @@ -46,36 +40,41 @@ impl HeaderFormat for WebSocketVersion { } } +impl fmt::Display for WebSocketVersion { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(fmt) + } +} + #[cfg(all(feature = "nightly", test))] mod tests { use super::*; - use hyper::header::{Header, HeaderFormatter}; + use hyper::header::Header; use test; #[test] fn test_websocket_version() { use header::Headers; - + let version = WebSocketVersion::WebSocket13; let mut headers = Headers::new(); headers.set(version); - + assert_eq!(&headers.to_string()[..], "Sec-WebSocket-Version: 13\r\n"); } #[bench] fn bench_header_version_parse(b: &mut test::Bencher) { let value = vec![b"13".to_vec()]; b.iter(|| { - let mut version: WebSocketVersion = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut version); - }); + let mut version: WebSocketVersion = Header::parse_header(&value[..]).unwrap(); + test::black_box(&mut version); + }); } #[bench] fn bench_header_version_format(b: &mut test::Bencher) { let value = vec![b"13".to_vec()]; let val: WebSocketVersion = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); b.iter(|| { - format!("{}", fmt); - }); + format!("{}", val); + }); } -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index e90c72ee95..1aa3202534 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,11 +38,13 @@ //! level. Their usage is explained in the module documentation. extern crate hyper; extern crate unicase; -extern crate url; +pub extern crate url; extern crate rustc_serialize as serialize; -extern crate openssl; extern crate rand; extern crate byteorder; +extern crate sha1; +#[cfg(feature="ssl")] +extern crate openssl; #[macro_use] extern crate bitflags; @@ -50,14 +52,34 @@ extern crate bitflags; #[cfg(all(feature = "nightly", test))] extern crate test; -pub use self::client::Client; +pub use self::client::{Client, ClientBuilder}; pub use self::server::Server; pub use self::dataframe::DataFrame; pub use self::message::Message; -pub use self::stream::WebSocketStream; +pub use self::stream::Stream; pub use self::ws::Sender; pub use self::ws::Receiver; +macro_rules! upsert_header { + ($headers:expr; $header:ty; { + Some($pat:pat) => $some_match:expr, + None => $default:expr + }) => {{ + match $headers.has::<$header>() { + true => { + match $headers.get_mut::<$header>() { + Some($pat) => { $some_match; }, + None => (), + }; + } + false => { + $headers.set($default); + }, + }; + }} +} + + pub mod ws; pub mod client; pub mod server; diff --git a/src/message.rs b/src/message.rs index 89b64e9387..f2103152c3 100644 --- a/src/message.rs +++ b/src/message.rs @@ -13,15 +13,15 @@ const FALSE_RESERVED_BITS: &'static [bool; 3] = &[false; 3]; /// Valid types of messages (in the default implementation) #[derive(Debug, PartialEq, Clone, Copy)] pub enum Type { - /// Message with UTF8 test + /// Message with UTF8 test Text = 1, - /// Message containing binary data + /// Message containing binary data Binary = 2, - /// Ping message with data + /// Ping message with data Ping = 9, - /// Pong message with data + /// Pong message with data Pong = 10, - /// Close connection message with optional reason + /// Close connection message with optional reason Close = 8, } @@ -35,12 +35,12 @@ pub enum Type { /// because this message just gets sent as one single DataFrame. #[derive(PartialEq, Clone, Debug)] pub struct Message<'a> { - /// Type of WebSocket message + /// Type of WebSocket message pub opcode: Type, - /// Optional status code to send when closing a connection. - /// (only used if this message is of Type::Close) + /// Optional status code to send when closing a connection. + /// (only used if this message is of Type::Close) pub cd_status_code: Option, - /// Main payload + /// Main payload pub payload: Cow<'a, [u8]>, } @@ -53,79 +53,88 @@ impl<'a> Message<'a> { } } - /// Create a new WebSocket message with text data + /// Create a new WebSocket message with text data pub fn text(data: S) -> Self - where S: Into> { - Message::new(Type::Text, None, match data.into() { - Cow::Owned(msg) => Cow::Owned(msg.into_bytes()), - Cow::Borrowed(msg) => Cow::Borrowed(msg.as_bytes()), - }) + where S: Into> + { + Message::new(Type::Text, + None, + match data.into() { + Cow::Owned(msg) => Cow::Owned(msg.into_bytes()), + Cow::Borrowed(msg) => Cow::Borrowed(msg.as_bytes()), + }) } - /// Create a new WebSocket message with binary data + /// Create a new WebSocket message with binary data pub fn binary(data: B) -> Self - where B: IntoCowBytes<'a> { + where B: IntoCowBytes<'a> + { Message::new(Type::Binary, None, data.into()) } - /// Create a new WebSocket message that signals the end of a WebSocket - /// connection, although messages can still be sent after sending this + /// Create a new WebSocket message that signals the end of a WebSocket + /// connection, although messages can still be sent after sending this pub fn close() -> Self { Message::new(Type::Close, None, Cow::Borrowed(&[0 as u8; 0])) } - /// Create a new WebSocket message that signals the end of a WebSocket - /// connection and provide a text reason and a status code for why. - /// Messages can still be sent after sending this message. + /// Create a new WebSocket message that signals the end of a WebSocket + /// connection and provide a text reason and a status code for why. + /// Messages can still be sent after sending this message. pub fn close_because(code: u16, reason: S) -> Self - where S: Into> { - Message::new(Type::Close, Some(code), match reason.into() { - Cow::Owned(msg) => Cow::Owned(msg.into_bytes()), - Cow::Borrowed(msg) => Cow::Borrowed(msg.as_bytes()), - }) + where S: Into> + { + Message::new(Type::Close, + Some(code), + match reason.into() { + Cow::Owned(msg) => Cow::Owned(msg.into_bytes()), + Cow::Borrowed(msg) => Cow::Borrowed(msg.as_bytes()), + }) } - /// Create a ping WebSocket message, a pong is usually sent back - /// after sending this with the same data + /// Create a ping WebSocket message, a pong is usually sent back + /// after sending this with the same data pub fn ping

(data: P) -> Self - where P: IntoCowBytes<'a> { + where P: IntoCowBytes<'a> + { Message::new(Type::Ping, None, data.into()) } - /// Create a pong WebSocket message, usually a response to a - /// ping message + /// Create a pong WebSocket message, usually a response to a + /// ping message pub fn pong

(data: P) -> Self - where P: IntoCowBytes<'a> { + where P: IntoCowBytes<'a> + { Message::new(Type::Pong, None, data.into()) } - /// Convert a ping message to a pong, keeping the data. - /// This will fail if the original message is not a ping. - pub fn into_pong(&mut self) -> Result<(), ()> { - if self.opcode == Type::Ping { - self.opcode = Type::Pong; - Ok(()) - } else { - Err(()) - } - } + /// Convert a ping message to a pong, keeping the data. + /// This will fail if the original message is not a ping. + pub fn into_pong(&mut self) -> Result<(), ()> { + if self.opcode == Type::Ping { + self.opcode = Type::Pong; + Ok(()) + } else { + Err(()) + } + } } impl<'a> ws::dataframe::DataFrame for Message<'a> { #[inline(always)] - fn is_last(&self) -> bool { - true - } + fn is_last(&self) -> bool { + true + } #[inline(always)] - fn opcode(&self) -> u8 { - self.opcode as u8 - } + fn opcode(&self) -> u8 { + self.opcode as u8 + } #[inline(always)] - fn reserved<'b>(&'b self) -> &'b [bool; 3] { + fn reserved<'b>(&'b self) -> &'b [bool; 3] { FALSE_RESERVED_BITS - } + } fn payload<'b>(&'b self) -> Cow<'b, [u8]> { let mut buf = Vec::with_capacity(self.size()); @@ -134,79 +143,70 @@ impl<'a> ws::dataframe::DataFrame for Message<'a> { } fn size(&self) -> usize { - self.payload.len() + if self.cd_status_code.is_some() { - 2 - } else { - 0 - } + self.payload.len() + if self.cd_status_code.is_some() { 2 } else { 0 } } - fn write_payload(&self, socket: &mut W) -> WebSocketResult<()> - where W: Write { + fn write_payload(&self, socket: &mut W) -> WebSocketResult<()> + where W: Write + { if let Some(reason) = self.cd_status_code { try!(socket.write_u16::(reason)); } try!(socket.write_all(&*self.payload)); Ok(()) - } + } } impl<'a, 'b> ws::Message<'b, &'b Message<'a>> for Message<'a> { - type DataFrameIterator = Take>>; fn dataframes(&'b self) -> Self::DataFrameIterator { repeat(self).take(1) - } + } /// Attempt to form a message from a series of data frames fn from_dataframes(frames: Vec) -> WebSocketResult - where D: ws::dataframe::DataFrame { - let opcode = try!(frames.first().ok_or(WebSocketError::ProtocolError( - "No dataframes provided" - )).map(|d| d.opcode())); + where D: ws::dataframe::DataFrame + { + let opcode = try!(frames.first() + .ok_or(WebSocketError::ProtocolError("No dataframes provided")) + .map(|d| d.opcode())); let mut data = Vec::new(); for (i, dataframe) in frames.iter().enumerate() { if i > 0 && dataframe.opcode() != Opcode::Continuation as u8 { - return Err(WebSocketError::ProtocolError( - "Unexpected non-continuation data frame" - )); + return Err(WebSocketError::ProtocolError("Unexpected non-continuation data frame")); } if *dataframe.reserved() != [false; 3] { - return Err(WebSocketError::ProtocolError( - "Unsupported reserved bits received" - )); + return Err(WebSocketError::ProtocolError("Unsupported reserved bits received")); } data.extend(dataframe.payload().iter().cloned()); } Ok(match Opcode::new(opcode) { - Some(Opcode::Text) => Message::text(try!(bytes_to_string(&data[..]))), - Some(Opcode::Binary) => Message::binary(data), - Some(Opcode::Close) => { - if data.len() > 0 { - let status_code = try!((&data[..]).read_u16::()); - let reason = try!(bytes_to_string(&data[2..])); - Message::close_because(status_code, reason) - } else { - Message::close() - } - } - Some(Opcode::Ping) => Message::ping(data), - Some(Opcode::Pong) => Message::pong(data), - _ => return Err(WebSocketError::ProtocolError( - "Unsupported opcode received" - )), - }) + Some(Opcode::Text) => Message::text(try!(bytes_to_string(&data[..]))), + Some(Opcode::Binary) => Message::binary(data), + Some(Opcode::Close) => { + if data.len() > 0 { + let status_code = try!((&data[..]).read_u16::()); + let reason = try!(bytes_to_string(&data[2..])); + Message::close_because(status_code, reason) + } else { + Message::close() + } + } + Some(Opcode::Ping) => Message::ping(data), + Some(Opcode::Pong) => Message::pong(data), + _ => return Err(WebSocketError::ProtocolError("Unsupported opcode received")), + }) } } /// Trait representing the ability to convert /// self to a `Cow<'a, [u8]>` pub trait IntoCowBytes<'a> { - /// Consume `self` and produce a `Cow<'a, [u8]>` + /// Consume `self` and produce a `Cow<'a, [u8]>` fn into(self) -> Cow<'a, [u8]>; } diff --git a/src/receiver.rs b/src/receiver.rs index 03cfab6197..ea3688094f 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -2,90 +2,125 @@ use std::io::Read; use std::io::Result as IoResult; + use hyper::buffer::BufReader; use dataframe::{DataFrame, Opcode}; use result::{WebSocketResult, WebSocketError}; -use stream::WebSocketStream; -use stream::Shutdown; use ws; +use ws::dataframe::DataFrame as DataFrameable; +use ws::receiver::Receiver as ReceiverTrait; +use ws::receiver::{MessageIterator, DataFrameIterator}; +use stream::{AsTcpStream, Stream}; +pub use stream::Shutdown; + +/// This reader bundles an existing stream with a parsing algorithm. +/// It is used by the client in its `.split()` function as the reading component. +pub struct Reader + where R: Read +{ + /// the stream to be read from + pub stream: BufReader, + /// the parser to parse bytes into messages + pub receiver: Receiver, +} + +impl Reader + where R: Read +{ + /// Reads a single data frame from the remote endpoint. + pub fn recv_dataframe(&mut self) -> WebSocketResult { + self.receiver.recv_dataframe(&mut self.stream) + } + + /// Returns an iterator over incoming data frames. + pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, BufReader> { + self.receiver.incoming_dataframes(&mut self.stream) + } + + /// Reads a single message from this receiver. + pub fn recv_message<'m, M, I>(&mut self) -> WebSocketResult + where M: ws::Message<'m, DataFrame, DataFrameIterator = I>, + I: Iterator + { + self.receiver.recv_message(&mut self.stream) + } + + /// An iterator over incoming messsages. + /// This iterator will block until new messages arrive and will never halt. + pub fn incoming_messages<'a, M, D>(&'a mut self,) + -> MessageIterator<'a, Receiver, D, M, BufReader> + where M: ws::Message<'a, D>, + D: DataFrameable + { + self.receiver.incoming_messages(&mut self.stream) + } +} + +impl Reader + where S: AsTcpStream + Stream + Read +{ + /// Closes the receiver side of the connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.stream.get_ref().as_tcp().shutdown(Shutdown::Read) + } + + /// Shuts down both Sender and Receiver, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown_all(&self) -> IoResult<()> { + self.stream.get_ref().as_tcp().shutdown(Shutdown::Both) + } +} /// A Receiver that wraps a Reader and provides a default implementation using /// DataFrames and Messages. -pub struct Receiver { - inner: BufReader, +pub struct Receiver { buffer: Vec, mask: bool, } -impl Receiver -where R: Read { +impl Receiver { /// Create a new Receiver using the specified Reader. - pub fn new(reader: BufReader, mask: bool) -> Receiver { + pub fn new(mask: bool) -> Receiver { Receiver { - inner: reader, buffer: Vec::new(), mask: mask, } } - /// Returns a reference to the underlying Reader. - pub fn get_ref(&self) -> &BufReader { - &self.inner - } - /// Returns a mutable reference to the underlying Reader. - pub fn get_mut(&mut self) -> &mut BufReader { - &mut self.inner - } } -impl Receiver { - /// Closes the receiver side of the connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&mut self) -> IoResult<()> { - self.inner.get_mut().shutdown(Shutdown::Read) - } - - /// Shuts down both Sender and Receiver, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown_all(&mut self) -> IoResult<()> { - self.inner.get_mut().shutdown(Shutdown::Both) - } - - /// Changes whether the receiver is in nonblocking mode. - /// - /// If it is in nonblocking mode and there is no incoming message, trying to receive a message - /// will return an error instead of blocking. - pub fn set_nonblocking(&self, nonblocking: bool) -> IoResult<()> { - self.inner.get_ref().set_nonblocking(nonblocking) - } -} -impl ws::Receiver for Receiver { +impl ws::Receiver for Receiver { + type F = DataFrame; + /// Reads a single data frame from the remote endpoint. - fn recv_dataframe(&mut self) -> WebSocketResult { - DataFrame::read_dataframe(&mut self.inner, self.mask) + fn recv_dataframe(&mut self, reader: &mut R) -> WebSocketResult + where R: Read + { + DataFrame::read_dataframe(reader, self.mask) } + /// Returns the data frames that constitute one message. - fn recv_message_dataframes(&mut self) -> WebSocketResult> { + fn recv_message_dataframes(&mut self, reader: &mut R) -> WebSocketResult> + where R: Read + { let mut finished = if self.buffer.is_empty() { - let first = try!(self.recv_dataframe()); + let first = try!(self.recv_dataframe(reader)); if first.opcode == Opcode::Continuation { - return Err(WebSocketError::ProtocolError( - "Unexpected continuation data frame opcode" - )); + return Err(WebSocketError::ProtocolError("Unexpected continuation data frame opcode")); } let finished = first.finished; self.buffer.push(first); finished - } - else { + } else { false }; while !finished { - let next = try!(self.recv_dataframe()); + let next = try!(self.recv_dataframe(reader)); finished = next.finished; match next.opcode as u8 { @@ -96,9 +131,7 @@ impl ws::Receiver for Receiver { return Ok(vec![next]); } // Others - _ => return Err(WebSocketError::ProtocolError( - "Unexpected data frame opcode" - )), + _ => return Err(WebSocketError::ProtocolError("Unexpected data frame opcode")), } } diff --git a/src/result.rs b/src/result.rs index f6b9b25fa8..f87b2ce5cc 100644 --- a/src/result.rs +++ b/src/result.rs @@ -5,10 +5,14 @@ use std::str::Utf8Error; use std::error::Error; use std::convert::From; use std::fmt; -use openssl::ssl::error::SslError; use hyper::Error as HttpError; use url::ParseError; +#[cfg(feature="ssl")] +use openssl::error::ErrorStack as SslError; +#[cfg(feature="ssl")] +use openssl::ssl::HandshakeError as SslHandshakeError; + /// The type used for WebSocket results pub type WebSocketResult = Result; @@ -31,10 +35,17 @@ pub enum WebSocketError { HttpError(HttpError), /// A URL parsing error UrlError(ParseError), - /// A WebSocket URL error - WebSocketUrlError(WSUrlErrorKind), + /// A WebSocket URL error + WebSocketUrlError(WSUrlErrorKind), /// An SSL error + #[cfg(feature="ssl")] SslError(SslError), + /// an ssl handshake failure + #[cfg(feature="ssl")] + SslHandshakeFailure, + /// an ssl handshake interruption + #[cfg(feature="ssl")] + SslHandshakeInterruption, /// A UTF-8 error Utf8Error(Utf8Error), } @@ -50,7 +61,7 @@ impl fmt::Display for WebSocketError { impl Error for WebSocketError { fn description(&self) -> &str { match *self { - WebSocketError::ProtocolError(_) => "WebSocket protocol error", + WebSocketError::ProtocolError(_) => "WebSocket protocol error", WebSocketError::RequestError(_) => "WebSocket request error", WebSocketError::ResponseError(_) => "WebSocket response error", WebSocketError::DataFrameError(_) => "WebSocket data frame error", @@ -58,9 +69,14 @@ impl Error for WebSocketError { WebSocketError::IoError(_) => "I/O failure", WebSocketError::HttpError(_) => "HTTP failure", WebSocketError::UrlError(_) => "URL failure", - WebSocketError::SslError(_) => "SSL failure", + #[cfg(feature="ssl")] + WebSocketError::SslError(_) => "SSL failure", + #[cfg(feature="ssl")] + WebSocketError::SslHandshakeFailure => "SSL Handshake failure", + #[cfg(feature="ssl")] + WebSocketError::SslHandshakeInterruption => "SSL Handshake interrupted", WebSocketError::Utf8Error(_) => "UTF-8 failure", - WebSocketError::WebSocketUrlError(_) => "WebSocket URL failure", + WebSocketError::WebSocketUrlError(_) => "WebSocket URL failure", } } @@ -69,9 +85,10 @@ impl Error for WebSocketError { WebSocketError::IoError(ref error) => Some(error), WebSocketError::HttpError(ref error) => Some(error), WebSocketError::UrlError(ref error) => Some(error), - WebSocketError::SslError(ref error) => Some(error), + #[cfg(feature="ssl")] + WebSocketError::SslError(ref error) => Some(error), WebSocketError::Utf8Error(ref error) => Some(error), - WebSocketError::WebSocketUrlError(ref error) => Some(error), + WebSocketError::WebSocketUrlError(ref error) => Some(error), _ => None, } } @@ -98,12 +115,24 @@ impl From for WebSocketError { } } +#[cfg(feature="ssl")] impl From for WebSocketError { fn from(err: SslError) -> WebSocketError { WebSocketError::SslError(err) } } +#[cfg(feature="ssl")] +impl From> for WebSocketError { + fn from(err: SslHandshakeError) -> WebSocketError { + match err { + SslHandshakeError::SetupFailure(err) => WebSocketError::SslError(err), + SslHandshakeError::Failure(_) => WebSocketError::SslHandshakeFailure, + SslHandshakeError::Interrupted(_) => WebSocketError::SslHandshakeInterruption, + } + } +} + impl From for WebSocketError { fn from(err: Utf8Error) -> WebSocketError { WebSocketError::Utf8Error(err) @@ -111,33 +140,36 @@ impl From for WebSocketError { } impl From for WebSocketError { - fn from(err: WSUrlErrorKind) -> WebSocketError { - WebSocketError::WebSocketUrlError(err) - } + fn from(err: WSUrlErrorKind) -> WebSocketError { + WebSocketError::WebSocketUrlError(err) + } } /// Represents a WebSocket URL error #[derive(Debug)] pub enum WSUrlErrorKind { - /// Fragments are not valid in a WebSocket URL - CannotSetFragment, - /// The scheme provided is invalid for a WebSocket - InvalidScheme, + /// Fragments are not valid in a WebSocket URL + CannotSetFragment, + /// The scheme provided is invalid for a WebSocket + InvalidScheme, + /// There is no hostname or IP address to connect to + NoHostName, } impl fmt::Display for WSUrlErrorKind { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - try!(fmt.write_str("WebSocket Url Error: ")); - try!(fmt.write_str(self.description())); - Ok(()) - } + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + try!(fmt.write_str("WebSocket Url Error: ")); + try!(fmt.write_str(self.description())); + Ok(()) + } } impl Error for WSUrlErrorKind { - fn description(&self) -> &str { - match *self { - WSUrlErrorKind::CannotSetFragment => "WebSocket URL cannot set fragment", - WSUrlErrorKind::InvalidScheme => "WebSocket URL invalid scheme" - } - } + fn description(&self) -> &str { + match *self { + WSUrlErrorKind::CannotSetFragment => "WebSocket URL cannot set fragment", + WSUrlErrorKind::InvalidScheme => "WebSocket URL invalid scheme", + WSUrlErrorKind::NoHostName => "WebSocket URL no host name provided", + } + } } diff --git a/src/sender.rs b/src/sender.rs index a4390f4741..398520f18d 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -4,58 +4,77 @@ use std::io::Write; use std::io::Result as IoResult; use result::WebSocketResult; use ws::dataframe::DataFrame; -use stream::WebSocketStream; -use stream::Shutdown; +use stream::AsTcpStream; use ws; +use ws::sender::Sender as SenderTrait; +pub use stream::Shutdown; -/// A Sender that wraps a Writer and provides a default implementation using -/// DataFrames and Messages. -pub struct Sender { - inner: W, - mask: bool, +/// A writer that bundles a stream with a serializer to send the messages. +/// This is used in the client's `.split()` function as the writing component. +/// +/// It can also be useful to use a websocket connection without a handshake. +pub struct Writer { + /// The stream that websocket messages will be written to + pub stream: W, + /// The serializer that will be used to serialize the messages + pub sender: Sender, } -impl Sender { - /// Create a new WebSocketSender using the specified Writer. - pub fn new(writer: W, mask: bool) -> Sender { - Sender { - inner: writer, - mask: mask, - } +impl Writer + where W: Write +{ + /// Sends a single data frame to the remote endpoint. + pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> + where D: DataFrame, + W: Write + { + self.sender.send_dataframe(&mut self.stream, dataframe) } - /// Returns a reference to the underlying Writer. - pub fn get_ref(&self) -> &W { - &self.inner + + /// Sends a single message to the remote endpoint. + pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> + where M: ws::Message<'m, D>, + D: DataFrame + { + self.sender.send_message(&mut self.stream, message) } - /// Returns a mutable reference to the underlying Writer. - pub fn get_mut(&mut self) -> &mut W { - &mut self.inner +} + +impl Writer + where S: AsTcpStream + Write +{ + /// Closes the sender side of the connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Write) } + + /// Shuts down both Sender and Receiver, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown_all(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Both) + } +} + +/// A Sender that wraps a Writer and provides a default implementation using +/// DataFrames and Messages. +pub struct Sender { + mask: bool, } -impl Sender { - /// Closes the sender side of the connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&mut self) -> IoResult<()> { - self.inner.shutdown(Shutdown::Write) - } - - /// Shuts down both Sender and Receiver, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown_all(&mut self) -> IoResult<()> { - self.inner.shutdown(Shutdown::Both) - } - - /// Changes whether the sender is in nonblocking mode. - pub fn set_nonblocking(&self, nonblocking: bool) -> IoResult<()> { - self.inner.set_nonblocking(nonblocking) - } +impl Sender { + /// Create a new WebSocketSender using the specified Writer. + pub fn new(mask: bool) -> Sender { + Sender { mask: mask } + } } -impl ws::Sender for Sender { +impl ws::Sender for Sender { /// Sends a single data frame to the remote endpoint. - fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> - where D: DataFrame { - dataframe.write_to(&mut self.inner, self.mask) + fn send_dataframe(&mut self, writer: &mut W, dataframe: &D) -> WebSocketResult<()> + where D: DataFrame, + W: Write + { + dataframe.write_to(writer, self.mask) } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 8236b1c089..44a5497101 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,24 +1,61 @@ //! Provides an implementation of a WebSocket server -use std::net::{SocketAddr, ToSocketAddrs, TcpListener}; -use std::net::Shutdown; -use std::io::{Read, Write}; +use std::net::{SocketAddr, ToSocketAddrs, TcpListener, TcpStream}; use std::io; -pub use self::request::Request; -pub use self::response::Response; +use std::convert::Into; +#[cfg(feature="ssl")] +use openssl::ssl::{SslStream, SslAcceptor}; +use stream::Stream; +use self::upgrade::{WsUpgrade, IntoWs, Buffer}; +pub use self::upgrade::{Request, HyperIntoWsError}; -use stream::WebSocketStream; +pub mod upgrade; -use openssl::ssl::SslContext; -use openssl::ssl::SslStream; +/// When a sever tries to accept a connection many things can go wrong. +/// +/// This struct is all the information that is recovered from a failed +/// websocket handshake, in case one wants to use the connection for something +/// else (such as HTTP). +pub struct InvalidConnection + where S: Stream +{ + /// if the stream was successfully setup it will be included here + /// on a failed connection. + pub stream: Option, + /// the parsed request. **This is a normal HTTP request** meaning you can + /// simply run this server and handle both HTTP and Websocket connections. + /// If you already have a server you want to use, checkout the + /// `server::upgrade` module to integrate this crate with your server. + pub parsed: Option, + /// the buffered data that was already taken from the stream + pub buffer: Option, + /// the cause of the failed websocket connection setup + pub error: HyperIntoWsError, +} + +/// Either the stream was established and it sent a websocket handshake +/// which represents the `Ok` variant, or there was an error (this is the +/// `Err` variant). +pub type AcceptResult = Result, InvalidConnection>; -pub mod request; -pub mod response; +/// Marker struct for a struct not being secure +#[derive(Clone)] +pub struct NoSslAcceptor; +/// Trait that is implemented over NoSslAcceptor and SslAcceptor that +/// serves as a generic bound to make a struct with. +/// Used in the Server to specify impls based on wether the server +/// is running over SSL or not. +pub trait OptionalSslAcceptor: Clone {} +impl OptionalSslAcceptor for NoSslAcceptor {} +#[cfg(feature="ssl")] +impl OptionalSslAcceptor for SslAcceptor {} -/// Represents a WebSocket server which can work with either normal (non-secure) connections, or secure WebSocket connections. +/// Represents a WebSocket server which can work with either normal +/// (non-secure) connections, or secure WebSocket connections. /// -/// This is a convenient way to implement WebSocket servers, however it is possible to use any sendable Reader and Writer to obtain +/// This is a convenient way to implement WebSocket servers, however +/// it is possible to use any sendable Reader and Writer to obtain /// a WebSocketClient, so if needed, an alternative server implementation can be used. -///#Non-secure Servers +///# Non-secure Servers /// /// ```no_run ///extern crate websocket; @@ -31,9 +68,7 @@ pub mod response; ///for connection in server { /// // Spawn a new thread for each connection. /// thread::spawn(move || { -/// let request = connection.unwrap().read_request().unwrap(); // Get the request -/// let response = request.accept(); // Form a response -/// let mut client = response.send().unwrap(); // Send the response +/// let mut client = connection.accept().unwrap(); /// /// let message = Message::text("Hello, client!"); /// let _ = client.send_message(&message); @@ -44,28 +79,40 @@ pub mod response; /// # } /// ``` /// -///#Secure Servers +///# Secure Servers /// ```no_run ///extern crate websocket; ///extern crate openssl; ///# fn main() { ///use std::thread; -///use std::path::Path; +///use std::io::Read; +///use std::fs::File; ///use websocket::{Server, Message}; -///use openssl::ssl::{SslContext, SslMethod}; -///use openssl::x509::X509FileType; +///use openssl::pkcs12::Pkcs12; +///use openssl::ssl::{SslMethod, SslAcceptorBuilder, SslStream}; +/// +///// In this example we retrieve our keypair and certificate chain from a PKCS #12 archive, +///// but but they can also be retrieved from, for example, individual PEM- or DER-formatted +///// files. See the documentation for the `PKey` and `X509` types for more details. +///let mut file = File::open("identity.pfx").unwrap(); +///let mut pkcs12 = vec![]; +///file.read_to_end(&mut pkcs12).unwrap(); +///let pkcs12 = Pkcs12::from_der(&pkcs12).unwrap(); +///let identity = pkcs12.parse("password123").unwrap(); +/// +///let acceptor = SslAcceptorBuilder::mozilla_intermediate(SslMethod::tls(), +/// &identity.pkey, +/// &identity.cert, +/// &identity.chain) +/// .unwrap() +/// .build(); /// -///let mut context = SslContext::new(SslMethod::Tlsv1).unwrap(); -///let _ = context.set_certificate_file(&(Path::new("cert.pem")), X509FileType::PEM); -///let _ = context.set_private_key_file(&(Path::new("key.pem")), X509FileType::PEM); -///let server = Server::bind_secure("127.0.0.1:1234", &context).unwrap(); +///let server = Server::bind_secure("127.0.0.1:1234", acceptor).unwrap(); /// ///for connection in server { /// // Spawn a new thread for each connection. /// thread::spawn(move || { -/// let request = connection.unwrap().read_request().unwrap(); // Get the request -/// let response = request.accept(); // Form a response -/// let mut client = response.send().unwrap(); // Send the response +/// let mut client = connection.accept().unwrap(); /// /// let message = Message::text("Hello, client!"); /// let _ = client.send_message(&message); @@ -75,95 +122,152 @@ pub mod response; ///} /// # } /// ``` -pub struct Server<'a> { - inner: TcpListener, - context: Option<&'a SslContext>, +/// +/// # A Hyper Server +/// This crates comes with hyper integration out of the box, you can create a hyper +/// server and serve websocket and HTTP **on the same port!** +/// check out the docs over at `websocket::server::upgrade::from_hyper` for an example. +/// +/// # A Custom Server +/// So you don't want to use any of our server implementations? That's O.K. +/// All it takes is implementing the `IntoWs` trait for your server's streams, +/// then calling `.into_ws()` on them. +/// check out the docs over at `websocket::server::upgrade` for more. +pub struct Server + where S: OptionalSslAcceptor +{ + listener: TcpListener, + ssl_acceptor: S, } -impl<'a> Server<'a> { - /// Bind this Server to this socket - pub fn bind(addr: T) -> io::Result> { - Ok(Server { - inner: try!(TcpListener::bind(&addr)), - context: None, - }) - } - /// Bind this Server to this socket, utilising the given SslContext - pub fn bind_secure(addr: T, context: &'a SslContext) -> io::Result> { - Ok(Server { - inner: try!(TcpListener::bind(&addr)), - context: Some(context), - }) - } +impl Server + where S: OptionalSslAcceptor +{ /// Get the socket address of this server pub fn local_addr(&self) -> io::Result { - self.inner.local_addr() + self.listener.local_addr() } /// Create a new independently owned handle to the underlying socket. - pub fn try_clone(&self) -> io::Result> { - let inner = try!(self.inner.try_clone()); + pub fn try_clone(&self) -> io::Result> { + let inner = try!(self.listener.try_clone()); Ok(Server { - inner: inner, - context: self.context - }) + listener: inner, + ssl_acceptor: self.ssl_acceptor.clone(), + }) + } +} + +#[cfg(feature="ssl")] +impl Server { + /// Bind this Server to this socket, utilising the given SslContext + pub fn bind_secure(addr: A, acceptor: SslAcceptor) -> io::Result + where A: ToSocketAddrs + { + Ok(Server { + listener: try!(TcpListener::bind(&addr)), + ssl_acceptor: acceptor, + }) } /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest - pub fn accept(&mut self) -> io::Result> { - let stream = try!(self.inner.accept()).0; - let wsstream = match self.context { - Some(context) => { - let sslstream = match SslStream::accept(context, stream) { - Ok(s) => s, - Err(err) => { - return Err(io::Error::new(io::ErrorKind::Other, err)); - } - }; - WebSocketStream::Ssl(sslstream) + pub fn accept(&mut self) -> AcceptResult> { + let stream = match self.listener.accept() { + Ok(s) => s.0, + Err(e) => { + return Err(InvalidConnection { + stream: None, + parsed: None, + buffer: None, + error: e.into(), + }) } - None => { WebSocketStream::Tcp(stream) } }; - Ok(Connection(try!(wsstream.try_clone()), try!(wsstream.try_clone()))) + + let stream = match self.ssl_acceptor.accept(stream) { + Ok(s) => s, + Err(err) => { + return Err(InvalidConnection { + stream: None, + parsed: None, + buffer: None, + error: io::Error::new(io::ErrorKind::Other, err).into(), + }) + } + }; + + match stream.into_ws() { + Ok(u) => Ok(u), + Err((s, r, b, e)) => { + Err(InvalidConnection { + stream: Some(s), + parsed: r, + buffer: b, + error: e.into(), + }) + } + } } - /// Changes whether the Server is in nonblocking mode. - /// - /// If it is in nonblocking mode, accept() will return an error instead of blocking when there - /// are no incoming connections. - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.inner.set_nonblocking(nonblocking) - } + /// Changes whether the Server is in nonblocking mode. + /// + /// If it is in nonblocking mode, accept() will return an error instead of blocking when there + /// are no incoming connections. + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.listener.set_nonblocking(nonblocking) + } } -impl<'a> Iterator for Server<'a> { - type Item = io::Result>; +#[cfg(feature="ssl")] +impl Iterator for Server { + type Item = WsUpgrade>; fn next(&mut self) -> Option<::Item> { - Some(self.accept()) + self.accept().ok() } } -/// Represents a connection to the server that has not been processed yet. -pub struct Connection(R, W); +impl Server { + /// Bind this Server to this socket + pub fn bind(addr: A) -> io::Result { + Ok(Server { + listener: try!(TcpListener::bind(&addr)), + ssl_acceptor: NoSslAcceptor, + }) + } + + /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest + pub fn accept(&mut self) -> AcceptResult { + let stream = match self.listener.accept() { + Ok(s) => s.0, + Err(e) => { + return Err(InvalidConnection { + stream: None, + parsed: None, + buffer: None, + error: e.into(), + }) + } + }; -impl Connection { - /// Process this connection and read the request. - pub fn read_request(self) -> io::Result> { - match Request::read(self.0, self.1) { - Ok(result) => { Ok(result) }, - Err(err) => { - Err(io::Error::new(io::ErrorKind::InvalidInput, err)) + match stream.into_ws() { + Ok(u) => Ok(u), + Err((s, r, b, e)) => { + Err(InvalidConnection { + stream: Some(s), + parsed: r, + buffer: b, + error: e.into(), + }) } } } } -impl Connection { - /// Shuts down the currennt connection in the specified way. - /// All future IO calls to this connection will return immediately with an appropriate - /// return value. - pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - self.0.shutdown(how) - } +impl Iterator for Server { + type Item = WsUpgrade; + + fn next(&mut self) -> Option<::Item> { + self.accept().ok() + } } diff --git a/src/server/request.rs b/src/server/request.rs deleted file mode 100644 index 990647592b..0000000000 --- a/src/server/request.rs +++ /dev/null @@ -1,169 +0,0 @@ -//! The server-side WebSocket request. - -use std::io::{Read, Write}; - -use server::Response; -use result::{WebSocketResult, WebSocketError}; -use header::{WebSocketKey, WebSocketVersion, WebSocketProtocol, WebSocketExtensions, Origin}; - -pub use hyper::uri::RequestUri; - -use hyper::buffer::BufReader; -use hyper::version::HttpVersion; -use hyper::header::Headers; -use hyper::header::{Connection, ConnectionOption}; -use hyper::header::{Upgrade, ProtocolName}; -use hyper::http::h1::parse_request; -use hyper::method::Method; - -use unicase::UniCase; - -/// Represents a server-side (incoming) request. -pub struct Request { - /// The HTTP method used to create the request. All values except `Method::Get` are - /// rejected by `validate()`. - pub method: Method, - - /// The target URI for this request. - pub url: RequestUri, - - /// The HTTP version of this request. - pub version: HttpVersion, - - /// The headers of this request. - pub headers: Headers, - - reader: R, - writer: W, -} - -unsafe impl Send for Request where R: Read + Send, W: Write + Send { } - -impl Request { - /// Short-cut to obtain the WebSocketKey value. - pub fn key(&self) -> Option<&WebSocketKey> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketVersion value. - pub fn version(&self) -> Option<&WebSocketVersion> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketProtocol value. - pub fn protocol(&self) -> Option<&WebSocketProtocol> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketExtensions value. - pub fn extensions(&self) -> Option<&WebSocketExtensions> { - self.headers.get() - } - /// Short-cut to obtain the Origin value. - pub fn origin(&self) -> Option<&Origin> { - self.headers.get() - } - /// Returns a reference to the inner Reader. - pub fn get_reader(&self) -> &R { - &self.reader - } - /// Returns a reference to the inner Writer. - pub fn get_writer(&self) -> &W { - &self.writer - } - /// Returns a mutable reference to the inner Reader. - pub fn get_mut_reader(&mut self) -> &mut R { - &mut self.reader - } - /// Returns a mutable reference to the inner Writer. - pub fn get_mut_writer(&mut self) -> &mut W { - &mut self.writer - } - /// Return the inner Reader and Writer - pub fn into_inner(self) -> (R, W) { - (self.reader, self.writer) - } - /// Reads an inbound request. - /// - /// This method is used within servers, and returns an inbound WebSocketRequest. - /// An error will be returned if the request cannot be read, or is not a valid HTTP - /// request. - /// - /// This method does not have any restrictions on the Request. All validation happens in - /// the `validate` method. - pub fn read(reader: R, writer: W) -> WebSocketResult> { - let mut reader = BufReader::new(reader); - let request = try!(parse_request(&mut reader)); - - Ok(Request { - method: request.subject.0, - url: request.subject.1, - version: request.version, - headers: request.headers, - reader: reader.into_inner(), - writer: writer, - }) - } - /// Check if this constitutes a valid WebSocket upgrade request. - /// - /// Note that `accept()` calls this function internally, however this may be useful for - /// handling requests in a custom way. - pub fn validate(&self) -> WebSocketResult<()> { - if self.method != Method::Get { - return Err(WebSocketError::RequestError("Request method must be GET")); - } - - if self.version == HttpVersion::Http09 || self.version == HttpVersion::Http10 { - return Err(WebSocketError::RequestError("Unsupported request HTTP version")); - } - - if self.version() != Some(&(WebSocketVersion::WebSocket13)) { - return Err(WebSocketError::RequestError("Unsupported WebSocket version")); - } - - if self.key().is_none() { - return Err(WebSocketError::RequestError("Missing Sec-WebSocket-Key header")); - } - - match self.headers.get() { - Some(&Upgrade(ref upgrade)) => { - let mut correct_upgrade = false; - for u in upgrade { - if u.name == ProtocolName::WebSocket { - correct_upgrade = true; - } - } - if !correct_upgrade { - return Err(WebSocketError::RequestError("Invalid Upgrade WebSocket header")); - } - } - None => { return Err(WebSocketError::RequestError("Missing Upgrade WebSocket header")); } - } - - match self.headers.get() { - Some(&Connection(ref connection)) => { - if !connection.contains(&(ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())))) { - return Err(WebSocketError::RequestError("Invalid Connection WebSocket header")); - } - } - None => { return Err(WebSocketError::RequestError("Missing Connection WebSocket header")); } - } - - Ok(()) - } - - /// Accept this request, ready to send a response. - /// - /// This function calls `validate()` on the request, and if the request is found to be invalid, - /// generates a response with a Bad Request status code. - pub fn accept(self) -> Response { - match self.validate() { - Ok(()) => { } - Err(_) => { return self.fail(); } - } - Response::new(self) - } - - /// Fail this request by generating a Bad Request response - pub fn fail(self) -> Response { - Response::bad_request(self) - } -} - diff --git a/src/server/response.rs b/src/server/response.rs deleted file mode 100644 index 0bae5e4cd0..0000000000 --- a/src/server/response.rs +++ /dev/null @@ -1,150 +0,0 @@ -//! Struct for server-side WebSocket response. -use std::io::{Read, Write}; - -use hyper::status::StatusCode; -use hyper::version::HttpVersion; -use hyper::header::Headers; -use hyper::header::{Connection, ConnectionOption}; -use hyper::header::{Upgrade, Protocol, ProtocolName}; -use hyper::buffer::BufReader; - -use unicase::UniCase; - -use header::{WebSocketAccept, WebSocketProtocol, WebSocketExtensions}; -use sender::Sender; -use receiver::Receiver; -use server::Request; -use client::Client; -use result::WebSocketResult; -use dataframe::DataFrame; -use ws::dataframe::DataFrame as DataFrameable; -use ws; - -/// Represents a server-side (outgoing) response. -pub struct Response { - /// The status of the response - pub status: StatusCode, - /// The headers contained in this response - pub headers: Headers, - /// The HTTP version of this response - pub version: HttpVersion, - - request: Request -} - -unsafe impl Send for Response where R: Read + Send, W: Write + Send { } - -impl Response { - /// Short-cut to obtain the WebSocketAccept value - pub fn accept(&self) -> Option<&WebSocketAccept> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketProtocol value - pub fn protocol(&self) -> Option<&WebSocketProtocol> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketExtensions value - pub fn extensions(&self) -> Option<&WebSocketExtensions> { - self.headers.get() - } - /// Returns a reference to the inner Reader. - pub fn get_reader(&self) -> &R { - self.request.get_reader() - } - /// Returns a reference to the inner Writer. - pub fn get_writer(&self) -> &W { - self.request.get_writer() - } - /// Returns a mutable reference to the inner Reader. - pub fn get_mut_reader(&mut self) -> &mut R { - self.request.get_mut_reader() - } - /// Returns a mutable reference to the inner Writer. - pub fn get_mut_writer(&mut self) -> &mut W { - self.request.get_mut_writer() - } - /// Returns a reference to the request associated with this response/ - pub fn get_request(&self) -> &Request { - &self.request - } - /// Return the inner Reader and Writer - pub fn into_inner(self) -> (R, W) { - self.request.into_inner() - } - /// Create a new outbound WebSocket response. - pub fn new(request: Request) -> Response { - let mut headers = Headers::new(); - headers.set(WebSocketAccept::new(request.key().unwrap())); - headers.set(Connection(vec![ - ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) - ])); - headers.set(Upgrade(vec![Protocol::new(ProtocolName::WebSocket, None)])); - Response { - status: StatusCode::SwitchingProtocols, - headers: headers, - version: HttpVersion::Http11, - request: request - } - } - /// Create a Bad Request response - pub fn bad_request(request: Request) -> Response { - Response { - status: StatusCode::BadRequest, - headers: Headers::new(), - version: HttpVersion::Http11, - request: request - } - } - /// Short-cut to obtain a mutable reference to the WebSocketAccept value - /// Note that to add a header that does not already exist, ```WebSocketResponse.headers.set()``` - /// must be used. - pub fn accept_mut(&mut self) -> Option<&mut WebSocketAccept> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the WebSocketProtocol value - /// Note that to add a header that does not already exist, ```WebSocketResponse.headers.set()``` - /// must be used. - pub fn protocol_mut(&mut self) -> Option<&mut WebSocketProtocol> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the WebSocketExtensions value - /// Note that to add a header that does not already exist, ```WebSocketResponse.headers.set()``` - /// must be used. - pub fn extensions_mut(&mut self) -> Option<&mut WebSocketExtensions> { - self.headers.get_mut() - } - - /// Send this response with the given data frame type D, Sender B and Receiver C. - pub fn send_with(mut self, sender: B, receiver: C) -> WebSocketResult> - where B: ws::Sender, C: ws::Receiver, D: DataFrameable { - let version = self.version; - let status = self.status; - let headers = self.headers.clone(); - try!(write!(self.get_mut_writer(), "{} {}\r\n", version, status)); - try!(write!(self.get_mut_writer(), "{}\r\n", headers)); - Ok(Client::new(sender, receiver)) - } - - /// Send this response, retrieving the inner Reader and Writer - pub fn send_into_inner(mut self) -> WebSocketResult<(R, W)> { - let version = self.version; - let status = self.status; - let headers = self.headers.clone(); - try!(write!(self.get_mut_writer(), "{} {}\r\n", version, status)); - try!(write!(self.get_mut_writer(), "{}\r\n", headers)); - Ok(self.into_inner()) - } - - /// Send this response, returning a Client ready to transmit/receive data frames - pub fn send(mut self) -> WebSocketResult, Receiver>> { - let version = self.version; - let status = self.status; - let headers = self.headers.clone(); - try!(write!(self.get_mut_writer(), "{} {}\r\n", version, status)); - try!(write!(self.get_mut_writer(), "{}\r\n", headers)); - let (reader, writer) = self.into_inner(); - let sender = Sender::new(writer, false); - let receiver = Receiver::new(BufReader::new(reader), true); - Ok(Client::new(sender, receiver)) - } -} diff --git a/src/server/upgrade/from_hyper.rs b/src/server/upgrade/from_hyper.rs new file mode 100644 index 0000000000..da6ae1e134 --- /dev/null +++ b/src/server/upgrade/from_hyper.rs @@ -0,0 +1,88 @@ +//! Upgrade a hyper connection to a websocket one. +//! +//! Using this method, one can start a hyper server and check if each request +//! is a websocket upgrade request, if so you can use websockets and hyper on the +//! same port! +//! +//! ```rust,no_run +//! # extern crate hyper; +//! # extern crate websocket; +//! # fn main() { +//! use hyper::server::{Server, Request, Response}; +//! use websocket::Message; +//! use websocket::server::upgrade::IntoWs; +//! use websocket::server::upgrade::from_hyper::HyperRequest; +//! +//! Server::http("0.0.0.0:80").unwrap().handle(move |req: Request, res: Response| { +//! match HyperRequest(req).into_ws() { +//! Ok(upgrade) => { +//! // `accept` sends a successful handshake, no need to worry about res +//! let mut client = match upgrade.accept() { +//! Ok(c) => c, +//! Err(_) => panic!(), +//! }; +//! +//! client.send_message(&Message::text("its free real estate")); +//! }, +//! +//! Err((request, err)) => { +//! // continue using the request as normal, "echo uri" +//! res.send(b"Try connecting over ws instead.").unwrap(); +//! }, +//! }; +//! }) +//! .unwrap(); +//! # } +//! ``` + +use hyper::net::NetworkStream; +use super::{IntoWs, WsUpgrade, Buffer}; + +pub use hyper::http::h1::Incoming; +pub use hyper::method::Method; +pub use hyper::version::HttpVersion; +pub use hyper::uri::RequestUri; +pub use hyper::buffer::BufReader; +use hyper::server::Request; +pub use hyper::header::{Headers, Upgrade, ProtocolName, Connection, ConnectionOption}; + +use super::validate; +use super::HyperIntoWsError; + +/// A hyper request is implicitly defined as a stream from other `impl`s of Stream. +/// Until trait impl specialization comes along, we use this struct to differentiate +/// a hyper request (which already has parsed headers) from a normal stream. +pub struct HyperRequest<'a, 'b: 'a>(pub Request<'a, 'b>); + +impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { + type Stream = &'a mut &'b mut NetworkStream; + type Error = (Request<'a, 'b>, HyperIntoWsError); + + fn into_ws(self) -> Result, Self::Error> { + if let Err(e) = validate(&self.0.method, &self.0.version, &self.0.headers) { + return Err((self.0, e)); + } + + let (_, method, headers, uri, version, reader) = + self.0.deconstruct(); + + let reader = reader.into_inner(); + let (buf, pos, cap) = reader.take_buf(); + let stream = reader.get_mut(); + + Ok(WsUpgrade { + headers: Headers::new(), + stream: stream, + buffer: Some(Buffer { + buf: buf, + pos: pos, + cap: cap, + }), + request: Incoming { + version: version, + headers: headers, + subject: (method, uri), + }, + }) + } +} diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs new file mode 100644 index 0000000000..51fa4332fe --- /dev/null +++ b/src/server/upgrade/mod.rs @@ -0,0 +1,436 @@ +//! Allows you to take an existing request or stream of data and convert it into a +//! WebSocket client. +use std::error::Error; +use std::net::TcpStream; +use std::io; +use std::io::Result as IoResult; +use std::io::Error as IoError; +use std::fmt::{self, Formatter, Display}; +use stream::{Stream, AsTcpStream}; +use header::extensions::Extension; +use header::{WebSocketAccept, WebSocketKey, WebSocketVersion, WebSocketProtocol, + WebSocketExtensions, Origin}; +use client::Client; + +use unicase::UniCase; +use hyper::status::StatusCode; +use hyper::http::h1::Incoming; +use hyper::method::Method; +use hyper::version::HttpVersion; +use hyper::uri::RequestUri; +use hyper::buffer::BufReader; +use hyper::http::h1::parse_request; +use hyper::header::{Headers, Upgrade, Protocol, ProtocolName, Connection, ConnectionOption}; + +pub mod from_hyper; + +/// This crate uses buffered readers to read in the handshake quickly, in order to +/// interface with other use cases that don't use buffered readers the buffered readers +/// is deconstructed when it is returned to the user and given as the underlying +/// reader and the buffer. +/// +/// This struct represents bytes that have already been read in from the stream. +/// A slice of valid data in this buffer can be obtained by: `&buf[pos..cap]`. +pub struct Buffer { + /// the contents of the buffered stream data + pub buf: Vec, + /// the current position of cursor in the buffer + /// Any data before `pos` has already been read and parsed. + pub pos: usize, + /// the last location of valid data + /// Any data after `cap` is not valid. + pub cap: usize, +} + +/// Intermediate representation of a half created websocket session. +/// Should be used to examine the client's handshake +/// accept the protocols requested, route the path, etc. +/// +/// Users should then call `accept` or `reject` to complete the handshake +/// and start a session. +pub struct WsUpgrade + where S: Stream +{ + /// The headers that will be used in the handshake response. + pub headers: Headers, + /// The stream that will be used to read from / write to. + pub stream: S, + /// The handshake request, filled with useful metadata. + pub request: Request, + /// Some buffered data from the stream, if it exists. + pub buffer: Option, +} + +impl WsUpgrade + where S: Stream +{ + /// Select a protocol to use in the handshake response. + pub fn use_protocol

(mut self, protocol: P) -> Self + where P: Into + { + upsert_header!(self.headers; WebSocketProtocol; { + Some(protos) => protos.0.push(protocol.into()), + None => WebSocketProtocol(vec![protocol.into()]) + }); + self + } + + /// Select an extension to use in the handshake response. + pub fn use_extension(mut self, extension: Extension) -> Self { + upsert_header!(self.headers; WebSocketExtensions; { + Some(protos) => protos.0.push(extension), + None => WebSocketExtensions(vec![extension]) + }); + self + } + + /// Select multiple extensions to use in the connection + pub fn use_extensions(mut self, extensions: I) -> Self + where I: IntoIterator + { + let mut extensions: Vec = + extensions.into_iter().collect(); + upsert_header!(self.headers; WebSocketExtensions; { + Some(protos) => protos.0.append(&mut extensions), + None => WebSocketExtensions(extensions) + }); + self + } + + /// Accept the handshake request and send a response, + /// if nothing goes wrong a client will be created. + pub fn accept(self) -> Result, (S, IoError)> { + self.accept_with(&Headers::new()) + } + + /// Accept the handshake request and send a response while + /// adding on a few headers. These headers are added before the required + /// headers are, so some might be overwritten. + pub fn accept_with(mut self, custom_headers: &Headers) -> Result, (S, IoError)> { + self.headers.extend(custom_headers.iter()); + self.headers + .set(WebSocketAccept::new(// NOTE: we know there is a key because this is a valid request + // i.e. to construct this you must go through the validate function + self.request.headers.get::().unwrap())); + self.headers + .set(Connection(vec![ + ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) + ])); + self.headers.set(Upgrade(vec![Protocol::new(ProtocolName::WebSocket, None)])); + + if let Err(e) = self.send(StatusCode::SwitchingProtocols) { + return Err((self.stream, e)); + } + + let stream = match self.buffer { + Some(Buffer { buf, pos, cap }) => BufReader::from_parts(self.stream, buf, pos, cap), + None => BufReader::new(self.stream), + }; + + Ok(Client::unchecked(stream, self.headers)) + } + + /// Reject the client's request to make a websocket connection. + pub fn reject(self) -> Result { + self.reject_with(&Headers::new()) + } + /// Reject the client's request to make a websocket connection + /// and send extra headers. + pub fn reject_with(mut self, headers: &Headers) -> Result { + self.headers.extend(headers.iter()); + match self.send(StatusCode::BadRequest) { + Ok(()) => Ok(self.stream), + Err(e) => Err((self.stream, e)), + } + } + + /// Drop the connection without saying anything. + pub fn drop(self) { + ::std::mem::drop(self); + } + + /// A list of protocols requested from the client. + pub fn protocols(&self) -> &[String] { + self.request + .headers + .get::() + .map(|p| p.0.as_slice()) + .unwrap_or(&[]) + } + + /// A list of extensions requested from the client. + pub fn extensions(&self) -> &[Extension] { + self.request + .headers + .get::() + .map(|e| e.0.as_slice()) + .unwrap_or(&[]) + } + + /// The client's websocket accept key. + pub fn key(&self) -> Option<&[u8; 16]> { + self.request.headers.get::().map(|k| &k.0) + } + + /// The client's websocket version. + pub fn version(&self) -> Option<&WebSocketVersion> { + self.request.headers.get::() + } + + /// Origin of the client + pub fn origin(&self) -> Option<&str> { + self.request.headers.get::().map(|o| &o.0 as &str) + } + + fn send(&mut self, status: StatusCode) -> IoResult<()> { + try!(write!(&mut self.stream, "{} {}\r\n", self.request.version, status)); + try!(write!(&mut self.stream, "{}\r\n", self.headers)); + Ok(()) + } +} + +impl WsUpgrade + where S: Stream + AsTcpStream +{ + /// Get a handle to the underlying TCP stream, useful to be able to set + /// TCP options, etc. + pub fn tcp_stream(&self) -> &TcpStream { + self.stream.as_tcp() + } +} + +/// Trait to take a stream or similar and attempt to recover the start of a +/// websocket handshake from it. +/// Should be used when a stream might contain a request for a websocket session. +/// +/// If an upgrade request can be parsed, one can accept or deny the handshake with +/// the `WsUpgrade` struct. +/// Otherwise the original stream is returned along with an error. +/// +/// Note: the stream is owned because the websocket client expects to own its stream. +/// +/// This is already implemented for all Streams, which means all types with Read + Write. +/// +/// # Example +/// +/// ```rust,no_run +/// use std::net::TcpListener; +/// use std::net::TcpStream; +/// use websocket::server::upgrade::IntoWs; +/// use websocket::Client; +/// +/// let listener = TcpListener::bind("127.0.0.1:80").unwrap(); +/// +/// for stream in listener.incoming().filter_map(Result::ok) { +/// let mut client: Client = match stream.into_ws() { +/// Ok(upgrade) => { +/// match upgrade.accept() { +/// Ok(client) => client, +/// Err(_) => panic!(), +/// } +/// }, +/// Err(_) => panic!(), +/// }; +/// } +/// ``` +pub trait IntoWs { + /// The type of stream this upgrade process is working with (TcpStream, etc.) + type Stream: Stream; + /// An error value in case the stream is not asking for a websocket connection + /// or something went wrong. It is common to also include the stream here. + type Error; + /// Attempt to parse the start of a Websocket handshake, later with the returned + /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to + /// send a handshake rejection response. + fn into_ws(self) -> Result, Self::Error>; +} + + +/// A typical request from hyper +pub type Request = Incoming<(Method, RequestUri)>; +/// If you have your requests separate from your stream you can use this struct +/// to upgrade the connection based on the request given +/// (the request should be a handshake). +pub struct RequestStreamPair(pub S, pub Request); + +impl IntoWs for S + where S: Stream +{ + type Stream = S; + type Error = (S, Option, Option, HyperIntoWsError); + + fn into_ws(self) -> Result, Self::Error> { + let mut reader = BufReader::new(self); + let request = parse_request(&mut reader); + + let (stream, buf, pos, cap) = reader.into_parts(); + let buffer = Some(Buffer { + buf: buf, + cap: cap, + pos: pos, + }); + + let request = match request { + Ok(r) => r, + Err(e) => return Err((stream, None, buffer, e.into())), + }; + + match validate(&request.subject.0, &request.version, &request.headers) { + Ok(_) => { + Ok(WsUpgrade { + headers: Headers::new(), + stream: stream, + request: request, + buffer: buffer, + }) + } + Err(e) => Err((stream, Some(request), buffer, e)), + } + } +} + +impl IntoWs for RequestStreamPair + where S: Stream +{ + type Stream = S; + type Error = (S, Request, HyperIntoWsError); + + fn into_ws(self) -> Result, Self::Error> { + match validate(&self.1.subject.0, &self.1.version, &self.1.headers) { + Ok(_) => { + Ok(WsUpgrade { + headers: Headers::new(), + stream: self.0, + request: self.1, + buffer: None, + }) + } + Err(e) => Err((self.0, self.1, e)), + } + } +} + +/// Errors that can occur when one tries to upgrade a connection to a +/// websocket connection. +#[derive(Debug)] +pub enum HyperIntoWsError { + /// The HTTP method in a valid websocket upgrade request must be GET + MethodNotGet, + /// Currently HTTP 2 is not supported + UnsupportedHttpVersion, + /// Currently only WebSocket13 is supported (RFC6455) + UnsupportedWebsocketVersion, + /// A websocket upgrade request must contain a key + NoSecWsKeyHeader, + /// A websocket upgrade request must ask to upgrade to a `websocket` + NoWsUpgradeHeader, + /// A websocket upgrade request must contain an `Upgrade` header + NoUpgradeHeader, + /// A websocket upgrade request's `Connection` header must be `Upgrade` + NoWsConnectionHeader, + /// A websocket upgrade request must contain a `Connection` header + NoConnectionHeader, + /// IO error from reading the underlying socket + Io(io::Error), + /// Error while parsing an incoming request + Parsing(::hyper::error::Error), +} + +impl Display for HyperIntoWsError { + fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { + fmt.write_str(self.description()) + } +} + +impl Error for HyperIntoWsError { + fn description(&self) -> &str { + use self::HyperIntoWsError::*; + match self { + &MethodNotGet => "Request method must be GET", + &UnsupportedHttpVersion => "Unsupported request HTTP version", + &UnsupportedWebsocketVersion => "Unsupported WebSocket version", + &NoSecWsKeyHeader => "Missing Sec-WebSocket-Key header", + &NoWsUpgradeHeader => "Invalid Upgrade WebSocket header", + &NoUpgradeHeader => "Missing Upgrade WebSocket header", + &NoWsConnectionHeader => "Invalid Connection WebSocket header", + &NoConnectionHeader => "Missing Connection WebSocket header", + &Io(ref e) => e.description(), + &Parsing(ref e) => e.description(), + } + } + + fn cause(&self) -> Option<&Error> { + match *self { + HyperIntoWsError::Io(ref e) => Some(e), + HyperIntoWsError::Parsing(ref e) => Some(e), + _ => None, + } + } +} + +impl From for HyperIntoWsError { + fn from(err: io::Error) -> Self { + HyperIntoWsError::Io(err) + } +} + +impl From<::hyper::error::Error> for HyperIntoWsError { + fn from(err: ::hyper::error::Error) -> Self { + HyperIntoWsError::Parsing(err) + } +} + +fn validate( + method: &Method, + version: &HttpVersion, + headers: &Headers, +) -> Result<(), HyperIntoWsError> { + if *method != Method::Get { + return Err(HyperIntoWsError::MethodNotGet); + } + + if *version == HttpVersion::Http09 || *version == HttpVersion::Http10 { + return Err(HyperIntoWsError::UnsupportedHttpVersion); + } + + if let Some(version) = headers.get::() { + if version != &WebSocketVersion::WebSocket13 { + return Err(HyperIntoWsError::UnsupportedWebsocketVersion); + } + } + + if headers.get::().is_none() { + return Err(HyperIntoWsError::NoSecWsKeyHeader); + } + + match headers.get() { + Some(&Upgrade(ref upgrade)) => { + if upgrade.iter().all(|u| u.name != ProtocolName::WebSocket) { + return Err(HyperIntoWsError::NoWsUpgradeHeader); + } + } + None => return Err(HyperIntoWsError::NoUpgradeHeader), + }; + + fn check_connection_header(headers: &Vec) -> bool { + for header in headers { + if let &ConnectionOption::ConnectionHeader(ref h) = header { + if h as &str == "upgrade" { + return true; + } + } + } + false + } + + match headers.get() { + Some(&Connection(ref connection)) => { + if !check_connection_header(connection) { + return Err(HyperIntoWsError::NoWsConnectionHeader); + } + } + None => return Err(HyperIntoWsError::NoConnectionHeader), + }; + + Ok(()) +} diff --git a/src/stream.rs b/src/stream.rs index 623b280bbd..c6901202d7 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,94 +1,139 @@ //! Provides the default stream type for WebSocket connections. -extern crate net2; +use std::ops::Deref; +use std::fmt::Arguments; use std::io::{self, Read, Write}; -use self::net2::TcpStreamExt; -use openssl::ssl::SslStream; +pub use std::net::TcpStream; +pub use std::net::Shutdown; +#[cfg(feature="ssl")] +pub use openssl::ssl::{SslStream, SslContext}; -pub use std::net::{SocketAddr, Shutdown, TcpStream}; +/// Represents a stream that can be read from, and written to. +/// This is an abstraction around readable and writable things to be able +/// to speak websockets over ssl, tcp, unix sockets, etc. +pub trait Stream: Read + Write {} -/// A useful stream type for carrying WebSocket connections. -pub enum WebSocketStream { - /// A TCP stream. - Tcp(TcpStream), - /// An SSL-backed TCP Stream - Ssl(SslStream) +impl Stream for S where S: Read + Write {} + +/// a `Stream` that can also be used as a borrow to a `TcpStream` +/// this is useful when you want to set `TcpStream` options on a +/// `Stream` like `nonblocking`. +pub trait NetworkStream: Read + Write + AsTcpStream {} + +impl NetworkStream for S where S: Read + Write + AsTcpStream {} + +/// some streams can be split up into separate reading and writing components +/// `TcpStream` is an example. This trait marks this ability so one can split +/// up the client into two parts. +/// +/// Notice however that this is not possible to do with SSL. +pub trait Splittable { + /// The reading component of this type + type Reader: Read; + /// The writing component of this type + type Writer: Write; + + /// Split apart this type into a reading and writing component. + fn split(self) -> io::Result<(Self::Reader, Self::Writer)>; } -impl Read for WebSocketStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *self { - WebSocketStream::Tcp(ref mut inner) => inner.read(buf), - WebSocketStream::Ssl(ref mut inner) => inner.read(buf), - } +impl Splittable for ReadWritePair + where R: Read, + W: Write +{ + type Reader = R; + type Writer = W; + + fn split(self) -> io::Result<(R, W)> { + Ok((self.0, self.1)) } } -impl Write for WebSocketStream { - fn write(&mut self, msg: &[u8]) -> io::Result { - match *self { - WebSocketStream::Tcp(ref mut inner) => inner.write(msg), - WebSocketStream::Ssl(ref mut inner) => inner.write(msg), - } +impl Splittable for TcpStream { + type Reader = TcpStream; + type Writer = TcpStream; + + fn split(self) -> io::Result<(TcpStream, TcpStream)> { + self.try_clone().map(|s| (s, self)) } +} - fn flush(&mut self) -> io::Result<()> { - match *self { - WebSocketStream::Tcp(ref mut inner) => inner.flush(), - WebSocketStream::Ssl(ref mut inner) => inner.flush(), - } +/// The ability access a borrow to an underlying TcpStream, +/// so one can set options on the stream such as `nonblocking`. +pub trait AsTcpStream { + /// Get a borrow of the TcpStream + fn as_tcp(&self) -> &TcpStream; +} + +impl AsTcpStream for TcpStream { + fn as_tcp(&self) -> &TcpStream { + &self } } -impl WebSocketStream { - /// See `TcpStream.peer_addr()`. - pub fn peer_addr(&self) -> io::Result { - match *self { - WebSocketStream::Tcp(ref inner) => inner.peer_addr(), - WebSocketStream::Ssl(ref inner) => inner.get_ref().peer_addr(), - } +#[cfg(feature="ssl")] +impl AsTcpStream for SslStream { + fn as_tcp(&self) -> &TcpStream { + self.get_ref() } - /// See `TcpStream.local_addr()`. - pub fn local_addr(&self) -> io::Result { - match *self { - WebSocketStream::Tcp(ref inner) => inner.local_addr(), - WebSocketStream::Ssl(ref inner) => inner.get_ref().local_addr(), - } +} + +impl AsTcpStream for Box + where T: AsTcpStream +{ + fn as_tcp(&self) -> &TcpStream { + self.deref().as_tcp() } - /// See `TcpStream.set_nodelay()`. - pub fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - match *self { - WebSocketStream::Tcp(ref mut inner) => TcpStreamExt::set_nodelay(inner, nodelay), - WebSocketStream::Ssl(ref mut inner) => TcpStreamExt::set_nodelay(inner.get_mut(), nodelay), - } +} + +/// If you would like to combine an input stream and an output stream into a single +/// stream to talk websockets over then this is the struct for you! +/// +/// This is useful if you want to use different mediums for different directions. +pub struct ReadWritePair(pub R, pub W) + where R: Read, + W: Write; + +impl Read for ReadWritePair + where R: Read, + W: Write +{ + #[inline(always)] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) } - /// See `TcpStream.set_keepalive()`. - pub fn set_keepalive(&mut self, delay_in_ms: Option) -> io::Result<()> { - match *self { - WebSocketStream::Tcp(ref mut inner) => TcpStreamExt::set_keepalive_ms(inner, delay_in_ms), - WebSocketStream::Ssl(ref mut inner) => TcpStreamExt::set_keepalive_ms(inner.get_mut(), delay_in_ms), - } + #[inline(always)] + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + self.0.read_to_end(buf) } - /// See `TcpStream.shutdown()`. - pub fn shutdown(&mut self, shutdown: Shutdown) -> io::Result<()> { - match *self { - WebSocketStream::Tcp(ref mut inner) => inner.shutdown(shutdown), - WebSocketStream::Ssl(ref mut inner) => inner.get_mut().shutdown(shutdown), - } + #[inline(always)] + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + self.0.read_to_string(buf) } - /// See `TcpStream.try_clone()`. - pub fn try_clone(&self) -> io::Result { - Ok(match *self { - WebSocketStream::Tcp(ref inner) => WebSocketStream::Tcp(try!(inner.try_clone())), - WebSocketStream::Ssl(ref inner) => WebSocketStream::Ssl(try!(inner.try_clone())), - }) + #[inline(always)] + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.0.read_exact(buf) } +} - /// Changes whether the stream is in nonblocking mode. - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - match *self { - WebSocketStream::Tcp(ref inner) => inner.set_nonblocking(nonblocking), - WebSocketStream::Ssl(ref inner) => inner.get_ref().set_nonblocking(nonblocking), - } - } +impl Write for ReadWritePair + where R: Read, + W: Write +{ + #[inline(always)] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.1.write(buf) + } + #[inline(always)] + fn flush(&mut self) -> io::Result<()> { + self.1.flush() + } + #[inline(always)] + fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { + self.1.write_all(buf) + } + #[inline(always)] + fn write_fmt(&mut self, fmt: Arguments) -> io::Result<()> { + self.1.write_fmt(fmt) + } } diff --git a/src/ws/dataframe.rs b/src/ws/dataframe.rs index 2c8243b240..8e1669093b 100644 --- a/src/ws/dataframe.rs +++ b/src/ws/dataframe.rs @@ -13,112 +13,113 @@ use ws::util::mask; /// provide these methods. (If the payload is not known in advance then /// rewrite the write_payload method) pub trait DataFrame { - /// Is this dataframe the final dataframe of the message? - fn is_last(&self) -> bool; - /// What type of data does this dataframe contain? - fn opcode(&self) -> u8; - /// Reserved bits of this dataframe - fn reserved<'a>(&'a self) -> &'a [bool; 3]; - /// Entire payload of the dataframe. If not known then implement - /// write_payload as that is the actual method used when sending the - /// dataframe over the wire. - fn payload<'a>(&'a self) -> Cow<'a, [u8]>; + /// Is this dataframe the final dataframe of the message? + fn is_last(&self) -> bool; + /// What type of data does this dataframe contain? + fn opcode(&self) -> u8; + /// Reserved bits of this dataframe + fn reserved<'a>(&'a self) -> &'a [bool; 3]; + /// Entire payload of the dataframe. If not known then implement + /// write_payload as that is the actual method used when sending the + /// dataframe over the wire. + fn payload<'a>(&'a self) -> Cow<'a, [u8]>; - /// How long (in bytes) is this dataframe's payload - fn size(&self) -> usize { - self.payload().len() - } + /// How long (in bytes) is this dataframe's payload + fn size(&self) -> usize { + self.payload().len() + } - /// Write the payload to a writer - fn write_payload(&self, socket: &mut W) -> WebSocketResult<()> - where W: Write { - try!(socket.write_all(&*self.payload())); - Ok(()) - } + /// Write the payload to a writer + fn write_payload(&self, socket: &mut W) -> WebSocketResult<()> + where W: Write + { + try!(socket.write_all(&*self.payload())); + Ok(()) + } - /// Writes a DataFrame to a Writer. - fn write_to(&self, writer: &mut W, mask: bool) -> WebSocketResult<()> - where W: Write { - let mut flags = dfh::DataFrameFlags::empty(); - if self.is_last() { - flags.insert(dfh::FIN); - } - { - let reserved = self.reserved(); - if reserved[0] { - flags.insert(dfh::RSV1); - } - if reserved[1] { - flags.insert(dfh::RSV2); - } - if reserved[2] { - flags.insert(dfh::RSV3); - } - } + /// Writes a DataFrame to a Writer. + fn write_to(&self, writer: &mut W, mask: bool) -> WebSocketResult<()> + where W: Write + { + let mut flags = dfh::DataFrameFlags::empty(); + if self.is_last() { + flags.insert(dfh::FIN); + } + { + let reserved = self.reserved(); + if reserved[0] { + flags.insert(dfh::RSV1); + } + if reserved[1] { + flags.insert(dfh::RSV2); + } + if reserved[2] { + flags.insert(dfh::RSV3); + } + } - let masking_key = if mask { - Some(mask::gen_mask()) - } else { - None - }; + let masking_key = if mask { Some(mask::gen_mask()) } else { None }; - let header = dfh::DataFrameHeader { - flags: flags, - opcode: self.opcode() as u8, - mask: masking_key, - len: self.size() as u64, - }; + let header = dfh::DataFrameHeader { + flags: flags, + opcode: self.opcode() as u8, + mask: masking_key, + len: self.size() as u64, + }; - try!(dfh::write_header(writer, header)); + try!(dfh::write_header(writer, header)); - match masking_key { - Some(mask) => { - let mut masker = Masker::new(mask, writer); - try!(self.write_payload(&mut masker)) - }, - None => try!(self.write_payload(writer)), - }; - try!(writer.flush()); - Ok(()) - } + match masking_key { + Some(mask) => { + let mut masker = Masker::new(mask, writer); + try!(self.write_payload(&mut masker)) + } + None => try!(self.write_payload(writer)), + }; + try!(writer.flush()); + Ok(()) + } } impl<'a, D> DataFrame for &'a D -where D: DataFrame { - #[inline(always)] - fn is_last(&self) -> bool { - D::is_last(self) - } + where D: DataFrame +{ + #[inline(always)] + fn is_last(&self) -> bool { + D::is_last(self) + } - #[inline(always)] - fn opcode(&self) -> u8 { - D::opcode(self) - } + #[inline(always)] + fn opcode(&self) -> u8 { + D::opcode(self) + } - #[inline(always)] - fn reserved<'b>(&'b self) -> &'b [bool; 3] { - D::reserved(self) - } + #[inline(always)] + fn reserved<'b>(&'b self) -> &'b [bool; 3] { + D::reserved(self) + } - #[inline(always)] - fn payload<'b>(&'b self) -> Cow<'b, [u8]> { - D::payload(self) - } + #[inline(always)] + fn payload<'b>(&'b self) -> Cow<'b, [u8]> { + D::payload(self) + } - #[inline(always)] - fn size(&self) -> usize { - D::size(self) - } + #[inline(always)] + fn size(&self) -> usize { + D::size(self) + } - #[inline(always)] - fn write_payload(&self, socket: &mut W) -> WebSocketResult<()> - where W: Write { - D::write_payload(self, socket) - } + #[inline(always)] + fn write_payload(&self, socket: &mut W) -> WebSocketResult<()> + where W: Write + { + D::write_payload(self, socket) + } - #[inline(always)] - fn write_to(&self, writer: &mut W, mask: bool) -> WebSocketResult<()> - where W: Write { - D::write_to(self, writer, mask) - } + #[inline(always)] + fn write_to(&self, writer: &mut W, mask: bool) -> WebSocketResult<()> + where W: Write + { + D::write_to(self, writer, mask) + } } diff --git a/src/ws/message.rs b/src/ws/message.rs index a21e110002..dc02aa1a63 100644 --- a/src/ws/message.rs +++ b/src/ws/message.rs @@ -7,12 +7,12 @@ use ws::dataframe::DataFrame; /// A trait for WebSocket messages pub trait Message<'a, F>: Sized -where F: DataFrame { + where F: DataFrame +{ /// The iterator type returned by dataframes type DataFrameIterator: Iterator; /// Attempt to form a message from a slice of data frames. - fn from_dataframes(frames: Vec) -> WebSocketResult - where D: DataFrame; + fn from_dataframes(frames: Vec) -> WebSocketResult where D: DataFrame; /// Turns this message into an iterator over data frames fn dataframes(&'a self) -> Self::DataFrameIterator; } diff --git a/src/ws/receiver.rs b/src/ws/receiver.rs index cfd934e41b..3e237166da 100644 --- a/src/ws/receiver.rs +++ b/src/ws/receiver.rs @@ -3,90 +3,108 @@ //! Also provides iterators over data frames and messages. //! See the `ws` module documentation for more information. +use std::io::Read; use std::marker::PhantomData; use ws::Message; use ws::dataframe::DataFrame; use result::WebSocketResult; /// A trait for receiving data frames and messages. -pub trait Receiver: Sized -where F: DataFrame { +pub trait Receiver: Sized { + /// The type of dataframe that incoming messages will be serialized to. + type F: DataFrame; + /// Reads a single data frame from this receiver. - fn recv_dataframe(&mut self) -> WebSocketResult; + fn recv_dataframe(&mut self, reader: &mut R) -> WebSocketResult where R: Read; + /// Returns the data frames that constitute one message. - fn recv_message_dataframes(&mut self) -> WebSocketResult>; + fn recv_message_dataframes(&mut self, reader: &mut R) -> WebSocketResult> + where R: Read; /// Returns an iterator over incoming data frames. - fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Self, F> { + fn incoming_dataframes<'a, R>(&'a mut self, reader: &'a mut R) -> DataFrameIterator<'a, Self, R> + where R: Read + { DataFrameIterator { + reader: reader, inner: self, - _dataframe: PhantomData } } + /// Reads a single message from this receiver. - fn recv_message<'m, D, M, I>(&mut self) -> WebSocketResult - where M: Message<'m, D, DataFrameIterator = I>, - I: Iterator, - D: DataFrame - { - let dataframes = try!(self.recv_message_dataframes()); + fn recv_message<'m, D, M, I, R>(&mut self, reader: &mut R) -> WebSocketResult + where M: Message<'m, D, DataFrameIterator = I>, + I: Iterator, + D: DataFrame, + R: Read + { + let dataframes = try!(self.recv_message_dataframes(reader)); Message::from_dataframes(dataframes) } /// Returns an iterator over incoming messages. - fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Self, D, F, M> - where M: Message<'a, D>, D: DataFrame { + fn incoming_messages<'a, M, D, R>( + &'a mut self, + reader: &'a mut R, + ) -> MessageIterator<'a, Self, D, M, R> + where M: Message<'a, D>, + D: DataFrame, + R: Read + { MessageIterator { + reader: reader, inner: self, _dataframe: PhantomData, - _receiver: PhantomData, _message: PhantomData, } } } /// An iterator over data frames from a Receiver. -pub struct DataFrameIterator<'a, R, D> -where R: 'a + Receiver, D: DataFrame { - inner: &'a mut R, - _dataframe: PhantomData +pub struct DataFrameIterator<'a, Recv, R> + where Recv: 'a + Receiver, + R: 'a + Read +{ + reader: &'a mut R, + inner: &'a mut Recv, } -impl<'a, R, D> Iterator for DataFrameIterator<'a, R, D> -where R: 'a + Receiver, D: DataFrame { - - type Item = WebSocketResult; +impl<'a, Recv, R> Iterator for DataFrameIterator<'a, Recv, R> + where Recv: 'a + Receiver, + R: Read +{ + type Item = WebSocketResult; /// Get the next data frame from the receiver. Always returns `Some`. - fn next(&mut self) -> Option> { - Some(self.inner.recv_dataframe()) + fn next(&mut self) -> Option> { + Some(self.inner.recv_dataframe(self.reader)) } } /// An iterator over messages from a Receiver. -pub struct MessageIterator<'a, R, D, F, M> -where R: 'a + Receiver, - M: Message<'a, D>, - D: DataFrame, - F: DataFrame, +pub struct MessageIterator<'a, Recv, D, M, R> + where Recv: 'a + Receiver, + M: Message<'a, D>, + D: DataFrame, + R: 'a + Read { - inner: &'a mut R, + reader: &'a mut R, + inner: &'a mut Recv, _dataframe: PhantomData, _message: PhantomData, - _receiver: PhantomData, } -impl<'a, R, D, F, M, I> Iterator for MessageIterator<'a, R, D, F, M> -where R: 'a + Receiver, - M: Message<'a, D, DataFrameIterator = I>, - I: Iterator, - D: DataFrame, - F: DataFrame, +impl<'a, Recv, D, M, I, R> Iterator for MessageIterator<'a, Recv, D, M, R> + where Recv: 'a + Receiver, + M: Message<'a, D, DataFrameIterator = I>, + I: Iterator, + D: DataFrame, + R: Read { type Item = WebSocketResult; /// Get the next message from the receiver. Always returns `Some`. fn next(&mut self) -> Option> { - Some(self.inner.recv_message()) + Some(self.inner.recv_message(self.reader)) } } diff --git a/src/ws/sender.rs b/src/ws/sender.rs index f017f8bd2b..c179fd275d 100644 --- a/src/ws/sender.rs +++ b/src/ws/sender.rs @@ -2,6 +2,7 @@ //! //! See the `ws` module documentation for more information. +use std::io::Write; use ws::Message; use ws::dataframe::DataFrame; use result::WebSocketResult; @@ -9,14 +10,18 @@ use result::WebSocketResult; /// A trait for sending data frames and messages. pub trait Sender { /// Sends a single data frame using this sender. - fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> - where D: DataFrame; + fn send_dataframe(&mut self, writer: &mut W, dataframe: &D) -> WebSocketResult<()> + where D: DataFrame, + W: Write; /// Sends a single message using this sender. - fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> - where M: Message<'m, D>, D: DataFrame { + fn send_message<'m, M, D, W>(&mut self, writer: &mut W, message: &'m M) -> WebSocketResult<()> + where M: Message<'m, D>, + D: DataFrame, + W: Write + { for ref dataframe in message.dataframes() { - try!(self.send_dataframe(dataframe)); + try!(self.send_dataframe(writer, dataframe)); } Ok(()) } diff --git a/src/ws/util/header.rs b/src/ws/util/header.rs index bdc845bd69..c5de25f209 100644 --- a/src/ws/util/header.rs +++ b/src/ws/util/header.rs @@ -28,41 +28,35 @@ pub struct DataFrameHeader { /// The masking key, if any. pub mask: Option<[u8; 4]>, /// The length of the payload. - pub len: u64 + pub len: u64, } /// Writes a data frame header. pub fn write_header(writer: &mut W, header: DataFrameHeader) -> WebSocketResult<()> - where W: Write { + where W: Write +{ if header.opcode > 0xF { - return Err(WebSocketError::DataFrameError( - "Invalid data frame opcode" - )); + return Err(WebSocketError::DataFrameError("Invalid data frame opcode")); } if header.opcode >= 8 && header.len >= 126 { - return Err(WebSocketError::DataFrameError( - "Control frame length too long" - )); + return Err(WebSocketError::DataFrameError("Control frame length too long")); } // Write 'FIN', 'RSV1', 'RSV2', 'RSV3' and 'opcode' try!(writer.write_u8((header.flags.bits) | header.opcode)); - try!(writer.write_u8( - // Write the 'MASK' - if header.mask.is_some() { 0x80 } else { 0x00 } | + try!(writer.write_u8(// Write the 'MASK' + if header.mask.is_some() { 0x80 } else { 0x00 } | // Write the 'Payload len' if header.len <= 125 { header.len as u8 } else if header.len <= 65535 { 126 } - else { 127 } - )); + else { 127 })); // Write 'Extended payload length' if header.len >= 126 && header.len <= 65535 { try!(writer.write_u16::(header.len as u16)); - } - else if header.len > 65535 { + } else if header.len > 65535 { try!(writer.write_u64::(header.len)); } @@ -77,7 +71,8 @@ pub fn write_header(writer: &mut W, header: DataFrameHeader) -> WebSocketResu /// Reads a data frame header. pub fn read_header(reader: &mut R) -> WebSocketResult - where R: Read { + where R: Read +{ let byte0 = try!(reader.read_u8()); let byte1 = try!(reader.read_u8()); @@ -90,18 +85,14 @@ pub fn read_header(reader: &mut R) -> WebSocketResult 126 => { let len = try!(reader.read_u16::()) as u64; if len <= 125 { - return Err(WebSocketError::DataFrameError( - "Invalid data frame length" - )); + return Err(WebSocketError::DataFrameError("Invalid data frame length")); } len } 127 => { let len = try!(reader.read_u64::()); if len <= 65535 { - return Err(WebSocketError::DataFrameError( - "Invalid data frame length" - )); + return Err(WebSocketError::DataFrameError("Invalid data frame length")); } len } @@ -110,14 +101,10 @@ pub fn read_header(reader: &mut R) -> WebSocketResult if opcode >= 8 { if len >= 126 { - return Err(WebSocketError::DataFrameError( - "Control frame length too long" - )); + return Err(WebSocketError::DataFrameError("Control frame length too long")); } if !flags.contains(FIN) { - return Err(WebSocketError::ProtocolError( - "Illegal fragmented control frame" - )); + return Err(WebSocketError::ProtocolError("Illegal fragmented control frame")); } } @@ -126,19 +113,18 @@ pub fn read_header(reader: &mut R) -> WebSocketResult try!(reader.read_u8()), try!(reader.read_u8()), try!(reader.read_u8()), - try!(reader.read_u8()) + try!(reader.read_u8()), ]) - } - else { + } else { None }; Ok(DataFrameHeader { - flags: flags, - opcode: opcode, - mask: mask, - len: len - }) + flags: flags, + opcode: opcode, + mask: mask, + len: len, + }) } #[cfg(all(feature = "nightly", test))] @@ -153,7 +139,7 @@ mod tests { flags: FIN, opcode: 1, mask: None, - len: 43 + len: 43, }; assert_eq!(obtained, expected); } @@ -163,7 +149,7 @@ mod tests { flags: FIN, opcode: 1, mask: None, - len: 43 + len: 43, }; let expected = [0x81, 0x2B]; let mut obtained = Vec::with_capacity(2); @@ -179,7 +165,7 @@ mod tests { flags: RSV1, opcode: 2, mask: Some([2, 4, 8, 16]), - len: 512 + len: 512, }; assert_eq!(obtained, expected); } @@ -189,7 +175,7 @@ mod tests { flags: RSV1, opcode: 2, mask: Some([2, 4, 8, 16]), - len: 512 + len: 512, }; let expected = [0x42, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10]; let mut obtained = Vec::with_capacity(8); @@ -200,9 +186,7 @@ mod tests { #[bench] fn bench_read_header(b: &mut test::Bencher) { let header = vec![0x42u8, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10]; - b.iter(|| { - read_header(&mut &header[..]).unwrap(); - }); + b.iter(|| { read_header(&mut &header[..]).unwrap(); }); } #[bench] fn bench_write_header(b: &mut test::Bencher) { @@ -210,11 +194,9 @@ mod tests { flags: RSV1, opcode: 2, mask: Some([2, 4, 8, 16]), - len: 512 + len: 512, }; let mut writer = Vec::with_capacity(8); - b.iter(|| { - write_header(&mut writer, header).unwrap(); - }); + b.iter(|| { write_header(&mut writer, header).unwrap(); }); } } diff --git a/src/ws/util/mask.rs b/src/ws/util/mask.rs index 21c9219dec..114727881a 100644 --- a/src/ws/util/mask.rs +++ b/src/ws/util/mask.rs @@ -7,39 +7,42 @@ use std::mem; /// Struct to pipe data into another writer, /// while masking the data being written pub struct Masker<'w, W> -where W: Write + 'w { - key: [u8; 4], - pos: usize, - end: &'w mut W, + where W: Write + 'w +{ + key: [u8; 4], + pos: usize, + end: &'w mut W, } impl<'w, W> Masker<'w, W> -where W: Write + 'w { - /// Create a new Masker with the key and the endpoint - /// to be writter to. - pub fn new(key: [u8; 4], endpoint: &'w mut W) -> Self { - Masker { - key: key, - pos: 0, - end: endpoint, - } - } + where W: Write + 'w +{ + /// Create a new Masker with the key and the endpoint + /// to be writter to. + pub fn new(key: [u8; 4], endpoint: &'w mut W) -> Self { + Masker { + key: key, + pos: 0, + end: endpoint, + } + } } impl<'w, W> Write for Masker<'w, W> -where W: Write + 'w { - fn write(&mut self, data: &[u8]) -> IoResult { - let mut buf = Vec::with_capacity(data.len()); - for &byte in data.iter() { - buf.push(byte ^ self.key[self.pos]); - self.pos = (self.pos + 1) % self.key.len(); - } - self.end.write(&buf) - } + where W: Write + 'w +{ + fn write(&mut self, data: &[u8]) -> IoResult { + let mut buf = Vec::with_capacity(data.len()); + for &byte in data.iter() { + buf.push(byte ^ self.key[self.pos]); + self.pos = (self.pos + 1) % self.key.len(); + } + self.end.write(&buf) + } - fn flush(&mut self) -> IoResult<()> { - self.end.flush() - } + fn flush(&mut self) -> IoResult<()> { + self.end.flush() + } } /// Generates a random masking key @@ -50,12 +53,12 @@ pub fn gen_mask() -> [u8; 4] { /// Masks data to send to a server and writes pub fn mask_data(mask: [u8; 4], data: &[u8]) -> Vec { - let mut out = Vec::with_capacity(data.len()); - let zip_iter = data.iter().zip(mask.iter().cycle()); - for (&buf_item, &key_item) in zip_iter { - out.push(buf_item ^ key_item); - } - out + let mut out = Vec::with_capacity(data.len()); + let zip_iter = data.iter().zip(mask.iter().cycle()); + for (&buf_item, &key_item) in zip_iter { + out.push(buf_item ^ key_item); + } + out } #[cfg(all(feature = "nightly", test))] @@ -79,16 +82,16 @@ mod tests { let buffer = b"The quick brown fox jumps over the lazy dog"; let key = gen_mask(); b.iter(|| { - let mut output = mask_data(key, buffer); - test::black_box(&mut output); - }); + let mut output = mask_data(key, buffer); + test::black_box(&mut output); + }); } #[bench] fn bench_gen_mask(b: &mut test::Bencher) { b.iter(|| { - let mut key = gen_mask(); - test::black_box(&mut key); - }); + let mut key = gen_mask(); + test::black_box(&mut key); + }); } } diff --git a/src/ws/util/mod.rs b/src/ws/util/mod.rs index e5878006b8..c59e04db63 100644 --- a/src/ws/util/mod.rs +++ b/src/ws/util/mod.rs @@ -2,7 +2,6 @@ pub mod header; pub mod mask; -pub mod url; use std::str::from_utf8; use std::str::Utf8Error; diff --git a/src/ws/util/url.rs b/src/ws/util/url.rs deleted file mode 100644 index f6862c12b1..0000000000 --- a/src/ws/util/url.rs +++ /dev/null @@ -1,343 +0,0 @@ -//! Utility functions for dealing with URLs - -use url::{Url, Position}; -use url::Host as UrlHost; -use hyper::header::Host; -use result::{WebSocketResult, WSUrlErrorKind}; - -/// Trait that gets required WebSocket URL components -pub trait ToWebSocketUrlComponents { - /// Retrieve the required WebSocket URL components from this - fn to_components(&self) -> WebSocketResult<(Host, String, bool)>; -} - -impl ToWebSocketUrlComponents for str { - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - parse_url_str(&self) - } -} - -impl ToWebSocketUrlComponents for Url { - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - parse_url(&self) - } -} - -impl ToWebSocketUrlComponents for (Host, String, bool) { - /// Convert a Host, resource name and secure flag to WebSocket URL components. - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - let (mut host, mut resource_name, secure) = self.clone(); - host.port = Some(match host.port { - Some(port) => port, - None => if secure { 443 } else { 80 }, - }); - if resource_name.is_empty() { - resource_name = "/".to_owned(); - } - Ok((host, resource_name, secure)) - } -} - -impl<'a> ToWebSocketUrlComponents for (Host, &'a str, bool) { - /// Convert a Host, resource name and secure flag to WebSocket URL components. - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - (self.0.clone(), self.1.to_owned(), self.2).to_components() - } -} - -impl<'a> ToWebSocketUrlComponents for (Host, &'a str) { - /// Convert a Host and resource name to WebSocket URL components, assuming an insecure connection. - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - (self.0.clone(), self.1.to_owned(), false).to_components() - } -} - -impl ToWebSocketUrlComponents for (Host, String) { - /// Convert a Host and resource name to WebSocket URL components, assuming an insecure connection. - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - (self.0.clone(), self.1.clone(), false).to_components() - } -} - -impl ToWebSocketUrlComponents for (UrlHost, u16, String, bool) { - /// Convert a Host, port, resource name and secure flag to WebSocket URL components. - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - (Host { - hostname: self.0.to_string(), - port: Some(self.1) - }, self.2.clone(), self.3).to_components() - } -} - -impl<'a> ToWebSocketUrlComponents for (UrlHost, u16, &'a str, bool) { - /// Convert a Host, port, resource name and secure flag to WebSocket URL components. - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - (Host { - hostname: self.0.to_string(), - port: Some(self.1) - }, self.2, self.3).to_components() - } -} - -impl<'a, T: ToWebSocketUrlComponents> ToWebSocketUrlComponents for &'a T { - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - (**self).to_components() - } -} - -/// Gets the host, port, resource and secure from the string representation of a url -pub fn parse_url_str(url_str: &str) -> WebSocketResult<(Host, String, bool)> { - // https://html.spec.whatwg.org/multipage/#parse-a-websocket-url's-components - // Steps 1 and 2 - let parsed_url = try!(Url::parse(url_str)); - parse_url(&parsed_url) -} - -/// Gets the host, port, resource, and secure flag from a url -pub fn parse_url(url: &Url) -> WebSocketResult<(Host, String, bool)> { - // https://html.spec.whatwg.org/multipage/#parse-a-websocket-url's-components - - // Step 4 - if url.fragment().is_some() { - return Err(From::from(WSUrlErrorKind::CannotSetFragment)); - } - - let secure = match url.scheme() { - // step 5 - "ws" => false, - "wss" => true, - // step 3 - _ => return Err(From::from(WSUrlErrorKind::InvalidScheme)), - }; - - let host = url.host_str().unwrap().to_owned(); // Step 6 - let port = url.port_or_known_default(); // Steps 7 and 8 - - // steps 9, 10, 11 - let resource = url[Position::BeforePath..Position::AfterQuery].to_owned(); - - // Step 12 - Ok((Host { hostname: host, port: port }, resource, secure)) -} - -#[cfg(all(feature = "nightly", test))] -mod tests { - use super::*; - //use test; - use url::Url; - use result::{WebSocketError, WSUrlErrorKind}; - - fn url_for_test() -> Url { - Url::parse("ws://www.example.com:8080/some/path?a=b&c=d").unwrap() - } - - #[test] - fn test_parse_url_fragments_not_accepted() { - let url = &mut url_for_test(); - url.set_fragment(Some("non_null_fragment")); - - let result = parse_url(url); - match result { - Err(WebSocketError::WebSocketUrlError( - WSUrlErrorKind::CannotSetFragment)) => (), - Err(e) => panic!("Expected WSUrlErrorKind::CannotSetFragment but got {}", e), - Ok(_) => panic!("Expected WSUrlErrorKind::CannotSetFragment but got Ok") - } - } - - #[test] - fn test_parse_url_invalid_schemes_return_error() { - let url = &mut url_for_test(); - - let invalid_schemes = &["http", "https", "gopher", "file", "ftp", "other"]; - for scheme in invalid_schemes { - url.set_scheme(scheme).unwrap(); - - let result = parse_url(url); - match result { - Err(WebSocketError::WebSocketUrlError( - WSUrlErrorKind::InvalidScheme)) => (), - Err(e) => panic!("Expected WSUrlErrorKind::InvalidScheme but got {}", e), - Ok(_) => panic!("Expected WSUrlErrorKind::InvalidScheme but got Ok") - } - } - } - - #[test] - fn test_parse_url_valid_schemes_return_ok() { - let url = &mut url_for_test(); - - let valid_schemes = &["ws", "wss"]; - for scheme in valid_schemes { - url.set_scheme(scheme).unwrap(); - - let result = parse_url(url); - match result { - Ok(_) => (), - Err(e) => panic!("Expected Ok, but got {}", e) - } - } - } - - #[test] - fn test_parse_url_ws_returns_unset_secure_flag() { - let url = &mut url_for_test(); - url.set_scheme("ws").unwrap(); - - let result = parse_url(url); - let secure = match result { - Ok((_, _, secure)) => secure, - Err(e) => panic!(e), - }; - assert!(!secure); - } - - #[test] - fn test_parse_url_wss_returns_set_secure_flag() { - let url = &mut url_for_test(); - url.set_scheme("wss").unwrap(); - - let result = parse_url(url); - let secure = match result { - Ok((_, _, secure)) => secure, - Err(e) => panic!(e), - }; - assert!(secure); - } - - #[test] - fn test_parse_url_generates_proper_output() { - let url = &url_for_test(); - - let result = parse_url(url); - let (host, resource) = match result { - Ok((host, resource, _)) => (host, resource), - Err(e) => panic!(e), - }; - - assert_eq!(host.hostname, "www.example.com".to_owned()); - assert_eq!(resource, "/some/path?a=b&c=d".to_owned()); - - match host.port { - Some(port) => assert_eq!(port, 8080), - _ => panic!("Port should not be None"), - } - } - - #[test] - fn test_parse_url_empty_path_should_give_slash() { - let url = &mut url_for_test(); - url.set_path("/"); - - let result = parse_url(url); - let resource = match result { - Ok((_, resource, _)) => resource, - Err(e) => panic!(e), - }; - - assert_eq!(resource, "/?a=b&c=d".to_owned()); - } - - #[test] - fn test_parse_url_none_query_should_not_append_question_mark() { - let url = &mut url_for_test(); - url.set_query(None); - - let result = parse_url(url); - let resource = match result { - Ok((_, resource, _)) => resource, - Err(e) => panic!(e), - }; - - assert_eq!(resource, "/some/path".to_owned()); - } - - #[test] - fn test_parse_url_none_port_should_use_default_port() { - let url = &mut url_for_test(); - url.set_port(None).unwrap(); - - let result = parse_url(url); - let host = match result { - Ok((host, _, _)) => host, - Err(e) => panic!(e), - }; - - match host.port { - Some(80) => (), - Some(p) => panic!("Expected port to be 80 but got {}", p), - None => panic!("Expected port to be 80 but got `None`"), - } - } - - #[test] - fn test_parse_url_str_valid_url1() { - let url_str = "ws://www.example.com/some/path?a=b&c=d"; - let result = parse_url_str(url_str); - let (host, resource, secure) = match result { - Ok((host, resource, secure)) => (host, resource, secure), - Err(e) => panic!(e), - }; - - match host.port { - Some(80) => (), - Some(p) => panic!("Expected port 80 but got {}", p), - None => panic!("Expected port 80 but got `None`") - } - assert_eq!(host.hostname, "www.example.com".to_owned()); - assert_eq!(resource, "/some/path?a=b&c=d".to_owned()); - assert!(!secure); - } - - #[test] - fn test_parse_url_str_valid_url2() { - let url_str = "wss://www.example.com"; - let result = parse_url_str(url_str); - let (host, resource, secure) = match result { - Ok((host, resource, secure)) => (host, resource, secure), - Err(e) => panic!(e) - }; - - match host.port { - Some(443) => (), - Some(p) => panic!("Expected port 443 but got {}", p), - None => panic!("Expected port 443 but got `None`") - } - assert_eq!(host.hostname, "www.example.com".to_owned()); - assert_eq!(resource, "/".to_owned()); - assert!(secure); - } - - #[test] - fn test_parse_url_str_invalid_relative_url() { - let url_str = "/some/relative/path?a=b&c=d"; - let result = parse_url_str(url_str); - match result { - Err(WebSocketError::UrlError(_)) => (), - Err(e) => panic!("Expected UrlError, but got unexpected error {}", e), - Ok(_) => panic!("Expected UrlError, but got Ok"), - } - } - - #[test] - fn test_parse_url_str_invalid_url_scheme() { - let url_str = "http://www.example.com/some/path?a=b&c=d"; - let result = parse_url_str(url_str); - match result { - Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::InvalidScheme)) => (), - Err(e) => panic!("Expected InvalidScheme, but got unexpected error {}", e), - Ok(_) => panic!("Expected InvalidScheme, but got Ok"), - } - } - - #[test] - fn test_parse_url_str_invalid_url_fragment() { - let url_str = "http://www.example.com/some/path#some-id"; - let result = parse_url_str(url_str); - match result { - Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::CannotSetFragment)) => (), - Err(e) => panic!("Expected CannotSetFragment, but got unexpected error {}", e), - Ok(_) => panic!("Expected CannotSetFragment, but got Ok"), - } - } -}