diff --git a/protobuf/src/error.rs b/protobuf/src/error.rs index 32e41a4a5..85af7f9a9 100644 --- a/protobuf/src/error.rs +++ b/protobuf/src/error.rs @@ -19,6 +19,7 @@ pub enum WireError { Utf8Error, InvalidEnumValue(i32), OverRecursionLimit, + TruncatedMessage, Other, } @@ -57,6 +58,7 @@ impl Error for ProtobufError { WireError::IncompleteMap => "incomplete map", WireError::UnexpectedEof => "unexpected EOF", WireError::OverRecursionLimit => "over recursion limit", + WireError::TruncatedMessage => "truncated message", WireError::Other => "other error", } } diff --git a/protobuf/src/stream.rs b/protobuf/src/stream.rs index 0c331863b..a5629c7fb 100644 --- a/protobuf/src/stream.rs +++ b/protobuf/src/stream.rs @@ -33,6 +33,9 @@ const OUTPUT_STREAM_BUFFER_SIZE: usize = 8 * 1024; // Default recursion level limit. 100 is the default value of C++'s implementation. const DEFAULT_RECURSION_LIMIT: u32 = 100; +// Max allocated vec when reading length-delimited from unknown input stream +const READ_RAW_BYTES_MAX_ALLOC: usize = 10_000_000; + pub mod wire_format { // TODO: temporary @@ -623,14 +626,34 @@ impl<'a> CodedInputStream<'a> { /// Read raw bytes into the supplied vector. The vector will be resized as needed and /// overwritten. pub fn read_raw_bytes_into(&mut self, count: u32, target: &mut Vec) -> ProtobufResult<()> { + let count = count as usize; + + // TODO: also do some limits when reading from unlimited source + if count as u64 > self.source.bytes_until_limit() { + return Err(ProtobufError::WireError(WireError::TruncatedMessage)); + } + unsafe { target.set_len(0); } - target.reserve(count as usize); - unsafe { - target.set_len(count as usize); + + if count >= READ_RAW_BYTES_MAX_ALLOC { + // avoid calling `reserve` on buf with very large buffer: could be a malformed message + + let mut take = self.by_ref().take(count as u64); + take.read_to_end(target)?; + + if target.len() != count { + return Err(ProtobufError::WireError(WireError::TruncatedMessage)); + } + } else { + target.reserve(count); + unsafe { + target.set_len(count); + } + + self.source.read_exact(target)?; } - self.read(target)?; Ok(()) } @@ -1255,6 +1278,7 @@ mod test { use super::wire_format; use super::CodedInputStream; use super::CodedOutputStream; + use super::READ_RAW_BYTES_MAX_ALLOC; fn test_read_partial(hex: &str, mut callback: F) where @@ -1425,6 +1449,32 @@ mod test { }); } + #[test] + fn test_input_stream_read_raw_bytes_into_huge() { + let mut v = Vec::new(); + for i in 0..READ_RAW_BYTES_MAX_ALLOC + 1000 { + v.push((i % 10) as u8); + } + + let mut slice: &[u8] = v.as_slice(); + + let mut is = CodedInputStream::new(&mut slice); + + let mut buf = Vec::new(); + + is.read_raw_bytes_into(READ_RAW_BYTES_MAX_ALLOC as u32 + 10, &mut buf).expect("read"); + + assert_eq!(READ_RAW_BYTES_MAX_ALLOC + 10, buf.len()); + + buf.clear(); + + is.read_raw_bytes_into(1000 - 10, &mut buf).expect("read"); + + assert_eq!(1000 - 10, buf.len()); + + assert!(is.eof().expect("eof")); + } + fn test_write(expected: &str, mut gen: F) where F : FnMut(&mut CodedOutputStream) -> ProtobufResult<()>,