diff --git a/src/receiver.rs b/src/receiver.rs index f74153c69b..9ab5593281 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -14,6 +14,10 @@ use crate::ws; use crate::ws::receiver::Receiver as ReceiverTrait; use crate::ws::receiver::{DataFrameIterator, MessageIterator}; +const DEFAULT_MAX_DATAFRAME_SIZE : usize = 1024*1024*100; +const DEFAULT_MAX_MESSAGE_SIZE : usize = 1024*1024*200; +const MAX_DATAFRAMES_IN_ONE_MESSAGE: usize = 1024*1024; + /// This reader bundles an existing stream with a parsing algorithm. /// It is used by the client in its `.split()` function as the reading component. pub struct Reader @@ -74,14 +78,33 @@ where pub struct Receiver { buffer: Vec, mask: bool, + // u32s instead uf usizes to economize used memory by this struct + max_dataframe_size: u32, + max_message_size: u32, } impl Receiver { /// Create a new Receiver using the specified Reader. + /// + /// Uses built-in limits for dataframe and message sizes. pub fn new(mask: bool) -> Receiver { + Receiver::new_with_limits(mask, DEFAULT_MAX_DATAFRAME_SIZE, DEFAULT_MAX_MESSAGE_SIZE) + } + + /// Create a new Receiver using the specified Reader, with configurable limits + /// + /// Sizes should not be larger than `u32::MAX`. + /// + /// Note that `max_message_size` denotes message size where no new dataframes would be read, + /// so actual maximum message size is larger. + pub fn new_with_limits(mask: bool, max_dataframe_size: usize, max_message_size: usize) -> Receiver { + let max_dataframe_size: u32 = max_dataframe_size.min(u32::MAX as usize) as u32; + let max_message_size: u32 = max_message_size.min(u32::MAX as usize) as u32; Receiver { buffer: Vec::new(), mask, + max_dataframe_size, + max_message_size, } } } @@ -96,7 +119,7 @@ impl ws::Receiver for Receiver { where R: Read, { - DataFrame::read_dataframe(reader, self.mask) + DataFrame::read_dataframe_with_limit(reader, self.mask, self.max_dataframe_size as usize) } /// Returns the data frames that constitute one message. @@ -104,6 +127,7 @@ impl ws::Receiver for Receiver { where R: Read, { + let mut current_message_length : usize = self.buffer.iter().map(|x|x.data.len()).sum(); let mut finished = if self.buffer.is_empty() { let first = self.recv_dataframe(reader)?; @@ -114,6 +138,7 @@ impl ws::Receiver for Receiver { } let finished = first.finished; + current_message_length += first.data.len(); self.buffer.push(first); finished } else { @@ -126,7 +151,10 @@ impl ws::Receiver for Receiver { match next.opcode as u8 { // Continuation opcode - 0 => self.buffer.push(next), + 0 => { + current_message_length += next.data.len(); + self.buffer.push(next) + } // Control frame 8..=15 => { return Ok(vec![next]); @@ -138,6 +166,19 @@ impl ws::Receiver for Receiver { )); } } + + if !finished { + if self.buffer.len() >= MAX_DATAFRAMES_IN_ONE_MESSAGE { + return Err(WebSocketError::ProtocolError( + "Exceeded count of data frames in one WebSocket message", + )); + } + if current_message_length >= self.max_message_size as usize { + return Err(WebSocketError::ProtocolError( + "Exceeded maximum WebSocket message size", + )); + } + } } Ok(::std::mem::replace(&mut self.buffer, Vec::new())) diff --git a/websocket-base/src/dataframe.rs b/websocket-base/src/dataframe.rs index c25221c922..eb3d2e4643 100644 --- a/websocket-base/src/dataframe.rs +++ b/websocket-base/src/dataframe.rs @@ -96,6 +96,25 @@ impl DataFrame { DataFrame::read_dataframe_body(header, data, should_be_masked) } + + /// Reads a DataFrame from a Reader, or error out if header declares exceeding limit you specify + pub fn read_dataframe_with_limit(reader: &mut R, should_be_masked: bool, limit: usize) -> WebSocketResult + where + R: Read, + { + let header = dfh::read_header(reader)?; + + if header.len > limit as u64 { + return Err(io::Error::new(io::ErrorKind::InvalidData, "exceeded DataFrame length limit").into()); + } + let mut data: Vec = Vec::with_capacity(header.len as usize); + let read = reader.take(header.len).read_to_end(&mut data)?; + if (read as u64) < header.len { + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "incomplete payload").into()); + } + + DataFrame::read_dataframe_body(header, data, should_be_masked) + } } impl DataFrameable for DataFrame {