diff --git a/tokio-util/src/codec/length_delimited.rs b/tokio-util/src/codec/length_delimited.rs index a182dcaec0c..92d76b2cd28 100644 --- a/tokio-util/src/codec/length_delimited.rs +++ b/tokio-util/src/codec/length_delimited.rs @@ -386,6 +386,10 @@ use std::{cmp, fmt, mem}; /// `Builder` enables constructing configured length delimited codecs. Note /// that not all configuration settings apply to both encoding and decoding. See /// the documentation for specific methods for more detail. +/// +/// Note that the if the value of [`Builder::max_frame_length`] becomes larger than +/// what can actually fit in [`Builder::length_field_length`], it will be clipped to +/// the maximum value that can fit. #[derive(Debug, Clone, Copy)] pub struct Builder { // Maximum frame length @@ -935,8 +939,12 @@ impl Builder { /// # } /// ``` pub fn new_codec(&self) -> LengthDelimitedCodec { + let mut builder = *self; + + builder.adjust_max_frame_len(); + LengthDelimitedCodec { - builder: *self, + builder, state: DecodeState::Head, } } @@ -1018,6 +1026,35 @@ impl Builder { self.num_skip .unwrap_or(self.length_field_offset + self.length_field_len) } + + fn adjust_max_frame_len(&mut self) { + // This function is basically `std::u64::saturating_add_signed`. Since it + // requires MSRV 1.66, its implementation is copied here. + // + // TODO: use the method from std when MSRV becomes >= 1.66 + fn saturating_add_signed(num: u64, rhs: i64) -> u64 { + let (res, overflow) = num.overflowing_add(rhs as u64); + if overflow == (rhs < 0) { + res + } else if overflow { + u64::MAX + } else { + 0 + } + } + + // Calculate the maximum number that can be represented using `length_field_len` bytes. + let max_number = match 1u64.checked_shl((8 * self.length_field_len) as u32) { + Some(shl) => shl - 1, + None => u64::MAX, + }; + + let max_allowed_len = saturating_add_signed(max_number, self.length_adjustment as i64); + + if self.max_frame_len as u64 > max_allowed_len { + self.max_frame_len = usize::try_from(max_allowed_len).unwrap_or(usize::MAX); + } + } } impl Default for Builder { diff --git a/tokio-util/tests/length_delimited.rs b/tokio-util/tests/length_delimited.rs index ed5590f9644..091a5b449e4 100644 --- a/tokio-util/tests/length_delimited.rs +++ b/tokio-util/tests/length_delimited.rs @@ -689,6 +689,66 @@ fn encode_overflow() { codec.encode(Bytes::from("hello"), &mut buf).unwrap(); } +#[test] +fn frame_does_not_fit() { + let codec = LengthDelimitedCodec::builder() + .length_field_length(1) + .max_frame_length(256) + .new_codec(); + + assert_eq!(codec.max_frame_length(), 255); +} + +#[test] +fn neg_adjusted_frame_does_not_fit() { + let codec = LengthDelimitedCodec::builder() + .length_field_length(1) + .length_adjustment(-1) + .new_codec(); + + assert_eq!(codec.max_frame_length(), 254); +} + +#[test] +fn pos_adjusted_frame_does_not_fit() { + let codec = LengthDelimitedCodec::builder() + .length_field_length(1) + .length_adjustment(1) + .new_codec(); + + assert_eq!(codec.max_frame_length(), 256); +} + +#[test] +fn max_allowed_frame_fits() { + let codec = LengthDelimitedCodec::builder() + .length_field_length(std::mem::size_of::()) + .max_frame_length(usize::MAX) + .new_codec(); + + assert_eq!(codec.max_frame_length(), usize::MAX); +} + +#[test] +fn smaller_frame_len_not_adjusted() { + let codec = LengthDelimitedCodec::builder() + .max_frame_length(10) + .length_field_length(std::mem::size_of::()) + .new_codec(); + + assert_eq!(codec.max_frame_length(), 10); +} + +#[test] +fn max_allowed_length_field() { + let codec = LengthDelimitedCodec::builder() + .length_field_length(8) + .max_frame_length(usize::MAX) + .new_codec(); + + assert_eq!(codec.max_frame_length(), usize::MAX); +} + // ===== Test utils ===== struct Mock {