diff --git a/src/qos_core/src/io/stream.rs b/src/qos_core/src/io/stream.rs index 49785305b..745e86907 100644 --- a/src/qos_core/src/io/stream.rs +++ b/src/qos_core/src/io/stream.rs @@ -25,11 +25,14 @@ const MAX_RETRY: usize = 25; const BACKOFF_MILLISECONDS: u64 = 10; const BACKLOG: usize = 128; -const MEGABYTE: usize = 1024 * 1024; +const MIB: usize = 1024 * 1024; -/// Maximum payload size for a single recv / send call. We're being generous with 128MB. +/// Maximum payload size for a single recv / send call. We're being generous with 128MiB. /// The goal here is to avoid server crashes if the payload size exceeds the available system memory. -pub const MAX_PAYLOAD_SIZE: usize = 128 * MEGABYTE; +pub const MAX_PAYLOAD_SIZE: usize = 128 * MIB; + +/// Even though we allow for big payloads we start by allocating a small buffer first. Then allocate more as needed. +pub const INITIAL_RECV_BUF_SIZE: usize = 2 * MIB; /// Socket address. #[derive(Clone, Debug, PartialEq, Eq)] @@ -234,33 +237,37 @@ impl Stream { return Err(IOError::OversizedPayload(length)); } - // Read the buffer - let mut buf = vec![0; length]; - { - let mut received_bytes = 0; - while received_bytes < length { - received_bytes += match recv( - self.fd, - &mut buf[received_bytes..length], - MsgFlags::empty(), - ) { - Ok(0) => { - return Err(IOError::RecvConnectionClosed); - } - Ok(size) => size, - Err(nix::Error::EINTR) => { - return Err(IOError::RecvInterrupted); - } - Err(nix::Error::EAGAIN) => { - return Err(IOError::RecvTimeout); - } - Err(err) => { - return Err(IOError::NixError(err)); - } - }; + // Allocate conservatively to avoid clients setting 128MB as their declared length and keeping the connection open. + // We'd only need a few of these to run out of memory. This "as needed" allocation ensures clients have skin in the game. + let initial_recv_buf_len = + core::cmp::min(length, INITIAL_RECV_BUF_SIZE); + let mut recv_buf = vec![0u8; initial_recv_buf_len]; + + let mut received_bytes = 0; + while received_bytes < length { + // If the receive buffer is full, double it. + if received_bytes == recv_buf.len() { + // Using `saturating_mul` here out of paranoia; it's cheap enough to saturate instead of overflow. + // We either double the recv buffer capacity, or set it to `length` if doubling would exceed it. + let new_len = + core::cmp::min(recv_buf.len().saturating_mul(2), length); + recv_buf.resize(new_len, 0); } + + received_bytes += match recv( + self.fd, + &mut recv_buf[received_bytes..], + MsgFlags::empty(), + ) { + Ok(0) => return Err(IOError::RecvConnectionClosed), + Ok(size) => size, + Err(nix::Error::EINTR) => return Err(IOError::RecvInterrupted), + Err(nix::Error::EAGAIN) => return Err(IOError::RecvTimeout), + Err(err) => return Err(IOError::NixError(err)), + }; } - Ok(buf) + + Ok(recv_buf) } } @@ -533,7 +540,7 @@ mod test { } }); - // Sending a request that is strictly less than the max size should work + // Sending a request that is exactly the max size should work // (the response will be exactly max size) let client = Stream::connect(&addr, timeval()).unwrap(); let req = vec![1u8; MAX_PAYLOAD_SIZE];