Skip to content

Commit

Permalink
Higher-level API for memory limits
Browse files Browse the repository at this point in the history
  • Loading branch information
vi committed Jul 24, 2022
1 parent 3bde0f4 commit 74c82be
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 7 deletions.
38 changes: 35 additions & 3 deletions src/client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ use std::borrow::Cow;
use std::convert::Into;
pub use url::{ParseError, Url};

const DEFAULT_MAX_DATAFRAME_SIZE : usize = 1024*1024*100;
const DEFAULT_MAX_MESSAGE_SIZE : usize = 1024*1024*200;

#[cfg(any(feature = "sync", feature = "async"))]
mod common_imports {
pub use crate::header::WebSocketAccept;
Expand Down Expand Up @@ -114,6 +117,8 @@ pub struct ClientBuilder<'u> {
headers: Headers,
version_set: bool,
key_set: bool,
max_dataframe_size: usize,
max_message_size: usize,
}

impl<'u> ClientBuilder<'u> {
Expand Down Expand Up @@ -161,6 +166,8 @@ impl<'u> ClientBuilder<'u> {
version_set: false,
key_set: false,
headers: Headers::new(),
max_dataframe_size: DEFAULT_MAX_DATAFRAME_SIZE,
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
}
}

Expand Down Expand Up @@ -289,6 +296,21 @@ impl<'u> ClientBuilder<'u> {
self
}

/// Set maximum dataframe size. Client will abort connection with error if it is exceed.
/// Values larger than `u32::MAX` are not supported.
pub fn max_dataframe_size(mut self, value: usize) -> Self {
self.max_dataframe_size = value;
self
}

/// Set maximum message size for which no more continuation dataframes are accepted.
/// Client will abort connection with error if it is exceed.
/// Values larger than `u32::MAX` are not supported.
pub fn max_message_size(mut self, value: usize) -> Self {
self.max_message_size = value;
self
}

/// Add a custom `Sec-WebSocket-Key` header.
/// Use this only if you know what you're doing, and this almost
/// never has to be used.
Expand Down Expand Up @@ -493,7 +515,7 @@ impl<'u> ClientBuilder<'u> {
// validate
self.validate(&response)?;

Ok(Client::unchecked(reader, response.headers, true, false))
Ok(Client::unchecked_with_limits(reader, response.headers, true, false, self.max_dataframe_size, self.max_dataframe_size))
}

/// Connect to a websocket server asynchronously.
Expand Down Expand Up @@ -553,6 +575,8 @@ impl<'u> ClientBuilder<'u> {
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
max_dataframe_size: self.max_dataframe_size,
max_message_size: self.max_message_size,
};

// check if we should connect over ssl or not
Expand Down Expand Up @@ -637,6 +661,8 @@ impl<'u> ClientBuilder<'u> {
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
max_dataframe_size: self.max_dataframe_size,
max_message_size: self.max_message_size,
};

// put it all together
Expand Down Expand Up @@ -687,6 +713,8 @@ impl<'u> ClientBuilder<'u> {
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
max_dataframe_size: self.max_dataframe_size,
max_message_size: self.max_message_size,
};

let future = tcp_stream.and_then(move |stream| builder.async_connect_on(stream));
Expand Down Expand Up @@ -746,6 +774,8 @@ impl<'u> ClientBuilder<'u> {
headers: self.headers,
version_set: self.version_set,
key_set: self.key_set,
max_dataframe_size: self.max_dataframe_size,
max_message_size: self.max_message_size,
};
let resource = builder.build_request();
let framed = crate::codec::http::HttpClientCodec.framed(stream);
Expand All @@ -755,6 +785,8 @@ impl<'u> ClientBuilder<'u> {
subject: (Method::Get, RequestUri::AbsolutePath(resource)),
};

let max_dataframe_size = self.max_dataframe_size;
let max_message_size = self.max_message_size;
let future = framed
// send request
.send(request)
Expand All @@ -770,8 +802,8 @@ impl<'u> ClientBuilder<'u> {
.and_then(|message| builder.validate(&message).map(|()| (message, stream)))
})
// output the final client and metadata
.map(|(message, stream)| {
let codec = MessageCodec::default(Context::Client);
.map(move |(message, stream)| {
let codec = MessageCodec::new_with_limits(Context::Client, max_dataframe_size, max_message_size);
let client = update_framed_codec(stream, codec);
(client, message.headers)
});
Expand Down
17 changes: 17 additions & 0 deletions src/client/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,23 @@ where
}
}

#[doc(hidden)]
pub fn unchecked_with_limits(
stream: BufReader<S>,
headers: Headers,
out_mask: bool,
in_mask: bool,
max_dataframe_size: usize,
max_message_size: usize,
) -> Self {
Client {
headers,
stream,
sender: Sender::new(out_mask), // true
receiver: Receiver::new_with_limits(in_mask, max_dataframe_size, max_message_size), // false
}
}

/// Sends a single data frame to the remote endpoint.
pub fn send_dataframe<D>(&mut self, dataframe: &D) -> WebSocketResult<()>
where
Expand Down
22 changes: 20 additions & 2 deletions src/server/upgrade/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ use hyper::status::StatusCode;
use std::io::{self, ErrorKind};
use tokio_codec::{Decoder, Framed, FramedParts};

const DEFAULT_MAX_DATAFRAME_SIZE : usize = 1024*1024*100;
const DEFAULT_MAX_MESSAGE_SIZE : usize = 1024*1024*200;

/// An asynchronous websocket upgrade.
///
/// This struct is given when a connection is being upgraded to a websocket
Expand Down Expand Up @@ -84,7 +87,22 @@ where
self.internal_accept(Some(custom_headers))
}

fn internal_accept(mut self, custom_headers: Option<&Headers>) -> ClientNew<S> {
/// Like `accept`, but also allows to set memory limits for incoming messages and dataframes
pub fn accept_with_limits(self, max_dataframe_size: usize, max_message_size: usize) -> ClientNew<S> {
self.internal_accept_with_limits(None, max_dataframe_size, max_message_size)
}

/// Like `accept_with`, but also allows to set memory limits for incoming messages and dataframes
pub fn accept_with_headers_and_limits(self, custom_headers: &Headers, max_dataframe_size: usize, max_message_size: usize) -> ClientNew<S> {
self.internal_accept_with_limits(Some(custom_headers), max_dataframe_size, max_message_size)
}


fn internal_accept(self, custom_headers: Option<&Headers>) -> ClientNew<S> {
self.internal_accept_with_limits(custom_headers, DEFAULT_MAX_DATAFRAME_SIZE, DEFAULT_MAX_MESSAGE_SIZE)
}

fn internal_accept_with_limits(mut self, custom_headers: Option<&Headers>, max_dataframe_size: usize, max_message_size: usize) -> ClientNew<S> {
let status = self.prepare_headers(custom_headers);
let WsUpgrade {
headers,
Expand All @@ -104,7 +122,7 @@ where
headers: headers.clone(),
})
.map(move |s| {
let codec = MessageCodec::default(Context::Server);
let codec = MessageCodec::new_with_limits(Context::Server, max_dataframe_size, max_message_size);
let client = update_framed_codec(s, codec);
(client, headers)
})
Expand Down
24 changes: 22 additions & 2 deletions src/server/upgrade/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ use hyper::http::h1::Incoming;
use hyper::net::NetworkStream;
use hyper::status::StatusCode;

const DEFAULT_MAX_DATAFRAME_SIZE : usize = 1024*1024*100;
const DEFAULT_MAX_MESSAGE_SIZE : usize = 1024*1024*200;

/// This crate uses buffered readers to read in the handshake quickly, in order to
/// interface with other use cases that don't use buffered readers the buffered readers
/// is deconstructed when it is returned to the user and given as the underlying
Expand Down Expand Up @@ -61,7 +64,24 @@ where
self.internal_accept(Some(custom_headers))
}

fn internal_accept(mut self, headers: Option<&Headers>) -> Result<Client<S>, (S, io::Error)> {
/// Accept the handshake request and send a response,
/// if nothing goes wrong a client will be created.
pub fn accept_with_limits(self, max_dataframe_size: usize, max_message_size: usize) -> Result<Client<S>, (S, io::Error)> {
self.internal_accept_with_limits(None, max_dataframe_size, max_message_size)
}

/// Accept the handshake request and send a response while
/// adding on a few headers. These headers are added before the required
/// headers are, so some might be overwritten.
pub fn accept_with_headers_and_limits(self, custom_headers: &Headers, max_dataframe_size: usize, max_message_size: usize) -> Result<Client<S>, (S, io::Error)> {
self.internal_accept_with_limits(Some(custom_headers), max_dataframe_size, max_message_size)
}

fn internal_accept(self, headers: Option<&Headers>) -> Result<Client<S>, (S, io::Error)> {
self.internal_accept_with_limits(headers, DEFAULT_MAX_DATAFRAME_SIZE, DEFAULT_MAX_MESSAGE_SIZE)
}

fn internal_accept_with_limits(mut self, headers: Option<&Headers>, max_dataframe_size: usize, max_message_size: usize) -> Result<Client<S>, (S, io::Error)> {
let status = self.prepare_headers(headers);

if let Err(e) = self.send(status) {
Expand All @@ -73,7 +93,7 @@ where
None => BufReader::new(self.stream),
};

Ok(Client::unchecked(stream, self.headers, false, true))
Ok(Client::unchecked_with_limits(stream, self.headers, false, true, max_dataframe_size, max_message_size))
}

/// Reject the client's request to make a websocket connection.
Expand Down

0 comments on commit 74c82be

Please sign in to comment.