From 775f006d7620cc84301388da071513eb468cceef Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Sat, 2 Jul 2016 15:03:18 -0400 Subject: [PATCH 01/32] adds a common stream interface that is simpler than WebsocketStream --- src/stream.rs | 64 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/src/stream.rs b/src/stream.rs index 623b280bbd..1c22d80472 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -7,6 +7,70 @@ use openssl::ssl::SslStream; pub use std::net::{SocketAddr, Shutdown, TcpStream}; +/// Represents a stream that can be read from, written to, and split into two. +/// This is an abstraction around readable and writable things to be able +/// to speak websockets over ssl, tcp, unix sockets, etc. +pub trait Stream +where R: Read, + W: Write, +{ + /// Get a mutable borrow to the reading component of this stream + fn reader(&mut self) -> &mut R; + + /// Get a mutable borrow to the writing component of this stream + fn writer(&mut self) -> &mut W; + + /// Split this stream into readable and writable components. + /// The motivation behind this is to be able to read on one thread + /// and send messages on another. + fn split(self) -> Result<(R, W), io::Error>; +} + +impl Stream for (R, W) +where R: Read, + W: Write, +{ + fn reader(&mut self) -> &mut R { + &mut self.0 + } + + fn writer(&mut self) -> &mut W { + &mut self.1 + } + + fn split(self) -> Result<(R, W), io::Error> { + Ok(self) + } +} + +impl Stream for TcpStream { + fn reader(&mut self) -> &mut TcpStream { + self + } + + fn writer(&mut self) -> &mut TcpStream { + self + } + + fn split(self) -> Result<(TcpStream, TcpStream), io::Error> { + Ok((try!(self.try_clone()), self)) + } +} + +impl Stream, SslStream> for SslStream { + fn reader(&mut self) -> &mut SslStream { + self + } + + fn writer(&mut self) -> &mut SslStream { + self + } + + fn split(self) -> Result<(SslStream, SslStream), io::Error> { + Ok((try!(self.try_clone()), self)) + } +} + /// A useful stream type for carrying WebSocket connections. pub enum WebSocketStream { /// A TCP stream. From a6df89faf46aac64832c66ebc945f700d90f2f79 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Sat, 2 Jul 2016 21:01:34 -0400 Subject: [PATCH 02/32] integrated new Stream trait and IntoWs struct --- src/client/mod.rs | 64 ++++++++++++----- src/client/request.rs | 5 +- src/lib.rs | 3 +- src/receiver.rs | 39 +++++------ src/sender.rs | 33 ++++----- src/server/mod.rs | 2 - src/server/upgrade.rs | 65 ++++++++++++++++++ src/stream.rs | 155 ++++++++++++++---------------------------- 8 files changed, 200 insertions(+), 166 deletions(-) create mode 100644 src/server/upgrade.rs diff --git a/src/client/mod.rs b/src/client/mod.rs index b62302d0bf..906359e9c8 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -8,7 +8,10 @@ use ws; use ws::util::url::ToWebSocketUrlComponents; use ws::receiver::{DataFrameIterator, MessageIterator}; use result::WebSocketResult; -use stream::WebSocketStream; +use stream::{ + AsTcpStream, + Stream, +}; use dataframe::DataFrame; use ws::dataframe::DataFrame as DataFrameable; @@ -62,15 +65,21 @@ pub struct Client { _dataframe: PhantomData } -impl Client, Receiver> { +impl Client, Receiver> +where S: AsTcpStream + Stream, +{ /// Connects to the given ws:// or wss:// URL and return a Request to be sent. /// /// A connection is established, however the request is not sent to /// the server until a call to ```send()```. - pub fn connect(components: T) -> WebSocketResult> { + pub fn connect(components: C) -> WebSocketResult> + where C: ToWebSocketUrlComponents, + { + // TODO: Do not create a default SSL Context every time let context = try!(SslContext::new(SslMethod::Tlsv1)); Client::connect_ssl_context(components, &context) } + /// Connects to the specified wss:// URL using the given SSL context. /// /// If a ws:// URL is supplied, a normal, non-secure connection is established @@ -78,39 +87,49 @@ impl Client, Receiver> { /// /// A connection is established, however the request is not sent to /// the server until a call to ```send()```. - pub fn connect_ssl_context(components: T, context: &SslContext) -> WebSocketResult> { + pub fn connect_ssl_context(components: C, context: &SslContext) -> WebSocketResult> + where C: ToWebSocketUrlComponents, + { let (host, resource_name, secure) = try!(components.to_components()); - let connection = try!(TcpStream::connect( - (&host.hostname[..], host.port.unwrap_or(if secure { 443 } else { 80 })) - )); + let port = match host.port { + Some(p) => p, + None => if secure { + 443 + } else { + 80 + }, + }; + + let hostname = &host.hostname[..]; + + let connection = try!(TcpStream::connect((hostname, port))); + + let components = (host, resource_name, secure); - let stream = if secure { + if secure { let sslstream = try!(SslStream::connect(context, connection)); - WebSocketStream::Ssl(sslstream) - } - else { - WebSocketStream::Tcp(connection) + Request::new(components, try!(sslstream.split())) + } else { + Request::new(components, try!(connection.split())) }; - - Request::new((host, resource_name, secure), try!(stream.try_clone()), stream) } /// Shuts down the sending half of the client connection, will cause all pending /// and future IO to return immediately with an appropriate value. - pub fn shutdown_sender(&mut self) -> IoResult<()> { + pub fn shutdown_sender(&self) -> IoResult<()> { self.sender.shutdown() } /// Shuts down the receiving half of the client connection, will cause all pending /// and future IO to return immediately with an appropriate value. - pub fn shutdown_receiver(&mut self) -> IoResult<()> { + pub fn shutdown_receiver(&self) -> IoResult<()> { self.receiver.shutdown() } /// Shuts down the client connection, will cause all pending and future IO to /// return immediately with an appropriate value. - pub fn shutdown(&mut self) -> IoResult<()> { + pub fn shutdown(&self) -> IoResult<()> { self.receiver.shutdown_all() } } @@ -126,29 +145,35 @@ impl> Client { _dataframe: PhantomData } } + /// Sends a single data frame to the remote endpoint. pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> where D: DataFrameable { self.sender.send_dataframe(dataframe) } + /// Sends a single message to the remote endpoint. pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> where M: ws::Message<'m, D>, D: DataFrameable { self.sender.send_message(message) } + /// Reads a single data frame from the remote endpoint. pub fn recv_dataframe(&mut self) -> WebSocketResult { self.receiver.recv_dataframe() } + /// Returns an iterator over incoming data frames. pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, R, F> { self.receiver.incoming_dataframes() } + /// Reads a single message from this receiver. pub fn recv_message<'m, M, I>(&mut self) -> WebSocketResult where M: ws::Message<'m, F, DataFrameIterator = I>, I: Iterator { self.receiver.recv_message() } + /// Returns an iterator over incoming messages. /// ///```no_run @@ -199,22 +224,27 @@ impl> Client { { self.receiver.incoming_messages() } + /// Returns a reference to the underlying Sender. pub fn get_sender(&self) -> &S { &self.sender } + /// Returns a reference to the underlying Receiver. pub fn get_receiver(&self) -> &R { &self.receiver } + /// Returns a mutable reference to the underlying Sender. pub fn get_mut_sender(&mut self) -> &mut S { &mut self.sender } + /// Returns a mutable reference to the underlying Receiver. pub fn get_mut_receiver(&mut self) -> &mut R { &mut self.receiver } + /// Split this client into its constituent Sender and Receiver pair. /// /// This allows the Sender and Receiver to be sent to different threads. diff --git a/src/client/request.rs b/src/client/request.rs index 0e470b761a..43f502cc67 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -38,7 +38,10 @@ impl Request { /// In general `Client::connect()` should be used for connecting to servers. /// However, if the request is to be written to a different Writer, this function /// may be used. - pub fn new(components: T, reader: R, writer: W) -> WebSocketResult> { + pub fn new(components: T, stream: (R, W)) -> WebSocketResult> + where T: ToWebSocketUrlComponents, + { + let (reader, writer) = stream; let mut headers = Headers::new(); let (host, resource_name, _) = try!(components.to_components()); headers.set(host); diff --git a/src/lib.rs b/src/lib.rs index e90c72ee95..1e38a1ace3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,7 +54,8 @@ pub use self::client::Client; pub use self::server::Server; pub use self::dataframe::DataFrame; pub use self::message::Message; -pub use self::stream::WebSocketStream; +pub use self::stream::Stream; +pub use self::stream::AsTcpStream; pub use self::ws::Sender; pub use self::ws::Receiver; diff --git a/src/receiver.rs b/src/receiver.rs index 03cfab6197..a7a845ccc2 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -6,9 +6,9 @@ use hyper::buffer::BufReader; use dataframe::{DataFrame, Opcode}; use result::{WebSocketResult, WebSocketError}; -use stream::WebSocketStream; -use stream::Shutdown; use ws; +use stream::AsTcpStream; +pub use stream::Shutdown; /// A Receiver that wraps a Reader and provides a default implementation using /// DataFrames and Messages. @@ -19,7 +19,8 @@ pub struct Receiver { } impl Receiver -where R: Read { +where R: Read, +{ /// Create a new Receiver using the specified Reader. pub fn new(reader: BufReader, mask: bool) -> Receiver { Receiver { @@ -38,26 +39,20 @@ where R: Read { } } -impl Receiver { - /// Closes the receiver side of the connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&mut self) -> IoResult<()> { - self.inner.get_mut().shutdown(Shutdown::Read) - } - - /// Shuts down both Sender and Receiver, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown_all(&mut self) -> IoResult<()> { - self.inner.get_mut().shutdown(Shutdown::Both) - } +impl Receiver +where S: AsTcpStream, +{ + /// Closes the receiver side of the connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.inner.get_ref().as_tcp().shutdown(Shutdown::Read) + } - /// Changes whether the receiver is in nonblocking mode. - /// - /// If it is in nonblocking mode and there is no incoming message, trying to receive a message - /// will return an error instead of blocking. - pub fn set_nonblocking(&self, nonblocking: bool) -> IoResult<()> { - self.inner.get_ref().set_nonblocking(nonblocking) - } + /// Shuts down both Sender and Receiver, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown_all(&self) -> IoResult<()> { + self.inner.get_ref().as_tcp().shutdown(Shutdown::Both) + } } impl ws::Receiver for Receiver { diff --git a/src/sender.rs b/src/sender.rs index a4390f4741..0dda17e950 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -4,9 +4,9 @@ use std::io::Write; use std::io::Result as IoResult; use result::WebSocketResult; use ws::dataframe::DataFrame; -use stream::WebSocketStream; -use stream::Shutdown; +use stream::AsTcpStream; use ws; +pub use stream::Shutdown; /// A Sender that wraps a Writer and provides a default implementation using /// DataFrames and Messages. @@ -33,23 +33,20 @@ impl Sender { } } -impl Sender { - /// Closes the sender side of the connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&mut self) -> IoResult<()> { - self.inner.shutdown(Shutdown::Write) - } - - /// Shuts down both Sender and Receiver, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown_all(&mut self) -> IoResult<()> { - self.inner.shutdown(Shutdown::Both) - } +impl Sender +where S: AsTcpStream +{ + /// Closes the sender side of the connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.inner.as_tcp().shutdown(Shutdown::Write) + } - /// Changes whether the sender is in nonblocking mode. - pub fn set_nonblocking(&self, nonblocking: bool) -> IoResult<()> { - self.inner.set_nonblocking(nonblocking) - } + /// Shuts down both Sender and Receiver, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown_all(&self) -> IoResult<()> { + self.inner.as_tcp().shutdown(Shutdown::Both) + } } impl ws::Sender for Sender { diff --git a/src/server/mod.rs b/src/server/mod.rs index 8236b1c089..541ce53a3b 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -6,8 +6,6 @@ use std::io; pub use self::request::Request; pub use self::response::Response; -use stream::WebSocketStream; - use openssl::ssl::SslContext; use openssl::ssl::SslStream; diff --git a/src/server/upgrade.rs b/src/server/upgrade.rs new file mode 100644 index 0000000000..d4f24a2b6f --- /dev/null +++ b/src/server/upgrade.rs @@ -0,0 +1,65 @@ +//! Allows you to take an existing request or stream of data and convert it into a +//! WebSocket client. +extern crate hyper; +extern crate openssl; + +use super::super::stream::Stream; + +/// Any error that could occur when attempting +/// to parse data into a websocket upgrade request +pub enum IntoWsError { + /// If the request was not actually asking for a websocket connection + RequestIsNotUpgrade, +} + +/// Intermediate representation of a half created websocket session. +/// Should be used to examine the client's handshake +/// accept the protocols requested, route the path, etc. +/// +/// Users should then call `accept` or `deny` to complete the handshake +/// and start a session. +pub struct WsUpgrade +where S: Stream, +{ + stream: S, +} + +impl WsUpgrade +where S: Stream, +{ + fn from_stream(inner: S) -> Self { + WsUpgrade { + stream: inner, + } + } + + fn unwrap(self) -> S { + self.stream + } + + fn accept(self) { + unimplemented!(); + } + + fn reject(self) -> S { + unimplemented!(); + } +} + +/// Trait to take a stream or similar and attempt to recover the start of a +/// websocket handshake from it. +/// Should be used when a stream might contain a request for a websocket session. +/// +/// If an upgrade request can be parsed, one can accept or deny the handshake with +/// the `WsUpgrade` struct. +/// Otherwise the original stream is returned along with an error. +/// +/// Note: the stream is owned because the websocket client expects to own its stream. +pub trait IntoWs +{ + /// Attempt to parse the start of a Websocket handshake, later with the returned + /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to + /// send a handshake rejection response. + fn into_ws(self) -> Result, (O, IntoWsError)>; +} + diff --git a/src/stream.rs b/src/stream.rs index 1c22d80472..02dfa0ff5a 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,49 +1,61 @@ //! Provides the default stream type for WebSocket connections. -extern crate net2; - -use std::io::{self, Read, Write}; -use self::net2::TcpStreamExt; -use openssl::ssl::SslStream; - -pub use std::net::{SocketAddr, Shutdown, TcpStream}; +use std::io::{ + self, + Read, + Write +}; +pub use std::net::{ + TcpStream, + Shutdown, +}; +pub use openssl::ssl::SslStream; /// Represents a stream that can be read from, written to, and split into two. /// This is an abstraction around readable and writable things to be able /// to speak websockets over ssl, tcp, unix sockets, etc. -pub trait Stream -where R: Read, - W: Write, +pub trait Stream { - /// Get a mutable borrow to the reading component of this stream - fn reader(&mut self) -> &mut R; + /// The reading component of the stream + type R: Read; + /// The writing component of the stream + type W: Write; - /// Get a mutable borrow to the writing component of this stream - fn writer(&mut self) -> &mut W; + /// Get a mutable borrow to the reading component of this stream + fn reader(&mut self) -> &mut Self::R; - /// Split this stream into readable and writable components. - /// The motivation behind this is to be able to read on one thread - /// and send messages on another. - fn split(self) -> Result<(R, W), io::Error>; + /// Get a mutable borrow to the writing component of this stream + fn writer(&mut self) -> &mut Self::W; + + /// Split this stream into readable and writable components. + /// The motivation behind this is to be able to read on one thread + /// and send messages on another. + fn split(self) -> io::Result<(Self::R, Self::W)>; } -impl Stream for (R, W) +impl Stream for (R, W) where R: Read, - W: Write, + W: Write, { - fn reader(&mut self) -> &mut R { + type R = R; + type W = W; + + fn reader(&mut self) -> &mut Self::R { &mut self.0 } - fn writer(&mut self) -> &mut W { + fn writer(&mut self) -> &mut Self::W { &mut self.1 } - fn split(self) -> Result<(R, W), io::Error> { + fn split(self) -> io::Result<(Self::R, Self::W)> { Ok(self) } } -impl Stream for TcpStream { +impl Stream for TcpStream { + type R = TcpStream; + type W = TcpStream; + fn reader(&mut self) -> &mut TcpStream { self } @@ -52,12 +64,15 @@ impl Stream for TcpStream { self } - fn split(self) -> Result<(TcpStream, TcpStream), io::Error> { + fn split(self) -> io::Result<(TcpStream, TcpStream)> { Ok((try!(self.try_clone()), self)) } } -impl Stream, SslStream> for SslStream { +impl Stream for SslStream { + type R = SslStream; + type W = SslStream; + fn reader(&mut self) -> &mut SslStream { self } @@ -66,93 +81,23 @@ impl Stream, SslStream> for SslStream self } - fn split(self) -> Result<(SslStream, SslStream), io::Error> { + fn split(self) -> io::Result<(SslStream, SslStream)> { Ok((try!(self.try_clone()), self)) } } -/// A useful stream type for carrying WebSocket connections. -pub enum WebSocketStream { - /// A TCP stream. - Tcp(TcpStream), - /// An SSL-backed TCP Stream - Ssl(SslStream) -} - -impl Read for WebSocketStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match *self { - WebSocketStream::Tcp(ref mut inner) => inner.read(buf), - WebSocketStream::Ssl(ref mut inner) => inner.read(buf), - } - } +pub trait AsTcpStream: Read + Write { + fn as_tcp(&self) -> &TcpStream; } -impl Write for WebSocketStream { - fn write(&mut self, msg: &[u8]) -> io::Result { - match *self { - WebSocketStream::Tcp(ref mut inner) => inner.write(msg), - WebSocketStream::Ssl(ref mut inner) => inner.write(msg), - } - } - - fn flush(&mut self) -> io::Result<()> { - match *self { - WebSocketStream::Tcp(ref mut inner) => inner.flush(), - WebSocketStream::Ssl(ref mut inner) => inner.flush(), - } +impl AsTcpStream for TcpStream { + fn as_tcp(&self) -> &TcpStream { + self } } -impl WebSocketStream { - /// See `TcpStream.peer_addr()`. - pub fn peer_addr(&self) -> io::Result { - match *self { - WebSocketStream::Tcp(ref inner) => inner.peer_addr(), - WebSocketStream::Ssl(ref inner) => inner.get_ref().peer_addr(), - } - } - /// See `TcpStream.local_addr()`. - pub fn local_addr(&self) -> io::Result { - match *self { - WebSocketStream::Tcp(ref inner) => inner.local_addr(), - WebSocketStream::Ssl(ref inner) => inner.get_ref().local_addr(), - } - } - /// See `TcpStream.set_nodelay()`. - pub fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> { - match *self { - WebSocketStream::Tcp(ref mut inner) => TcpStreamExt::set_nodelay(inner, nodelay), - WebSocketStream::Ssl(ref mut inner) => TcpStreamExt::set_nodelay(inner.get_mut(), nodelay), - } +impl AsTcpStream for SslStream { + fn as_tcp(&self) -> &TcpStream { + self.get_ref() } - /// See `TcpStream.set_keepalive()`. - pub fn set_keepalive(&mut self, delay_in_ms: Option) -> io::Result<()> { - match *self { - WebSocketStream::Tcp(ref mut inner) => TcpStreamExt::set_keepalive_ms(inner, delay_in_ms), - WebSocketStream::Ssl(ref mut inner) => TcpStreamExt::set_keepalive_ms(inner.get_mut(), delay_in_ms), - } - } - /// See `TcpStream.shutdown()`. - pub fn shutdown(&mut self, shutdown: Shutdown) -> io::Result<()> { - match *self { - WebSocketStream::Tcp(ref mut inner) => inner.shutdown(shutdown), - WebSocketStream::Ssl(ref mut inner) => inner.get_mut().shutdown(shutdown), - } - } - /// See `TcpStream.try_clone()`. - pub fn try_clone(&self) -> io::Result { - Ok(match *self { - WebSocketStream::Tcp(ref inner) => WebSocketStream::Tcp(try!(inner.try_clone())), - WebSocketStream::Ssl(ref inner) => WebSocketStream::Ssl(try!(inner.try_clone())), - }) - } - - /// Changes whether the stream is in nonblocking mode. - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - match *self { - WebSocketStream::Tcp(ref inner) => inner.set_nonblocking(nonblocking), - WebSocketStream::Ssl(ref inner) => inner.get_ref().set_nonblocking(nonblocking), - } - } } From 14e159c2e8eab7e8e87265759df079204af56c63 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Sun, 3 Jul 2016 13:12:30 -0400 Subject: [PATCH 03/32] Fixed client creation with new stream trait Changes the way clients are created slightly: - (slightly, ostensibly) better performance when one knows if they want a secure connection or not. This means no match statement redirecting calls to their respective SSL or TCP stream. These calls have been named: - `connect` for insecure connections - `connect_secure` for secure connections - If one does not know how which connection it wants it can take the (not verified by testing) performance hit of putting the underlying stream on the heap behind a `Box`. This can be done with: - `connect_agnostic` All calls currently take the custom ToWebSocketUrlComponents but that should be changed to a `Url` (probably from rust-url). --- src/client/mod.rs | 101 ++++++++++++++++++++++++++++++---------------- src/sender.rs | 2 +- src/stream.rs | 72 ++++++++++++++++++++++++--------- 3 files changed, 120 insertions(+), 55 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 906359e9c8..5143532755 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -3,6 +3,7 @@ use std::net::TcpStream; use std::marker::PhantomData; use std::io::Result as IoResult; +use std::ops::Deref; use ws; use ws::util::url::ToWebSocketUrlComponents; @@ -10,6 +11,7 @@ use ws::receiver::{DataFrameIterator, MessageIterator}; use result::WebSocketResult; use stream::{ AsTcpStream, + TryUnsizedClone, Stream, }; use dataframe::DataFrame; @@ -65,33 +67,54 @@ pub struct Client { _dataframe: PhantomData } -impl Client, Receiver> -where S: AsTcpStream + Stream, -{ - /// Connects to the given ws:// or wss:// URL and return a Request to be sent. +impl Client, Receiver> { + /// Connects to the given ws:// URL and return a Request to be sent. + /// + /// If you would like to use a secure connection (wss://), please use `connect_secure`. /// /// A connection is established, however the request is not sent to /// the server until a call to ```send()```. - pub fn connect(components: C) -> WebSocketResult> + pub fn connect(components: C) -> WebSocketResult> where C: ToWebSocketUrlComponents, { - // TODO: Do not create a default SSL Context every time - let context = try!(SslContext::new(SslMethod::Tlsv1)); - Client::connect_ssl_context(components, &context) + let (host, resource_name, secure) = try!(components.to_components()); + let stream = TcpStream::connect((&host.hostname[..], host.port.unwrap_or(80))); + let stream = try!(stream); + Request::new((host, resource_name, secure), try!(stream.split())) } +} +impl Client>, Receiver>> { /// Connects to the specified wss:// URL using the given SSL context. /// - /// If a ws:// URL is supplied, a normal, non-secure connection is established - /// and the context parameter is ignored. + /// If you would like to use an insecure connection (ws://), please use `connect`. /// /// A connection is established, however the request is not sent to /// the server until a call to ```send()```. - pub fn connect_ssl_context(components: C, context: &SslContext) -> WebSocketResult> + pub fn connect_secure(components: C, context: Option<&SslContext>) -> WebSocketResult, SslStream>> where C: ToWebSocketUrlComponents, { let (host, resource_name, secure) = try!(components.to_components()); + let stream = TcpStream::connect((&host.hostname[..], host.port.unwrap_or(443))); + let stream = try!(stream); + let sslstream = if let Some(c) = context { + SslStream::connect(c, stream) + } else { + let context = try!(SslContext::new(SslMethod::Tlsv1)); + SslStream::connect(&context, stream) + }; + let sslstream = try!(sslstream); + + Request::new((host, resource_name, secure), try!(sslstream.split())) + } +} + +impl Client>, Receiver>> { + pub fn connect_agnostic(components: C, ssl_context: Option<&SslContext>) -> WebSocketResult, Box>> + where C: ToWebSocketUrlComponents + { + let (host, resource_name, secure) = try!(components.to_components()); let port = match host.port { Some(p) => p, None => if secure { @@ -100,38 +123,46 @@ where S: AsTcpStream + Stream, 80 }, }; - let hostname = &host.hostname[..]; + let tcp_stream = try!(TcpStream::connect((hostname, port))); - let connection = try!(TcpStream::connect((hostname, port))); - - let components = (host, resource_name, secure); - - if secure { - let sslstream = try!(SslStream::connect(context, connection)); - Request::new(components, try!(sslstream.split())) + let stream: Box = if secure { + if let Some(c) = ssl_context { + Box::new(try!(SslStream::connect(c, tcp_stream))) + } else { + let context = try!(SslContext::new(SslMethod::Tlsv1)); + Box::new(try!(SslStream::connect(&context, tcp_stream))) + } } else { - Request::new(components, try!(connection.split())) + Box::new(tcp_stream) }; + + let (read, write) = (try!(stream.try_clone()), stream); + + Request::new((host, resource_name, secure), (read, write)) } +} - /// Shuts down the sending half of the client connection, will cause all pending - /// and future IO to return immediately with an appropriate value. - pub fn shutdown_sender(&self) -> IoResult<()> { - self.sender.shutdown() - } +impl Client, Receiver> +where S: AsTcpStream, +{ + /// Shuts down the sending half of the client connection, will cause all pending + /// and future IO to return immediately with an appropriate value. + pub fn shutdown_sender(&self) -> IoResult<()> { + self.sender.shutdown() + } - /// Shuts down the receiving half of the client connection, will cause all pending - /// and future IO to return immediately with an appropriate value. - pub fn shutdown_receiver(&self) -> IoResult<()> { - self.receiver.shutdown() - } + /// Shuts down the receiving half of the client connection, will cause all pending + /// and future IO to return immediately with an appropriate value. + pub fn shutdown_receiver(&self) -> IoResult<()> { + self.receiver.shutdown() + } - /// Shuts down the client connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&self) -> IoResult<()> { - self.receiver.shutdown_all() - } + /// Shuts down the client connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.receiver.shutdown_all() + } } impl> Client { diff --git a/src/sender.rs b/src/sender.rs index 0dda17e950..30d6133507 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -34,7 +34,7 @@ impl Sender { } impl Sender -where S: AsTcpStream +where S: AsTcpStream, { /// Closes the sender side of the connection, will cause all pending and future IO to /// return immediately with an appropriate value. diff --git a/src/stream.rs b/src/stream.rs index 02dfa0ff5a..d50c327160 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -4,12 +4,45 @@ use std::io::{ Read, Write }; +use std::ops::Deref; pub use std::net::{ TcpStream, Shutdown, }; pub use openssl::ssl::SslStream; +pub trait AsTcpStream: Read + Write { + fn as_tcp(&self) -> &TcpStream; +} + +impl AsTcpStream for TcpStream { + fn as_tcp(&self) -> &TcpStream { + self + } +} + +impl AsTcpStream for SslStream { + fn as_tcp(&self) -> &TcpStream { + self.get_ref() + } +} + +impl AsTcpStream for Box { + fn as_tcp(&self) -> &TcpStream { + self.deref().as_tcp() + } +} + +pub trait TryUnsizedClone { + fn try_clone(&self) -> io::Result>; +} + +impl TryUnsizedClone for Box { + fn try_clone(&self) -> io::Result> { + unimplemented!(); + } +} + /// Represents a stream that can be read from, written to, and split into two. /// This is an abstraction around readable and writable things to be able /// to speak websockets over ssl, tcp, unix sockets, etc. @@ -53,51 +86,52 @@ where R: Read, } impl Stream for TcpStream { - type R = TcpStream; - type W = TcpStream; + type R = Self; + type W = Self; - fn reader(&mut self) -> &mut TcpStream { + fn reader(&mut self) -> &mut Self::R { self } - fn writer(&mut self) -> &mut TcpStream { + fn writer(&mut self) -> &mut Self::W { self } - fn split(self) -> io::Result<(TcpStream, TcpStream)> { + fn split(self) -> io::Result<(Self::R, Self::W)> { Ok((try!(self.try_clone()), self)) } } impl Stream for SslStream { - type R = SslStream; - type W = SslStream; + type R = Self; + type W = Self; - fn reader(&mut self) -> &mut SslStream { + fn reader(&mut self) -> &mut Self::R { self } - fn writer(&mut self) -> &mut SslStream { + fn writer(&mut self) -> &mut Self::W { self } - fn split(self) -> io::Result<(SslStream, SslStream)> { + fn split(self) -> io::Result<(Self::R, Self::W)> { Ok((try!(self.try_clone()), self)) } } -pub trait AsTcpStream: Read + Write { - fn as_tcp(&self) -> &TcpStream; -} +impl Stream for Box { + type R = Self; + type W = Self; -impl AsTcpStream for TcpStream { - fn as_tcp(&self) -> &TcpStream { + fn reader(&mut self) -> &mut Self::R { self } -} -impl AsTcpStream for SslStream { - fn as_tcp(&self) -> &TcpStream { - self.get_ref() + fn writer(&mut self) -> &mut Self::W { + self + } + + fn split(self) -> io::Result<(Self::R, Self::W)> { + Ok((try!(self.try_clone()), self)) } } From 87fecda2e696db6c06e4e42fc5e91b132d4ec573 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Sun, 3 Jul 2016 16:42:36 -0400 Subject: [PATCH 04/32] started work on having the server side code use websocket upgrade --- src/client/mod.rs | 3 +- src/server/mod.rs | 163 ++++++++++++++++++++++++++---------------- src/server/upgrade.rs | 26 +++++-- src/stream.rs | 13 +++- 4 files changed, 138 insertions(+), 67 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 5143532755..8018e2d0bf 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -123,8 +123,7 @@ impl Client>, Receiver>> { 80 }, }; - let hostname = &host.hostname[..]; - let tcp_stream = try!(TcpStream::connect((hostname, port))); + let tcp_stream = try!(TcpStream::connect((&host.hostname[..], port))); let stream: Box = if secure { if let Some(c) = ssl_context { diff --git a/src/server/mod.rs b/src/server/mod.rs index 541ce53a3b..ab494ee83a 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,14 +1,31 @@ //! Provides an implementation of a WebSocket server -use std::net::{SocketAddr, ToSocketAddrs, TcpListener}; -use std::net::Shutdown; -use std::io::{Read, Write}; -use std::io; +use std::net::{ + SocketAddr, + ToSocketAddrs, + TcpListener, + TcpStream, + Shutdown, +}; +use std::io::{ + self, + Read, + Write, +}; +use std::borrow::Cow; +use std::ops::Deref; +use openssl::ssl::{ + SslContext, + SslMethod, + SslStream, +}; +use stream::{ + MaybeSslContext, + NoSslContext, +}; + pub use self::request::Request; pub use self::response::Response; -use openssl::ssl::SslContext; -use openssl::ssl::SslStream; - pub mod request; pub mod response; @@ -73,56 +90,58 @@ pub mod response; ///} /// # } /// ``` -pub struct Server<'a> { +pub struct Server<'s, S> +where S: MaybeSslContext + 's, +{ inner: TcpListener, - context: Option<&'a SslContext>, + ssl_context: Cow<'s, S>, } -impl<'a> Server<'a> { - /// Bind this Server to this socket - pub fn bind(addr: T) -> io::Result> { - Ok(Server { - inner: try!(TcpListener::bind(&addr)), - context: None, - }) - } - /// Bind this Server to this socket, utilising the given SslContext - pub fn bind_secure(addr: T, context: &'a SslContext) -> io::Result> { - Ok(Server { - inner: try!(TcpListener::bind(&addr)), - context: Some(context), - }) - } +impl<'s, S> Server<'s, S> +where S: MaybeSslContext + 's, +{ /// Get the socket address of this server pub fn local_addr(&self) -> io::Result { self.inner.local_addr() } /// Create a new independently owned handle to the underlying socket. - pub fn try_clone(&self) -> io::Result> { + pub fn try_clone(&'s self) -> io::Result> { let inner = try!(self.inner.try_clone()); Ok(Server { inner: inner, - context: self.context + ssl_context: Cow::Borrowed(&*self.ssl_context), + }) + } + + pub fn into_owned<'o>(self) -> io::Result> { + Ok(Server { + inner: self.inner, + ssl_context: Cow::Owned(self.ssl_context.into_owned()), + }) + } +} + +impl<'s> Server<'s, SslContext> { + /// Bind this Server to this socket, utilising the given SslContext + pub fn bind_secure(addr: A, context: &'s SslContext) -> io::Result + where A: ToSocketAddrs, + { + Ok(Server { + inner: try!(TcpListener::bind(&addr)), + ssl_context: Cow::Borrowed(context), }) } /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest - pub fn accept(&mut self) -> io::Result> { + pub fn accept(&mut self) -> io::Result> { let stream = try!(self.inner.accept()).0; - let wsstream = match self.context { - Some(context) => { - let sslstream = match SslStream::accept(context, stream) { - Ok(s) => s, - Err(err) => { - return Err(io::Error::new(io::ErrorKind::Other, err)); - } - }; - WebSocketStream::Ssl(sslstream) - } - None => { WebSocketStream::Tcp(stream) } - }; - Ok(Connection(try!(wsstream.try_clone()), try!(wsstream.try_clone()))) + match SslStream::accept(&*self.ssl_context, stream) { + Ok(s) => Ok(s), + Err(err) => { + Err(io::Error::new(io::ErrorKind::Other, err)) + }, + } } /// Changes whether the Server is in nonblocking mode. @@ -134,34 +153,58 @@ impl<'a> Server<'a> { } } -impl<'a> Iterator for Server<'a> { - type Item = io::Result>; +impl<'s> Iterator for Server<'s, SslContext> { + type Item = io::Result>; fn next(&mut self) -> Option<::Item> { Some(self.accept()) } } -/// Represents a connection to the server that has not been processed yet. -pub struct Connection(R, W); +impl<'s> Server<'s, NoSslContext> { + /// Bind this Server to this socket + pub fn bind(addr: A) -> io::Result { + Ok(Server { + inner: try!(TcpListener::bind(&addr)), + ssl_context: Cow::Owned(NoSslContext), + }) + } -impl Connection { - /// Process this connection and read the request. - pub fn read_request(self) -> io::Result> { - match Request::read(self.0, self.1) { - Ok(result) => { Ok(result) }, - Err(err) => { - Err(io::Error::new(io::ErrorKind::InvalidInput, err)) - } - } + /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest + pub fn accept(&mut self) -> io::Result { + Ok(try!(self.inner.accept()).0) } } -impl Connection { - /// Shuts down the currennt connection in the specified way. - /// All future IO calls to this connection will return immediately with an appropriate - /// return value. - pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { - self.0.shutdown(how) - } +impl<'s> Iterator for Server<'s, NoSslContext> { + type Item = io::Result; + + fn next(&mut self) -> Option<::Item> { + Some(self.accept()) + } } + +// /// Represents a connection to the server that has not been processed yet. +// pub struct Connection(R, W); + +// impl Connection { +// /// Process this connection and read the request. +// pub fn read_request(self) -> io::Result> { +// match Request::read(self.0, self.1) { +// Ok(result) => { Ok(result) }, +// Err(err) => { +// Err(io::Error::new(io::ErrorKind::InvalidInput, err)) +// } +// } +// } +// } + +// impl Connection { +// /// Shuts down the currennt connection in the specified way. +// /// All future IO calls to this connection will return immediately with an appropriate +// /// return value. +// pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { +// self.0.shutdown(how) +// } +// } + diff --git a/src/server/upgrade.rs b/src/server/upgrade.rs index d4f24a2b6f..7058f0eec5 100644 --- a/src/server/upgrade.rs +++ b/src/server/upgrade.rs @@ -46,6 +46,22 @@ where S: Stream, } } +impl IntoWs for S +where S: Stream, +{ + type Stream = S; + + fn into_ws(self) -> Result, (Request, IntoWsError)> { + unimplemented!(); + } +} + +impl IntoWs for Request { + fn into_ws(self) -> Result, (Request, IntoWsError)> { + unimplemented!(); + } +} + /// Trait to take a stream or similar and attempt to recover the start of a /// websocket handshake from it. /// Should be used when a stream might contain a request for a websocket session. @@ -57,9 +73,11 @@ where S: Stream, /// Note: the stream is owned because the websocket client expects to own its stream. pub trait IntoWs { - /// Attempt to parse the start of a Websocket handshake, later with the returned - /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to - /// send a handshake rejection response. - fn into_ws(self) -> Result, (O, IntoWsError)>; + type Stream: Stream; + + /// Attempt to parse the start of a Websocket handshake, later with the returned + /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to + /// send a handshake rejection response. + fn into_ws(self) -> Result, (O, IntoWsError)>; } diff --git a/src/stream.rs b/src/stream.rs index d50c327160..f003038386 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -9,7 +9,10 @@ pub use std::net::{ TcpStream, Shutdown, }; -pub use openssl::ssl::SslStream; +pub use openssl::ssl::{ + SslStream, + SslContext, +}; pub trait AsTcpStream: Read + Write { fn as_tcp(&self) -> &TcpStream; @@ -135,3 +138,11 @@ impl Stream for Box { Ok((try!(self.try_clone()), self)) } } + +#[derive(Clone)] +pub struct NoSslContext; + +pub trait MaybeSslContext: Clone {} + +impl MaybeSslContext for NoSslContext {} +impl MaybeSslContext for SslContext {} From 7a6a13733ee1089017e46c35b36f1c1012028c42 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Mon, 4 Jul 2016 14:45:05 -0400 Subject: [PATCH 05/32] server now uses the ws upgrade trait to create clients --- src/server/mod.rs | 109 +++++++++++++------- src/server/request.rs | 2 +- src/server/response.rs | 2 +- src/server/upgrade.rs | 229 +++++++++++++++++++++++++++++++++++------ 4 files changed, 270 insertions(+), 72 deletions(-) diff --git a/src/server/mod.rs b/src/server/mod.rs index ab494ee83a..075f29b878 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -13,21 +13,39 @@ use std::io::{ }; use std::borrow::Cow; use std::ops::Deref; +use std::convert::Into; use openssl::ssl::{ SslContext, SslMethod, SslStream, }; use stream::{ + Stream, MaybeSslContext, NoSslContext, }; - -pub use self::request::Request; -pub use self::response::Response; +use self::upgrade::{ + WsUpgrade, + IntoWs, +}; +pub use self::upgrade::hyper::{ + Request, + HyperIntoWsError, +}; pub mod request; pub mod response; +pub mod upgrade; + +pub struct InvalidConnection +where S: Stream, +{ + pub stream: Option, + pub parsed: Option, + pub error: HyperIntoWsError, +} + +pub type AcceptResult = Result, InvalidConnection>; /// Represents a WebSocket server which can work with either normal (non-secure) connections, or secure WebSocket connections. /// @@ -134,13 +152,32 @@ impl<'s> Server<'s, SslContext> { } /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest - pub fn accept(&mut self) -> io::Result> { - let stream = try!(self.inner.accept()).0; - match SslStream::accept(&*self.ssl_context, stream) { - Ok(s) => Ok(s), - Err(err) => { - Err(io::Error::new(io::ErrorKind::Other, err)) - }, + pub fn accept(&mut self) -> AcceptResult> { + let stream = match self.inner.accept() { + Ok(s) => s.0, + Err(e) => return Err(InvalidConnection { + stream: None, + parsed: None, + error: e.into(), + }), + }; + + let stream = match SslStream::accept(&*self.ssl_context, stream) { + Ok(s) => s, + Err(err) => return Err(InvalidConnection { + stream: None, + parsed: None, + error: io::Error::new(io::ErrorKind::Other, err).into(), + }), + }; + + match stream.into_ws() { + Ok(u) => Ok(u), + Err((s, r, e)) => Err(InvalidConnection { + stream: Some(s), + parsed: r, + error: e.into(), + }), } } @@ -154,10 +191,10 @@ impl<'s> Server<'s, SslContext> { } impl<'s> Iterator for Server<'s, SslContext> { - type Item = io::Result>; + type Item = WsUpgrade>; fn next(&mut self) -> Option<::Item> { - Some(self.accept()) + self.accept().ok() } } @@ -171,40 +208,32 @@ impl<'s> Server<'s, NoSslContext> { } /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest - pub fn accept(&mut self) -> io::Result { - Ok(try!(self.inner.accept()).0) + pub fn accept(&mut self) -> AcceptResult { + let stream = match self.inner.accept() { + Ok(s) => s.0, + Err(e) => return Err(InvalidConnection { + stream: None, + parsed: None, + error: e.into(), + }), + }; + + match stream.into_ws() { + Ok(u) => Ok(u), + Err((s, r, e)) => Err(InvalidConnection { + stream: Some(s), + parsed: r, + error: e.into(), + }), + } } } impl<'s> Iterator for Server<'s, NoSslContext> { - type Item = io::Result; + type Item = WsUpgrade; fn next(&mut self) -> Option<::Item> { - Some(self.accept()) + self.accept().ok() } } -// /// Represents a connection to the server that has not been processed yet. -// pub struct Connection(R, W); - -// impl Connection { -// /// Process this connection and read the request. -// pub fn read_request(self) -> io::Result> { -// match Request::read(self.0, self.1) { -// Ok(result) => { Ok(result) }, -// Err(err) => { -// Err(io::Error::new(io::ErrorKind::InvalidInput, err)) -// } -// } -// } -// } - -// impl Connection { -// /// Shuts down the currennt connection in the specified way. -// /// All future IO calls to this connection will return immediately with an appropriate -// /// return value. -// pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> { -// self.0.shutdown(how) -// } -// } - diff --git a/src/server/request.rs b/src/server/request.rs index 990647592b..b9e1ec1114 100644 --- a/src/server/request.rs +++ b/src/server/request.rs @@ -2,7 +2,7 @@ use std::io::{Read, Write}; -use server::Response; +use server::response::Response; use result::{WebSocketResult, WebSocketError}; use header::{WebSocketKey, WebSocketVersion, WebSocketProtocol, WebSocketExtensions, Origin}; diff --git a/src/server/response.rs b/src/server/response.rs index 0bae5e4cd0..0d5d763ccb 100644 --- a/src/server/response.rs +++ b/src/server/response.rs @@ -13,7 +13,7 @@ use unicase::UniCase; use header::{WebSocketAccept, WebSocketProtocol, WebSocketExtensions}; use sender::Sender; use receiver::Receiver; -use server::Request; +use server::request::Request; use client::Client; use result::WebSocketResult; use dataframe::DataFrame; diff --git a/src/server/upgrade.rs b/src/server/upgrade.rs index 7058f0eec5..402b1fe3ca 100644 --- a/src/server/upgrade.rs +++ b/src/server/upgrade.rs @@ -1,16 +1,10 @@ //! Allows you to take an existing request or stream of data and convert it into a //! WebSocket client. -extern crate hyper; -extern crate openssl; - -use super::super::stream::Stream; - -/// Any error that could occur when attempting -/// to parse data into a websocket upgrade request -pub enum IntoWsError { - /// If the request was not actually asking for a websocket connection - RequestIsNotUpgrade, -} +use std::net::TcpStream; +use stream::{ + Stream, + AsTcpStream, +}; /// Intermediate representation of a half created websocket session. /// Should be used to examine the client's handshake @@ -33,31 +27,23 @@ where S: Stream, } } - fn unwrap(self) -> S { - self.stream - } - - fn accept(self) { + pub fn accept(self) { unimplemented!(); } - fn reject(self) -> S { + pub fn reject(self) -> S { unimplemented!(); } -} -impl IntoWs for S -where S: Stream, -{ - type Stream = S; - - fn into_ws(self) -> Result, (Request, IntoWsError)> { + pub fn into_stream(self) -> S { unimplemented!(); } } -impl IntoWs for Request { - fn into_ws(self) -> Result, (Request, IntoWsError)> { +impl WsUpgrade +where S: Stream + AsTcpStream, +{ + pub fn tcp_stream(&self) -> &TcpStream { unimplemented!(); } } @@ -71,13 +57,196 @@ impl IntoWs for Request { /// Otherwise the original stream is returned along with an error. /// /// Note: the stream is owned because the websocket client expects to own its stream. -pub trait IntoWs +pub trait IntoWs +where S: Stream, { - type Stream: Stream; - + type Error; /// Attempt to parse the start of a Websocket handshake, later with the returned /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to /// send a handshake rejection response. - fn into_ws(self) -> Result, (O, IntoWsError)>; + fn into_ws(mut self) -> Result, Self::Error>; +} + +pub mod hyper { + extern crate hyper; + + use std::convert::From; + use std::error::Error; + use std::io; + use hyper::http::h1::parse_request; + use header::{ + WebSocketKey, + WebSocketVersion, + }; + use std::fmt::{ + Formatter, + Display, + self, + }; + use stream::Stream; + use super::{ + IntoWs, + WsUpgrade, + }; + + pub use hyper::http::h1::Incoming; + pub use hyper::method::Method; + pub use hyper::version::HttpVersion; + pub use hyper::uri::RequestUri; + pub use hyper::buffer::BufReader; + pub use hyper::header::{ + Upgrade, + ProtocolName, + Connection, + ConnectionOption, + }; + + pub type Request = Incoming<(Method, RequestUri)>; + + #[derive(Debug)] + pub enum HyperIntoWsError { + MethodNotGet, + UnsupportedHttpVersion, + UnsupportedWebsocketVersion, + NoSecWsKeyHeader, + NoWsUpgradeHeader, + NoUpgradeHeader, + NoWsConnectionHeader, + NoConnectionHeader, + /// IO error from reading the underlying socket + Io(io::Error), + /// Error while parsing an incoming request + Parsing(hyper::error::Error), + } + + impl Display for HyperIntoWsError { + fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { + fmt.write_str(self.description()) + } + } + + impl Error for HyperIntoWsError { + fn description(&self) -> &str { + use self::HyperIntoWsError::*; + match self { + &MethodNotGet => "Request method must be GET", + &UnsupportedHttpVersion => "Unsupported request HTTP version", + &UnsupportedWebsocketVersion => "Unsupported WebSocket version", + &NoSecWsKeyHeader => "Missing Sec-WebSocket-Key header", + &NoWsUpgradeHeader => "Invalid Upgrade WebSocket header", + &NoUpgradeHeader => "Missing Upgrade WebSocket header", + &NoWsConnectionHeader => "Invalid Connection WebSocket header", + &NoConnectionHeader => "Missing Connection WebSocket header", + &Io(ref e) => e.description(), + &Parsing(ref e) => e.description(), + } + } + + fn cause(&self) -> Option<&Error> { + match *self { + HyperIntoWsError::Io(ref e) => Some(e), + HyperIntoWsError::Parsing(ref e) => Some(e), + _ => None, + } + } + } + + impl From for HyperIntoWsError { + fn from(err: io::Error) -> Self { + HyperIntoWsError::Io(err) + } + } + + impl From for HyperIntoWsError { + fn from(err: hyper::error::Error) -> Self { + HyperIntoWsError::Parsing(err) + } + } + + impl IntoWs for S + where S: Stream, + { + type Error = (Self, Option, HyperIntoWsError); + + fn into_ws(mut self) -> Result, Self::Error> { + let request = { + let mut reader = BufReader::new(self.reader()); + parse_request(&mut reader) + }; + + let request = match request { + Ok(r) => r, + Err(e) => return Err((self, None, e.into())), + }; + + match validate(&request) { + Ok(_) => unimplemented!(), + Err(e) => Err((self, Some(request), e)), + } + } + } + + // TODO: Remove request and response from server + + // TODO + // impl IntoWs for Request { + // fn into_ws(self) -> Result, (Request, IntoWsError)> { + // unimplemented!(); + // } + // } + + + pub fn validate(request: &Request) -> Result<(), HyperIntoWsError> { + if request.subject.0 != Method::Get { + return Err(HyperIntoWsError::MethodNotGet); + } + + if request.version == HttpVersion::Http09 + || request.version == HttpVersion::Http10 + { + return Err(HyperIntoWsError::UnsupportedHttpVersion); + } + + if let Some(version) = request.headers.get::() { + if version != &WebSocketVersion::WebSocket13 { + return Err(HyperIntoWsError::UnsupportedWebsocketVersion); + } + } + + if request.headers.get::().is_none() { + return Err(HyperIntoWsError::NoSecWsKeyHeader); + } + + match request.headers.get() { + Some(&Upgrade(ref upgrade)) => { + if upgrade.iter().all(|u| u.name != ProtocolName::WebSocket) { + return Err(HyperIntoWsError::NoWsUpgradeHeader) + } + }, + None => return Err(HyperIntoWsError::NoUpgradeHeader), + }; + + fn check_connection_header(headers: &Vec) -> bool { + for header in headers { + if let &ConnectionOption::ConnectionHeader(ref h) = header { + if h as &str == "upgrade" { + return true; + } + } + } + false + } + + match request.headers.get() { + Some(&Connection(ref connection)) => { + if !check_connection_header(connection) { + return Err(HyperIntoWsError::NoWsConnectionHeader); + } + }, + None => return Err(HyperIntoWsError::NoConnectionHeader), + }; + + Ok(()) + } } From f7e4d3b1227ce5442d5a8df55bb8d9c4d0a23b10 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Mon, 4 Jul 2016 17:05:03 -0400 Subject: [PATCH 06/32] IntoWs for Streams & Request/Stream pairs Both streams and request/stream pairs can be made into a WsUpgrade struct. also AsTcpStream streams have a method to access the inner tcp stream. Things left: - Finish WsUpgrade struct: add methods and ability to accept or reject connections. - Implement IntoWs for a hyper serverside request. --- src/server/upgrade.rs | 62 ++++++++++++++++++++++++++++++------------- src/stream.rs | 20 +++++++------- 2 files changed, 53 insertions(+), 29 deletions(-) diff --git a/src/server/upgrade.rs b/src/server/upgrade.rs index 402b1fe3ca..5c5d61e82e 100644 --- a/src/server/upgrade.rs +++ b/src/server/upgrade.rs @@ -16,17 +16,12 @@ pub struct WsUpgrade where S: Stream, { stream: S, + request: hyper::Request, } impl WsUpgrade where S: Stream, { - fn from_stream(inner: S) -> Self { - WsUpgrade { - stream: inner, - } - } - pub fn accept(self) { unimplemented!(); } @@ -44,7 +39,7 @@ impl WsUpgrade where S: Stream + AsTcpStream, { pub fn tcp_stream(&self) -> &TcpStream { - unimplemented!(); + self.stream.as_tcp() } } @@ -57,14 +52,13 @@ where S: Stream + AsTcpStream, /// Otherwise the original stream is returned along with an error. /// /// Note: the stream is owned because the websocket client expects to own its stream. -pub trait IntoWs -where S: Stream, -{ +pub trait IntoWs { + type Stream: Stream; type Error; /// Attempt to parse the start of a Websocket handshake, later with the returned /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to /// send a handshake rejection response. - fn into_ws(mut self) -> Result, Self::Error>; + fn into_ws(mut self) -> Result, Self::Error>; } pub mod hyper { @@ -72,8 +66,8 @@ pub mod hyper { use std::convert::From; use std::error::Error; - use std::io; use hyper::http::h1::parse_request; + use hyper::net::NetworkStream; use header::{ WebSocketKey, WebSocketVersion, @@ -88,12 +82,18 @@ pub mod hyper { IntoWs, WsUpgrade, }; + use std::io::{ + Read, + Write, + self, + }; pub use hyper::http::h1::Incoming; pub use hyper::method::Method; pub use hyper::version::HttpVersion; pub use hyper::uri::RequestUri; pub use hyper::buffer::BufReader; + pub use hyper::server::Request as HyperRequest; pub use hyper::header::{ Upgrade, ProtocolName, @@ -103,6 +103,8 @@ pub mod hyper { pub type Request = Incoming<(Method, RequestUri)>; + pub struct RequestStreamPair(pub S, pub Request); + #[derive(Debug)] pub enum HyperIntoWsError { MethodNotGet, @@ -163,12 +165,13 @@ pub mod hyper { } } - impl IntoWs for S + impl IntoWs for S where S: Stream, { + type Stream = S; type Error = (Self, Option, HyperIntoWsError); - fn into_ws(mut self) -> Result, Self::Error> { + fn into_ws(mut self) -> Result, Self::Error> { let request = { let mut reader = BufReader::new(self.reader()); parse_request(&mut reader) @@ -180,21 +183,42 @@ pub mod hyper { }; match validate(&request) { - Ok(_) => unimplemented!(), + Ok(_) => Ok(WsUpgrade { + stream: self, + request: request, + }), Err(e) => Err((self, Some(request), e)), } } } - // TODO: Remove request and response from server + impl IntoWs for RequestStreamPair + where S: Stream, + { + type Stream = S; + type Error = (S, Request, HyperIntoWsError); - // TODO - // impl IntoWs for Request { - // fn into_ws(self) -> Result, (Request, IntoWsError)> { + fn into_ws(self) -> Result, Self::Error> { + match validate(&self.1) { + Ok(_) => Ok(WsUpgrade { + stream: self.0, + request: self.1, + }), + Err(e) => Err((self.0, self.1, e)), + } + } + } + + // impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { + // type Stream = Box, R=Box>>; + // type Error = (HyperRequest<'a, 'b>, HyperIntoWsError); + + // fn into_ws(self) -> Result, Self::Error> { // unimplemented!(); // } // } + // TODO: Remove request and response from server pub fn validate(request: &Request) -> Result<(), HyperIntoWsError> { if request.subject.0 != Method::Get { diff --git a/src/stream.rs b/src/stream.rs index f003038386..47354b8895 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -57,10 +57,10 @@ pub trait Stream type W: Write; /// Get a mutable borrow to the reading component of this stream - fn reader(&mut self) -> &mut Self::R; + fn reader(&mut self) -> &mut Read; /// Get a mutable borrow to the writing component of this stream - fn writer(&mut self) -> &mut Self::W; + fn writer(&mut self) -> &mut Write; /// Split this stream into readable and writable components. /// The motivation behind this is to be able to read on one thread @@ -75,11 +75,11 @@ where R: Read, type R = R; type W = W; - fn reader(&mut self) -> &mut Self::R { + fn reader(&mut self) -> &mut Read { &mut self.0 } - fn writer(&mut self) -> &mut Self::W { + fn writer(&mut self) -> &mut Write { &mut self.1 } @@ -92,11 +92,11 @@ impl Stream for TcpStream { type R = Self; type W = Self; - fn reader(&mut self) -> &mut Self::R { + fn reader(&mut self) -> &mut Read { self } - fn writer(&mut self) -> &mut Self::W { + fn writer(&mut self) -> &mut Write { self } @@ -109,11 +109,11 @@ impl Stream for SslStream { type R = Self; type W = Self; - fn reader(&mut self) -> &mut Self::R { + fn reader(&mut self) -> &mut Read { self } - fn writer(&mut self) -> &mut Self::W { + fn writer(&mut self) -> &mut Write { self } @@ -126,11 +126,11 @@ impl Stream for Box { type R = Self; type W = Self; - fn reader(&mut self) -> &mut Self::R { + fn reader(&mut self) -> &mut Read { self } - fn writer(&mut self) -> &mut Self::W { + fn writer(&mut self) -> &mut Write { self } From 495351ea59ed85c0c65cd0ed92b5ee73789d9b72 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Tue, 5 Jul 2016 10:20:31 -0400 Subject: [PATCH 07/32] started on hyper IntoWs integration The last bit is to extract the NetworkStream from a hyper server-side request and save that as a Stream, then turn it into an Incoming<(..)> and validate the request. Implementing Stream for a NetworkStream is proving complicated, since because now there is conflicting implementations for many structs implementing NetworkStream. --- src/stream/hyper.rs | 54 ++++++++++++++++++++++++++++++++ src/{stream.rs => stream/mod.rs} | 2 ++ 2 files changed, 56 insertions(+) create mode 100644 src/stream/hyper.rs rename src/{stream.rs => stream/mod.rs} (99%) diff --git a/src/stream/hyper.rs b/src/stream/hyper.rs new file mode 100644 index 0000000000..5cd373f01e --- /dev/null +++ b/src/stream/hyper.rs @@ -0,0 +1,54 @@ +extern crate hyper; +extern crate openssl; + +use openssl::ssl::SslStream; +use hyper::client::pool::PooledStream; +use std::net::TcpStream; +use std::io::{ + Read, + Write, + self, +}; +use stream::{ + Stream, + AsTcpStream, +}; +use hyper::net::{ + NetworkStream, + HttpStream, + HttpsStream, +}; + +impl Stream for S +where S: NetworkStream, +{ + type R = Self; + type W = Self; + + fn reader(&mut self) -> &mut Read { + self + } + + fn writer(&mut self) -> &mut Read { + self + } + + fn split(self) -> io::Result<(Self::R, Self::W)> { + if let Some(http) = self.downcast_ref::() { + Ok((http.clone(), self)) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "Unknown implementation of NetworkStream found!", + )) + } + } +} + +impl AsTcpStream for S +where S: NetworkStream, +{ + fn as_tcp(&self) -> &TcpStream { + unimplemented!(); + } +} diff --git a/src/stream.rs b/src/stream/mod.rs similarity index 99% rename from src/stream.rs rename to src/stream/mod.rs index 47354b8895..d25a6bc779 100644 --- a/src/stream.rs +++ b/src/stream/mod.rs @@ -14,6 +14,8 @@ pub use openssl::ssl::{ SslContext, }; +mod hyper; + pub trait AsTcpStream: Read + Write { fn as_tcp(&self) -> &TcpStream; } From 8628391a526668e397d604e0936bd118d111d7ee Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Wed, 6 Jul 2016 23:20:54 -0400 Subject: [PATCH 08/32] removed TryUnsizedClone and many Stream impls for a simpler model where all AsTcpStreams implement Stream --- src/client/mod.rs | 3 +- src/server/upgrade.rs | 14 +++--- src/{stream/mod.rs => stream.rs} | 74 +++++++++++--------------------- src/stream/hyper.rs | 54 ----------------------- 4 files changed, 32 insertions(+), 113 deletions(-) rename src/{stream/mod.rs => stream.rs} (70%) delete mode 100644 src/stream/hyper.rs diff --git a/src/client/mod.rs b/src/client/mod.rs index 8018e2d0bf..f828a74f4e 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -11,7 +11,6 @@ use ws::receiver::{DataFrameIterator, MessageIterator}; use result::WebSocketResult; use stream::{ AsTcpStream, - TryUnsizedClone, Stream, }; use dataframe::DataFrame; @@ -136,7 +135,7 @@ impl Client>, Receiver>> { Box::new(tcp_stream) }; - let (read, write) = (try!(stream.try_clone()), stream); + let (read, write) = (try!(stream.duplicate()), stream); Request::new((host, resource_name, secure), (read, write)) } diff --git a/src/server/upgrade.rs b/src/server/upgrade.rs index 5c5d61e82e..b65eeb2987 100644 --- a/src/server/upgrade.rs +++ b/src/server/upgrade.rs @@ -210,7 +210,7 @@ pub mod hyper { } // impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { - // type Stream = Box, R=Box>>; + // type Stream = Box; // type Error = (HyperRequest<'a, 'b>, HyperIntoWsError); // fn into_ws(self) -> Result, Self::Error> { @@ -218,8 +218,6 @@ pub mod hyper { // } // } - // TODO: Remove request and response from server - pub fn validate(request: &Request) -> Result<(), HyperIntoWsError> { if request.subject.0 != Method::Get { return Err(HyperIntoWsError::MethodNotGet); @@ -230,17 +228,17 @@ pub mod hyper { { return Err(HyperIntoWsError::UnsupportedHttpVersion); } - + if let Some(version) = request.headers.get::() { if version != &WebSocketVersion::WebSocket13 { return Err(HyperIntoWsError::UnsupportedWebsocketVersion); } } - + if request.headers.get::().is_none() { return Err(HyperIntoWsError::NoSecWsKeyHeader); } - + match request.headers.get() { Some(&Upgrade(ref upgrade)) => { if upgrade.iter().all(|u| u.name != ProtocolName::WebSocket) { @@ -260,7 +258,7 @@ pub mod hyper { } false } - + match request.headers.get() { Some(&Connection(ref connection)) => { if !check_connection_header(connection) { @@ -269,7 +267,7 @@ pub mod hyper { }, None => return Err(HyperIntoWsError::NoConnectionHeader), }; - + Ok(()) } } diff --git a/src/stream/mod.rs b/src/stream.rs similarity index 70% rename from src/stream/mod.rs rename to src/stream.rs index d25a6bc779..d74f717e10 100644 --- a/src/stream/mod.rs +++ b/src/stream.rs @@ -1,10 +1,11 @@ //! Provides the default stream type for WebSocket connections. +use std::ops::Deref; +use std::any::Any; use std::io::{ self, Read, Write }; -use std::ops::Deref; pub use std::net::{ TcpStream, Shutdown, @@ -14,36 +15,40 @@ pub use openssl::ssl::{ SslContext, }; -mod hyper; - -pub trait AsTcpStream: Read + Write { +pub trait AsTcpStream: Read + Write + Any + 'static { fn as_tcp(&self) -> &TcpStream; + + fn duplicate(&self) -> io::Result + where Self: Sized; } impl AsTcpStream for TcpStream { fn as_tcp(&self) -> &TcpStream { self } + + fn duplicate(&self) -> io::Result { + self.try_clone() + } } impl AsTcpStream for SslStream { fn as_tcp(&self) -> &TcpStream { self.get_ref() } + + fn duplicate(&self) -> io::Result { + self.try_clone() + } } impl AsTcpStream for Box { fn as_tcp(&self) -> &TcpStream { self.deref().as_tcp() } -} - -pub trait TryUnsizedClone { - fn try_clone(&self) -> io::Result>; -} -impl TryUnsizedClone for Box { - fn try_clone(&self) -> io::Result> { + fn duplicate(&self) -> io::Result + where Self: Any { unimplemented!(); } } @@ -90,41 +95,9 @@ where R: Read, } } -impl Stream for TcpStream { - type R = Self; - type W = Self; - - fn reader(&mut self) -> &mut Read { - self - } - - fn writer(&mut self) -> &mut Write { - self - } - - fn split(self) -> io::Result<(Self::R, Self::W)> { - Ok((try!(self.try_clone()), self)) - } -} - -impl Stream for SslStream { - type R = Self; - type W = Self; - - fn reader(&mut self) -> &mut Read { - self - } - - fn writer(&mut self) -> &mut Write { - self - } - - fn split(self) -> io::Result<(Self::R, Self::W)> { - Ok((try!(self.try_clone()), self)) - } -} - -impl Stream for Box { +impl Stream for S +where S: AsTcpStream, +{ type R = Self; type W = Self; @@ -137,14 +110,17 @@ impl Stream for Box { } fn split(self) -> io::Result<(Self::R, Self::W)> { - Ok((try!(self.try_clone()), self)) + Ok((try!(self.duplicate()), self)) } } +/// Marker struct for having no SSL context in a struct. #[derive(Clone)] pub struct NoSslContext; - +/// Trait that is implemented over NoSslContext and SslContext that +/// serves as a generic bound to make a struct with. +/// Used in the Server to specify impls based on wether the server +/// is running over SSL or not. pub trait MaybeSslContext: Clone {} - impl MaybeSslContext for NoSslContext {} impl MaybeSslContext for SslContext {} diff --git a/src/stream/hyper.rs b/src/stream/hyper.rs deleted file mode 100644 index 5cd373f01e..0000000000 --- a/src/stream/hyper.rs +++ /dev/null @@ -1,54 +0,0 @@ -extern crate hyper; -extern crate openssl; - -use openssl::ssl::SslStream; -use hyper::client::pool::PooledStream; -use std::net::TcpStream; -use std::io::{ - Read, - Write, - self, -}; -use stream::{ - Stream, - AsTcpStream, -}; -use hyper::net::{ - NetworkStream, - HttpStream, - HttpsStream, -}; - -impl Stream for S -where S: NetworkStream, -{ - type R = Self; - type W = Self; - - fn reader(&mut self) -> &mut Read { - self - } - - fn writer(&mut self) -> &mut Read { - self - } - - fn split(self) -> io::Result<(Self::R, Self::W)> { - if let Some(http) = self.downcast_ref::() { - Ok((http.clone(), self)) - } else { - Err(io::Error::new( - io::ErrorKind::Other, - "Unknown implementation of NetworkStream found!", - )) - } - } -} - -impl AsTcpStream for S -where S: NetworkStream, -{ - fn as_tcp(&self) -> &TcpStream { - unimplemented!(); - } -} From dccab0f037fabc0cdbcf43aa11e7b3b0637d1b15 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Wed, 6 Jul 2016 23:22:00 -0400 Subject: [PATCH 09/32] removed response and request from the server code in favor of WsUpgrade --- src/server/mod.rs | 2 - src/server/request.rs | 169 ----------------------------------------- src/server/response.rs | 150 ------------------------------------ 3 files changed, 321 deletions(-) delete mode 100644 src/server/request.rs delete mode 100644 src/server/response.rs diff --git a/src/server/mod.rs b/src/server/mod.rs index 075f29b878..ed8acaaa2f 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -33,8 +33,6 @@ pub use self::upgrade::hyper::{ HyperIntoWsError, }; -pub mod request; -pub mod response; pub mod upgrade; pub struct InvalidConnection diff --git a/src/server/request.rs b/src/server/request.rs deleted file mode 100644 index b9e1ec1114..0000000000 --- a/src/server/request.rs +++ /dev/null @@ -1,169 +0,0 @@ -//! The server-side WebSocket request. - -use std::io::{Read, Write}; - -use server::response::Response; -use result::{WebSocketResult, WebSocketError}; -use header::{WebSocketKey, WebSocketVersion, WebSocketProtocol, WebSocketExtensions, Origin}; - -pub use hyper::uri::RequestUri; - -use hyper::buffer::BufReader; -use hyper::version::HttpVersion; -use hyper::header::Headers; -use hyper::header::{Connection, ConnectionOption}; -use hyper::header::{Upgrade, ProtocolName}; -use hyper::http::h1::parse_request; -use hyper::method::Method; - -use unicase::UniCase; - -/// Represents a server-side (incoming) request. -pub struct Request { - /// The HTTP method used to create the request. All values except `Method::Get` are - /// rejected by `validate()`. - pub method: Method, - - /// The target URI for this request. - pub url: RequestUri, - - /// The HTTP version of this request. - pub version: HttpVersion, - - /// The headers of this request. - pub headers: Headers, - - reader: R, - writer: W, -} - -unsafe impl Send for Request where R: Read + Send, W: Write + Send { } - -impl Request { - /// Short-cut to obtain the WebSocketKey value. - pub fn key(&self) -> Option<&WebSocketKey> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketVersion value. - pub fn version(&self) -> Option<&WebSocketVersion> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketProtocol value. - pub fn protocol(&self) -> Option<&WebSocketProtocol> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketExtensions value. - pub fn extensions(&self) -> Option<&WebSocketExtensions> { - self.headers.get() - } - /// Short-cut to obtain the Origin value. - pub fn origin(&self) -> Option<&Origin> { - self.headers.get() - } - /// Returns a reference to the inner Reader. - pub fn get_reader(&self) -> &R { - &self.reader - } - /// Returns a reference to the inner Writer. - pub fn get_writer(&self) -> &W { - &self.writer - } - /// Returns a mutable reference to the inner Reader. - pub fn get_mut_reader(&mut self) -> &mut R { - &mut self.reader - } - /// Returns a mutable reference to the inner Writer. - pub fn get_mut_writer(&mut self) -> &mut W { - &mut self.writer - } - /// Return the inner Reader and Writer - pub fn into_inner(self) -> (R, W) { - (self.reader, self.writer) - } - /// Reads an inbound request. - /// - /// This method is used within servers, and returns an inbound WebSocketRequest. - /// An error will be returned if the request cannot be read, or is not a valid HTTP - /// request. - /// - /// This method does not have any restrictions on the Request. All validation happens in - /// the `validate` method. - pub fn read(reader: R, writer: W) -> WebSocketResult> { - let mut reader = BufReader::new(reader); - let request = try!(parse_request(&mut reader)); - - Ok(Request { - method: request.subject.0, - url: request.subject.1, - version: request.version, - headers: request.headers, - reader: reader.into_inner(), - writer: writer, - }) - } - /// Check if this constitutes a valid WebSocket upgrade request. - /// - /// Note that `accept()` calls this function internally, however this may be useful for - /// handling requests in a custom way. - pub fn validate(&self) -> WebSocketResult<()> { - if self.method != Method::Get { - return Err(WebSocketError::RequestError("Request method must be GET")); - } - - if self.version == HttpVersion::Http09 || self.version == HttpVersion::Http10 { - return Err(WebSocketError::RequestError("Unsupported request HTTP version")); - } - - if self.version() != Some(&(WebSocketVersion::WebSocket13)) { - return Err(WebSocketError::RequestError("Unsupported WebSocket version")); - } - - if self.key().is_none() { - return Err(WebSocketError::RequestError("Missing Sec-WebSocket-Key header")); - } - - match self.headers.get() { - Some(&Upgrade(ref upgrade)) => { - let mut correct_upgrade = false; - for u in upgrade { - if u.name == ProtocolName::WebSocket { - correct_upgrade = true; - } - } - if !correct_upgrade { - return Err(WebSocketError::RequestError("Invalid Upgrade WebSocket header")); - } - } - None => { return Err(WebSocketError::RequestError("Missing Upgrade WebSocket header")); } - } - - match self.headers.get() { - Some(&Connection(ref connection)) => { - if !connection.contains(&(ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())))) { - return Err(WebSocketError::RequestError("Invalid Connection WebSocket header")); - } - } - None => { return Err(WebSocketError::RequestError("Missing Connection WebSocket header")); } - } - - Ok(()) - } - - /// Accept this request, ready to send a response. - /// - /// This function calls `validate()` on the request, and if the request is found to be invalid, - /// generates a response with a Bad Request status code. - pub fn accept(self) -> Response { - match self.validate() { - Ok(()) => { } - Err(_) => { return self.fail(); } - } - Response::new(self) - } - - /// Fail this request by generating a Bad Request response - pub fn fail(self) -> Response { - Response::bad_request(self) - } -} - diff --git a/src/server/response.rs b/src/server/response.rs deleted file mode 100644 index 0d5d763ccb..0000000000 --- a/src/server/response.rs +++ /dev/null @@ -1,150 +0,0 @@ -//! Struct for server-side WebSocket response. -use std::io::{Read, Write}; - -use hyper::status::StatusCode; -use hyper::version::HttpVersion; -use hyper::header::Headers; -use hyper::header::{Connection, ConnectionOption}; -use hyper::header::{Upgrade, Protocol, ProtocolName}; -use hyper::buffer::BufReader; - -use unicase::UniCase; - -use header::{WebSocketAccept, WebSocketProtocol, WebSocketExtensions}; -use sender::Sender; -use receiver::Receiver; -use server::request::Request; -use client::Client; -use result::WebSocketResult; -use dataframe::DataFrame; -use ws::dataframe::DataFrame as DataFrameable; -use ws; - -/// Represents a server-side (outgoing) response. -pub struct Response { - /// The status of the response - pub status: StatusCode, - /// The headers contained in this response - pub headers: Headers, - /// The HTTP version of this response - pub version: HttpVersion, - - request: Request -} - -unsafe impl Send for Response where R: Read + Send, W: Write + Send { } - -impl Response { - /// Short-cut to obtain the WebSocketAccept value - pub fn accept(&self) -> Option<&WebSocketAccept> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketProtocol value - pub fn protocol(&self) -> Option<&WebSocketProtocol> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketExtensions value - pub fn extensions(&self) -> Option<&WebSocketExtensions> { - self.headers.get() - } - /// Returns a reference to the inner Reader. - pub fn get_reader(&self) -> &R { - self.request.get_reader() - } - /// Returns a reference to the inner Writer. - pub fn get_writer(&self) -> &W { - self.request.get_writer() - } - /// Returns a mutable reference to the inner Reader. - pub fn get_mut_reader(&mut self) -> &mut R { - self.request.get_mut_reader() - } - /// Returns a mutable reference to the inner Writer. - pub fn get_mut_writer(&mut self) -> &mut W { - self.request.get_mut_writer() - } - /// Returns a reference to the request associated with this response/ - pub fn get_request(&self) -> &Request { - &self.request - } - /// Return the inner Reader and Writer - pub fn into_inner(self) -> (R, W) { - self.request.into_inner() - } - /// Create a new outbound WebSocket response. - pub fn new(request: Request) -> Response { - let mut headers = Headers::new(); - headers.set(WebSocketAccept::new(request.key().unwrap())); - headers.set(Connection(vec![ - ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) - ])); - headers.set(Upgrade(vec![Protocol::new(ProtocolName::WebSocket, None)])); - Response { - status: StatusCode::SwitchingProtocols, - headers: headers, - version: HttpVersion::Http11, - request: request - } - } - /// Create a Bad Request response - pub fn bad_request(request: Request) -> Response { - Response { - status: StatusCode::BadRequest, - headers: Headers::new(), - version: HttpVersion::Http11, - request: request - } - } - /// Short-cut to obtain a mutable reference to the WebSocketAccept value - /// Note that to add a header that does not already exist, ```WebSocketResponse.headers.set()``` - /// must be used. - pub fn accept_mut(&mut self) -> Option<&mut WebSocketAccept> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the WebSocketProtocol value - /// Note that to add a header that does not already exist, ```WebSocketResponse.headers.set()``` - /// must be used. - pub fn protocol_mut(&mut self) -> Option<&mut WebSocketProtocol> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the WebSocketExtensions value - /// Note that to add a header that does not already exist, ```WebSocketResponse.headers.set()``` - /// must be used. - pub fn extensions_mut(&mut self) -> Option<&mut WebSocketExtensions> { - self.headers.get_mut() - } - - /// Send this response with the given data frame type D, Sender B and Receiver C. - pub fn send_with(mut self, sender: B, receiver: C) -> WebSocketResult> - where B: ws::Sender, C: ws::Receiver, D: DataFrameable { - let version = self.version; - let status = self.status; - let headers = self.headers.clone(); - try!(write!(self.get_mut_writer(), "{} {}\r\n", version, status)); - try!(write!(self.get_mut_writer(), "{}\r\n", headers)); - Ok(Client::new(sender, receiver)) - } - - /// Send this response, retrieving the inner Reader and Writer - pub fn send_into_inner(mut self) -> WebSocketResult<(R, W)> { - let version = self.version; - let status = self.status; - let headers = self.headers.clone(); - try!(write!(self.get_mut_writer(), "{} {}\r\n", version, status)); - try!(write!(self.get_mut_writer(), "{}\r\n", headers)); - Ok(self.into_inner()) - } - - /// Send this response, returning a Client ready to transmit/receive data frames - pub fn send(mut self) -> WebSocketResult, Receiver>> { - let version = self.version; - let status = self.status; - let headers = self.headers.clone(); - try!(write!(self.get_mut_writer(), "{} {}\r\n", version, status)); - try!(write!(self.get_mut_writer(), "{}\r\n", headers)); - let (reader, writer) = self.into_inner(); - let sender = Sender::new(writer, false); - let receiver = Receiver::new(BufReader::new(reader), true); - Ok(Client::new(sender, receiver)) - } -} From 03152d27f93dd040d0eb4047ed553796c3266689 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Thu, 7 Jul 2016 00:00:08 -0400 Subject: [PATCH 10/32] moved hyper-specific upgrade implementation into its own module --- src/server/upgrade.rs | 274 ------------------------------------ src/server/upgrade/hyper.rs | 208 +++++++++++++++++++++++++++ src/server/upgrade/mod.rs | 65 +++++++++ 3 files changed, 273 insertions(+), 274 deletions(-) delete mode 100644 src/server/upgrade.rs create mode 100644 src/server/upgrade/hyper.rs create mode 100644 src/server/upgrade/mod.rs diff --git a/src/server/upgrade.rs b/src/server/upgrade.rs deleted file mode 100644 index b65eeb2987..0000000000 --- a/src/server/upgrade.rs +++ /dev/null @@ -1,274 +0,0 @@ -//! Allows you to take an existing request or stream of data and convert it into a -//! WebSocket client. -use std::net::TcpStream; -use stream::{ - Stream, - AsTcpStream, -}; - -/// Intermediate representation of a half created websocket session. -/// Should be used to examine the client's handshake -/// accept the protocols requested, route the path, etc. -/// -/// Users should then call `accept` or `deny` to complete the handshake -/// and start a session. -pub struct WsUpgrade -where S: Stream, -{ - stream: S, - request: hyper::Request, -} - -impl WsUpgrade -where S: Stream, -{ - pub fn accept(self) { - unimplemented!(); - } - - pub fn reject(self) -> S { - unimplemented!(); - } - - pub fn into_stream(self) -> S { - unimplemented!(); - } -} - -impl WsUpgrade -where S: Stream + AsTcpStream, -{ - pub fn tcp_stream(&self) -> &TcpStream { - self.stream.as_tcp() - } -} - -/// Trait to take a stream or similar and attempt to recover the start of a -/// websocket handshake from it. -/// Should be used when a stream might contain a request for a websocket session. -/// -/// If an upgrade request can be parsed, one can accept or deny the handshake with -/// the `WsUpgrade` struct. -/// Otherwise the original stream is returned along with an error. -/// -/// Note: the stream is owned because the websocket client expects to own its stream. -pub trait IntoWs { - type Stream: Stream; - type Error; - /// Attempt to parse the start of a Websocket handshake, later with the returned - /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to - /// send a handshake rejection response. - fn into_ws(mut self) -> Result, Self::Error>; -} - -pub mod hyper { - extern crate hyper; - - use std::convert::From; - use std::error::Error; - use hyper::http::h1::parse_request; - use hyper::net::NetworkStream; - use header::{ - WebSocketKey, - WebSocketVersion, - }; - use std::fmt::{ - Formatter, - Display, - self, - }; - use stream::Stream; - use super::{ - IntoWs, - WsUpgrade, - }; - use std::io::{ - Read, - Write, - self, - }; - - pub use hyper::http::h1::Incoming; - pub use hyper::method::Method; - pub use hyper::version::HttpVersion; - pub use hyper::uri::RequestUri; - pub use hyper::buffer::BufReader; - pub use hyper::server::Request as HyperRequest; - pub use hyper::header::{ - Upgrade, - ProtocolName, - Connection, - ConnectionOption, - }; - - pub type Request = Incoming<(Method, RequestUri)>; - - pub struct RequestStreamPair(pub S, pub Request); - - #[derive(Debug)] - pub enum HyperIntoWsError { - MethodNotGet, - UnsupportedHttpVersion, - UnsupportedWebsocketVersion, - NoSecWsKeyHeader, - NoWsUpgradeHeader, - NoUpgradeHeader, - NoWsConnectionHeader, - NoConnectionHeader, - /// IO error from reading the underlying socket - Io(io::Error), - /// Error while parsing an incoming request - Parsing(hyper::error::Error), - } - - impl Display for HyperIntoWsError { - fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { - fmt.write_str(self.description()) - } - } - - impl Error for HyperIntoWsError { - fn description(&self) -> &str { - use self::HyperIntoWsError::*; - match self { - &MethodNotGet => "Request method must be GET", - &UnsupportedHttpVersion => "Unsupported request HTTP version", - &UnsupportedWebsocketVersion => "Unsupported WebSocket version", - &NoSecWsKeyHeader => "Missing Sec-WebSocket-Key header", - &NoWsUpgradeHeader => "Invalid Upgrade WebSocket header", - &NoUpgradeHeader => "Missing Upgrade WebSocket header", - &NoWsConnectionHeader => "Invalid Connection WebSocket header", - &NoConnectionHeader => "Missing Connection WebSocket header", - &Io(ref e) => e.description(), - &Parsing(ref e) => e.description(), - } - } - - fn cause(&self) -> Option<&Error> { - match *self { - HyperIntoWsError::Io(ref e) => Some(e), - HyperIntoWsError::Parsing(ref e) => Some(e), - _ => None, - } - } - } - - impl From for HyperIntoWsError { - fn from(err: io::Error) -> Self { - HyperIntoWsError::Io(err) - } - } - - impl From for HyperIntoWsError { - fn from(err: hyper::error::Error) -> Self { - HyperIntoWsError::Parsing(err) - } - } - - impl IntoWs for S - where S: Stream, - { - type Stream = S; - type Error = (Self, Option, HyperIntoWsError); - - fn into_ws(mut self) -> Result, Self::Error> { - let request = { - let mut reader = BufReader::new(self.reader()); - parse_request(&mut reader) - }; - - let request = match request { - Ok(r) => r, - Err(e) => return Err((self, None, e.into())), - }; - - match validate(&request) { - Ok(_) => Ok(WsUpgrade { - stream: self, - request: request, - }), - Err(e) => Err((self, Some(request), e)), - } - } - } - - impl IntoWs for RequestStreamPair - where S: Stream, - { - type Stream = S; - type Error = (S, Request, HyperIntoWsError); - - fn into_ws(self) -> Result, Self::Error> { - match validate(&self.1) { - Ok(_) => Ok(WsUpgrade { - stream: self.0, - request: self.1, - }), - Err(e) => Err((self.0, self.1, e)), - } - } - } - - // impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { - // type Stream = Box; - // type Error = (HyperRequest<'a, 'b>, HyperIntoWsError); - - // fn into_ws(self) -> Result, Self::Error> { - // unimplemented!(); - // } - // } - - pub fn validate(request: &Request) -> Result<(), HyperIntoWsError> { - if request.subject.0 != Method::Get { - return Err(HyperIntoWsError::MethodNotGet); - } - - if request.version == HttpVersion::Http09 - || request.version == HttpVersion::Http10 - { - return Err(HyperIntoWsError::UnsupportedHttpVersion); - } - - if let Some(version) = request.headers.get::() { - if version != &WebSocketVersion::WebSocket13 { - return Err(HyperIntoWsError::UnsupportedWebsocketVersion); - } - } - - if request.headers.get::().is_none() { - return Err(HyperIntoWsError::NoSecWsKeyHeader); - } - - match request.headers.get() { - Some(&Upgrade(ref upgrade)) => { - if upgrade.iter().all(|u| u.name != ProtocolName::WebSocket) { - return Err(HyperIntoWsError::NoWsUpgradeHeader) - } - }, - None => return Err(HyperIntoWsError::NoUpgradeHeader), - }; - - fn check_connection_header(headers: &Vec) -> bool { - for header in headers { - if let &ConnectionOption::ConnectionHeader(ref h) = header { - if h as &str == "upgrade" { - return true; - } - } - } - false - } - - match request.headers.get() { - Some(&Connection(ref connection)) => { - if !check_connection_header(connection) { - return Err(HyperIntoWsError::NoWsConnectionHeader); - } - }, - None => return Err(HyperIntoWsError::NoConnectionHeader), - }; - - Ok(()) - } -} - diff --git a/src/server/upgrade/hyper.rs b/src/server/upgrade/hyper.rs new file mode 100644 index 0000000000..e40bb6c887 --- /dev/null +++ b/src/server/upgrade/hyper.rs @@ -0,0 +1,208 @@ +extern crate hyper; + +use std::convert::From; +use std::error::Error; +use hyper::http::h1::parse_request; +use hyper::net::NetworkStream; +use header::{ + WebSocketKey, + WebSocketVersion, +}; +use std::fmt::{ + Formatter, + Display, + self, +}; +use stream::Stream; +use super::{ + IntoWs, + WsUpgrade, +}; +use std::io::{ + Read, + Write, + self, +}; + +pub use hyper::http::h1::Incoming; +pub use hyper::method::Method; +pub use hyper::version::HttpVersion; +pub use hyper::uri::RequestUri; +pub use hyper::buffer::BufReader; +pub use hyper::server::Request as HyperRequest; +pub use hyper::header::{ + Upgrade, + ProtocolName, + Connection, + ConnectionOption, +}; + +pub type Request = Incoming<(Method, RequestUri)>; + +pub struct RequestStreamPair(pub S, pub Request); + +#[derive(Debug)] +pub enum HyperIntoWsError { + MethodNotGet, + UnsupportedHttpVersion, + UnsupportedWebsocketVersion, + NoSecWsKeyHeader, + NoWsUpgradeHeader, + NoUpgradeHeader, + NoWsConnectionHeader, + NoConnectionHeader, + /// IO error from reading the underlying socket + Io(io::Error), + /// Error while parsing an incoming request + Parsing(hyper::error::Error), +} + +impl Display for HyperIntoWsError { + fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { + fmt.write_str(self.description()) + } +} + +impl Error for HyperIntoWsError { + fn description(&self) -> &str { + use self::HyperIntoWsError::*; + match self { + &MethodNotGet => "Request method must be GET", + &UnsupportedHttpVersion => "Unsupported request HTTP version", + &UnsupportedWebsocketVersion => "Unsupported WebSocket version", + &NoSecWsKeyHeader => "Missing Sec-WebSocket-Key header", + &NoWsUpgradeHeader => "Invalid Upgrade WebSocket header", + &NoUpgradeHeader => "Missing Upgrade WebSocket header", + &NoWsConnectionHeader => "Invalid Connection WebSocket header", + &NoConnectionHeader => "Missing Connection WebSocket header", + &Io(ref e) => e.description(), + &Parsing(ref e) => e.description(), + } + } + + fn cause(&self) -> Option<&Error> { + match *self { + HyperIntoWsError::Io(ref e) => Some(e), + HyperIntoWsError::Parsing(ref e) => Some(e), + _ => None, + } + } +} + +impl From for HyperIntoWsError { + fn from(err: io::Error) -> Self { + HyperIntoWsError::Io(err) + } +} + +impl From for HyperIntoWsError { + fn from(err: hyper::error::Error) -> Self { + HyperIntoWsError::Parsing(err) + } +} + +impl IntoWs for S +where S: Stream, +{ + type Stream = S; + type Error = (Self, Option, HyperIntoWsError); + + fn into_ws(mut self) -> Result, Self::Error> { + let request = { + let mut reader = BufReader::new(self.reader()); + parse_request(&mut reader) + }; + + let request = match request { + Ok(r) => r, + Err(e) => return Err((self, None, e.into())), + }; + + match validate(&request) { + Ok(_) => Ok(WsUpgrade { + stream: self, + request: request, + }), + Err(e) => Err((self, Some(request), e)), + } + } +} + +impl IntoWs for RequestStreamPair +where S: Stream, +{ + type Stream = S; + type Error = (S, Request, HyperIntoWsError); + + fn into_ws(self) -> Result, Self::Error> { + match validate(&self.1) { + Ok(_) => Ok(WsUpgrade { + stream: self.0, + request: self.1, + }), + Err(e) => Err((self.0, self.1, e)), + } + } +} + +// impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { +// type Stream = Box; +// type Error = (HyperRequest<'a, 'b>, HyperIntoWsError); + +// fn into_ws(self) -> Result, Self::Error> { +// unimplemented!(); +// } +// } + +pub fn validate(request: &Request) -> Result<(), HyperIntoWsError> { + if request.subject.0 != Method::Get { + return Err(HyperIntoWsError::MethodNotGet); + } + + if request.version == HttpVersion::Http09 + || request.version == HttpVersion::Http10 + { + return Err(HyperIntoWsError::UnsupportedHttpVersion); + } + + if let Some(version) = request.headers.get::() { + if version != &WebSocketVersion::WebSocket13 { + return Err(HyperIntoWsError::UnsupportedWebsocketVersion); + } + } + + if request.headers.get::().is_none() { + return Err(HyperIntoWsError::NoSecWsKeyHeader); + } + + match request.headers.get() { + Some(&Upgrade(ref upgrade)) => { + if upgrade.iter().all(|u| u.name != ProtocolName::WebSocket) { + return Err(HyperIntoWsError::NoWsUpgradeHeader) + } + }, + None => return Err(HyperIntoWsError::NoUpgradeHeader), + }; + + fn check_connection_header(headers: &Vec) -> bool { + for header in headers { + if let &ConnectionOption::ConnectionHeader(ref h) = header { + if h as &str == "upgrade" { + return true; + } + } + } + false + } + + match request.headers.get() { + Some(&Connection(ref connection)) => { + if !check_connection_header(connection) { + return Err(HyperIntoWsError::NoWsConnectionHeader); + } + }, + None => return Err(HyperIntoWsError::NoConnectionHeader), + }; + + Ok(()) +} diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs new file mode 100644 index 0000000000..88fd4dfbe4 --- /dev/null +++ b/src/server/upgrade/mod.rs @@ -0,0 +1,65 @@ +//! Allows you to take an existing request or stream of data and convert it into a +//! WebSocket client. +use std::net::TcpStream; +use stream::{ + Stream, + AsTcpStream, +}; + +pub mod hyper; + +/// Intermediate representation of a half created websocket session. +/// Should be used to examine the client's handshake +/// accept the protocols requested, route the path, etc. +/// +/// Users should then call `accept` or `deny` to complete the handshake +/// and start a session. +pub struct WsUpgrade +where S: Stream, +{ + stream: S, + request: hyper::Request, +} + +impl WsUpgrade +where S: Stream, +{ + pub fn accept(self) { + unimplemented!(); + } + + pub fn reject(self) -> S { + unimplemented!(); + } + + pub fn into_stream(self) -> S { + unimplemented!(); + } +} + +impl WsUpgrade +where S: Stream + AsTcpStream, +{ + pub fn tcp_stream(&self) -> &TcpStream { + self.stream.as_tcp() + } +} + +/// Trait to take a stream or similar and attempt to recover the start of a +/// websocket handshake from it. +/// Should be used when a stream might contain a request for a websocket session. +/// +/// If an upgrade request can be parsed, one can accept or deny the handshake with +/// the `WsUpgrade` struct. +/// Otherwise the original stream is returned along with an error. +/// +/// Note: the stream is owned because the websocket client expects to own its stream. +pub trait IntoWs { + type Stream: Stream; + type Error; + /// Attempt to parse the start of a Websocket handshake, later with the returned + /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to + /// send a handshake rejection response. + fn into_ws(mut self) -> Result, Self::Error>; +} + From 9f14e1a8eba1f8fdc0cc53a4b485ef28a23299c9 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Sat, 16 Jul 2016 21:34:27 -0400 Subject: [PATCH 11/32] implemented AsTcpStream for Box This was initially very hard because one cannot clone unsized types and put them on the stack. I was about to revert to WebSocketStream style enum-variant-per-sized-type style but I managed to live another day without all those match blocks. --- src/stream.rs | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/stream.rs b/src/stream.rs index d74f717e10..d8bf80e34d 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,6 +1,5 @@ //! Provides the default stream type for WebSocket connections. use std::ops::Deref; -use std::any::Any; use std::io::{ self, Read, @@ -15,11 +14,13 @@ pub use openssl::ssl::{ SslContext, }; -pub trait AsTcpStream: Read + Write + Any + 'static { +pub trait AsTcpStream: Read + Write { fn as_tcp(&self) -> &TcpStream; fn duplicate(&self) -> io::Result where Self: Sized; + + fn box_duplicate(&self) -> io::Result>; } impl AsTcpStream for TcpStream { @@ -30,6 +31,10 @@ impl AsTcpStream for TcpStream { fn duplicate(&self) -> io::Result { self.try_clone() } + + fn box_duplicate(&self) -> io::Result> { + Ok(Box::new(try!(self.duplicate()))) + } } impl AsTcpStream for SslStream { @@ -40,6 +45,10 @@ impl AsTcpStream for SslStream { fn duplicate(&self) -> io::Result { self.try_clone() } + + fn box_duplicate(&self) -> io::Result> { + Ok(Box::new(try!(self.duplicate()))) + } } impl AsTcpStream for Box { @@ -47,9 +56,12 @@ impl AsTcpStream for Box { self.deref().as_tcp() } - fn duplicate(&self) -> io::Result - where Self: Any { - unimplemented!(); + fn duplicate(&self) -> io::Result { + self.deref().box_duplicate() + } + + fn box_duplicate(&self) -> io::Result> { + self.duplicate() } } From 6bb3f8fc86e0099bc40ed6887024335f71a5546e Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Sat, 16 Jul 2016 23:08:49 -0400 Subject: [PATCH 12/32] Steps toward implementing IntoWs for Hyper request The validation is done, now the only work is to extract the underlying stream. --- Cargo.toml | 3 +- src/client/mod.rs | 2 + src/lib.rs | 1 - src/receiver.rs | 12 ++- src/server/upgrade/hyper.rs | 89 ++++++++++++++++------ src/server/upgrade/mod.rs | 5 +- src/stream.rs | 147 +++++++++++++++--------------------- src/ws/receiver.rs | 9 ++- src/ws/sender.rs | 8 +- 9 files changed, 150 insertions(+), 126 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 884b3d0b8c..4cf8ea39c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,8 @@ keywords = ["websocket", "websockets", "rfc6455"] license = "MIT" [dependencies] -hyper = ">=0.7, <0.11" +hyper = ">=0.7, <0.10" +mio = "0.5.1" unicase = "1.0.1" openssl = "0.7.6" url = "1.0" diff --git a/src/client/mod.rs b/src/client/mod.rs index f828a74f4e..02112b03d6 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -109,6 +109,7 @@ impl Client>, Receiver>, Receiver>> { pub fn connect_agnostic(components: C, ssl_context: Option<&SslContext>) -> WebSocketResult, Box>> where C: ToWebSocketUrlComponents @@ -141,6 +142,7 @@ impl Client>, Receiver>> { } } +// TODO: add method to expose tcp to edit things impl Client, Receiver> where S: AsTcpStream, { diff --git a/src/lib.rs b/src/lib.rs index 1e38a1ace3..843eebaeea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,7 +55,6 @@ pub use self::server::Server; pub use self::dataframe::DataFrame; pub use self::message::Message; pub use self::stream::Stream; -pub use self::stream::AsTcpStream; pub use self::ws::Sender; pub use self::ws::Receiver; diff --git a/src/receiver.rs b/src/receiver.rs index a7a845ccc2..6d1b1def53 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -12,27 +12,25 @@ pub use stream::Shutdown; /// A Receiver that wraps a Reader and provides a default implementation using /// DataFrames and Messages. -pub struct Receiver { - inner: BufReader, +pub struct Receiver { buffer: Vec, mask: bool, } -impl Receiver -where R: Read, -{ +impl Receiver { /// Create a new Receiver using the specified Reader. - pub fn new(reader: BufReader, mask: bool) -> Receiver { + pub fn new(mask: bool) -> Receiver { Receiver { - inner: reader, buffer: Vec::new(), mask: mask, } } + /// Returns a reference to the underlying Reader. pub fn get_ref(&self) -> &BufReader { &self.inner } + /// Returns a mutable reference to the underlying Reader. pub fn get_mut(&mut self) -> &mut BufReader { &mut self.inner diff --git a/src/server/upgrade/hyper.rs b/src/server/upgrade/hyper.rs index e40bb6c887..5eb0ccd281 100644 --- a/src/server/upgrade/hyper.rs +++ b/src/server/upgrade/hyper.rs @@ -1,9 +1,17 @@ extern crate hyper; +extern crate openssl; +use std::net::TcpStream; +use std::any::Any; use std::convert::From; use std::error::Error; +use openssl::ssl::SslStream; use hyper::http::h1::parse_request; -use hyper::net::NetworkStream; +use hyper::net::{ + NetworkStream, + HttpStream, + HttpsStream, +}; use header::{ WebSocketKey, WebSocketVersion, @@ -13,7 +21,10 @@ use std::fmt::{ Display, self, }; -use stream::Stream; +use stream::{ + Stream, + AsTcpStream, +}; use super::{ IntoWs, WsUpgrade, @@ -31,6 +42,7 @@ pub use hyper::uri::RequestUri; pub use hyper::buffer::BufReader; pub use hyper::server::Request as HyperRequest; pub use hyper::header::{ + Headers, Upgrade, ProtocolName, Connection, @@ -51,6 +63,7 @@ pub enum HyperIntoWsError { NoUpgradeHeader, NoWsConnectionHeader, NoConnectionHeader, + UnknownNetworkStream, /// IO error from reading the underlying socket Io(io::Error), /// Error while parsing an incoming request @@ -75,6 +88,7 @@ impl Error for HyperIntoWsError { &NoUpgradeHeader => "Missing Upgrade WebSocket header", &NoWsConnectionHeader => "Invalid Connection WebSocket header", &NoConnectionHeader => "Missing Connection WebSocket header", + &UnknownNetworkStream => "Cannot downcast to known impl of NetworkStream", &Io(ref e) => e.description(), &Parsing(ref e) => e.description(), } @@ -101,8 +115,11 @@ impl From for HyperIntoWsError { } } -impl IntoWs for S -where S: Stream, +// TODO: Move this into the main upgrade module +impl IntoWs for S +where S: Stream, + R: Read, + W: Write, { type Stream = S; type Error = (Self, Option, HyperIntoWsError); @@ -118,7 +135,7 @@ where S: Stream, Err(e) => return Err((self, None, e.into())), }; - match validate(&request) { + match validate(&request.subject.0, &request.version, &request.headers) { Ok(_) => Ok(WsUpgrade { stream: self, request: request, @@ -128,14 +145,17 @@ where S: Stream, } } -impl IntoWs for RequestStreamPair -where S: Stream, +// TODO: Move this into the main upgrade module +impl IntoWs for RequestStreamPair +where S: Stream, + R: Read, + W: Write, { type Stream = S; type Error = (S, Request, HyperIntoWsError); fn into_ws(self) -> Result, Self::Error> { - match validate(&self.1) { + match validate(&self.1.subject.0, &self.1.version, &self.1.headers) { Ok(_) => Ok(WsUpgrade { stream: self.0, request: self.1, @@ -145,37 +165,57 @@ where S: Stream, } } -// impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { -// type Stream = Box; -// type Error = (HyperRequest<'a, 'b>, HyperIntoWsError); +impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { + type Stream = Box; + type Error = (HyperRequest<'a, 'b>, HyperIntoWsError); + + fn into_ws(self) -> Result, Self::Error> { + if let Err(e) = validate(&self.method, &self.version, &self.headers) { + return Err((self, e)); + } -// fn into_ws(self) -> Result, Self::Error> { -// unimplemented!(); -// } -// } + let stream: Option> = unimplemented!(); + + if let Some(s) = stream { + Ok(WsUpgrade { + stream: s, + request: Incoming { + version: self.version, + headers: self.headers, + subject: (self.method, self.uri), + }, + }) + } else { + Err((self, HyperIntoWsError::UnknownNetworkStream)) + } + } +} -pub fn validate(request: &Request) -> Result<(), HyperIntoWsError> { - if request.subject.0 != Method::Get { +pub fn validate( + method: &Method, + version: &HttpVersion, + headers: &Headers +) -> Result<(), HyperIntoWsError> +{ + if *method != Method::Get { return Err(HyperIntoWsError::MethodNotGet); } - if request.version == HttpVersion::Http09 - || request.version == HttpVersion::Http10 - { + if *version == HttpVersion::Http09 || *version == HttpVersion::Http10 { return Err(HyperIntoWsError::UnsupportedHttpVersion); } - if let Some(version) = request.headers.get::() { + if let Some(version) = headers.get::() { if version != &WebSocketVersion::WebSocket13 { return Err(HyperIntoWsError::UnsupportedWebsocketVersion); } } - if request.headers.get::().is_none() { + if headers.get::().is_none() { return Err(HyperIntoWsError::NoSecWsKeyHeader); } - match request.headers.get() { + match headers.get() { Some(&Upgrade(ref upgrade)) => { if upgrade.iter().all(|u| u.name != ProtocolName::WebSocket) { return Err(HyperIntoWsError::NoWsUpgradeHeader) @@ -195,7 +235,7 @@ pub fn validate(request: &Request) -> Result<(), HyperIntoWsError> { false } - match request.headers.get() { + match headers.get() { Some(&Connection(ref connection)) => { if !check_connection_header(connection) { return Err(HyperIntoWsError::NoWsConnectionHeader); @@ -206,3 +246,4 @@ pub fn validate(request: &Request) -> Result<(), HyperIntoWsError> { Ok(()) } + diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index 88fd4dfbe4..d616c74210 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -1,6 +1,7 @@ //! Allows you to take an existing request or stream of data and convert it into a //! WebSocket client. use std::net::TcpStream; +use std::io::Read; use stream::{ Stream, AsTcpStream, @@ -15,7 +16,9 @@ pub mod hyper; /// Users should then call `accept` or `deny` to complete the handshake /// and start a session. pub struct WsUpgrade -where S: Stream, +where S: Stream, + R: Read, + W: Write, { stream: S, request: hyper::Request, diff --git a/src/stream.rs b/src/stream.rs index d8bf80e34d..f6c0e08130 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,129 +1,102 @@ //! Provides the default stream type for WebSocket connections. -use std::ops::Deref; +extern crate mio; + use std::io::{ self, Read, Write }; -pub use std::net::{ - TcpStream, - Shutdown, -}; +pub use std::net::TcpStream; +pub use std::net::Shutdown; pub use openssl::ssl::{ SslStream, SslContext, }; +pub use self::mio::Evented; -pub trait AsTcpStream: Read + Write { - fn as_tcp(&self) -> &TcpStream; - - fn duplicate(&self) -> io::Result - where Self: Sized; - - fn box_duplicate(&self) -> io::Result>; -} - -impl AsTcpStream for TcpStream { - fn as_tcp(&self) -> &TcpStream { - self - } - - fn duplicate(&self) -> io::Result { - self.try_clone() - } - - fn box_duplicate(&self) -> io::Result> { - Ok(Box::new(try!(self.duplicate()))) - } -} - -impl AsTcpStream for SslStream { - fn as_tcp(&self) -> &TcpStream { - self.get_ref() - } - - fn duplicate(&self) -> io::Result { - self.try_clone() - } - - fn box_duplicate(&self) -> io::Result> { - Ok(Box::new(try!(self.duplicate()))) - } -} - -impl AsTcpStream for Box { - fn as_tcp(&self) -> &TcpStream { - self.deref().as_tcp() - } - - fn duplicate(&self) -> io::Result { - self.deref().box_duplicate() - } - - fn box_duplicate(&self) -> io::Result> { - self.duplicate() - } +pub trait Splittable +where R: Read + Evented, + W: Write + Evented, +{ + fn split(self) -> io::Result<(R, W)>; } -/// Represents a stream that can be read from, written to, and split into two. +/// Represents a stream that can be read from, and written to. /// This is an abstraction around readable and writable things to be able /// to speak websockets over ssl, tcp, unix sockets, etc. -pub trait Stream +pub trait Stream +where R: Read + Evented, + W: Write + Evented, { - /// The reading component of the stream - type R: Read; - /// The writing component of the stream - type W: Write; - /// Get a mutable borrow to the reading component of this stream - fn reader(&mut self) -> &mut Read; + fn reader(&mut self) -> &mut R; /// Get a mutable borrow to the writing component of this stream - fn writer(&mut self) -> &mut Write; - - /// Split this stream into readable and writable components. - /// The motivation behind this is to be able to read on one thread - /// and send messages on another. - fn split(self) -> io::Result<(Self::R, Self::W)>; + fn writer(&mut self) -> &mut W; } -impl Stream for (R, W) -where R: Read, - W: Write, +pub struct ReadWritePair(pub R, pub W) +where R: Read + Evented, + W: Write + Evented; + +impl Splittable for ReadWritePair +where R: Read + Evented, + W: Write + Evented, { - type R = R; - type W = W; + fn split(self) -> io::Result<(R, W)> { + Ok((self.0, self.1)) + } +} - fn reader(&mut self) -> &mut Read { +impl Stream for ReadWritePair +where R: Read + Evented, + W: Write + Evented, +{ + #[inline] + fn reader(&mut self) -> &mut R { &mut self.0 } - fn writer(&mut self) -> &mut Write { + #[inline] + fn writer(&mut self) -> &mut W { &mut self.1 } +} - fn split(self) -> io::Result<(Self::R, Self::W)> { - Ok(self) +impl Splittable for TcpStream { + fn split(self) -> io::Result<(TcpStream, TcpStream)> { + self.try_clone().map(|s| (s, self)) } } -impl Stream for S -where S: AsTcpStream, +impl Stream for S +where S: Read + Write + Evented, { - type R = Self; - type W = Self; - - fn reader(&mut self) -> &mut Read { + #[inline] + fn reader(&mut self) -> &mut S { self } - fn writer(&mut self) -> &mut Write { + #[inline] + fn writer(&mut self) -> &mut S { self } +} - fn split(self) -> io::Result<(Self::R, Self::W)> { - Ok((try!(self.duplicate()), self)) - } +pub trait AsTcpStream { + fn as_tcp(&self) -> &TcpStream; +} + +impl AsTcpStream for TcpStream { + fn as_tcp(&self) -> &TcpStream { + &self + } +} + +impl AsTcpStream for SslStream { + fn as_tcp(&self) -> &TcpStream { + self.get_ref() + } } /// Marker struct for having no SSL context in a struct. diff --git a/src/ws/receiver.rs b/src/ws/receiver.rs index cfd934e41b..84690909e8 100644 --- a/src/ws/receiver.rs +++ b/src/ws/receiver.rs @@ -3,6 +3,7 @@ //! Also provides iterators over data frames and messages. //! See the `ws` module documentation for more information. +use std::io::Read; use std::marker::PhantomData; use ws::Message; use ws::dataframe::DataFrame; @@ -13,6 +14,7 @@ pub trait Receiver: Sized where F: DataFrame { /// Reads a single data frame from this receiver. fn recv_dataframe(&mut self) -> WebSocketResult; + /// Returns the data frames that constitute one message. fn recv_message_dataframes(&mut self) -> WebSocketResult>; @@ -23,6 +25,7 @@ where F: DataFrame { _dataframe: PhantomData } } + /// Reads a single message from this receiver. fn recv_message<'m, D, M, I>(&mut self) -> WebSocketResult where M: Message<'m, D, DataFrameIterator = I>, @@ -35,7 +38,8 @@ where F: DataFrame { /// Returns an iterator over incoming messages. fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Self, D, F, M> - where M: Message<'a, D>, D: DataFrame { + where M: Message<'a, D>, D: DataFrame + { MessageIterator { inner: self, _dataframe: PhantomData, @@ -47,7 +51,8 @@ where F: DataFrame { /// An iterator over data frames from a Receiver. pub struct DataFrameIterator<'a, R, D> -where R: 'a + Receiver, D: DataFrame { +where R: 'a + Receiver, D: DataFrame +{ inner: &'a mut R, _dataframe: PhantomData } diff --git a/src/ws/sender.rs b/src/ws/sender.rs index f017f8bd2b..6d8c7d8e37 100644 --- a/src/ws/sender.rs +++ b/src/ws/sender.rs @@ -2,6 +2,7 @@ //! //! See the `ws` module documentation for more information. +use std::io::Write; use ws::Message; use ws::dataframe::DataFrame; use result::WebSocketResult; @@ -9,12 +10,13 @@ use result::WebSocketResult; /// A trait for sending data frames and messages. pub trait Sender { /// Sends a single data frame using this sender. - fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> + fn send_dataframe(&mut self, writer: &mut Write, dataframe: &D) -> WebSocketResult<()> where D: DataFrame; /// Sends a single message using this sender. - fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> - where M: Message<'m, D>, D: DataFrame { + fn send_message<'m, M, D>(&mut self, writer: &mut Write, message: &'m M) -> WebSocketResult<()> + where M: Message<'m, D>, D: DataFrame + { for ref dataframe in message.dataframes() { try!(self.send_dataframe(dataframe)); } From ebb11ebdb3c4e6c9b1f098e71338e2837939dd88 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Mon, 27 Mar 2017 18:49:23 -0400 Subject: [PATCH 13/32] sender and receiver not holding the stream anymore, client uses stream --- src/client/mod.rs | 470 ++++++++++++++++++------------------ src/client/response.rs | 37 +-- src/receiver.rs | 85 ++++--- src/sender.rs | 75 +++--- src/server/upgrade/hyper.rs | 63 +++-- src/server/upgrade/mod.rs | 4 +- src/stream.rs | 44 ++-- src/ws/receiver.rs | 141 ++++++----- src/ws/sender.rs | 14 +- 9 files changed, 501 insertions(+), 432 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 02112b03d6..d3a6d1a529 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,17 +1,21 @@ //! Contains the WebSocket client. - use std::net::TcpStream; use std::marker::PhantomData; use std::io::Result as IoResult; use std::ops::Deref; use ws; +use ws::sender::Sender as SenderTrait; use ws::util::url::ToWebSocketUrlComponents; -use ws::receiver::{DataFrameIterator, MessageIterator}; +use ws::receiver::{ + DataFrameIterator, + MessageIterator, +}; +use ws::receiver::Receiver as ReceiverTrait; use result::WebSocketResult; use stream::{ - AsTcpStream, - Stream, + AsTcpStream, + Stream, }; use dataframe::DataFrame; use ws::dataframe::DataFrame as DataFrameable; @@ -60,253 +64,259 @@ pub mod response; ///client.send_message(&message).unwrap(); // Send message ///# } ///``` -pub struct Client { - sender: S, - receiver: R, - _dataframe: PhantomData +pub struct Client + where S: Stream, +{ + stream: S, + sender: Sender, + receiver: Receiver, } +// TODO: add back client.split() -impl Client, Receiver> { - /// Connects to the given ws:// URL and return a Request to be sent. - /// - /// If you would like to use a secure connection (wss://), please use `connect_secure`. - /// - /// A connection is established, however the request is not sent to - /// the server until a call to ```send()```. - pub fn connect(components: C) -> WebSocketResult> - where C: ToWebSocketUrlComponents, - { - let (host, resource_name, secure) = try!(components.to_components()); - let stream = TcpStream::connect((&host.hostname[..], host.port.unwrap_or(80))); - let stream = try!(stream); - Request::new((host, resource_name, secure), try!(stream.split())) - } +pub struct ClientBuilder + where S: Stream, +{ + stream: S, } -impl Client>, Receiver>> { - /// Connects to the specified wss:// URL using the given SSL context. - /// - /// If you would like to use an insecure connection (ws://), please use `connect`. - /// - /// A connection is established, however the request is not sent to - /// the server until a call to ```send()```. - pub fn connect_secure(components: C, context: Option<&SslContext>) -> WebSocketResult, SslStream>> - where C: ToWebSocketUrlComponents, - { - let (host, resource_name, secure) = try!(components.to_components()); +// impl Client { +// /// Connects to the given ws:// URL and return a Request to be sent. +// /// +// /// If you would like to use a secure connection (wss://), please use `connect_secure`. +// /// +// /// A connection is established, however the request is not sent to +// /// the server until a call to ```send()```. +// pub fn connect(components: C) -> WebSocketResult> +// where C: ToWebSocketUrlComponents, +// { +// let (host, resource_name, secure) = try!(components.to_components()); +// let stream = TcpStream::connect((&host.hostname[..], host.port.unwrap_or(80))); +// let stream = try!(stream); +// Request::new((host, resource_name, secure), try!(stream.split())) +// } +// } - let stream = TcpStream::connect((&host.hostname[..], host.port.unwrap_or(443))); - let stream = try!(stream); - let sslstream = if let Some(c) = context { - SslStream::connect(c, stream) - } else { - let context = try!(SslContext::new(SslMethod::Tlsv1)); - SslStream::connect(&context, stream) - }; - let sslstream = try!(sslstream); +// impl Client> { +// /// Connects to the specified wss:// URL using the given SSL context. +// /// +// /// If you would like to use an insecure connection (ws://), please use `connect`. +// /// +// /// A connection is established, however the request is not sent to +// /// the server until a call to ```send()```. +// pub fn connect_secure(components: C, context: Option<&SslContext>) -> WebSocketResult, SslStream>> +// where C: ToWebSocketUrlComponents, +// { +// let (host, resource_name, secure) = try!(components.to_components()); - Request::new((host, resource_name, secure), try!(sslstream.split())) - } -} +// let stream = TcpStream::connect((&host.hostname[..], host.port.unwrap_or(443))); +// let stream = try!(stream); +// let sslstream = if let Some(c) = context { +// SslStream::connect(c, stream) +// } else { +// let context = try!(SslContext::new(SslMethod::Tlsv1)); +// SslStream::connect(&context, stream) +// }; +// let sslstream = try!(sslstream); -// TODO: look at how to get hyper to give you a stream then maybe remove this -impl Client>, Receiver>> { - pub fn connect_agnostic(components: C, ssl_context: Option<&SslContext>) -> WebSocketResult, Box>> - where C: ToWebSocketUrlComponents - { - let (host, resource_name, secure) = try!(components.to_components()); - let port = match host.port { - Some(p) => p, - None => if secure { - 443 - } else { - 80 - }, - }; - let tcp_stream = try!(TcpStream::connect((&host.hostname[..], port))); +// Request::new((host, resource_name, secure), try!(sslstream.split())) +// } +// } - let stream: Box = if secure { - if let Some(c) = ssl_context { - Box::new(try!(SslStream::connect(c, tcp_stream))) - } else { - let context = try!(SslContext::new(SslMethod::Tlsv1)); - Box::new(try!(SslStream::connect(&context, tcp_stream))) - } - } else { - Box::new(tcp_stream) - }; +// // TODO: look at how to get hyper to give you a stream then maybe remove this +// impl Client> { +// pub fn connect_agnostic(components: C, ssl_context: Option<&SslContext>) -> WebSocketResult, Box>> +// where C: ToWebSocketUrlComponents +// { +// let (host, resource_name, secure) = try!(components.to_components()); +// let port = match host.port { +// Some(p) => p, +// None => if secure { +// 443 +// } else { +// 80 +// }, +// }; +// let tcp_stream = try!(TcpStream::connect((&host.hostname[..], port))); - let (read, write) = (try!(stream.duplicate()), stream); - - Request::new((host, resource_name, secure), (read, write)) - } -} - -// TODO: add method to expose tcp to edit things -impl Client, Receiver> -where S: AsTcpStream, -{ - /// Shuts down the sending half of the client connection, will cause all pending - /// and future IO to return immediately with an appropriate value. - pub fn shutdown_sender(&self) -> IoResult<()> { - self.sender.shutdown() - } +// let stream: Box = if secure { +// if let Some(c) = ssl_context { +// Box::new(try!(SslStream::connect(c, tcp_stream))) +// } else { +// let context = try!(SslContext::new(SslMethod::Tlsv1)); +// Box::new(try!(SslStream::connect(&context, tcp_stream))) +// } +// } else { +// Box::new(tcp_stream) +// }; - /// Shuts down the receiving half of the client connection, will cause all pending - /// and future IO to return immediately with an appropriate value. - pub fn shutdown_receiver(&self) -> IoResult<()> { - self.receiver.shutdown() - } +// let (read, write) = (try!(stream.duplicate()), stream); - /// Shuts down the client connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&self) -> IoResult<()> { - self.receiver.shutdown_all() - } -} +// Request::new((host, resource_name, secure), (read, write)) +// } +// } -impl> Client { - /// Creates a Client from the given Sender and Receiver. - /// - /// Essentially the opposite of `Client.split()`. - pub fn new(sender: S, receiver: R) -> Client { - Client { - sender: sender, - receiver: receiver, - _dataframe: PhantomData - } - } +// // TODO: add method to expose tcp to edit things +// impl Client +// where S: AsTcpStream + Stream, +// { +// /// Shuts down the sending half of the client connection, will cause all pending +// /// and future IO to return immediately with an appropriate value. +// pub fn shutdown_sender(&self) -> IoResult<()> { +// self.sender.shutdown() +// } - /// Sends a single data frame to the remote endpoint. - pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> - where D: DataFrameable { - self.sender.send_dataframe(dataframe) - } +// /// Shuts down the receiving half of the client connection, will cause all pending +// /// and future IO to return immediately with an appropriate value. +// pub fn shutdown_receiver(&self) -> IoResult<()> { +// self.receiver.shutdown() +// } - /// Sends a single message to the remote endpoint. - pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> - where M: ws::Message<'m, D>, D: DataFrameable { - self.sender.send_message(message) - } +// /// Shuts down the client connection, will cause all pending and future IO to +// /// return immediately with an appropriate value. +// pub fn shutdown(&self) -> IoResult<()> { +// self.receiver.shutdown_all() +// } +// } - /// Reads a single data frame from the remote endpoint. - pub fn recv_dataframe(&mut self) -> WebSocketResult { - self.receiver.recv_dataframe() - } +impl Client + where S: Stream, +{ + /// Creates a Client from the given Sender and Receiver. + /// + /// Essentially the opposite of `Client.split()`. + fn new(stream: S) -> Client { + Client { + stream: stream, + // TODO: always true? + sender: Sender::new(true), + // TODO: always false? + receiver: Receiver::new(false), + } + } - /// Returns an iterator over incoming data frames. - pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, R, F> { - self.receiver.incoming_dataframes() - } + /// Sends a single data frame to the remote endpoint. + pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> + where D: DataFrameable { + self.sender.send_dataframe(self.stream.writer(), dataframe) + } - /// Reads a single message from this receiver. - pub fn recv_message<'m, M, I>(&mut self) -> WebSocketResult - where M: ws::Message<'m, F, DataFrameIterator = I>, I: Iterator { - self.receiver.recv_message() - } + /// Sends a single message to the remote endpoint. + pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> + where M: ws::Message<'m, D>, D: DataFrameable { + self.sender.send_message(self.stream.writer(), message) + } - /// Returns an iterator over incoming messages. - /// - ///```no_run - ///# extern crate websocket; - ///# fn main() { - ///use websocket::{Client, Message}; - ///# use websocket::client::request::Url; - ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL - ///# let request = Client::connect(url).unwrap(); // Connect to the server - ///# let response = request.send().unwrap(); // Send the request - ///# response.validate().unwrap(); // Ensure the response is valid - /// - ///let mut client = response.begin(); // Get a Client - /// - ///for message in client.incoming_messages() { - /// let message: Message = message.unwrap(); - /// println!("Recv: {:?}", message); - ///} - ///# } - ///``` - /// - /// Note that since this method mutably borrows the `Client`, it may be necessary to - /// first `split()` the `Client` and call `incoming_messages()` on the returned - /// `Receiver` to be able to send messages within an iteration. - /// - ///```no_run - ///# extern crate websocket; - ///# fn main() { - ///use websocket::{Client, Message, Sender, Receiver}; - ///# use websocket::client::request::Url; - ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL - ///# let request = Client::connect(url).unwrap(); // Connect to the server - ///# let response = request.send().unwrap(); // Send the request - ///# response.validate().unwrap(); // Ensure the response is valid - /// - ///let client = response.begin(); // Get a Client - ///let (mut sender, mut receiver) = client.split(); // Split the Client - ///for message in receiver.incoming_messages() { - /// let message: Message = message.unwrap(); - /// // Echo the message back - /// sender.send_message(&message).unwrap(); - ///} - ///# } - ///``` - pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, R, D, F, M> - where M: ws::Message<'a, D>, - D: DataFrameable - { - self.receiver.incoming_messages() - } + /// Reads a single data frame from the remote endpoint. + pub fn recv_dataframe(&mut self) -> WebSocketResult { + self.receiver.recv_dataframe(self.stream.reader()) + } - /// Returns a reference to the underlying Sender. - pub fn get_sender(&self) -> &S { - &self.sender - } + /// Returns an iterator over incoming data frames. + pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, S::Reader> { + self.receiver.incoming_dataframes(self.stream.reader()) + } - /// Returns a reference to the underlying Receiver. - pub fn get_receiver(&self) -> &R { - &self.receiver - } + /// Reads a single message from this receiver. + pub fn recv_message<'m, M, I>(&mut self) -> WebSocketResult + where M: ws::Message<'m, DataFrame, DataFrameIterator = I>, I: Iterator { + self.receiver.recv_message(self.stream.reader()) + } - /// Returns a mutable reference to the underlying Sender. - pub fn get_mut_sender(&mut self) -> &mut S { - &mut self.sender - } + pub fn stream_ref(&self) -> &S { + &self.stream + } - /// Returns a mutable reference to the underlying Receiver. - pub fn get_mut_receiver(&mut self) -> &mut R { - &mut self.receiver - } + pub fn stream_ref_mut(&mut self) -> &mut S { + &mut self.stream + } - /// Split this client into its constituent Sender and Receiver pair. - /// - /// This allows the Sender and Receiver to be sent to different threads. - /// - ///```no_run - ///# extern crate websocket; - ///# fn main() { - ///use websocket::{Client, Message, Sender, Receiver}; - ///use std::thread; - ///# use websocket::client::request::Url; - ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL - ///# let request = Client::connect(url).unwrap(); // Connect to the server - ///# let response = request.send().unwrap(); // Send the request - ///# response.validate().unwrap(); // Ensure the response is valid - /// - ///let client = response.begin(); // Get a Client - /// - ///let (mut sender, mut receiver) = client.split(); - /// - ///thread::spawn(move || { - /// for message in receiver.incoming_messages() { - /// let message: Message = message.unwrap(); - /// println!("Recv: {:?}", message); - /// } - ///}); - /// - ///let message = Message::text("Hello, World!"); - ///sender.send_message(&message).unwrap(); - ///# } - ///``` - pub fn split(self) -> (S, R) { - (self.sender, self.receiver) - } + /// Returns an iterator over incoming messages. + /// + ///```no_run + ///# extern crate websocket; + ///# fn main() { + ///use websocket::{Client, Message}; + ///# use websocket::client::request::Url; + ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL + ///# let request = Client::connect(url).unwrap(); // Connect to the server + ///# let response = request.send().unwrap(); // Send the request + ///# response.validate().unwrap(); // Ensure the response is valid + /// + ///let mut client = response.begin(); // Get a Client + /// + ///for message in client.incoming_messages() { + /// let message: Message = message.unwrap(); + /// println!("Recv: {:?}", message); + ///} + ///# } + ///``` + /// + /// Note that since this method mutably borrows the `Client`, it may be necessary to + /// first `split()` the `Client` and call `incoming_messages()` on the returned + /// `Receiver` to be able to send messages within an iteration. + /// + ///```no_run + ///# extern crate websocket; + ///# fn main() { + ///use websocket::{Client, Message, Sender, Receiver}; + ///# use websocket::client::request::Url; + ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL + ///# let request = Client::connect(url).unwrap(); // Connect to the server + ///# let response = request.send().unwrap(); // Send the request + ///# response.validate().unwrap(); // Ensure the response is valid + /// + ///let client = response.begin(); // Get a Client + ///let (mut sender, mut receiver) = client.split(); // Split the Client + ///for message in receiver.incoming_messages() { + /// let message: Message = message.unwrap(); + /// // Echo the message back + /// sender.send_message(&message).unwrap(); + ///} + ///# } + ///``` + pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, DataFrame, M, S::Reader> + where M: ws::Message<'a, D>, + D: DataFrameable + { + self.receiver.incoming_messages(self.stream.reader()) + } } + +// TODO +// impl Client +// where S: Splittable, +// { +// /// Split this client into its constituent Sender and Receiver pair. +// /// +// /// This allows the Sender and Receiver to be sent to different threads. +// /// +// ///```no_run +// ///# extern crate websocket; +// ///# fn main() { +// ///use websocket::{Client, Message, Sender, Receiver}; +// ///use std::thread; +// ///# use websocket::client::request::Url; +// ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL +// ///# let request = Client::connect(url).unwrap(); // Connect to the server +// ///# let response = request.send().unwrap(); // Send the request +// ///# response.validate().unwrap(); // Ensure the response is valid +// /// +// ///let client = response.begin(); // Get a Client +// /// +// ///let (mut sender, mut receiver) = client.split(); +// /// +// ///thread::spawn(move || { +// /// for message in receiver.incoming_messages() { +// /// let message: Message = message.unwrap(); +// /// println!("Recv: {:?}", message); +// /// } +// ///}); +// /// +// ///let message = Message::text("Hello, World!"); +// ///sender.send_message(&message).unwrap(); +// ///# } +// ///``` +// pub fn split(self) -> (Reader, Writer) { +// let cloned_stream = try!(self.stream.try_clone()); +// } +// } diff --git a/src/client/response.rs b/src/client/response.rs index e60b3fa31c..7c0cdfaf6e 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -116,21 +116,24 @@ impl Response { Ok(()) } - /// Consume this response and return a Client ready to transmit/receive data frames - /// using the data frame type D, Sender B and Receiver C. - /// - /// Does not check if the response was valid. Use `validate()` to ensure that the response constitutes a successful handshake. - pub fn begin_with(self, sender: B, receiver: C) -> Client - where B: ws::Sender, C: ws::Receiver, D: DataFrameable { - Client::new(sender, receiver) - } - /// Consume this response and return a Client ready to transmit/receive data frames. - /// - /// Does not check if the response was valid. Use `validate()` to ensure that the response constitutes a successful handshake. - pub fn begin(self) -> Client, Receiver> { - let (reader, writer) = self.into_inner(); - let sender = Sender::new(writer, true); - let receiver = Receiver::new(reader, false); - Client::new(sender, receiver) - } + // TODO + // /// Consume this response and return a Client ready to transmit/receive data frames + // /// using the data frame type D, Sender B and Receiver C. + // /// + // /// Does not check if the response was valid. Use `validate()` to ensure that the response constitutes a successful handshake. + // pub fn begin_with(self, sender: B, receiver: C) -> Client + // where B: ws::Sender, C: ws::Receiver, D: DataFrameable { + // Client::new(sender, receiver) + // } + + // TODO + // /// Consume this response and return a Client ready to transmit/receive data frames. + // /// + // /// Does not check if the response was valid. Use `validate()` to ensure that the response constitutes a successful handshake. + // pub fn begin(self) -> Client, Receiver> { + // let (reader, writer) = self.into_inner(); + // let sender = Sender::new(writer, true); + // let receiver = Receiver::new(reader, false); + // Client::new(sender, receiver) + // } } diff --git a/src/receiver.rs b/src/receiver.rs index 6d1b1def53..fd8b06c985 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -7,9 +7,50 @@ use hyper::buffer::BufReader; use dataframe::{DataFrame, Opcode}; use result::{WebSocketResult, WebSocketError}; use ws; -use stream::AsTcpStream; +use stream::{ + AsTcpStream, + Stream, +}; pub use stream::Shutdown; +// TODO: buffer the readers +pub struct Reader + where R: Read +{ + reader: R, + receiver: Receiver, +} + +impl Reader + where R: Read, +{ + /// Returns a reference to the underlying Reader. + pub fn get_ref(&self) -> &R { + &self.reader + } + + /// Returns a mutable reference to the underlying Reader. + pub fn get_mut(&mut self) -> &mut R { + &mut self.reader + } +} + +impl Reader + where S: AsTcpStream + Stream + Read, +{ + /// Closes the receiver side of the connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.reader.as_tcp().shutdown(Shutdown::Read) + } + + /// Shuts down both Sender and Receiver, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown_all(&self) -> IoResult<()> { + self.reader.as_tcp().shutdown(Shutdown::Both) + } +} + /// A Receiver that wraps a Reader and provides a default implementation using /// DataFrames and Messages. pub struct Receiver { @@ -25,43 +66,25 @@ impl Receiver { mask: mask, } } - - /// Returns a reference to the underlying Reader. - pub fn get_ref(&self) -> &BufReader { - &self.inner - } - - /// Returns a mutable reference to the underlying Reader. - pub fn get_mut(&mut self) -> &mut BufReader { - &mut self.inner - } } -impl Receiver -where S: AsTcpStream, -{ - /// Closes the receiver side of the connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&self) -> IoResult<()> { - self.inner.get_ref().as_tcp().shutdown(Shutdown::Read) - } - /// Shuts down both Sender and Receiver, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown_all(&self) -> IoResult<()> { - self.inner.get_ref().as_tcp().shutdown(Shutdown::Both) - } -} +impl ws::Receiver for Receiver { + type F = DataFrame; -impl ws::Receiver for Receiver { /// Reads a single data frame from the remote endpoint. - fn recv_dataframe(&mut self) -> WebSocketResult { - DataFrame::read_dataframe(&mut self.inner, self.mask) + fn recv_dataframe(&mut self, reader: &mut R) -> WebSocketResult + where R: Read, + { + DataFrame::read_dataframe(reader, self.mask) } + /// Returns the data frames that constitute one message. - fn recv_message_dataframes(&mut self) -> WebSocketResult> { + fn recv_message_dataframes(&mut self, reader: &mut R) -> WebSocketResult> + where R: Read, + { let mut finished = if self.buffer.is_empty() { - let first = try!(self.recv_dataframe()); + let first = try!(self.recv_dataframe(reader)); if first.opcode == Opcode::Continuation { return Err(WebSocketError::ProtocolError( @@ -78,7 +101,7 @@ impl ws::Receiver for Receiver { }; while !finished { - let next = try!(self.recv_dataframe()); + let next = try!(self.recv_dataframe(reader)); finished = next.finished; match next.opcode as u8 { diff --git a/src/sender.rs b/src/sender.rs index 30d6133507..a53d9899f8 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -8,51 +8,62 @@ use stream::AsTcpStream; use ws; pub use stream::Shutdown; +pub struct Writer { + writer: W, + sender: Sender, +} + +impl Writer + where W: Write, +{ + /// Returns a reference to the underlying Writer. + pub fn get_ref(&self) -> &W { + &self.writer + } + /// Returns a mutable reference to the underlying Writer. + pub fn get_mut(&mut self) -> &mut W { + &mut self.writer + } +} + +impl Writer + where S: AsTcpStream + Write, +{ + /// Closes the sender side of the connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.writer.as_tcp().shutdown(Shutdown::Write) + } + + /// Shuts down both Sender and Receiver, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown_all(&self) -> IoResult<()> { + self.writer.as_tcp().shutdown(Shutdown::Both) + } +} + /// A Sender that wraps a Writer and provides a default implementation using /// DataFrames and Messages. -pub struct Sender { - inner: W, +pub struct Sender { mask: bool, } -impl Sender { +impl Sender { /// Create a new WebSocketSender using the specified Writer. - pub fn new(writer: W, mask: bool) -> Sender { + pub fn new(mask: bool) -> Sender { Sender { - inner: writer, mask: mask, } } - /// Returns a reference to the underlying Writer. - pub fn get_ref(&self) -> &W { - &self.inner - } - /// Returns a mutable reference to the underlying Writer. - pub fn get_mut(&mut self) -> &mut W { - &mut self.inner - } } -impl Sender -where S: AsTcpStream, -{ - /// Closes the sender side of the connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&self) -> IoResult<()> { - self.inner.as_tcp().shutdown(Shutdown::Write) - } - - /// Shuts down both Sender and Receiver, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown_all(&self) -> IoResult<()> { - self.inner.as_tcp().shutdown(Shutdown::Both) - } -} -impl ws::Sender for Sender { +impl ws::Sender for Sender { /// Sends a single data frame to the remote endpoint. - fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> - where D: DataFrame { - dataframe.write_to(&mut self.inner, self.mask) + fn send_dataframe(&mut self, writer: &mut W, dataframe: &D) -> WebSocketResult<()> + where D: DataFrame, + W: Write, + { + dataframe.write_to(writer, self.mask) } } diff --git a/src/server/upgrade/hyper.rs b/src/server/upgrade/hyper.rs index 5eb0ccd281..6bf35abf43 100644 --- a/src/server/upgrade/hyper.rs +++ b/src/server/upgrade/hyper.rs @@ -116,10 +116,8 @@ impl From for HyperIntoWsError { } // TODO: Move this into the main upgrade module -impl IntoWs for S -where S: Stream, - R: Read, - W: Write, +impl IntoWs for S +where S: Stream, { type Stream = S; type Error = (Self, Option, HyperIntoWsError); @@ -146,10 +144,8 @@ where S: Stream, } // TODO: Move this into the main upgrade module -impl IntoWs for RequestStreamPair -where S: Stream, - R: Read, - W: Write, +impl IntoWs for RequestStreamPair +where S: Stream, { type Stream = S; type Error = (S, Request, HyperIntoWsError); @@ -165,31 +161,32 @@ where S: Stream, } } -impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { - type Stream = Box; - type Error = (HyperRequest<'a, 'b>, HyperIntoWsError); - - fn into_ws(self) -> Result, Self::Error> { - if let Err(e) = validate(&self.method, &self.version, &self.headers) { - return Err((self, e)); - } - - let stream: Option> = unimplemented!(); - - if let Some(s) = stream { - Ok(WsUpgrade { - stream: s, - request: Incoming { - version: self.version, - headers: self.headers, - subject: (self.method, self.uri), - }, - }) - } else { - Err((self, HyperIntoWsError::UnknownNetworkStream)) - } - } -} +// TODO +// impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { +// type Stream = Box; +// type Error = (HyperRequest<'a, 'b>, HyperIntoWsError); + +// fn into_ws(self) -> Result, Self::Error> { +// if let Err(e) = validate(&self.method, &self.version, &self.headers) { +// return Err((self, e)); +// } + +// let stream: Option> = unimplemented!(); + +// if let Some(s) = stream { +// Ok(WsUpgrade { +// stream: s, +// request: Incoming { +// version: self.version, +// headers: self.headers, +// subject: (self.method, self.uri), +// }, +// }) +// } else { +// Err((self, HyperIntoWsError::UnknownNetworkStream)) +// } +// } +// } pub fn validate( method: &Method, diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index d616c74210..d905561fba 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -16,9 +16,7 @@ pub mod hyper; /// Users should then call `accept` or `deny` to complete the handshake /// and start a session. pub struct WsUpgrade -where S: Stream, - R: Read, - W: Write, +where S: Stream, { stream: S, request: hyper::Request, diff --git a/src/stream.rs b/src/stream.rs index f6c0e08130..cdcc91792c 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,5 +1,6 @@ //! Provides the default stream type for WebSocket connections. -extern crate mio; +// TODO: add mio support & tokio +// extern crate mio; use std::io::{ self, @@ -12,11 +13,10 @@ pub use openssl::ssl::{ SslStream, SslContext, }; -pub use self::mio::Evented; pub trait Splittable -where R: Read + Evented, - W: Write + Evented, +where R: Read, + W: Write, { fn split(self) -> io::Result<(R, W)>; } @@ -24,34 +24,37 @@ where R: Read + Evented, /// Represents a stream that can be read from, and written to. /// This is an abstraction around readable and writable things to be able /// to speak websockets over ssl, tcp, unix sockets, etc. -pub trait Stream -where R: Read + Evented, - W: Write + Evented, -{ +pub trait Stream { + type Reader: Read; + type Writer: Write; + /// Get a mutable borrow to the reading component of this stream - fn reader(&mut self) -> &mut R; + fn reader(&mut self) -> &mut Self::Reader; /// Get a mutable borrow to the writing component of this stream - fn writer(&mut self) -> &mut W; + fn writer(&mut self) -> &mut Self::Writer; } pub struct ReadWritePair(pub R, pub W) -where R: Read + Evented, - W: Write + Evented; +where R: Read, + W: Write; impl Splittable for ReadWritePair -where R: Read + Evented, - W: Write + Evented, +where R: Read, + W: Write, { fn split(self) -> io::Result<(R, W)> { Ok((self.0, self.1)) } } -impl Stream for ReadWritePair -where R: Read + Evented, - W: Write + Evented, +impl Stream for ReadWritePair +where R: Read, + W: Write, { + type Reader = R; + type Writer = W; + #[inline] fn reader(&mut self) -> &mut R { &mut self.0 @@ -69,9 +72,12 @@ impl Splittable for TcpStream { } } -impl Stream for S -where S: Read + Write + Evented, +impl Stream for S +where S: Read + Write, { + type Reader = Self; + type Writer = Self; + #[inline] fn reader(&mut self) -> &mut S { self diff --git a/src/ws/receiver.rs b/src/ws/receiver.rs index 84690909e8..2ee5b0a184 100644 --- a/src/ws/receiver.rs +++ b/src/ws/receiver.rs @@ -9,89 +9,106 @@ use ws::Message; use ws::dataframe::DataFrame; use result::WebSocketResult; +// TODO: maybe this is not needed anymore /// A trait for receiving data frames and messages. -pub trait Receiver: Sized -where F: DataFrame { - /// Reads a single data frame from this receiver. - fn recv_dataframe(&mut self) -> WebSocketResult; +pub trait Receiver: Sized +{ + type F: DataFrame; - /// Returns the data frames that constitute one message. - fn recv_message_dataframes(&mut self) -> WebSocketResult>; + /// Reads a single data frame from this receiver. + fn recv_dataframe(&mut self, reader: &mut R) -> WebSocketResult + where R: Read; - /// Returns an iterator over incoming data frames. - fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Self, F> { - DataFrameIterator { - inner: self, - _dataframe: PhantomData - } - } + /// Returns the data frames that constitute one message. + fn recv_message_dataframes(&mut self, reader: &mut R) -> WebSocketResult> + where R: Read; - /// Reads a single message from this receiver. - fn recv_message<'m, D, M, I>(&mut self) -> WebSocketResult - where M: Message<'m, D, DataFrameIterator = I>, - I: Iterator, - D: DataFrame + /// Returns an iterator over incoming data frames. + fn incoming_dataframes<'a, R>(&'a mut self, reader: &'a mut R) -> DataFrameIterator<'a, Self, R> + where R: Read, { - let dataframes = try!(self.recv_message_dataframes()); - Message::from_dataframes(dataframes) - } + DataFrameIterator { + reader: reader, + inner: self, + } + } + + /// Reads a single message from this receiver. + fn recv_message<'m, D, M, I, R>(&mut self, reader: &mut R) -> WebSocketResult + where M: Message<'m, D, DataFrameIterator = I>, + I: Iterator, + D: DataFrame, + R: Read, + { + let dataframes = try!(self.recv_message_dataframes(reader)); + Message::from_dataframes(dataframes) + } - /// Returns an iterator over incoming messages. - fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Self, D, F, M> - where M: Message<'a, D>, D: DataFrame - { - MessageIterator { - inner: self, - _dataframe: PhantomData, - _receiver: PhantomData, - _message: PhantomData, - } - } + /// Returns an iterator over incoming messages. + fn incoming_messages<'a, M, D, R>(&'a mut self, reader: &'a mut R) -> MessageIterator<'a, Self, D, Self::F, M, R> + where M: Message<'a, D>, + D: DataFrame, + R: Read, + { + MessageIterator { + reader: reader, + inner: self, + _dataframe: PhantomData, + _receiver: PhantomData, + _message: PhantomData, + } + } } /// An iterator over data frames from a Receiver. -pub struct DataFrameIterator<'a, R, D> -where R: 'a + Receiver, D: DataFrame +pub struct DataFrameIterator<'a, Recv, R> + where Recv: 'a + Receiver, + R: 'a + Read, { - inner: &'a mut R, - _dataframe: PhantomData + reader: &'a mut R, + inner: &'a mut Recv, } -impl<'a, R, D> Iterator for DataFrameIterator<'a, R, D> -where R: 'a + Receiver, D: DataFrame { +impl<'a, Recv, R> Iterator for DataFrameIterator<'a, Recv, R> + where Recv: 'a + Receiver, + R: Read, +{ - type Item = WebSocketResult; + type Item = WebSocketResult; - /// Get the next data frame from the receiver. Always returns `Some`. - fn next(&mut self) -> Option> { - Some(self.inner.recv_dataframe()) - } + /// Get the next data frame from the receiver. Always returns `Some`. + fn next(&mut self) -> Option> { + Some(self.inner.recv_dataframe(self.reader)) + } } /// An iterator over messages from a Receiver. -pub struct MessageIterator<'a, R, D, F, M> -where R: 'a + Receiver, - M: Message<'a, D>, - D: DataFrame, - F: DataFrame, +pub struct MessageIterator<'a, Recv, D, F, M, R> + where Recv: 'a + Receiver, + M: Message<'a, D>, + D: DataFrame, + F: DataFrame, + R: 'a + Read, { - inner: &'a mut R, - _dataframe: PhantomData, - _message: PhantomData, + reader: &'a mut R, + inner: &'a mut Recv, + _dataframe: PhantomData, + _message: PhantomData, _receiver: PhantomData, } -impl<'a, R, D, F, M, I> Iterator for MessageIterator<'a, R, D, F, M> -where R: 'a + Receiver, - M: Message<'a, D, DataFrameIterator = I>, - I: Iterator, - D: DataFrame, - F: DataFrame, +impl<'a, Recv, D, F, M, I, R> Iterator for MessageIterator<'a, Recv, D, F, M, R> + where Recv: 'a + Receiver, + M: Message<'a, D, DataFrameIterator = I>, + I: Iterator, + D: DataFrame, + F: DataFrame, + R: Read, { - type Item = WebSocketResult; + type Item = WebSocketResult; - /// Get the next message from the receiver. Always returns `Some`. - fn next(&mut self) -> Option> { - Some(self.inner.recv_message()) - } + /// Get the next message from the receiver. Always returns `Some`. + fn next(&mut self) -> Option> { + Some(self.inner.recv_message(self.reader)) + } } diff --git a/src/ws/sender.rs b/src/ws/sender.rs index 6d8c7d8e37..0a76ce66eb 100644 --- a/src/ws/sender.rs +++ b/src/ws/sender.rs @@ -7,18 +7,22 @@ use ws::Message; use ws::dataframe::DataFrame; use result::WebSocketResult; +// TODO: maybe this is not needed anymore /// A trait for sending data frames and messages. pub trait Sender { /// Sends a single data frame using this sender. - fn send_dataframe(&mut self, writer: &mut Write, dataframe: &D) -> WebSocketResult<()> - where D: DataFrame; + fn send_dataframe(&mut self, writer: &mut W, dataframe: &D) -> WebSocketResult<()> + where D: DataFrame, + W: Write; /// Sends a single message using this sender. - fn send_message<'m, M, D>(&mut self, writer: &mut Write, message: &'m M) -> WebSocketResult<()> - where M: Message<'m, D>, D: DataFrame + fn send_message<'m, M, D, W>(&mut self, writer: &mut W, message: &'m M) -> WebSocketResult<()> + where M: Message<'m, D>, + D: DataFrame, + W: Write, { for ref dataframe in message.dataframes() { - try!(self.send_dataframe(dataframe)); + try!(self.send_dataframe(writer, dataframe)); } Ok(()) } From f1595d8fdc5d5d76432ad818d643d20e6b21cef0 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Mon, 27 Mar 2017 19:33:07 -0400 Subject: [PATCH 14/32] added back splitting and shutting down client --- src/client/mod.rs | 144 +++++++++++++++++++++++++-------------------- src/receiver.rs | 47 +++++++++++---- src/sender.rs | 27 +++++---- src/stream.rs | 100 ++++++++++++++++--------------- src/ws/receiver.rs | 10 +--- 5 files changed, 190 insertions(+), 138 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index d3a6d1a529..0bae6be9a1 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -2,6 +2,10 @@ use std::net::TcpStream; use std::marker::PhantomData; use std::io::Result as IoResult; +use std::io::{ + Read, + Write, +}; use std::ops::Deref; use ws; @@ -16,6 +20,8 @@ use result::WebSocketResult; use stream::{ AsTcpStream, Stream, + Splittable, + Shutdown, }; use dataframe::DataFrame; use ws::dataframe::DataFrame as DataFrameable; @@ -25,8 +31,11 @@ use openssl::ssl::{SslContext, SslMethod, SslStream}; pub use self::request::Request; pub use self::response::Response; -pub use sender::Sender; -pub use receiver::Receiver; +use sender::Sender; +pub use sender::Writer; + +use receiver::Receiver; +pub use receiver::Reader; pub mod request; pub mod response; @@ -155,36 +164,37 @@ pub struct ClientBuilder // } // } -// // TODO: add method to expose tcp to edit things -// impl Client -// where S: AsTcpStream + Stream, -// { -// /// Shuts down the sending half of the client connection, will cause all pending -// /// and future IO to return immediately with an appropriate value. -// pub fn shutdown_sender(&self) -> IoResult<()> { -// self.sender.shutdown() -// } +impl Client + where S: AsTcpStream + Stream, +{ + /// Shuts down the sending half of the client connection, will cause all pending + /// and future IO to return immediately with an appropriate value. + pub fn shutdown_sender(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Write) + } -// /// Shuts down the receiving half of the client connection, will cause all pending -// /// and future IO to return immediately with an appropriate value. -// pub fn shutdown_receiver(&self) -> IoResult<()> { -// self.receiver.shutdown() -// } + /// Shuts down the receiving half of the client connection, will cause all pending + /// and future IO to return immediately with an appropriate value. + pub fn shutdown_receiver(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Read) + } -// /// Shuts down the client connection, will cause all pending and future IO to -// /// return immediately with an appropriate value. -// pub fn shutdown(&self) -> IoResult<()> { -// self.receiver.shutdown_all() -// } -// } + /// Shuts down the client connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Both) + } + + // TODO: add net2 set_nonblocking and stuff +} impl Client where S: Stream, { - /// Creates a Client from the given Sender and Receiver. + /// Crtes a Client from the given Sender and Receiver. /// - /// Essentially the opposite of `Client.split()`. - fn new(stream: S) -> Client { + /// Esstiallthe opposite of `Client.split()`. + fn new(stream: S) -> Self { Client { stream: stream, // TODO: always true? @@ -274,7 +284,7 @@ impl Client ///} ///# } ///``` - pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, DataFrame, M, S::Reader> + pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, S::Reader> where M: ws::Message<'a, D>, D: DataFrameable { @@ -282,41 +292,47 @@ impl Client } } -// TODO -// impl Client -// where S: Splittable, -// { -// /// Split this client into its constituent Sender and Receiver pair. -// /// -// /// This allows the Sender and Receiver to be sent to different threads. -// /// -// ///```no_run -// ///# extern crate websocket; -// ///# fn main() { -// ///use websocket::{Client, Message, Sender, Receiver}; -// ///use std::thread; -// ///# use websocket::client::request::Url; -// ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL -// ///# let request = Client::connect(url).unwrap(); // Connect to the server -// ///# let response = request.send().unwrap(); // Send the request -// ///# response.validate().unwrap(); // Ensure the response is valid -// /// -// ///let client = response.begin(); // Get a Client -// /// -// ///let (mut sender, mut receiver) = client.split(); -// /// -// ///thread::spawn(move || { -// /// for message in receiver.incoming_messages() { -// /// let message: Message = message.unwrap(); -// /// println!("Recv: {:?}", message); -// /// } -// ///}); -// /// -// ///let message = Message::text("Hello, World!"); -// ///sender.send_message(&message).unwrap(); -// ///# } -// ///``` -// pub fn split(self) -> (Reader, Writer) { -// let cloned_stream = try!(self.stream.try_clone()); -// } -// } +impl Client + where S: Splittable + Stream, +{ + /// Split this client into its constituent Sender and Receiver pair. + /// + /// This allows the Sender and Receiver to be sent to different threads. + /// + ///```no_run + ///# extern crate websocket; + ///# fn main() { + ///use websocket::{Client, Message, Sender, Receiver}; + ///use std::thread; + ///# use websocket::client::request::Url; + ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL + ///# let request = Client::connect(url).unwrap(); // Connect to the server + ///# let response = request.send().unwrap(); // Send the request + ///# response.validate().unwrap(); // Ensure the response is valid + /// + ///let client = response.begin(); // Get a Client + /// + ///let (mut sender, mut receiver) = client.split(); + /// + ///thread::spawn(move || { + /// for message in receiver.incoming_messages() { + /// let message: Message = message.unwrap(); + /// println!("Recv: {:?}", message); + /// } + ///}); + /// + ///let message = Message::text("Hello, World!"); + ///sender.send_message(&message).unwrap(); + ///# } + ///``` + pub fn split(self) -> IoResult<(Reader<::Reader>, Writer<::Writer>)> { + let (read, write) = try!(self.stream.split()); + Ok((Reader { + reader: read, + receiver: self.receiver, + }, Writer { + writer: write, + sender: self.sender, + })) + } +} diff --git a/src/receiver.rs b/src/receiver.rs index fd8b06c985..16d757f4c1 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -4,9 +4,21 @@ use std::io::Read; use std::io::Result as IoResult; use hyper::buffer::BufReader; -use dataframe::{DataFrame, Opcode}; -use result::{WebSocketResult, WebSocketError}; +use dataframe::{ + DataFrame, + Opcode +}; +use result::{ + WebSocketResult, + WebSocketError +}; use ws; +use ws::dataframe::DataFrame as DataFrameable; +use ws::receiver::Receiver as ReceiverTrait; +use ws::receiver::{ + MessageIterator, + DataFrameIterator, +}; use stream::{ AsTcpStream, Stream, @@ -17,21 +29,36 @@ pub use stream::Shutdown; pub struct Reader where R: Read { - reader: R, - receiver: Receiver, + pub reader: R, + pub receiver: Receiver, } impl Reader where R: Read, { - /// Returns a reference to the underlying Reader. - pub fn get_ref(&self) -> &R { - &self.reader + /// Reads a single data frame from the remote endpoint. + pub fn recv_dataframe(&mut self) -> WebSocketResult { + self.receiver.recv_dataframe(&mut self.reader) + } + + /// Returns an iterator over incoming data frames. + pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, R> { + self.receiver.incoming_dataframes(&mut self.reader) } - /// Returns a mutable reference to the underlying Reader. - pub fn get_mut(&mut self) -> &mut R { - &mut self.reader + /// Reads a single message from this receiver. + pub fn recv_message<'m, M, I>(&mut self) -> WebSocketResult + where M: ws::Message<'m, DataFrame, DataFrameIterator = I>, + I: Iterator + { + self.receiver.recv_message(&mut self.reader) + } + + pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, R> + where M: ws::Message<'a, D>, + D: DataFrameable + { + self.receiver.incoming_messages(&mut self.reader) } } diff --git a/src/sender.rs b/src/sender.rs index a53d9899f8..fafec5d982 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -6,23 +6,31 @@ use result::WebSocketResult; use ws::dataframe::DataFrame; use stream::AsTcpStream; use ws; +use ws::sender::Sender as SenderTrait; pub use stream::Shutdown; pub struct Writer { - writer: W, - sender: Sender, + pub writer: W, + pub sender: Sender, } impl Writer where W: Write, { - /// Returns a reference to the underlying Writer. - pub fn get_ref(&self) -> &W { - &self.writer - } - /// Returns a mutable reference to the underlying Writer. - pub fn get_mut(&mut self) -> &mut W { - &mut self.writer + /// Sends a single data frame to the remote endpoint. + fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> + where D: DataFrame, + W: Write, + { + self.sender.send_dataframe(&mut self.writer, dataframe) + } + + /// Sends a single message to the remote endpoint. + pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> + where M: ws::Message<'m, D>, + D: DataFrame + { + self.sender.send_message(&mut self.writer, message) } } @@ -57,7 +65,6 @@ impl Sender { } } - impl ws::Sender for Sender { /// Sends a single data frame to the remote endpoint. fn send_dataframe(&mut self, writer: &mut W, dataframe: &D) -> WebSocketResult<()> diff --git a/src/stream.rs b/src/stream.rs index cdcc91792c..b0f0eeafa6 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -3,90 +3,96 @@ // extern crate mio; use std::io::{ - self, - Read, - Write + self, + Read, + Write }; pub use std::net::TcpStream; pub use std::net::Shutdown; pub use openssl::ssl::{ - SslStream, - SslContext, + SslStream, + SslContext, }; -pub trait Splittable -where R: Read, - W: Write, -{ - fn split(self) -> io::Result<(R, W)>; +pub trait Splittable { + type Reader: Read; + type Writer: Write; + + fn split(self) -> io::Result<(Self::Reader, Self::Writer)>; } /// Represents a stream that can be read from, and written to. /// This is an abstraction around readable and writable things to be able /// to speak websockets over ssl, tcp, unix sockets, etc. pub trait Stream { - type Reader: Read; - type Writer: Write; + type Reader: Read; + type Writer: Write; - /// Get a mutable borrow to the reading component of this stream - fn reader(&mut self) -> &mut Self::Reader; + /// Get a mutable borrow to the reading component of this stream + fn reader(&mut self) -> &mut Self::Reader; - /// Get a mutable borrow to the writing component of this stream - fn writer(&mut self) -> &mut Self::Writer; + /// Get a mutable borrow to the writing component of this stream + fn writer(&mut self) -> &mut Self::Writer; } pub struct ReadWritePair(pub R, pub W) -where R: Read, - W: Write; + where R: Read, + W: Write; -impl Splittable for ReadWritePair -where R: Read, - W: Write, +impl Splittable for ReadWritePair + where R: Read, + W: Write, { - fn split(self) -> io::Result<(R, W)> { - Ok((self.0, self.1)) - } + type Reader = R; + type Writer = W; + + fn split(self) -> io::Result<(R, W)> { + Ok((self.0, self.1)) + } } impl Stream for ReadWritePair -where R: Read, - W: Write, + where R: Read, + W: Write, { type Reader = R; type Writer = W; - #[inline] - fn reader(&mut self) -> &mut R { - &mut self.0 - } + #[inline] + fn reader(&mut self) -> &mut R { + &mut self.0 + } - #[inline] - fn writer(&mut self) -> &mut W { - &mut self.1 - } + #[inline] + fn writer(&mut self) -> &mut W { + &mut self.1 + } } -impl Splittable for TcpStream { - fn split(self) -> io::Result<(TcpStream, TcpStream)> { - self.try_clone().map(|s| (s, self)) - } +impl Splittable for TcpStream { + type Reader = TcpStream; + type Writer = TcpStream; + + fn split(self) -> io::Result<(TcpStream, TcpStream)> { + self.try_clone().map(|s| (s, self)) + } } impl Stream for S -where S: Read + Write, + where S: Read + Write, { type Reader = Self; type Writer = Self; - #[inline] - fn reader(&mut self) -> &mut S { - self - } + #[inline] + fn reader(&mut self) -> &mut S { + self + } - #[inline] - fn writer(&mut self) -> &mut S { - self - } + #[inline] + fn writer(&mut self) -> &mut S { + self + } } pub trait AsTcpStream { diff --git a/src/ws/receiver.rs b/src/ws/receiver.rs index 2ee5b0a184..bea0c1d962 100644 --- a/src/ws/receiver.rs +++ b/src/ws/receiver.rs @@ -45,7 +45,7 @@ pub trait Receiver: Sized } /// Returns an iterator over incoming messages. - fn incoming_messages<'a, M, D, R>(&'a mut self, reader: &'a mut R) -> MessageIterator<'a, Self, D, Self::F, M, R> + fn incoming_messages<'a, M, D, R>(&'a mut self, reader: &'a mut R) -> MessageIterator<'a, Self, D, M, R> where M: Message<'a, D>, D: DataFrame, R: Read, @@ -54,7 +54,6 @@ pub trait Receiver: Sized reader: reader, inner: self, _dataframe: PhantomData, - _receiver: PhantomData, _message: PhantomData, } } @@ -83,26 +82,23 @@ impl<'a, Recv, R> Iterator for DataFrameIterator<'a, Recv, R> } /// An iterator over messages from a Receiver. -pub struct MessageIterator<'a, Recv, D, F, M, R> +pub struct MessageIterator<'a, Recv, D, M, R> where Recv: 'a + Receiver, M: Message<'a, D>, D: DataFrame, - F: DataFrame, R: 'a + Read, { reader: &'a mut R, inner: &'a mut Recv, _dataframe: PhantomData, _message: PhantomData, - _receiver: PhantomData, } -impl<'a, Recv, D, F, M, I, R> Iterator for MessageIterator<'a, Recv, D, F, M, R> +impl<'a, Recv, D, M, I, R> Iterator for MessageIterator<'a, Recv, D, M, R> where Recv: 'a + Receiver, M: Message<'a, D, DataFrameIterator = I>, I: Iterator, D: DataFrame, - F: DataFrame, R: Read, { type Item = WebSocketResult; From f03c0071087d18323d6419c973271064a9f83665 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Tue, 28 Mar 2017 01:04:21 -0400 Subject: [PATCH 15/32] started working on builder --- src/client/mod.rs | 318 ++++++++++++++++++++++++++++++++++------- src/client/request.rs | 232 +++++++++++++++--------------- src/client/response.rs | 198 ++++++++++++------------- 3 files changed, 480 insertions(+), 268 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 0bae6be9a1..91b1a2890b 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,4 +1,7 @@ //! Contains the WebSocket client. +extern crate url; + +use std::borrow::Borrow; use std::net::TcpStream; use std::marker::PhantomData; use std::io::Result as IoResult; @@ -8,6 +11,27 @@ use std::io::{ }; use std::ops::Deref; +use self::url::{ + Url, + Position, +}; +use openssl::ssl::{ + SslContext, + SslMethod, + SslStream, +}; +use hyper::version::HttpVersion; +use hyper::header::{ + Headers, + Host, + Connection, + ConnectionOption, + Upgrade, + Protocol, + ProtocolName, +}; +use unicase::UniCase; + use ws; use ws::sender::Sender as SenderTrait; use ws::util::url::ToWebSocketUrlComponents; @@ -16,6 +40,14 @@ use ws::receiver::{ MessageIterator, }; use ws::receiver::Receiver as ReceiverTrait; +use header::extensions::Extension; +use header::{ + WebSocketKey, + WebSocketVersion, + WebSocketProtocol, + WebSocketExtensions, + Origin +}; use result::WebSocketResult; use stream::{ AsTcpStream, @@ -24,68 +56,201 @@ use stream::{ Shutdown, }; use dataframe::DataFrame; -use ws::dataframe::DataFrame as DataFrameable; - -use openssl::ssl::{SslContext, SslMethod, SslStream}; +use ws::dataframe::DataFrame as DataFrameable; +use sender::Sender; +use receiver::Receiver; pub use self::request::Request; pub use self::response::Response; - -use sender::Sender; pub use sender::Writer; - -use receiver::Receiver; pub use receiver::Reader; pub mod request; pub mod response; -/// Represents a WebSocket client, which can send and receive messages/data frames. -/// -/// `D` is the data frame type, `S` is the type implementing `Sender` and `R` -/// is the type implementing `Receiver`. -/// -/// For most cases, the data frame type will be `dataframe::DataFrame`, the Sender -/// type will be `client::Sender` and the receiver type -/// will be `client::Receiver`. -/// -/// A `Client` can be split into a `Sender` and a `Receiver` which can then be moved -/// to different threads, often using a send loop and receiver loop concurrently, -/// as shown in the client example in `examples/client.rs`. -/// -///#Connecting to a Server -/// -///```no_run -///extern crate websocket; -///# fn main() { -/// -///use websocket::{Client, Message}; -///use websocket::client::request::Url; -/// -///let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL -///let request = Client::connect(url).unwrap(); // Connect to the server -///let response = request.send().unwrap(); // Send the request -///response.validate().unwrap(); // Ensure the response is valid -/// -///let mut client = response.begin(); // Get a Client -/// -///let message = Message::text("Hello, World!"); -///client.send_message(&message).unwrap(); // Send message -///# } -///``` -pub struct Client - where S: Stream, -{ - stream: S, - sender: Sender, - receiver: Receiver, +// TODO: if using URL, remove the to url components trait +// TODO: implement ToOwned +/// Build clients with a builder-style API +pub struct ClientBuilder<'u, 's> { + url: &'u Url, + version: HttpVersion, + headers: Headers, + version_set: bool, + key_set: bool, + ssl_context: Option<&'s SslContext>, } -// TODO: add back client.split() -pub struct ClientBuilder - where S: Stream, -{ - stream: S, +macro_rules! upsert_header { + ($headers:expr; $header:ty; { + Some($pat:pat) => $some_match:expr, + None => $default:expr + }) => {{ + match $headers.has::<$header>() { + true => { + match $headers.get_mut::<$header>() { + Some($pat) => { $some_match; }, + None => (), + }; + } + false => { + $headers.set($default); + }, + }; + }} +} + +impl<'u, 's> ClientBuilder<'u, 's> { + pub fn new(url: &'u Url) -> Self { + ClientBuilder { + url: url, + version: HttpVersion::Http11, + version_set: false, + key_set: false, + ssl_context: None, + headers: Headers::new(), + } + } + + pub fn add_protocol

(mut self, protocol: P) -> Self + where P: Into, + { + upsert_header!(self.headers; WebSocketProtocol; { + Some(protos) => protos.0.push(protocol.into()), + None => WebSocketProtocol(vec![protocol.into()]) + }); + self + } + + pub fn add_protocols(mut self, protocols: I) -> Self + where I: IntoIterator, + S: Into, + { + let mut protocols: Vec = protocols.into_iter() + .map(Into::into).collect(); + + upsert_header!(self.headers; WebSocketProtocol; { + Some(protos) => protos.0.append(&mut protocols), + None => WebSocketProtocol(protocols) + }); + self + } + + pub fn clear_protocols(mut self) -> Self { + self.headers.remove::(); + self + } + + pub fn add_extension(mut self, extension: Extension) -> Self + { + upsert_header!(self.headers; WebSocketExtensions; { + Some(protos) => protos.0.push(extension), + None => WebSocketExtensions(vec![extension]) + }); + self + } + + pub fn add_extensions(mut self, extensions: I) -> Self + where I: IntoIterator, + { + let mut extensions: Vec = extensions.into_iter().collect(); + upsert_header!(self.headers; WebSocketExtensions; { + Some(protos) => protos.0.append(&mut extensions), + None => WebSocketExtensions(extensions) + }); + self + } + + pub fn clear_extensions(mut self) -> Self { + self.headers.remove::(); + self + } + + pub fn key(mut self, key: [u8; 16]) -> Self { + self.headers.set(WebSocketKey(key)); + self.key_set = true; + self + } + + pub fn clear_key(mut self) -> Self { + self.headers.remove::(); + self.key_set = false; + self + } + + pub fn version(mut self, version: WebSocketVersion) -> Self { + self.headers.set(version); + self.version_set = true; + self + } + + pub fn clear_version(mut self) -> Self { + self.headers.remove::(); + self.version_set = false; + self + } + + pub fn origin(mut self, origin: String) -> Self { + self.headers.set(Origin(origin)); + self + } + + pub fn custom_headers(mut self, edit: F) -> Self + where F: Fn(&mut Headers), + { + edit(&mut self.headers); + self + } + + pub fn ssl_context(mut self, context: &'s SslContext) -> Self { + self.ssl_context = Some(context); + self + } + + pub fn connect_on(&mut self, mut stream: S) -> WebSocketResult> + where S: Stream, + { + // Get info about ports + let is_ssl = self.url.scheme() == "wss"; + let port = match self.url.port() { + Some(port) => port, + None if is_ssl => 443, + None => 80, + }; + let host = match self.url.host_str() { + Some(h) => h, + None => unimplemented!(), + }; + let resource = self.url[Position::BeforePath..Position::AfterQuery] + .to_owned(); + + // set host header and other headers + self.headers.set(Host { + hostname: host.to_string(), + port: self.url.port(), + }); + + self.headers.set(Connection(vec![ + ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) + ])); + + self.headers.set(Upgrade(vec![Protocol { + name: ProtocolName::WebSocket, + // TODO: actually correct or just works? + version: None + }])); + + if !self.version_set { + self.headers.set(WebSocketVersion::WebSocket13); + } + + if !self.key_set { + self.headers.set(WebSocketKey::new()); + } + + try!(write!(stream.writer(), "GET {} {}\r\n", resource, self.version)); + try!(write!(stream.writer(), "{}\r\n", self.headers)); + unimplemented!(); + } } // impl Client { @@ -164,6 +329,48 @@ pub struct ClientBuilder // } // } +/// Represents a WebSocket client, which can send and receive messages/data frames. +/// +/// `D` is the data frame type, `S` is the type implementing `Sender` and `R` +/// is the type implementing `Receiver`. +/// +/// For most cases, the data frame type will be `dataframe::DataFrame`, the Sender +/// type will be `client::Sender` and the receiver type +/// will be `client::Receiver`. +/// +/// A `Client` can be split into a `Sender` and a `Receiver` which can then be moved +/// to different threads, often using a send loop and receiver loop concurrently, +/// as shown in the client example in `examples/client.rs`. +/// +///#Connecting to a Server +/// +///```no_run +///extern crate websocket; +///# fn main() { +/// +///use websocket::{Client, Message}; +///use websocket::client::request::Url; +/// +///let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL +///let request = Client::connect(url).unwrap(); // Connect to the server +///let response = request.send().unwrap(); // Send the request +///response.validate().unwrap(); // Ensure the response is valid +/// +///let mut client = response.begin(); // Get a Client +/// +///let message = Message::text("Hello, World!"); +///client.send_message(&message).unwrap(); // Send message +///# } +///``` +pub struct Client + where S: Stream, +{ + stream: S, + sender: Sender, + receiver: Receiver, +} + +// TODO: maybe make shutdown options only for TcpStream? how does it work with SslStream? impl Client where S: AsTcpStream + Stream, { @@ -188,10 +395,15 @@ impl Client // TODO: add net2 set_nonblocking and stuff } -impl Client +impl<'u, 'p, 'e, 's, S> Client where S: Stream, { - /// Crtes a Client from the given Sender and Receiver. + + pub fn build(address: &'u Url) -> ClientBuilder<'u, 's> { + ClientBuilder::new(address) + } + + /// Creates a Client from the given Sender and Receiver. /// /// Esstiallthe opposite of `Client.split()`. fn new(stream: S) -> Self { diff --git a/src/client/request.rs b/src/client/request.rs index 43f502cc67..92224d23b9 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -22,126 +22,126 @@ use ws::util::url::ToWebSocketUrlComponents; pub struct Request { /// The HTTP version of this request. pub version: HttpVersion, - /// The headers of this request. - pub headers: Headers, - + /// The headers of this request. + pub headers: Headers, + resource_name: String, - reader: BufReader, - writer: W, + reader: BufReader, + writer: W, } unsafe impl Send for Request where R: Read + Send, W: Write + Send { } impl Request { - /// Creates a new client-side request. - /// - /// In general `Client::connect()` should be used for connecting to servers. - /// However, if the request is to be written to a different Writer, this function - /// may be used. - pub fn new(components: T, stream: (R, W)) -> WebSocketResult> - where T: ToWebSocketUrlComponents, - { - let (reader, writer) = stream; - let mut headers = Headers::new(); - let (host, resource_name, _) = try!(components.to_components()); - headers.set(host); - headers.set(Connection(vec![ - ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) - ])); - headers.set(Upgrade(vec![Protocol{ - name: ProtocolName::WebSocket, - version: None - }])); - headers.set(WebSocketVersion::WebSocket13); - headers.set(WebSocketKey::new()); - - Ok(Request { - version: HttpVersion::Http11, - headers: headers, - resource_name: resource_name, - reader: BufReader::new(reader), - writer: writer - }) - } - /// Short-cut to obtain the WebSocketKey value. - pub fn key(&self) -> Option<&WebSocketKey> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketVersion value. - pub fn version(&self) -> Option<&WebSocketVersion> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketProtocol value. - pub fn protocol(&self) -> Option<&WebSocketProtocol> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketExtensions value. - pub fn extensions(&self) -> Option<&WebSocketExtensions> { - self.headers.get() - } - /// Short-cut to obtain the Origin value. - pub fn origin(&self) -> Option<&Origin> { - self.headers.get() - } - /// Short-cut to obtain a mutable reference to the WebSocketKey value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn key_mut(&mut self) -> Option<&mut WebSocketKey> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the WebSocketVersion value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn version_mut(&mut self) -> Option<&mut WebSocketVersion> { - self.headers.get_mut() - } - /// Short-cut to obtaina mutable reference to the WebSocketProtocol value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn protocol_mut(&mut self) -> Option<&mut WebSocketProtocol> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the WebSocketExtensions value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn extensions_mut(&mut self) -> Option<&mut WebSocketExtensions> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the Origin value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn origin_mut(&mut self) -> Option<&mut Origin> { - self.headers.get_mut() - } - /// Returns a reference to the inner Reader. - pub fn get_reader(&self) -> &BufReader { - &self.reader - } - /// Returns a reference to the inner Writer. - pub fn get_writer(&self) -> &W { - &self.writer - } - /// Returns a mutable reference to the inner Reader. - pub fn get_mut_reader(&mut self) -> &mut BufReader { - &mut self.reader - } - /// Returns a mutable reference to the inner Writer. - pub fn get_mut_writer(&mut self) -> &mut W { - &mut self.writer - } - /// Return the inner Reader and Writer. - pub fn into_inner(self) -> (BufReader, W) { - (self.reader, self.writer) - } - /// Sends the request to the server and returns a response. - pub fn send(mut self) -> WebSocketResult> { - try!(write!(&mut self.writer, "GET {} {}\r\n", self.resource_name, self.version)); - try!(write!(&mut self.writer, "{}\r\n", self.headers)); - Response::read(self) - } + /// Creates a new client-side request. + /// + /// In general `Client::connect()` should be used for connecting to servers. + /// However, if the request is to be written to a different Writer, this function + /// may be used. + pub fn new(components: T, stream: (R, W)) -> WebSocketResult> + where T: ToWebSocketUrlComponents, + { + let (reader, writer) = stream; + let mut headers = Headers::new(); + let (host, resource_name, _) = try!(components.to_components()); + headers.set(host); + headers.set(Connection(vec![ + ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) + ])); + headers.set(Upgrade(vec![Protocol{ + name: ProtocolName::WebSocket, + version: None + }])); + headers.set(WebSocketVersion::WebSocket13); + headers.set(WebSocketKey::new()); + + Ok(Request { + version: HttpVersion::Http11, + headers: headers, + resource_name: resource_name, + reader: BufReader::new(reader), + writer: writer + }) + } + /// Short-cut to obtain the WebSocketKey value. + pub fn key(&self) -> Option<&WebSocketKey> { + self.headers.get() + } + /// Short-cut to obtain the WebSocketVersion value. + pub fn version(&self) -> Option<&WebSocketVersion> { + self.headers.get() + } + /// Short-cut to obtain the WebSocketProtocol value. + pub fn protocol(&self) -> Option<&WebSocketProtocol> { + self.headers.get() + } + /// Short-cut to obtain the WebSocketExtensions value. + pub fn extensions(&self) -> Option<&WebSocketExtensions> { + self.headers.get() + } + /// Short-cut to obtain the Origin value. + pub fn origin(&self) -> Option<&Origin> { + self.headers.get() + } + /// Short-cut to obtain a mutable reference to the WebSocketKey value. + /// + /// Note that to add a header that does not already exist, ```Request.headers.set()``` + /// must be used. + pub fn key_mut(&mut self) -> Option<&mut WebSocketKey> { + self.headers.get_mut() + } + /// Short-cut to obtain a mutable reference to the WebSocketVersion value. + /// + /// Note that to add a header that does not already exist, ```Request.headers.set()``` + /// must be used. + pub fn version_mut(&mut self) -> Option<&mut WebSocketVersion> { + self.headers.get_mut() + } + /// Short-cut to obtaina mutable reference to the WebSocketProtocol value. + /// + /// Note that to add a header that does not already exist, ```Request.headers.set()``` + /// must be used. + pub fn protocol_mut(&mut self) -> Option<&mut WebSocketProtocol> { + self.headers.get_mut() + } + /// Short-cut to obtain a mutable reference to the WebSocketExtensions value. + /// + /// Note that to add a header that does not already exist, ```Request.headers.set()``` + /// must be used. + pub fn extensions_mut(&mut self) -> Option<&mut WebSocketExtensions> { + self.headers.get_mut() + } + /// Short-cut to obtain a mutable reference to the Origin value. + /// + /// Note that to add a header that does not already exist, ```Request.headers.set()``` + /// must be used. + pub fn origin_mut(&mut self) -> Option<&mut Origin> { + self.headers.get_mut() + } + /// Returns a reference to the inner Reader. + pub fn get_reader(&self) -> &BufReader { + &self.reader + } + /// Returns a reference to the inner Writer. + pub fn get_writer(&self) -> &W { + &self.writer + } + /// Returns a mutable reference to the inner Reader. + pub fn get_mut_reader(&mut self) -> &mut BufReader { + &mut self.reader + } + /// Returns a mutable reference to the inner Writer. + pub fn get_mut_writer(&mut self) -> &mut W { + &mut self.writer + } + /// Return the inner Reader and Writer. + pub fn into_inner(self) -> (BufReader, W) { + (self.reader, self.writer) + } + /// Sends the request to the server and returns a response. + pub fn send(mut self) -> WebSocketResult> { + try!(write!(&mut self.writer, "GET {} {}\r\n", self.resource_name, self.version)); + try!(write!(&mut self.writer, "{}\r\n", self.headers)); + Response::read(self) + } } diff --git a/src/client/response.rs b/src/client/response.rs index 7c0cdfaf6e..8fbce9f3d8 100644 --- a/src/client/response.rs +++ b/src/client/response.rs @@ -22,118 +22,118 @@ use ws; /// Represents a WebSocket response. pub struct Response { - /// The status of the response - pub status: StatusCode, - /// The headers contained in this response - pub headers: Headers, - /// The HTTP version of this response - pub version: HttpVersion, + /// The status of the response + pub status: StatusCode, + /// The headers contained in this response + pub headers: Headers, + /// The HTTP version of this response + pub version: HttpVersion, - request: Request + request: Request } unsafe impl Send for Response where R: Read + Send, W: Write + Send { } impl Response { - /// Reads a Response off the stream associated with a Request. - /// - /// This is called by Request.send(), and does not need to be called by the user. - pub fn read(mut request: Request) -> WebSocketResult> { - let (status, version, headers) = { - let reader = request.get_mut_reader(); + /// Reads a Response off the stream associated with a Request. + /// + /// This is called by Request.send(), and does not need to be called by the user. + pub fn read(mut request: Request) -> WebSocketResult> { + let (status, version, headers) = { + let reader = request.get_mut_reader(); - let response = try!(parse_response(reader)); + let response = try!(parse_response(reader)); - let status = StatusCode::from_u16(response.subject.0); - (status, response.version, response.headers) - }; + let status = StatusCode::from_u16(response.subject.0); + (status, response.version, response.headers) + }; - Ok(Response { - status: status, - headers: headers, - version: version, - request: request - }) - } + Ok(Response { + status: status, + headers: headers, + version: version, + request: request + }) + } - /// Short-cut to obtain the WebSocketAccept value. - pub fn accept(&self) -> Option<&WebSocketAccept> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketProtocol value. - pub fn protocol(&self) -> Option<&WebSocketProtocol> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketExtensions value. - pub fn extensions(&self) -> Option<&WebSocketExtensions> { - self.headers.get() - } + /// Short-cut to obtain the WebSocketAccept value. + pub fn accept(&self) -> Option<&WebSocketAccept> { + self.headers.get() + } + /// Short-cut to obtain the WebSocketProtocol value. + pub fn protocol(&self) -> Option<&WebSocketProtocol> { + self.headers.get() + } + /// Short-cut to obtain the WebSocketExtensions value. + pub fn extensions(&self) -> Option<&WebSocketExtensions> { + self.headers.get() + } /// Returns a reference to the inner Reader. - pub fn get_reader(&self) -> &BufReader { - self.request.get_reader() - } - /// Returns a reference to the inner Writer. - pub fn get_writer(&self) -> &W { - self.request.get_writer() - } - /// Returns a mutable reference to the inner Reader. - pub fn get_mut_reader(&mut self) -> &mut BufReader { - self.request.get_mut_reader() - } - /// Returns a mutable reference to the inner Writer. - pub fn get_mut_writer(&mut self) -> &mut W { - self.request.get_mut_writer() - } - /// Returns a reference to the request associated with this response. - pub fn get_request(&self) -> &Request { - &self.request - } - /// Return the inner Reader and Writer. - pub fn into_inner(self) -> (BufReader, W) { - self.request.into_inner() - } + pub fn get_reader(&self) -> &BufReader { + self.request.get_reader() + } + /// Returns a reference to the inner Writer. + pub fn get_writer(&self) -> &W { + self.request.get_writer() + } + /// Returns a mutable reference to the inner Reader. + pub fn get_mut_reader(&mut self) -> &mut BufReader { + self.request.get_mut_reader() + } + /// Returns a mutable reference to the inner Writer. + pub fn get_mut_writer(&mut self) -> &mut W { + self.request.get_mut_writer() + } + /// Returns a reference to the request associated with this response. + pub fn get_request(&self) -> &Request { + &self.request + } + /// Return the inner Reader and Writer. + pub fn into_inner(self) -> (BufReader, W) { + self.request.into_inner() + } - /// Check if this response constitutes a successful handshake. - pub fn validate(&self) -> WebSocketResult<()> { - if self.status != StatusCode::SwitchingProtocols { - return Err(WebSocketError::ResponseError("Status code must be Switching Protocols")); - } - let key = try!(self.request.key().ok_or( - WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid") - )); - if self.accept() != Some(&(WebSocketAccept::new(key))) { - return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid")); - } - if self.headers.get() != Some(&(Upgrade(vec![Protocol{ - name: ProtocolName::WebSocket, - version: None - }]))) { - return Err(WebSocketError::ResponseError("Upgrade field must be WebSocket")); - } - if self.headers.get() != Some(&(Connection(vec![ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string()))]))) { - return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); - } - Ok(()) - } + /// Check if this response constitutes a successful handshake. + pub fn validate(&self) -> WebSocketResult<()> { + if self.status != StatusCode::SwitchingProtocols { + return Err(WebSocketError::ResponseError("Status code must be Switching Protocols")); + } + let key = try!(self.request.key().ok_or( + WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid") + )); + if self.accept() != Some(&(WebSocketAccept::new(key))) { + return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid")); + } + if self.headers.get() != Some(&(Upgrade(vec![Protocol{ + name: ProtocolName::WebSocket, + version: None + }]))) { + return Err(WebSocketError::ResponseError("Upgrade field must be WebSocket")); + } + if self.headers.get() != Some(&(Connection(vec![ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string()))]))) { + return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); + } + Ok(()) + } // TODO - // /// Consume this response and return a Client ready to transmit/receive data frames - // /// using the data frame type D, Sender B and Receiver C. - // /// - // /// Does not check if the response was valid. Use `validate()` to ensure that the response constitutes a successful handshake. - // pub fn begin_with(self, sender: B, receiver: C) -> Client - // where B: ws::Sender, C: ws::Receiver, D: DataFrameable { - // Client::new(sender, receiver) - // } + // /// Consume this response and return a Client ready to transmit/receive data frames + // /// using the data frame type D, Sender B and Receiver C. + // /// + // /// Does not check if the response was valid. Use `validate()` to ensure that the response constitutes a successful handshake. + // pub fn begin_with(self, sender: B, receiver: C) -> Client + // where B: ws::Sender, C: ws::Receiver, D: DataFrameable { + // Client::new(sender, receiver) + // } - // TODO - // /// Consume this response and return a Client ready to transmit/receive data frames. - // /// - // /// Does not check if the response was valid. Use `validate()` to ensure that the response constitutes a successful handshake. - // pub fn begin(self) -> Client, Receiver> { - // let (reader, writer) = self.into_inner(); - // let sender = Sender::new(writer, true); - // let receiver = Receiver::new(reader, false); - // Client::new(sender, receiver) - // } + // TODO + // /// Consume this response and return a Client ready to transmit/receive data frames. + // /// + // /// Does not check if the response was valid. Use `validate()` to ensure that the response constitutes a successful handshake. + // pub fn begin(self) -> Client, Receiver> { + // let (reader, writer) = self.into_inner(); + // let sender = Sender::new(writer, true); + // let receiver = Receiver::new(reader, false); + // Client::new(sender, receiver) + // } } From ae9b002b63326056be1460b29d72ac519170b7c4 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Tue, 28 Mar 2017 10:54:03 -0400 Subject: [PATCH 16/32] finished builder pattern and removed unused modules --- src/client/mod.rs | 538 +++++++++++++++++++++-------------------- src/client/request.rs | 147 ----------- src/client/response.rs | 139 ----------- src/result.rs | 5 +- src/stream.rs | 55 +++++ src/ws/receiver.rs | 1 - src/ws/sender.rs | 1 - src/ws/util/mod.rs | 1 - src/ws/util/url.rs | 343 -------------------------- 9 files changed, 335 insertions(+), 895 deletions(-) delete mode 100644 src/client/request.rs delete mode 100644 src/client/response.rs delete mode 100644 src/ws/util/url.rs diff --git a/src/client/mod.rs b/src/client/mod.rs index 91b1a2890b..180a56ffb8 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,7 +1,10 @@ //! Contains the WebSocket client. extern crate url; -use std::borrow::Borrow; +use std::borrow::{ + Borrow, + Cow, +}; use std::net::TcpStream; use std::marker::PhantomData; use std::io::Result as IoResult; @@ -13,6 +16,7 @@ use std::ops::Deref; use self::url::{ Url, + ParseError, Position, }; use openssl::ssl::{ @@ -20,6 +24,10 @@ use openssl::ssl::{ SslMethod, SslStream, }; +use openssl::ssl::error::SslError; +use hyper::buffer::BufReader; +use hyper::status::StatusCode; +use hyper::http::h1::parse_response; use hyper::version::HttpVersion; use hyper::header::{ Headers, @@ -34,7 +42,6 @@ use unicase::UniCase; use ws; use ws::sender::Sender as SenderTrait; -use ws::util::url::ToWebSocketUrlComponents; use ws::receiver::{ DataFrameIterator, MessageIterator, @@ -42,16 +49,22 @@ use ws::receiver::{ use ws::receiver::Receiver as ReceiverTrait; use header::extensions::Extension; use header::{ + WebSocketAccept, WebSocketKey, WebSocketVersion, WebSocketProtocol, WebSocketExtensions, Origin }; -use result::WebSocketResult; +use result::{ + WSUrlErrorKind, + WebSocketResult, + WebSocketError, +}; use stream::{ - AsTcpStream, - Stream, + BoxedNetworkStream, + AsTcpStream, + Stream, Splittable, Shutdown, }; @@ -60,24 +73,18 @@ use dataframe::DataFrame; use ws::dataframe::DataFrame as DataFrameable; use sender::Sender; use receiver::Receiver; -pub use self::request::Request; -pub use self::response::Response; pub use sender::Writer; pub use receiver::Reader; -pub mod request; -pub mod response; - -// TODO: if using URL, remove the to url components trait -// TODO: implement ToOwned /// Build clients with a builder-style API +#[derive(Clone, Debug)] pub struct ClientBuilder<'u, 's> { - url: &'u Url, + url: Cow<'u, Url>, version: HttpVersion, - headers: Headers, + headers: Headers, version_set: bool, key_set: bool, - ssl_context: Option<&'s SslContext>, + ssl_context: Option>, } macro_rules! upsert_header { @@ -100,10 +107,10 @@ macro_rules! upsert_header { } impl<'u, 's> ClientBuilder<'u, 's> { - pub fn new(url: &'u Url) -> Self { + pub fn new(url: Cow<'u, Url>) -> Self { ClientBuilder { url: url, - version: HttpVersion::Http11, + version: HttpVersion::Http11, version_set: false, key_set: false, ssl_context: None, @@ -165,11 +172,11 @@ impl<'u, 's> ClientBuilder<'u, 's> { self } - pub fn key(mut self, key: [u8; 16]) -> Self { + pub fn key(mut self, key: [u8; 16]) -> Self { self.headers.set(WebSocketKey(key)); self.key_set = true; self - } + } pub fn clear_key(mut self) -> Self { self.headers.remove::(); @@ -177,11 +184,11 @@ impl<'u, 's> ClientBuilder<'u, 's> { self } - pub fn version(mut self, version: WebSocketVersion) -> Self { + pub fn version(mut self, version: WebSocketVersion) -> Self { self.headers.set(version); self.version_set = true; self - } + } pub fn clear_version(mut self) -> Self { self.headers.remove::(); @@ -189,10 +196,10 @@ impl<'u, 's> ClientBuilder<'u, 's> { self } - pub fn origin(mut self, origin: String) -> Self { + pub fn origin(mut self, origin: String) -> Self { self.headers.set(Origin(origin)); self - } + } pub fn custom_headers(mut self, edit: F) -> Self where F: Fn(&mut Headers), @@ -202,133 +209,134 @@ impl<'u, 's> ClientBuilder<'u, 's> { } pub fn ssl_context(mut self, context: &'s SslContext) -> Self { - self.ssl_context = Some(context); + self.ssl_context = Some(Cow::Borrowed(context)); self } - pub fn connect_on(&mut self, mut stream: S) -> WebSocketResult> - where S: Stream, - { - // Get info about ports - let is_ssl = self.url.scheme() == "wss"; - let port = match self.url.port() { - Some(port) => port, - None if is_ssl => 443, - None => 80, + fn establish_tcp(&mut self, secure: Option) -> WebSocketResult { + let port = match (self.url.port(), secure) { + (Some(port), _) => port, + (None, None) if self.url.scheme() == "wss" => 443, + (None, None) => 80, + (None, Some(true)) => 443, + (None, Some(false)) => 80, }; let host = match self.url.host_str() { Some(h) => h, - None => unimplemented!(), + None => return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName)), + }; + + let tcp_stream = try!(TcpStream::connect((host, port))); + Ok(tcp_stream) + } + + fn wrap_ssl(&self, tcp_stream: TcpStream) -> Result, SslError> { + let context = match self.ssl_context { + Some(ref ctx) => Cow::Borrowed(ctx.as_ref()), + None => Cow::Owned(try!(SslContext::new(SslMethod::Tlsv1))), + }; + + SslStream::connect(&*context, tcp_stream) + } + + pub fn connect(&mut self) -> WebSocketResult> { + let tcp_stream = try!(self.establish_tcp(None)); + + let boxed_stream = if self.url.scheme() == "wss" { + BoxedNetworkStream(Box::new(try!(self.wrap_ssl(tcp_stream)))) + } else { + BoxedNetworkStream(Box::new(tcp_stream)) }; + + self.connect_on(boxed_stream) + } + + pub fn connect_insecure(&mut self) -> WebSocketResult> { + let tcp_stream = try!(self.establish_tcp(Some(false))); + + self.connect_on(tcp_stream) + } + + pub fn connect_secure(&mut self) -> WebSocketResult>> { + let tcp_stream = try!(self.establish_tcp(Some(true))); + + let ssl_stream = try!(self.wrap_ssl(tcp_stream)); + + self.connect_on(ssl_stream) + } + + // TODO: refactor and split apart into two parts, for when evented happens + pub fn connect_on(&mut self, mut stream: S) -> WebSocketResult> + where S: Stream, + { let resource = self.url[Position::BeforePath..Position::AfterQuery] .to_owned(); - // set host header and other headers - self.headers.set(Host { - hostname: host.to_string(), - port: self.url.port(), - }); + // enter host if available (unix sockets don't have hosts) + if let Some(host) = self.url.host_str() { + self.headers.set(Host { + hostname: host.to_string(), + port: self.url.port(), + }); + } - self.headers.set(Connection(vec![ - ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) - ])); + self.headers.set(Connection(vec![ + ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) + ])); - self.headers.set(Upgrade(vec![Protocol { - name: ProtocolName::WebSocket, + self.headers.set(Upgrade(vec![Protocol { + name: ProtocolName::WebSocket, // TODO: actually correct or just works? - version: None - }])); + version: None + }])); if !self.version_set { - self.headers.set(WebSocketVersion::WebSocket13); + self.headers.set(WebSocketVersion::WebSocket13); } if !self.key_set { - self.headers.set(WebSocketKey::new()); + self.headers.set(WebSocketKey::new()); + } + + // send request + try!(write!(stream.writer(), "GET {} {}\r\n", resource, self.version)); + try!(write!(stream.writer(), "{}\r\n", self.headers)); + + // wait for a response + // TODO: we should buffer it all, how to set up stream for this? + let response = try!(parse_response(&mut BufReader::new(stream.reader()))); + let status = StatusCode::from_u16(response.subject.0); + + // validate + if status != StatusCode::SwitchingProtocols { + return Err(WebSocketError::ResponseError("Status code must be Switching Protocols")); } - try!(write!(stream.writer(), "GET {} {}\r\n", resource, self.version)); - try!(write!(stream.writer(), "{}\r\n", self.headers)); - unimplemented!(); + let key = try!(self.headers.get::().ok_or( + WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid") + )); + + if response.headers.get() != Some(&(WebSocketAccept::new(key))) { + return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid")); + } + + if response.headers.get() != Some(&(Upgrade(vec![Protocol { + name: ProtocolName::WebSocket, + version: None + }]))) { + return Err(WebSocketError::ResponseError("Upgrade field must be WebSocket")); + } + + if self.headers.get() != Some(&(Connection(vec![ + ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())), + ]))) { + return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); + } + + Ok(Client::new(stream)) } } -// impl Client { -// /// Connects to the given ws:// URL and return a Request to be sent. -// /// -// /// If you would like to use a secure connection (wss://), please use `connect_secure`. -// /// -// /// A connection is established, however the request is not sent to -// /// the server until a call to ```send()```. -// pub fn connect(components: C) -> WebSocketResult> -// where C: ToWebSocketUrlComponents, -// { -// let (host, resource_name, secure) = try!(components.to_components()); -// let stream = TcpStream::connect((&host.hostname[..], host.port.unwrap_or(80))); -// let stream = try!(stream); -// Request::new((host, resource_name, secure), try!(stream.split())) -// } -// } - -// impl Client> { -// /// Connects to the specified wss:// URL using the given SSL context. -// /// -// /// If you would like to use an insecure connection (ws://), please use `connect`. -// /// -// /// A connection is established, however the request is not sent to -// /// the server until a call to ```send()```. -// pub fn connect_secure(components: C, context: Option<&SslContext>) -> WebSocketResult, SslStream>> -// where C: ToWebSocketUrlComponents, -// { -// let (host, resource_name, secure) = try!(components.to_components()); - -// let stream = TcpStream::connect((&host.hostname[..], host.port.unwrap_or(443))); -// let stream = try!(stream); -// let sslstream = if let Some(c) = context { -// SslStream::connect(c, stream) -// } else { -// let context = try!(SslContext::new(SslMethod::Tlsv1)); -// SslStream::connect(&context, stream) -// }; -// let sslstream = try!(sslstream); - -// Request::new((host, resource_name, secure), try!(sslstream.split())) -// } -// } - -// // TODO: look at how to get hyper to give you a stream then maybe remove this -// impl Client> { -// pub fn connect_agnostic(components: C, ssl_context: Option<&SslContext>) -> WebSocketResult, Box>> -// where C: ToWebSocketUrlComponents -// { -// let (host, resource_name, secure) = try!(components.to_components()); -// let port = match host.port { -// Some(p) => p, -// None => if secure { -// 443 -// } else { -// 80 -// }, -// }; -// let tcp_stream = try!(TcpStream::connect((&host.hostname[..], port))); - -// let stream: Box = if secure { -// if let Some(c) = ssl_context { -// Box::new(try!(SslStream::connect(c, tcp_stream))) -// } else { -// let context = try!(SslContext::new(SslMethod::Tlsv1)); -// Box::new(try!(SslStream::connect(&context, tcp_stream))) -// } -// } else { -// Box::new(tcp_stream) -// }; - -// let (read, write) = (try!(stream.duplicate()), stream); - -// Request::new((host, resource_name, secure), (read, write)) -// } -// } - /// Represents a WebSocket client, which can send and receive messages/data frames. /// /// `D` is the data frame type, `S` is the type implementing `Sender` and `R` @@ -365,84 +373,90 @@ impl<'u, 's> ClientBuilder<'u, 's> { pub struct Client where S: Stream, { - stream: S, + stream: S, sender: Sender, receiver: Receiver, } -// TODO: maybe make shutdown options only for TcpStream? how does it work with SslStream? +impl Client { + /// Shuts down the sending half of the client connection, will cause all pending + /// and future IO to return immediately with an appropriate value. + pub fn shutdown_sender(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Write) + } + + /// Shuts down the receiving half of the client connection, will cause all pending + /// and future IO to return immediately with an appropriate value. + pub fn shutdown_receiver(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Read) + } +} + +// TODO: add net2 set_nonblocking and stuff impl Client where S: AsTcpStream + Stream, { - /// Shuts down the sending half of the client connection, will cause all pending - /// and future IO to return immediately with an appropriate value. - pub fn shutdown_sender(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Write) - } - - /// Shuts down the receiving half of the client connection, will cause all pending - /// and future IO to return immediately with an appropriate value. - pub fn shutdown_receiver(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Read) - } - - /// Shuts down the client connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Both) - } - - // TODO: add net2 set_nonblocking and stuff + /// Shuts down the client connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Both) + } } impl<'u, 'p, 'e, 's, S> Client where S: Stream, { + pub fn from_url(address: &'u Url) -> ClientBuilder<'u, 's> { + ClientBuilder::new(Cow::Borrowed(address)) + } - pub fn build(address: &'u Url) -> ClientBuilder<'u, 's> { - ClientBuilder::new(address) + pub fn build(address: &str) -> Result, ParseError> { + let url = try!(Url::parse(address)); + Ok(ClientBuilder::new(Cow::Owned(url))) } - /// Creates a Client from the given Sender and Receiver. - /// - /// Esstiallthe opposite of `Client.split()`. - fn new(stream: S) -> Self { - Client { + /// Creates a Client from the given Sender and Receiver. + /// + /// Esstiallthe opposite of `Client.split()`. + fn new(stream: S) -> Self { + Client { stream: stream, // TODO: always true? sender: Sender::new(true), // TODO: always false? receiver: Receiver::new(false), - } - } - - /// Sends a single data frame to the remote endpoint. - pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> - where D: DataFrameable { - self.sender.send_dataframe(self.stream.writer(), dataframe) - } - - /// Sends a single message to the remote endpoint. - pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> - where M: ws::Message<'m, D>, D: DataFrameable { - self.sender.send_message(self.stream.writer(), message) - } - - /// Reads a single data frame from the remote endpoint. - pub fn recv_dataframe(&mut self) -> WebSocketResult { - self.receiver.recv_dataframe(self.stream.reader()) - } - - /// Returns an iterator over incoming data frames. - pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, S::Reader> { - self.receiver.incoming_dataframes(self.stream.reader()) - } - - /// Reads a single message from this receiver. - pub fn recv_message<'m, M, I>(&mut self) -> WebSocketResult - where M: ws::Message<'m, DataFrame, DataFrameIterator = I>, I: Iterator { - self.receiver.recv_message(self.stream.reader()) - } + } + } + + /// Sends a single data frame to the remote endpoint. + pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> + where D: DataFrameable + { + self.sender.send_dataframe(self.stream.writer(), dataframe) + } + + /// Sends a single message to the remote endpoint. + pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> + where M: ws::Message<'m, D>, D: DataFrameable + { + self.sender.send_message(self.stream.writer(), message) + } + + /// Reads a single data frame from the remote endpoint. + pub fn recv_dataframe(&mut self) -> WebSocketResult { + self.receiver.recv_dataframe(self.stream.reader()) + } + + /// Returns an iterator over incoming data frames. + pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, S::Reader> { + self.receiver.incoming_dataframes(self.stream.reader()) + } + + /// Reads a single message from this receiver. + pub fn recv_message<'m, M, I>(&mut self) -> WebSocketResult + where M: ws::Message<'m, DataFrame, DataFrameIterator = I>, I: Iterator { + self.receiver.recv_message(self.stream.reader()) + } pub fn stream_ref(&self) -> &S { &self.stream @@ -452,92 +466,92 @@ impl<'u, 'p, 'e, 's, S> Client &mut self.stream } - /// Returns an iterator over incoming messages. - /// - ///```no_run - ///# extern crate websocket; - ///# fn main() { - ///use websocket::{Client, Message}; - ///# use websocket::client::request::Url; - ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL - ///# let request = Client::connect(url).unwrap(); // Connect to the server - ///# let response = request.send().unwrap(); // Send the request - ///# response.validate().unwrap(); // Ensure the response is valid - /// - ///let mut client = response.begin(); // Get a Client - /// - ///for message in client.incoming_messages() { + /// Returns an iterator over incoming messages. + /// + ///```no_run + ///# extern crate websocket; + ///# fn main() { + ///use websocket::{Client, Message}; + ///# use websocket::client::request::Url; + ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL + ///# let request = Client::connect(url).unwrap(); // Connect to the server + ///# let response = request.send().unwrap(); // Send the request + ///# response.validate().unwrap(); // Ensure the response is valid + /// + ///let mut client = response.begin(); // Get a Client + /// + ///for message in client.incoming_messages() { /// let message: Message = message.unwrap(); - /// println!("Recv: {:?}", message); - ///} - ///# } - ///``` - /// - /// Note that since this method mutably borrows the `Client`, it may be necessary to - /// first `split()` the `Client` and call `incoming_messages()` on the returned - /// `Receiver` to be able to send messages within an iteration. - /// - ///```no_run - ///# extern crate websocket; - ///# fn main() { - ///use websocket::{Client, Message, Sender, Receiver}; - ///# use websocket::client::request::Url; - ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL - ///# let request = Client::connect(url).unwrap(); // Connect to the server - ///# let response = request.send().unwrap(); // Send the request - ///# response.validate().unwrap(); // Ensure the response is valid - /// - ///let client = response.begin(); // Get a Client - ///let (mut sender, mut receiver) = client.split(); // Split the Client - ///for message in receiver.incoming_messages() { + /// println!("Recv: {:?}", message); + ///} + ///# } + ///``` + /// + /// Note that since this method mutably borrows the `Client`, it may be necessary to + /// first `split()` the `Client` and call `incoming_messages()` on the returned + /// `Receiver` to be able to send messages within an iteration. + /// + ///```no_run + ///# extern crate websocket; + ///# fn main() { + ///use websocket::{Client, Message, Sender, Receiver}; + ///# use websocket::client::request::Url; + ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL + ///# let request = Client::connect(url).unwrap(); // Connect to the server + ///# let response = request.send().unwrap(); // Send the request + ///# response.validate().unwrap(); // Ensure the response is valid + /// + ///let client = response.begin(); // Get a Client + ///let (mut sender, mut receiver) = client.split(); // Split the Client + ///for message in receiver.incoming_messages() { /// let message: Message = message.unwrap(); - /// // Echo the message back - /// sender.send_message(&message).unwrap(); - ///} - ///# } - ///``` - pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, S::Reader> - where M: ws::Message<'a, D>, + /// // Echo the message back + /// sender.send_message(&message).unwrap(); + ///} + ///# } + ///``` + pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, S::Reader> + where M: ws::Message<'a, D>, D: DataFrameable { - self.receiver.incoming_messages(self.stream.reader()) - } + self.receiver.incoming_messages(self.stream.reader()) + } } impl Client where S: Splittable + Stream, { - /// Split this client into its constituent Sender and Receiver pair. - /// - /// This allows the Sender and Receiver to be sent to different threads. - /// - ///```no_run - ///# extern crate websocket; - ///# fn main() { - ///use websocket::{Client, Message, Sender, Receiver}; - ///use std::thread; - ///# use websocket::client::request::Url; - ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL - ///# let request = Client::connect(url).unwrap(); // Connect to the server - ///# let response = request.send().unwrap(); // Send the request - ///# response.validate().unwrap(); // Ensure the response is valid - /// - ///let client = response.begin(); // Get a Client - /// - ///let (mut sender, mut receiver) = client.split(); - /// - ///thread::spawn(move || { - /// for message in receiver.incoming_messages() { + /// Split this client into its constituent Sender and Receiver pair. + /// + /// This allows the Sender and Receiver to be sent to different threads. + /// + ///```no_run + ///# extern crate websocket; + ///# fn main() { + ///use websocket::{Client, Message, Sender, Receiver}; + ///use std::thread; + ///# use websocket::client::request::Url; + ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL + ///# let request = Client::connect(url).unwrap(); // Connect to the server + ///# let response = request.send().unwrap(); // Send the request + ///# response.validate().unwrap(); // Ensure the response is valid + /// + ///let client = response.begin(); // Get a Client + /// + ///let (mut sender, mut receiver) = client.split(); + /// + ///thread::spawn(move || { + /// for message in receiver.incoming_messages() { /// let message: Message = message.unwrap(); - /// println!("Recv: {:?}", message); - /// } - ///}); - /// - ///let message = Message::text("Hello, World!"); - ///sender.send_message(&message).unwrap(); - ///# } - ///``` - pub fn split(self) -> IoResult<(Reader<::Reader>, Writer<::Writer>)> { + /// println!("Recv: {:?}", message); + /// } + ///}); + /// + ///let message = Message::text("Hello, World!"); + ///sender.send_message(&message).unwrap(); + ///# } + ///``` + pub fn split(self) -> IoResult<(Reader<::Reader>, Writer<::Writer>)> { let (read, write) = try!(self.stream.split()); Ok((Reader { reader: read, @@ -546,5 +560,5 @@ impl Client writer: write, sender: self.sender, })) - } + } } diff --git a/src/client/request.rs b/src/client/request.rs deleted file mode 100644 index 92224d23b9..0000000000 --- a/src/client/request.rs +++ /dev/null @@ -1,147 +0,0 @@ -//! Structs for client-side (outbound) WebSocket requests -use std::io::{Read, Write}; - -pub use url::Url; - -use hyper::version::HttpVersion; -use hyper::buffer::BufReader; -use hyper::header::Headers; -use hyper::header::{Connection, ConnectionOption}; -use hyper::header::{Upgrade, Protocol, ProtocolName}; - -use unicase::UniCase; - -use header::{WebSocketKey, WebSocketVersion, WebSocketProtocol, WebSocketExtensions, Origin}; -use result::WebSocketResult; -use client::response::Response; -use ws::util::url::ToWebSocketUrlComponents; - -/// Represents a WebSocket request. -/// -/// Note that nothing is written to the internal Writer until the `send()` method is called. -pub struct Request { - /// The HTTP version of this request. - pub version: HttpVersion, - /// The headers of this request. - pub headers: Headers, - - resource_name: String, - reader: BufReader, - writer: W, -} - -unsafe impl Send for Request where R: Read + Send, W: Write + Send { } - -impl Request { - /// Creates a new client-side request. - /// - /// In general `Client::connect()` should be used for connecting to servers. - /// However, if the request is to be written to a different Writer, this function - /// may be used. - pub fn new(components: T, stream: (R, W)) -> WebSocketResult> - where T: ToWebSocketUrlComponents, - { - let (reader, writer) = stream; - let mut headers = Headers::new(); - let (host, resource_name, _) = try!(components.to_components()); - headers.set(host); - headers.set(Connection(vec![ - ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) - ])); - headers.set(Upgrade(vec![Protocol{ - name: ProtocolName::WebSocket, - version: None - }])); - headers.set(WebSocketVersion::WebSocket13); - headers.set(WebSocketKey::new()); - - Ok(Request { - version: HttpVersion::Http11, - headers: headers, - resource_name: resource_name, - reader: BufReader::new(reader), - writer: writer - }) - } - /// Short-cut to obtain the WebSocketKey value. - pub fn key(&self) -> Option<&WebSocketKey> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketVersion value. - pub fn version(&self) -> Option<&WebSocketVersion> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketProtocol value. - pub fn protocol(&self) -> Option<&WebSocketProtocol> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketExtensions value. - pub fn extensions(&self) -> Option<&WebSocketExtensions> { - self.headers.get() - } - /// Short-cut to obtain the Origin value. - pub fn origin(&self) -> Option<&Origin> { - self.headers.get() - } - /// Short-cut to obtain a mutable reference to the WebSocketKey value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn key_mut(&mut self) -> Option<&mut WebSocketKey> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the WebSocketVersion value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn version_mut(&mut self) -> Option<&mut WebSocketVersion> { - self.headers.get_mut() - } - /// Short-cut to obtaina mutable reference to the WebSocketProtocol value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn protocol_mut(&mut self) -> Option<&mut WebSocketProtocol> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the WebSocketExtensions value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn extensions_mut(&mut self) -> Option<&mut WebSocketExtensions> { - self.headers.get_mut() - } - /// Short-cut to obtain a mutable reference to the Origin value. - /// - /// Note that to add a header that does not already exist, ```Request.headers.set()``` - /// must be used. - pub fn origin_mut(&mut self) -> Option<&mut Origin> { - self.headers.get_mut() - } - /// Returns a reference to the inner Reader. - pub fn get_reader(&self) -> &BufReader { - &self.reader - } - /// Returns a reference to the inner Writer. - pub fn get_writer(&self) -> &W { - &self.writer - } - /// Returns a mutable reference to the inner Reader. - pub fn get_mut_reader(&mut self) -> &mut BufReader { - &mut self.reader - } - /// Returns a mutable reference to the inner Writer. - pub fn get_mut_writer(&mut self) -> &mut W { - &mut self.writer - } - /// Return the inner Reader and Writer. - pub fn into_inner(self) -> (BufReader, W) { - (self.reader, self.writer) - } - /// Sends the request to the server and returns a response. - pub fn send(mut self) -> WebSocketResult> { - try!(write!(&mut self.writer, "GET {} {}\r\n", self.resource_name, self.version)); - try!(write!(&mut self.writer, "{}\r\n", self.headers)); - Response::read(self) - } -} diff --git a/src/client/response.rs b/src/client/response.rs deleted file mode 100644 index 8fbce9f3d8..0000000000 --- a/src/client/response.rs +++ /dev/null @@ -1,139 +0,0 @@ -//! Structs for WebSocket responses -use std::option::Option; -use std::io::{Read, Write}; - -use hyper::status::StatusCode; -use hyper::buffer::BufReader; -use hyper::version::HttpVersion; -use hyper::header::Headers; -use hyper::header::{Connection, ConnectionOption}; -use hyper::header::{Upgrade, Protocol, ProtocolName}; -use hyper::http::h1::parse_response; - -use unicase::UniCase; - -use header::{WebSocketAccept, WebSocketProtocol, WebSocketExtensions}; - -use client::{Client, Request, Sender, Receiver}; -use result::{WebSocketResult, WebSocketError}; -use dataframe::DataFrame; -use ws::dataframe::DataFrame as DataFrameable; -use ws; - -/// Represents a WebSocket response. -pub struct Response { - /// The status of the response - pub status: StatusCode, - /// The headers contained in this response - pub headers: Headers, - /// The HTTP version of this response - pub version: HttpVersion, - - request: Request -} - -unsafe impl Send for Response where R: Read + Send, W: Write + Send { } - -impl Response { - /// Reads a Response off the stream associated with a Request. - /// - /// This is called by Request.send(), and does not need to be called by the user. - pub fn read(mut request: Request) -> WebSocketResult> { - let (status, version, headers) = { - let reader = request.get_mut_reader(); - - let response = try!(parse_response(reader)); - - let status = StatusCode::from_u16(response.subject.0); - (status, response.version, response.headers) - }; - - Ok(Response { - status: status, - headers: headers, - version: version, - request: request - }) - } - - /// Short-cut to obtain the WebSocketAccept value. - pub fn accept(&self) -> Option<&WebSocketAccept> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketProtocol value. - pub fn protocol(&self) -> Option<&WebSocketProtocol> { - self.headers.get() - } - /// Short-cut to obtain the WebSocketExtensions value. - pub fn extensions(&self) -> Option<&WebSocketExtensions> { - self.headers.get() - } - /// Returns a reference to the inner Reader. - pub fn get_reader(&self) -> &BufReader { - self.request.get_reader() - } - /// Returns a reference to the inner Writer. - pub fn get_writer(&self) -> &W { - self.request.get_writer() - } - /// Returns a mutable reference to the inner Reader. - pub fn get_mut_reader(&mut self) -> &mut BufReader { - self.request.get_mut_reader() - } - /// Returns a mutable reference to the inner Writer. - pub fn get_mut_writer(&mut self) -> &mut W { - self.request.get_mut_writer() - } - /// Returns a reference to the request associated with this response. - pub fn get_request(&self) -> &Request { - &self.request - } - /// Return the inner Reader and Writer. - pub fn into_inner(self) -> (BufReader, W) { - self.request.into_inner() - } - - /// Check if this response constitutes a successful handshake. - pub fn validate(&self) -> WebSocketResult<()> { - if self.status != StatusCode::SwitchingProtocols { - return Err(WebSocketError::ResponseError("Status code must be Switching Protocols")); - } - let key = try!(self.request.key().ok_or( - WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid") - )); - if self.accept() != Some(&(WebSocketAccept::new(key))) { - return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid")); - } - if self.headers.get() != Some(&(Upgrade(vec![Protocol{ - name: ProtocolName::WebSocket, - version: None - }]))) { - return Err(WebSocketError::ResponseError("Upgrade field must be WebSocket")); - } - if self.headers.get() != Some(&(Connection(vec![ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string()))]))) { - return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); - } - Ok(()) - } - - // TODO - // /// Consume this response and return a Client ready to transmit/receive data frames - // /// using the data frame type D, Sender B and Receiver C. - // /// - // /// Does not check if the response was valid. Use `validate()` to ensure that the response constitutes a successful handshake. - // pub fn begin_with(self, sender: B, receiver: C) -> Client - // where B: ws::Sender, C: ws::Receiver, D: DataFrameable { - // Client::new(sender, receiver) - // } - - // TODO - // /// Consume this response and return a Client ready to transmit/receive data frames. - // /// - // /// Does not check if the response was valid. Use `validate()` to ensure that the response constitutes a successful handshake. - // pub fn begin(self) -> Client, Receiver> { - // let (reader, writer) = self.into_inner(); - // let sender = Sender::new(writer, true); - // let receiver = Receiver::new(reader, false); - // Client::new(sender, receiver) - // } -} diff --git a/src/result.rs b/src/result.rs index f6b9b25fa8..c7eb9d1943 100644 --- a/src/result.rs +++ b/src/result.rs @@ -123,6 +123,8 @@ pub enum WSUrlErrorKind { CannotSetFragment, /// The scheme provided is invalid for a WebSocket InvalidScheme, + /// There is no hostname or IP address to connect to + NoHostName, } impl fmt::Display for WSUrlErrorKind { @@ -137,7 +139,8 @@ impl Error for WSUrlErrorKind { fn description(&self) -> &str { match *self { WSUrlErrorKind::CannotSetFragment => "WebSocket URL cannot set fragment", - WSUrlErrorKind::InvalidScheme => "WebSocket URL invalid scheme" + WSUrlErrorKind::InvalidScheme => "WebSocket URL invalid scheme", + WSUrlErrorKind::NoHostName => "WebSocket URL no host name provided", } } } diff --git a/src/stream.rs b/src/stream.rs index b0f0eeafa6..1891f99163 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -2,6 +2,7 @@ // TODO: add mio support & tokio // extern crate mio; +use std::ops::Deref; use std::io::{ self, Read, @@ -69,6 +70,52 @@ impl Stream for ReadWritePair } } +pub trait ReadWrite: Read + Write {} +impl ReadWrite for S where S: Read + Write {} + +pub struct BoxedStream(pub Box); + +impl Stream for BoxedStream { + type Reader = Box; + type Writer = Box; + + #[inline] + fn reader(&mut self) -> &mut Self::Reader { + &mut self.0 + } + + #[inline] + fn writer(&mut self) -> &mut Self::Writer { + &mut self.0 + } +} + +pub trait NetworkStream: Read + Write + AsTcpStream {} +impl NetworkStream for S where S: Read + Write + AsTcpStream {} + +pub struct BoxedNetworkStream(pub Box); + +impl AsTcpStream for BoxedNetworkStream { + fn as_tcp(&self) -> &TcpStream { + self.0.deref().as_tcp() + } +} + +impl Stream for BoxedNetworkStream { + type Reader = Box; + type Writer = Box; + + #[inline] + fn reader(&mut self) -> &mut Self::Reader { + &mut self.0 + } + + #[inline] + fn writer(&mut self) -> &mut Self::Writer { + &mut self.0 + } +} + impl Splittable for TcpStream { type Reader = TcpStream; type Writer = TcpStream; @@ -111,6 +158,14 @@ impl AsTcpStream for SslStream { } } +impl AsTcpStream for Box + where T: AsTcpStream, +{ + fn as_tcp(&self) -> &TcpStream { + self.deref().as_tcp() + } +} + /// Marker struct for having no SSL context in a struct. #[derive(Clone)] pub struct NoSslContext; diff --git a/src/ws/receiver.rs b/src/ws/receiver.rs index bea0c1d962..3da8ade00f 100644 --- a/src/ws/receiver.rs +++ b/src/ws/receiver.rs @@ -9,7 +9,6 @@ use ws::Message; use ws::dataframe::DataFrame; use result::WebSocketResult; -// TODO: maybe this is not needed anymore /// A trait for receiving data frames and messages. pub trait Receiver: Sized { diff --git a/src/ws/sender.rs b/src/ws/sender.rs index 0a76ce66eb..ca2bde11ea 100644 --- a/src/ws/sender.rs +++ b/src/ws/sender.rs @@ -7,7 +7,6 @@ use ws::Message; use ws::dataframe::DataFrame; use result::WebSocketResult; -// TODO: maybe this is not needed anymore /// A trait for sending data frames and messages. pub trait Sender { /// Sends a single data frame using this sender. diff --git a/src/ws/util/mod.rs b/src/ws/util/mod.rs index e5878006b8..c59e04db63 100644 --- a/src/ws/util/mod.rs +++ b/src/ws/util/mod.rs @@ -2,7 +2,6 @@ pub mod header; pub mod mask; -pub mod url; use std::str::from_utf8; use std::str::Utf8Error; diff --git a/src/ws/util/url.rs b/src/ws/util/url.rs deleted file mode 100644 index f6862c12b1..0000000000 --- a/src/ws/util/url.rs +++ /dev/null @@ -1,343 +0,0 @@ -//! Utility functions for dealing with URLs - -use url::{Url, Position}; -use url::Host as UrlHost; -use hyper::header::Host; -use result::{WebSocketResult, WSUrlErrorKind}; - -/// Trait that gets required WebSocket URL components -pub trait ToWebSocketUrlComponents { - /// Retrieve the required WebSocket URL components from this - fn to_components(&self) -> WebSocketResult<(Host, String, bool)>; -} - -impl ToWebSocketUrlComponents for str { - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - parse_url_str(&self) - } -} - -impl ToWebSocketUrlComponents for Url { - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - parse_url(&self) - } -} - -impl ToWebSocketUrlComponents for (Host, String, bool) { - /// Convert a Host, resource name and secure flag to WebSocket URL components. - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - let (mut host, mut resource_name, secure) = self.clone(); - host.port = Some(match host.port { - Some(port) => port, - None => if secure { 443 } else { 80 }, - }); - if resource_name.is_empty() { - resource_name = "/".to_owned(); - } - Ok((host, resource_name, secure)) - } -} - -impl<'a> ToWebSocketUrlComponents for (Host, &'a str, bool) { - /// Convert a Host, resource name and secure flag to WebSocket URL components. - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - (self.0.clone(), self.1.to_owned(), self.2).to_components() - } -} - -impl<'a> ToWebSocketUrlComponents for (Host, &'a str) { - /// Convert a Host and resource name to WebSocket URL components, assuming an insecure connection. - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - (self.0.clone(), self.1.to_owned(), false).to_components() - } -} - -impl ToWebSocketUrlComponents for (Host, String) { - /// Convert a Host and resource name to WebSocket URL components, assuming an insecure connection. - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - (self.0.clone(), self.1.clone(), false).to_components() - } -} - -impl ToWebSocketUrlComponents for (UrlHost, u16, String, bool) { - /// Convert a Host, port, resource name and secure flag to WebSocket URL components. - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - (Host { - hostname: self.0.to_string(), - port: Some(self.1) - }, self.2.clone(), self.3).to_components() - } -} - -impl<'a> ToWebSocketUrlComponents for (UrlHost, u16, &'a str, bool) { - /// Convert a Host, port, resource name and secure flag to WebSocket URL components. - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - (Host { - hostname: self.0.to_string(), - port: Some(self.1) - }, self.2, self.3).to_components() - } -} - -impl<'a, T: ToWebSocketUrlComponents> ToWebSocketUrlComponents for &'a T { - fn to_components(&self) -> WebSocketResult<(Host, String, bool)> { - (**self).to_components() - } -} - -/// Gets the host, port, resource and secure from the string representation of a url -pub fn parse_url_str(url_str: &str) -> WebSocketResult<(Host, String, bool)> { - // https://html.spec.whatwg.org/multipage/#parse-a-websocket-url's-components - // Steps 1 and 2 - let parsed_url = try!(Url::parse(url_str)); - parse_url(&parsed_url) -} - -/// Gets the host, port, resource, and secure flag from a url -pub fn parse_url(url: &Url) -> WebSocketResult<(Host, String, bool)> { - // https://html.spec.whatwg.org/multipage/#parse-a-websocket-url's-components - - // Step 4 - if url.fragment().is_some() { - return Err(From::from(WSUrlErrorKind::CannotSetFragment)); - } - - let secure = match url.scheme() { - // step 5 - "ws" => false, - "wss" => true, - // step 3 - _ => return Err(From::from(WSUrlErrorKind::InvalidScheme)), - }; - - let host = url.host_str().unwrap().to_owned(); // Step 6 - let port = url.port_or_known_default(); // Steps 7 and 8 - - // steps 9, 10, 11 - let resource = url[Position::BeforePath..Position::AfterQuery].to_owned(); - - // Step 12 - Ok((Host { hostname: host, port: port }, resource, secure)) -} - -#[cfg(all(feature = "nightly", test))] -mod tests { - use super::*; - //use test; - use url::Url; - use result::{WebSocketError, WSUrlErrorKind}; - - fn url_for_test() -> Url { - Url::parse("ws://www.example.com:8080/some/path?a=b&c=d").unwrap() - } - - #[test] - fn test_parse_url_fragments_not_accepted() { - let url = &mut url_for_test(); - url.set_fragment(Some("non_null_fragment")); - - let result = parse_url(url); - match result { - Err(WebSocketError::WebSocketUrlError( - WSUrlErrorKind::CannotSetFragment)) => (), - Err(e) => panic!("Expected WSUrlErrorKind::CannotSetFragment but got {}", e), - Ok(_) => panic!("Expected WSUrlErrorKind::CannotSetFragment but got Ok") - } - } - - #[test] - fn test_parse_url_invalid_schemes_return_error() { - let url = &mut url_for_test(); - - let invalid_schemes = &["http", "https", "gopher", "file", "ftp", "other"]; - for scheme in invalid_schemes { - url.set_scheme(scheme).unwrap(); - - let result = parse_url(url); - match result { - Err(WebSocketError::WebSocketUrlError( - WSUrlErrorKind::InvalidScheme)) => (), - Err(e) => panic!("Expected WSUrlErrorKind::InvalidScheme but got {}", e), - Ok(_) => panic!("Expected WSUrlErrorKind::InvalidScheme but got Ok") - } - } - } - - #[test] - fn test_parse_url_valid_schemes_return_ok() { - let url = &mut url_for_test(); - - let valid_schemes = &["ws", "wss"]; - for scheme in valid_schemes { - url.set_scheme(scheme).unwrap(); - - let result = parse_url(url); - match result { - Ok(_) => (), - Err(e) => panic!("Expected Ok, but got {}", e) - } - } - } - - #[test] - fn test_parse_url_ws_returns_unset_secure_flag() { - let url = &mut url_for_test(); - url.set_scheme("ws").unwrap(); - - let result = parse_url(url); - let secure = match result { - Ok((_, _, secure)) => secure, - Err(e) => panic!(e), - }; - assert!(!secure); - } - - #[test] - fn test_parse_url_wss_returns_set_secure_flag() { - let url = &mut url_for_test(); - url.set_scheme("wss").unwrap(); - - let result = parse_url(url); - let secure = match result { - Ok((_, _, secure)) => secure, - Err(e) => panic!(e), - }; - assert!(secure); - } - - #[test] - fn test_parse_url_generates_proper_output() { - let url = &url_for_test(); - - let result = parse_url(url); - let (host, resource) = match result { - Ok((host, resource, _)) => (host, resource), - Err(e) => panic!(e), - }; - - assert_eq!(host.hostname, "www.example.com".to_owned()); - assert_eq!(resource, "/some/path?a=b&c=d".to_owned()); - - match host.port { - Some(port) => assert_eq!(port, 8080), - _ => panic!("Port should not be None"), - } - } - - #[test] - fn test_parse_url_empty_path_should_give_slash() { - let url = &mut url_for_test(); - url.set_path("/"); - - let result = parse_url(url); - let resource = match result { - Ok((_, resource, _)) => resource, - Err(e) => panic!(e), - }; - - assert_eq!(resource, "/?a=b&c=d".to_owned()); - } - - #[test] - fn test_parse_url_none_query_should_not_append_question_mark() { - let url = &mut url_for_test(); - url.set_query(None); - - let result = parse_url(url); - let resource = match result { - Ok((_, resource, _)) => resource, - Err(e) => panic!(e), - }; - - assert_eq!(resource, "/some/path".to_owned()); - } - - #[test] - fn test_parse_url_none_port_should_use_default_port() { - let url = &mut url_for_test(); - url.set_port(None).unwrap(); - - let result = parse_url(url); - let host = match result { - Ok((host, _, _)) => host, - Err(e) => panic!(e), - }; - - match host.port { - Some(80) => (), - Some(p) => panic!("Expected port to be 80 but got {}", p), - None => panic!("Expected port to be 80 but got `None`"), - } - } - - #[test] - fn test_parse_url_str_valid_url1() { - let url_str = "ws://www.example.com/some/path?a=b&c=d"; - let result = parse_url_str(url_str); - let (host, resource, secure) = match result { - Ok((host, resource, secure)) => (host, resource, secure), - Err(e) => panic!(e), - }; - - match host.port { - Some(80) => (), - Some(p) => panic!("Expected port 80 but got {}", p), - None => panic!("Expected port 80 but got `None`") - } - assert_eq!(host.hostname, "www.example.com".to_owned()); - assert_eq!(resource, "/some/path?a=b&c=d".to_owned()); - assert!(!secure); - } - - #[test] - fn test_parse_url_str_valid_url2() { - let url_str = "wss://www.example.com"; - let result = parse_url_str(url_str); - let (host, resource, secure) = match result { - Ok((host, resource, secure)) => (host, resource, secure), - Err(e) => panic!(e) - }; - - match host.port { - Some(443) => (), - Some(p) => panic!("Expected port 443 but got {}", p), - None => panic!("Expected port 443 but got `None`") - } - assert_eq!(host.hostname, "www.example.com".to_owned()); - assert_eq!(resource, "/".to_owned()); - assert!(secure); - } - - #[test] - fn test_parse_url_str_invalid_relative_url() { - let url_str = "/some/relative/path?a=b&c=d"; - let result = parse_url_str(url_str); - match result { - Err(WebSocketError::UrlError(_)) => (), - Err(e) => panic!("Expected UrlError, but got unexpected error {}", e), - Ok(_) => panic!("Expected UrlError, but got Ok"), - } - } - - #[test] - fn test_parse_url_str_invalid_url_scheme() { - let url_str = "http://www.example.com/some/path?a=b&c=d"; - let result = parse_url_str(url_str); - match result { - Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::InvalidScheme)) => (), - Err(e) => panic!("Expected InvalidScheme, but got unexpected error {}", e), - Ok(_) => panic!("Expected InvalidScheme, but got Ok"), - } - } - - #[test] - fn test_parse_url_str_invalid_url_fragment() { - let url_str = "http://www.example.com/some/path#some-id"; - let result = parse_url_str(url_str); - match result { - Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::CannotSetFragment)) => (), - Err(e) => panic!("Expected CannotSetFragment, but got unexpected error {}", e), - Ok(_) => panic!("Expected CannotSetFragment, but got Ok"), - } - } -} From c582ee4f6c4060c7eb664eb8f8d9c370fa23e0a2 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Tue, 28 Mar 2017 12:54:05 -0400 Subject: [PATCH 17/32] implement intows for hyper request, cleanup and refactoring --- src/client/mod.rs | 4 - src/receiver.rs | 1 - src/sender.rs | 2 +- src/server/mod.rs | 2 +- src/server/upgrade/hyper.rs | 254 +++++------------------------------- src/server/upgrade/mod.rs | 243 ++++++++++++++++++++++++++++++---- 6 files changed, 249 insertions(+), 257 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 180a56ffb8..663c2e357f 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -2,17 +2,13 @@ extern crate url; use std::borrow::{ - Borrow, Cow, }; use std::net::TcpStream; -use std::marker::PhantomData; use std::io::Result as IoResult; use std::io::{ - Read, Write, }; -use std::ops::Deref; use self::url::{ Url, diff --git a/src/receiver.rs b/src/receiver.rs index 16d757f4c1..951c6f30e5 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -2,7 +2,6 @@ use std::io::Read; use std::io::Result as IoResult; -use hyper::buffer::BufReader; use dataframe::{ DataFrame, diff --git a/src/sender.rs b/src/sender.rs index fafec5d982..5e78f6c68e 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -18,7 +18,7 @@ impl Writer where W: Write, { /// Sends a single data frame to the remote endpoint. - fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> + pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> where D: DataFrame, W: Write, { diff --git a/src/server/mod.rs b/src/server/mod.rs index ed8acaaa2f..acdd159301 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -28,7 +28,7 @@ use self::upgrade::{ WsUpgrade, IntoWs, }; -pub use self::upgrade::hyper::{ +pub use self::upgrade::{ Request, HyperIntoWsError, }; diff --git a/src/server/upgrade/hyper.rs b/src/server/upgrade/hyper.rs index 6bf35abf43..323fc329b2 100644 --- a/src/server/upgrade/hyper.rs +++ b/src/server/upgrade/hyper.rs @@ -1,38 +1,12 @@ extern crate hyper; extern crate openssl; -use std::net::TcpStream; -use std::any::Any; -use std::convert::From; -use std::error::Error; -use openssl::ssl::SslStream; -use hyper::http::h1::parse_request; use hyper::net::{ - NetworkStream, - HttpStream, - HttpsStream, -}; -use header::{ - WebSocketKey, - WebSocketVersion, -}; -use std::fmt::{ - Formatter, - Display, - self, -}; -use stream::{ - Stream, - AsTcpStream, + NetworkStream, }; use super::{ - IntoWs, - WsUpgrade, -}; -use std::io::{ - Read, - Write, - self, + IntoWs, + WsUpgrade, }; pub use hyper::http::h1::Incoming; @@ -40,207 +14,39 @@ pub use hyper::method::Method; pub use hyper::version::HttpVersion; pub use hyper::uri::RequestUri; pub use hyper::buffer::BufReader; -pub use hyper::server::Request as HyperRequest; +use hyper::server::Request; pub use hyper::header::{ - Headers, - Upgrade, - ProtocolName, - Connection, - ConnectionOption, + Headers, + Upgrade, + ProtocolName, + Connection, + ConnectionOption, }; -pub type Request = Incoming<(Method, RequestUri)>; - -pub struct RequestStreamPair(pub S, pub Request); - -#[derive(Debug)] -pub enum HyperIntoWsError { - MethodNotGet, - UnsupportedHttpVersion, - UnsupportedWebsocketVersion, - NoSecWsKeyHeader, - NoWsUpgradeHeader, - NoUpgradeHeader, - NoWsConnectionHeader, - NoConnectionHeader, - UnknownNetworkStream, - /// IO error from reading the underlying socket - Io(io::Error), - /// Error while parsing an incoming request - Parsing(hyper::error::Error), -} - -impl Display for HyperIntoWsError { - fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { - fmt.write_str(self.description()) - } -} - -impl Error for HyperIntoWsError { - fn description(&self) -> &str { - use self::HyperIntoWsError::*; - match self { - &MethodNotGet => "Request method must be GET", - &UnsupportedHttpVersion => "Unsupported request HTTP version", - &UnsupportedWebsocketVersion => "Unsupported WebSocket version", - &NoSecWsKeyHeader => "Missing Sec-WebSocket-Key header", - &NoWsUpgradeHeader => "Invalid Upgrade WebSocket header", - &NoUpgradeHeader => "Missing Upgrade WebSocket header", - &NoWsConnectionHeader => "Invalid Connection WebSocket header", - &NoConnectionHeader => "Missing Connection WebSocket header", - &UnknownNetworkStream => "Cannot downcast to known impl of NetworkStream", - &Io(ref e) => e.description(), - &Parsing(ref e) => e.description(), - } - } - - fn cause(&self) -> Option<&Error> { - match *self { - HyperIntoWsError::Io(ref e) => Some(e), - HyperIntoWsError::Parsing(ref e) => Some(e), - _ => None, - } - } -} - -impl From for HyperIntoWsError { - fn from(err: io::Error) -> Self { - HyperIntoWsError::Io(err) - } -} - -impl From for HyperIntoWsError { - fn from(err: hyper::error::Error) -> Self { - HyperIntoWsError::Parsing(err) - } -} - -// TODO: Move this into the main upgrade module -impl IntoWs for S -where S: Stream, -{ - type Stream = S; - type Error = (Self, Option, HyperIntoWsError); - - fn into_ws(mut self) -> Result, Self::Error> { - let request = { - let mut reader = BufReader::new(self.reader()); - parse_request(&mut reader) - }; - - let request = match request { - Ok(r) => r, - Err(e) => return Err((self, None, e.into())), - }; - - match validate(&request.subject.0, &request.version, &request.headers) { - Ok(_) => Ok(WsUpgrade { - stream: self, - request: request, - }), - Err(e) => Err((self, Some(request), e)), - } - } -} - -// TODO: Move this into the main upgrade module -impl IntoWs for RequestStreamPair -where S: Stream, -{ - type Stream = S; - type Error = (S, Request, HyperIntoWsError); - - fn into_ws(self) -> Result, Self::Error> { - match validate(&self.1.subject.0, &self.1.version, &self.1.headers) { - Ok(_) => Ok(WsUpgrade { - stream: self.0, - request: self.1, - }), - Err(e) => Err((self.0, self.1, e)), - } - } -} - -// TODO -// impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { -// type Stream = Box; -// type Error = (HyperRequest<'a, 'b>, HyperIntoWsError); - -// fn into_ws(self) -> Result, Self::Error> { -// if let Err(e) = validate(&self.method, &self.version, &self.headers) { -// return Err((self, e)); -// } - -// let stream: Option> = unimplemented!(); - -// if let Some(s) = stream { -// Ok(WsUpgrade { -// stream: s, -// request: Incoming { -// version: self.version, -// headers: self.headers, -// subject: (self.method, self.uri), -// }, -// }) -// } else { -// Err((self, HyperIntoWsError::UnknownNetworkStream)) -// } -// } -// } +use super::validate; +use super::HyperIntoWsError; -pub fn validate( - method: &Method, - version: &HttpVersion, - headers: &Headers -) -> Result<(), HyperIntoWsError> -{ - if *method != Method::Get { - return Err(HyperIntoWsError::MethodNotGet); - } +pub struct HyperRequest<'a, 'b: 'a>(pub Request<'a, 'b>); - if *version == HttpVersion::Http09 || *version == HttpVersion::Http10 { - return Err(HyperIntoWsError::UnsupportedHttpVersion); - } +impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { + type Stream = &'a mut &'b mut NetworkStream; + type Error = (Request<'a, 'b>, HyperIntoWsError); - if let Some(version) = headers.get::() { - if version != &WebSocketVersion::WebSocket13 { - return Err(HyperIntoWsError::UnsupportedWebsocketVersion); - } - } + fn into_ws(self) -> Result, Self::Error> { + if let Err(e) = validate(&self.0.method, &self.0.version, &self.0.headers) { + return Err((self.0, e)); + } - if headers.get::().is_none() { - return Err(HyperIntoWsError::NoSecWsKeyHeader); - } + let (_, method, headers, uri, version, reader) = self.0.deconstruct(); + let stream = reader.into_inner().get_mut(); - match headers.get() { - Some(&Upgrade(ref upgrade)) => { - if upgrade.iter().all(|u| u.name != ProtocolName::WebSocket) { - return Err(HyperIntoWsError::NoWsUpgradeHeader) - } - }, - None => return Err(HyperIntoWsError::NoUpgradeHeader), - }; - - fn check_connection_header(headers: &Vec) -> bool { - for header in headers { - if let &ConnectionOption::ConnectionHeader(ref h) = header { - if h as &str == "upgrade" { - return true; - } - } - } - false - } - - match headers.get() { - Some(&Connection(ref connection)) => { - if !check_connection_header(connection) { - return Err(HyperIntoWsError::NoWsConnectionHeader); - } - }, - None => return Err(HyperIntoWsError::NoConnectionHeader), - }; - - Ok(()) + Ok(WsUpgrade { + stream: stream, + request: Incoming { + version: version, + headers: headers, + subject: (method, uri), + }, + }) + } } - diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index d905561fba..9b6c270ef2 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -1,10 +1,36 @@ //! Allows you to take an existing request or stream of data and convert it into a //! WebSocket client. +extern crate hyper as real_hyper; + +use std::error::Error; use std::net::TcpStream; -use std::io::Read; +use std::io; +use std::fmt::{ + Formatter, + Display, + self, +}; use stream::{ - Stream, - AsTcpStream, + Stream, + AsTcpStream, +}; +use header::{ + WebSocketKey, + WebSocketVersion, +}; + +pub use self::real_hyper::http::h1::Incoming; +pub use self::real_hyper::method::Method; +pub use self::real_hyper::version::HttpVersion; +pub use self::real_hyper::uri::RequestUri; +pub use self::real_hyper::buffer::BufReader; +pub use self::real_hyper::http::h1::parse_request; +pub use self::real_hyper::header::{ + Headers, + Upgrade, + ProtocolName, + Connection, + ConnectionOption, }; pub mod hyper; @@ -16,34 +42,34 @@ pub mod hyper; /// Users should then call `accept` or `deny` to complete the handshake /// and start a session. pub struct WsUpgrade -where S: Stream, + where S: Stream, { - stream: S, - request: hyper::Request, + stream: S, + request: Request, } impl WsUpgrade -where S: Stream, + where S: Stream, { - pub fn accept(self) { - unimplemented!(); - } + pub fn accept(self) { + unimplemented!(); + } - pub fn reject(self) -> S { - unimplemented!(); - } + pub fn reject(self) -> S { + unimplemented!(); + } - pub fn into_stream(self) -> S { - unimplemented!(); - } + pub fn into_stream(self) -> S { + unimplemented!(); + } } impl WsUpgrade -where S: Stream + AsTcpStream, + where S: Stream + AsTcpStream, { - pub fn tcp_stream(&self) -> &TcpStream { - self.stream.as_tcp() - } + pub fn tcp_stream(&self) -> &TcpStream { + self.stream.as_tcp() + } } /// Trait to take a stream or similar and attempt to recover the start of a @@ -56,11 +82,176 @@ where S: Stream + AsTcpStream, /// /// Note: the stream is owned because the websocket client expects to own its stream. pub trait IntoWs { - type Stream: Stream; - type Error; - /// Attempt to parse the start of a Websocket handshake, later with the returned - /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to - /// send a handshake rejection response. - fn into_ws(mut self) -> Result, Self::Error>; + type Stream: Stream; + type Error; + /// Attempt to parse the start of a Websocket handshake, later with the returned + /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to + /// send a handshake rejection response. + fn into_ws(self) -> Result, Self::Error>; +} + + +pub type Request = Incoming<(Method, RequestUri)>; +pub struct RequestStreamPair(pub S, pub Request); + +impl IntoWs for S + where S: Stream, +{ + type Stream = S; + type Error = (Self, Option, HyperIntoWsError); + + fn into_ws(mut self) -> Result, Self::Error> { + let request = { + let mut reader = BufReader::new(self.reader()); + parse_request(&mut reader) + }; + + let request = match request { + Ok(r) => r, + Err(e) => return Err((self, None, e.into())), + }; + + match validate(&request.subject.0, &request.version, &request.headers) { + Ok(_) => Ok(WsUpgrade { + stream: self, + request: request, + }), + Err(e) => Err((self, Some(request), e)), + } + } +} + +impl IntoWs for RequestStreamPair + where S: Stream, +{ + type Stream = S; + type Error = (S, Request, HyperIntoWsError); + + fn into_ws(self) -> Result, Self::Error> { + match validate(&self.1.subject.0, &self.1.version, &self.1.headers) { + Ok(_) => Ok(WsUpgrade { + stream: self.0, + request: self.1, + }), + Err(e) => Err((self.0, self.1, e)), + } + } +} + +#[derive(Debug)] +pub enum HyperIntoWsError { + MethodNotGet, + UnsupportedHttpVersion, + UnsupportedWebsocketVersion, + NoSecWsKeyHeader, + NoWsUpgradeHeader, + NoUpgradeHeader, + NoWsConnectionHeader, + NoConnectionHeader, + UnknownNetworkStream, + /// IO error from reading the underlying socket + Io(io::Error), + /// Error while parsing an incoming request + Parsing(self::real_hyper::error::Error), } +impl Display for HyperIntoWsError { + fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { + fmt.write_str(self.description()) + } +} + +impl Error for HyperIntoWsError { + fn description(&self) -> &str { + use self::HyperIntoWsError::*; + match self { + &MethodNotGet => "Request method must be GET", + &UnsupportedHttpVersion => "Unsupported request HTTP version", + &UnsupportedWebsocketVersion => "Unsupported WebSocket version", + &NoSecWsKeyHeader => "Missing Sec-WebSocket-Key header", + &NoWsUpgradeHeader => "Invalid Upgrade WebSocket header", + &NoUpgradeHeader => "Missing Upgrade WebSocket header", + &NoWsConnectionHeader => "Invalid Connection WebSocket header", + &NoConnectionHeader => "Missing Connection WebSocket header", + &UnknownNetworkStream => "Cannot downcast to known impl of NetworkStream", + &Io(ref e) => e.description(), + &Parsing(ref e) => e.description(), + } + } + + fn cause(&self) -> Option<&Error> { + match *self { + HyperIntoWsError::Io(ref e) => Some(e), + HyperIntoWsError::Parsing(ref e) => Some(e), + _ => None, + } + } +} + +impl From for HyperIntoWsError { + fn from(err: io::Error) -> Self { + HyperIntoWsError::Io(err) + } +} + +impl From for HyperIntoWsError { + fn from(err: real_hyper::error::Error) -> Self { + HyperIntoWsError::Parsing(err) + } +} + +pub fn validate( + method: &Method, + version: &HttpVersion, + headers: &Headers +) -> Result<(), HyperIntoWsError> +{ + if *method != Method::Get { + return Err(HyperIntoWsError::MethodNotGet); + } + + if *version == HttpVersion::Http09 || *version == HttpVersion::Http10 { + return Err(HyperIntoWsError::UnsupportedHttpVersion); + } + + if let Some(version) = headers.get::() { + if version != &WebSocketVersion::WebSocket13 { + return Err(HyperIntoWsError::UnsupportedWebsocketVersion); + } + } + + if headers.get::().is_none() { + return Err(HyperIntoWsError::NoSecWsKeyHeader); + } + + match headers.get() { + Some(&Upgrade(ref upgrade)) => { + if upgrade.iter().all(|u| u.name != ProtocolName::WebSocket) { + return Err(HyperIntoWsError::NoWsUpgradeHeader) + } + }, + None => return Err(HyperIntoWsError::NoUpgradeHeader), + }; + + fn check_connection_header(headers: &Vec) -> bool { + for header in headers { + if let &ConnectionOption::ConnectionHeader(ref h) = header { + if h as &str == "upgrade" { + return true; + } + } + } + false + } + + match headers.get() { + Some(&Connection(ref connection)) => { + if !check_connection_header(connection) { + return Err(HyperIntoWsError::NoWsConnectionHeader); + } + }, + None => return Err(HyperIntoWsError::NoConnectionHeader), + }; + + Ok(()) +} From e659824f08517f15379725bb744b0ed1c7d5192b Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Tue, 28 Mar 2017 14:11:30 -0400 Subject: [PATCH 18/32] moved client builder into its own module --- src/client/builder.rs | 315 ++++++++++++++++++++++++++++++++++++++++++ src/client/mod.rs | 273 ++---------------------------------- 2 files changed, 323 insertions(+), 265 deletions(-) create mode 100644 src/client/builder.rs diff --git a/src/client/builder.rs b/src/client/builder.rs new file mode 100644 index 0000000000..12d1676508 --- /dev/null +++ b/src/client/builder.rs @@ -0,0 +1,315 @@ +use std::borrow::{ + Cow, +}; +use std::io::{ + Write, +}; +use std::net::TcpStream; +use url::{ + Url, + Position, +}; +use hyper::version::HttpVersion; +use hyper::status::StatusCode; +use hyper::buffer::BufReader; +use hyper::http::h1::parse_response; +use hyper::header::{ + Headers, + Host, + Connection, + ConnectionOption, + Upgrade, + Protocol, + ProtocolName, +}; +use unicase::UniCase; +use openssl::ssl::error::SslError; +use openssl::ssl::{ + SslContext, + SslMethod, + SslStream, +}; +use header::extensions::Extension; +use header::{ + WebSocketAccept, + WebSocketKey, + WebSocketVersion, + WebSocketProtocol, + WebSocketExtensions, + Origin, +}; +use result::{ + WSUrlErrorKind, + WebSocketResult, + WebSocketError, +}; +use stream::{ + BoxedNetworkStream, + AsTcpStream, + Stream, + Splittable, + Shutdown, +}; +use super::Client; + +macro_rules! upsert_header { + ($headers:expr; $header:ty; { + Some($pat:pat) => $some_match:expr, + None => $default:expr + }) => {{ + match $headers.has::<$header>() { + true => { + match $headers.get_mut::<$header>() { + Some($pat) => { $some_match; }, + None => (), + }; + } + false => { + $headers.set($default); + }, + }; + }} +} + +/// Build clients with a builder-style API +#[derive(Clone, Debug)] +pub struct ClientBuilder<'u, 's> { + url: Cow<'u, Url>, + version: HttpVersion, + headers: Headers, + version_set: bool, + key_set: bool, + ssl_context: Option>, +} + +impl<'u, 's> ClientBuilder<'u, 's> { + pub fn new(url: Cow<'u, Url>) -> Self { + ClientBuilder { + url: url, + version: HttpVersion::Http11, + version_set: false, + key_set: false, + ssl_context: None, + headers: Headers::new(), + } + } + + pub fn add_protocol

(mut self, protocol: P) -> Self + where P: Into, + { + upsert_header!(self.headers; WebSocketProtocol; { + Some(protos) => protos.0.push(protocol.into()), + None => WebSocketProtocol(vec![protocol.into()]) + }); + self + } + + pub fn add_protocols(mut self, protocols: I) -> Self + where I: IntoIterator, + S: Into, + { + let mut protocols: Vec = protocols.into_iter() + .map(Into::into).collect(); + + upsert_header!(self.headers; WebSocketProtocol; { + Some(protos) => protos.0.append(&mut protocols), + None => WebSocketProtocol(protocols) + }); + self + } + + pub fn clear_protocols(mut self) -> Self { + self.headers.remove::(); + self + } + + pub fn add_extension(mut self, extension: Extension) -> Self + { + upsert_header!(self.headers; WebSocketExtensions; { + Some(protos) => protos.0.push(extension), + None => WebSocketExtensions(vec![extension]) + }); + self + } + + pub fn add_extensions(mut self, extensions: I) -> Self + where I: IntoIterator, + { + let mut extensions: Vec = extensions.into_iter().collect(); + upsert_header!(self.headers; WebSocketExtensions; { + Some(protos) => protos.0.append(&mut extensions), + None => WebSocketExtensions(extensions) + }); + self + } + + pub fn clear_extensions(mut self) -> Self { + self.headers.remove::(); + self + } + + pub fn key(mut self, key: [u8; 16]) -> Self { + self.headers.set(WebSocketKey(key)); + self.key_set = true; + self + } + + pub fn clear_key(mut self) -> Self { + self.headers.remove::(); + self.key_set = false; + self + } + + pub fn version(mut self, version: WebSocketVersion) -> Self { + self.headers.set(version); + self.version_set = true; + self + } + + pub fn clear_version(mut self) -> Self { + self.headers.remove::(); + self.version_set = false; + self + } + + pub fn origin(mut self, origin: String) -> Self { + self.headers.set(Origin(origin)); + self + } + + pub fn custom_headers(mut self, edit: F) -> Self + where F: Fn(&mut Headers), + { + edit(&mut self.headers); + self + } + + pub fn ssl_context(mut self, context: &'s SslContext) -> Self { + self.ssl_context = Some(Cow::Borrowed(context)); + self + } + + fn establish_tcp(&mut self, secure: Option) -> WebSocketResult { + let port = match (self.url.port(), secure) { + (Some(port), _) => port, + (None, None) if self.url.scheme() == "wss" => 443, + (None, None) => 80, + (None, Some(true)) => 443, + (None, Some(false)) => 80, + }; + let host = match self.url.host_str() { + Some(h) => h, + None => return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName)), + }; + + let tcp_stream = try!(TcpStream::connect((host, port))); + Ok(tcp_stream) + } + + fn wrap_ssl(&self, tcp_stream: TcpStream) -> Result, SslError> { + let context = match self.ssl_context { + Some(ref ctx) => Cow::Borrowed(ctx.as_ref()), + None => Cow::Owned(try!(SslContext::new(SslMethod::Tlsv1))), + }; + + SslStream::connect(&*context, tcp_stream) + } + + pub fn connect(&mut self) -> WebSocketResult> { + let tcp_stream = try!(self.establish_tcp(None)); + + let boxed_stream = if self.url.scheme() == "wss" { + BoxedNetworkStream(Box::new(try!(self.wrap_ssl(tcp_stream)))) + } else { + BoxedNetworkStream(Box::new(tcp_stream)) + }; + + self.connect_on(boxed_stream) + } + + pub fn connect_insecure(&mut self) -> WebSocketResult> { + let tcp_stream = try!(self.establish_tcp(Some(false))); + + self.connect_on(tcp_stream) + } + + pub fn connect_secure(&mut self) -> WebSocketResult>> { + let tcp_stream = try!(self.establish_tcp(Some(true))); + + let ssl_stream = try!(self.wrap_ssl(tcp_stream)); + + self.connect_on(ssl_stream) + } + + // TODO: refactor and split apart into two parts, for when evented happens + pub fn connect_on(&mut self, mut stream: S) -> WebSocketResult> + where S: Stream, + { + let resource = self.url[Position::BeforePath..Position::AfterQuery] + .to_owned(); + + // enter host if available (unix sockets don't have hosts) + if let Some(host) = self.url.host_str() { + self.headers.set(Host { + hostname: host.to_string(), + port: self.url.port(), + }); + } + + self.headers.set(Connection(vec![ + ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) + ])); + + self.headers.set(Upgrade(vec![Protocol { + name: ProtocolName::WebSocket, + // TODO: actually correct or just works? + version: None + }])); + + if !self.version_set { + self.headers.set(WebSocketVersion::WebSocket13); + } + + if !self.key_set { + self.headers.set(WebSocketKey::new()); + } + + // send request + try!(write!(stream.writer(), "GET {} {}\r\n", resource, self.version)); + try!(write!(stream.writer(), "{}\r\n", self.headers)); + + // wait for a response + // TODO: we should buffer it all, how to set up stream for this? + let response = try!(parse_response(&mut BufReader::new(stream.reader()))); + let status = StatusCode::from_u16(response.subject.0); + + // validate + if status != StatusCode::SwitchingProtocols { + return Err(WebSocketError::ResponseError("Status code must be Switching Protocols")); + } + + let key = try!(self.headers.get::().ok_or( + WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid") + )); + + if response.headers.get() != Some(&(WebSocketAccept::new(key))) { + return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid")); + } + + if response.headers.get() != Some(&(Upgrade(vec![Protocol { + name: ProtocolName::WebSocket, + version: None + }]))) { + return Err(WebSocketError::ResponseError("Upgrade field must be WebSocket")); + } + + if self.headers.get() != Some(&(Connection(vec![ + ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())), + ]))) { + return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); + } + + Ok(Client::unchecked(stream)) + } +} + diff --git a/src/client/mod.rs b/src/client/mod.rs index 663c2e357f..52f533e18b 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -50,7 +50,7 @@ use header::{ WebSocketVersion, WebSocketProtocol, WebSocketExtensions, - Origin + Origin, }; use result::{ WSUrlErrorKind, @@ -72,266 +72,8 @@ use receiver::Receiver; pub use sender::Writer; pub use receiver::Reader; -/// Build clients with a builder-style API -#[derive(Clone, Debug)] -pub struct ClientBuilder<'u, 's> { - url: Cow<'u, Url>, - version: HttpVersion, - headers: Headers, - version_set: bool, - key_set: bool, - ssl_context: Option>, -} - -macro_rules! upsert_header { - ($headers:expr; $header:ty; { - Some($pat:pat) => $some_match:expr, - None => $default:expr - }) => {{ - match $headers.has::<$header>() { - true => { - match $headers.get_mut::<$header>() { - Some($pat) => { $some_match; }, - None => (), - }; - } - false => { - $headers.set($default); - }, - }; - }} -} - -impl<'u, 's> ClientBuilder<'u, 's> { - pub fn new(url: Cow<'u, Url>) -> Self { - ClientBuilder { - url: url, - version: HttpVersion::Http11, - version_set: false, - key_set: false, - ssl_context: None, - headers: Headers::new(), - } - } - - pub fn add_protocol

(mut self, protocol: P) -> Self - where P: Into, - { - upsert_header!(self.headers; WebSocketProtocol; { - Some(protos) => protos.0.push(protocol.into()), - None => WebSocketProtocol(vec![protocol.into()]) - }); - self - } - - pub fn add_protocols(mut self, protocols: I) -> Self - where I: IntoIterator, - S: Into, - { - let mut protocols: Vec = protocols.into_iter() - .map(Into::into).collect(); - - upsert_header!(self.headers; WebSocketProtocol; { - Some(protos) => protos.0.append(&mut protocols), - None => WebSocketProtocol(protocols) - }); - self - } - - pub fn clear_protocols(mut self) -> Self { - self.headers.remove::(); - self - } - - pub fn add_extension(mut self, extension: Extension) -> Self - { - upsert_header!(self.headers; WebSocketExtensions; { - Some(protos) => protos.0.push(extension), - None => WebSocketExtensions(vec![extension]) - }); - self - } - - pub fn add_extensions(mut self, extensions: I) -> Self - where I: IntoIterator, - { - let mut extensions: Vec = extensions.into_iter().collect(); - upsert_header!(self.headers; WebSocketExtensions; { - Some(protos) => protos.0.append(&mut extensions), - None => WebSocketExtensions(extensions) - }); - self - } - - pub fn clear_extensions(mut self) -> Self { - self.headers.remove::(); - self - } - - pub fn key(mut self, key: [u8; 16]) -> Self { - self.headers.set(WebSocketKey(key)); - self.key_set = true; - self - } - - pub fn clear_key(mut self) -> Self { - self.headers.remove::(); - self.key_set = false; - self - } - - pub fn version(mut self, version: WebSocketVersion) -> Self { - self.headers.set(version); - self.version_set = true; - self - } - - pub fn clear_version(mut self) -> Self { - self.headers.remove::(); - self.version_set = false; - self - } - - pub fn origin(mut self, origin: String) -> Self { - self.headers.set(Origin(origin)); - self - } - - pub fn custom_headers(mut self, edit: F) -> Self - where F: Fn(&mut Headers), - { - edit(&mut self.headers); - self - } - - pub fn ssl_context(mut self, context: &'s SslContext) -> Self { - self.ssl_context = Some(Cow::Borrowed(context)); - self - } - - fn establish_tcp(&mut self, secure: Option) -> WebSocketResult { - let port = match (self.url.port(), secure) { - (Some(port), _) => port, - (None, None) if self.url.scheme() == "wss" => 443, - (None, None) => 80, - (None, Some(true)) => 443, - (None, Some(false)) => 80, - }; - let host = match self.url.host_str() { - Some(h) => h, - None => return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName)), - }; - - let tcp_stream = try!(TcpStream::connect((host, port))); - Ok(tcp_stream) - } - - fn wrap_ssl(&self, tcp_stream: TcpStream) -> Result, SslError> { - let context = match self.ssl_context { - Some(ref ctx) => Cow::Borrowed(ctx.as_ref()), - None => Cow::Owned(try!(SslContext::new(SslMethod::Tlsv1))), - }; - - SslStream::connect(&*context, tcp_stream) - } - - pub fn connect(&mut self) -> WebSocketResult> { - let tcp_stream = try!(self.establish_tcp(None)); - - let boxed_stream = if self.url.scheme() == "wss" { - BoxedNetworkStream(Box::new(try!(self.wrap_ssl(tcp_stream)))) - } else { - BoxedNetworkStream(Box::new(tcp_stream)) - }; - - self.connect_on(boxed_stream) - } - - pub fn connect_insecure(&mut self) -> WebSocketResult> { - let tcp_stream = try!(self.establish_tcp(Some(false))); - - self.connect_on(tcp_stream) - } - - pub fn connect_secure(&mut self) -> WebSocketResult>> { - let tcp_stream = try!(self.establish_tcp(Some(true))); - - let ssl_stream = try!(self.wrap_ssl(tcp_stream)); - - self.connect_on(ssl_stream) - } - - // TODO: refactor and split apart into two parts, for when evented happens - pub fn connect_on(&mut self, mut stream: S) -> WebSocketResult> - where S: Stream, - { - let resource = self.url[Position::BeforePath..Position::AfterQuery] - .to_owned(); - - // enter host if available (unix sockets don't have hosts) - if let Some(host) = self.url.host_str() { - self.headers.set(Host { - hostname: host.to_string(), - port: self.url.port(), - }); - } - - self.headers.set(Connection(vec![ - ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) - ])); - - self.headers.set(Upgrade(vec![Protocol { - name: ProtocolName::WebSocket, - // TODO: actually correct or just works? - version: None - }])); - - if !self.version_set { - self.headers.set(WebSocketVersion::WebSocket13); - } - - if !self.key_set { - self.headers.set(WebSocketKey::new()); - } - - // send request - try!(write!(stream.writer(), "GET {} {}\r\n", resource, self.version)); - try!(write!(stream.writer(), "{}\r\n", self.headers)); - - // wait for a response - // TODO: we should buffer it all, how to set up stream for this? - let response = try!(parse_response(&mut BufReader::new(stream.reader()))); - let status = StatusCode::from_u16(response.subject.0); - - // validate - if status != StatusCode::SwitchingProtocols { - return Err(WebSocketError::ResponseError("Status code must be Switching Protocols")); - } - - let key = try!(self.headers.get::().ok_or( - WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid") - )); - - if response.headers.get() != Some(&(WebSocketAccept::new(key))) { - return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid")); - } - - if response.headers.get() != Some(&(Upgrade(vec![Protocol { - name: ProtocolName::WebSocket, - version: None - }]))) { - return Err(WebSocketError::ResponseError("Upgrade field must be WebSocket")); - } - - if self.headers.get() != Some(&(Connection(vec![ - ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())), - ]))) { - return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); - } - - Ok(Client::new(stream)) - } -} +pub mod builder; +pub use self::builder::ClientBuilder; /// Represents a WebSocket client, which can send and receive messages/data frames. /// @@ -411,10 +153,11 @@ impl<'u, 'p, 'e, 's, S> Client Ok(ClientBuilder::new(Cow::Owned(url))) } - /// Creates a Client from the given Sender and Receiver. - /// - /// Esstiallthe opposite of `Client.split()`. - fn new(stream: S) -> Self { + /// Creates a Client from a given stream + /// **without sending any handshake** this is meant to only be used with + /// a stream that has a websocket connection already set up. + /// If in doubt, don't use this! + pub fn unchecked(stream: S) -> Self { Client { stream: stream, // TODO: always true? From 6a28b31b1bd6ffc87cf96d0abbb26af7380f637e Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Tue, 28 Mar 2017 14:11:59 -0400 Subject: [PATCH 19/32] implemented intows --- src/server/upgrade/mod.rs | 85 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 80 insertions(+), 5 deletions(-) diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index 9b6c270ef2..abb49ac890 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -5,6 +5,11 @@ extern crate hyper as real_hyper; use std::error::Error; use std::net::TcpStream; use std::io; +use std::io::Result as IoResult; +use std::io::Error as IoError; +use std::io::{ + Write, +}; use std::fmt::{ Formatter, Display, @@ -14,11 +19,19 @@ use stream::{ Stream, AsTcpStream, }; +use header::extensions::Extension; use header::{ + WebSocketAccept, WebSocketKey, WebSocketVersion, + WebSocketProtocol, + WebSocketExtensions, + Origin, }; +use client::Client; +use unicase::UniCase; +use self::real_hyper::status::StatusCode; pub use self::real_hyper::http::h1::Incoming; pub use self::real_hyper::method::Method; pub use self::real_hyper::version::HttpVersion; @@ -28,6 +41,7 @@ pub use self::real_hyper::http::h1::parse_request; pub use self::real_hyper::header::{ Headers, Upgrade, + Protocol, ProtocolName, Connection, ConnectionOption, @@ -51,17 +65,78 @@ pub struct WsUpgrade impl WsUpgrade where S: Stream, { - pub fn accept(self) { - unimplemented!(); + pub fn accept(self) -> IoResult> { + self.accept_with(&Headers::new()) } - pub fn reject(self) -> S { - unimplemented!(); + pub fn accept_with(mut self, custom_headers: &Headers) -> IoResult> { + let mut headers = Headers::new(); + headers.extend(custom_headers.iter()); + headers.set(WebSocketAccept::new( + // NOTE: we know there is a key because this is a valid request + // i.e. to construct this you must go through the validate function + self.request.headers.get::().unwrap() + )); + headers.set(Connection(vec![ + ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) + ])); + headers.set(Upgrade(vec![ + // TODO: really not set a version for this? + Protocol::new(ProtocolName::WebSocket, None) + ])); + + try!(self.send(StatusCode::SwitchingProtocols, &headers)); + + Ok(Client::unchecked(self.stream)) } + pub fn reject(self) -> Result { + self.reject_with(&Headers::new()) + } + + pub fn reject_with(mut self, headers: &Headers) -> Result { + match self.send(StatusCode::BadRequest, headers) { + Ok(()) => Ok(self.stream), + Err(e) => Err((self.stream, e)) + } + } + + pub fn drop(self) { + ::std::mem::drop(self); + } + + pub fn protocols(&self) -> Option<&[String]> { + self.request.headers.get::().map(|p| p.0.as_slice()) + } + + pub fn extensions(&self) -> Option<&[Extension]> { + self.request.headers.get::().map(|e| e.0.as_slice()) + } + + pub fn key(&self) -> Option<&[u8; 16]> { + self.request.headers.get::().map(|k| &k.0) + } + + pub fn version(&self) -> Option<&WebSocketVersion> { + self.request.headers.get::() + } + + pub fn origin(&self) -> Option<&str> { + self.request.headers.get::().map(|o| &o.0 as &str) + } + pub fn into_stream(self) -> S { - unimplemented!(); + self.stream } + + fn send(&mut self, + status: StatusCode, + headers: &Headers + ) -> IoResult<()> { + try!(write!(self.stream.writer(), "{} {}\r\n", self.request.version, status)); + try!(write!(self.stream.writer(), "{}\r\n", headers)); + Ok(()) + } } impl WsUpgrade From 7fed8bbc7052f49de7c8fb4b06fa5da3bd513d57 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Tue, 28 Mar 2017 14:54:41 -0400 Subject: [PATCH 20/32] added TCP settings functions and verified questions about ws protocol --- ROADMAP.md | 5 +++++ src/client/builder.rs | 1 - src/client/mod.rs | 39 +++++++++++++++++++++++++++++++++------ src/receiver.rs | 14 +++++++------- src/sender.rs | 10 +++++----- src/server/upgrade/mod.rs | 1 - 6 files changed, 50 insertions(+), 20 deletions(-) create mode 100644 ROADMAP.md diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 0000000000..1bf0284e08 --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,5 @@ + +### Adding Features + + - Make the usage of `net2` a feature + - Make evented diff --git a/src/client/builder.rs b/src/client/builder.rs index 12d1676508..63d1be71a2 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -262,7 +262,6 @@ impl<'u, 's> ClientBuilder<'u, 's> { self.headers.set(Upgrade(vec![Protocol { name: ProtocolName::WebSocket, - // TODO: actually correct or just works? version: None }])); diff --git a/src/client/mod.rs b/src/client/mod.rs index 52f533e18b..0071d80de0 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,10 +1,12 @@ //! Contains the WebSocket client. extern crate url; +extern crate net2; use std::borrow::{ Cow, }; use std::net::TcpStream; +use std::net::SocketAddr; use std::io::Result as IoResult; use std::io::{ Write, @@ -35,6 +37,7 @@ use hyper::header::{ ProtocolName, }; use unicase::UniCase; +use self::net2::TcpStreamExt; use ws; use ws::sender::Sender as SenderTrait; @@ -111,7 +114,7 @@ pub use self::builder::ClientBuilder; pub struct Client where S: Stream, { - stream: S, + pub stream: S, sender: Sender, receiver: Receiver, } @@ -130,7 +133,6 @@ impl Client { } } -// TODO: add net2 set_nonblocking and stuff impl Client where S: AsTcpStream + Stream, { @@ -139,6 +141,31 @@ impl Client pub fn shutdown(&self) -> IoResult<()> { self.stream.as_tcp().shutdown(Shutdown::Both) } + + /// See `TcpStream.peer_addr()`. + pub fn peer_addr(&self) -> IoResult { + self.stream.as_tcp().peer_addr() + } + + /// See `TcpStream.local_addr()`. + pub fn local_addr(&self) -> IoResult { + self.stream.as_tcp().local_addr() + } + + /// See `TcpStream.set_nodelay()`. + pub fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { + self.stream.as_tcp().set_nodelay(nodelay) + } + + /// See `TcpStream.set_keepalive()`. + pub fn set_keepalive(&mut self, delay_in_ms: Option) -> IoResult<()> { + TcpStreamExt::set_keepalive_ms(self.stream.as_tcp(), delay_in_ms) + } + + /// Changes whether the stream is in nonblocking mode. + pub fn set_nonblocking(&self, nonblocking: bool) -> IoResult<()> { + self.stream.as_tcp().set_nonblocking(nonblocking) + } } impl<'u, 'p, 'e, 's, S> Client @@ -160,9 +187,9 @@ impl<'u, 'p, 'e, 's, S> Client pub fn unchecked(stream: S) -> Self { Client { stream: stream, - // TODO: always true? + // NOTE: these are always true & false, see + // https://tools.ietf.org/html/rfc6455#section-5 sender: Sender::new(true), - // TODO: always false? receiver: Receiver::new(false), } } @@ -293,10 +320,10 @@ impl Client pub fn split(self) -> IoResult<(Reader<::Reader>, Writer<::Writer>)> { let (read, write) = try!(self.stream.split()); Ok((Reader { - reader: read, + stream: read, receiver: self.receiver, }, Writer { - writer: write, + stream: write, sender: self.sender, })) } diff --git a/src/receiver.rs b/src/receiver.rs index 951c6f30e5..85b7f09cea 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -28,7 +28,7 @@ pub use stream::Shutdown; pub struct Reader where R: Read { - pub reader: R, + pub stream: R, pub receiver: Receiver, } @@ -37,12 +37,12 @@ impl Reader { /// Reads a single data frame from the remote endpoint. pub fn recv_dataframe(&mut self) -> WebSocketResult { - self.receiver.recv_dataframe(&mut self.reader) + self.receiver.recv_dataframe(&mut self.stream) } /// Returns an iterator over incoming data frames. pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, R> { - self.receiver.incoming_dataframes(&mut self.reader) + self.receiver.incoming_dataframes(&mut self.stream) } /// Reads a single message from this receiver. @@ -50,14 +50,14 @@ impl Reader where M: ws::Message<'m, DataFrame, DataFrameIterator = I>, I: Iterator { - self.receiver.recv_message(&mut self.reader) + self.receiver.recv_message(&mut self.stream) } pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, R> where M: ws::Message<'a, D>, D: DataFrameable { - self.receiver.incoming_messages(&mut self.reader) + self.receiver.incoming_messages(&mut self.stream) } } @@ -67,13 +67,13 @@ impl Reader /// Closes the receiver side of the connection, will cause all pending and future IO to /// return immediately with an appropriate value. pub fn shutdown(&self) -> IoResult<()> { - self.reader.as_tcp().shutdown(Shutdown::Read) + self.stream.as_tcp().shutdown(Shutdown::Read) } /// Shuts down both Sender and Receiver, will cause all pending and future IO to /// return immediately with an appropriate value. pub fn shutdown_all(&self) -> IoResult<()> { - self.reader.as_tcp().shutdown(Shutdown::Both) + self.stream.as_tcp().shutdown(Shutdown::Both) } } diff --git a/src/sender.rs b/src/sender.rs index 5e78f6c68e..ed0ebefca0 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -10,7 +10,7 @@ use ws::sender::Sender as SenderTrait; pub use stream::Shutdown; pub struct Writer { - pub writer: W, + pub stream: W, pub sender: Sender, } @@ -22,7 +22,7 @@ impl Writer where D: DataFrame, W: Write, { - self.sender.send_dataframe(&mut self.writer, dataframe) + self.sender.send_dataframe(&mut self.stream, dataframe) } /// Sends a single message to the remote endpoint. @@ -30,7 +30,7 @@ impl Writer where M: ws::Message<'m, D>, D: DataFrame { - self.sender.send_message(&mut self.writer, message) + self.sender.send_message(&mut self.stream, message) } } @@ -40,13 +40,13 @@ impl Writer /// Closes the sender side of the connection, will cause all pending and future IO to /// return immediately with an appropriate value. pub fn shutdown(&self) -> IoResult<()> { - self.writer.as_tcp().shutdown(Shutdown::Write) + self.stream.as_tcp().shutdown(Shutdown::Write) } /// Shuts down both Sender and Receiver, will cause all pending and future IO to /// return immediately with an appropriate value. pub fn shutdown_all(&self) -> IoResult<()> { - self.writer.as_tcp().shutdown(Shutdown::Both) + self.stream.as_tcp().shutdown(Shutdown::Both) } } diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index abb49ac890..3f3680f2e2 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -81,7 +81,6 @@ impl WsUpgrade ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) ])); headers.set(Upgrade(vec![ - // TODO: really not set a version for this? Protocol::new(ProtocolName::WebSocket, None) ])); From d6fb9fe84c2ba7f7555dab8382f3c4e08775eabc Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Tue, 28 Mar 2017 19:21:58 -0400 Subject: [PATCH 21/32] updated all dependencies to the latest and greatest --- Cargo.toml | 20 ++- src/client/builder.rs | 47 ++++--- src/client/mod.rs | 14 +- src/header/accept.rs | 9 +- src/result.rs | 153 ++++++++++++---------- src/server/mod.rs | 264 ++++++++++++++++++++------------------ src/server/upgrade/mod.rs | 4 +- src/stream.rs | 11 -- 8 files changed, 267 insertions(+), 255 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4cf8ea39c8..ac71b6d32b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "websocket" version = "0.17.2" -authors = ["cyderize "] +authors = ["cyderize ", "Michael Eden "] description = "A WebSocket (RFC6455) library for Rust." @@ -17,16 +17,14 @@ keywords = ["websocket", "websockets", "rfc6455"] license = "MIT" [dependencies] -hyper = ">=0.7, <0.10" -mio = "0.5.1" -unicase = "1.0.1" -openssl = "0.7.6" -url = "1.0" -rustc-serialize = "0.3.16" -bitflags = "0.7" -rand = "0.3.12" -byteorder = "1.0" -net2 = "0.2.17" +hyper = "^0.10" +unicase = "^1.0" +openssl = "^0.9.10" +url = "^1.0" +rustc-serialize = "^0.3" +bitflags = "^0.8" +rand = "^0.3" +byteorder = "^1.0" [features] nightly = ["hyper/nightly"] diff --git a/src/client/builder.rs b/src/client/builder.rs index 63d1be71a2..d361310ecb 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -23,11 +23,13 @@ use hyper::header::{ ProtocolName, }; use unicase::UniCase; -use openssl::ssl::error::SslError; +use openssl::error::ErrorStack as SslError; use openssl::ssl::{ SslContext, SslMethod, SslStream, + SslConnector, + SslConnectorBuilder, }; use header::extensions::Extension; use header::{ @@ -73,23 +75,21 @@ macro_rules! upsert_header { /// Build clients with a builder-style API #[derive(Clone, Debug)] -pub struct ClientBuilder<'u, 's> { +pub struct ClientBuilder<'u> { url: Cow<'u, Url>, version: HttpVersion, headers: Headers, version_set: bool, key_set: bool, - ssl_context: Option>, } -impl<'u, 's> ClientBuilder<'u, 's> { +impl<'u> ClientBuilder<'u> { pub fn new(url: Cow<'u, Url>) -> Self { ClientBuilder { url: url, version: HttpVersion::Http11, version_set: false, key_set: false, - ssl_context: None, headers: Headers::new(), } } @@ -184,11 +184,6 @@ impl<'u, 's> ClientBuilder<'u, 's> { self } - pub fn ssl_context(mut self, context: &'s SslContext) -> Self { - self.ssl_context = Some(Cow::Borrowed(context)); - self - } - fn establish_tcp(&mut self, secure: Option) -> WebSocketResult { let port = match (self.url.port(), secure) { (Some(port), _) => port, @@ -206,20 +201,30 @@ impl<'u, 's> ClientBuilder<'u, 's> { Ok(tcp_stream) } - fn wrap_ssl(&self, tcp_stream: TcpStream) -> Result, SslError> { - let context = match self.ssl_context { - Some(ref ctx) => Cow::Borrowed(ctx.as_ref()), - None => Cow::Owned(try!(SslContext::new(SslMethod::Tlsv1))), + fn wrap_ssl(&self, + tcp_stream: TcpStream, + connector: Option + ) -> WebSocketResult> { + let host = match self.url.host_str() { + Some(h) => h, + None => return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName)), + }; + let connector = match connector { + Some(c) => c, + None => try!(SslConnectorBuilder::new(SslMethod::tls())).build(), }; - SslStream::connect(&*context, tcp_stream) + let ssl_stream = try!(connector.connect(host, tcp_stream)); + Ok(ssl_stream) } - pub fn connect(&mut self) -> WebSocketResult> { + pub fn connect(&mut self, + ssl_config: Option + ) -> WebSocketResult> { let tcp_stream = try!(self.establish_tcp(None)); let boxed_stream = if self.url.scheme() == "wss" { - BoxedNetworkStream(Box::new(try!(self.wrap_ssl(tcp_stream)))) + BoxedNetworkStream(Box::new(try!(self.wrap_ssl(tcp_stream, ssl_config)))) } else { BoxedNetworkStream(Box::new(tcp_stream)) }; @@ -233,10 +238,12 @@ impl<'u, 's> ClientBuilder<'u, 's> { self.connect_on(tcp_stream) } - pub fn connect_secure(&mut self) -> WebSocketResult>> { + pub fn connect_secure(&mut self, + ssl_config: Option + ) -> WebSocketResult>> { let tcp_stream = try!(self.establish_tcp(Some(true))); - let ssl_stream = try!(self.wrap_ssl(tcp_stream)); + let ssl_stream = try!(self.wrap_ssl(tcp_stream, ssl_config)); self.connect_on(ssl_stream) } @@ -291,7 +298,7 @@ impl<'u, 's> ClientBuilder<'u, 's> { WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid") )); - if response.headers.get() != Some(&(WebSocketAccept::new(key))) { + if response.headers.get() != Some(&(try!(WebSocketAccept::new(key)))) { return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid")); } diff --git a/src/client/mod.rs b/src/client/mod.rs index 0071d80de0..017432995b 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,6 +1,5 @@ //! Contains the WebSocket client. extern crate url; -extern crate net2; use std::borrow::{ Cow, @@ -22,7 +21,6 @@ use openssl::ssl::{ SslMethod, SslStream, }; -use openssl::ssl::error::SslError; use hyper::buffer::BufReader; use hyper::status::StatusCode; use hyper::http::h1::parse_response; @@ -37,7 +35,6 @@ use hyper::header::{ ProtocolName, }; use unicase::UniCase; -use self::net2::TcpStreamExt; use ws; use ws::sender::Sender as SenderTrait; @@ -157,25 +154,20 @@ impl Client self.stream.as_tcp().set_nodelay(nodelay) } - /// See `TcpStream.set_keepalive()`. - pub fn set_keepalive(&mut self, delay_in_ms: Option) -> IoResult<()> { - TcpStreamExt::set_keepalive_ms(self.stream.as_tcp(), delay_in_ms) - } - /// Changes whether the stream is in nonblocking mode. pub fn set_nonblocking(&self, nonblocking: bool) -> IoResult<()> { self.stream.as_tcp().set_nonblocking(nonblocking) } } -impl<'u, 'p, 'e, 's, S> Client +impl<'u, S> Client where S: Stream, { - pub fn from_url(address: &'u Url) -> ClientBuilder<'u, 's> { + pub fn from_url(address: &'u Url) -> ClientBuilder<'u> { ClientBuilder::new(Cow::Borrowed(address)) } - pub fn build(address: &str) -> Result, ParseError> { + pub fn build(address: &str) -> Result, ParseError> { let url = try!(Url::parse(address)); Ok(ClientBuilder::new(Cow::Owned(url))) } diff --git a/src/header/accept.rs b/src/header/accept.rs index 8bc0ddab15..c6edc8d130 100644 --- a/src/header/accept.rs +++ b/src/header/accept.rs @@ -5,7 +5,8 @@ use std::fmt::{self, Debug}; use std::str::FromStr; use serialize::base64::{ToBase64, FromBase64, STANDARD}; use header::WebSocketKey; -use openssl::crypto::hash::{self, hash}; +use openssl::hash::{self, hash}; +use openssl::error::ErrorStack as SslError; use result::{WebSocketResult, WebSocketError}; static MAGIC_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -49,18 +50,18 @@ impl FromStr for WebSocketAccept { impl WebSocketAccept { /// Create a new WebSocketAccept from the given WebSocketKey - pub fn new(key: &WebSocketKey) -> WebSocketAccept { + pub fn new(key: &WebSocketKey) -> Result { let serialized = key.serialize(); let mut concat_key = String::with_capacity(serialized.len() + 36); concat_key.push_str(&serialized[..]); concat_key.push_str(MAGIC_GUID); - let output = hash(hash::Type::SHA1, concat_key.as_bytes()); + let output = try!(hash(hash::MessageDigest::sha1(), concat_key.as_bytes())); let mut iter = output.into_iter(); let mut bytes = [0u8; 20]; for i in bytes.iter_mut() { *i = iter.next().unwrap(); } - WebSocketAccept(bytes) + Ok(WebSocketAccept(bytes)) } /// Return the Base64 encoding of this WebSocketAccept pub fn serialize(&self) -> String { diff --git a/src/result.rs b/src/result.rs index c7eb9d1943..449709bf63 100644 --- a/src/result.rs +++ b/src/result.rs @@ -5,7 +5,8 @@ use std::str::Utf8Error; use std::error::Error; use std::convert::From; use std::fmt; -use openssl::ssl::error::SslError; +use openssl::error::ErrorStack as SslError; +use openssl::ssl::HandshakeError as SslHandshakeError; use hyper::Error as HttpError; use url::ParseError; @@ -15,99 +16,115 @@ pub type WebSocketResult = Result; /// Represents a WebSocket error #[derive(Debug)] pub enum WebSocketError { - /// A WebSocket protocol error - ProtocolError(&'static str), - /// Invalid WebSocket request error - RequestError(&'static str), - /// Invalid WebSocket response error - ResponseError(&'static str), - /// Invalid WebSocket data frame error - DataFrameError(&'static str), - /// No data available - NoDataAvailable, - /// An input/output error - IoError(io::Error), - /// An HTTP parsing error - HttpError(HttpError), - /// A URL parsing error - UrlError(ParseError), + /// A WebSocket protocol error + ProtocolError(&'static str), + /// Invalid WebSocket request error + RequestError(&'static str), + /// Invalid WebSocket response error + ResponseError(&'static str), + /// Invalid WebSocket data frame error + DataFrameError(&'static str), + /// No data available + NoDataAvailable, + /// An input/output error + IoError(io::Error), + /// An HTTP parsing error + HttpError(HttpError), + /// A URL parsing error + UrlError(ParseError), /// A WebSocket URL error WebSocketUrlError(WSUrlErrorKind), - /// An SSL error - SslError(SslError), - /// A UTF-8 error - Utf8Error(Utf8Error), + /// An SSL error + SslError(SslError), + /// an ssl handshake failure + SslHandshakeFailure, + /// an ssl handshake interruption + SslHandshakeInterruption, + /// A UTF-8 error + Utf8Error(Utf8Error), } impl fmt::Display for WebSocketError { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - try!(fmt.write_str("WebSocketError: ")); - try!(fmt.write_str(self.description())); - Ok(()) - } + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + try!(fmt.write_str("WebSocketError: ")); + try!(fmt.write_str(self.description())); + Ok(()) + } } impl Error for WebSocketError { - fn description(&self) -> &str { - match *self { + fn description(&self) -> &str { + match *self { WebSocketError::ProtocolError(_) => "WebSocket protocol error", - WebSocketError::RequestError(_) => "WebSocket request error", - WebSocketError::ResponseError(_) => "WebSocket response error", - WebSocketError::DataFrameError(_) => "WebSocket data frame error", - WebSocketError::NoDataAvailable => "No data available", - WebSocketError::IoError(_) => "I/O failure", - WebSocketError::HttpError(_) => "HTTP failure", - WebSocketError::UrlError(_) => "URL failure", - WebSocketError::SslError(_) => "SSL failure", - WebSocketError::Utf8Error(_) => "UTF-8 failure", + WebSocketError::RequestError(_) => "WebSocket request error", + WebSocketError::ResponseError(_) => "WebSocket response error", + WebSocketError::DataFrameError(_) => "WebSocket data frame error", + WebSocketError::NoDataAvailable => "No data available", + WebSocketError::IoError(_) => "I/O failure", + WebSocketError::HttpError(_) => "HTTP failure", + WebSocketError::UrlError(_) => "URL failure", + WebSocketError::SslError(_) => "SSL failure", + WebSocketError::SslHandshakeFailure => "SSL Handshake failure", + WebSocketError::SslHandshakeInterruption => "SSL Handshake interrupted", + WebSocketError::Utf8Error(_) => "UTF-8 failure", WebSocketError::WebSocketUrlError(_) => "WebSocket URL failure", - } - } - - fn cause(&self) -> Option<&Error> { - match *self { - WebSocketError::IoError(ref error) => Some(error), - WebSocketError::HttpError(ref error) => Some(error), - WebSocketError::UrlError(ref error) => Some(error), - WebSocketError::SslError(ref error) => Some(error), - WebSocketError::Utf8Error(ref error) => Some(error), + } + } + + fn cause(&self) -> Option<&Error> { + match *self { + WebSocketError::IoError(ref error) => Some(error), + WebSocketError::HttpError(ref error) => Some(error), + WebSocketError::UrlError(ref error) => Some(error), + WebSocketError::SslError(ref error) => Some(error), + WebSocketError::Utf8Error(ref error) => Some(error), WebSocketError::WebSocketUrlError(ref error) => Some(error), - _ => None, - } - } + _ => None, + } + } } impl From for WebSocketError { - fn from(err: io::Error) -> WebSocketError { - if err.kind() == io::ErrorKind::UnexpectedEof { - return WebSocketError::NoDataAvailable; - } - WebSocketError::IoError(err) - } + fn from(err: io::Error) -> WebSocketError { + if err.kind() == io::ErrorKind::UnexpectedEof { + return WebSocketError::NoDataAvailable; + } + WebSocketError::IoError(err) + } } impl From for WebSocketError { - fn from(err: HttpError) -> WebSocketError { - WebSocketError::HttpError(err) - } + fn from(err: HttpError) -> WebSocketError { + WebSocketError::HttpError(err) + } } impl From for WebSocketError { - fn from(err: ParseError) -> WebSocketError { - WebSocketError::UrlError(err) - } + fn from(err: ParseError) -> WebSocketError { + WebSocketError::UrlError(err) + } } impl From for WebSocketError { - fn from(err: SslError) -> WebSocketError { - WebSocketError::SslError(err) - } + fn from(err: SslError) -> WebSocketError { + WebSocketError::SslError(err) + } +} + +impl From> for WebSocketError { + fn from(err: SslHandshakeError) -> WebSocketError { + match err { + SslHandshakeError::SetupFailure(err) => WebSocketError::SslError(err), + SslHandshakeError::Failure(_) => WebSocketError::SslHandshakeFailure, + SslHandshakeError::Interrupted(_) => WebSocketError::SslHandshakeInterruption, + } + } } impl From for WebSocketError { - fn from(err: Utf8Error) -> WebSocketError { - WebSocketError::Utf8Error(err) - } + fn from(err: Utf8Error) -> WebSocketError { + WebSocketError::Utf8Error(err) + } } impl From for WebSocketError { diff --git a/src/server/mod.rs b/src/server/mod.rs index acdd159301..8f33fb9ca5 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,50 +1,60 @@ //! Provides an implementation of a WebSocket server use std::net::{ - SocketAddr, - ToSocketAddrs, - TcpListener, - TcpStream, - Shutdown, + SocketAddr, + ToSocketAddrs, + TcpListener, + TcpStream, + Shutdown, }; use std::io::{ - self, - Read, - Write, + self, + Read, + Write, }; use std::borrow::Cow; use std::ops::Deref; use std::convert::Into; use openssl::ssl::{ - SslContext, - SslMethod, - SslStream, + SslContext, + SslMethod, + SslStream, + SslAcceptor, }; use stream::{ - Stream, - MaybeSslContext, - NoSslContext, + Stream, }; use self::upgrade::{ - WsUpgrade, - IntoWs, + WsUpgrade, + IntoWs, }; pub use self::upgrade::{ - Request, - HyperIntoWsError, + Request, + HyperIntoWsError, }; pub mod upgrade; pub struct InvalidConnection -where S: Stream, + where S: Stream, { - pub stream: Option, - pub parsed: Option, - pub error: HyperIntoWsError, + pub stream: Option, + pub parsed: Option, + pub error: HyperIntoWsError, } pub type AcceptResult = Result, InvalidConnection>; +/// Marker struct for a struct not being secure +#[derive(Clone)] +pub struct NoSslAcceptor; +/// Trait that is implemented over NoSslAcceptor and SslAcceptor that +/// serves as a generic bound to make a struct with. +/// Used in the Server to specify impls based on wether the server +/// is running over SSL or not. +pub trait OptionalSslAcceptor: Clone {} +impl OptionalSslAcceptor for NoSslAcceptor {} +impl OptionalSslAcceptor for SslAcceptor {} + /// Represents a WebSocket server which can work with either normal (non-secure) connections, or secure WebSocket connections. /// /// This is a convenient way to implement WebSocket servers, however it is possible to use any sendable Reader and Writer to obtain @@ -106,132 +116,130 @@ pub type AcceptResult = Result, InvalidConnection>; ///} /// # } /// ``` -pub struct Server<'s, S> -where S: MaybeSslContext + 's, +pub struct Server + where S: OptionalSslAcceptor, { - inner: TcpListener, - ssl_context: Cow<'s, S>, + pub listener: TcpListener, + ssl_acceptor: S, } -impl<'s, S> Server<'s, S> -where S: MaybeSslContext + 's, +impl Server + where S: OptionalSslAcceptor, { - /// Get the socket address of this server - pub fn local_addr(&self) -> io::Result { - self.inner.local_addr() - } - - /// Create a new independently owned handle to the underlying socket. - pub fn try_clone(&'s self) -> io::Result> { - let inner = try!(self.inner.try_clone()); - Ok(Server { - inner: inner, - ssl_context: Cow::Borrowed(&*self.ssl_context), - }) - } - - pub fn into_owned<'o>(self) -> io::Result> { - Ok(Server { - inner: self.inner, - ssl_context: Cow::Owned(self.ssl_context.into_owned()), - }) - } + /// Get the socket address of this server + pub fn local_addr(&self) -> io::Result { + self.listener.local_addr() + } + + /// Create a new independently owned handle to the underlying socket. + pub fn try_clone(&self) -> io::Result> { + let inner = try!(self.listener.try_clone()); + Ok(Server { + listener: inner, + ssl_acceptor: self.ssl_acceptor.clone(), + }) + } } -impl<'s> Server<'s, SslContext> { - /// Bind this Server to this socket, utilising the given SslContext - pub fn bind_secure(addr: A, context: &'s SslContext) -> io::Result - where A: ToSocketAddrs, - { - Ok(Server { - inner: try!(TcpListener::bind(&addr)), - ssl_context: Cow::Borrowed(context), - }) - } - - /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest - pub fn accept(&mut self) -> AcceptResult> { - let stream = match self.inner.accept() { - Ok(s) => s.0, - Err(e) => return Err(InvalidConnection { - stream: None, - parsed: None, - error: e.into(), - }), - }; - - let stream = match SslStream::accept(&*self.ssl_context, stream) { - Ok(s) => s, - Err(err) => return Err(InvalidConnection { - stream: None, - parsed: None, - error: io::Error::new(io::ErrorKind::Other, err).into(), - }), - }; - - match stream.into_ws() { - Ok(u) => Ok(u), - Err((s, r, e)) => Err(InvalidConnection { - stream: Some(s), - parsed: r, - error: e.into(), - }), - } - } +impl Server { + /// Bind this Server to this socket, utilising the given SslContext + pub fn bind_secure(addr: A, acceptor: Option) -> io::Result + where A: ToSocketAddrs, + { + Ok(Server { + listener: try!(TcpListener::bind(&addr)), + ssl_acceptor: match acceptor { + Some(acc) => acc, + None => { + unimplemented!(); + } + }, + }) + } + + /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest + pub fn accept(&mut self) -> AcceptResult> { + let stream = match self.listener.accept() { + Ok(s) => s.0, + Err(e) => return Err(InvalidConnection { + stream: None, + parsed: None, + error: e.into(), + }), + }; + + let stream = match self.ssl_acceptor.accept(stream) { + Ok(s) => s, + Err(err) => return Err(InvalidConnection { + stream: None, + parsed: None, + error: io::Error::new(io::ErrorKind::Other, err).into(), + }), + }; + + match stream.into_ws() { + Ok(u) => Ok(u), + Err((s, r, e)) => Err(InvalidConnection { + stream: Some(s), + parsed: r, + error: e.into(), + }), + } + } /// Changes whether the Server is in nonblocking mode. /// /// If it is in nonblocking mode, accept() will return an error instead of blocking when there /// are no incoming connections. pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.inner.set_nonblocking(nonblocking) + self.listener.set_nonblocking(nonblocking) } } -impl<'s> Iterator for Server<'s, SslContext> { - type Item = WsUpgrade>; +impl Iterator for Server { + type Item = WsUpgrade>; - fn next(&mut self) -> Option<::Item> { - self.accept().ok() - } + fn next(&mut self) -> Option<::Item> { + self.accept().ok() + } } -impl<'s> Server<'s, NoSslContext> { - /// Bind this Server to this socket - pub fn bind(addr: A) -> io::Result { - Ok(Server { - inner: try!(TcpListener::bind(&addr)), - ssl_context: Cow::Owned(NoSslContext), - }) - } - - /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest - pub fn accept(&mut self) -> AcceptResult { - let stream = match self.inner.accept() { - Ok(s) => s.0, - Err(e) => return Err(InvalidConnection { - stream: None, - parsed: None, - error: e.into(), - }), - }; - - match stream.into_ws() { - Ok(u) => Ok(u), - Err((s, r, e)) => Err(InvalidConnection { - stream: Some(s), - parsed: r, - error: e.into(), - }), - } - } +impl Server { + /// Bind this Server to this socket + pub fn bind(addr: A) -> io::Result { + Ok(Server { + listener: try!(TcpListener::bind(&addr)), + ssl_acceptor: NoSslAcceptor, + }) + } + + /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest + pub fn accept(&mut self) -> AcceptResult { + let stream = match self.listener.accept() { + Ok(s) => s.0, + Err(e) => return Err(InvalidConnection { + stream: None, + parsed: None, + error: e.into(), + }), + }; + + match stream.into_ws() { + Ok(u) => Ok(u), + Err((s, r, e)) => Err(InvalidConnection { + stream: Some(s), + parsed: r, + error: e.into(), + }), + } + } } -impl<'s> Iterator for Server<'s, NoSslContext> { - type Item = WsUpgrade; +impl Iterator for Server { + type Item = WsUpgrade; - fn next(&mut self) -> Option<::Item> { - self.accept().ok() - } + fn next(&mut self) -> Option<::Item> { + self.accept().ok() + } } diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index 3f3680f2e2..6c072a9ad6 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -72,11 +72,11 @@ impl WsUpgrade pub fn accept_with(mut self, custom_headers: &Headers) -> IoResult> { let mut headers = Headers::new(); headers.extend(custom_headers.iter()); - headers.set(WebSocketAccept::new( + headers.set(try!(WebSocketAccept::new( // NOTE: we know there is a key because this is a valid request // i.e. to construct this you must go through the validate function self.request.headers.get::().unwrap() - )); + ))); headers.set(Connection(vec![ ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) ])); diff --git a/src/stream.rs b/src/stream.rs index 1891f99163..c9860e7a13 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -165,14 +165,3 @@ impl AsTcpStream for Box self.deref().as_tcp() } } - -/// Marker struct for having no SSL context in a struct. -#[derive(Clone)] -pub struct NoSslContext; -/// Trait that is implemented over NoSslContext and SslContext that -/// serves as a generic bound to make a struct with. -/// Used in the Server to specify impls based on wether the server -/// is running over SSL or not. -pub trait MaybeSslContext: Clone {} -impl MaybeSslContext for NoSslContext {} -impl MaybeSslContext for SslContext {} From d5864e27a813a3f9371c9b417355d385533ea4ee Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Tue, 28 Mar 2017 20:08:51 -0400 Subject: [PATCH 22/32] made openssl optional behind a feature gate --- Cargo.toml | 5 +- src/client/builder.rs | 14 +-- src/client/mod.rs | 35 ------- src/header/accept.rs | 186 ++++++++++++++++++------------------ src/lib.rs | 4 +- src/result.rs | 16 +++- src/server/mod.rs | 11 +-- src/server/upgrade/hyper.rs | 1 - src/server/upgrade/mod.rs | 4 +- src/stream.rs | 2 + 10 files changed, 127 insertions(+), 151 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ac71b6d32b..6fed692717 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,12 +19,15 @@ license = "MIT" [dependencies] hyper = "^0.10" unicase = "^1.0" -openssl = "^0.9.10" url = "^1.0" rustc-serialize = "^0.3" bitflags = "^0.8" rand = "^0.3" byteorder = "^1.0" +sha1 = "^0.2" +openssl = { version = "^0.9.10", optional = true } [features] +default = ["ssl"] +ssl = ["openssl"] nightly = ["hyper/nightly"] diff --git a/src/client/builder.rs b/src/client/builder.rs index d361310ecb..c626cb8083 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -23,14 +23,15 @@ use hyper::header::{ ProtocolName, }; use unicase::UniCase; -use openssl::error::ErrorStack as SslError; +#[cfg(feature="ssl")] use openssl::ssl::{ - SslContext, SslMethod, SslStream, SslConnector, SslConnectorBuilder, }; +#[cfg(feature="ssl")] +use stream::BoxedNetworkStream; use header::extensions::Extension; use header::{ WebSocketAccept, @@ -46,11 +47,7 @@ use result::{ WebSocketError, }; use stream::{ - BoxedNetworkStream, - AsTcpStream, Stream, - Splittable, - Shutdown, }; use super::Client; @@ -201,6 +198,7 @@ impl<'u> ClientBuilder<'u> { Ok(tcp_stream) } + #[cfg(feature="ssl")] fn wrap_ssl(&self, tcp_stream: TcpStream, connector: Option @@ -218,6 +216,7 @@ impl<'u> ClientBuilder<'u> { Ok(ssl_stream) } + #[cfg(feature="ssl")] pub fn connect(&mut self, ssl_config: Option ) -> WebSocketResult> { @@ -238,6 +237,7 @@ impl<'u> ClientBuilder<'u> { self.connect_on(tcp_stream) } + #[cfg(feature="ssl")] pub fn connect_secure(&mut self, ssl_config: Option ) -> WebSocketResult>> { @@ -298,7 +298,7 @@ impl<'u> ClientBuilder<'u> { WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid") )); - if response.headers.get() != Some(&(try!(WebSocketAccept::new(key)))) { + if response.headers.get() != Some(&(WebSocketAccept::new(key))) { return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid")); } diff --git a/src/client/mod.rs b/src/client/mod.rs index 017432995b..fe28fdd9e3 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -7,34 +7,11 @@ use std::borrow::{ use std::net::TcpStream; use std::net::SocketAddr; use std::io::Result as IoResult; -use std::io::{ - Write, -}; use self::url::{ Url, ParseError, - Position, -}; -use openssl::ssl::{ - SslContext, - SslMethod, - SslStream, }; -use hyper::buffer::BufReader; -use hyper::status::StatusCode; -use hyper::http::h1::parse_response; -use hyper::version::HttpVersion; -use hyper::header::{ - Headers, - Host, - Connection, - ConnectionOption, - Upgrade, - Protocol, - ProtocolName, -}; -use unicase::UniCase; use ws; use ws::sender::Sender as SenderTrait; @@ -43,22 +20,10 @@ use ws::receiver::{ MessageIterator, }; use ws::receiver::Receiver as ReceiverTrait; -use header::extensions::Extension; -use header::{ - WebSocketAccept, - WebSocketKey, - WebSocketVersion, - WebSocketProtocol, - WebSocketExtensions, - Origin, -}; use result::{ - WSUrlErrorKind, WebSocketResult, - WebSocketError, }; use stream::{ - BoxedNetworkStream, AsTcpStream, Stream, Splittable, diff --git a/src/header/accept.rs b/src/header/accept.rs index c6edc8d130..aa40070045 100644 --- a/src/header/accept.rs +++ b/src/header/accept.rs @@ -5,9 +5,8 @@ use std::fmt::{self, Debug}; use std::str::FromStr; use serialize::base64::{ToBase64, FromBase64, STANDARD}; use header::WebSocketKey; -use openssl::hash::{self, hash}; -use openssl::error::ErrorStack as SslError; use result::{WebSocketResult, WebSocketError}; +use sha1::Sha1; static MAGIC_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; @@ -16,115 +15,112 @@ static MAGIC_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; pub struct WebSocketAccept([u8; 20]); impl Debug for WebSocketAccept { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "WebSocketAccept({})", self.serialize()) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "WebSocketAccept({})", self.serialize()) + } } impl FromStr for WebSocketAccept { - type Err = WebSocketError; + type Err = WebSocketError; - fn from_str(accept: &str) -> WebSocketResult { - match accept.from_base64() { - Ok(vec) => { - if vec.len() != 20 { - return Err(WebSocketError::ProtocolError( - "Sec-WebSocket-Accept must be 20 bytes" - )); - } - let mut array = [0u8; 20]; - let mut iter = vec.into_iter(); - for i in array.iter_mut() { - *i = iter.next().unwrap(); - } - Ok(WebSocketAccept(array)) - } - Err(_) => { - return Err(WebSocketError::ProtocolError( - "Invalid Sec-WebSocket-Accept " - )); - } - } - } + fn from_str(accept: &str) -> WebSocketResult { + match accept.from_base64() { + Ok(vec) => { + if vec.len() != 20 { + return Err(WebSocketError::ProtocolError( + "Sec-WebSocket-Accept must be 20 bytes" + )); + } + let mut array = [0u8; 20]; + let mut iter = vec.into_iter(); + for i in array.iter_mut() { + *i = iter.next().unwrap(); + } + Ok(WebSocketAccept(array)) + } + Err(_) => { + return Err(WebSocketError::ProtocolError( + "Invalid Sec-WebSocket-Accept " + )); + } + } + } } impl WebSocketAccept { - /// Create a new WebSocketAccept from the given WebSocketKey - pub fn new(key: &WebSocketKey) -> Result { - let serialized = key.serialize(); - let mut concat_key = String::with_capacity(serialized.len() + 36); - concat_key.push_str(&serialized[..]); - concat_key.push_str(MAGIC_GUID); - let output = try!(hash(hash::MessageDigest::sha1(), concat_key.as_bytes())); - let mut iter = output.into_iter(); - let mut bytes = [0u8; 20]; - for i in bytes.iter_mut() { - *i = iter.next().unwrap(); - } - Ok(WebSocketAccept(bytes)) - } - /// Return the Base64 encoding of this WebSocketAccept - pub fn serialize(&self) -> String { - let WebSocketAccept(accept) = *self; - accept.to_base64(STANDARD) - } + /// Create a new WebSocketAccept from the given WebSocketKey + pub fn new(key: &WebSocketKey) -> WebSocketAccept { + let serialized = key.serialize(); + let mut concat_key = String::with_capacity(serialized.len() + 36); + concat_key.push_str(&serialized[..]); + concat_key.push_str(MAGIC_GUID); + let mut sha1 = Sha1::new(); + sha1.update(concat_key.as_bytes()); + let bytes = sha1.digest().bytes(); + WebSocketAccept(bytes) + } + /// Return the Base64 encoding of this WebSocketAccept + pub fn serialize(&self) -> String { + let WebSocketAccept(accept) = *self; + accept.to_base64(STANDARD) + } } impl Header for WebSocketAccept { - fn header_name() -> &'static str { - "Sec-WebSocket-Accept" - } + fn header_name() -> &'static str { + "Sec-WebSocket-Accept" + } - fn parse_header(raw: &[Vec]) -> hyper::Result { - from_one_raw_str(raw) - } + fn parse_header(raw: &[Vec]) -> hyper::Result { + from_one_raw_str(raw) + } } impl HeaderFormat for WebSocketAccept { - fn fmt_header(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "{}", self.serialize()) - } + fn fmt_header(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "{}", self.serialize()) + } } #[cfg(all(feature = "nightly", test))] mod tests { - use super::*; - use test; - use std::str::FromStr; - use header::{Headers, WebSocketKey}; - use hyper::header::{Header, HeaderFormatter}; - #[test] - fn test_header_accept() { - let key = FromStr::from_str("dGhlIHNhbXBsZSBub25jZQ==").unwrap(); - let accept = WebSocketAccept::new(&key); - let mut headers = Headers::new(); - headers.set(accept); - - assert_eq!(&headers.to_string()[..], "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"); - } - #[bench] - fn bench_header_accept_new(b: &mut test::Bencher) { - let key = WebSocketKey::new(); - b.iter(|| { - let mut accept = WebSocketAccept::new(&key); - test::black_box(&mut accept); - }); - } - #[bench] - fn bench_header_accept_parse(b: &mut test::Bencher) { - let value = vec![b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".to_vec()]; - b.iter(|| { - let mut accept: WebSocketAccept = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut accept); - }); - } - #[bench] - fn bench_header_accept_format(b: &mut test::Bencher) { - let value = vec![b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".to_vec()]; - let val: WebSocketAccept = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); - b.iter(|| { - format!("{}", fmt); - }); - } + use super::*; + use test; + use std::str::FromStr; + use header::{Headers, WebSocketKey}; + use hyper::header::{Header, HeaderFormatter}; + #[test] + fn test_header_accept() { + let key = FromStr::from_str("dGhlIHNhbXBsZSBub25jZQ==").unwrap(); + let accept = WebSocketAccept::new(&key); + let mut headers = Headers::new(); + headers.set(accept); + + assert_eq!(&headers.to_string()[..], "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"); + } + #[bench] + fn bench_header_accept_new(b: &mut test::Bencher) { + let key = WebSocketKey::new(); + b.iter(|| { + let mut accept = WebSocketAccept::new(&key); + test::black_box(&mut accept); + }); + } + #[bench] + fn bench_header_accept_parse(b: &mut test::Bencher) { + let value = vec![b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".to_vec()]; + b.iter(|| { + let mut accept: WebSocketAccept = Header::parse_header(&value[..]).unwrap(); + test::black_box(&mut accept); + }); + } + #[bench] + fn bench_header_accept_format(b: &mut test::Bencher) { + let value = vec![b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".to_vec()]; + let val: WebSocketAccept = Header::parse_header(&value[..]).unwrap(); + let fmt = HeaderFormatter(&val); + b.iter(|| { + format!("{}", fmt); + }); + } } diff --git a/src/lib.rs b/src/lib.rs index 843eebaeea..8d6c11d75a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,9 +40,11 @@ extern crate hyper; extern crate unicase; extern crate url; extern crate rustc_serialize as serialize; -extern crate openssl; extern crate rand; extern crate byteorder; +extern crate sha1; +#[cfg(feature="ssl")] +extern crate openssl; #[macro_use] extern crate bitflags; diff --git a/src/result.rs b/src/result.rs index 449709bf63..712c49c0b9 100644 --- a/src/result.rs +++ b/src/result.rs @@ -5,11 +5,14 @@ use std::str::Utf8Error; use std::error::Error; use std::convert::From; use std::fmt; -use openssl::error::ErrorStack as SslError; -use openssl::ssl::HandshakeError as SslHandshakeError; use hyper::Error as HttpError; use url::ParseError; +#[cfg(feature="ssl")] +use openssl::error::ErrorStack as SslError; +#[cfg(feature="ssl")] +use openssl::ssl::HandshakeError as SslHandshakeError; + /// The type used for WebSocket results pub type WebSocketResult = Result; @@ -35,10 +38,13 @@ pub enum WebSocketError { /// A WebSocket URL error WebSocketUrlError(WSUrlErrorKind), /// An SSL error + #[cfg(feature="ssl")] SslError(SslError), /// an ssl handshake failure + #[cfg(feature="ssl")] SslHandshakeFailure, /// an ssl handshake interruption + #[cfg(feature="ssl")] SslHandshakeInterruption, /// A UTF-8 error Utf8Error(Utf8Error), @@ -63,8 +69,11 @@ impl Error for WebSocketError { WebSocketError::IoError(_) => "I/O failure", WebSocketError::HttpError(_) => "HTTP failure", WebSocketError::UrlError(_) => "URL failure", + #[cfg(feature="ssl")] WebSocketError::SslError(_) => "SSL failure", + #[cfg(feature="ssl")] WebSocketError::SslHandshakeFailure => "SSL Handshake failure", + #[cfg(feature="ssl")] WebSocketError::SslHandshakeInterruption => "SSL Handshake interrupted", WebSocketError::Utf8Error(_) => "UTF-8 failure", WebSocketError::WebSocketUrlError(_) => "WebSocket URL failure", @@ -76,6 +85,7 @@ impl Error for WebSocketError { WebSocketError::IoError(ref error) => Some(error), WebSocketError::HttpError(ref error) => Some(error), WebSocketError::UrlError(ref error) => Some(error), + #[cfg(feature="ssl")] WebSocketError::SslError(ref error) => Some(error), WebSocketError::Utf8Error(ref error) => Some(error), WebSocketError::WebSocketUrlError(ref error) => Some(error), @@ -105,12 +115,14 @@ impl From for WebSocketError { } } +#[cfg(feature="ssl")] impl From for WebSocketError { fn from(err: SslError) -> WebSocketError { WebSocketError::SslError(err) } } +#[cfg(feature="ssl")] impl From> for WebSocketError { fn from(err: SslHandshakeError) -> WebSocketError { match err { diff --git a/src/server/mod.rs b/src/server/mod.rs index 8f33fb9ca5..1b14c7041e 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -4,19 +4,13 @@ use std::net::{ ToSocketAddrs, TcpListener, TcpStream, - Shutdown, }; use std::io::{ self, - Read, - Write, }; -use std::borrow::Cow; -use std::ops::Deref; use std::convert::Into; +#[cfg(feature="ssl")] use openssl::ssl::{ - SslContext, - SslMethod, SslStream, SslAcceptor, }; @@ -53,6 +47,7 @@ pub struct NoSslAcceptor; /// is running over SSL or not. pub trait OptionalSslAcceptor: Clone {} impl OptionalSslAcceptor for NoSslAcceptor {} +#[cfg(feature="ssl")] impl OptionalSslAcceptor for SslAcceptor {} /// Represents a WebSocket server which can work with either normal (non-secure) connections, or secure WebSocket connections. @@ -141,6 +136,7 @@ impl Server } } +#[cfg(feature="ssl")] impl Server { /// Bind this Server to this socket, utilising the given SslContext pub fn bind_secure(addr: A, acceptor: Option) -> io::Result @@ -196,6 +192,7 @@ impl Server { } } +#[cfg(feature="ssl")] impl Iterator for Server { type Item = WsUpgrade>; diff --git a/src/server/upgrade/hyper.rs b/src/server/upgrade/hyper.rs index 323fc329b2..90b9f3d86a 100644 --- a/src/server/upgrade/hyper.rs +++ b/src/server/upgrade/hyper.rs @@ -1,5 +1,4 @@ extern crate hyper; -extern crate openssl; use hyper::net::{ NetworkStream, diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index 6c072a9ad6..3f3680f2e2 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -72,11 +72,11 @@ impl WsUpgrade pub fn accept_with(mut self, custom_headers: &Headers) -> IoResult> { let mut headers = Headers::new(); headers.extend(custom_headers.iter()); - headers.set(try!(WebSocketAccept::new( + headers.set(WebSocketAccept::new( // NOTE: we know there is a key because this is a valid request // i.e. to construct this you must go through the validate function self.request.headers.get::().unwrap() - ))); + )); headers.set(Connection(vec![ ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) ])); diff --git a/src/stream.rs b/src/stream.rs index c9860e7a13..9c2209ba84 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -10,6 +10,7 @@ use std::io::{ }; pub use std::net::TcpStream; pub use std::net::Shutdown; +#[cfg(feature="ssl")] pub use openssl::ssl::{ SslStream, SslContext, @@ -152,6 +153,7 @@ impl AsTcpStream for TcpStream { } } +#[cfg(feature="ssl")] impl AsTcpStream for SslStream { fn as_tcp(&self) -> &TcpStream { self.get_ref() From 19f298a415e110015eb7a1f221da4a1d417ef916 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Wed, 29 Mar 2017 00:33:35 -0400 Subject: [PATCH 23/32] pass all current tests with warnings --- examples/autobahn-client.rs | 272 +++++++++++++++++------------------- examples/autobahn-server.rs | 65 ++++----- examples/client.rs | 242 ++++++++++++++++---------------- examples/hyper.rs | 97 ++++++------- examples/server.rs | 102 +++++++------- src/client/builder.rs | 16 ++- src/client/mod.rs | 69 +++------ src/lib.rs | 5 +- src/server/mod.rs | 45 +++--- src/server/upgrade/mod.rs | 6 +- 10 files changed, 440 insertions(+), 479 deletions(-) diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index 77423c7b47..ce5eb0def9 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -2,156 +2,146 @@ extern crate websocket; extern crate rustc_serialize as serialize; use std::str::from_utf8; -use websocket::client::request::Url; -use websocket::{Client, Message, Sender, Receiver}; +use websocket::ClientBuilder; +use websocket::Message; use websocket::message::Type; use serialize::json; fn main() { - let addr = "ws://127.0.0.1:9001".to_string(); - let agent = "rust-websocket"; - - println!("Using fuzzingserver {}", addr); - println!("Using agent {}", agent); - - println!("Running test suite..."); - - let mut current_case_id = 1; - let case_count = get_case_count(addr.clone()); - - while current_case_id <= case_count { - let url = addr.clone() + "/runCase?case=" + ¤t_case_id.to_string()[..] + "&agent=" + agent; - - let ws_uri = Url::parse(&url[..]).unwrap(); - let request = Client::connect(ws_uri).unwrap(); - let response = request.send().unwrap(); - match response.validate() { - Ok(()) => (), - Err(e) => { - println!("{:?}", e); - current_case_id += 1; - continue; - } - } - let (mut sender, mut receiver) = response.begin().split(); - - println!("Executing test case: {}/{}", current_case_id, case_count); - - for message in receiver.incoming_messages() { - let message: Message = match message { - Ok(message) => message, - Err(e) => { - println!("Error: {:?}", e); - let _ = sender.send_message(&Message::close()); - break; - } - }; - - match message.opcode { - Type::Text => { + let addr = "ws://127.0.0.1:9001".to_string(); + let agent = "rust-websocket"; + + println!("Using fuzzingserver {}", addr); + println!("Using agent {}", agent); + + println!("Running test suite..."); + + let mut current_case_id = 1; + let case_count = get_case_count(addr.clone()); + + while current_case_id <= case_count { + let case_id = current_case_id; + current_case_id += 1; + let url = addr.clone() + "/runCase?case=" + &case_id.to_string()[..] + "&agent=" + agent; + + let client = ClientBuilder::new(&url).unwrap() + .connect_insecure().unwrap(); + + let (mut receiver, mut sender) = client.split().unwrap(); + + println!("Executing test case: {}/{}", case_id, case_count); + + for message in receiver.incoming_messages() { + let message: Message = match message { + Ok(message) => message, + Err(e) => { + println!("Error: {:?}", e); + let _ = sender.send_message(&Message::close()); + break; + } + }; + + match message.opcode { + Type::Text => { let response = Message::text(from_utf8(&*message.payload).unwrap()); - sender.send_message(&response).unwrap(); - } - Type::Binary => { - sender.send_message(&Message::binary(message.payload)).unwrap(); - } - Type::Close => { - let _ = sender.send_message(&Message::close()); - break; - } - Type::Ping => { - sender.send_message(&Message::pong(message.payload)).unwrap(); - } - _ => (), - } - } - - current_case_id += 1; - } - - update_reports(addr.clone(), agent); + sender.send_message(&response).unwrap(); + } + Type::Binary => { + sender.send_message(&Message::binary(message.payload)).unwrap(); + } + Type::Close => { + let _ = sender.send_message(&Message::close()); + break; + } + Type::Ping => { + sender.send_message(&Message::pong(message.payload)).unwrap(); + } + _ => (), + } + } + } + + update_reports(addr.clone(), agent); } fn get_case_count(addr: String) -> usize { - let url = addr + "/getCaseCount"; - let ws_uri = Url::parse(&url[..]).unwrap(); - let request = Client::connect(ws_uri).unwrap(); - let response = request.send().unwrap(); - match response.validate() { - Ok(()) => (), - Err(e) => { - println!("{:?}", e); - return 0; - } - } - let (mut sender, mut receiver) = response.begin().split(); - - let mut count = 0; - - for message in receiver.incoming_messages() { - let message: Message = match message { - Ok(message) => message, - Err(e) => { - println!("Error: {:?}", e); - let _ = sender.send_message(&Message::close_because(1002, "".to_string())); - break; - } - }; - match message.opcode { - Type::Text => { - count = json::decode(from_utf8(&*message.payload).unwrap()).unwrap(); - println!("Will run {} cases...", count); - } - Type::Close => { - let _ = sender.send_message(&Message::close()); - break; - } - Type::Ping => { - sender.send_message(&Message::pong(message.payload)).unwrap(); - } - _ => (), - } - } - - count + let url = addr + "/getCaseCount"; + + let client = match ClientBuilder::new(&url).unwrap().connect_insecure() { + Ok(c) => c, + Err(e) => { + println!("{:?}", e); + return 0; + }, + }; + + let (mut receiver, mut sender) = client.split().unwrap(); + + let mut count = 0; + + for message in receiver.incoming_messages() { + let message: Message = match message { + Ok(message) => message, + Err(e) => { + println!("Error: {:?}", e); + let _ = sender.send_message(&Message::close_because(1002, "".to_string())); + break; + } + }; + match message.opcode { + Type::Text => { + count = json::decode(from_utf8(&*message.payload).unwrap()).unwrap(); + println!("Will run {} cases...", count); + } + Type::Close => { + let _ = sender.send_message(&Message::close()); + break; + } + Type::Ping => { + sender.send_message(&Message::pong(message.payload)).unwrap(); + } + _ => (), + } + } + + count } fn update_reports(addr: String, agent: &str) { - let url = addr + "/updateReports?agent=" + agent; - let ws_uri = Url::parse(&url[..]).unwrap(); - let request = Client::connect(ws_uri).unwrap(); - let response = request.send().unwrap(); - match response.validate() { - Ok(()) => (), - Err(e) => { - println!("{:?}", e); - return; - } - } - let (mut sender, mut receiver) = response.begin().split(); - - println!("Updating reports..."); - - for message in receiver.incoming_messages() { - let message: Message = match message { - Ok(message) => message, - Err(e) => { - println!("Error: {:?}", e); - let _ = sender.send_message(&Message::close()); - return; - } - }; - match message.opcode { - Type::Close => { - let _ = sender.send_message(&Message::close()); - println!("Reports updated."); - println!("Test suite finished!"); - return; - } - Type::Ping => { - sender.send_message(&Message::pong(message.payload)).unwrap(); - } - _ => (), - } - } + let url = addr + "/updateReports?agent=" + agent; + + let client = match ClientBuilder::new(&url).unwrap().connect_insecure() { + Ok(c) => c, + Err(e) => { + println!("{:?}", e); + return; + }, + }; + + let (mut receiver, mut sender) = client.split().unwrap(); + + println!("Updating reports..."); + + for message in receiver.incoming_messages() { + let message: Message = match message { + Ok(message) => message, + Err(e) => { + println!("Error: {:?}", e); + let _ = sender.send_message(&Message::close()); + return; + } + }; + match message.opcode { + Type::Close => { + let _ = sender.send_message(&Message::close()); + println!("Reports updated."); + println!("Test suite finished!"); + return; + } + Type::Ping => { + sender.send_message(&Message::pong(message.payload)).unwrap(); + } + _ => (), + } + } } diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 6e4885a0fa..6e7cb68eab 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -2,48 +2,45 @@ extern crate websocket; use std::thread; use std::str::from_utf8; -use websocket::{Server, Message, Sender, Receiver}; +use websocket::{Server, Message}; use websocket::message::Type; fn main() { - let addr = "127.0.0.1:9002".to_string(); + let server = Server::bind("127.0.0.1:9002").unwrap(); - let server = Server::bind(&addr[..]).unwrap(); + for connection in server { + thread::spawn(move || { + let client = connection.accept().unwrap(); - for connection in server { - thread::spawn(move || { - let request = connection.unwrap().read_request().unwrap(); - request.validate().unwrap(); - let response = request.accept(); - let (mut sender, mut receiver) = response.send().unwrap().split(); + let (mut receiver, mut sender) = client.split().unwrap(); - for message in receiver.incoming_messages() { - let message: Message = match message { - Ok(message) => message, - Err(e) => { - println!("{:?}", e); - let _ = sender.send_message(&Message::close()); - return; - } - }; + for message in receiver.incoming_messages() { + let message: Message = match message { + Ok(message) => message, + Err(e) => { + println!("{:?}", e); + let _ = sender.send_message(&Message::close()); + return; + } + }; - match message.opcode { - Type::Text => { + match message.opcode { + Type::Text => { let response = Message::text(from_utf8(&*message.payload).unwrap()); sender.send_message(&response).unwrap() }, - Type::Binary => sender.send_message(&Message::binary(message.payload)).unwrap(), - Type::Close => { - let _ = sender.send_message(&Message::close()); - return; - } - Type::Ping => { - let message = Message::pong(message.payload); - sender.send_message(&message).unwrap(); - } - _ => (), - } - } - }); - } + Type::Binary => sender.send_message(&Message::binary(message.payload)).unwrap(), + Type::Close => { + let _ = sender.send_message(&Message::close()); + return; + } + Type::Ping => { + let message = Message::pong(message.payload); + sender.send_message(&message).unwrap(); + } + _ => (), + } + } + }); + } } diff --git a/examples/client.rs b/examples/client.rs index 95e63552d3..d5ed6cb9aa 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,130 +1,124 @@ extern crate websocket; +const CONNECTION: &'static str = "ws://127.0.0.1:2794"; + fn main() { - use std::thread; - use std::sync::mpsc::channel; - use std::io::stdin; + use std::thread; + use std::sync::mpsc::channel; + use std::io::stdin; - use websocket::{Message, Sender, Receiver}; + use websocket::Message; use websocket::message::Type; - use websocket::client::request::Url; - use websocket::Client; - - let url = Url::parse("ws://127.0.0.1:2794").unwrap(); - - println!("Connecting to {}", url); - - let request = Client::connect(url).unwrap(); - - let response = request.send().unwrap(); // Send the request and retrieve a response - - println!("Validating response..."); - - response.validate().unwrap(); // Validate the response - - println!("Successfully connected"); - - let (mut sender, mut receiver) = response.begin().split(); - - let (tx, rx) = channel(); - - let tx_1 = tx.clone(); - - let send_loop = thread::spawn(move || { - loop { - // Send loop - let message: Message = match rx.recv() { - Ok(m) => m, - Err(e) => { - println!("Send Loop: {:?}", e); - return; - } - }; - match message.opcode { - Type::Close => { - let _ = sender.send_message(&message); - // If it's a close message, just send it and then return. - return; - }, - _ => (), - } - // Send the message - match sender.send_message(&message) { - Ok(()) => (), - Err(e) => { - println!("Send Loop: {:?}", e); - let _ = sender.send_message(&Message::close()); - return; - } - } - } - }); - - let receive_loop = thread::spawn(move || { - // Receive loop - for message in receiver.incoming_messages() { - let message: Message = match message { - Ok(m) => m, - Err(e) => { - println!("Receive Loop: {:?}", e); - let _ = tx_1.send(Message::close()); - return; - } - }; - match message.opcode { - Type::Close => { - // Got a close message, so send a close message and return - let _ = tx_1.send(Message::close()); - return; - } - Type::Ping => match tx_1.send(Message::pong(message.payload)) { - // Send a pong in response - Ok(()) => (), - Err(e) => { - println!("Receive Loop: {:?}", e); - return; - } - }, - // Say what we received - _ => println!("Receive Loop: {:?}", message), - } - } - }); - - loop { - let mut input = String::new(); - - stdin().read_line(&mut input).unwrap(); - - let trimmed = input.trim(); - - let message = match trimmed { - "/close" => { - // Close the connection - let _ = tx.send(Message::close()); - break; - } - // Send a ping - "/ping" => Message::ping(b"PING".to_vec()), - // Otherwise, just send text - _ => Message::text(trimmed.to_string()), - }; - - match tx.send(message) { - Ok(()) => (), - Err(e) => { - println!("Main Loop: {:?}", e); - break; - } - } - } - - // We're exiting - - println!("Waiting for child threads to exit"); - - let _ = send_loop.join(); - let _ = receive_loop.join(); - - println!("Exited"); + use websocket::client::ClientBuilder; + + println!("Connecting to {}", CONNECTION); + + let client = ClientBuilder::new(CONNECTION).unwrap() + .connect_insecure().unwrap(); + + println!("Successfully connected"); + + let (mut receiver, mut sender) = client.split().unwrap(); + + let (tx, rx) = channel(); + + let tx_1 = tx.clone(); + + let send_loop = thread::spawn(move || { + loop { + // Send loop + let message: Message = match rx.recv() { + Ok(m) => m, + Err(e) => { + println!("Send Loop: {:?}", e); + return; + } + }; + match message.opcode { + Type::Close => { + let _ = sender.send_message(&message); + // If it's a close message, just send it and then return. + return; + }, + _ => (), + } + // Send the message + match sender.send_message(&message) { + Ok(()) => (), + Err(e) => { + println!("Send Loop: {:?}", e); + let _ = sender.send_message(&Message::close()); + return; + } + } + } + }); + + let receive_loop = thread::spawn(move || { + // Receive loop + for message in receiver.incoming_messages() { + let message: Message = match message { + Ok(m) => m, + Err(e) => { + println!("Receive Loop: {:?}", e); + let _ = tx_1.send(Message::close()); + return; + } + }; + match message.opcode { + Type::Close => { + // Got a close message, so send a close message and return + let _ = tx_1.send(Message::close()); + return; + } + Type::Ping => match tx_1.send(Message::pong(message.payload)) { + // Send a pong in response + Ok(()) => (), + Err(e) => { + println!("Receive Loop: {:?}", e); + return; + } + }, + // Say what we received + _ => println!("Receive Loop: {:?}", message), + } + } + }); + + loop { + let mut input = String::new(); + + stdin().read_line(&mut input).unwrap(); + + let trimmed = input.trim(); + + let message = match trimmed { + "/close" => { + // Close the connection + let _ = tx.send(Message::close()); + break; + } + // Send a ping + "/ping" => Message::ping(b"PING".to_vec()), + // Otherwise, just send text + _ => Message::text(trimmed.to_string()), + }; + + match tx.send(message) { + Ok(()) => (), + Err(e) => { + println!("Main Loop: {:?}", e); + break; + } + } + } + + // We're exiting + + println!("Waiting for child threads to exit"); + + let _ = send_loop.join(); + let _ = receive_loop.join(); + + println!("Exited"); } diff --git a/examples/hyper.rs b/examples/hyper.rs index 640bc1c6a4..f0f731f9bb 100644 --- a/examples/hyper.rs +++ b/examples/hyper.rs @@ -14,70 +14,59 @@ use hyper::server::response::Response; // The HTTP server handler fn http_handler(_: Request, response: Response) { - let mut response = response.start().unwrap(); - // Send a client webpage - response.write_all(b"WebSocket Test

Received Messages:

").unwrap(); - response.end().unwrap(); + let mut response = response.start().unwrap(); + // Send a client webpage + response.write_all(b"WebSocket Test

Received Messages:

").unwrap(); + response.end().unwrap(); } fn main() { - // Start listening for http connections - thread::spawn(move || { - let http_server = HttpServer::http("127.0.0.1:8080").unwrap(); - http_server.handle(http_handler).unwrap(); - }); + // Start listening for http connections + thread::spawn(move || { + let http_server = HttpServer::http("127.0.0.1:8080").unwrap(); + http_server.handle(http_handler).unwrap(); + }); - // Start listening for WebSocket connections - let ws_server = Server::bind("127.0.0.1:2794").unwrap(); + // Start listening for WebSocket connections + let ws_server = Server::bind("127.0.0.1:2794").unwrap(); - for connection in ws_server { - // Spawn a new thread for each connection. - thread::spawn(move || { - let request = connection.unwrap().read_request().unwrap(); // Get the request - let headers = request.headers.clone(); // Keep the headers so we can check them + for connection in ws_server { + // Spawn a new thread for each connection. + thread::spawn(move || { + if !connection.protocols().contains(&"rust-websocket".to_string()) { + connection.reject().unwrap(); + return; + } - request.validate().unwrap(); // Validate the request + // TODO: same check like in server.rs + let mut client = connection.accept().unwrap(); - let mut response = request.accept(); // Form a response + let ip = client.peer_addr().unwrap(); - if let Some(&WebSocketProtocol(ref protocols)) = headers.get() { - if protocols.contains(&("rust-websocket".to_string())) { - // We have a protocol we want to use - response.headers.set(WebSocketProtocol(vec!["rust-websocket".to_string()])); - } - } + println!("Connection from {}", ip); - let mut client = response.send().unwrap(); // Send the response + let message = Message::text("Hello".to_string()); + client.send_message(&message).unwrap(); - let ip = client.get_mut_sender() - .get_mut() - .peer_addr() - .unwrap(); + let (mut receiver, mut sender) = client.split().unwrap(); - println!("Connection from {}", ip); + for message in receiver.incoming_messages() { + let message: Message = message.unwrap(); - let message = Message::text("Hello".to_string()); - client.send_message(&message).unwrap(); - - let (mut sender, mut receiver) = client.split(); - - for message in receiver.incoming_messages() { - let message: Message = message.unwrap(); - - match message.opcode { - Type::Close => { - let message = Message::close(); - sender.send_message(&message).unwrap(); - println!("Client {} disconnected", ip); - return; - }, - Type::Ping => { - let message = Message::pong(message.payload); - sender.send_message(&message).unwrap(); - }, - _ => sender.send_message(&message).unwrap(), - } - } - }); - } + match message.opcode { + Type::Close => { + let message = Message::close(); + sender.send_message(&message).unwrap(); + println!("Client {} disconnected", ip); + return; + }, + Type::Ping => { + let message = Message::pong(message.payload); + sender.send_message(&message).unwrap(); + }, + _ => sender.send_message(&message).unwrap(), + } + } + }); + } } diff --git a/examples/server.rs b/examples/server.rs index 1bfef5c090..87e136e513 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,61 +1,57 @@ extern crate websocket; use std::thread; -use websocket::{Server, Message, Sender, Receiver}; +use websocket::{Server, Message}; use websocket::message::Type; use websocket::header::WebSocketProtocol; +// TODO: I think the .reject() call is only for malformed packets +// there should be an easy way to accept the socket with the given protocols +// this would mean there should be a way to accept or reject on the client +// Do you send the protocol you want to talk when you are not given it as an +// option? What is a rejection response? Does the client check for it? +// Client should expose what the decided protocols/extensions/etc are. +// can you accept only one protocol?? + fn main() { - let server = Server::bind("127.0.0.1:2794").unwrap(); - - for connection in server { - // Spawn a new thread for each connection. - thread::spawn(move || { - let request = connection.unwrap().read_request().unwrap(); // Get the request - let headers = request.headers.clone(); // Keep the headers so we can check them - - request.validate().unwrap(); // Validate the request - - let mut response = request.accept(); // Form a response - - if let Some(&WebSocketProtocol(ref protocols)) = headers.get() { - if protocols.contains(&("rust-websocket".to_string())) { - // We have a protocol we want to use - response.headers.set(WebSocketProtocol(vec!["rust-websocket".to_string()])); - } - } - - let mut client = response.send().unwrap(); // Send the response - - let ip = client.get_mut_sender() - .get_mut() - .peer_addr() - .unwrap(); - - println!("Connection from {}", ip); - - let message: Message = Message::text("Hello".to_string()); - client.send_message(&message).unwrap(); - - let (mut sender, mut receiver) = client.split(); - - for message in receiver.incoming_messages() { - let message: Message = message.unwrap(); - - match message.opcode { - Type::Close => { - let message = Message::close(); - sender.send_message(&message).unwrap(); - println!("Client {} disconnected", ip); - return; - }, - Type::Ping => { - let message = Message::pong(message.payload); - sender.send_message(&message).unwrap(); - } - _ => sender.send_message(&message).unwrap(), - } - } - }); - } + let server = Server::bind("127.0.0.1:2794").unwrap(); + + for request in server { + // Spawn a new thread for each connection. + thread::spawn(move || { + if !request.protocols().contains(&"rust-websocket".to_string()) { + request.reject().unwrap(); + return; + } + + let mut client = request.accept().unwrap(); + + let ip = client.peer_addr().unwrap(); + + println!("Connection from {}", ip); + + let message: Message = Message::text("Hello".to_string()); + client.send_message(&message).unwrap(); + + let (mut receiver, mut sender) = client.split().unwrap(); + + for message in receiver.incoming_messages() { + let message: Message = message.unwrap(); + + match message.opcode { + Type::Close => { + let message = Message::close(); + sender.send_message(&message).unwrap(); + println!("Client {} disconnected", ip); + return; + }, + Type::Ping => { + let message = Message::pong(message.payload); + sender.send_message(&message).unwrap(); + } + _ => sender.send_message(&message).unwrap(), + } + } + }); + } } diff --git a/src/client/builder.rs b/src/client/builder.rs index c626cb8083..ec9317093c 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -5,10 +5,11 @@ use std::io::{ Write, }; use std::net::TcpStream; -use url::{ +pub use url::{ Url, - Position, + ParseError, }; +use url::Position; use hyper::version::HttpVersion; use hyper::status::StatusCode; use hyper::buffer::BufReader; @@ -81,7 +82,16 @@ pub struct ClientBuilder<'u> { } impl<'u> ClientBuilder<'u> { - pub fn new(url: Cow<'u, Url>) -> Self { + pub fn from_url(address: &'u Url) -> Self { + ClientBuilder::init(Cow::Borrowed(address)) + } + + pub fn new(address: &str) -> Result { + let url = try!(Url::parse(address)); + Ok(ClientBuilder::init(Cow::Owned(url))) + } + + fn init(url: Cow<'u, Url>) -> Self { ClientBuilder { url: url, version: HttpVersion::Http11, diff --git a/src/client/mod.rs b/src/client/mod.rs index fe28fdd9e3..85753b7c87 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -8,11 +8,6 @@ use std::net::TcpStream; use std::net::SocketAddr; use std::io::Result as IoResult; -use self::url::{ - Url, - ParseError, -}; - use ws; use ws::sender::Sender as SenderTrait; use ws::receiver::{ @@ -38,7 +33,11 @@ pub use sender::Writer; pub use receiver::Reader; pub mod builder; -pub use self::builder::ClientBuilder; +pub use self::builder::{ + ClientBuilder, + Url, + ParseError, +}; /// Represents a WebSocket client, which can send and receive messages/data frames. /// @@ -59,15 +58,10 @@ pub use self::builder::ClientBuilder; ///extern crate websocket; ///# fn main() { /// -///use websocket::{Client, Message}; -///use websocket::client::request::Url; -/// -///let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL -///let request = Client::connect(url).unwrap(); // Connect to the server -///let response = request.send().unwrap(); // Send the request -///response.validate().unwrap(); // Ensure the response is valid +///use websocket::{ClientBuilder, Message}; /// -///let mut client = response.begin(); // Get a Client +///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() +/// .connect(None).unwrap(); /// ///let message = Message::text("Hello, World!"); ///client.send_message(&message).unwrap(); // Send message @@ -125,18 +119,9 @@ impl Client } } -impl<'u, S> Client +impl Client where S: Stream, { - pub fn from_url(address: &'u Url) -> ClientBuilder<'u> { - ClientBuilder::new(Cow::Borrowed(address)) - } - - pub fn build(address: &str) -> Result, ParseError> { - let url = try!(Url::parse(address)); - Ok(ClientBuilder::new(Cow::Owned(url))) - } - /// Creates a Client from a given stream /// **without sending any handshake** this is meant to only be used with /// a stream that has a websocket connection already set up. @@ -194,14 +179,10 @@ impl<'u, S> Client ///```no_run ///# extern crate websocket; ///# fn main() { - ///use websocket::{Client, Message}; - ///# use websocket::client::request::Url; - ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL - ///# let request = Client::connect(url).unwrap(); // Connect to the server - ///# let response = request.send().unwrap(); // Send the request - ///# response.validate().unwrap(); // Ensure the response is valid + ///use websocket::{ClientBuilder, Message}; /// - ///let mut client = response.begin(); // Get a Client + ///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() + /// .connect(None).unwrap(); /// ///for message in client.incoming_messages() { /// let message: Message = message.unwrap(); @@ -217,15 +198,13 @@ impl<'u, S> Client ///```no_run ///# extern crate websocket; ///# fn main() { - ///use websocket::{Client, Message, Sender, Receiver}; - ///# use websocket::client::request::Url; - ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL - ///# let request = Client::connect(url).unwrap(); // Connect to the server - ///# let response = request.send().unwrap(); // Send the request - ///# response.validate().unwrap(); // Ensure the response is valid + ///use websocket::{ClientBuilder, Message}; + /// + ///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() + /// .connect_insecure().unwrap(); + /// + ///let (mut receiver, mut sender) = client.split().unwrap(); /// - ///let client = response.begin(); // Get a Client - ///let (mut sender, mut receiver) = client.split(); // Split the Client ///for message in receiver.incoming_messages() { /// let message: Message = message.unwrap(); /// // Echo the message back @@ -251,17 +230,13 @@ impl Client ///```no_run ///# extern crate websocket; ///# fn main() { - ///use websocket::{Client, Message, Sender, Receiver}; ///use std::thread; - ///# use websocket::client::request::Url; - ///# let url = Url::parse("ws://127.0.0.1:1234").unwrap(); // Get the URL - ///# let request = Client::connect(url).unwrap(); // Connect to the server - ///# let response = request.send().unwrap(); // Send the request - ///# response.validate().unwrap(); // Ensure the response is valid + ///use websocket::{ClientBuilder, Message}; /// - ///let client = response.begin(); // Get a Client + ///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() + /// .connect_insecure().unwrap(); /// - ///let (mut sender, mut receiver) = client.split(); + ///let (mut receiver, mut sender) = client.split().unwrap(); /// ///thread::spawn(move || { /// for message in receiver.incoming_messages() { diff --git a/src/lib.rs b/src/lib.rs index 8d6c11d75a..7f489841b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,7 +52,10 @@ extern crate bitflags; #[cfg(all(feature = "nightly", test))] extern crate test; -pub use self::client::Client; +pub use self::client::{ + Client, + ClientBuilder, +}; pub use self::server::Server; pub use self::dataframe::DataFrame; pub use self::message::Message; diff --git a/src/server/mod.rs b/src/server/mod.rs index 1b14c7041e..105f84ca0c 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -67,9 +67,7 @@ impl OptionalSslAcceptor for SslAcceptor {} ///for connection in server { /// // Spawn a new thread for each connection. /// thread::spawn(move || { -/// let request = connection.unwrap().read_request().unwrap(); // Get the request -/// let response = request.accept(); // Form a response -/// let mut client = response.send().unwrap(); // Send the response +/// let mut client = connection.accept().unwrap(); /// /// let message = Message::text("Hello, client!"); /// let _ = client.send_message(&message); @@ -86,22 +84,34 @@ impl OptionalSslAcceptor for SslAcceptor {} ///extern crate openssl; ///# fn main() { ///use std::thread; -///use std::path::Path; +///use std::io::Read; +///use std::fs::File; ///use websocket::{Server, Message}; -///use openssl::ssl::{SslContext, SslMethod}; -///use openssl::x509::X509FileType; +///use openssl::pkcs12::Pkcs12; +///use openssl::ssl::{SslMethod, SslAcceptorBuilder, SslStream}; /// -///let mut context = SslContext::new(SslMethod::Tlsv1).unwrap(); -///let _ = context.set_certificate_file(&(Path::new("cert.pem")), X509FileType::PEM); -///let _ = context.set_private_key_file(&(Path::new("key.pem")), X509FileType::PEM); -///let server = Server::bind_secure("127.0.0.1:1234", &context).unwrap(); +///// In this example we retrieve our keypair and certificate chain from a PKCS #12 archive, +///// but but they can also be retrieved from, for example, individual PEM- or DER-formatted +///// files. See the documentation for the `PKey` and `X509` types for more details. +///let mut file = File::open("identity.pfx").unwrap(); +///let mut pkcs12 = vec![]; +///file.read_to_end(&mut pkcs12).unwrap(); +///let pkcs12 = Pkcs12::from_der(&pkcs12).unwrap(); +///let identity = pkcs12.parse("password123").unwrap(); +/// +///let acceptor = SslAcceptorBuilder::mozilla_intermediate(SslMethod::tls(), +/// &identity.pkey, +/// &identity.cert, +/// &identity.chain) +/// .unwrap() +/// .build(); +/// +///let server = Server::bind_secure("127.0.0.1:1234", acceptor).unwrap(); /// ///for connection in server { /// // Spawn a new thread for each connection. /// thread::spawn(move || { -/// let request = connection.unwrap().read_request().unwrap(); // Get the request -/// let response = request.accept(); // Form a response -/// let mut client = response.send().unwrap(); // Send the response +/// let mut client = connection.accept().unwrap(); /// /// let message = Message::text("Hello, client!"); /// let _ = client.send_message(&message); @@ -139,17 +149,12 @@ impl Server #[cfg(feature="ssl")] impl Server { /// Bind this Server to this socket, utilising the given SslContext - pub fn bind_secure
(addr: A, acceptor: Option) -> io::Result + pub fn bind_secure(addr: A, acceptor: SslAcceptor) -> io::Result where A: ToSocketAddrs, { Ok(Server { listener: try!(TcpListener::bind(&addr)), - ssl_acceptor: match acceptor { - Some(acc) => acc, - None => { - unimplemented!(); - } - }, + ssl_acceptor: acceptor, }) } diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index 3f3680f2e2..6fae866f83 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -104,8 +104,10 @@ impl WsUpgrade ::std::mem::drop(self); } - pub fn protocols(&self) -> Option<&[String]> { - self.request.headers.get::().map(|p| p.0.as_slice()) + pub fn protocols(&self) -> &[String] { + self.request.headers.get::() + .map(|p| p.0.as_slice()) + .unwrap_or(&[]) } pub fn extensions(&self) -> Option<&[Extension]> { From 8dd2ff0b43c366cfe957de0e126674e66ee2d0c9 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Wed, 29 Mar 2017 01:10:11 -0400 Subject: [PATCH 24/32] tests pass without warnings --- examples/hyper.rs | 4 +--- examples/server.rs | 1 - src/client/mod.rs | 3 --- src/header/accept.rs | 6 +++--- src/header/extensions.rs | 13 +++++++++---- src/header/key.rs | 7 +++---- src/header/origin.rs | 13 +++++++++---- src/header/protocol.rs | 13 +++++++++---- src/header/version.rs | 13 +++++++++---- src/server/upgrade/mod.rs | 6 ++++-- 10 files changed, 47 insertions(+), 32 deletions(-) diff --git a/examples/hyper.rs b/examples/hyper.rs index f0f731f9bb..932604cb7d 100644 --- a/examples/hyper.rs +++ b/examples/hyper.rs @@ -3,11 +3,9 @@ extern crate hyper; use std::thread; use std::io::Write; -use websocket::{Server, Message, Sender, Receiver}; -use websocket::header::WebSocketProtocol; +use websocket::{Server, Message}; use websocket::message::Type; use hyper::Server as HttpServer; -use hyper::server::Handler; use hyper::net::Fresh; use hyper::server::request::Request; use hyper::server::response::Response; diff --git a/examples/server.rs b/examples/server.rs index 87e136e513..91d21fe65a 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -3,7 +3,6 @@ extern crate websocket; use std::thread; use websocket::{Server, Message}; use websocket::message::Type; -use websocket::header::WebSocketProtocol; // TODO: I think the .reject() call is only for malformed packets // there should be an easy way to accept the socket with the given protocols diff --git a/src/client/mod.rs b/src/client/mod.rs index 85753b7c87..bd224bc945 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,9 +1,6 @@ //! Contains the WebSocket client. extern crate url; -use std::borrow::{ - Cow, -}; use std::net::TcpStream; use std::net::SocketAddr; use std::io::Result as IoResult; diff --git a/src/header/accept.rs b/src/header/accept.rs index aa40070045..0b2c57a171 100644 --- a/src/header/accept.rs +++ b/src/header/accept.rs @@ -88,7 +88,8 @@ mod tests { use test; use std::str::FromStr; use header::{Headers, WebSocketKey}; - use hyper::header::{Header, HeaderFormatter}; + use hyper::header::Header; + #[test] fn test_header_accept() { let key = FromStr::from_str("dGhlIHNhbXBsZSBub25jZQ==").unwrap(); @@ -118,9 +119,8 @@ mod tests { fn bench_header_accept_format(b: &mut test::Bencher) { let value = vec![b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".to_vec()]; let val: WebSocketAccept = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); b.iter(|| { - format!("{}", fmt); + format!("{}", val.serialize()); }); } } diff --git a/src/header/extensions.rs b/src/header/extensions.rs index 07854ad109..6c842dd32d 100644 --- a/src/header/extensions.rs +++ b/src/header/extensions.rs @@ -120,10 +120,16 @@ impl HeaderFormat for WebSocketExtensions { } } +impl fmt::Display for WebSocketExtensions { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(fmt) + } +} + #[cfg(all(feature = "nightly", test))] mod tests { use super::*; - use hyper::header::{Header, HeaderFormatter}; + use hyper::header::Header; use test; #[test] fn test_header_extensions() { @@ -148,9 +154,8 @@ mod tests { fn bench_header_extensions_format(b: &mut test::Bencher) { let value = vec![b"foo, bar; baz; qux=quux".to_vec()]; let val: WebSocketExtensions = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); b.iter(|| { - format!("{}", fmt); + format!("{}", val); }); } -} \ No newline at end of file +} diff --git a/src/header/key.rs b/src/header/key.rs index 31706aabf1..fb96faf4f1 100644 --- a/src/header/key.rs +++ b/src/header/key.rs @@ -83,7 +83,7 @@ impl HeaderFormat for WebSocketKey { #[cfg(all(feature = "nightly", test))] mod tests { use super::*; - use hyper::header::{Header, HeaderFormatter}; + use hyper::header::Header; use test; #[test] fn test_header_key() { @@ -114,9 +114,8 @@ mod tests { fn bench_header_key_format(b: &mut test::Bencher) { let value = vec![b"QUFBQUFBQUFBQUFBQUFBQQ==".to_vec()]; let val: WebSocketKey = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); b.iter(|| { - format!("{}", fmt); + format!("{}", val.serialize()); }); } -} \ No newline at end of file +} diff --git a/src/header/origin.rs b/src/header/origin.rs index 8209fe1d29..3f172e108b 100644 --- a/src/header/origin.rs +++ b/src/header/origin.rs @@ -32,10 +32,16 @@ impl HeaderFormat for Origin { } } +impl fmt::Display for Origin { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(fmt) + } +} + #[cfg(all(feature = "nightly", test))] mod tests { use super::*; - use hyper::header::{Header, HeaderFormatter}; + use hyper::header::Header; use test; #[test] fn test_header_origin() { @@ -59,9 +65,8 @@ mod tests { fn bench_header_origin_format(b: &mut test::Bencher) { let value = vec![b"foobar".to_vec()]; let val: Origin = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); b.iter(|| { - format!("{}", fmt); + format!("{}", val); }); } -} \ No newline at end of file +} diff --git a/src/header/protocol.rs b/src/header/protocol.rs index 899c970e48..3582b323c8 100644 --- a/src/header/protocol.rs +++ b/src/header/protocol.rs @@ -32,10 +32,16 @@ impl HeaderFormat for WebSocketProtocol { } } +impl fmt::Display for WebSocketProtocol { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(fmt) + } +} + #[cfg(all(feature = "nightly", test))] mod tests { use super::*; - use hyper::header::{Header, HeaderFormatter}; + use hyper::header::Header; use test; #[test] fn test_header_protocol() { @@ -59,9 +65,8 @@ mod tests { fn bench_header_protocol_format(b: &mut test::Bencher) { let value = vec![b"foo, bar".to_vec()]; let val: WebSocketProtocol = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); b.iter(|| { - format!("{}", fmt); + format!("{}", val); }); } -} \ No newline at end of file +} diff --git a/src/header/version.rs b/src/header/version.rs index 664a4faeee..4128e35267 100644 --- a/src/header/version.rs +++ b/src/header/version.rs @@ -46,10 +46,16 @@ impl HeaderFormat for WebSocketVersion { } } +impl fmt::Display for WebSocketVersion { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(fmt) + } +} + #[cfg(all(feature = "nightly", test))] mod tests { use super::*; - use hyper::header::{Header, HeaderFormatter}; + use hyper::header::Header; use test; #[test] fn test_websocket_version() { @@ -73,9 +79,8 @@ mod tests { fn bench_header_version_format(b: &mut test::Bencher) { let value = vec![b"13".to_vec()]; let val: WebSocketVersion = Header::parse_header(&value[..]).unwrap(); - let fmt = HeaderFormatter(&val); b.iter(|| { - format!("{}", fmt); + format!("{}", val); }); } -} \ No newline at end of file +} diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index 6fae866f83..72db067b5f 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -110,8 +110,10 @@ impl WsUpgrade .unwrap_or(&[]) } - pub fn extensions(&self) -> Option<&[Extension]> { - self.request.headers.get::().map(|e| e.0.as_slice()) + pub fn extensions(&self) -> &[Extension] { + self.request.headers.get::() + .map(|e| e.0.as_slice()) + .unwrap_or(&[]) } pub fn key(&self) -> Option<&[u8; 16]> { From 7b86a411496e3fdbb9b88f44e1cf07df8fbdb641 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Wed, 29 Mar 2017 01:38:20 -0400 Subject: [PATCH 25/32] fleshed out roadmap, fixes #27 --- ROADMAP.md | 31 ++++++++++++++++++++++++++++--- src/stream.rs | 2 -- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/ROADMAP.md b/ROADMAP.md index 1bf0284e08..3e8910742b 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -1,5 +1,30 @@ +# The Roadmap -### Adding Features +## More Docs, Examples and Tests + +Easy as that, every method should be tested and documented. +Every use-case should have an example. + +## Adding Features + +### `net2` Feature + +This is a feature to add the `net2` crate which will let us do cool things +like set the option `SO_REUSEADDR` and similar when making TCP connections. + +This is discussed in [vi/rust-websocket#2](https://github.com/vi/rust-websocket/pull/2). + +### Add Mio & Tokio (Evented Websocket) + +There are a lot of issues that would be solved if this was evented, such as: + + - [#88 tokio support](https://github.com/cyderize/rust-websocket/issues/88) + - [#66 Timeout on recv_message](https://github.com/cyderize/rust-websocket/issues/66) + - [#6 one client, one thread?](https://github.com/cyderize/rust-websocket/issues/6) + +So maybe we should _just_ add `tokio` support, or maybe `mio` is still used and popular. + +### Support Permessage-Deflate + +We need this to pass more autobahn tests! - - Make the usage of `net2` a feature - - Make evented diff --git a/src/stream.rs b/src/stream.rs index 9c2209ba84..79c9bb8572 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,6 +1,4 @@ //! Provides the default stream type for WebSocket connections. -// TODO: add mio support & tokio -// extern crate mio; use std::ops::Deref; use std::io::{ From ca8d8245c6effa29b96a7ce56f6567690016fdb8 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Wed, 29 Mar 2017 02:05:03 -0400 Subject: [PATCH 26/32] added standard style enforced by the CI --- .rustfmt.toml | 8 + .travis.yml | 47 ++-- examples/autobahn-client.rs | 263 +++++++++--------- examples/autobahn-server.rs | 68 ++--- examples/client.rs | 240 +++++++++-------- examples/hyper.rs | 88 +++--- examples/server.rs | 80 +++--- src/client/builder.rs | 524 +++++++++++++++++------------------- src/client/mod.rs | 379 +++++++++++++------------- src/dataframe.rs | 165 ++++++------ src/header/accept.rs | 173 ++++++------ src/header/extensions.rs | 83 +++--- src/header/key.rs | 37 ++- src/header/mod.rs | 2 +- src/header/origin.rs | 28 +- src/header/protocol.rs | 29 +- src/header/version.rs | 40 ++- src/lib.rs | 5 +- src/message.rs | 178 ++++++------ src/receiver.rs | 125 ++++----- src/result.rs | 218 +++++++-------- src/sender.rs | 68 +++-- src/server/mod.rs | 243 ++++++++--------- src/server/upgrade/hyper.rs | 58 ++-- src/server/upgrade/mod.rs | 476 ++++++++++++++++---------------- src/stream.rs | 187 +++++++------ src/ws/dataframe.rs | 185 ++++++------- src/ws/message.rs | 6 +- src/ws/receiver.rs | 142 +++++----- src/ws/sender.rs | 10 +- src/ws/util/header.rs | 76 ++---- src/ws/util/mask.rs | 79 +++--- 32 files changed, 2107 insertions(+), 2203 deletions(-) create mode 100644 .rustfmt.toml diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000000..5deae4867d --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,8 @@ +hard_tabs = true +array_layout = "Block" +fn_args_layout = "Block" +chain_indent = "Visual" +chain_one_line_max = 100 +take_source_hints = true +write_mode = "Overwrite" + diff --git a/.travis.yml b/.travis.yml index 65fe2cf2c7..f47b1f4114 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,12 @@ language: rust rust: nightly +cache: cargo +before_script: + - export PATH="$PATH:$HOME/.cargo/bin" + - which rustfmt || cargo install rustfmt script: + - cargo fmt -- --write-mode=diff - cargo build --features nightly - cargo test --features nightly - cargo bench --features nightly @@ -12,36 +17,36 @@ after_success: - > [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && sudo pip install ghp-import - > - [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && { - echo "Running Autobahn TestSuite for client" ; - wstest -m fuzzingserver -s ./autobahn/fuzzingserver.json & FUZZINGSERVER_PID=$! ; - sleep 10 ; - ./target/debug/examples/autobahn-client ; + [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && { + echo "Running Autobahn TestSuite for client" ; + wstest -m fuzzingserver -s ./autobahn/fuzzingserver.json & FUZZINGSERVER_PID=$! ; + sleep 10 ; + ./target/debug/examples/autobahn-client ; kill -9 ${FUZZINGSERVER_PID} ; } - > - [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && { - echo "Running Autobahn TestSuite for server" ; - ./target/debug/examples/autobahn-server & WSSERVER_PID=$! ; - sleep 10 ; - wstest -m fuzzingclient -s ./autobahn/fuzzingclient.json ; + [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && { + echo "Running Autobahn TestSuite for server" ; + ./target/debug/examples/autobahn-server & WSSERVER_PID=$! ; + sleep 10 ; + wstest -m fuzzingclient -s ./autobahn/fuzzingclient.json ; kill -9 ${WSSERVER_PID} ; } - > [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && { - echo "Building docs and gh-pages" ; - PROJECT_VERSION=$(cargo doc --features nightly | grep "Documenting websocket v" | sed 's/.*Documenting websocket v\(.*\) .*/\1/') ; - curl -sL https://github.com/${TRAVIS_REPO_SLUG}/archive/html.tar.gz | tar xz ; - cd ./rust-websocket-html && - find . -type f | xargs sed -i 's//'"${PROJECT_VERSION}"'/g' ; - mv ../target/doc ./doc ; - mv ../autobahn/server ./autobahn/server ; - mv ../autobahn/client ./autobahn/client ; - mv ./autobahn/server/index.json ./autobahn/server/index.temp && rm ./autobahn/server/*.json && mv ./autobahn/server/index.temp ./autobahn/server/index.json ; + echo "Building docs and gh-pages" ; + PROJECT_VERSION=$(cargo doc --features nightly | grep "Documenting websocket v" | sed 's/.*Documenting websocket v\(.*\) .*/\1/') ; + curl -sL https://github.com/${TRAVIS_REPO_SLUG}/archive/html.tar.gz | tar xz ; + cd ./rust-websocket-html && + find . -type f | xargs sed -i 's//'"${PROJECT_VERSION}"'/g' ; + mv ../target/doc ./doc ; + mv ../autobahn/server ./autobahn/server ; + mv ../autobahn/client ./autobahn/client ; + mv ./autobahn/server/index.json ./autobahn/server/index.temp && rm ./autobahn/server/*.json && mv ./autobahn/server/index.temp ./autobahn/server/index.json ; mv ./autobahn/client/index.json ./autobahn/client/index.temp && rm ./autobahn/client/*.json && mv ./autobahn/client/index.temp ./autobahn/client/index.json ; cd ../ ; } - > [ $TRAVIS_BRANCH = master ] && [ $TRAVIS_PULL_REQUEST = false ] && { - echo "Pushing gh-pages" ; - ghp-import -n ./rust-websocket-html -m "Generated by Travis CI build ${TRAVIS_BUILD_NUMBER} for commit ${TRAVIS_COMMIT}" && + echo "Pushing gh-pages" ; + ghp-import -n ./rust-websocket-html -m "Generated by Travis CI build ${TRAVIS_BUILD_NUMBER} for commit ${TRAVIS_COMMIT}" && git push -fq https://${TOKEN}@github.com/${TRAVIS_REPO_SLUG}.git gh-pages ; } env: diff --git a/examples/autobahn-client.rs b/examples/autobahn-client.rs index ce5eb0def9..2519e79f1e 100644 --- a/examples/autobahn-client.rs +++ b/examples/autobahn-client.rs @@ -8,140 +8,143 @@ use websocket::message::Type; use serialize::json; fn main() { - let addr = "ws://127.0.0.1:9001".to_string(); - let agent = "rust-websocket"; - - println!("Using fuzzingserver {}", addr); - println!("Using agent {}", agent); - - println!("Running test suite..."); - - let mut current_case_id = 1; - let case_count = get_case_count(addr.clone()); - - while current_case_id <= case_count { - let case_id = current_case_id; - current_case_id += 1; - let url = addr.clone() + "/runCase?case=" + &case_id.to_string()[..] + "&agent=" + agent; - - let client = ClientBuilder::new(&url).unwrap() - .connect_insecure().unwrap(); - - let (mut receiver, mut sender) = client.split().unwrap(); - - println!("Executing test case: {}/{}", case_id, case_count); - - for message in receiver.incoming_messages() { - let message: Message = match message { - Ok(message) => message, - Err(e) => { - println!("Error: {:?}", e); - let _ = sender.send_message(&Message::close()); - break; - } - }; - - match message.opcode { - Type::Text => { - let response = Message::text(from_utf8(&*message.payload).unwrap()); - sender.send_message(&response).unwrap(); - } - Type::Binary => { - sender.send_message(&Message::binary(message.payload)).unwrap(); - } - Type::Close => { - let _ = sender.send_message(&Message::close()); - break; - } - Type::Ping => { - sender.send_message(&Message::pong(message.payload)).unwrap(); - } - _ => (), - } - } - } - - update_reports(addr.clone(), agent); + let addr = "ws://127.0.0.1:9001".to_string(); + let agent = "rust-websocket"; + + println!("Using fuzzingserver {}", addr); + println!("Using agent {}", agent); + + println!("Running test suite..."); + + let mut current_case_id = 1; + let case_count = get_case_count(addr.clone()); + + while current_case_id <= case_count { + let case_id = current_case_id; + current_case_id += 1; + let url = addr.clone() + "/runCase?case=" + &case_id.to_string()[..] + "&agent=" + agent; + + let client = ClientBuilder::new(&url) + .unwrap() + .connect_insecure() + .unwrap(); + + let (mut receiver, mut sender) = client.split().unwrap(); + + println!("Executing test case: {}/{}", case_id, case_count); + + for message in receiver.incoming_messages() { + let message: Message = match message { + Ok(message) => message, + Err(e) => { + println!("Error: {:?}", e); + let _ = sender.send_message(&Message::close()); + break; + } + }; + + match message.opcode { + Type::Text => { + let response = Message::text(from_utf8(&*message.payload).unwrap()); + sender.send_message(&response).unwrap(); + } + Type::Binary => { + sender.send_message(&Message::binary(message.payload)).unwrap(); + } + Type::Close => { + let _ = sender.send_message(&Message::close()); + break; + } + Type::Ping => { + sender.send_message(&Message::pong(message.payload)).unwrap(); + } + _ => (), + } + } + } + + update_reports(addr.clone(), agent); } fn get_case_count(addr: String) -> usize { - let url = addr + "/getCaseCount"; - - let client = match ClientBuilder::new(&url).unwrap().connect_insecure() { - Ok(c) => c, - Err(e) => { - println!("{:?}", e); - return 0; - }, - }; - - let (mut receiver, mut sender) = client.split().unwrap(); - - let mut count = 0; - - for message in receiver.incoming_messages() { - let message: Message = match message { - Ok(message) => message, - Err(e) => { - println!("Error: {:?}", e); - let _ = sender.send_message(&Message::close_because(1002, "".to_string())); - break; - } - }; - match message.opcode { - Type::Text => { - count = json::decode(from_utf8(&*message.payload).unwrap()).unwrap(); - println!("Will run {} cases...", count); - } - Type::Close => { - let _ = sender.send_message(&Message::close()); - break; - } - Type::Ping => { - sender.send_message(&Message::pong(message.payload)).unwrap(); - } - _ => (), - } - } - - count + let url = addr + "/getCaseCount"; + + let client = match ClientBuilder::new(&url).unwrap().connect_insecure() { + Ok(c) => c, + Err(e) => { + println!("{:?}", e); + return 0; + } + }; + + let (mut receiver, mut sender) = client.split().unwrap(); + + let mut count = 0; + + for message in receiver.incoming_messages() { + let message: Message = match message { + Ok(message) => message, + Err(e) => { + println!("Error: {:?}", e); + let _ = + sender.send_message(&Message::close_because(1002, "".to_string())); + break; + } + }; + match message.opcode { + Type::Text => { + count = json::decode(from_utf8(&*message.payload).unwrap()).unwrap(); + println!("Will run {} cases...", count); + } + Type::Close => { + let _ = sender.send_message(&Message::close()); + break; + } + Type::Ping => { + sender.send_message(&Message::pong(message.payload)).unwrap(); + } + _ => (), + } + } + + count } fn update_reports(addr: String, agent: &str) { - let url = addr + "/updateReports?agent=" + agent; - - let client = match ClientBuilder::new(&url).unwrap().connect_insecure() { - Ok(c) => c, - Err(e) => { - println!("{:?}", e); - return; - }, - }; - - let (mut receiver, mut sender) = client.split().unwrap(); - - println!("Updating reports..."); - - for message in receiver.incoming_messages() { - let message: Message = match message { - Ok(message) => message, - Err(e) => { - println!("Error: {:?}", e); - let _ = sender.send_message(&Message::close()); - return; - } - }; - match message.opcode { - Type::Close => { - let _ = sender.send_message(&Message::close()); - println!("Reports updated."); - println!("Test suite finished!"); - return; - } - Type::Ping => { - sender.send_message(&Message::pong(message.payload)).unwrap(); - } - _ => (), - } - } + let url = addr + "/updateReports?agent=" + agent; + + let client = match ClientBuilder::new(&url).unwrap().connect_insecure() { + Ok(c) => c, + Err(e) => { + println!("{:?}", e); + return; + } + }; + + let (mut receiver, mut sender) = client.split().unwrap(); + + println!("Updating reports..."); + + for message in receiver.incoming_messages() { + let message: Message = match message { + Ok(message) => message, + Err(e) => { + println!("Error: {:?}", e); + let _ = sender.send_message(&Message::close()); + return; + } + }; + match message.opcode { + Type::Close => { + let _ = sender.send_message(&Message::close()); + println!("Reports updated."); + println!("Test suite finished!"); + return; + } + Type::Ping => { + sender.send_message(&Message::pong(message.payload)).unwrap(); + } + _ => (), + } + } } diff --git a/examples/autobahn-server.rs b/examples/autobahn-server.rs index 6e7cb68eab..f9a8d959e9 100644 --- a/examples/autobahn-server.rs +++ b/examples/autobahn-server.rs @@ -6,41 +6,43 @@ use websocket::{Server, Message}; use websocket::message::Type; fn main() { - let server = Server::bind("127.0.0.1:9002").unwrap(); + let server = Server::bind("127.0.0.1:9002").unwrap(); - for connection in server { - thread::spawn(move || { - let client = connection.accept().unwrap(); + for connection in server { + thread::spawn(move || { + let client = connection.accept().unwrap(); - let (mut receiver, mut sender) = client.split().unwrap(); + let (mut receiver, mut sender) = client.split().unwrap(); - for message in receiver.incoming_messages() { - let message: Message = match message { - Ok(message) => message, - Err(e) => { - println!("{:?}", e); - let _ = sender.send_message(&Message::close()); - return; - } - }; + for message in receiver.incoming_messages() { + let message: Message = match message { + Ok(message) => message, + Err(e) => { + println!("{:?}", e); + let _ = sender.send_message(&Message::close()); + return; + } + }; - match message.opcode { - Type::Text => { - let response = Message::text(from_utf8(&*message.payload).unwrap()); - sender.send_message(&response).unwrap() - }, - Type::Binary => sender.send_message(&Message::binary(message.payload)).unwrap(), - Type::Close => { - let _ = sender.send_message(&Message::close()); - return; - } - Type::Ping => { - let message = Message::pong(message.payload); - sender.send_message(&message).unwrap(); - } - _ => (), - } - } - }); - } + match message.opcode { + Type::Text => { + let response = Message::text(from_utf8(&*message.payload).unwrap()); + sender.send_message(&response).unwrap() + } + Type::Binary => { + sender.send_message(&Message::binary(message.payload)).unwrap() + } + Type::Close => { + let _ = sender.send_message(&Message::close()); + return; + } + Type::Ping => { + let message = Message::pong(message.payload); + sender.send_message(&message).unwrap(); + } + _ => (), + } + } + }); + } } diff --git a/examples/client.rs b/examples/client.rs index d5ed6cb9aa..62bc504708 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -3,122 +3,126 @@ extern crate websocket; const CONNECTION: &'static str = "ws://127.0.0.1:2794"; fn main() { - use std::thread; - use std::sync::mpsc::channel; - use std::io::stdin; - - use websocket::Message; - use websocket::message::Type; - use websocket::client::ClientBuilder; - - println!("Connecting to {}", CONNECTION); - - let client = ClientBuilder::new(CONNECTION).unwrap() - .connect_insecure().unwrap(); - - println!("Successfully connected"); - - let (mut receiver, mut sender) = client.split().unwrap(); - - let (tx, rx) = channel(); - - let tx_1 = tx.clone(); - - let send_loop = thread::spawn(move || { - loop { - // Send loop - let message: Message = match rx.recv() { - Ok(m) => m, - Err(e) => { - println!("Send Loop: {:?}", e); - return; - } - }; - match message.opcode { - Type::Close => { - let _ = sender.send_message(&message); - // If it's a close message, just send it and then return. - return; - }, - _ => (), - } - // Send the message - match sender.send_message(&message) { - Ok(()) => (), - Err(e) => { - println!("Send Loop: {:?}", e); - let _ = sender.send_message(&Message::close()); - return; - } - } - } - }); - - let receive_loop = thread::spawn(move || { - // Receive loop - for message in receiver.incoming_messages() { - let message: Message = match message { - Ok(m) => m, - Err(e) => { - println!("Receive Loop: {:?}", e); - let _ = tx_1.send(Message::close()); - return; - } - }; - match message.opcode { - Type::Close => { - // Got a close message, so send a close message and return - let _ = tx_1.send(Message::close()); - return; - } - Type::Ping => match tx_1.send(Message::pong(message.payload)) { - // Send a pong in response - Ok(()) => (), - Err(e) => { - println!("Receive Loop: {:?}", e); - return; - } - }, - // Say what we received - _ => println!("Receive Loop: {:?}", message), - } - } - }); - - loop { - let mut input = String::new(); - - stdin().read_line(&mut input).unwrap(); - - let trimmed = input.trim(); - - let message = match trimmed { - "/close" => { - // Close the connection - let _ = tx.send(Message::close()); - break; - } - // Send a ping - "/ping" => Message::ping(b"PING".to_vec()), - // Otherwise, just send text - _ => Message::text(trimmed.to_string()), - }; - - match tx.send(message) { - Ok(()) => (), - Err(e) => { - println!("Main Loop: {:?}", e); - break; - } - } - } - - // We're exiting - - println!("Waiting for child threads to exit"); - - let _ = send_loop.join(); - let _ = receive_loop.join(); - - println!("Exited"); + use std::thread; + use std::sync::mpsc::channel; + use std::io::stdin; + + use websocket::Message; + use websocket::message::Type; + use websocket::client::ClientBuilder; + + println!("Connecting to {}", CONNECTION); + + let client = ClientBuilder::new(CONNECTION) + .unwrap() + .connect_insecure() + .unwrap(); + + println!("Successfully connected"); + + let (mut receiver, mut sender) = client.split().unwrap(); + + let (tx, rx) = channel(); + + let tx_1 = tx.clone(); + + let send_loop = thread::spawn(move || { + loop { + // Send loop + let message: Message = match rx.recv() { + Ok(m) => m, + Err(e) => { + println!("Send Loop: {:?}", e); + return; + } + }; + match message.opcode { + Type::Close => { + let _ = sender.send_message(&message); + // If it's a close message, just send it and then return. + return; + } + _ => (), + } + // Send the message + match sender.send_message(&message) { + Ok(()) => (), + Err(e) => { + println!("Send Loop: {:?}", e); + let _ = sender.send_message(&Message::close()); + return; + } + } + } + }); + + let receive_loop = thread::spawn(move || { + // Receive loop + for message in receiver.incoming_messages() { + let message: Message = match message { + Ok(m) => m, + Err(e) => { + println!("Receive Loop: {:?}", e); + let _ = tx_1.send(Message::close()); + return; + } + }; + match message.opcode { + Type::Close => { + // Got a close message, so send a close message and return + let _ = tx_1.send(Message::close()); + return; + } + Type::Ping => { + match tx_1.send(Message::pong(message.payload)) { + // Send a pong in response + Ok(()) => (), + Err(e) => { + println!("Receive Loop: {:?}", e); + return; + } + } + } + // Say what we received + _ => println!("Receive Loop: {:?}", message), + } + } + }); + + loop { + let mut input = String::new(); + + stdin().read_line(&mut input).unwrap(); + + let trimmed = input.trim(); + + let message = match trimmed { + "/close" => { + // Close the connection + let _ = tx.send(Message::close()); + break; + } + // Send a ping + "/ping" => Message::ping(b"PING".to_vec()), + // Otherwise, just send text + _ => Message::text(trimmed.to_string()), + }; + + match tx.send(message) { + Ok(()) => (), + Err(e) => { + println!("Main Loop: {:?}", e); + break; + } + } + } + + // We're exiting + + println!("Waiting for child threads to exit"); + + let _ = send_loop.join(); + let _ = receive_loop.join(); + + println!("Exited"); } diff --git a/examples/hyper.rs b/examples/hyper.rs index 932604cb7d..33a52edaae 100644 --- a/examples/hyper.rs +++ b/examples/hyper.rs @@ -10,61 +10,63 @@ use hyper::net::Fresh; use hyper::server::request::Request; use hyper::server::response::Response; +const HTML: &'static str = include_str!("websockets.html"); + // The HTTP server handler fn http_handler(_: Request, response: Response) { - let mut response = response.start().unwrap(); - // Send a client webpage - response.write_all(b"WebSocket Test

Received Messages:

").unwrap(); - response.end().unwrap(); + let mut response = response.start().unwrap(); + // Send a client webpage + response.write_all(HTML.as_bytes()).unwrap(); + response.end().unwrap(); } fn main() { - // Start listening for http connections - thread::spawn(move || { - let http_server = HttpServer::http("127.0.0.1:8080").unwrap(); - http_server.handle(http_handler).unwrap(); - }); + // Start listening for http connections + thread::spawn(move || { + let http_server = HttpServer::http("127.0.0.1:8080").unwrap(); + http_server.handle(http_handler).unwrap(); + }); - // Start listening for WebSocket connections - let ws_server = Server::bind("127.0.0.1:2794").unwrap(); + // Start listening for WebSocket connections + let ws_server = Server::bind("127.0.0.1:2794").unwrap(); - for connection in ws_server { - // Spawn a new thread for each connection. - thread::spawn(move || { - if !connection.protocols().contains(&"rust-websocket".to_string()) { - connection.reject().unwrap(); - return; - } + for connection in ws_server { + // Spawn a new thread for each connection. + thread::spawn(move || { + if !connection.protocols().contains(&"rust-websocket".to_string()) { + connection.reject().unwrap(); + return; + } - // TODO: same check like in server.rs - let mut client = connection.accept().unwrap(); + // TODO: same check like in server.rs + let mut client = connection.accept().unwrap(); - let ip = client.peer_addr().unwrap(); + let ip = client.peer_addr().unwrap(); - println!("Connection from {}", ip); + println!("Connection from {}", ip); - let message = Message::text("Hello".to_string()); - client.send_message(&message).unwrap(); + let message = Message::text("Hello".to_string()); + client.send_message(&message).unwrap(); - let (mut receiver, mut sender) = client.split().unwrap(); + let (mut receiver, mut sender) = client.split().unwrap(); - for message in receiver.incoming_messages() { - let message: Message = message.unwrap(); + for message in receiver.incoming_messages() { + let message: Message = message.unwrap(); - match message.opcode { - Type::Close => { - let message = Message::close(); - sender.send_message(&message).unwrap(); - println!("Client {} disconnected", ip); - return; - }, - Type::Ping => { - let message = Message::pong(message.payload); - sender.send_message(&message).unwrap(); - }, - _ => sender.send_message(&message).unwrap(), - } - } - }); - } + match message.opcode { + Type::Close => { + let message = Message::close(); + sender.send_message(&message).unwrap(); + println!("Client {} disconnected", ip); + return; + } + Type::Ping => { + let message = Message::pong(message.payload); + sender.send_message(&message).unwrap(); + } + _ => sender.send_message(&message).unwrap(), + } + } + }); + } } diff --git a/examples/server.rs b/examples/server.rs index 91d21fe65a..776c396d34 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -13,44 +13,44 @@ use websocket::message::Type; // can you accept only one protocol?? fn main() { - let server = Server::bind("127.0.0.1:2794").unwrap(); - - for request in server { - // Spawn a new thread for each connection. - thread::spawn(move || { - if !request.protocols().contains(&"rust-websocket".to_string()) { - request.reject().unwrap(); - return; - } - - let mut client = request.accept().unwrap(); - - let ip = client.peer_addr().unwrap(); - - println!("Connection from {}", ip); - - let message: Message = Message::text("Hello".to_string()); - client.send_message(&message).unwrap(); - - let (mut receiver, mut sender) = client.split().unwrap(); - - for message in receiver.incoming_messages() { - let message: Message = message.unwrap(); - - match message.opcode { - Type::Close => { - let message = Message::close(); - sender.send_message(&message).unwrap(); - println!("Client {} disconnected", ip); - return; - }, - Type::Ping => { - let message = Message::pong(message.payload); - sender.send_message(&message).unwrap(); - } - _ => sender.send_message(&message).unwrap(), - } - } - }); - } + let server = Server::bind("127.0.0.1:2794").unwrap(); + + for request in server { + // Spawn a new thread for each connection. + thread::spawn(move || { + if !request.protocols().contains(&"rust-websocket".to_string()) { + request.reject().unwrap(); + return; + } + + let mut client = request.accept().unwrap(); + + let ip = client.peer_addr().unwrap(); + + println!("Connection from {}", ip); + + let message: Message = Message::text("Hello".to_string()); + client.send_message(&message).unwrap(); + + let (mut receiver, mut sender) = client.split().unwrap(); + + for message in receiver.incoming_messages() { + let message: Message = message.unwrap(); + + match message.opcode { + Type::Close => { + let message = Message::close(); + sender.send_message(&message).unwrap(); + println!("Client {} disconnected", ip); + return; + } + Type::Ping => { + let message = Message::pong(message.payload); + sender.send_message(&message).unwrap(); + } + _ => sender.send_message(&message).unwrap(), + } + } + }); + } } diff --git a/src/client/builder.rs b/src/client/builder.rs index ec9317093c..30bc2d2dbf 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -1,55 +1,23 @@ -use std::borrow::{ - Cow, -}; -use std::io::{ - Write, -}; +use std::borrow::Cow; +use std::io::Write; use std::net::TcpStream; -pub use url::{ - Url, - ParseError, -}; +pub use url::{Url, ParseError}; use url::Position; use hyper::version::HttpVersion; use hyper::status::StatusCode; use hyper::buffer::BufReader; use hyper::http::h1::parse_response; -use hyper::header::{ - Headers, - Host, - Connection, - ConnectionOption, - Upgrade, - Protocol, - ProtocolName, -}; +use hyper::header::{Headers, Host, Connection, ConnectionOption, Upgrade, Protocol, ProtocolName}; use unicase::UniCase; #[cfg(feature="ssl")] -use openssl::ssl::{ - SslMethod, - SslStream, - SslConnector, - SslConnectorBuilder, -}; +use openssl::ssl::{SslMethod, SslStream, SslConnector, SslConnectorBuilder}; #[cfg(feature="ssl")] use stream::BoxedNetworkStream; use header::extensions::Extension; -use header::{ - WebSocketAccept, - WebSocketKey, - WebSocketVersion, - WebSocketProtocol, - WebSocketExtensions, - Origin, -}; -use result::{ - WSUrlErrorKind, - WebSocketResult, - WebSocketError, -}; -use stream::{ - Stream, -}; +use header::{WebSocketAccept, WebSocketKey, WebSocketVersion, WebSocketProtocol, + WebSocketExtensions, Origin}; +use result::{WSUrlErrorKind, WebSocketResult, WebSocketError}; +use stream::Stream; use super::Client; macro_rules! upsert_header { @@ -74,258 +42,270 @@ macro_rules! upsert_header { /// Build clients with a builder-style API #[derive(Clone, Debug)] pub struct ClientBuilder<'u> { - url: Cow<'u, Url>, - version: HttpVersion, - headers: Headers, - version_set: bool, - key_set: bool, + url: Cow<'u, Url>, + version: HttpVersion, + headers: Headers, + version_set: bool, + key_set: bool, } impl<'u> ClientBuilder<'u> { - pub fn from_url(address: &'u Url) -> Self { - ClientBuilder::init(Cow::Borrowed(address)) - } - - pub fn new(address: &str) -> Result { - let url = try!(Url::parse(address)); - Ok(ClientBuilder::init(Cow::Owned(url))) - } - - fn init(url: Cow<'u, Url>) -> Self { - ClientBuilder { - url: url, - version: HttpVersion::Http11, - version_set: false, - key_set: false, - headers: Headers::new(), - } - } - - pub fn add_protocol

(mut self, protocol: P) -> Self - where P: Into, - { - upsert_header!(self.headers; WebSocketProtocol; { + pub fn from_url(address: &'u Url) -> Self { + ClientBuilder::init(Cow::Borrowed(address)) + } + + pub fn new(address: &str) -> Result { + let url = try!(Url::parse(address)); + Ok(ClientBuilder::init(Cow::Owned(url))) + } + + fn init(url: Cow<'u, Url>) -> Self { + ClientBuilder { + url: url, + version: HttpVersion::Http11, + version_set: false, + key_set: false, + headers: Headers::new(), + } + } + + pub fn add_protocol

(mut self, protocol: P) -> Self + where P: Into + { + upsert_header!(self.headers; WebSocketProtocol; { Some(protos) => protos.0.push(protocol.into()), None => WebSocketProtocol(vec![protocol.into()]) }); - self - } - - pub fn add_protocols(mut self, protocols: I) -> Self - where I: IntoIterator, - S: Into, - { - let mut protocols: Vec = protocols.into_iter() - .map(Into::into).collect(); - - upsert_header!(self.headers; WebSocketProtocol; { + self + } + + pub fn add_protocols(mut self, protocols: I) -> Self + where I: IntoIterator, + S: Into + { + let mut protocols: Vec = + protocols.into_iter() + .map(Into::into) + .collect(); + + upsert_header!(self.headers; WebSocketProtocol; { Some(protos) => protos.0.append(&mut protocols), None => WebSocketProtocol(protocols) }); - self - } + self + } - pub fn clear_protocols(mut self) -> Self { - self.headers.remove::(); - self - } + pub fn clear_protocols(mut self) -> Self { + self.headers.remove::(); + self + } - pub fn add_extension(mut self, extension: Extension) -> Self - { - upsert_header!(self.headers; WebSocketExtensions; { + pub fn add_extension(mut self, extension: Extension) -> Self { + upsert_header!(self.headers; WebSocketExtensions; { Some(protos) => protos.0.push(extension), None => WebSocketExtensions(vec![extension]) }); - self - } - - pub fn add_extensions(mut self, extensions: I) -> Self - where I: IntoIterator, - { - let mut extensions: Vec = extensions.into_iter().collect(); - upsert_header!(self.headers; WebSocketExtensions; { + self + } + + pub fn add_extensions(mut self, extensions: I) -> Self + where I: IntoIterator + { + let mut extensions: Vec = + extensions.into_iter().collect(); + upsert_header!(self.headers; WebSocketExtensions; { Some(protos) => protos.0.append(&mut extensions), None => WebSocketExtensions(extensions) }); - self - } - - pub fn clear_extensions(mut self) -> Self { - self.headers.remove::(); - self - } - - pub fn key(mut self, key: [u8; 16]) -> Self { - self.headers.set(WebSocketKey(key)); - self.key_set = true; - self - } - - pub fn clear_key(mut self) -> Self { - self.headers.remove::(); - self.key_set = false; - self - } - - pub fn version(mut self, version: WebSocketVersion) -> Self { - self.headers.set(version); - self.version_set = true; - self - } - - pub fn clear_version(mut self) -> Self { - self.headers.remove::(); - self.version_set = false; - self - } - - pub fn origin(mut self, origin: String) -> Self { - self.headers.set(Origin(origin)); - self - } - - pub fn custom_headers(mut self, edit: F) -> Self - where F: Fn(&mut Headers), - { - edit(&mut self.headers); - self - } - - fn establish_tcp(&mut self, secure: Option) -> WebSocketResult { - let port = match (self.url.port(), secure) { - (Some(port), _) => port, - (None, None) if self.url.scheme() == "wss" => 443, - (None, None) => 80, - (None, Some(true)) => 443, - (None, Some(false)) => 80, - }; - let host = match self.url.host_str() { - Some(h) => h, - None => return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName)), - }; - - let tcp_stream = try!(TcpStream::connect((host, port))); - Ok(tcp_stream) - } - - #[cfg(feature="ssl")] - fn wrap_ssl(&self, - tcp_stream: TcpStream, - connector: Option - ) -> WebSocketResult> { - let host = match self.url.host_str() { - Some(h) => h, - None => return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName)), - }; - let connector = match connector { - Some(c) => c, - None => try!(SslConnectorBuilder::new(SslMethod::tls())).build(), - }; - - let ssl_stream = try!(connector.connect(host, tcp_stream)); - Ok(ssl_stream) - } - - #[cfg(feature="ssl")] - pub fn connect(&mut self, - ssl_config: Option - ) -> WebSocketResult> { - let tcp_stream = try!(self.establish_tcp(None)); - - let boxed_stream = if self.url.scheme() == "wss" { - BoxedNetworkStream(Box::new(try!(self.wrap_ssl(tcp_stream, ssl_config)))) - } else { - BoxedNetworkStream(Box::new(tcp_stream)) - }; - - self.connect_on(boxed_stream) - } - - pub fn connect_insecure(&mut self) -> WebSocketResult> { - let tcp_stream = try!(self.establish_tcp(Some(false))); - - self.connect_on(tcp_stream) - } - - #[cfg(feature="ssl")] - pub fn connect_secure(&mut self, - ssl_config: Option - ) -> WebSocketResult>> { - let tcp_stream = try!(self.establish_tcp(Some(true))); - - let ssl_stream = try!(self.wrap_ssl(tcp_stream, ssl_config)); - - self.connect_on(ssl_stream) - } - - // TODO: refactor and split apart into two parts, for when evented happens - pub fn connect_on(&mut self, mut stream: S) -> WebSocketResult> - where S: Stream, - { - let resource = self.url[Position::BeforePath..Position::AfterQuery] - .to_owned(); - - // enter host if available (unix sockets don't have hosts) - if let Some(host) = self.url.host_str() { - self.headers.set(Host { - hostname: host.to_string(), - port: self.url.port(), - }); - } - - self.headers.set(Connection(vec![ + self + } + + pub fn clear_extensions(mut self) -> Self { + self.headers.remove::(); + self + } + + pub fn key(mut self, key: [u8; 16]) -> Self { + self.headers.set(WebSocketKey(key)); + self.key_set = true; + self + } + + pub fn clear_key(mut self) -> Self { + self.headers.remove::(); + self.key_set = false; + self + } + + pub fn version(mut self, version: WebSocketVersion) -> Self { + self.headers.set(version); + self.version_set = true; + self + } + + pub fn clear_version(mut self) -> Self { + self.headers.remove::(); + self.version_set = false; + self + } + + pub fn origin(mut self, origin: String) -> Self { + self.headers.set(Origin(origin)); + self + } + + pub fn custom_headers(mut self, edit: F) -> Self + where F: Fn(&mut Headers) + { + edit(&mut self.headers); + self + } + + fn establish_tcp(&mut self, secure: Option) -> WebSocketResult { + let port = match (self.url.port(), secure) { + (Some(port), _) => port, + (None, None) if self.url.scheme() == "wss" => 443, + (None, None) => 80, + (None, Some(true)) => 443, + (None, Some(false)) => 80, + }; + let host = match self.url.host_str() { + Some(h) => h, + None => return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName)), + }; + + let tcp_stream = try!(TcpStream::connect((host, port))); + Ok(tcp_stream) + } + + #[cfg(feature="ssl")] + fn wrap_ssl( + &self, + tcp_stream: TcpStream, + connector: Option, + ) -> WebSocketResult> { + let host = match self.url.host_str() { + Some(h) => h, + None => return Err(WebSocketError::WebSocketUrlError(WSUrlErrorKind::NoHostName)), + }; + let connector = match connector { + Some(c) => c, + None => try!(SslConnectorBuilder::new(SslMethod::tls())).build(), + }; + + let ssl_stream = try!(connector.connect(host, tcp_stream)); + Ok(ssl_stream) + } + + #[cfg(feature="ssl")] + pub fn connect( + &mut self, + ssl_config: Option, + ) -> WebSocketResult> { + let tcp_stream = try!(self.establish_tcp(None)); + + let boxed_stream = if self.url.scheme() == "wss" { + BoxedNetworkStream(Box::new(try!(self.wrap_ssl(tcp_stream, ssl_config)))) + } else { + BoxedNetworkStream(Box::new(tcp_stream)) + }; + + self.connect_on(boxed_stream) + } + + pub fn connect_insecure(&mut self) -> WebSocketResult> { + let tcp_stream = try!(self.establish_tcp(Some(false))); + + self.connect_on(tcp_stream) + } + + #[cfg(feature="ssl")] + pub fn connect_secure( + &mut self, + ssl_config: Option, + ) -> WebSocketResult>> { + let tcp_stream = try!(self.establish_tcp(Some(true))); + + let ssl_stream = try!(self.wrap_ssl(tcp_stream, ssl_config)); + + self.connect_on(ssl_stream) + } + + // TODO: refactor and split apart into two parts, for when evented happens + pub fn connect_on(&mut self, mut stream: S) -> WebSocketResult> + where S: Stream + { + let resource = self.url[Position::BeforePath..Position::AfterQuery].to_owned(); + + // enter host if available (unix sockets don't have hosts) + if let Some(host) = self.url.host_str() { + self.headers + .set(Host { + hostname: host.to_string(), + port: self.url.port(), + }); + } + + self.headers + .set(Connection(vec![ ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) ])); - self.headers.set(Upgrade(vec![Protocol { - name: ProtocolName::WebSocket, - version: None - }])); - - if !self.version_set { - self.headers.set(WebSocketVersion::WebSocket13); - } - - if !self.key_set { - self.headers.set(WebSocketKey::new()); - } - - // send request - try!(write!(stream.writer(), "GET {} {}\r\n", resource, self.version)); - try!(write!(stream.writer(), "{}\r\n", self.headers)); - - // wait for a response - // TODO: we should buffer it all, how to set up stream for this? - let response = try!(parse_response(&mut BufReader::new(stream.reader()))); - let status = StatusCode::from_u16(response.subject.0); - - // validate - if status != StatusCode::SwitchingProtocols { - return Err(WebSocketError::ResponseError("Status code must be Switching Protocols")); - } - - let key = try!(self.headers.get::().ok_or( - WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid") - )); - - if response.headers.get() != Some(&(WebSocketAccept::new(key))) { - return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid")); - } - - if response.headers.get() != Some(&(Upgrade(vec![Protocol { - name: ProtocolName::WebSocket, - version: None - }]))) { - return Err(WebSocketError::ResponseError("Upgrade field must be WebSocket")); - } - - if self.headers.get() != Some(&(Connection(vec![ + self.headers + .set(Upgrade(vec![ + Protocol { + name: ProtocolName::WebSocket, + version: None, + }, + ])); + + if !self.version_set { + self.headers.set(WebSocketVersion::WebSocket13); + } + + if !self.key_set { + self.headers.set(WebSocketKey::new()); + } + + // send request + try!(write!(stream.writer(), "GET {} {}\r\n", resource, self.version)); + try!(write!(stream.writer(), "{}\r\n", self.headers)); + + // wait for a response + // TODO: we should buffer it all, how to set up stream for this? + let response = try!(parse_response(&mut BufReader::new(stream.reader()))); + let status = StatusCode::from_u16(response.subject.0); + + // validate + if status != StatusCode::SwitchingProtocols { + return Err(WebSocketError::ResponseError("Status code must be Switching Protocols")); + } + + let key = try!(self.headers + .get::() + .ok_or(WebSocketError::RequestError("Request Sec-WebSocket-Key was invalid"))); + + if response.headers.get() != Some(&(WebSocketAccept::new(key))) { + return Err(WebSocketError::ResponseError("Sec-WebSocket-Accept is invalid")); + } + + if response.headers.get() != + Some(&(Upgrade(vec![ + Protocol { + name: ProtocolName::WebSocket, + version: None, + }, + ]))) { + return Err(WebSocketError::ResponseError("Upgrade field must be WebSocket")); + } + + if self.headers.get() != + Some(&(Connection(vec![ ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())), ]))) { - return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); - } + return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); + } - Ok(Client::unchecked(stream)) - } + Ok(Client::unchecked(stream)) + } } - diff --git a/src/client/mod.rs b/src/client/mod.rs index bd224bc945..1cb50af356 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -7,20 +7,10 @@ use std::io::Result as IoResult; use ws; use ws::sender::Sender as SenderTrait; -use ws::receiver::{ - DataFrameIterator, - MessageIterator, -}; +use ws::receiver::{DataFrameIterator, MessageIterator}; use ws::receiver::Receiver as ReceiverTrait; -use result::{ - WebSocketResult, -}; -use stream::{ - AsTcpStream, - Stream, - Splittable, - Shutdown, -}; +use result::WebSocketResult; +use stream::{AsTcpStream, Stream, Splittable, Shutdown}; use dataframe::DataFrame; use ws::dataframe::DataFrame as DataFrameable; @@ -30,11 +20,7 @@ pub use sender::Writer; pub use receiver::Reader; pub mod builder; -pub use self::builder::{ - ClientBuilder, - Url, - ParseError, -}; +pub use self::builder::{ClientBuilder, Url, ParseError}; /// Represents a WebSocket client, which can send and receive messages/data frames. /// @@ -65,195 +51,202 @@ pub use self::builder::{ ///# } ///``` pub struct Client - where S: Stream, + where S: Stream { - pub stream: S, - sender: Sender, - receiver: Receiver, + pub stream: S, + sender: Sender, + receiver: Receiver, } impl Client { - /// Shuts down the sending half of the client connection, will cause all pending - /// and future IO to return immediately with an appropriate value. - pub fn shutdown_sender(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Write) - } - - /// Shuts down the receiving half of the client connection, will cause all pending - /// and future IO to return immediately with an appropriate value. - pub fn shutdown_receiver(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Read) - } + /// Shuts down the sending half of the client connection, will cause all pending + /// and future IO to return immediately with an appropriate value. + pub fn shutdown_sender(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Write) + } + + /// Shuts down the receiving half of the client connection, will cause all pending + /// and future IO to return immediately with an appropriate value. + pub fn shutdown_receiver(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Read) + } } impl Client - where S: AsTcpStream + Stream, + where S: AsTcpStream + Stream { - /// Shuts down the client connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Both) - } - - /// See `TcpStream.peer_addr()`. - pub fn peer_addr(&self) -> IoResult { - self.stream.as_tcp().peer_addr() - } - - /// See `TcpStream.local_addr()`. - pub fn local_addr(&self) -> IoResult { - self.stream.as_tcp().local_addr() - } - - /// See `TcpStream.set_nodelay()`. - pub fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { - self.stream.as_tcp().set_nodelay(nodelay) - } - - /// Changes whether the stream is in nonblocking mode. - pub fn set_nonblocking(&self, nonblocking: bool) -> IoResult<()> { - self.stream.as_tcp().set_nonblocking(nonblocking) - } + /// Shuts down the client connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Both) + } + + /// See `TcpStream.peer_addr()`. + pub fn peer_addr(&self) -> IoResult { + self.stream.as_tcp().peer_addr() + } + + /// See `TcpStream.local_addr()`. + pub fn local_addr(&self) -> IoResult { + self.stream.as_tcp().local_addr() + } + + /// See `TcpStream.set_nodelay()`. + pub fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { + self.stream.as_tcp().set_nodelay(nodelay) + } + + /// Changes whether the stream is in nonblocking mode. + pub fn set_nonblocking(&self, nonblocking: bool) -> IoResult<()> { + self.stream.as_tcp().set_nonblocking(nonblocking) + } } impl Client - where S: Stream, + where S: Stream { - /// Creates a Client from a given stream - /// **without sending any handshake** this is meant to only be used with - /// a stream that has a websocket connection already set up. - /// If in doubt, don't use this! - pub fn unchecked(stream: S) -> Self { - Client { - stream: stream, - // NOTE: these are always true & false, see - // https://tools.ietf.org/html/rfc6455#section-5 - sender: Sender::new(true), - receiver: Receiver::new(false), - } - } - - /// Sends a single data frame to the remote endpoint. - pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> - where D: DataFrameable - { - self.sender.send_dataframe(self.stream.writer(), dataframe) - } - - /// Sends a single message to the remote endpoint. - pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> - where M: ws::Message<'m, D>, D: DataFrameable - { - self.sender.send_message(self.stream.writer(), message) - } - - /// Reads a single data frame from the remote endpoint. - pub fn recv_dataframe(&mut self) -> WebSocketResult { - self.receiver.recv_dataframe(self.stream.reader()) - } - - /// Returns an iterator over incoming data frames. - pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, S::Reader> { - self.receiver.incoming_dataframes(self.stream.reader()) - } - - /// Reads a single message from this receiver. - pub fn recv_message<'m, M, I>(&mut self) -> WebSocketResult - where M: ws::Message<'m, DataFrame, DataFrameIterator = I>, I: Iterator { - self.receiver.recv_message(self.stream.reader()) - } - - pub fn stream_ref(&self) -> &S { - &self.stream - } - - pub fn stream_ref_mut(&mut self) -> &mut S { - &mut self.stream - } - - /// Returns an iterator over incoming messages. - /// - ///```no_run - ///# extern crate websocket; - ///# fn main() { - ///use websocket::{ClientBuilder, Message}; - /// - ///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() - /// .connect(None).unwrap(); - /// - ///for message in client.incoming_messages() { - /// let message: Message = message.unwrap(); - /// println!("Recv: {:?}", message); - ///} - ///# } - ///``` - /// - /// Note that since this method mutably borrows the `Client`, it may be necessary to - /// first `split()` the `Client` and call `incoming_messages()` on the returned - /// `Receiver` to be able to send messages within an iteration. - /// - ///```no_run - ///# extern crate websocket; - ///# fn main() { - ///use websocket::{ClientBuilder, Message}; - /// - ///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() - /// .connect_insecure().unwrap(); - /// - ///let (mut receiver, mut sender) = client.split().unwrap(); - /// - ///for message in receiver.incoming_messages() { - /// let message: Message = message.unwrap(); - /// // Echo the message back - /// sender.send_message(&message).unwrap(); - ///} - ///# } - ///``` - pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, S::Reader> - where M: ws::Message<'a, D>, - D: DataFrameable - { - self.receiver.incoming_messages(self.stream.reader()) - } + /// Creates a Client from a given stream + /// **without sending any handshake** this is meant to only be used with + /// a stream that has a websocket connection already set up. + /// If in doubt, don't use this! + pub fn unchecked(stream: S) -> Self { + Client { + stream: stream, + // NOTE: these are always true & false, see + // https://tools.ietf.org/html/rfc6455#section-5 + sender: Sender::new(true), + receiver: Receiver::new(false), + } + } + + /// Sends a single data frame to the remote endpoint. + pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> + where D: DataFrameable + { + self.sender.send_dataframe(self.stream.writer(), dataframe) + } + + /// Sends a single message to the remote endpoint. + pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> + where M: ws::Message<'m, D>, + D: DataFrameable + { + self.sender.send_message(self.stream.writer(), message) + } + + /// Reads a single data frame from the remote endpoint. + pub fn recv_dataframe(&mut self) -> WebSocketResult { + self.receiver.recv_dataframe(self.stream.reader()) + } + + /// Returns an iterator over incoming data frames. + pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, S::Reader> { + self.receiver.incoming_dataframes(self.stream.reader()) + } + + /// Reads a single message from this receiver. + pub fn recv_message<'m, M, I>(&mut self) -> WebSocketResult + where M: ws::Message<'m, DataFrame, DataFrameIterator = I>, + I: Iterator + { + self.receiver.recv_message(self.stream.reader()) + } + + pub fn stream_ref(&self) -> &S { + &self.stream + } + + pub fn stream_ref_mut(&mut self) -> &mut S { + &mut self.stream + } + + /// Returns an iterator over incoming messages. + /// + ///```no_run + ///# extern crate websocket; + ///# fn main() { + ///use websocket::{ClientBuilder, Message}; + /// + ///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() + /// .connect(None).unwrap(); + /// + ///for message in client.incoming_messages() { + /// let message: Message = message.unwrap(); + /// println!("Recv: {:?}", message); + ///} + ///# } + ///``` + /// + /// Note that since this method mutably borrows the `Client`, it may be necessary to + /// first `split()` the `Client` and call `incoming_messages()` on the returned + /// `Receiver` to be able to send messages within an iteration. + /// + ///```no_run + ///# extern crate websocket; + ///# fn main() { + ///use websocket::{ClientBuilder, Message}; + /// + ///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() + /// .connect_insecure().unwrap(); + /// + ///let (mut receiver, mut sender) = client.split().unwrap(); + /// + ///for message in receiver.incoming_messages() { + /// let message: Message = message.unwrap(); + /// // Echo the message back + /// sender.send_message(&message).unwrap(); + ///} + ///# } + ///``` + pub fn incoming_messages<'a, M, D>(&'a mut self,) + -> MessageIterator<'a, Receiver, D, M, S::Reader> + where M: ws::Message<'a, D>, + D: DataFrameable + { + self.receiver.incoming_messages(self.stream.reader()) + } } impl Client - where S: Splittable + Stream, + where S: Splittable + Stream { - /// Split this client into its constituent Sender and Receiver pair. - /// - /// This allows the Sender and Receiver to be sent to different threads. - /// - ///```no_run - ///# extern crate websocket; - ///# fn main() { - ///use std::thread; - ///use websocket::{ClientBuilder, Message}; - /// - ///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() - /// .connect_insecure().unwrap(); - /// - ///let (mut receiver, mut sender) = client.split().unwrap(); - /// - ///thread::spawn(move || { - /// for message in receiver.incoming_messages() { - /// let message: Message = message.unwrap(); - /// println!("Recv: {:?}", message); - /// } - ///}); - /// - ///let message = Message::text("Hello, World!"); - ///sender.send_message(&message).unwrap(); - ///# } - ///``` - pub fn split(self) -> IoResult<(Reader<::Reader>, Writer<::Writer>)> { - let (read, write) = try!(self.stream.split()); - Ok((Reader { - stream: read, - receiver: self.receiver, - }, Writer { - stream: write, - sender: self.sender, - })) - } + /// Split this client into its constituent Sender and Receiver pair. + /// + /// This allows the Sender and Receiver to be sent to different threads. + /// + ///```no_run + ///# extern crate websocket; + ///# fn main() { + ///use std::thread; + ///use websocket::{ClientBuilder, Message}; + /// + ///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() + /// .connect_insecure().unwrap(); + /// + ///let (mut receiver, mut sender) = client.split().unwrap(); + /// + ///thread::spawn(move || { + /// for message in receiver.incoming_messages() { + /// let message: Message = message.unwrap(); + /// println!("Recv: {:?}", message); + /// } + ///}); + /// + ///let message = Message::text("Hello, World!"); + ///sender.send_message(&message).unwrap(); + ///# } + ///``` + pub fn split + (self,) + -> IoResult<(Reader<::Reader>, Writer<::Writer>)> { + let (read, write) = try!(self.stream.split()); + Ok((Reader { + stream: read, + receiver: self.receiver, + }, + Writer { + stream: write, + sender: self.sender, + })) + } } diff --git a/src/dataframe.rs b/src/dataframe.rs index cb861692a9..9f08fe2747 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -37,58 +37,55 @@ impl DataFrame { } } - /// Reads a DataFrame from a Reader. - pub fn read_dataframe(reader: &mut R, should_be_masked: bool) -> WebSocketResult - where R: Read { - let header = try!(dfh::read_header(reader)); + /// Reads a DataFrame from a Reader. + pub fn read_dataframe(reader: &mut R, should_be_masked: bool) -> WebSocketResult + where R: Read + { + let header = try!(dfh::read_header(reader)); - Ok(DataFrame { - finished: header.flags.contains(dfh::FIN), - reserved: [ - header.flags.contains(dfh::RSV1), - header.flags.contains(dfh::RSV2), - header.flags.contains(dfh::RSV3) - ], - opcode: Opcode::new(header.opcode).expect("Invalid header opcode!"), - data: match header.mask { - Some(mask) => { - if !should_be_masked { - return Err(WebSocketError::DataFrameError( - "Expected unmasked data frame" - )); - } - let mut data: Vec = Vec::with_capacity(header.len as usize); - try!(reader.take(header.len).read_to_end(&mut data)); - mask::mask_data(mask, &data) - } - None => { - if should_be_masked { - return Err(WebSocketError::DataFrameError( - "Expected masked data frame" - )); - } - let mut data: Vec = Vec::with_capacity(header.len as usize); - try!(reader.take(header.len).read_to_end(&mut data)); - data - } - } - }) - } + Ok(DataFrame { + finished: header.flags.contains(dfh::FIN), + reserved: [ + header.flags.contains(dfh::RSV1), + header.flags.contains(dfh::RSV2), + header.flags.contains(dfh::RSV3), + ], + opcode: Opcode::new(header.opcode).expect("Invalid header opcode!"), + data: match header.mask { + Some(mask) => { + if !should_be_masked { + return Err(WebSocketError::DataFrameError("Expected unmasked data frame")); + } + let mut data: Vec = Vec::with_capacity(header.len as usize); + try!(reader.take(header.len).read_to_end(&mut data)); + mask::mask_data(mask, &data) + } + None => { + if should_be_masked { + return Err(WebSocketError::DataFrameError("Expected masked data frame")); + } + let mut data: Vec = Vec::with_capacity(header.len as usize); + try!(reader.take(header.len).read_to_end(&mut data)); + data + } + }, + }) + } } impl DataFrameable for DataFrame { #[inline(always)] - fn is_last(&self) -> bool { + fn is_last(&self) -> bool { self.finished } #[inline(always)] - fn opcode(&self) -> u8 { + fn opcode(&self) -> u8 { self.opcode as u8 } #[inline(always)] - fn reserved<'a>(&'a self) -> &'a [bool; 3] { + fn reserved<'a>(&'a self) -> &'a [bool; 3] { &self.reserved } @@ -141,62 +138,60 @@ impl Opcode { /// Returns the Opcode, or None if the opcode is out of range. pub fn new(op: u8) -> Option { Some(match op { - 0 => Opcode::Continuation, - 1 => Opcode::Text, - 2 => Opcode::Binary, - 3 => Opcode::NonControl1, - 4 => Opcode::NonControl2, - 5 => Opcode::NonControl3, - 6 => Opcode::NonControl4, - 7 => Opcode::NonControl5, - 8 => Opcode::Close, - 9 => Opcode::Ping, - 10 => Opcode::Pong, - 11 => Opcode::Control1, - 12 => Opcode::Control2, - 13 => Opcode::Control3, - 14 => Opcode::Control4, - 15 => Opcode::Control5, - _ => return None, - }) + 0 => Opcode::Continuation, + 1 => Opcode::Text, + 2 => Opcode::Binary, + 3 => Opcode::NonControl1, + 4 => Opcode::NonControl2, + 5 => Opcode::NonControl3, + 6 => Opcode::NonControl4, + 7 => Opcode::NonControl5, + 8 => Opcode::Close, + 9 => Opcode::Ping, + 10 => Opcode::Pong, + 11 => Opcode::Control1, + 12 => Opcode::Control2, + 13 => Opcode::Control3, + 14 => Opcode::Control4, + 15 => Opcode::Control5, + _ => return None, + }) } } #[cfg(all(feature = "nightly", test))] mod tests { - use super::*; + use super::*; use ws::dataframe::DataFrame as DataFrameable; - use test::Bencher; + use test::Bencher; - #[test] - fn test_read_dataframe() { - let data = b"The quick brown fox jumps over the lazy dog"; - let mut dataframe = vec![0x81, 0x2B]; - for i in data.iter() { - dataframe.push(*i); - } - let obtained = DataFrame::read_dataframe(&mut &dataframe[..], false).unwrap(); - let expected = DataFrame { - finished: true, - reserved: [false; 3], - opcode: Opcode::Text, - data: data.to_vec() - }; - assert_eq!(obtained, expected); - } - #[bench] + #[test] + fn test_read_dataframe() { + let data = b"The quick brown fox jumps over the lazy dog"; + let mut dataframe = vec![0x81, 0x2B]; + for i in data.iter() { + dataframe.push(*i); + } + let obtained = DataFrame::read_dataframe(&mut &dataframe[..], false).unwrap(); + let expected = DataFrame { + finished: true, + reserved: [false; 3], + opcode: Opcode::Text, + data: data.to_vec(), + }; + assert_eq!(obtained, expected); + } + #[bench] fn bench_read_dataframe(b: &mut Bencher) { let data = b"The quick brown fox jumps over the lazy dog"; let mut dataframe = vec![0x81, 0x2B]; for i in data.iter() { dataframe.push(*i); } - b.iter(|| { - DataFrame::read_dataframe(&mut &dataframe[..], false).unwrap(); - }); + b.iter(|| { DataFrame::read_dataframe(&mut &dataframe[..], false).unwrap(); }); } - #[test] + #[test] fn test_write_dataframe() { let data = b"The quick brown fox jumps over the lazy dog"; let mut expected = vec![0x81, 0x2B]; @@ -207,10 +202,10 @@ mod tests { finished: true, reserved: [false; 3], opcode: Opcode::Text, - data: data.to_vec() + data: data.to_vec(), }; let mut obtained = Vec::new(); - dataframe.write_to(&mut obtained, false).unwrap(); + dataframe.write_to(&mut obtained, false).unwrap(); assert_eq!(&obtained[..], &expected[..]); } @@ -222,11 +217,9 @@ mod tests { finished: true, reserved: [false; 3], opcode: Opcode::Text, - data: data.to_vec() + data: data.to_vec(), }; let mut writer = Vec::with_capacity(45); - b.iter(|| { - dataframe.write_to(&mut writer, false).unwrap(); - }); + b.iter(|| { dataframe.write_to(&mut writer, false).unwrap(); }); } } diff --git a/src/header/accept.rs b/src/header/accept.rs index 0b2c57a171..aabb718ece 100644 --- a/src/header/accept.rs +++ b/src/header/accept.rs @@ -15,112 +15,109 @@ static MAGIC_GUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; pub struct WebSocketAccept([u8; 20]); impl Debug for WebSocketAccept { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "WebSocketAccept({})", self.serialize()) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "WebSocketAccept({})", self.serialize()) + } } impl FromStr for WebSocketAccept { - type Err = WebSocketError; + type Err = WebSocketError; - fn from_str(accept: &str) -> WebSocketResult { - match accept.from_base64() { - Ok(vec) => { - if vec.len() != 20 { - return Err(WebSocketError::ProtocolError( - "Sec-WebSocket-Accept must be 20 bytes" - )); - } - let mut array = [0u8; 20]; - let mut iter = vec.into_iter(); - for i in array.iter_mut() { - *i = iter.next().unwrap(); - } - Ok(WebSocketAccept(array)) - } - Err(_) => { - return Err(WebSocketError::ProtocolError( - "Invalid Sec-WebSocket-Accept " - )); - } - } - } + fn from_str(accept: &str) -> WebSocketResult { + match accept.from_base64() { + Ok(vec) => { + if vec.len() != 20 { + return Err(WebSocketError::ProtocolError("Sec-WebSocket-Accept must be 20 bytes")); + } + let mut array = [0u8; 20]; + let mut iter = vec.into_iter(); + for i in array.iter_mut() { + *i = iter.next().unwrap(); + } + Ok(WebSocketAccept(array)) + } + Err(_) => { + return Err(WebSocketError::ProtocolError("Invalid Sec-WebSocket-Accept ")); + } + } + } } impl WebSocketAccept { - /// Create a new WebSocketAccept from the given WebSocketKey - pub fn new(key: &WebSocketKey) -> WebSocketAccept { - let serialized = key.serialize(); - let mut concat_key = String::with_capacity(serialized.len() + 36); - concat_key.push_str(&serialized[..]); - concat_key.push_str(MAGIC_GUID); - let mut sha1 = Sha1::new(); - sha1.update(concat_key.as_bytes()); - let bytes = sha1.digest().bytes(); - WebSocketAccept(bytes) - } - /// Return the Base64 encoding of this WebSocketAccept - pub fn serialize(&self) -> String { - let WebSocketAccept(accept) = *self; - accept.to_base64(STANDARD) - } + /// Create a new WebSocketAccept from the given WebSocketKey + pub fn new(key: &WebSocketKey) -> WebSocketAccept { + let serialized = key.serialize(); + let mut concat_key = String::with_capacity(serialized.len() + 36); + concat_key.push_str(&serialized[..]); + concat_key.push_str(MAGIC_GUID); + let mut sha1 = Sha1::new(); + sha1.update(concat_key.as_bytes()); + let bytes = sha1.digest().bytes(); + WebSocketAccept(bytes) + } + /// Return the Base64 encoding of this WebSocketAccept + pub fn serialize(&self) -> String { + let WebSocketAccept(accept) = *self; + accept.to_base64(STANDARD) + } } impl Header for WebSocketAccept { - fn header_name() -> &'static str { - "Sec-WebSocket-Accept" - } + fn header_name() -> &'static str { + "Sec-WebSocket-Accept" + } - fn parse_header(raw: &[Vec]) -> hyper::Result { - from_one_raw_str(raw) - } + fn parse_header(raw: &[Vec]) -> hyper::Result { + from_one_raw_str(raw) + } } impl HeaderFormat for WebSocketAccept { - fn fmt_header(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "{}", self.serialize()) - } + fn fmt_header(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "{}", self.serialize()) + } } #[cfg(all(feature = "nightly", test))] mod tests { - use super::*; - use test; - use std::str::FromStr; - use header::{Headers, WebSocketKey}; - use hyper::header::Header; + use super::*; + use test; + use std::str::FromStr; + use header::{Headers, WebSocketKey}; + use hyper::header::Header; - #[test] - fn test_header_accept() { - let key = FromStr::from_str("dGhlIHNhbXBsZSBub25jZQ==").unwrap(); - let accept = WebSocketAccept::new(&key); - let mut headers = Headers::new(); - headers.set(accept); + #[test] + fn test_header_accept() { + let key = FromStr::from_str("dGhlIHNhbXBsZSBub25jZQ==").unwrap(); + let accept = WebSocketAccept::new(&key); + let mut headers = Headers::new(); + headers.set(accept); - assert_eq!(&headers.to_string()[..], "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"); - } - #[bench] - fn bench_header_accept_new(b: &mut test::Bencher) { - let key = WebSocketKey::new(); - b.iter(|| { - let mut accept = WebSocketAccept::new(&key); - test::black_box(&mut accept); - }); - } - #[bench] - fn bench_header_accept_parse(b: &mut test::Bencher) { - let value = vec![b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".to_vec()]; - b.iter(|| { - let mut accept: WebSocketAccept = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut accept); - }); - } - #[bench] - fn bench_header_accept_format(b: &mut test::Bencher) { - let value = vec![b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".to_vec()]; - let val: WebSocketAccept = Header::parse_header(&value[..]).unwrap(); - b.iter(|| { - format!("{}", val.serialize()); - }); - } + assert_eq!(&headers.to_string()[..], + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n"); + } + #[bench] + fn bench_header_accept_new(b: &mut test::Bencher) { + let key = WebSocketKey::new(); + b.iter(|| { + let mut accept = WebSocketAccept::new(&key); + test::black_box(&mut accept); + }); + } + #[bench] + fn bench_header_accept_parse(b: &mut test::Bencher) { + let value = vec![b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".to_vec()]; + b.iter(|| { + let mut accept: WebSocketAccept = Header::parse_header(&value[..]).unwrap(); + test::black_box(&mut accept); + }); + } + #[bench] + fn bench_header_accept_format(b: &mut test::Bencher) { + let value = vec![b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".to_vec()]; + let val: WebSocketAccept = Header::parse_header(&value[..]).unwrap(); + b.iter(|| { + format!("{}", val.serialize()); + }); + } } diff --git a/src/header/extensions.rs b/src/header/extensions.rs index 6c842dd32d..d1c2f20bf6 100644 --- a/src/header/extensions.rs +++ b/src/header/extensions.rs @@ -8,6 +8,8 @@ use std::str::FromStr; use std::ops::Deref; use result::{WebSocketResult, WebSocketError}; +const INVALID_EXTENSION: &'static str = "Invalid Sec-WebSocket-Extensions extension name"; + /// Represents a Sec-WebSocket-Extensions header #[derive(PartialEq, Clone, Debug)] pub struct WebSocketExtensions(pub Vec); @@ -15,9 +17,9 @@ pub struct WebSocketExtensions(pub Vec); impl Deref for WebSocketExtensions { type Target = Vec; - fn deref<'a>(&'a self) -> &'a Vec { - &self.0 - } + fn deref<'a>(&'a self) -> &'a Vec { + &self.0 + } } #[derive(PartialEq, Clone, Debug)] @@ -26,7 +28,7 @@ pub struct Extension { /// The name of this extension pub name: String, /// The parameters for this extension - pub params: Vec + pub params: Vec, } impl Extension { @@ -34,43 +36,42 @@ impl Extension { pub fn new(name: String) -> Extension { Extension { name: name, - params: Vec::new() + params: Vec::new(), } } } impl FromStr for Extension { type Err = WebSocketError; - + fn from_str(s: &str) -> WebSocketResult { let mut ext = s.split(';').map(|x| x.trim()); Ok(Extension { - name: match ext.next() { - Some(x) => x.to_string(), - None => return Err(WebSocketError::ProtocolError( - "Invalid Sec-WebSocket-Extensions extension name" - )), - }, - params: ext.map(|x| { - let mut pair = x.splitn(1, '=').map(|x| x.trim().to_string()); - - Parameter { - name: pair.next().unwrap(), - value: pair.next() - } - }).collect() - }) + name: match ext.next() { + Some(x) => x.to_string(), + None => return Err(WebSocketError::ProtocolError(INVALID_EXTENSION)), + }, + params: ext.map(|x| { + let mut pair = x.splitn(1, '=').map(|x| x.trim().to_string()); + + Parameter { + name: pair.next().unwrap(), + value: pair.next(), + } + }) + .collect(), + }) } } impl fmt::Display for Extension { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - try!(write!(f, "{}", self.name)); + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + try!(write!(f, "{}", self.name)); for param in self.params.iter() { try!(write!(f, "; {}", param)); } Ok(()) - } + } } #[derive(PartialEq, Clone, Debug)] @@ -79,7 +80,7 @@ pub struct Parameter { /// The name of this parameter pub name: String, /// The value of this parameter, if any - pub value: Option + pub value: Option, } impl Parameter { @@ -87,20 +88,20 @@ impl Parameter { pub fn new(name: String, value: Option) -> Parameter { Parameter { name: name, - value: value + value: value, } } } impl fmt::Display for Parameter { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - try!(write!(f, "{}", self.name)); + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + try!(write!(f, "{}", self.name)); match self.value { Some(ref x) => try!(write!(f, "={}", x)), None => (), } Ok(()) - } + } } impl Header for WebSocketExtensions { @@ -121,9 +122,9 @@ impl HeaderFormat for WebSocketExtensions { } impl fmt::Display for WebSocketExtensions { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - self.fmt_header(fmt) - } + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(fmt) + } } #[cfg(all(feature = "nightly", test))] @@ -136,26 +137,28 @@ mod tests { use header::Headers; let value = vec![b"foo, bar; baz; qux=quux".to_vec()]; let extensions: WebSocketExtensions = Header::parse_header(&value[..]).unwrap(); - + let mut headers = Headers::new(); headers.set(extensions); - - assert_eq!(&headers.to_string()[..], "Sec-WebSocket-Extensions: foo, bar; baz; qux=quux\r\n"); + + assert_eq!(&headers.to_string()[..], + "Sec-WebSocket-Extensions: foo, bar; baz; qux=quux\r\n"); } #[bench] fn bench_header_extensions_parse(b: &mut test::Bencher) { let value = vec![b"foo, bar; baz; qux=quux".to_vec()]; b.iter(|| { - let mut extensions: WebSocketExtensions = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut extensions); - }); + let mut extensions: WebSocketExtensions = Header::parse_header(&value[..]) + .unwrap(); + test::black_box(&mut extensions); + }); } #[bench] fn bench_header_extensions_format(b: &mut test::Bencher) { let value = vec![b"foo, bar; baz; qux=quux".to_vec()]; let val: WebSocketExtensions = Header::parse_header(&value[..]).unwrap(); b.iter(|| { - format!("{}", val); - }); + format!("{}", val); + }); } } diff --git a/src/header/key.rs b/src/header/key.rs index fb96faf4f1..3e3331b953 100644 --- a/src/header/key.rs +++ b/src/header/key.rs @@ -25,22 +25,18 @@ impl FromStr for WebSocketKey { match key.from_base64() { Ok(vec) => { if vec.len() != 16 { - return Err(WebSocketError::ProtocolError( - "Sec-WebSocket-Key must be 16 bytes" - )); + return Err(WebSocketError::ProtocolError("Sec-WebSocket-Key must be 16 bytes")); } let mut array = [0u8; 16]; let mut iter = vec.into_iter(); for i in array.iter_mut() { *i = iter.next().unwrap(); } - + Ok(WebSocketKey(array)) } Err(_) => { - return Err(WebSocketError::ProtocolError( - "Invalid Sec-WebSocket-Accept" - )); + return Err(WebSocketError::ProtocolError("Invalid Sec-WebSocket-Accept")); } } } @@ -51,9 +47,7 @@ impl WebSocketKey { pub fn new() -> WebSocketKey { let key: [u8; 16] = unsafe { // Much faster than calling random() several times - mem::transmute( - rand::random::<(u64, u64)>() - ) + mem::transmute(rand::random::<(u64, u64)>()) }; WebSocketKey(key) } @@ -88,34 +82,35 @@ mod tests { #[test] fn test_header_key() { use header::Headers; - + let extensions = WebSocketKey([65; 16]); let mut headers = Headers::new(); headers.set(extensions); - - assert_eq!(&headers.to_string()[..], "Sec-WebSocket-Key: QUFBQUFBQUFBQUFBQUFBQQ==\r\n"); + + assert_eq!(&headers.to_string()[..], + "Sec-WebSocket-Key: QUFBQUFBQUFBQUFBQUFBQQ==\r\n"); } #[bench] fn bench_header_key_new(b: &mut test::Bencher) { b.iter(|| { - let mut key = WebSocketKey::new(); - test::black_box(&mut key); - }); + let mut key = WebSocketKey::new(); + test::black_box(&mut key); + }); } #[bench] fn bench_header_key_parse(b: &mut test::Bencher) { let value = vec![b"QUFBQUFBQUFBQUFBQUFBQQ==".to_vec()]; b.iter(|| { - let mut key: WebSocketKey = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut key); - }); + let mut key: WebSocketKey = Header::parse_header(&value[..]).unwrap(); + test::black_box(&mut key); + }); } #[bench] fn bench_header_key_format(b: &mut test::Bencher) { let value = vec![b"QUFBQUFBQUFBQUFBQUFBQQ==".to_vec()]; let val: WebSocketKey = Header::parse_header(&value[..]).unwrap(); b.iter(|| { - format!("{}", val.serialize()); - }); + format!("{}", val.serialize()); + }); } } diff --git a/src/header/mod.rs b/src/header/mod.rs index 9fc0dda0ae..01a4aa79a4 100644 --- a/src/header/mod.rs +++ b/src/header/mod.rs @@ -16,4 +16,4 @@ mod key; mod protocol; mod version; pub mod extensions; -mod origin; \ No newline at end of file +mod origin; diff --git a/src/header/origin.rs b/src/header/origin.rs index 3f172e108b..3c3d807106 100644 --- a/src/header/origin.rs +++ b/src/header/origin.rs @@ -10,9 +10,9 @@ pub struct Origin(pub String); impl Deref for Origin { type Target = String; - fn deref<'a>(&'a self) -> &'a String { - &self.0 - } + fn deref<'a>(&'a self) -> &'a String { + &self.0 + } } impl Header for Origin { @@ -28,14 +28,14 @@ impl Header for Origin { impl HeaderFormat for Origin { fn fmt_header(&self, fmt: &mut fmt::Formatter) -> fmt::Result { let Origin(ref value) = *self; - write!(fmt, "{}", value) + write!(fmt, "{}", value) } } impl fmt::Display for Origin { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - self.fmt_header(fmt) - } + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(fmt) + } } #[cfg(all(feature = "nightly", test))] @@ -46,27 +46,27 @@ mod tests { #[test] fn test_header_origin() { use header::Headers; - + let origin = Origin("foo bar".to_string()); let mut headers = Headers::new(); headers.set(origin); - + assert_eq!(&headers.to_string()[..], "Origin: foo bar\r\n"); } #[bench] fn bench_header_origin_parse(b: &mut test::Bencher) { let value = vec![b"foobar".to_vec()]; b.iter(|| { - let mut origin: Origin = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut origin); - }); + let mut origin: Origin = Header::parse_header(&value[..]).unwrap(); + test::black_box(&mut origin); + }); } #[bench] fn bench_header_origin_format(b: &mut test::Bencher) { let value = vec![b"foobar".to_vec()]; let val: Origin = Header::parse_header(&value[..]).unwrap(); b.iter(|| { - format!("{}", val); - }); + format!("{}", val); + }); } } diff --git a/src/header/protocol.rs b/src/header/protocol.rs index 3582b323c8..51773cda8f 100644 --- a/src/header/protocol.rs +++ b/src/header/protocol.rs @@ -10,9 +10,9 @@ pub struct WebSocketProtocol(pub Vec); impl Deref for WebSocketProtocol { type Target = Vec; - fn deref<'a>(&'a self) -> &'a Vec { - &self.0 - } + fn deref<'a>(&'a self) -> &'a Vec { + &self.0 + } } impl Header for WebSocketProtocol { @@ -33,9 +33,9 @@ impl HeaderFormat for WebSocketProtocol { } impl fmt::Display for WebSocketProtocol { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - self.fmt_header(fmt) - } + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(fmt) + } } #[cfg(all(feature = "nightly", test))] @@ -46,27 +46,28 @@ mod tests { #[test] fn test_header_protocol() { use header::Headers; - + let protocol = WebSocketProtocol(vec!["foo".to_string(), "bar".to_string()]); let mut headers = Headers::new(); headers.set(protocol); - - assert_eq!(&headers.to_string()[..], "Sec-WebSocket-Protocol: foo, bar\r\n"); + + assert_eq!(&headers.to_string()[..], + "Sec-WebSocket-Protocol: foo, bar\r\n"); } #[bench] fn bench_header_protocol_parse(b: &mut test::Bencher) { let value = vec![b"foo, bar".to_vec()]; b.iter(|| { - let mut protocol: WebSocketProtocol = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut protocol); - }); + let mut protocol: WebSocketProtocol = Header::parse_header(&value[..]).unwrap(); + test::black_box(&mut protocol); + }); } #[bench] fn bench_header_protocol_format(b: &mut test::Bencher) { let value = vec![b"foo, bar".to_vec()]; let val: WebSocketProtocol = Header::parse_header(&value[..]).unwrap(); b.iter(|| { - format!("{}", val); - }); + format!("{}", val); + }); } } diff --git a/src/header/version.rs b/src/header/version.rs index 4128e35267..7fadfc4ff5 100644 --- a/src/header/version.rs +++ b/src/header/version.rs @@ -9,18 +9,14 @@ pub enum WebSocketVersion { /// The version of WebSocket defined in RFC6455 WebSocket13, /// An unknown version of WebSocket - Unknown(String) + Unknown(String), } impl fmt::Debug for WebSocketVersion { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - WebSocketVersion::WebSocket13 => { - write!(f, "13") - } - WebSocketVersion::Unknown(ref value) => { - write!(f, "{}", value) - } + WebSocketVersion::WebSocket13 => write!(f, "13"), + WebSocketVersion::Unknown(ref value) => write!(f, "{}", value), } } } @@ -31,12 +27,10 @@ impl Header for WebSocketVersion { } fn parse_header(raw: &[Vec]) -> hyper::Result { - from_one_raw_str(raw).map(|s : String| - match &s[..] { - "13" => { WebSocketVersion::WebSocket13 } - _ => { WebSocketVersion::Unknown(s) } - } - ) + from_one_raw_str(raw).map(|s: String| match &s[..] { + "13" => WebSocketVersion::WebSocket13, + _ => WebSocketVersion::Unknown(s), + }) } } @@ -47,9 +41,9 @@ impl HeaderFormat for WebSocketVersion { } impl fmt::Display for WebSocketVersion { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - self.fmt_header(fmt) - } + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + self.fmt_header(fmt) + } } #[cfg(all(feature = "nightly", test))] @@ -60,27 +54,27 @@ mod tests { #[test] fn test_websocket_version() { use header::Headers; - + let version = WebSocketVersion::WebSocket13; let mut headers = Headers::new(); headers.set(version); - + assert_eq!(&headers.to_string()[..], "Sec-WebSocket-Version: 13\r\n"); } #[bench] fn bench_header_version_parse(b: &mut test::Bencher) { let value = vec![b"13".to_vec()]; b.iter(|| { - let mut version: WebSocketVersion = Header::parse_header(&value[..]).unwrap(); - test::black_box(&mut version); - }); + let mut version: WebSocketVersion = Header::parse_header(&value[..]).unwrap(); + test::black_box(&mut version); + }); } #[bench] fn bench_header_version_format(b: &mut test::Bencher) { let value = vec![b"13".to_vec()]; let val: WebSocketVersion = Header::parse_header(&value[..]).unwrap(); b.iter(|| { - format!("{}", val); - }); + format!("{}", val); + }); } } diff --git a/src/lib.rs b/src/lib.rs index 7f489841b5..e1ca80cc63 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,10 +52,7 @@ extern crate bitflags; #[cfg(all(feature = "nightly", test))] extern crate test; -pub use self::client::{ - Client, - ClientBuilder, -}; +pub use self::client::{Client, ClientBuilder}; pub use self::server::Server; pub use self::dataframe::DataFrame; pub use self::message::Message; diff --git a/src/message.rs b/src/message.rs index 89b64e9387..f2103152c3 100644 --- a/src/message.rs +++ b/src/message.rs @@ -13,15 +13,15 @@ const FALSE_RESERVED_BITS: &'static [bool; 3] = &[false; 3]; /// Valid types of messages (in the default implementation) #[derive(Debug, PartialEq, Clone, Copy)] pub enum Type { - /// Message with UTF8 test + /// Message with UTF8 test Text = 1, - /// Message containing binary data + /// Message containing binary data Binary = 2, - /// Ping message with data + /// Ping message with data Ping = 9, - /// Pong message with data + /// Pong message with data Pong = 10, - /// Close connection message with optional reason + /// Close connection message with optional reason Close = 8, } @@ -35,12 +35,12 @@ pub enum Type { /// because this message just gets sent as one single DataFrame. #[derive(PartialEq, Clone, Debug)] pub struct Message<'a> { - /// Type of WebSocket message + /// Type of WebSocket message pub opcode: Type, - /// Optional status code to send when closing a connection. - /// (only used if this message is of Type::Close) + /// Optional status code to send when closing a connection. + /// (only used if this message is of Type::Close) pub cd_status_code: Option, - /// Main payload + /// Main payload pub payload: Cow<'a, [u8]>, } @@ -53,79 +53,88 @@ impl<'a> Message<'a> { } } - /// Create a new WebSocket message with text data + /// Create a new WebSocket message with text data pub fn text(data: S) -> Self - where S: Into> { - Message::new(Type::Text, None, match data.into() { - Cow::Owned(msg) => Cow::Owned(msg.into_bytes()), - Cow::Borrowed(msg) => Cow::Borrowed(msg.as_bytes()), - }) + where S: Into> + { + Message::new(Type::Text, + None, + match data.into() { + Cow::Owned(msg) => Cow::Owned(msg.into_bytes()), + Cow::Borrowed(msg) => Cow::Borrowed(msg.as_bytes()), + }) } - /// Create a new WebSocket message with binary data + /// Create a new WebSocket message with binary data pub fn binary(data: B) -> Self - where B: IntoCowBytes<'a> { + where B: IntoCowBytes<'a> + { Message::new(Type::Binary, None, data.into()) } - /// Create a new WebSocket message that signals the end of a WebSocket - /// connection, although messages can still be sent after sending this + /// Create a new WebSocket message that signals the end of a WebSocket + /// connection, although messages can still be sent after sending this pub fn close() -> Self { Message::new(Type::Close, None, Cow::Borrowed(&[0 as u8; 0])) } - /// Create a new WebSocket message that signals the end of a WebSocket - /// connection and provide a text reason and a status code for why. - /// Messages can still be sent after sending this message. + /// Create a new WebSocket message that signals the end of a WebSocket + /// connection and provide a text reason and a status code for why. + /// Messages can still be sent after sending this message. pub fn close_because(code: u16, reason: S) -> Self - where S: Into> { - Message::new(Type::Close, Some(code), match reason.into() { - Cow::Owned(msg) => Cow::Owned(msg.into_bytes()), - Cow::Borrowed(msg) => Cow::Borrowed(msg.as_bytes()), - }) + where S: Into> + { + Message::new(Type::Close, + Some(code), + match reason.into() { + Cow::Owned(msg) => Cow::Owned(msg.into_bytes()), + Cow::Borrowed(msg) => Cow::Borrowed(msg.as_bytes()), + }) } - /// Create a ping WebSocket message, a pong is usually sent back - /// after sending this with the same data + /// Create a ping WebSocket message, a pong is usually sent back + /// after sending this with the same data pub fn ping

(data: P) -> Self - where P: IntoCowBytes<'a> { + where P: IntoCowBytes<'a> + { Message::new(Type::Ping, None, data.into()) } - /// Create a pong WebSocket message, usually a response to a - /// ping message + /// Create a pong WebSocket message, usually a response to a + /// ping message pub fn pong

(data: P) -> Self - where P: IntoCowBytes<'a> { + where P: IntoCowBytes<'a> + { Message::new(Type::Pong, None, data.into()) } - /// Convert a ping message to a pong, keeping the data. - /// This will fail if the original message is not a ping. - pub fn into_pong(&mut self) -> Result<(), ()> { - if self.opcode == Type::Ping { - self.opcode = Type::Pong; - Ok(()) - } else { - Err(()) - } - } + /// Convert a ping message to a pong, keeping the data. + /// This will fail if the original message is not a ping. + pub fn into_pong(&mut self) -> Result<(), ()> { + if self.opcode == Type::Ping { + self.opcode = Type::Pong; + Ok(()) + } else { + Err(()) + } + } } impl<'a> ws::dataframe::DataFrame for Message<'a> { #[inline(always)] - fn is_last(&self) -> bool { - true - } + fn is_last(&self) -> bool { + true + } #[inline(always)] - fn opcode(&self) -> u8 { - self.opcode as u8 - } + fn opcode(&self) -> u8 { + self.opcode as u8 + } #[inline(always)] - fn reserved<'b>(&'b self) -> &'b [bool; 3] { + fn reserved<'b>(&'b self) -> &'b [bool; 3] { FALSE_RESERVED_BITS - } + } fn payload<'b>(&'b self) -> Cow<'b, [u8]> { let mut buf = Vec::with_capacity(self.size()); @@ -134,79 +143,70 @@ impl<'a> ws::dataframe::DataFrame for Message<'a> { } fn size(&self) -> usize { - self.payload.len() + if self.cd_status_code.is_some() { - 2 - } else { - 0 - } + self.payload.len() + if self.cd_status_code.is_some() { 2 } else { 0 } } - fn write_payload(&self, socket: &mut W) -> WebSocketResult<()> - where W: Write { + fn write_payload(&self, socket: &mut W) -> WebSocketResult<()> + where W: Write + { if let Some(reason) = self.cd_status_code { try!(socket.write_u16::(reason)); } try!(socket.write_all(&*self.payload)); Ok(()) - } + } } impl<'a, 'b> ws::Message<'b, &'b Message<'a>> for Message<'a> { - type DataFrameIterator = Take>>; fn dataframes(&'b self) -> Self::DataFrameIterator { repeat(self).take(1) - } + } /// Attempt to form a message from a series of data frames fn from_dataframes(frames: Vec) -> WebSocketResult - where D: ws::dataframe::DataFrame { - let opcode = try!(frames.first().ok_or(WebSocketError::ProtocolError( - "No dataframes provided" - )).map(|d| d.opcode())); + where D: ws::dataframe::DataFrame + { + let opcode = try!(frames.first() + .ok_or(WebSocketError::ProtocolError("No dataframes provided")) + .map(|d| d.opcode())); let mut data = Vec::new(); for (i, dataframe) in frames.iter().enumerate() { if i > 0 && dataframe.opcode() != Opcode::Continuation as u8 { - return Err(WebSocketError::ProtocolError( - "Unexpected non-continuation data frame" - )); + return Err(WebSocketError::ProtocolError("Unexpected non-continuation data frame")); } if *dataframe.reserved() != [false; 3] { - return Err(WebSocketError::ProtocolError( - "Unsupported reserved bits received" - )); + return Err(WebSocketError::ProtocolError("Unsupported reserved bits received")); } data.extend(dataframe.payload().iter().cloned()); } Ok(match Opcode::new(opcode) { - Some(Opcode::Text) => Message::text(try!(bytes_to_string(&data[..]))), - Some(Opcode::Binary) => Message::binary(data), - Some(Opcode::Close) => { - if data.len() > 0 { - let status_code = try!((&data[..]).read_u16::()); - let reason = try!(bytes_to_string(&data[2..])); - Message::close_because(status_code, reason) - } else { - Message::close() - } - } - Some(Opcode::Ping) => Message::ping(data), - Some(Opcode::Pong) => Message::pong(data), - _ => return Err(WebSocketError::ProtocolError( - "Unsupported opcode received" - )), - }) + Some(Opcode::Text) => Message::text(try!(bytes_to_string(&data[..]))), + Some(Opcode::Binary) => Message::binary(data), + Some(Opcode::Close) => { + if data.len() > 0 { + let status_code = try!((&data[..]).read_u16::()); + let reason = try!(bytes_to_string(&data[2..])); + Message::close_because(status_code, reason) + } else { + Message::close() + } + } + Some(Opcode::Ping) => Message::ping(data), + Some(Opcode::Pong) => Message::pong(data), + _ => return Err(WebSocketError::ProtocolError("Unsupported opcode received")), + }) } } /// Trait representing the ability to convert /// self to a `Cow<'a, [u8]>` pub trait IntoCowBytes<'a> { - /// Consume `self` and produce a `Cow<'a, [u8]>` + /// Consume `self` and produce a `Cow<'a, [u8]>` fn into(self) -> Cow<'a, [u8]>; } diff --git a/src/receiver.rs b/src/receiver.rs index 85b7f09cea..aa1a7e6760 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -3,78 +3,66 @@ use std::io::Read; use std::io::Result as IoResult; -use dataframe::{ - DataFrame, - Opcode -}; -use result::{ - WebSocketResult, - WebSocketError -}; +use dataframe::{DataFrame, Opcode}; +use result::{WebSocketResult, WebSocketError}; use ws; use ws::dataframe::DataFrame as DataFrameable; use ws::receiver::Receiver as ReceiverTrait; -use ws::receiver::{ - MessageIterator, - DataFrameIterator, -}; -use stream::{ - AsTcpStream, - Stream, -}; +use ws::receiver::{MessageIterator, DataFrameIterator}; +use stream::{AsTcpStream, Stream}; pub use stream::Shutdown; // TODO: buffer the readers pub struct Reader - where R: Read + where R: Read { - pub stream: R, - pub receiver: Receiver, + pub stream: R, + pub receiver: Receiver, } impl Reader - where R: Read, + where R: Read { - /// Reads a single data frame from the remote endpoint. - pub fn recv_dataframe(&mut self) -> WebSocketResult { - self.receiver.recv_dataframe(&mut self.stream) - } - - /// Returns an iterator over incoming data frames. - pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, R> { - self.receiver.incoming_dataframes(&mut self.stream) - } - - /// Reads a single message from this receiver. - pub fn recv_message<'m, M, I>(&mut self) -> WebSocketResult - where M: ws::Message<'m, DataFrame, DataFrameIterator = I>, - I: Iterator - { - self.receiver.recv_message(&mut self.stream) - } - - pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, R> - where M: ws::Message<'a, D>, - D: DataFrameable - { - self.receiver.incoming_messages(&mut self.stream) - } + /// Reads a single data frame from the remote endpoint. + pub fn recv_dataframe(&mut self) -> WebSocketResult { + self.receiver.recv_dataframe(&mut self.stream) + } + + /// Returns an iterator over incoming data frames. + pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, R> { + self.receiver.incoming_dataframes(&mut self.stream) + } + + /// Reads a single message from this receiver. + pub fn recv_message<'m, M, I>(&mut self) -> WebSocketResult + where M: ws::Message<'m, DataFrame, DataFrameIterator = I>, + I: Iterator + { + self.receiver.recv_message(&mut self.stream) + } + + pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, R> + where M: ws::Message<'a, D>, + D: DataFrameable + { + self.receiver.incoming_messages(&mut self.stream) + } } impl Reader - where S: AsTcpStream + Stream + Read, + where S: AsTcpStream + Stream + Read { - /// Closes the receiver side of the connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Read) - } - - /// Shuts down both Sender and Receiver, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown_all(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Both) - } + /// Closes the receiver side of the connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Read) + } + + /// Shuts down both Sender and Receiver, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown_all(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Both) + } } /// A Receiver that wraps a Reader and provides a default implementation using @@ -96,33 +84,30 @@ impl Receiver { impl ws::Receiver for Receiver { - type F = DataFrame; + type F = DataFrame; /// Reads a single data frame from the remote endpoint. - fn recv_dataframe(&mut self, reader: &mut R) -> WebSocketResult - where R: Read, - { + fn recv_dataframe(&mut self, reader: &mut R) -> WebSocketResult + where R: Read + { DataFrame::read_dataframe(reader, self.mask) } /// Returns the data frames that constitute one message. - fn recv_message_dataframes(&mut self, reader: &mut R) -> WebSocketResult> - where R: Read, - { + fn recv_message_dataframes(&mut self, reader: &mut R) -> WebSocketResult> + where R: Read + { let mut finished = if self.buffer.is_empty() { let first = try!(self.recv_dataframe(reader)); if first.opcode == Opcode::Continuation { - return Err(WebSocketError::ProtocolError( - "Unexpected continuation data frame opcode" - )); + return Err(WebSocketError::ProtocolError("Unexpected continuation data frame opcode")); } let finished = first.finished; self.buffer.push(first); finished - } - else { + } else { false }; @@ -138,9 +123,7 @@ impl ws::Receiver for Receiver { return Ok(vec![next]); } // Others - _ => return Err(WebSocketError::ProtocolError( - "Unexpected data frame opcode" - )), + _ => return Err(WebSocketError::ProtocolError("Unexpected data frame opcode")), } } diff --git a/src/result.rs b/src/result.rs index 712c49c0b9..f87b2ce5cc 100644 --- a/src/result.rs +++ b/src/result.rs @@ -19,157 +19,157 @@ pub type WebSocketResult = Result; /// Represents a WebSocket error #[derive(Debug)] pub enum WebSocketError { - /// A WebSocket protocol error - ProtocolError(&'static str), - /// Invalid WebSocket request error - RequestError(&'static str), - /// Invalid WebSocket response error - ResponseError(&'static str), - /// Invalid WebSocket data frame error - DataFrameError(&'static str), - /// No data available - NoDataAvailable, - /// An input/output error - IoError(io::Error), - /// An HTTP parsing error - HttpError(HttpError), - /// A URL parsing error - UrlError(ParseError), - /// A WebSocket URL error - WebSocketUrlError(WSUrlErrorKind), - /// An SSL error - #[cfg(feature="ssl")] - SslError(SslError), - /// an ssl handshake failure - #[cfg(feature="ssl")] - SslHandshakeFailure, - /// an ssl handshake interruption - #[cfg(feature="ssl")] - SslHandshakeInterruption, - /// A UTF-8 error - Utf8Error(Utf8Error), + /// A WebSocket protocol error + ProtocolError(&'static str), + /// Invalid WebSocket request error + RequestError(&'static str), + /// Invalid WebSocket response error + ResponseError(&'static str), + /// Invalid WebSocket data frame error + DataFrameError(&'static str), + /// No data available + NoDataAvailable, + /// An input/output error + IoError(io::Error), + /// An HTTP parsing error + HttpError(HttpError), + /// A URL parsing error + UrlError(ParseError), + /// A WebSocket URL error + WebSocketUrlError(WSUrlErrorKind), + /// An SSL error + #[cfg(feature="ssl")] + SslError(SslError), + /// an ssl handshake failure + #[cfg(feature="ssl")] + SslHandshakeFailure, + /// an ssl handshake interruption + #[cfg(feature="ssl")] + SslHandshakeInterruption, + /// A UTF-8 error + Utf8Error(Utf8Error), } impl fmt::Display for WebSocketError { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - try!(fmt.write_str("WebSocketError: ")); - try!(fmt.write_str(self.description())); - Ok(()) - } + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + try!(fmt.write_str("WebSocketError: ")); + try!(fmt.write_str(self.description())); + Ok(()) + } } impl Error for WebSocketError { - fn description(&self) -> &str { - match *self { - WebSocketError::ProtocolError(_) => "WebSocket protocol error", - WebSocketError::RequestError(_) => "WebSocket request error", - WebSocketError::ResponseError(_) => "WebSocket response error", - WebSocketError::DataFrameError(_) => "WebSocket data frame error", - WebSocketError::NoDataAvailable => "No data available", - WebSocketError::IoError(_) => "I/O failure", - WebSocketError::HttpError(_) => "HTTP failure", - WebSocketError::UrlError(_) => "URL failure", - #[cfg(feature="ssl")] + fn description(&self) -> &str { + match *self { + WebSocketError::ProtocolError(_) => "WebSocket protocol error", + WebSocketError::RequestError(_) => "WebSocket request error", + WebSocketError::ResponseError(_) => "WebSocket response error", + WebSocketError::DataFrameError(_) => "WebSocket data frame error", + WebSocketError::NoDataAvailable => "No data available", + WebSocketError::IoError(_) => "I/O failure", + WebSocketError::HttpError(_) => "HTTP failure", + WebSocketError::UrlError(_) => "URL failure", + #[cfg(feature="ssl")] WebSocketError::SslError(_) => "SSL failure", - #[cfg(feature="ssl")] + #[cfg(feature="ssl")] WebSocketError::SslHandshakeFailure => "SSL Handshake failure", - #[cfg(feature="ssl")] + #[cfg(feature="ssl")] WebSocketError::SslHandshakeInterruption => "SSL Handshake interrupted", - WebSocketError::Utf8Error(_) => "UTF-8 failure", - WebSocketError::WebSocketUrlError(_) => "WebSocket URL failure", - } - } - - fn cause(&self) -> Option<&Error> { - match *self { - WebSocketError::IoError(ref error) => Some(error), - WebSocketError::HttpError(ref error) => Some(error), - WebSocketError::UrlError(ref error) => Some(error), - #[cfg(feature="ssl")] + WebSocketError::Utf8Error(_) => "UTF-8 failure", + WebSocketError::WebSocketUrlError(_) => "WebSocket URL failure", + } + } + + fn cause(&self) -> Option<&Error> { + match *self { + WebSocketError::IoError(ref error) => Some(error), + WebSocketError::HttpError(ref error) => Some(error), + WebSocketError::UrlError(ref error) => Some(error), + #[cfg(feature="ssl")] WebSocketError::SslError(ref error) => Some(error), - WebSocketError::Utf8Error(ref error) => Some(error), - WebSocketError::WebSocketUrlError(ref error) => Some(error), - _ => None, - } - } + WebSocketError::Utf8Error(ref error) => Some(error), + WebSocketError::WebSocketUrlError(ref error) => Some(error), + _ => None, + } + } } impl From for WebSocketError { - fn from(err: io::Error) -> WebSocketError { - if err.kind() == io::ErrorKind::UnexpectedEof { - return WebSocketError::NoDataAvailable; - } - WebSocketError::IoError(err) - } + fn from(err: io::Error) -> WebSocketError { + if err.kind() == io::ErrorKind::UnexpectedEof { + return WebSocketError::NoDataAvailable; + } + WebSocketError::IoError(err) + } } impl From for WebSocketError { - fn from(err: HttpError) -> WebSocketError { - WebSocketError::HttpError(err) - } + fn from(err: HttpError) -> WebSocketError { + WebSocketError::HttpError(err) + } } impl From for WebSocketError { - fn from(err: ParseError) -> WebSocketError { - WebSocketError::UrlError(err) - } + fn from(err: ParseError) -> WebSocketError { + WebSocketError::UrlError(err) + } } #[cfg(feature="ssl")] impl From for WebSocketError { - fn from(err: SslError) -> WebSocketError { - WebSocketError::SslError(err) - } + fn from(err: SslError) -> WebSocketError { + WebSocketError::SslError(err) + } } #[cfg(feature="ssl")] impl From> for WebSocketError { - fn from(err: SslHandshakeError) -> WebSocketError { - match err { - SslHandshakeError::SetupFailure(err) => WebSocketError::SslError(err), - SslHandshakeError::Failure(_) => WebSocketError::SslHandshakeFailure, - SslHandshakeError::Interrupted(_) => WebSocketError::SslHandshakeInterruption, - } - } + fn from(err: SslHandshakeError) -> WebSocketError { + match err { + SslHandshakeError::SetupFailure(err) => WebSocketError::SslError(err), + SslHandshakeError::Failure(_) => WebSocketError::SslHandshakeFailure, + SslHandshakeError::Interrupted(_) => WebSocketError::SslHandshakeInterruption, + } + } } impl From for WebSocketError { - fn from(err: Utf8Error) -> WebSocketError { - WebSocketError::Utf8Error(err) - } + fn from(err: Utf8Error) -> WebSocketError { + WebSocketError::Utf8Error(err) + } } impl From for WebSocketError { - fn from(err: WSUrlErrorKind) -> WebSocketError { - WebSocketError::WebSocketUrlError(err) - } + fn from(err: WSUrlErrorKind) -> WebSocketError { + WebSocketError::WebSocketUrlError(err) + } } /// Represents a WebSocket URL error #[derive(Debug)] pub enum WSUrlErrorKind { - /// Fragments are not valid in a WebSocket URL - CannotSetFragment, - /// The scheme provided is invalid for a WebSocket - InvalidScheme, - /// There is no hostname or IP address to connect to - NoHostName, + /// Fragments are not valid in a WebSocket URL + CannotSetFragment, + /// The scheme provided is invalid for a WebSocket + InvalidScheme, + /// There is no hostname or IP address to connect to + NoHostName, } impl fmt::Display for WSUrlErrorKind { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - try!(fmt.write_str("WebSocket Url Error: ")); - try!(fmt.write_str(self.description())); - Ok(()) - } + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + try!(fmt.write_str("WebSocket Url Error: ")); + try!(fmt.write_str(self.description())); + Ok(()) + } } impl Error for WSUrlErrorKind { - fn description(&self) -> &str { - match *self { - WSUrlErrorKind::CannotSetFragment => "WebSocket URL cannot set fragment", - WSUrlErrorKind::InvalidScheme => "WebSocket URL invalid scheme", - WSUrlErrorKind::NoHostName => "WebSocket URL no host name provided", - } - } + fn description(&self) -> &str { + match *self { + WSUrlErrorKind::CannotSetFragment => "WebSocket URL cannot set fragment", + WSUrlErrorKind::InvalidScheme => "WebSocket URL invalid scheme", + WSUrlErrorKind::NoHostName => "WebSocket URL no host name provided", + } + } } diff --git a/src/sender.rs b/src/sender.rs index ed0ebefca0..e8860cdbef 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -10,44 +10,44 @@ use ws::sender::Sender as SenderTrait; pub use stream::Shutdown; pub struct Writer { - pub stream: W, - pub sender: Sender, + pub stream: W, + pub sender: Sender, } impl Writer - where W: Write, + where W: Write { - /// Sends a single data frame to the remote endpoint. - pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> - where D: DataFrame, - W: Write, - { - self.sender.send_dataframe(&mut self.stream, dataframe) - } + /// Sends a single data frame to the remote endpoint. + pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> + where D: DataFrame, + W: Write + { + self.sender.send_dataframe(&mut self.stream, dataframe) + } - /// Sends a single message to the remote endpoint. - pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> - where M: ws::Message<'m, D>, - D: DataFrame - { - self.sender.send_message(&mut self.stream, message) - } + /// Sends a single message to the remote endpoint. + pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()> + where M: ws::Message<'m, D>, + D: DataFrame + { + self.sender.send_message(&mut self.stream, message) + } } impl Writer - where S: AsTcpStream + Write, + where S: AsTcpStream + Write { - /// Closes the sender side of the connection, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Write) - } + /// Closes the sender side of the connection, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Write) + } - /// Shuts down both Sender and Receiver, will cause all pending and future IO to - /// return immediately with an appropriate value. - pub fn shutdown_all(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Both) - } + /// Shuts down both Sender and Receiver, will cause all pending and future IO to + /// return immediately with an appropriate value. + pub fn shutdown_all(&self) -> IoResult<()> { + self.stream.as_tcp().shutdown(Shutdown::Both) + } } /// A Sender that wraps a Writer and provides a default implementation using @@ -59,18 +59,16 @@ pub struct Sender { impl Sender { /// Create a new WebSocketSender using the specified Writer. pub fn new(mask: bool) -> Sender { - Sender { - mask: mask, - } + Sender { mask: mask } } } impl ws::Sender for Sender { /// Sends a single data frame to the remote endpoint. - fn send_dataframe(&mut self, writer: &mut W, dataframe: &D) -> WebSocketResult<()> - where D: DataFrame, - W: Write, - { + fn send_dataframe(&mut self, writer: &mut W, dataframe: &D) -> WebSocketResult<()> + where D: DataFrame, + W: Write + { dataframe.write_to(writer, self.mask) } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 105f84ca0c..5a8af0be3d 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,39 +1,21 @@ //! Provides an implementation of a WebSocket server -use std::net::{ - SocketAddr, - ToSocketAddrs, - TcpListener, - TcpStream, -}; -use std::io::{ - self, -}; +use std::net::{SocketAddr, ToSocketAddrs, TcpListener, TcpStream}; +use std::io; use std::convert::Into; #[cfg(feature="ssl")] -use openssl::ssl::{ - SslStream, - SslAcceptor, -}; -use stream::{ - Stream, -}; -use self::upgrade::{ - WsUpgrade, - IntoWs, -}; -pub use self::upgrade::{ - Request, - HyperIntoWsError, -}; +use openssl::ssl::{SslStream, SslAcceptor}; +use stream::Stream; +use self::upgrade::{WsUpgrade, IntoWs}; +pub use self::upgrade::{Request, HyperIntoWsError}; pub mod upgrade; pub struct InvalidConnection - where S: Stream, + where S: Stream { - pub stream: Option, - pub parsed: Option, - pub error: HyperIntoWsError, + pub stream: Option, + pub parsed: Option, + pub error: HyperIntoWsError, } pub type AcceptResult = Result, InvalidConnection>; @@ -50,9 +32,11 @@ impl OptionalSslAcceptor for NoSslAcceptor {} #[cfg(feature="ssl")] impl OptionalSslAcceptor for SslAcceptor {} -/// Represents a WebSocket server which can work with either normal (non-secure) connections, or secure WebSocket connections. +/// Represents a WebSocket server which can work with either normal +/// (non-secure) connections, or secure WebSocket connections. /// -/// This is a convenient way to implement WebSocket servers, however it is possible to use any sendable Reader and Writer to obtain +/// This is a convenient way to implement WebSocket servers, however +/// it is possible to use any sendable Reader and Writer to obtain /// a WebSocketClient, so if needed, an alternative server implementation can be used. ///#Non-secure Servers /// @@ -122,126 +106,135 @@ impl OptionalSslAcceptor for SslAcceptor {} /// # } /// ``` pub struct Server - where S: OptionalSslAcceptor, + where S: OptionalSslAcceptor { - pub listener: TcpListener, - ssl_acceptor: S, + pub listener: TcpListener, + ssl_acceptor: S, } impl Server - where S: OptionalSslAcceptor, + where S: OptionalSslAcceptor { - /// Get the socket address of this server - pub fn local_addr(&self) -> io::Result { - self.listener.local_addr() - } - - /// Create a new independently owned handle to the underlying socket. - pub fn try_clone(&self) -> io::Result> { - let inner = try!(self.listener.try_clone()); - Ok(Server { - listener: inner, - ssl_acceptor: self.ssl_acceptor.clone(), - }) - } + /// Get the socket address of this server + pub fn local_addr(&self) -> io::Result { + self.listener.local_addr() + } + + /// Create a new independently owned handle to the underlying socket. + pub fn try_clone(&self) -> io::Result> { + let inner = try!(self.listener.try_clone()); + Ok(Server { + listener: inner, + ssl_acceptor: self.ssl_acceptor.clone(), + }) + } } #[cfg(feature="ssl")] impl Server { - /// Bind this Server to this socket, utilising the given SslContext - pub fn bind_secure(addr: A, acceptor: SslAcceptor) -> io::Result - where A: ToSocketAddrs, - { - Ok(Server { - listener: try!(TcpListener::bind(&addr)), - ssl_acceptor: acceptor, - }) - } - - /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest - pub fn accept(&mut self) -> AcceptResult> { - let stream = match self.listener.accept() { - Ok(s) => s.0, - Err(e) => return Err(InvalidConnection { - stream: None, - parsed: None, - error: e.into(), - }), - }; - - let stream = match self.ssl_acceptor.accept(stream) { - Ok(s) => s, - Err(err) => return Err(InvalidConnection { - stream: None, - parsed: None, - error: io::Error::new(io::ErrorKind::Other, err).into(), - }), - }; - - match stream.into_ws() { - Ok(u) => Ok(u), - Err((s, r, e)) => Err(InvalidConnection { + /// Bind this Server to this socket, utilising the given SslContext + pub fn bind_secure(addr: A, acceptor: SslAcceptor) -> io::Result + where A: ToSocketAddrs + { + Ok(Server { + listener: try!(TcpListener::bind(&addr)), + ssl_acceptor: acceptor, + }) + } + + /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest + pub fn accept(&mut self) -> AcceptResult> { + let stream = match self.listener.accept() { + Ok(s) => s.0, + Err(e) => { + return Err(InvalidConnection { + stream: None, + parsed: None, + error: e.into(), + }) + } + }; + + let stream = match self.ssl_acceptor.accept(stream) { + Ok(s) => s, + Err(err) => { + return Err(InvalidConnection { + stream: None, + parsed: None, + error: io::Error::new(io::ErrorKind::Other, err).into(), + }) + } + }; + + match stream.into_ws() { + Ok(u) => Ok(u), + Err((s, r, e)) => { + Err(InvalidConnection { stream: Some(s), parsed: r, error: e.into(), - }), - } - } - - /// Changes whether the Server is in nonblocking mode. - /// - /// If it is in nonblocking mode, accept() will return an error instead of blocking when there - /// are no incoming connections. - pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { - self.listener.set_nonblocking(nonblocking) - } + }) + } + } + } + + /// Changes whether the Server is in nonblocking mode. + /// + /// If it is in nonblocking mode, accept() will return an error instead of blocking when there + /// are no incoming connections. + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.listener.set_nonblocking(nonblocking) + } } #[cfg(feature="ssl")] impl Iterator for Server { - type Item = WsUpgrade>; + type Item = WsUpgrade>; - fn next(&mut self) -> Option<::Item> { - self.accept().ok() - } + fn next(&mut self) -> Option<::Item> { + self.accept().ok() + } } impl Server { - /// Bind this Server to this socket - pub fn bind(addr: A) -> io::Result { - Ok(Server { - listener: try!(TcpListener::bind(&addr)), - ssl_acceptor: NoSslAcceptor, - }) - } - - /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest - pub fn accept(&mut self) -> AcceptResult { - let stream = match self.listener.accept() { - Ok(s) => s.0, - Err(e) => return Err(InvalidConnection { - stream: None, - parsed: None, - error: e.into(), - }), - }; - - match stream.into_ws() { - Ok(u) => Ok(u), - Err((s, r, e)) => Err(InvalidConnection { + /// Bind this Server to this socket + pub fn bind(addr: A) -> io::Result { + Ok(Server { + listener: try!(TcpListener::bind(&addr)), + ssl_acceptor: NoSslAcceptor, + }) + } + + /// Wait for and accept an incoming WebSocket connection, returning a WebSocketRequest + pub fn accept(&mut self) -> AcceptResult { + let stream = match self.listener.accept() { + Ok(s) => s.0, + Err(e) => { + return Err(InvalidConnection { + stream: None, + parsed: None, + error: e.into(), + }) + } + }; + + match stream.into_ws() { + Ok(u) => Ok(u), + Err((s, r, e)) => { + Err(InvalidConnection { stream: Some(s), parsed: r, error: e.into(), - }), - } - } + }) + } + } + } } impl Iterator for Server { - type Item = WsUpgrade; + type Item = WsUpgrade; - fn next(&mut self) -> Option<::Item> { - self.accept().ok() - } + fn next(&mut self) -> Option<::Item> { + self.accept().ok() + } } - diff --git a/src/server/upgrade/hyper.rs b/src/server/upgrade/hyper.rs index 90b9f3d86a..42ddd3f382 100644 --- a/src/server/upgrade/hyper.rs +++ b/src/server/upgrade/hyper.rs @@ -1,12 +1,7 @@ extern crate hyper; -use hyper::net::{ - NetworkStream, -}; -use super::{ - IntoWs, - WsUpgrade, -}; +use hyper::net::NetworkStream; +use super::{IntoWs, WsUpgrade}; pub use hyper::http::h1::Incoming; pub use hyper::method::Method; @@ -14,13 +9,7 @@ pub use hyper::version::HttpVersion; pub use hyper::uri::RequestUri; pub use hyper::buffer::BufReader; use hyper::server::Request; -pub use hyper::header::{ - Headers, - Upgrade, - ProtocolName, - Connection, - ConnectionOption, -}; +pub use hyper::header::{Headers, Upgrade, ProtocolName, Connection, ConnectionOption}; use super::validate; use super::HyperIntoWsError; @@ -28,24 +17,25 @@ use super::HyperIntoWsError; pub struct HyperRequest<'a, 'b: 'a>(pub Request<'a, 'b>); impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { - type Stream = &'a mut &'b mut NetworkStream; - type Error = (Request<'a, 'b>, HyperIntoWsError); - - fn into_ws(self) -> Result, Self::Error> { - if let Err(e) = validate(&self.0.method, &self.0.version, &self.0.headers) { - return Err((self.0, e)); - } - - let (_, method, headers, uri, version, reader) = self.0.deconstruct(); - let stream = reader.into_inner().get_mut(); - - Ok(WsUpgrade { - stream: stream, - request: Incoming { - version: version, - headers: headers, - subject: (method, uri), - }, - }) - } + type Stream = &'a mut &'b mut NetworkStream; + type Error = (Request<'a, 'b>, HyperIntoWsError); + + fn into_ws(self) -> Result, Self::Error> { + if let Err(e) = validate(&self.0.method, &self.0.version, &self.0.headers) { + return Err((self.0, e)); + } + + let (_, method, headers, uri, version, reader) = + self.0.deconstruct(); + let stream = reader.into_inner().get_mut(); + + Ok(WsUpgrade { + stream: stream, + request: Incoming { + version: version, + headers: headers, + subject: (method, uri), + }, + }) + } } diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index 72db067b5f..941dedafaf 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -7,27 +7,12 @@ use std::net::TcpStream; use std::io; use std::io::Result as IoResult; use std::io::Error as IoError; -use std::io::{ - Write, -}; -use std::fmt::{ - Formatter, - Display, - self, -}; -use stream::{ - Stream, - AsTcpStream, -}; +use std::io::Write; +use std::fmt::{self, Formatter, Display}; +use stream::{Stream, AsTcpStream}; use header::extensions::Extension; -use header::{ - WebSocketAccept, - WebSocketKey, - WebSocketVersion, - WebSocketProtocol, - WebSocketExtensions, - Origin, -}; +use header::{WebSocketAccept, WebSocketKey, WebSocketVersion, WebSocketProtocol, + WebSocketExtensions, Origin}; use client::Client; use unicase::UniCase; @@ -38,14 +23,8 @@ pub use self::real_hyper::version::HttpVersion; pub use self::real_hyper::uri::RequestUri; pub use self::real_hyper::buffer::BufReader; pub use self::real_hyper::http::h1::parse_request; -pub use self::real_hyper::header::{ - Headers, - Upgrade, - Protocol, - ProtocolName, - Connection, - ConnectionOption, -}; +pub use self::real_hyper::header::{Headers, Upgrade, Protocol, ProtocolName, Connection, + ConnectionOption}; pub mod hyper; @@ -56,98 +35,100 @@ pub mod hyper; /// Users should then call `accept` or `deny` to complete the handshake /// and start a session. pub struct WsUpgrade - where S: Stream, + where S: Stream { - stream: S, - request: Request, + stream: S, + request: Request, } impl WsUpgrade - where S: Stream, + where S: Stream { - pub fn accept(self) -> IoResult> { - self.accept_with(&Headers::new()) - } - - pub fn accept_with(mut self, custom_headers: &Headers) -> IoResult> { - let mut headers = Headers::new(); - headers.extend(custom_headers.iter()); - headers.set(WebSocketAccept::new( + pub fn accept(self) -> IoResult> { + self.accept_with(&Headers::new()) + } + + pub fn accept_with(mut self, custom_headers: &Headers) -> IoResult> { + let mut headers = Headers::new(); + headers.extend(custom_headers.iter()); + headers.set(WebSocketAccept::new( // NOTE: we know there is a key because this is a valid request // i.e. to construct this you must go through the validate function self.request.headers.get::().unwrap() )); - headers.set(Connection(vec![ + headers.set(Connection(vec![ ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) ])); - headers.set(Upgrade(vec![ - Protocol::new(ProtocolName::WebSocket, None) - ])); - - try!(self.send(StatusCode::SwitchingProtocols, &headers)); - - Ok(Client::unchecked(self.stream)) - } - - pub fn reject(self) -> Result { - self.reject_with(&Headers::new()) - } - - pub fn reject_with(mut self, headers: &Headers) -> Result { - match self.send(StatusCode::BadRequest, headers) { - Ok(()) => Ok(self.stream), - Err(e) => Err((self.stream, e)) - } - } - - pub fn drop(self) { - ::std::mem::drop(self); - } - - pub fn protocols(&self) -> &[String] { - self.request.headers.get::() - .map(|p| p.0.as_slice()) - .unwrap_or(&[]) - } - - pub fn extensions(&self) -> &[Extension] { - self.request.headers.get::() - .map(|e| e.0.as_slice()) - .unwrap_or(&[]) - } - - pub fn key(&self) -> Option<&[u8; 16]> { - self.request.headers.get::().map(|k| &k.0) - } - - pub fn version(&self) -> Option<&WebSocketVersion> { - self.request.headers.get::() - } - - pub fn origin(&self) -> Option<&str> { - self.request.headers.get::().map(|o| &o.0 as &str) - } - - pub fn into_stream(self) -> S { - self.stream - } - - fn send(&mut self, - status: StatusCode, - headers: &Headers - ) -> IoResult<()> { - try!(write!(self.stream.writer(), "{} {}\r\n", self.request.version, status)); - try!(write!(self.stream.writer(), "{}\r\n", headers)); - Ok(()) - } + headers.set(Upgrade(vec![Protocol::new(ProtocolName::WebSocket, None)])); + + try!(self.send(StatusCode::SwitchingProtocols, &headers)); + + Ok(Client::unchecked(self.stream)) + } + + pub fn reject(self) -> Result { + self.reject_with(&Headers::new()) + } + + pub fn reject_with(mut self, headers: &Headers) -> Result { + match self.send(StatusCode::BadRequest, headers) { + Ok(()) => Ok(self.stream), + Err(e) => Err((self.stream, e)), + } + } + + pub fn drop(self) { + ::std::mem::drop(self); + } + + pub fn protocols(&self) -> &[String] { + self.request + .headers + .get::() + .map(|p| p.0.as_slice()) + .unwrap_or(&[]) + } + + pub fn extensions(&self) -> &[Extension] { + self.request + .headers + .get::() + .map(|e| e.0.as_slice()) + .unwrap_or(&[]) + } + + pub fn key(&self) -> Option<&[u8; 16]> { + self.request.headers.get::().map(|k| &k.0) + } + + pub fn version(&self) -> Option<&WebSocketVersion> { + self.request.headers.get::() + } + + pub fn origin(&self) -> Option<&str> { + self.request.headers.get::().map(|o| &o.0 as &str) + } + + pub fn into_stream(self) -> S { + self.stream + } + + fn send(&mut self, status: StatusCode, headers: &Headers) -> IoResult<()> { + try!(write!(self.stream.writer(), + "{} {}\r\n", + self.request.version, + status)); + try!(write!(self.stream.writer(), "{}\r\n", headers)); + Ok(()) + } } impl WsUpgrade - where S: Stream + AsTcpStream, + where S: Stream + AsTcpStream { - pub fn tcp_stream(&self) -> &TcpStream { - self.stream.as_tcp() - } + pub fn tcp_stream(&self) -> &TcpStream { + self.stream.as_tcp() + } } /// Trait to take a stream or similar and attempt to recover the start of a @@ -160,12 +141,12 @@ impl WsUpgrade /// /// Note: the stream is owned because the websocket client expects to own its stream. pub trait IntoWs { - type Stream: Stream; - type Error; - /// Attempt to parse the start of a Websocket handshake, later with the returned - /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to - /// send a handshake rejection response. - fn into_ws(self) -> Result, Self::Error>; + type Stream: Stream; + type Error; + /// Attempt to parse the start of a Websocket handshake, later with the returned + /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to + /// send a handshake rejection response. + fn into_ws(self) -> Result, Self::Error>; } @@ -173,163 +154,166 @@ pub type Request = Incoming<(Method, RequestUri)>; pub struct RequestStreamPair(pub S, pub Request); impl IntoWs for S - where S: Stream, + where S: Stream { - type Stream = S; - type Error = (Self, Option, HyperIntoWsError); - - fn into_ws(mut self) -> Result, Self::Error> { - let request = { - let mut reader = BufReader::new(self.reader()); - parse_request(&mut reader) - }; - - let request = match request { - Ok(r) => r, - Err(e) => return Err((self, None, e.into())), - }; - - match validate(&request.subject.0, &request.version, &request.headers) { - Ok(_) => Ok(WsUpgrade { - stream: self, - request: request, - }), - Err(e) => Err((self, Some(request), e)), - } - } + type Stream = S; + type Error = (Self, Option, HyperIntoWsError); + + fn into_ws(mut self) -> Result, Self::Error> { + let request = { + let mut reader = BufReader::new(self.reader()); + parse_request(&mut reader) + }; + + let request = match request { + Ok(r) => r, + Err(e) => return Err((self, None, e.into())), + }; + + match validate(&request.subject.0, &request.version, &request.headers) { + Ok(_) => { + Ok(WsUpgrade { + stream: self, + request: request, + }) + } + Err(e) => Err((self, Some(request), e)), + } + } } impl IntoWs for RequestStreamPair - where S: Stream, + where S: Stream { - type Stream = S; - type Error = (S, Request, HyperIntoWsError); - - fn into_ws(self) -> Result, Self::Error> { - match validate(&self.1.subject.0, &self.1.version, &self.1.headers) { - Ok(_) => Ok(WsUpgrade { - stream: self.0, - request: self.1, - }), - Err(e) => Err((self.0, self.1, e)), - } - } + type Stream = S; + type Error = (S, Request, HyperIntoWsError); + + fn into_ws(self) -> Result, Self::Error> { + match validate(&self.1.subject.0, &self.1.version, &self.1.headers) { + Ok(_) => { + Ok(WsUpgrade { + stream: self.0, + request: self.1, + }) + } + Err(e) => Err((self.0, self.1, e)), + } + } } #[derive(Debug)] pub enum HyperIntoWsError { - MethodNotGet, - UnsupportedHttpVersion, - UnsupportedWebsocketVersion, - NoSecWsKeyHeader, - NoWsUpgradeHeader, - NoUpgradeHeader, - NoWsConnectionHeader, - NoConnectionHeader, - UnknownNetworkStream, - /// IO error from reading the underlying socket - Io(io::Error), - /// Error while parsing an incoming request - Parsing(self::real_hyper::error::Error), + MethodNotGet, + UnsupportedHttpVersion, + UnsupportedWebsocketVersion, + NoSecWsKeyHeader, + NoWsUpgradeHeader, + NoUpgradeHeader, + NoWsConnectionHeader, + NoConnectionHeader, + UnknownNetworkStream, + /// IO error from reading the underlying socket + Io(io::Error), + /// Error while parsing an incoming request + Parsing(self::real_hyper::error::Error), } impl Display for HyperIntoWsError { - fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { - fmt.write_str(self.description()) - } + fn fmt(&self, fmt: &mut Formatter) -> Result<(), fmt::Error> { + fmt.write_str(self.description()) + } } impl Error for HyperIntoWsError { - fn description(&self) -> &str { - use self::HyperIntoWsError::*; - match self { - &MethodNotGet => "Request method must be GET", - &UnsupportedHttpVersion => "Unsupported request HTTP version", - &UnsupportedWebsocketVersion => "Unsupported WebSocket version", - &NoSecWsKeyHeader => "Missing Sec-WebSocket-Key header", - &NoWsUpgradeHeader => "Invalid Upgrade WebSocket header", - &NoUpgradeHeader => "Missing Upgrade WebSocket header", - &NoWsConnectionHeader => "Invalid Connection WebSocket header", - &NoConnectionHeader => "Missing Connection WebSocket header", - &UnknownNetworkStream => "Cannot downcast to known impl of NetworkStream", - &Io(ref e) => e.description(), - &Parsing(ref e) => e.description(), - } - } - - fn cause(&self) -> Option<&Error> { - match *self { - HyperIntoWsError::Io(ref e) => Some(e), - HyperIntoWsError::Parsing(ref e) => Some(e), - _ => None, - } - } + fn description(&self) -> &str { + use self::HyperIntoWsError::*; + match self { + &MethodNotGet => "Request method must be GET", + &UnsupportedHttpVersion => "Unsupported request HTTP version", + &UnsupportedWebsocketVersion => "Unsupported WebSocket version", + &NoSecWsKeyHeader => "Missing Sec-WebSocket-Key header", + &NoWsUpgradeHeader => "Invalid Upgrade WebSocket header", + &NoUpgradeHeader => "Missing Upgrade WebSocket header", + &NoWsConnectionHeader => "Invalid Connection WebSocket header", + &NoConnectionHeader => "Missing Connection WebSocket header", + &UnknownNetworkStream => "Cannot downcast to known impl of NetworkStream", + &Io(ref e) => e.description(), + &Parsing(ref e) => e.description(), + } + } + + fn cause(&self) -> Option<&Error> { + match *self { + HyperIntoWsError::Io(ref e) => Some(e), + HyperIntoWsError::Parsing(ref e) => Some(e), + _ => None, + } + } } impl From for HyperIntoWsError { - fn from(err: io::Error) -> Self { - HyperIntoWsError::Io(err) - } + fn from(err: io::Error) -> Self { + HyperIntoWsError::Io(err) + } } impl From for HyperIntoWsError { - fn from(err: real_hyper::error::Error) -> Self { - HyperIntoWsError::Parsing(err) - } + fn from(err: real_hyper::error::Error) -> Self { + HyperIntoWsError::Parsing(err) + } } pub fn validate( - method: &Method, - version: &HttpVersion, - headers: &Headers -) -> Result<(), HyperIntoWsError> -{ - if *method != Method::Get { - return Err(HyperIntoWsError::MethodNotGet); - } - - if *version == HttpVersion::Http09 || *version == HttpVersion::Http10 { - return Err(HyperIntoWsError::UnsupportedHttpVersion); - } - - if let Some(version) = headers.get::() { - if version != &WebSocketVersion::WebSocket13 { - return Err(HyperIntoWsError::UnsupportedWebsocketVersion); - } - } - - if headers.get::().is_none() { - return Err(HyperIntoWsError::NoSecWsKeyHeader); - } - - match headers.get() { - Some(&Upgrade(ref upgrade)) => { - if upgrade.iter().all(|u| u.name != ProtocolName::WebSocket) { - return Err(HyperIntoWsError::NoWsUpgradeHeader) - } - }, - None => return Err(HyperIntoWsError::NoUpgradeHeader), - }; - - fn check_connection_header(headers: &Vec) -> bool { - for header in headers { - if let &ConnectionOption::ConnectionHeader(ref h) = header { - if h as &str == "upgrade" { - return true; - } - } - } - false - } - - match headers.get() { - Some(&Connection(ref connection)) => { - if !check_connection_header(connection) { - return Err(HyperIntoWsError::NoWsConnectionHeader); - } - }, - None => return Err(HyperIntoWsError::NoConnectionHeader), - }; - - Ok(()) + method: &Method, + version: &HttpVersion, + headers: &Headers, +) -> Result<(), HyperIntoWsError> { + if *method != Method::Get { + return Err(HyperIntoWsError::MethodNotGet); + } + + if *version == HttpVersion::Http09 || *version == HttpVersion::Http10 { + return Err(HyperIntoWsError::UnsupportedHttpVersion); + } + + if let Some(version) = headers.get::() { + if version != &WebSocketVersion::WebSocket13 { + return Err(HyperIntoWsError::UnsupportedWebsocketVersion); + } + } + + if headers.get::().is_none() { + return Err(HyperIntoWsError::NoSecWsKeyHeader); + } + + match headers.get() { + Some(&Upgrade(ref upgrade)) => { + if upgrade.iter().all(|u| u.name != ProtocolName::WebSocket) { + return Err(HyperIntoWsError::NoWsUpgradeHeader); + } + } + None => return Err(HyperIntoWsError::NoUpgradeHeader), + }; + + fn check_connection_header(headers: &Vec) -> bool { + for header in headers { + if let &ConnectionOption::ConnectionHeader(ref h) = header { + if h as &str == "upgrade" { + return true; + } + } + } + false + } + + match headers.get() { + Some(&Connection(ref connection)) => { + if !check_connection_header(connection) { + return Err(HyperIntoWsError::NoWsConnectionHeader); + } + } + None => return Err(HyperIntoWsError::NoConnectionHeader), + }; + + Ok(()) } diff --git a/src/stream.rs b/src/stream.rs index 79c9bb8572..1a0962df1c 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,72 +1,65 @@ //! Provides the default stream type for WebSocket connections. use std::ops::Deref; -use std::io::{ - self, - Read, - Write -}; +use std::io::{self, Read, Write}; pub use std::net::TcpStream; pub use std::net::Shutdown; #[cfg(feature="ssl")] -pub use openssl::ssl::{ - SslStream, - SslContext, -}; +pub use openssl::ssl::{SslStream, SslContext}; pub trait Splittable { - type Reader: Read; - type Writer: Write; + type Reader: Read; + type Writer: Write; - fn split(self) -> io::Result<(Self::Reader, Self::Writer)>; + fn split(self) -> io::Result<(Self::Reader, Self::Writer)>; } /// Represents a stream that can be read from, and written to. /// This is an abstraction around readable and writable things to be able /// to speak websockets over ssl, tcp, unix sockets, etc. pub trait Stream { - type Reader: Read; - type Writer: Write; + type Reader: Read; + type Writer: Write; - /// Get a mutable borrow to the reading component of this stream - fn reader(&mut self) -> &mut Self::Reader; + /// Get a mutable borrow to the reading component of this stream + fn reader(&mut self) -> &mut Self::Reader; - /// Get a mutable borrow to the writing component of this stream - fn writer(&mut self) -> &mut Self::Writer; + /// Get a mutable borrow to the writing component of this stream + fn writer(&mut self) -> &mut Self::Writer; } pub struct ReadWritePair(pub R, pub W) - where R: Read, - W: Write; + where R: Read, + W: Write; impl Splittable for ReadWritePair - where R: Read, - W: Write, + where R: Read, + W: Write { - type Reader = R; - type Writer = W; + type Reader = R; + type Writer = W; - fn split(self) -> io::Result<(R, W)> { - Ok((self.0, self.1)) - } + fn split(self) -> io::Result<(R, W)> { + Ok((self.0, self.1)) + } } impl Stream for ReadWritePair - where R: Read, - W: Write, + where R: Read, + W: Write { - type Reader = R; - type Writer = W; - - #[inline] - fn reader(&mut self) -> &mut R { - &mut self.0 - } - - #[inline] - fn writer(&mut self) -> &mut W { - &mut self.1 - } + type Reader = R; + type Writer = W; + + #[inline] + fn reader(&mut self) -> &mut R { + &mut self.0 + } + + #[inline] + fn writer(&mut self) -> &mut W { + &mut self.1 + } } pub trait ReadWrite: Read + Write {} @@ -75,18 +68,18 @@ impl ReadWrite for S where S: Read + Write {} pub struct BoxedStream(pub Box); impl Stream for BoxedStream { - type Reader = Box; - type Writer = Box; - - #[inline] - fn reader(&mut self) -> &mut Self::Reader { - &mut self.0 - } - - #[inline] - fn writer(&mut self) -> &mut Self::Writer { - &mut self.0 - } + type Reader = Box; + type Writer = Box; + + #[inline] + fn reader(&mut self) -> &mut Self::Reader { + &mut self.0 + } + + #[inline] + fn writer(&mut self) -> &mut Self::Writer { + &mut self.0 + } } pub trait NetworkStream: Read + Write + AsTcpStream {} @@ -95,73 +88,73 @@ impl NetworkStream for S where S: Read + Write + AsTcpStream {} pub struct BoxedNetworkStream(pub Box); impl AsTcpStream for BoxedNetworkStream { - fn as_tcp(&self) -> &TcpStream { - self.0.deref().as_tcp() - } + fn as_tcp(&self) -> &TcpStream { + self.0.deref().as_tcp() + } } impl Stream for BoxedNetworkStream { - type Reader = Box; - type Writer = Box; - - #[inline] - fn reader(&mut self) -> &mut Self::Reader { - &mut self.0 - } - - #[inline] - fn writer(&mut self) -> &mut Self::Writer { - &mut self.0 - } + type Reader = Box; + type Writer = Box; + + #[inline] + fn reader(&mut self) -> &mut Self::Reader { + &mut self.0 + } + + #[inline] + fn writer(&mut self) -> &mut Self::Writer { + &mut self.0 + } } impl Splittable for TcpStream { - type Reader = TcpStream; - type Writer = TcpStream; + type Reader = TcpStream; + type Writer = TcpStream; - fn split(self) -> io::Result<(TcpStream, TcpStream)> { - self.try_clone().map(|s| (s, self)) - } + fn split(self) -> io::Result<(TcpStream, TcpStream)> { + self.try_clone().map(|s| (s, self)) + } } impl Stream for S - where S: Read + Write, + where S: Read + Write { - type Reader = Self; - type Writer = Self; - - #[inline] - fn reader(&mut self) -> &mut S { - self - } - - #[inline] - fn writer(&mut self) -> &mut S { - self - } + type Reader = Self; + type Writer = Self; + + #[inline] + fn reader(&mut self) -> &mut S { + self + } + + #[inline] + fn writer(&mut self) -> &mut S { + self + } } pub trait AsTcpStream { - fn as_tcp(&self) -> &TcpStream; + fn as_tcp(&self) -> &TcpStream; } impl AsTcpStream for TcpStream { - fn as_tcp(&self) -> &TcpStream { - &self - } + fn as_tcp(&self) -> &TcpStream { + &self + } } #[cfg(feature="ssl")] impl AsTcpStream for SslStream { - fn as_tcp(&self) -> &TcpStream { - self.get_ref() - } + fn as_tcp(&self) -> &TcpStream { + self.get_ref() + } } impl AsTcpStream for Box - where T: AsTcpStream, + where T: AsTcpStream { - fn as_tcp(&self) -> &TcpStream { - self.deref().as_tcp() - } + fn as_tcp(&self) -> &TcpStream { + self.deref().as_tcp() + } } diff --git a/src/ws/dataframe.rs b/src/ws/dataframe.rs index 2c8243b240..8e1669093b 100644 --- a/src/ws/dataframe.rs +++ b/src/ws/dataframe.rs @@ -13,112 +13,113 @@ use ws::util::mask; /// provide these methods. (If the payload is not known in advance then /// rewrite the write_payload method) pub trait DataFrame { - /// Is this dataframe the final dataframe of the message? - fn is_last(&self) -> bool; - /// What type of data does this dataframe contain? - fn opcode(&self) -> u8; - /// Reserved bits of this dataframe - fn reserved<'a>(&'a self) -> &'a [bool; 3]; - /// Entire payload of the dataframe. If not known then implement - /// write_payload as that is the actual method used when sending the - /// dataframe over the wire. - fn payload<'a>(&'a self) -> Cow<'a, [u8]>; + /// Is this dataframe the final dataframe of the message? + fn is_last(&self) -> bool; + /// What type of data does this dataframe contain? + fn opcode(&self) -> u8; + /// Reserved bits of this dataframe + fn reserved<'a>(&'a self) -> &'a [bool; 3]; + /// Entire payload of the dataframe. If not known then implement + /// write_payload as that is the actual method used when sending the + /// dataframe over the wire. + fn payload<'a>(&'a self) -> Cow<'a, [u8]>; - /// How long (in bytes) is this dataframe's payload - fn size(&self) -> usize { - self.payload().len() - } + /// How long (in bytes) is this dataframe's payload + fn size(&self) -> usize { + self.payload().len() + } - /// Write the payload to a writer - fn write_payload(&self, socket: &mut W) -> WebSocketResult<()> - where W: Write { - try!(socket.write_all(&*self.payload())); - Ok(()) - } + /// Write the payload to a writer + fn write_payload(&self, socket: &mut W) -> WebSocketResult<()> + where W: Write + { + try!(socket.write_all(&*self.payload())); + Ok(()) + } - /// Writes a DataFrame to a Writer. - fn write_to(&self, writer: &mut W, mask: bool) -> WebSocketResult<()> - where W: Write { - let mut flags = dfh::DataFrameFlags::empty(); - if self.is_last() { - flags.insert(dfh::FIN); - } - { - let reserved = self.reserved(); - if reserved[0] { - flags.insert(dfh::RSV1); - } - if reserved[1] { - flags.insert(dfh::RSV2); - } - if reserved[2] { - flags.insert(dfh::RSV3); - } - } + /// Writes a DataFrame to a Writer. + fn write_to(&self, writer: &mut W, mask: bool) -> WebSocketResult<()> + where W: Write + { + let mut flags = dfh::DataFrameFlags::empty(); + if self.is_last() { + flags.insert(dfh::FIN); + } + { + let reserved = self.reserved(); + if reserved[0] { + flags.insert(dfh::RSV1); + } + if reserved[1] { + flags.insert(dfh::RSV2); + } + if reserved[2] { + flags.insert(dfh::RSV3); + } + } - let masking_key = if mask { - Some(mask::gen_mask()) - } else { - None - }; + let masking_key = if mask { Some(mask::gen_mask()) } else { None }; - let header = dfh::DataFrameHeader { - flags: flags, - opcode: self.opcode() as u8, - mask: masking_key, - len: self.size() as u64, - }; + let header = dfh::DataFrameHeader { + flags: flags, + opcode: self.opcode() as u8, + mask: masking_key, + len: self.size() as u64, + }; - try!(dfh::write_header(writer, header)); + try!(dfh::write_header(writer, header)); - match masking_key { - Some(mask) => { - let mut masker = Masker::new(mask, writer); - try!(self.write_payload(&mut masker)) - }, - None => try!(self.write_payload(writer)), - }; - try!(writer.flush()); - Ok(()) - } + match masking_key { + Some(mask) => { + let mut masker = Masker::new(mask, writer); + try!(self.write_payload(&mut masker)) + } + None => try!(self.write_payload(writer)), + }; + try!(writer.flush()); + Ok(()) + } } impl<'a, D> DataFrame for &'a D -where D: DataFrame { - #[inline(always)] - fn is_last(&self) -> bool { - D::is_last(self) - } + where D: DataFrame +{ + #[inline(always)] + fn is_last(&self) -> bool { + D::is_last(self) + } - #[inline(always)] - fn opcode(&self) -> u8 { - D::opcode(self) - } + #[inline(always)] + fn opcode(&self) -> u8 { + D::opcode(self) + } - #[inline(always)] - fn reserved<'b>(&'b self) -> &'b [bool; 3] { - D::reserved(self) - } + #[inline(always)] + fn reserved<'b>(&'b self) -> &'b [bool; 3] { + D::reserved(self) + } - #[inline(always)] - fn payload<'b>(&'b self) -> Cow<'b, [u8]> { - D::payload(self) - } + #[inline(always)] + fn payload<'b>(&'b self) -> Cow<'b, [u8]> { + D::payload(self) + } - #[inline(always)] - fn size(&self) -> usize { - D::size(self) - } + #[inline(always)] + fn size(&self) -> usize { + D::size(self) + } - #[inline(always)] - fn write_payload(&self, socket: &mut W) -> WebSocketResult<()> - where W: Write { - D::write_payload(self, socket) - } + #[inline(always)] + fn write_payload(&self, socket: &mut W) -> WebSocketResult<()> + where W: Write + { + D::write_payload(self, socket) + } - #[inline(always)] - fn write_to(&self, writer: &mut W, mask: bool) -> WebSocketResult<()> - where W: Write { - D::write_to(self, writer, mask) - } + #[inline(always)] + fn write_to(&self, writer: &mut W, mask: bool) -> WebSocketResult<()> + where W: Write + { + D::write_to(self, writer, mask) + } } diff --git a/src/ws/message.rs b/src/ws/message.rs index a21e110002..dc02aa1a63 100644 --- a/src/ws/message.rs +++ b/src/ws/message.rs @@ -7,12 +7,12 @@ use ws::dataframe::DataFrame; /// A trait for WebSocket messages pub trait Message<'a, F>: Sized -where F: DataFrame { + where F: DataFrame +{ /// The iterator type returned by dataframes type DataFrameIterator: Iterator; /// Attempt to form a message from a slice of data frames. - fn from_dataframes(frames: Vec) -> WebSocketResult - where D: DataFrame; + fn from_dataframes(frames: Vec) -> WebSocketResult where D: DataFrame; /// Turns this message into an iterator over data frames fn dataframes(&'a self) -> Self::DataFrameIterator; } diff --git a/src/ws/receiver.rs b/src/ws/receiver.rs index 3da8ade00f..9c218eea40 100644 --- a/src/ws/receiver.rs +++ b/src/ws/receiver.rs @@ -10,100 +10,100 @@ use ws::dataframe::DataFrame; use result::WebSocketResult; /// A trait for receiving data frames and messages. -pub trait Receiver: Sized -{ - type F: DataFrame; +pub trait Receiver: Sized { + type F: DataFrame; - /// Reads a single data frame from this receiver. - fn recv_dataframe(&mut self, reader: &mut R) -> WebSocketResult - where R: Read; + /// Reads a single data frame from this receiver. + fn recv_dataframe(&mut self, reader: &mut R) -> WebSocketResult where R: Read; - /// Returns the data frames that constitute one message. - fn recv_message_dataframes(&mut self, reader: &mut R) -> WebSocketResult> - where R: Read; + /// Returns the data frames that constitute one message. + fn recv_message_dataframes(&mut self, reader: &mut R) -> WebSocketResult> + where R: Read; - /// Returns an iterator over incoming data frames. - fn incoming_dataframes<'a, R>(&'a mut self, reader: &'a mut R) -> DataFrameIterator<'a, Self, R> - where R: Read, - { - DataFrameIterator { - reader: reader, - inner: self, - } - } + /// Returns an iterator over incoming data frames. + fn incoming_dataframes<'a, R>(&'a mut self, reader: &'a mut R) -> DataFrameIterator<'a, Self, R> + where R: Read + { + DataFrameIterator { + reader: reader, + inner: self, + } + } - /// Reads a single message from this receiver. - fn recv_message<'m, D, M, I, R>(&mut self, reader: &mut R) -> WebSocketResult - where M: Message<'m, D, DataFrameIterator = I>, - I: Iterator, - D: DataFrame, - R: Read, - { - let dataframes = try!(self.recv_message_dataframes(reader)); - Message::from_dataframes(dataframes) - } + /// Reads a single message from this receiver. + fn recv_message<'m, D, M, I, R>(&mut self, reader: &mut R) -> WebSocketResult + where M: Message<'m, D, DataFrameIterator = I>, + I: Iterator, + D: DataFrame, + R: Read + { + let dataframes = try!(self.recv_message_dataframes(reader)); + Message::from_dataframes(dataframes) + } - /// Returns an iterator over incoming messages. - fn incoming_messages<'a, M, D, R>(&'a mut self, reader: &'a mut R) -> MessageIterator<'a, Self, D, M, R> - where M: Message<'a, D>, - D: DataFrame, - R: Read, - { - MessageIterator { - reader: reader, - inner: self, - _dataframe: PhantomData, - _message: PhantomData, - } - } + /// Returns an iterator over incoming messages. + fn incoming_messages<'a, M, D, R>( + &'a mut self, + reader: &'a mut R, + ) -> MessageIterator<'a, Self, D, M, R> + where M: Message<'a, D>, + D: DataFrame, + R: Read + { + MessageIterator { + reader: reader, + inner: self, + _dataframe: PhantomData, + _message: PhantomData, + } + } } /// An iterator over data frames from a Receiver. pub struct DataFrameIterator<'a, Recv, R> - where Recv: 'a + Receiver, - R: 'a + Read, + where Recv: 'a + Receiver, + R: 'a + Read { - reader: &'a mut R, - inner: &'a mut Recv, + reader: &'a mut R, + inner: &'a mut Recv, } impl<'a, Recv, R> Iterator for DataFrameIterator<'a, Recv, R> - where Recv: 'a + Receiver, - R: Read, + where Recv: 'a + Receiver, + R: Read { + type Item = WebSocketResult; - type Item = WebSocketResult; - - /// Get the next data frame from the receiver. Always returns `Some`. - fn next(&mut self) -> Option> { - Some(self.inner.recv_dataframe(self.reader)) - } + /// Get the next data frame from the receiver. Always returns `Some`. + fn next(&mut self) -> Option> { + Some(self.inner.recv_dataframe(self.reader)) + } } /// An iterator over messages from a Receiver. pub struct MessageIterator<'a, Recv, D, M, R> - where Recv: 'a + Receiver, - M: Message<'a, D>, - D: DataFrame, - R: 'a + Read, + where Recv: 'a + Receiver, + M: Message<'a, D>, + D: DataFrame, + R: 'a + Read { - reader: &'a mut R, - inner: &'a mut Recv, - _dataframe: PhantomData, - _message: PhantomData, + reader: &'a mut R, + inner: &'a mut Recv, + _dataframe: PhantomData, + _message: PhantomData, } impl<'a, Recv, D, M, I, R> Iterator for MessageIterator<'a, Recv, D, M, R> - where Recv: 'a + Receiver, - M: Message<'a, D, DataFrameIterator = I>, - I: Iterator, - D: DataFrame, - R: Read, + where Recv: 'a + Receiver, + M: Message<'a, D, DataFrameIterator = I>, + I: Iterator, + D: DataFrame, + R: Read { - type Item = WebSocketResult; + type Item = WebSocketResult; - /// Get the next message from the receiver. Always returns `Some`. - fn next(&mut self) -> Option> { - Some(self.inner.recv_message(self.reader)) - } + /// Get the next message from the receiver. Always returns `Some`. + fn next(&mut self) -> Option> { + Some(self.inner.recv_message(self.reader)) + } } diff --git a/src/ws/sender.rs b/src/ws/sender.rs index ca2bde11ea..c179fd275d 100644 --- a/src/ws/sender.rs +++ b/src/ws/sender.rs @@ -11,14 +11,14 @@ use result::WebSocketResult; pub trait Sender { /// Sends a single data frame using this sender. fn send_dataframe(&mut self, writer: &mut W, dataframe: &D) -> WebSocketResult<()> - where D: DataFrame, - W: Write; + where D: DataFrame, + W: Write; /// Sends a single message using this sender. fn send_message<'m, M, D, W>(&mut self, writer: &mut W, message: &'m M) -> WebSocketResult<()> - where M: Message<'m, D>, - D: DataFrame, - W: Write, + where M: Message<'m, D>, + D: DataFrame, + W: Write { for ref dataframe in message.dataframes() { try!(self.send_dataframe(writer, dataframe)); diff --git a/src/ws/util/header.rs b/src/ws/util/header.rs index bdc845bd69..c5de25f209 100644 --- a/src/ws/util/header.rs +++ b/src/ws/util/header.rs @@ -28,41 +28,35 @@ pub struct DataFrameHeader { /// The masking key, if any. pub mask: Option<[u8; 4]>, /// The length of the payload. - pub len: u64 + pub len: u64, } /// Writes a data frame header. pub fn write_header(writer: &mut W, header: DataFrameHeader) -> WebSocketResult<()> - where W: Write { + where W: Write +{ if header.opcode > 0xF { - return Err(WebSocketError::DataFrameError( - "Invalid data frame opcode" - )); + return Err(WebSocketError::DataFrameError("Invalid data frame opcode")); } if header.opcode >= 8 && header.len >= 126 { - return Err(WebSocketError::DataFrameError( - "Control frame length too long" - )); + return Err(WebSocketError::DataFrameError("Control frame length too long")); } // Write 'FIN', 'RSV1', 'RSV2', 'RSV3' and 'opcode' try!(writer.write_u8((header.flags.bits) | header.opcode)); - try!(writer.write_u8( - // Write the 'MASK' - if header.mask.is_some() { 0x80 } else { 0x00 } | + try!(writer.write_u8(// Write the 'MASK' + if header.mask.is_some() { 0x80 } else { 0x00 } | // Write the 'Payload len' if header.len <= 125 { header.len as u8 } else if header.len <= 65535 { 126 } - else { 127 } - )); + else { 127 })); // Write 'Extended payload length' if header.len >= 126 && header.len <= 65535 { try!(writer.write_u16::(header.len as u16)); - } - else if header.len > 65535 { + } else if header.len > 65535 { try!(writer.write_u64::(header.len)); } @@ -77,7 +71,8 @@ pub fn write_header(writer: &mut W, header: DataFrameHeader) -> WebSocketResu /// Reads a data frame header. pub fn read_header(reader: &mut R) -> WebSocketResult - where R: Read { + where R: Read +{ let byte0 = try!(reader.read_u8()); let byte1 = try!(reader.read_u8()); @@ -90,18 +85,14 @@ pub fn read_header(reader: &mut R) -> WebSocketResult 126 => { let len = try!(reader.read_u16::()) as u64; if len <= 125 { - return Err(WebSocketError::DataFrameError( - "Invalid data frame length" - )); + return Err(WebSocketError::DataFrameError("Invalid data frame length")); } len } 127 => { let len = try!(reader.read_u64::()); if len <= 65535 { - return Err(WebSocketError::DataFrameError( - "Invalid data frame length" - )); + return Err(WebSocketError::DataFrameError("Invalid data frame length")); } len } @@ -110,14 +101,10 @@ pub fn read_header(reader: &mut R) -> WebSocketResult if opcode >= 8 { if len >= 126 { - return Err(WebSocketError::DataFrameError( - "Control frame length too long" - )); + return Err(WebSocketError::DataFrameError("Control frame length too long")); } if !flags.contains(FIN) { - return Err(WebSocketError::ProtocolError( - "Illegal fragmented control frame" - )); + return Err(WebSocketError::ProtocolError("Illegal fragmented control frame")); } } @@ -126,19 +113,18 @@ pub fn read_header(reader: &mut R) -> WebSocketResult try!(reader.read_u8()), try!(reader.read_u8()), try!(reader.read_u8()), - try!(reader.read_u8()) + try!(reader.read_u8()), ]) - } - else { + } else { None }; Ok(DataFrameHeader { - flags: flags, - opcode: opcode, - mask: mask, - len: len - }) + flags: flags, + opcode: opcode, + mask: mask, + len: len, + }) } #[cfg(all(feature = "nightly", test))] @@ -153,7 +139,7 @@ mod tests { flags: FIN, opcode: 1, mask: None, - len: 43 + len: 43, }; assert_eq!(obtained, expected); } @@ -163,7 +149,7 @@ mod tests { flags: FIN, opcode: 1, mask: None, - len: 43 + len: 43, }; let expected = [0x81, 0x2B]; let mut obtained = Vec::with_capacity(2); @@ -179,7 +165,7 @@ mod tests { flags: RSV1, opcode: 2, mask: Some([2, 4, 8, 16]), - len: 512 + len: 512, }; assert_eq!(obtained, expected); } @@ -189,7 +175,7 @@ mod tests { flags: RSV1, opcode: 2, mask: Some([2, 4, 8, 16]), - len: 512 + len: 512, }; let expected = [0x42, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10]; let mut obtained = Vec::with_capacity(8); @@ -200,9 +186,7 @@ mod tests { #[bench] fn bench_read_header(b: &mut test::Bencher) { let header = vec![0x42u8, 0xFE, 0x02, 0x00, 0x02, 0x04, 0x08, 0x10]; - b.iter(|| { - read_header(&mut &header[..]).unwrap(); - }); + b.iter(|| { read_header(&mut &header[..]).unwrap(); }); } #[bench] fn bench_write_header(b: &mut test::Bencher) { @@ -210,11 +194,9 @@ mod tests { flags: RSV1, opcode: 2, mask: Some([2, 4, 8, 16]), - len: 512 + len: 512, }; let mut writer = Vec::with_capacity(8); - b.iter(|| { - write_header(&mut writer, header).unwrap(); - }); + b.iter(|| { write_header(&mut writer, header).unwrap(); }); } } diff --git a/src/ws/util/mask.rs b/src/ws/util/mask.rs index 21c9219dec..114727881a 100644 --- a/src/ws/util/mask.rs +++ b/src/ws/util/mask.rs @@ -7,39 +7,42 @@ use std::mem; /// Struct to pipe data into another writer, /// while masking the data being written pub struct Masker<'w, W> -where W: Write + 'w { - key: [u8; 4], - pos: usize, - end: &'w mut W, + where W: Write + 'w +{ + key: [u8; 4], + pos: usize, + end: &'w mut W, } impl<'w, W> Masker<'w, W> -where W: Write + 'w { - /// Create a new Masker with the key and the endpoint - /// to be writter to. - pub fn new(key: [u8; 4], endpoint: &'w mut W) -> Self { - Masker { - key: key, - pos: 0, - end: endpoint, - } - } + where W: Write + 'w +{ + /// Create a new Masker with the key and the endpoint + /// to be writter to. + pub fn new(key: [u8; 4], endpoint: &'w mut W) -> Self { + Masker { + key: key, + pos: 0, + end: endpoint, + } + } } impl<'w, W> Write for Masker<'w, W> -where W: Write + 'w { - fn write(&mut self, data: &[u8]) -> IoResult { - let mut buf = Vec::with_capacity(data.len()); - for &byte in data.iter() { - buf.push(byte ^ self.key[self.pos]); - self.pos = (self.pos + 1) % self.key.len(); - } - self.end.write(&buf) - } + where W: Write + 'w +{ + fn write(&mut self, data: &[u8]) -> IoResult { + let mut buf = Vec::with_capacity(data.len()); + for &byte in data.iter() { + buf.push(byte ^ self.key[self.pos]); + self.pos = (self.pos + 1) % self.key.len(); + } + self.end.write(&buf) + } - fn flush(&mut self) -> IoResult<()> { - self.end.flush() - } + fn flush(&mut self) -> IoResult<()> { + self.end.flush() + } } /// Generates a random masking key @@ -50,12 +53,12 @@ pub fn gen_mask() -> [u8; 4] { /// Masks data to send to a server and writes pub fn mask_data(mask: [u8; 4], data: &[u8]) -> Vec { - let mut out = Vec::with_capacity(data.len()); - let zip_iter = data.iter().zip(mask.iter().cycle()); - for (&buf_item, &key_item) in zip_iter { - out.push(buf_item ^ key_item); - } - out + let mut out = Vec::with_capacity(data.len()); + let zip_iter = data.iter().zip(mask.iter().cycle()); + for (&buf_item, &key_item) in zip_iter { + out.push(buf_item ^ key_item); + } + out } #[cfg(all(feature = "nightly", test))] @@ -79,16 +82,16 @@ mod tests { let buffer = b"The quick brown fox jumps over the lazy dog"; let key = gen_mask(); b.iter(|| { - let mut output = mask_data(key, buffer); - test::black_box(&mut output); - }); + let mut output = mask_data(key, buffer); + test::black_box(&mut output); + }); } #[bench] fn bench_gen_mask(b: &mut test::Bencher) { b.iter(|| { - let mut key = gen_mask(); - test::black_box(&mut key); - }); + let mut key = gen_mask(); + test::black_box(&mut key); + }); } } From c828d9f4c4df3230963359d5847bb2fc093163e0 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Wed, 29 Mar 2017 02:35:34 -0400 Subject: [PATCH 27/32] added buffering reads and writes to roadmap --- ROADMAP.md | 11 +++++++++++ src/client/builder.rs | 3 +-- src/receiver.rs | 1 - 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/ROADMAP.md b/ROADMAP.md index 3e8910742b..afd557c319 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -28,3 +28,14 @@ So maybe we should _just_ add `tokio` support, or maybe `mio` is still used and We need this to pass more autobahn tests! +### Buffer Reads and Writes + +In the old crate the stream was split up into a reader and writer stream so you could +have both a `BufReader` and a `BufWriter` to buffer your operations to gain some speed. +However is doesn't make sense to split the stream up anymore +(see [#83](https://github.com/cyderize/rust-websocket/issues/83)) +meaning that we should buffer reads and writes in some other way. + +Some work has begun on this, like [#91](https://github.com/cyderize/rust-websocket/pull/91), +but is this enough? And what about writing? + diff --git a/src/client/builder.rs b/src/client/builder.rs index 30bc2d2dbf..3037649ff5 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -231,7 +231,6 @@ impl<'u> ClientBuilder<'u> { self.connect_on(ssl_stream) } - // TODO: refactor and split apart into two parts, for when evented happens pub fn connect_on(&mut self, mut stream: S) -> WebSocketResult> where S: Stream { @@ -272,7 +271,7 @@ impl<'u> ClientBuilder<'u> { try!(write!(stream.writer(), "{}\r\n", self.headers)); // wait for a response - // TODO: we should buffer it all, how to set up stream for this? + // TODO: some extra data might get lost with this reader, try to avoid #72 let response = try!(parse_response(&mut BufReader::new(stream.reader()))); let status = StatusCode::from_u16(response.subject.0); diff --git a/src/receiver.rs b/src/receiver.rs index aa1a7e6760..6c55b1af9f 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -12,7 +12,6 @@ use ws::receiver::{MessageIterator, DataFrameIterator}; use stream::{AsTcpStream, Stream}; pub use stream::Shutdown; -// TODO: buffer the readers pub struct Reader where R: Read { From 66a132a44b7fd9e2753091426b6bfccc463f4b9a Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Wed, 29 Mar 2017 11:21:48 -0400 Subject: [PATCH 28/32] easier wsupgrade api for protocols/exts, ability for client to verify protocols, etc. --- examples/hyper.rs | 3 +- examples/server.rs | 10 +----- src/client/builder.rs | 21 +---------- src/client/mod.rs | 33 ++++++++++++++++- src/lib.rs | 20 +++++++++++ src/server/upgrade/hyper.rs | 1 + src/server/upgrade/mod.rs | 71 +++++++++++++++++++++++++++---------- 7 files changed, 109 insertions(+), 50 deletions(-) diff --git a/examples/hyper.rs b/examples/hyper.rs index 33a52edaae..f5b8b21248 100644 --- a/examples/hyper.rs +++ b/examples/hyper.rs @@ -38,8 +38,7 @@ fn main() { return; } - // TODO: same check like in server.rs - let mut client = connection.accept().unwrap(); + let mut client = connection.use_protocol("rust-websocket").accept().unwrap(); let ip = client.peer_addr().unwrap(); diff --git a/examples/server.rs b/examples/server.rs index 776c396d34..dd11f5a098 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -4,14 +4,6 @@ use std::thread; use websocket::{Server, Message}; use websocket::message::Type; -// TODO: I think the .reject() call is only for malformed packets -// there should be an easy way to accept the socket with the given protocols -// this would mean there should be a way to accept or reject on the client -// Do you send the protocol you want to talk when you are not given it as an -// option? What is a rejection response? Does the client check for it? -// Client should expose what the decided protocols/extensions/etc are. -// can you accept only one protocol?? - fn main() { let server = Server::bind("127.0.0.1:2794").unwrap(); @@ -23,7 +15,7 @@ fn main() { return; } - let mut client = request.accept().unwrap(); + let mut client = request.use_protocol("rust-websocket").accept().unwrap(); let ip = client.peer_addr().unwrap(); diff --git a/src/client/builder.rs b/src/client/builder.rs index 3037649ff5..28ba2905b4 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -20,25 +20,6 @@ use result::{WSUrlErrorKind, WebSocketResult, WebSocketError}; use stream::Stream; use super::Client; -macro_rules! upsert_header { - ($headers:expr; $header:ty; { - Some($pat:pat) => $some_match:expr, - None => $default:expr - }) => {{ - match $headers.has::<$header>() { - true => { - match $headers.get_mut::<$header>() { - Some($pat) => { $some_match; }, - None => (), - }; - } - false => { - $headers.set($default); - }, - }; - }} -} - /// Build clients with a builder-style API #[derive(Clone, Debug)] pub struct ClientBuilder<'u> { @@ -305,6 +286,6 @@ impl<'u> ClientBuilder<'u> { return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); } - Ok(Client::unchecked(stream)) + Ok(Client::unchecked(stream, response.headers)) } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 1cb50af356..877aad9b4d 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -4,6 +4,7 @@ extern crate url; use std::net::TcpStream; use std::net::SocketAddr; use std::io::Result as IoResult; +use hyper::header::Headers; use ws; use ws::sender::Sender as SenderTrait; @@ -12,6 +13,8 @@ use ws::receiver::Receiver as ReceiverTrait; use result::WebSocketResult; use stream::{AsTcpStream, Stream, Splittable, Shutdown}; use dataframe::DataFrame; +use header::{WebSocketProtocol, WebSocketExtensions, Origin}; +use header::extensions::Extension; use ws::dataframe::DataFrame as DataFrameable; use sender::Sender; @@ -54,6 +57,7 @@ pub struct Client where S: Stream { pub stream: S, + headers: Headers, sender: Sender, receiver: Receiver, } @@ -109,8 +113,9 @@ impl Client /// **without sending any handshake** this is meant to only be used with /// a stream that has a websocket connection already set up. /// If in doubt, don't use this! - pub fn unchecked(stream: S) -> Self { + pub fn unchecked(stream: S, headers: Headers) -> Self { Client { + headers: headers, stream: stream, // NOTE: these are always true & false, see // https://tools.ietf.org/html/rfc6455#section-5 @@ -152,6 +157,28 @@ impl Client self.receiver.recv_message(self.stream.reader()) } + pub fn headers(&self) -> &Headers { + &self.headers + } + + pub fn protocols(&self) -> &[String] { + self.headers + .get::() + .map(|p| p.0.as_slice()) + .unwrap_or(&[]) + } + + pub fn extensions(&self) -> &[Extension] { + self.headers + .get::() + .map(|e| e.0.as_slice()) + .unwrap_or(&[]) + } + + pub fn origin(&self) -> Option<&str> { + self.headers.get::().map(|o| &o.0 as &str) + } + pub fn stream_ref(&self) -> &S { &self.stream } @@ -160,6 +187,10 @@ impl Client &mut self.stream } + pub fn into_stream(self) -> S { + self.stream + } + /// Returns an iterator over incoming messages. /// ///```no_run diff --git a/src/lib.rs b/src/lib.rs index e1ca80cc63..92d74ba93b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,6 +60,26 @@ pub use self::stream::Stream; pub use self::ws::Sender; pub use self::ws::Receiver; +macro_rules! upsert_header { + ($headers:expr; $header:ty; { + Some($pat:pat) => $some_match:expr, + None => $default:expr + }) => {{ + match $headers.has::<$header>() { + true => { + match $headers.get_mut::<$header>() { + Some($pat) => { $some_match; }, + None => (), + }; + } + false => { + $headers.set($default); + }, + }; + }} +} + + pub mod ws; pub mod client; pub mod server; diff --git a/src/server/upgrade/hyper.rs b/src/server/upgrade/hyper.rs index 42ddd3f382..3d94a29251 100644 --- a/src/server/upgrade/hyper.rs +++ b/src/server/upgrade/hyper.rs @@ -30,6 +30,7 @@ impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { let stream = reader.into_inner().get_mut(); Ok(WsUpgrade { + headers: Headers::new(), stream: stream, request: Incoming { version: version, diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index 941dedafaf..0be1d6950e 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -37,33 +37,65 @@ pub mod hyper; pub struct WsUpgrade where S: Stream { - stream: S, - request: Request, + pub headers: Headers, + pub stream: S, + pub request: Request, } impl WsUpgrade where S: Stream { - pub fn accept(self) -> IoResult> { + pub fn use_protocol

(mut self, protocol: P) -> Self + where P: Into + { + upsert_header!(self.headers; WebSocketProtocol; { + Some(protos) => protos.0.push(protocol.into()), + None => WebSocketProtocol(vec![protocol.into()]) + }); + self + } + + pub fn use_extension(mut self, extension: Extension) -> Self { + upsert_header!(self.headers; WebSocketExtensions; { + Some(protos) => protos.0.push(extension), + None => WebSocketExtensions(vec![extension]) + }); + self + } + + pub fn use_extensions(mut self, extensions: I) -> Self + where I: IntoIterator + { + let mut extensions: Vec = + extensions.into_iter().collect(); + upsert_header!(self.headers; WebSocketExtensions; { + Some(protos) => protos.0.append(&mut extensions), + None => WebSocketExtensions(extensions) + }); + self + } + + pub fn accept(self) -> Result, (S, IoError)> { self.accept_with(&Headers::new()) } - pub fn accept_with(mut self, custom_headers: &Headers) -> IoResult> { - let mut headers = Headers::new(); - headers.extend(custom_headers.iter()); - headers.set(WebSocketAccept::new( - // NOTE: we know there is a key because this is a valid request - // i.e. to construct this you must go through the validate function - self.request.headers.get::().unwrap() - )); - headers.set(Connection(vec![ + pub fn accept_with(mut self, custom_headers: &Headers) -> Result, (S, IoError)> { + self.headers.extend(custom_headers.iter()); + self.headers + .set(WebSocketAccept::new(// NOTE: we know there is a key because this is a valid request + // i.e. to construct this you must go through the validate function + self.request.headers.get::().unwrap())); + self.headers + .set(Connection(vec![ ConnectionOption::ConnectionHeader(UniCase("Upgrade".to_string())) ])); - headers.set(Upgrade(vec![Protocol::new(ProtocolName::WebSocket, None)])); + self.headers.set(Upgrade(vec![Protocol::new(ProtocolName::WebSocket, None)])); - try!(self.send(StatusCode::SwitchingProtocols, &headers)); + if let Err(e) = self.send(StatusCode::SwitchingProtocols) { + return Err((self.stream, e)); + } - Ok(Client::unchecked(self.stream)) + Ok(Client::unchecked(self.stream, self.headers)) } pub fn reject(self) -> Result { @@ -71,7 +103,8 @@ impl WsUpgrade } pub fn reject_with(mut self, headers: &Headers) -> Result { - match self.send(StatusCode::BadRequest, headers) { + self.headers.extend(headers.iter()); + match self.send(StatusCode::BadRequest) { Ok(()) => Ok(self.stream), Err(e) => Err((self.stream, e)), } @@ -113,12 +146,12 @@ impl WsUpgrade self.stream } - fn send(&mut self, status: StatusCode, headers: &Headers) -> IoResult<()> { + fn send(&mut self, status: StatusCode) -> IoResult<()> { try!(write!(self.stream.writer(), "{} {}\r\n", self.request.version, status)); - try!(write!(self.stream.writer(), "{}\r\n", headers)); + try!(write!(self.stream.writer(), "{}\r\n", self.headers)); Ok(()) } } @@ -173,6 +206,7 @@ impl IntoWs for S match validate(&request.subject.0, &request.version, &request.headers) { Ok(_) => { Ok(WsUpgrade { + headers: Headers::new(), stream: self, request: request, }) @@ -192,6 +226,7 @@ impl IntoWs for RequestStreamPair match validate(&self.1.subject.0, &self.1.version, &self.1.headers) { Ok(_) => { Ok(WsUpgrade { + headers: Headers::new(), stream: self.0, request: self.1, }) From df581cf411194b87f717c99c685a6a94a959478c Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Thu, 30 Mar 2017 11:43:02 -0400 Subject: [PATCH 29/32] changed stream trait to impl Read and Write, not contain --- src/client/builder.rs | 19 +++-- src/client/mod.rs | 17 ++-- src/server/upgrade/hyper.rs | 2 + src/server/upgrade/mod.rs | 11 +-- src/stream.rs | 157 ++++++++++++++---------------------- 5 files changed, 82 insertions(+), 124 deletions(-) diff --git a/src/client/builder.rs b/src/client/builder.rs index 28ba2905b4..20aa001a0b 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -1,5 +1,4 @@ use std::borrow::Cow; -use std::io::Write; use std::net::TcpStream; pub use url::{Url, ParseError}; use url::Position; @@ -12,12 +11,11 @@ use unicase::UniCase; #[cfg(feature="ssl")] use openssl::ssl::{SslMethod, SslStream, SslConnector, SslConnectorBuilder}; #[cfg(feature="ssl")] -use stream::BoxedNetworkStream; use header::extensions::Extension; use header::{WebSocketAccept, WebSocketKey, WebSocketVersion, WebSocketProtocol, WebSocketExtensions, Origin}; use result::{WSUrlErrorKind, WebSocketResult, WebSocketError}; -use stream::Stream; +use stream::{Stream, NetworkStream}; use super::Client; /// Build clients with a builder-style API @@ -182,13 +180,14 @@ impl<'u> ClientBuilder<'u> { pub fn connect( &mut self, ssl_config: Option, - ) -> WebSocketResult> { + ) -> WebSocketResult>> { let tcp_stream = try!(self.establish_tcp(None)); - let boxed_stream = if self.url.scheme() == "wss" { - BoxedNetworkStream(Box::new(try!(self.wrap_ssl(tcp_stream, ssl_config)))) + let boxed_stream: Box = if + self.url.scheme() == "wss" { + Box::new(try!(self.wrap_ssl(tcp_stream, ssl_config))) } else { - BoxedNetworkStream(Box::new(tcp_stream)) + Box::new(tcp_stream) }; self.connect_on(boxed_stream) @@ -248,12 +247,12 @@ impl<'u> ClientBuilder<'u> { } // send request - try!(write!(stream.writer(), "GET {} {}\r\n", resource, self.version)); - try!(write!(stream.writer(), "{}\r\n", self.headers)); + try!(write!(stream, "GET {} {}\r\n", resource, self.version)); + try!(write!(stream, "{}\r\n", self.headers)); // wait for a response // TODO: some extra data might get lost with this reader, try to avoid #72 - let response = try!(parse_response(&mut BufReader::new(stream.reader()))); + let response = try!(parse_response(&mut BufReader::new(&mut stream))); let status = StatusCode::from_u16(response.subject.0); // validate diff --git a/src/client/mod.rs b/src/client/mod.rs index 877aad9b4d..8295226560 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -128,7 +128,7 @@ impl Client pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> where D: DataFrameable { - self.sender.send_dataframe(self.stream.writer(), dataframe) + self.sender.send_dataframe(&mut self.stream, dataframe) } /// Sends a single message to the remote endpoint. @@ -136,17 +136,17 @@ impl Client where M: ws::Message<'m, D>, D: DataFrameable { - self.sender.send_message(self.stream.writer(), message) + self.sender.send_message(&mut self.stream, message) } /// Reads a single data frame from the remote endpoint. pub fn recv_dataframe(&mut self) -> WebSocketResult { - self.receiver.recv_dataframe(self.stream.reader()) + self.receiver.recv_dataframe(&mut self.stream) } /// Returns an iterator over incoming data frames. - pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, S::Reader> { - self.receiver.incoming_dataframes(self.stream.reader()) + pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, S> { + self.receiver.incoming_dataframes(&mut self.stream) } /// Reads a single message from this receiver. @@ -154,7 +154,7 @@ impl Client where M: ws::Message<'m, DataFrame, DataFrameIterator = I>, I: Iterator { - self.receiver.recv_message(self.stream.reader()) + self.receiver.recv_message(&mut self.stream) } pub fn headers(&self) -> &Headers { @@ -229,12 +229,11 @@ impl Client ///} ///# } ///``` - pub fn incoming_messages<'a, M, D>(&'a mut self,) - -> MessageIterator<'a, Receiver, D, M, S::Reader> + pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, S> where M: ws::Message<'a, D>, D: DataFrameable { - self.receiver.incoming_messages(self.stream.reader()) + self.receiver.incoming_messages(&mut self.stream) } } diff --git a/src/server/upgrade/hyper.rs b/src/server/upgrade/hyper.rs index 3d94a29251..a70c06182e 100644 --- a/src/server/upgrade/hyper.rs +++ b/src/server/upgrade/hyper.rs @@ -27,6 +27,8 @@ impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { let (_, method, headers, uri, version, reader) = self.0.deconstruct(); + + // TODO: some extra data might get lost with this reader, try to avoid #72 let stream = reader.into_inner().get_mut(); Ok(WsUpgrade { diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index 0be1d6950e..1c75f1744f 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -7,7 +7,6 @@ use std::net::TcpStream; use std::io; use std::io::Result as IoResult; use std::io::Error as IoError; -use std::io::Write; use std::fmt::{self, Formatter, Display}; use stream::{Stream, AsTcpStream}; use header::extensions::Extension; @@ -147,11 +146,8 @@ impl WsUpgrade } fn send(&mut self, status: StatusCode) -> IoResult<()> { - try!(write!(self.stream.writer(), - "{} {}\r\n", - self.request.version, - status)); - try!(write!(self.stream.writer(), "{}\r\n", self.headers)); + try!(write!(&mut self.stream, "{} {}\r\n", self.request.version, status)); + try!(write!(&mut self.stream, "{}\r\n", self.headers)); Ok(()) } } @@ -194,7 +190,8 @@ impl IntoWs for S fn into_ws(mut self) -> Result, Self::Error> { let request = { - let mut reader = BufReader::new(self.reader()); + // TODO: some extra data might get lost with this reader, try to avoid #72 + let mut reader = BufReader::new(&mut self); parse_request(&mut reader) }; diff --git a/src/stream.rs b/src/stream.rs index 1a0962df1c..67cca057e5 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -1,37 +1,31 @@ //! Provides the default stream type for WebSocket connections. use std::ops::Deref; +use std::fmt::Arguments; use std::io::{self, Read, Write}; pub use std::net::TcpStream; pub use std::net::Shutdown; #[cfg(feature="ssl")] pub use openssl::ssl::{SslStream, SslContext}; -pub trait Splittable { - type Reader: Read; - type Writer: Write; - - fn split(self) -> io::Result<(Self::Reader, Self::Writer)>; -} - /// Represents a stream that can be read from, and written to. /// This is an abstraction around readable and writable things to be able /// to speak websockets over ssl, tcp, unix sockets, etc. -pub trait Stream { +pub trait Stream: Read + Write {} + +impl Stream for S where S: Read + Write {} + +pub trait NetworkStream: Read + Write + AsTcpStream {} + +impl NetworkStream for S where S: Read + Write + AsTcpStream {} + +pub trait Splittable { type Reader: Read; type Writer: Write; - /// Get a mutable borrow to the reading component of this stream - fn reader(&mut self) -> &mut Self::Reader; - - /// Get a mutable borrow to the writing component of this stream - fn writer(&mut self) -> &mut Self::Writer; + fn split(self) -> io::Result<(Self::Reader, Self::Writer)>; } -pub struct ReadWritePair(pub R, pub W) - where R: Read, - W: Write; - impl Splittable for ReadWritePair where R: Read, W: Write @@ -44,70 +38,6 @@ impl Splittable for ReadWritePair } } -impl Stream for ReadWritePair - where R: Read, - W: Write -{ - type Reader = R; - type Writer = W; - - #[inline] - fn reader(&mut self) -> &mut R { - &mut self.0 - } - - #[inline] - fn writer(&mut self) -> &mut W { - &mut self.1 - } -} - -pub trait ReadWrite: Read + Write {} -impl ReadWrite for S where S: Read + Write {} - -pub struct BoxedStream(pub Box); - -impl Stream for BoxedStream { - type Reader = Box; - type Writer = Box; - - #[inline] - fn reader(&mut self) -> &mut Self::Reader { - &mut self.0 - } - - #[inline] - fn writer(&mut self) -> &mut Self::Writer { - &mut self.0 - } -} - -pub trait NetworkStream: Read + Write + AsTcpStream {} -impl NetworkStream for S where S: Read + Write + AsTcpStream {} - -pub struct BoxedNetworkStream(pub Box); - -impl AsTcpStream for BoxedNetworkStream { - fn as_tcp(&self) -> &TcpStream { - self.0.deref().as_tcp() - } -} - -impl Stream for BoxedNetworkStream { - type Reader = Box; - type Writer = Box; - - #[inline] - fn reader(&mut self) -> &mut Self::Reader { - &mut self.0 - } - - #[inline] - fn writer(&mut self) -> &mut Self::Writer { - &mut self.0 - } -} - impl Splittable for TcpStream { type Reader = TcpStream; type Writer = TcpStream; @@ -117,23 +47,6 @@ impl Splittable for TcpStream { } } -impl Stream for S - where S: Read + Write -{ - type Reader = Self; - type Writer = Self; - - #[inline] - fn reader(&mut self) -> &mut S { - self - } - - #[inline] - fn writer(&mut self) -> &mut S { - self - } -} - pub trait AsTcpStream { fn as_tcp(&self) -> &TcpStream; } @@ -158,3 +71,51 @@ impl AsTcpStream for Box self.deref().as_tcp() } } + +pub struct ReadWritePair(pub R, pub W) + where R: Read, + W: Write; + +impl Read for ReadWritePair + where R: Read, + W: Write +{ + #[inline(always)] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } + #[inline(always)] + fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { + self.0.read_to_end(buf) + } + #[inline(always)] + fn read_to_string(&mut self, buf: &mut String) -> io::Result { + self.0.read_to_string(buf) + } + #[inline(always)] + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.0.read_exact(buf) + } +} + +impl Write for ReadWritePair + where R: Read, + W: Write +{ + #[inline(always)] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.1.write(buf) + } + #[inline(always)] + fn flush(&mut self) -> io::Result<()> { + self.1.flush() + } + #[inline(always)] + fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { + self.1.write_all(buf) + } + #[inline(always)] + fn write_fmt(&mut self, fmt: Arguments) -> io::Result<()> { + self.1.write_fmt(fmt) + } +} From 134e3cb6e665a331d0dc465d6df6e8ceeefff305 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Thu, 30 Mar 2017 21:46:15 -0400 Subject: [PATCH 30/32] keep track of all the used buffers in BufReaders so no data goes missing this is possible with the help of hyperium/hyper#1107 --- Cargo.toml | 2 +- src/client/builder.rs | 11 +++++---- src/client/mod.rs | 47 ++++++++++++++++++++++--------------- src/receiver.rs | 13 ++++++---- src/server/mod.rs | 12 +++++++--- src/server/upgrade/hyper.rs | 12 +++++++--- src/server/upgrade/mod.rs | 40 ++++++++++++++++++++++--------- 7 files changed, 90 insertions(+), 47 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6fed692717..d280de4bc1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ keywords = ["websocket", "websockets", "rfc6455"] license = "MIT" [dependencies] -hyper = "^0.10" +hyper = { git = "https://github.com/hyperium/hyper.git", branch = "0.10.x" } unicase = "^1.0" url = "^1.0" rustc-serialize = "^0.3" diff --git a/src/client/builder.rs b/src/client/builder.rs index 20aa001a0b..52e20b773a 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -10,12 +10,13 @@ use hyper::header::{Headers, Host, Connection, ConnectionOption, Upgrade, Protoc use unicase::UniCase; #[cfg(feature="ssl")] use openssl::ssl::{SslMethod, SslStream, SslConnector, SslConnectorBuilder}; -#[cfg(feature="ssl")] use header::extensions::Extension; use header::{WebSocketAccept, WebSocketKey, WebSocketVersion, WebSocketProtocol, WebSocketExtensions, Origin}; use result::{WSUrlErrorKind, WebSocketResult, WebSocketError}; -use stream::{Stream, NetworkStream}; +#[cfg(feature="ssl")] +use stream::NetworkStream; +use stream::Stream; use super::Client; /// Build clients with a builder-style API @@ -251,8 +252,8 @@ impl<'u> ClientBuilder<'u> { try!(write!(stream, "{}\r\n", self.headers)); // wait for a response - // TODO: some extra data might get lost with this reader, try to avoid #72 - let response = try!(parse_response(&mut BufReader::new(&mut stream))); + let mut reader = BufReader::new(stream); + let response = try!(parse_response(&mut reader)); let status = StatusCode::from_u16(response.subject.0); // validate @@ -285,6 +286,6 @@ impl<'u> ClientBuilder<'u> { return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); } - Ok(Client::unchecked(stream, response.headers)) + Ok(Client::unchecked(reader, response.headers)) } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 8295226560..4210515d11 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -4,7 +4,9 @@ extern crate url; use std::net::TcpStream; use std::net::SocketAddr; use std::io::Result as IoResult; +use std::io::{Read, Write}; use hyper::header::Headers; +use hyper::buffer::BufReader; use ws; use ws::sender::Sender as SenderTrait; @@ -56,7 +58,7 @@ pub use self::builder::{ClientBuilder, Url, ParseError}; pub struct Client where S: Stream { - pub stream: S, + stream: BufReader, headers: Headers, sender: Sender, receiver: Receiver, @@ -66,13 +68,13 @@ impl Client { /// Shuts down the sending half of the client connection, will cause all pending /// and future IO to return immediately with an appropriate value. pub fn shutdown_sender(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Write) + self.stream.get_ref().as_tcp().shutdown(Shutdown::Write) } /// Shuts down the receiving half of the client connection, will cause all pending /// and future IO to return immediately with an appropriate value. pub fn shutdown_receiver(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Read) + self.stream.get_ref().as_tcp().shutdown(Shutdown::Read) } } @@ -82,27 +84,27 @@ impl Client /// Shuts down the client connection, will cause all pending and future IO to /// return immediately with an appropriate value. pub fn shutdown(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Both) + self.stream.get_ref().as_tcp().shutdown(Shutdown::Both) } /// See `TcpStream.peer_addr()`. pub fn peer_addr(&self) -> IoResult { - self.stream.as_tcp().peer_addr() + self.stream.get_ref().as_tcp().peer_addr() } /// See `TcpStream.local_addr()`. pub fn local_addr(&self) -> IoResult { - self.stream.as_tcp().local_addr() + self.stream.get_ref().as_tcp().local_addr() } /// See `TcpStream.set_nodelay()`. pub fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { - self.stream.as_tcp().set_nodelay(nodelay) + self.stream.get_ref().as_tcp().set_nodelay(nodelay) } /// Changes whether the stream is in nonblocking mode. pub fn set_nonblocking(&self, nonblocking: bool) -> IoResult<()> { - self.stream.as_tcp().set_nonblocking(nonblocking) + self.stream.get_ref().as_tcp().set_nonblocking(nonblocking) } } @@ -113,7 +115,7 @@ impl Client /// **without sending any handshake** this is meant to only be used with /// a stream that has a websocket connection already set up. /// If in doubt, don't use this! - pub fn unchecked(stream: S, headers: Headers) -> Self { + pub fn unchecked(stream: BufReader, headers: Headers) -> Self { Client { headers: headers, stream: stream, @@ -128,7 +130,7 @@ impl Client pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> where D: DataFrameable { - self.sender.send_dataframe(&mut self.stream, dataframe) + self.sender.send_dataframe(self.stream.get_mut(), dataframe) } /// Sends a single message to the remote endpoint. @@ -136,7 +138,7 @@ impl Client where M: ws::Message<'m, D>, D: DataFrameable { - self.sender.send_message(&mut self.stream, message) + self.sender.send_message(self.stream.get_mut(), message) } /// Reads a single data frame from the remote endpoint. @@ -145,7 +147,7 @@ impl Client } /// Returns an iterator over incoming data frames. - pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, S> { + pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, BufReader> { self.receiver.incoming_dataframes(&mut self.stream) } @@ -180,15 +182,20 @@ impl Client } pub fn stream_ref(&self) -> &S { - &self.stream + self.stream.get_ref() } - pub fn stream_ref_mut(&mut self) -> &mut S { + pub fn writer_mut(&mut self) -> &mut Write { + self.stream.get_mut() + } + + pub fn reader_mut(&mut self) -> &mut Read { &mut self.stream } - pub fn into_stream(self) -> S { - self.stream + pub fn into_stream(self) -> (S, Option<(Vec, usize, usize)>) { + let (stream, buf, pos, cap) = self.stream.into_parts(); + (stream, Some((buf, pos, cap))) } /// Returns an iterator over incoming messages. @@ -229,7 +236,8 @@ impl Client ///} ///# } ///``` - pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, S> + pub fn incoming_messages<'a, M, D>(&'a mut self,) + -> MessageIterator<'a, Receiver, D, M, BufReader> where M: ws::Message<'a, D>, D: DataFrameable { @@ -269,9 +277,10 @@ impl Client pub fn split (self,) -> IoResult<(Reader<::Reader>, Writer<::Writer>)> { - let (read, write) = try!(self.stream.split()); + let (stream, buf, pos, cap) = self.stream.into_parts(); + let (read, write) = try!(stream.split()); Ok((Reader { - stream: read, + stream: BufReader::from_parts(read, buf, pos, cap), receiver: self.receiver, }, Writer { diff --git a/src/receiver.rs b/src/receiver.rs index 6c55b1af9f..0f1bd25e54 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -3,6 +3,8 @@ use std::io::Read; use std::io::Result as IoResult; +use hyper::buffer::BufReader; + use dataframe::{DataFrame, Opcode}; use result::{WebSocketResult, WebSocketError}; use ws; @@ -15,7 +17,7 @@ pub use stream::Shutdown; pub struct Reader where R: Read { - pub stream: R, + pub stream: BufReader, pub receiver: Receiver, } @@ -28,7 +30,7 @@ impl Reader } /// Returns an iterator over incoming data frames. - pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, R> { + pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, BufReader> { self.receiver.incoming_dataframes(&mut self.stream) } @@ -40,7 +42,8 @@ impl Reader self.receiver.recv_message(&mut self.stream) } - pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, R> + pub fn incoming_messages<'a, M, D>(&'a mut self,) + -> MessageIterator<'a, Receiver, D, M, BufReader> where M: ws::Message<'a, D>, D: DataFrameable { @@ -54,13 +57,13 @@ impl Reader /// Closes the receiver side of the connection, will cause all pending and future IO to /// return immediately with an appropriate value. pub fn shutdown(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Read) + self.stream.get_ref().as_tcp().shutdown(Shutdown::Read) } /// Shuts down both Sender and Receiver, will cause all pending and future IO to /// return immediately with an appropriate value. pub fn shutdown_all(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Both) + self.stream.get_ref().as_tcp().shutdown(Shutdown::Both) } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 5a8af0be3d..dc3cbd283d 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -5,7 +5,7 @@ use std::convert::Into; #[cfg(feature="ssl")] use openssl::ssl::{SslStream, SslAcceptor}; use stream::Stream; -use self::upgrade::{WsUpgrade, IntoWs}; +use self::upgrade::{WsUpgrade, IntoWs, Buffer}; pub use self::upgrade::{Request, HyperIntoWsError}; pub mod upgrade; @@ -15,6 +15,7 @@ pub struct InvalidConnection { pub stream: Option, pub parsed: Option, + pub buffer: Option, pub error: HyperIntoWsError, } @@ -150,6 +151,7 @@ impl Server { return Err(InvalidConnection { stream: None, parsed: None, + buffer: None, error: e.into(), }) } @@ -161,6 +163,7 @@ impl Server { return Err(InvalidConnection { stream: None, parsed: None, + buffer: None, error: io::Error::new(io::ErrorKind::Other, err).into(), }) } @@ -168,10 +171,11 @@ impl Server { match stream.into_ws() { Ok(u) => Ok(u), - Err((s, r, e)) => { + Err((s, r, b, e)) => { Err(InvalidConnection { stream: Some(s), parsed: r, + buffer: b, error: e.into(), }) } @@ -213,6 +217,7 @@ impl Server { return Err(InvalidConnection { stream: None, parsed: None, + buffer: None, error: e.into(), }) } @@ -220,10 +225,11 @@ impl Server { match stream.into_ws() { Ok(u) => Ok(u), - Err((s, r, e)) => { + Err((s, r, b, e)) => { Err(InvalidConnection { stream: Some(s), parsed: r, + buffer: b, error: e.into(), }) } diff --git a/src/server/upgrade/hyper.rs b/src/server/upgrade/hyper.rs index a70c06182e..1e44a6e1a2 100644 --- a/src/server/upgrade/hyper.rs +++ b/src/server/upgrade/hyper.rs @@ -1,7 +1,7 @@ extern crate hyper; use hyper::net::NetworkStream; -use super::{IntoWs, WsUpgrade}; +use super::{IntoWs, WsUpgrade, Buffer}; pub use hyper::http::h1::Incoming; pub use hyper::method::Method; @@ -28,12 +28,18 @@ impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { let (_, method, headers, uri, version, reader) = self.0.deconstruct(); - // TODO: some extra data might get lost with this reader, try to avoid #72 - let stream = reader.into_inner().get_mut(); + let reader = reader.into_inner(); + let (buf, pos, cap) = reader.take_buf(); + let stream = reader.get_mut(); Ok(WsUpgrade { headers: Headers::new(), stream: stream, + buffer: Some(Buffer { + buf: buf, + pos: pos, + cap: cap, + }), request: Incoming { version: version, headers: headers, diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index 1c75f1744f..4aeb9ba0ab 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -27,6 +27,12 @@ pub use self::real_hyper::header::{Headers, Upgrade, Protocol, ProtocolName, Con pub mod hyper; +pub struct Buffer { + pub buf: Vec, + pub pos: usize, + pub cap: usize, +} + /// Intermediate representation of a half created websocket session. /// Should be used to examine the client's handshake /// accept the protocols requested, route the path, etc. @@ -39,6 +45,7 @@ pub struct WsUpgrade pub headers: Headers, pub stream: S, pub request: Request, + pub buffer: Option, } impl WsUpgrade @@ -94,7 +101,12 @@ impl WsUpgrade return Err((self.stream, e)); } - Ok(Client::unchecked(self.stream, self.headers)) + let stream = match self.buffer { + Some(Buffer { buf, pos, cap }) => BufReader::from_parts(self.stream, buf, pos, cap), + None => BufReader::new(self.stream), + }; + + Ok(Client::unchecked(stream, self.headers)) } pub fn reject(self) -> Result { @@ -186,29 +198,34 @@ impl IntoWs for S where S: Stream { type Stream = S; - type Error = (Self, Option, HyperIntoWsError); + type Error = (S, Option, Option, HyperIntoWsError); - fn into_ws(mut self) -> Result, Self::Error> { - let request = { - // TODO: some extra data might get lost with this reader, try to avoid #72 - let mut reader = BufReader::new(&mut self); - parse_request(&mut reader) - }; + fn into_ws(self) -> Result, Self::Error> { + let mut reader = BufReader::new(self); + let request = parse_request(&mut reader); + + let (stream, buf, pos, cap) = reader.into_parts(); + let buffer = Some(Buffer { + buf: buf, + cap: cap, + pos: pos, + }); let request = match request { Ok(r) => r, - Err(e) => return Err((self, None, e.into())), + Err(e) => return Err((stream, None, buffer, e.into())), }; match validate(&request.subject.0, &request.version, &request.headers) { Ok(_) => { Ok(WsUpgrade { headers: Headers::new(), - stream: self, + stream: stream, request: request, + buffer: buffer, }) } - Err(e) => Err((self, Some(request), e)), + Err(e) => Err((stream, Some(request), buffer, e)), } } } @@ -226,6 +243,7 @@ impl IntoWs for RequestStreamPair headers: Headers::new(), stream: self.0, request: self.1, + buffer: None, }) } Err(e) => Err((self.0, self.1, e)), From a363b1fe2794dcfe4024c4017835754297a5e680 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Wed, 29 Mar 2017 02:36:06 -0400 Subject: [PATCH 31/32] version bump to 0.18.0: major API change --- Cargo.toml | 2 +- README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d280de4bc1..8d61955c62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "websocket" -version = "0.17.2" +version = "0.18.0" authors = ["cyderize ", "Michael Eden "] description = "A WebSocket (RFC6455) library for Rust." diff --git a/README.md b/README.md index 1d965bc859..1d6702f0a1 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Rust-WebSocket provides a framework for dealing with WebSocket connections (both To add a library release version from [crates.io](https://crates.io/crates/websocket) to a Cargo project, add this to the 'dependencies' section of your Cargo.toml: ```INI -websocket = "0.17.1" +websocket = "0.18.0" ``` To add the library's Git repository to a Cargo project, add this to your Cargo.toml: From 3a60b54c4488c07149454e0bfa69967ae8f9f9f0 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Sat, 1 Apr 2017 01:20:39 -0400 Subject: [PATCH 32/32] docs and doc tests --- src/client/builder.rs | 295 ++++++++++++++++++++++++++++++- src/client/mod.rs | 123 +++++++++++-- src/header/extensions.rs | 2 + src/header/mod.rs | 2 +- src/header/protocol.rs | 2 + src/lib.rs | 2 +- src/receiver.rs | 6 + src/sender.rs | 6 + src/server/mod.rs | 33 +++- src/server/upgrade/from_hyper.rs | 88 +++++++++ src/server/upgrade/hyper.rs | 50 ------ src/server/upgrade/mod.rs | 119 ++++++++++--- src/stream.rs | 18 ++ src/ws/receiver.rs | 1 + 14 files changed, 644 insertions(+), 103 deletions(-) create mode 100644 src/server/upgrade/from_hyper.rs delete mode 100644 src/server/upgrade/hyper.rs diff --git a/src/client/builder.rs b/src/client/builder.rs index 52e20b773a..b16e41a36a 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -1,3 +1,5 @@ +//! Everything you need to create a client connection to a websocket. + use std::borrow::Cow; use std::net::TcpStream; pub use url::{Url, ParseError}; @@ -6,7 +8,8 @@ use hyper::version::HttpVersion; use hyper::status::StatusCode; use hyper::buffer::BufReader; use hyper::http::h1::parse_response; -use hyper::header::{Headers, Host, Connection, ConnectionOption, Upgrade, Protocol, ProtocolName}; +use hyper::header::{Headers, Header, HeaderFormat, Host, Connection, ConnectionOption, Upgrade, + Protocol, ProtocolName}; use unicase::UniCase; #[cfg(feature="ssl")] use openssl::ssl::{SslMethod, SslStream, SslConnector, SslConnectorBuilder}; @@ -20,6 +23,50 @@ use stream::Stream; use super::Client; /// Build clients with a builder-style API +/// This makes it easy to create and configure a websocket +/// connection: +/// +/// The easiest way to connect is like this: +/// +/// ```rust,no_run +/// use websocket::ClientBuilder; +/// +/// let client = ClientBuilder::new("ws://myapp.com") +/// .unwrap() +/// .connect_insecure() +/// .unwrap(); +/// ``` +/// +/// But there are so many more possibilities: +/// +/// ```rust,no_run +/// use websocket::ClientBuilder; +/// use websocket::header::{Headers, Cookie}; +/// +/// let default_protos = vec!["ping", "chat"]; +/// let mut my_headers = Headers::new(); +/// my_headers.set(Cookie(vec!["userid=1".to_owned()])); +/// +/// let mut builder = ClientBuilder::new("ws://myapp.com/room/discussion") +/// .unwrap() +/// .add_protocols(default_protos) // any IntoIterator +/// .add_protocol("video-chat") +/// .custom_headers(&my_headers); +/// +/// // connect to a chat server with a user +/// let client = builder.connect_insecure().unwrap(); +/// +/// // clone the builder and take it with you +/// let not_logged_in = builder +/// .clone() +/// .clear_header::() +/// .connect_insecure().unwrap(); +/// ``` +/// +/// You may have noticed we're not using SSL, have no fear, SSL is included! +/// This crate's openssl dependency is optional (and included by default). +/// One can use `connect_secure` to connect to an SSL service, or simply `connect` +/// to choose either SSL or not based on the protocol (`ws://` or `wss://`). #[derive(Clone, Debug)] pub struct ClientBuilder<'u> { url: Cow<'u, Url>, @@ -30,10 +77,37 @@ pub struct ClientBuilder<'u> { } impl<'u> ClientBuilder<'u> { + /// Create a client builder from an already parsed Url, + /// because there is no need to parse this will never error. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// use websocket::url::Url; + /// + /// // the parsing error will be handled outside the constructor + /// let url = Url::parse("ws://bitcoins.pizza").unwrap(); + /// + /// let builder = ClientBuilder::from_url(&url); + /// ``` + /// The path of a URL is optional if no port is given then port + /// 80 will be used in the case of `ws://` and port `443` will be + /// used in the case of `wss://`. pub fn from_url(address: &'u Url) -> Self { ClientBuilder::init(Cow::Borrowed(address)) } + /// Create a client builder from a URL string, this will + /// attempt to parse the URL immediately and return a `ParseError` + /// if the URL is invalid. URLs must be of the form: + /// `[ws or wss]://[domain]:[port]/[path]` + /// The path of a URL is optional if no port is given then port + /// 80 will be used in the case of `ws://` and port `443` will be + /// used in the case of `wss://`. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// let builder = ClientBuilder::new("wss://mycluster.club"); + /// ``` pub fn new(address: &str) -> Result { let url = try!(Url::parse(address)); Ok(ClientBuilder::init(Cow::Owned(url))) @@ -49,6 +123,18 @@ impl<'u> ClientBuilder<'u> { } } + /// Adds a user-defined protocol to the handshake, the server will be + /// given a list of these protocols and will send back the ones it accepts. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// # use websocket::header::WebSocketProtocol; + /// let builder = ClientBuilder::new("wss://my-twitch-clone.rs").unwrap() + /// .add_protocol("my-chat-proto"); + /// + /// let protos = &builder.get_header::().unwrap().0; + /// assert!(protos.contains(&"my-chat-proto".to_string())); + /// ``` pub fn add_protocol

(mut self, protocol: P) -> Self where P: Into { @@ -59,6 +145,19 @@ impl<'u> ClientBuilder<'u> { self } + /// Adds a user-defined protocols to the handshake. + /// This can take many kinds of iterators. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// # use websocket::header::WebSocketProtocol; + /// let builder = ClientBuilder::new("wss://my-twitch-clone.rs").unwrap() + /// .add_protocols(vec!["pubsub", "sub.events"]); + /// + /// let protos = &builder.get_header::().unwrap().0; + /// assert!(protos.contains(&"pubsub".to_string())); + /// assert!(protos.contains(&"sub.events".to_string())); + /// ``` pub fn add_protocols(mut self, protocols: I) -> Self where I: IntoIterator, S: Into @@ -75,11 +174,31 @@ impl<'u> ClientBuilder<'u> { self } + /// Removes all the currently set protocols. pub fn clear_protocols(mut self) -> Self { self.headers.remove::(); self } + /// Adds an extension to the connection. + /// Unlike protocols, extensions can be below the application level + /// (like compression). Currently no extensions are supported + /// out-of-the-box but one can still use them by using their own + /// implementation. Support is coming soon though. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// # use websocket::header::{WebSocketExtensions}; + /// # use websocket::header::extensions::Extension; + /// let builder = ClientBuilder::new("wss://skype-for-linux-lol.com").unwrap() + /// .add_extension(Extension { + /// name: "permessage-deflate".to_string(), + /// params: vec![], + /// }); + /// + /// let exts = &builder.get_header::().unwrap().0; + /// assert!(exts.first().unwrap().name == "permessage-deflate"); + /// ``` pub fn add_extension(mut self, extension: Extension) -> Self { upsert_header!(self.headers; WebSocketExtensions; { Some(protos) => protos.0.push(extension), @@ -88,6 +207,30 @@ impl<'u> ClientBuilder<'u> { self } + /// Adds some extensions to the connection. + /// Currently no extensions are supported out-of-the-box but one can + /// still use them by using their own implementation. Support is coming soon though. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// # use websocket::header::{WebSocketExtensions}; + /// # use websocket::header::extensions::Extension; + /// let builder = ClientBuilder::new("wss://moxie-chat.org").unwrap() + /// .add_extensions(vec![ + /// Extension { + /// name: "permessage-deflate".to_string(), + /// params: vec![], + /// }, + /// Extension { + /// name: "crypt-omemo".to_string(), + /// params: vec![], + /// }, + /// ]); + /// + /// # let exts = &builder.get_header::().unwrap().0; + /// # assert!(exts.first().unwrap().name == "permessage-deflate"); + /// # assert!(exts.last().unwrap().name == "crypt-omemo"); + /// ``` pub fn add_extensions(mut self, extensions: I) -> Self where I: IntoIterator { @@ -100,47 +243,96 @@ impl<'u> ClientBuilder<'u> { self } + /// Remove all the extensions added to the builder. pub fn clear_extensions(mut self) -> Self { self.headers.remove::(); self } + /// Add a custom `Sec-WebSocket-Key` header. + /// Use this only if you know what you're doing, and this almost + /// never has to be used. pub fn key(mut self, key: [u8; 16]) -> Self { self.headers.set(WebSocketKey(key)); self.key_set = true; self } + /// Remove the currently set `Sec-WebSocket-Key` header if any. pub fn clear_key(mut self) -> Self { self.headers.remove::(); self.key_set = false; self } + /// Set the version of the Websocket connection. + /// Currently this library only supports version 13 (from RFC6455), + /// but one could use this library to create the handshake then use an + /// implementation of another websocket version. pub fn version(mut self, version: WebSocketVersion) -> Self { self.headers.set(version); self.version_set = true; self } + /// Unset the websocket version to be the default (WebSocket 13). pub fn clear_version(mut self) -> Self { self.headers.remove::(); self.version_set = false; self } + /// Sets the Origin header of the handshake. + /// Normally in browsers this is used to protect against + /// unauthorized cross-origin use of a WebSocket server, but it is rarely + /// send by non-browser clients. Still, it can be useful. pub fn origin(mut self, origin: String) -> Self { self.headers.set(Origin(origin)); self } - pub fn custom_headers(mut self, edit: F) -> Self - where F: Fn(&mut Headers) + /// Remove the Origin header from the handshake. + pub fn clear_origin(mut self) -> Self { + self.headers.remove::(); + self + } + + /// This is a catch all to add random headers to your handshake, + /// the process here is more manual. + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// # use websocket::header::{Headers, Authorization}; + /// let mut headers = Headers::new(); + /// headers.set(Authorization("let me in".to_owned())); + /// + /// let builder = ClientBuilder::new("ws://moz.illest").unwrap() + /// .custom_headers(&headers); + /// + /// # let hds = &builder.get_header::>().unwrap().0; + /// # assert!(hds == &"let me in".to_string()); + /// ``` + pub fn custom_headers(mut self, custom_headers: &Headers) -> Self { + self.headers.extend(custom_headers.iter()); + self + } + + /// Remove a type of header from the handshake, this is to be used + /// with the catch all `custom_headers`. + pub fn clear_header(mut self) -> Self + where H: Header + HeaderFormat { - edit(&mut self.headers); + self.headers.remove::(); self } + /// Get a header to inspect it. + pub fn get_header(&self) -> Option<&H> + where H: Header + HeaderFormat + { + self.headers.get::() + } + fn establish_tcp(&mut self, secure: Option) -> WebSocketResult { let port = match (self.url.port(), secure) { (Some(port), _) => port, @@ -177,6 +369,25 @@ impl<'u> ClientBuilder<'u> { Ok(ssl_stream) } + /// Connect to a server (finally)! + /// This will use a `Box` to represent either an SSL + /// connection or a normal TCP connection, what to use will be decided + /// using the protocol of the URL passed in (e.g. `ws://` or `wss://`) + /// + /// If you have non-default SSL circumstances, you can use the `ssl_config` + /// parameter to configure those. + /// + /// ```rust,no_run + /// # use websocket::ClientBuilder; + /// # use websocket::Message; + /// let mut client = ClientBuilder::new("wss://supersecret.l33t").unwrap() + /// .connect(None) + /// .unwrap(); + /// + /// // send messages! + /// let message = Message::text("m337 47 7pm"); + /// client.send_message(&message).unwrap(); + /// ``` #[cfg(feature="ssl")] pub fn connect( &mut self, @@ -194,12 +405,30 @@ impl<'u> ClientBuilder<'u> { self.connect_on(boxed_stream) } + /// Create an insecure (plain TCP) connection to the client. + /// In this case no `Box` will be used you will just get a TcpStream, + /// giving you the ability to split the stream into a reader and writer + /// (since SSL streams cannot be cloned). + /// + /// ```rust,no_run + /// # use websocket::ClientBuilder; + /// let mut client = ClientBuilder::new("wss://supersecret.l33t").unwrap() + /// .connect_insecure() + /// .unwrap(); + /// + /// // split into two (for some reason)! + /// let (receiver, sender) = client.split().unwrap(); + /// ``` pub fn connect_insecure(&mut self) -> WebSocketResult> { let tcp_stream = try!(self.establish_tcp(Some(false))); self.connect_on(tcp_stream) } + /// Create an SSL connection to the sever. + /// This will only use an `SslStream`, this is useful + /// when you want to be sure to connect over SSL or when you want access + /// to the `SslStream` functions (without having to go through a `Box`). #[cfg(feature="ssl")] pub fn connect_secure( &mut self, @@ -212,6 +441,36 @@ impl<'u> ClientBuilder<'u> { self.connect_on(ssl_stream) } + // TODO: similar ability for server? + /// Connects to a websocket server on any stream you would like. + /// Possible streams: + /// - Unix Sockets + /// - Logging Middle-ware + /// - SSH + /// + /// ```rust + /// # use websocket::ClientBuilder; + /// use websocket::stream::ReadWritePair; + /// use std::io::Cursor; + /// + /// let accept = b"HTTP/1.1 101 Switching Protocols\r + /// Upgrade: websocket\r + /// Connection: Upgrade\r + /// Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r + /// \r\n"; + /// + /// let input = Cursor::new(&accept[..]); + /// let output = Cursor::new(Vec::new()); + /// + /// let client = ClientBuilder::new("wss://test.ws").unwrap() + /// .key(b"the sample nonce".clone()) + /// .connect_on(ReadWritePair(input, output)) + /// .unwrap(); + /// + /// let text = (client.into_stream().0).1.into_inner(); + /// let text = String::from_utf8(text).unwrap(); + /// assert!(text.contains("dGhlIHNhbXBsZSBub25jZQ=="), "{}", text); + /// ``` pub fn connect_on(&mut self, mut stream: S) -> WebSocketResult> where S: Stream { @@ -289,3 +548,31 @@ impl<'u> ClientBuilder<'u> { Ok(Client::unchecked(reader, response.headers)) } } + +mod tests { + #[test] + fn build_client_with_protocols() { + use super::*; + let builder = ClientBuilder::new("ws://127.0.0.1:8080/hello/world") + .unwrap() + .add_protocol("protobeard"); + + let protos = &builder.headers.get::().unwrap().0; + assert!(protos.contains(&"protobeard".to_string())); + assert!(protos.len() == 1); + + let builder = ClientBuilder::new("ws://example.org/hello") + .unwrap() + .add_protocol("rust-websocket") + .clear_protocols() + .add_protocols(vec!["electric", "boogaloo"]); + + let protos = &builder.headers.get::().unwrap().0; + + assert!(protos.contains(&"boogaloo".to_string())); + assert!(protos.contains(&"electric".to_string())); + assert!(!protos.contains(&"rust-websocket".to_string())); + } + + // TODO: a few more +} diff --git a/src/client/mod.rs b/src/client/mod.rs index 4210515d11..2e87651c11 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -15,7 +15,7 @@ use ws::receiver::Receiver as ReceiverTrait; use result::WebSocketResult; use stream::{AsTcpStream, Stream, Splittable, Shutdown}; use dataframe::DataFrame; -use header::{WebSocketProtocol, WebSocketExtensions, Origin}; +use header::{WebSocketProtocol, WebSocketExtensions}; use header::extensions::Extension; use ws::dataframe::DataFrame as DataFrameable; @@ -29,18 +29,17 @@ pub use self::builder::{ClientBuilder, Url, ParseError}; /// Represents a WebSocket client, which can send and receive messages/data frames. /// -/// `D` is the data frame type, `S` is the type implementing `Sender` and `R` -/// is the type implementing `Receiver`. +/// The client just wraps around a `Stream` (which is something that can be read from +/// and written to) and handles the websocket protocol. TCP or SSL over TCP is common, +/// but any stream can be used. /// -/// For most cases, the data frame type will be `dataframe::DataFrame`, the Sender -/// type will be `client::Sender` and the receiver type -/// will be `client::Receiver`. -/// -/// A `Client` can be split into a `Sender` and a `Receiver` which can then be moved +/// A `Client` can also be split into a `Reader` and a `Writer` which can then be moved /// to different threads, often using a send loop and receiver loop concurrently, /// as shown in the client example in `examples/client.rs`. +/// This is only possible for streams that implement the `Splittable` trait, which +/// currently is only TCP streams. (it is unsafe to duplicate an SSL stream) /// -///#Connecting to a Server +///# Connecting to a Server /// ///```no_run ///extern crate websocket; @@ -48,8 +47,10 @@ pub use self::builder::{ClientBuilder, Url, ParseError}; /// ///use websocket::{ClientBuilder, Message}; /// -///let mut client = ClientBuilder::new("ws://127.0.0.1:1234").unwrap() -/// .connect(None).unwrap(); +///let mut client = ClientBuilder::new("ws://127.0.0.1:1234") +/// .unwrap() +/// .connect_insecure() +/// .unwrap(); /// ///let message = Message::text("Hello, World!"); ///client.send_message(&message).unwrap(); // Send message @@ -87,17 +88,20 @@ impl Client self.stream.get_ref().as_tcp().shutdown(Shutdown::Both) } - /// See `TcpStream.peer_addr()`. + /// See [`TcpStream::peer_addr`] + /// (https://doc.rust-lang.org/std/net/struct.TcpStream.html#method.peer_addr). pub fn peer_addr(&self) -> IoResult { self.stream.get_ref().as_tcp().peer_addr() } - /// See `TcpStream.local_addr()`. + /// See [`TcpStream::local_addr`] + /// (https://doc.rust-lang.org/std/net/struct.TcpStream.html#method.local_addr). pub fn local_addr(&self) -> IoResult { self.stream.get_ref().as_tcp().local_addr() } - /// See `TcpStream.set_nodelay()`. + /// See [`TcpStream::set_nodelay`] + /// (https://doc.rust-lang.org/std/net/struct.TcpStream.html#method.set_nodelay). pub fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { self.stream.get_ref().as_tcp().set_nodelay(nodelay) } @@ -115,6 +119,7 @@ impl Client /// **without sending any handshake** this is meant to only be used with /// a stream that has a websocket connection already set up. /// If in doubt, don't use this! + #[doc(hidden)] pub fn unchecked(stream: BufReader, headers: Headers) -> Self { Client { headers: headers, @@ -159,10 +164,27 @@ impl Client self.receiver.recv_message(&mut self.stream) } + /// Access the headers that were sent in the server's handshake response. + /// This is a catch all for headers other than protocols and extensions. pub fn headers(&self) -> &Headers { &self.headers } + /// **If you supplied a protocol, you must check that it was accepted by + /// the server** using this function. + /// This is not done automatically because the terms of accepting a protocol + /// can get complicated, especially if some protocols depend on others, etc. + /// + /// ```rust,no_run + /// # use websocket::ClientBuilder; + /// let mut client = ClientBuilder::new("wss://test.fysh.in").unwrap() + /// .add_protocol("xmpp") + /// .connect_insecure() + /// .unwrap(); + /// + /// // be sure to check the protocol is there! + /// assert!(client.protocols().iter().any(|p| p as &str == "xmpp")); + /// ``` pub fn protocols(&self) -> &[String] { self.headers .get::() @@ -170,6 +192,9 @@ impl Client .unwrap_or(&[]) } + /// If you supplied a protocol, be sure to check if it was accepted by the + /// server here. Since no extensions are implemented out of the box yet, using + /// one will require its own implementation. pub fn extensions(&self) -> &[Extension] { self.headers .get::() @@ -177,22 +202,84 @@ impl Client .unwrap_or(&[]) } - pub fn origin(&self) -> Option<&str> { - self.headers.get::().map(|o| &o.0 as &str) - } - + /// Get a reference to the stream. + /// Useful to be able to set options on the stream. + /// + /// ```rust,no_run + /// # use websocket::ClientBuilder; + /// let mut client = ClientBuilder::new("ws://double.down").unwrap() + /// .connect_insecure() + /// .unwrap(); + /// + /// client.stream_ref().set_ttl(60).unwrap(); + /// ``` pub fn stream_ref(&self) -> &S { self.stream.get_ref() } + /// Get a handle to the writable portion of this stream. + /// This can be used to write custom extensions. + /// + /// ```rust,no_run + /// # use websocket::ClientBuilder; + /// use websocket::Message; + /// use websocket::ws::sender::Sender as SenderTrait; + /// use websocket::sender::Sender; + /// + /// let mut client = ClientBuilder::new("ws://the.room").unwrap() + /// .connect_insecure() + /// .unwrap(); + /// + /// let message = Message::text("Oh hi, Mark."); + /// let mut sender = Sender::new(true); + /// let mut buf = Vec::new(); + /// + /// sender.send_message(&mut buf, &message); + /// + /// /* transform buf somehow */ + /// + /// client.writer_mut().write_all(&buf); + /// ``` pub fn writer_mut(&mut self) -> &mut Write { self.stream.get_mut() } + /// Get a handle to the readable portion of this stream. + /// This can be used to transform raw bytes before they + /// are read in. + /// + /// ```rust,no_run + /// # use websocket::ClientBuilder; + /// use std::io::Cursor; + /// use websocket::Message; + /// use websocket::ws::receiver::Receiver as ReceiverTrait; + /// use websocket::receiver::Receiver; + /// + /// let mut client = ClientBuilder::new("ws://the.room").unwrap() + /// .connect_insecure() + /// .unwrap(); + /// + /// let mut receiver = Receiver::new(false); + /// let mut buf = Vec::new(); + /// + /// client.reader_mut().read_to_end(&mut buf); + /// + /// /* transform buf somehow */ + /// + /// let mut buf_reader = Cursor::new(&mut buf); + /// let message: Message = receiver.recv_message(&mut buf_reader).unwrap(); + /// ``` pub fn reader_mut(&mut self) -> &mut Read { &mut self.stream } + /// Deconstruct the client into its underlying stream and + /// maybe some of the buffer that was already read from the stream. + /// The client uses a buffered reader to read in messages, so some + /// bytes might already be read from the stream when this is called, + /// these buffered bytes are returned in the form + /// + /// `(byte_buffer: Vec, buffer_capacity: usize, buffer_position: usize)` pub fn into_stream(self) -> (S, Option<(Vec, usize, usize)>) { let (stream, buf, pos, cap) = self.stream.into_parts(); (stream, Some((buf, pos, cap))) diff --git a/src/header/extensions.rs b/src/header/extensions.rs index d1c2f20bf6..efbb7537f8 100644 --- a/src/header/extensions.rs +++ b/src/header/extensions.rs @@ -10,6 +10,8 @@ use result::{WebSocketResult, WebSocketError}; const INVALID_EXTENSION: &'static str = "Invalid Sec-WebSocket-Extensions extension name"; +// TODO: check if extension name is valid according to spec + /// Represents a Sec-WebSocket-Extensions header #[derive(PartialEq, Clone, Debug)] pub struct WebSocketExtensions(pub Vec); diff --git a/src/header/mod.rs b/src/header/mod.rs index 01a4aa79a4..48a73bfbcc 100644 --- a/src/header/mod.rs +++ b/src/header/mod.rs @@ -9,7 +9,7 @@ pub use self::protocol::WebSocketProtocol; pub use self::version::WebSocketVersion; pub use self::extensions::WebSocketExtensions; pub use self::origin::Origin; -pub use hyper::header::Headers; +pub use hyper::header::*; mod accept; mod key; diff --git a/src/header/protocol.rs b/src/header/protocol.rs index 51773cda8f..9a6e8f1cf0 100644 --- a/src/header/protocol.rs +++ b/src/header/protocol.rs @@ -4,6 +4,8 @@ use hyper; use std::fmt; use std::ops::Deref; +// TODO: only allow valid protocol names to be added + /// Represents a Sec-WebSocket-Protocol header #[derive(PartialEq, Clone, Debug)] pub struct WebSocketProtocol(pub Vec); diff --git a/src/lib.rs b/src/lib.rs index 92d74ba93b..1aa3202534 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -38,7 +38,7 @@ //! level. Their usage is explained in the module documentation. extern crate hyper; extern crate unicase; -extern crate url; +pub extern crate url; extern crate rustc_serialize as serialize; extern crate rand; extern crate byteorder; diff --git a/src/receiver.rs b/src/receiver.rs index 0f1bd25e54..ea3688094f 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -14,10 +14,14 @@ use ws::receiver::{MessageIterator, DataFrameIterator}; use stream::{AsTcpStream, Stream}; pub use stream::Shutdown; +/// This reader bundles an existing stream with a parsing algorithm. +/// It is used by the client in its `.split()` function as the reading component. pub struct Reader where R: Read { + /// the stream to be read from pub stream: BufReader, + /// the parser to parse bytes into messages pub receiver: Receiver, } @@ -42,6 +46,8 @@ impl Reader self.receiver.recv_message(&mut self.stream) } + /// An iterator over incoming messsages. + /// This iterator will block until new messages arrive and will never halt. pub fn incoming_messages<'a, M, D>(&'a mut self,) -> MessageIterator<'a, Receiver, D, M, BufReader> where M: ws::Message<'a, D>, diff --git a/src/sender.rs b/src/sender.rs index e8860cdbef..398520f18d 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -9,8 +9,14 @@ use ws; use ws::sender::Sender as SenderTrait; pub use stream::Shutdown; +/// A writer that bundles a stream with a serializer to send the messages. +/// This is used in the client's `.split()` function as the writing component. +/// +/// It can also be useful to use a websocket connection without a handshake. pub struct Writer { + /// The stream that websocket messages will be written to pub stream: W, + /// The serializer that will be used to serialize the messages pub sender: Sender, } diff --git a/src/server/mod.rs b/src/server/mod.rs index dc3cbd283d..44a5497101 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -10,15 +10,31 @@ pub use self::upgrade::{Request, HyperIntoWsError}; pub mod upgrade; +/// When a sever tries to accept a connection many things can go wrong. +/// +/// This struct is all the information that is recovered from a failed +/// websocket handshake, in case one wants to use the connection for something +/// else (such as HTTP). pub struct InvalidConnection where S: Stream { + /// if the stream was successfully setup it will be included here + /// on a failed connection. pub stream: Option, + /// the parsed request. **This is a normal HTTP request** meaning you can + /// simply run this server and handle both HTTP and Websocket connections. + /// If you already have a server you want to use, checkout the + /// `server::upgrade` module to integrate this crate with your server. pub parsed: Option, + /// the buffered data that was already taken from the stream pub buffer: Option, + /// the cause of the failed websocket connection setup pub error: HyperIntoWsError, } +/// Either the stream was established and it sent a websocket handshake +/// which represents the `Ok` variant, or there was an error (this is the +/// `Err` variant). pub type AcceptResult = Result, InvalidConnection>; /// Marker struct for a struct not being secure @@ -39,7 +55,7 @@ impl OptionalSslAcceptor for SslAcceptor {} /// This is a convenient way to implement WebSocket servers, however /// it is possible to use any sendable Reader and Writer to obtain /// a WebSocketClient, so if needed, an alternative server implementation can be used. -///#Non-secure Servers +///# Non-secure Servers /// /// ```no_run ///extern crate websocket; @@ -63,7 +79,7 @@ impl OptionalSslAcceptor for SslAcceptor {} /// # } /// ``` /// -///#Secure Servers +///# Secure Servers /// ```no_run ///extern crate websocket; ///extern crate openssl; @@ -106,10 +122,21 @@ impl OptionalSslAcceptor for SslAcceptor {} ///} /// # } /// ``` +/// +/// # A Hyper Server +/// This crates comes with hyper integration out of the box, you can create a hyper +/// server and serve websocket and HTTP **on the same port!** +/// check out the docs over at `websocket::server::upgrade::from_hyper` for an example. +/// +/// # A Custom Server +/// So you don't want to use any of our server implementations? That's O.K. +/// All it takes is implementing the `IntoWs` trait for your server's streams, +/// then calling `.into_ws()` on them. +/// check out the docs over at `websocket::server::upgrade` for more. pub struct Server where S: OptionalSslAcceptor { - pub listener: TcpListener, + listener: TcpListener, ssl_acceptor: S, } diff --git a/src/server/upgrade/from_hyper.rs b/src/server/upgrade/from_hyper.rs new file mode 100644 index 0000000000..da6ae1e134 --- /dev/null +++ b/src/server/upgrade/from_hyper.rs @@ -0,0 +1,88 @@ +//! Upgrade a hyper connection to a websocket one. +//! +//! Using this method, one can start a hyper server and check if each request +//! is a websocket upgrade request, if so you can use websockets and hyper on the +//! same port! +//! +//! ```rust,no_run +//! # extern crate hyper; +//! # extern crate websocket; +//! # fn main() { +//! use hyper::server::{Server, Request, Response}; +//! use websocket::Message; +//! use websocket::server::upgrade::IntoWs; +//! use websocket::server::upgrade::from_hyper::HyperRequest; +//! +//! Server::http("0.0.0.0:80").unwrap().handle(move |req: Request, res: Response| { +//! match HyperRequest(req).into_ws() { +//! Ok(upgrade) => { +//! // `accept` sends a successful handshake, no need to worry about res +//! let mut client = match upgrade.accept() { +//! Ok(c) => c, +//! Err(_) => panic!(), +//! }; +//! +//! client.send_message(&Message::text("its free real estate")); +//! }, +//! +//! Err((request, err)) => { +//! // continue using the request as normal, "echo uri" +//! res.send(b"Try connecting over ws instead.").unwrap(); +//! }, +//! }; +//! }) +//! .unwrap(); +//! # } +//! ``` + +use hyper::net::NetworkStream; +use super::{IntoWs, WsUpgrade, Buffer}; + +pub use hyper::http::h1::Incoming; +pub use hyper::method::Method; +pub use hyper::version::HttpVersion; +pub use hyper::uri::RequestUri; +pub use hyper::buffer::BufReader; +use hyper::server::Request; +pub use hyper::header::{Headers, Upgrade, ProtocolName, Connection, ConnectionOption}; + +use super::validate; +use super::HyperIntoWsError; + +/// A hyper request is implicitly defined as a stream from other `impl`s of Stream. +/// Until trait impl specialization comes along, we use this struct to differentiate +/// a hyper request (which already has parsed headers) from a normal stream. +pub struct HyperRequest<'a, 'b: 'a>(pub Request<'a, 'b>); + +impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { + type Stream = &'a mut &'b mut NetworkStream; + type Error = (Request<'a, 'b>, HyperIntoWsError); + + fn into_ws(self) -> Result, Self::Error> { + if let Err(e) = validate(&self.0.method, &self.0.version, &self.0.headers) { + return Err((self.0, e)); + } + + let (_, method, headers, uri, version, reader) = + self.0.deconstruct(); + + let reader = reader.into_inner(); + let (buf, pos, cap) = reader.take_buf(); + let stream = reader.get_mut(); + + Ok(WsUpgrade { + headers: Headers::new(), + stream: stream, + buffer: Some(Buffer { + buf: buf, + pos: pos, + cap: cap, + }), + request: Incoming { + version: version, + headers: headers, + subject: (method, uri), + }, + }) + } +} diff --git a/src/server/upgrade/hyper.rs b/src/server/upgrade/hyper.rs deleted file mode 100644 index 1e44a6e1a2..0000000000 --- a/src/server/upgrade/hyper.rs +++ /dev/null @@ -1,50 +0,0 @@ -extern crate hyper; - -use hyper::net::NetworkStream; -use super::{IntoWs, WsUpgrade, Buffer}; - -pub use hyper::http::h1::Incoming; -pub use hyper::method::Method; -pub use hyper::version::HttpVersion; -pub use hyper::uri::RequestUri; -pub use hyper::buffer::BufReader; -use hyper::server::Request; -pub use hyper::header::{Headers, Upgrade, ProtocolName, Connection, ConnectionOption}; - -use super::validate; -use super::HyperIntoWsError; - -pub struct HyperRequest<'a, 'b: 'a>(pub Request<'a, 'b>); - -impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { - type Stream = &'a mut &'b mut NetworkStream; - type Error = (Request<'a, 'b>, HyperIntoWsError); - - fn into_ws(self) -> Result, Self::Error> { - if let Err(e) = validate(&self.0.method, &self.0.version, &self.0.headers) { - return Err((self.0, e)); - } - - let (_, method, headers, uri, version, reader) = - self.0.deconstruct(); - - let reader = reader.into_inner(); - let (buf, pos, cap) = reader.take_buf(); - let stream = reader.get_mut(); - - Ok(WsUpgrade { - headers: Headers::new(), - stream: stream, - buffer: Some(Buffer { - buf: buf, - pos: pos, - cap: cap, - }), - request: Incoming { - version: version, - headers: headers, - subject: (method, uri), - }, - }) - } -} diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index 4aeb9ba0ab..51fa4332fe 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -1,7 +1,5 @@ //! Allows you to take an existing request or stream of data and convert it into a //! WebSocket client. -extern crate hyper as real_hyper; - use std::error::Error; use std::net::TcpStream; use std::io; @@ -15,21 +13,32 @@ use header::{WebSocketAccept, WebSocketKey, WebSocketVersion, WebSocketProtocol, use client::Client; use unicase::UniCase; -use self::real_hyper::status::StatusCode; -pub use self::real_hyper::http::h1::Incoming; -pub use self::real_hyper::method::Method; -pub use self::real_hyper::version::HttpVersion; -pub use self::real_hyper::uri::RequestUri; -pub use self::real_hyper::buffer::BufReader; -pub use self::real_hyper::http::h1::parse_request; -pub use self::real_hyper::header::{Headers, Upgrade, Protocol, ProtocolName, Connection, - ConnectionOption}; - -pub mod hyper; - +use hyper::status::StatusCode; +use hyper::http::h1::Incoming; +use hyper::method::Method; +use hyper::version::HttpVersion; +use hyper::uri::RequestUri; +use hyper::buffer::BufReader; +use hyper::http::h1::parse_request; +use hyper::header::{Headers, Upgrade, Protocol, ProtocolName, Connection, ConnectionOption}; + +pub mod from_hyper; + +/// This crate uses buffered readers to read in the handshake quickly, in order to +/// interface with other use cases that don't use buffered readers the buffered readers +/// is deconstructed when it is returned to the user and given as the underlying +/// reader and the buffer. +/// +/// This struct represents bytes that have already been read in from the stream. +/// A slice of valid data in this buffer can be obtained by: `&buf[pos..cap]`. pub struct Buffer { + /// the contents of the buffered stream data pub buf: Vec, + /// the current position of cursor in the buffer + /// Any data before `pos` has already been read and parsed. pub pos: usize, + /// the last location of valid data + /// Any data after `cap` is not valid. pub cap: usize, } @@ -37,20 +46,25 @@ pub struct Buffer { /// Should be used to examine the client's handshake /// accept the protocols requested, route the path, etc. /// -/// Users should then call `accept` or `deny` to complete the handshake +/// Users should then call `accept` or `reject` to complete the handshake /// and start a session. pub struct WsUpgrade where S: Stream { + /// The headers that will be used in the handshake response. pub headers: Headers, + /// The stream that will be used to read from / write to. pub stream: S, + /// The handshake request, filled with useful metadata. pub request: Request, + /// Some buffered data from the stream, if it exists. pub buffer: Option, } impl WsUpgrade where S: Stream { + /// Select a protocol to use in the handshake response. pub fn use_protocol

(mut self, protocol: P) -> Self where P: Into { @@ -61,6 +75,7 @@ impl WsUpgrade self } + /// Select an extension to use in the handshake response. pub fn use_extension(mut self, extension: Extension) -> Self { upsert_header!(self.headers; WebSocketExtensions; { Some(protos) => protos.0.push(extension), @@ -69,6 +84,7 @@ impl WsUpgrade self } + /// Select multiple extensions to use in the connection pub fn use_extensions(mut self, extensions: I) -> Self where I: IntoIterator { @@ -81,10 +97,15 @@ impl WsUpgrade self } + /// Accept the handshake request and send a response, + /// if nothing goes wrong a client will be created. pub fn accept(self) -> Result, (S, IoError)> { self.accept_with(&Headers::new()) } + /// Accept the handshake request and send a response while + /// adding on a few headers. These headers are added before the required + /// headers are, so some might be overwritten. pub fn accept_with(mut self, custom_headers: &Headers) -> Result, (S, IoError)> { self.headers.extend(custom_headers.iter()); self.headers @@ -109,10 +130,12 @@ impl WsUpgrade Ok(Client::unchecked(stream, self.headers)) } + /// Reject the client's request to make a websocket connection. pub fn reject(self) -> Result { self.reject_with(&Headers::new()) } - + /// Reject the client's request to make a websocket connection + /// and send extra headers. pub fn reject_with(mut self, headers: &Headers) -> Result { self.headers.extend(headers.iter()); match self.send(StatusCode::BadRequest) { @@ -121,10 +144,12 @@ impl WsUpgrade } } + /// Drop the connection without saying anything. pub fn drop(self) { ::std::mem::drop(self); } + /// A list of protocols requested from the client. pub fn protocols(&self) -> &[String] { self.request .headers @@ -133,6 +158,7 @@ impl WsUpgrade .unwrap_or(&[]) } + /// A list of extensions requested from the client. pub fn extensions(&self) -> &[Extension] { self.request .headers @@ -141,22 +167,21 @@ impl WsUpgrade .unwrap_or(&[]) } + /// The client's websocket accept key. pub fn key(&self) -> Option<&[u8; 16]> { self.request.headers.get::().map(|k| &k.0) } + /// The client's websocket version. pub fn version(&self) -> Option<&WebSocketVersion> { self.request.headers.get::() } + /// Origin of the client pub fn origin(&self) -> Option<&str> { self.request.headers.get::().map(|o| &o.0 as &str) } - pub fn into_stream(self) -> S { - self.stream - } - fn send(&mut self, status: StatusCode) -> IoResult<()> { try!(write!(&mut self.stream, "{} {}\r\n", self.request.version, status)); try!(write!(&mut self.stream, "{}\r\n", self.headers)); @@ -167,6 +192,8 @@ impl WsUpgrade impl WsUpgrade where S: Stream + AsTcpStream { + /// Get a handle to the underlying TCP stream, useful to be able to set + /// TCP options, etc. pub fn tcp_stream(&self) -> &TcpStream { self.stream.as_tcp() } @@ -181,8 +208,36 @@ impl WsUpgrade /// Otherwise the original stream is returned along with an error. /// /// Note: the stream is owned because the websocket client expects to own its stream. +/// +/// This is already implemented for all Streams, which means all types with Read + Write. +/// +/// # Example +/// +/// ```rust,no_run +/// use std::net::TcpListener; +/// use std::net::TcpStream; +/// use websocket::server::upgrade::IntoWs; +/// use websocket::Client; +/// +/// let listener = TcpListener::bind("127.0.0.1:80").unwrap(); +/// +/// for stream in listener.incoming().filter_map(Result::ok) { +/// let mut client: Client = match stream.into_ws() { +/// Ok(upgrade) => { +/// match upgrade.accept() { +/// Ok(client) => client, +/// Err(_) => panic!(), +/// } +/// }, +/// Err(_) => panic!(), +/// }; +/// } +/// ``` pub trait IntoWs { + /// The type of stream this upgrade process is working with (TcpStream, etc.) type Stream: Stream; + /// An error value in case the stream is not asking for a websocket connection + /// or something went wrong. It is common to also include the stream here. type Error; /// Attempt to parse the start of a Websocket handshake, later with the returned /// `WsUpgrade` struct, call `accept to start a websocket client, and `reject` to @@ -191,7 +246,11 @@ pub trait IntoWs { } +/// A typical request from hyper pub type Request = Incoming<(Method, RequestUri)>; +/// If you have your requests separate from your stream you can use this struct +/// to upgrade the connection based on the request given +/// (the request should be a handshake). pub struct RequestStreamPair(pub S, pub Request); impl IntoWs for S @@ -251,21 +310,30 @@ impl IntoWs for RequestStreamPair } } +/// Errors that can occur when one tries to upgrade a connection to a +/// websocket connection. #[derive(Debug)] pub enum HyperIntoWsError { + /// The HTTP method in a valid websocket upgrade request must be GET MethodNotGet, + /// Currently HTTP 2 is not supported UnsupportedHttpVersion, + /// Currently only WebSocket13 is supported (RFC6455) UnsupportedWebsocketVersion, + /// A websocket upgrade request must contain a key NoSecWsKeyHeader, + /// A websocket upgrade request must ask to upgrade to a `websocket` NoWsUpgradeHeader, + /// A websocket upgrade request must contain an `Upgrade` header NoUpgradeHeader, + /// A websocket upgrade request's `Connection` header must be `Upgrade` NoWsConnectionHeader, + /// A websocket upgrade request must contain a `Connection` header NoConnectionHeader, - UnknownNetworkStream, /// IO error from reading the underlying socket Io(io::Error), /// Error while parsing an incoming request - Parsing(self::real_hyper::error::Error), + Parsing(::hyper::error::Error), } impl Display for HyperIntoWsError { @@ -286,7 +354,6 @@ impl Error for HyperIntoWsError { &NoUpgradeHeader => "Missing Upgrade WebSocket header", &NoWsConnectionHeader => "Invalid Connection WebSocket header", &NoConnectionHeader => "Missing Connection WebSocket header", - &UnknownNetworkStream => "Cannot downcast to known impl of NetworkStream", &Io(ref e) => e.description(), &Parsing(ref e) => e.description(), } @@ -307,13 +374,13 @@ impl From for HyperIntoWsError { } } -impl From for HyperIntoWsError { - fn from(err: real_hyper::error::Error) -> Self { +impl From<::hyper::error::Error> for HyperIntoWsError { + fn from(err: ::hyper::error::Error) -> Self { HyperIntoWsError::Parsing(err) } } -pub fn validate( +fn validate( method: &Method, version: &HttpVersion, headers: &Headers, diff --git a/src/stream.rs b/src/stream.rs index 67cca057e5..c6901202d7 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -15,14 +15,25 @@ pub trait Stream: Read + Write {} impl Stream for S where S: Read + Write {} +/// a `Stream` that can also be used as a borrow to a `TcpStream` +/// this is useful when you want to set `TcpStream` options on a +/// `Stream` like `nonblocking`. pub trait NetworkStream: Read + Write + AsTcpStream {} impl NetworkStream for S where S: Read + Write + AsTcpStream {} +/// some streams can be split up into separate reading and writing components +/// `TcpStream` is an example. This trait marks this ability so one can split +/// up the client into two parts. +/// +/// Notice however that this is not possible to do with SSL. pub trait Splittable { + /// The reading component of this type type Reader: Read; + /// The writing component of this type type Writer: Write; + /// Split apart this type into a reading and writing component. fn split(self) -> io::Result<(Self::Reader, Self::Writer)>; } @@ -47,7 +58,10 @@ impl Splittable for TcpStream { } } +/// The ability access a borrow to an underlying TcpStream, +/// so one can set options on the stream such as `nonblocking`. pub trait AsTcpStream { + /// Get a borrow of the TcpStream fn as_tcp(&self) -> &TcpStream; } @@ -72,6 +86,10 @@ impl AsTcpStream for Box } } +/// If you would like to combine an input stream and an output stream into a single +/// stream to talk websockets over then this is the struct for you! +/// +/// This is useful if you want to use different mediums for different directions. pub struct ReadWritePair(pub R, pub W) where R: Read, W: Write; diff --git a/src/ws/receiver.rs b/src/ws/receiver.rs index 9c218eea40..3e237166da 100644 --- a/src/ws/receiver.rs +++ b/src/ws/receiver.rs @@ -11,6 +11,7 @@ use result::WebSocketResult; /// A trait for receiving data frames and messages. pub trait Receiver: Sized { + /// The type of dataframe that incoming messages will be serialized to. type F: DataFrame; /// Reads a single data frame from this receiver.