From 17a48918af8a5aa1d1838d3c30b27ab641536bdc Mon Sep 17 00:00:00 2001 From: suikammd Date: Wed, 28 Jun 2023 14:02:53 +0800 Subject: [PATCH] codec: add borrow framed --- tokio-util/src/codec/framed.rs | 78 ++++++++++++++++++++++++++++ tokio-util/src/codec/framed_impl.rs | 20 +++++++ tokio-util/src/codec/framed_read.rs | 69 ++++++++++++++++++++++++ tokio-util/src/codec/framed_write.rs | 69 ++++++++++++++++++++++++ tokio-util/tests/framed.rs | 27 ++++++++++ tokio-util/tests/framed_read.rs | 19 +++++++ tokio-util/tests/framed_write.rs | 26 ++++++++++ 7 files changed, 308 insertions(+) diff --git a/tokio-util/src/codec/framed.rs b/tokio-util/src/codec/framed.rs index 8a344f90db2..fa8c15986d6 100644 --- a/tokio-util/src/codec/framed.rs +++ b/tokio-util/src/codec/framed.rs @@ -30,6 +30,22 @@ pin_project! { } } +pin_project! { + /// A borrowed unified [`Stream`] and [`Sink`] interface to an underlying I/O object, using + /// the `Encoder` and `Decoder` traits to encode and decode frames. + /// + /// You can create a `BorrowFramed` instance by using the `with_codec` function of Framed + /// + /// [`Stream`]: futures_core::Stream + /// [`Sink`]: futures_sink::Sink + /// [`AsyncRead`]: tokio::io::AsyncRead + /// [`Decoder::framed`]: crate::codec::Decoder::framed() + pub struct BorrowFramed<'borrow, T, U> { + #[pin] + inner: FramedImpl<&'borrow mut T, U, &'borrow mut RWFrames>, + } +} + impl Framed where T: AsyncRead + AsyncWrite, @@ -224,6 +240,29 @@ impl Framed { }) } + /// Maps the codec `U` to `C` temporarily using &mut self + /// preserving the read and write buffers wrapped by `Framed`. + /// + /// Note that care should be taken to not tamper with the underlying codec + /// as it may corrupt the stream of frames otherwise being worked with. + pub fn with_codec(&mut self, map: F) -> BorrowFramed<'_, T, C> + where + F: FnOnce(&mut U) -> C, + { + let FramedImpl { + inner, + state, + codec, + } = &mut self.inner; + BorrowFramed { + inner: FramedImpl { + inner, + state, + codec: map(codec), + }, + } + } + /// Returns a mutable reference to the underlying codec wrapped by /// `Framed`. /// @@ -341,6 +380,45 @@ where } } +// This impl just defers to the underlying FramedImpl +impl<'borrow, T, U> Stream for BorrowFramed<'borrow, T, U> +where + T: AsyncRead + Unpin, + U: Decoder, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_next(cx) + } +} + +// This impl just defers to the underlying FramedImpl +impl<'borrow, T, I, U> Sink for BorrowFramed<'borrow, T, U> +where + T: AsyncWrite + Unpin, + U: Encoder, + U::Error: From, +{ + type Error = U::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx) + } +} + /// `FramedParts` contains an export of the data of a Framed transport. /// It can be used to construct a new [`Framed`] with a different codec. /// It contains all current buffers and the inner transport. diff --git a/tokio-util/src/codec/framed_impl.rs b/tokio-util/src/codec/framed_impl.rs index 8f3fa49b0d9..4e112a3e9c4 100644 --- a/tokio-util/src/codec/framed_impl.rs +++ b/tokio-util/src/codec/framed_impl.rs @@ -115,6 +115,26 @@ impl BorrowMut for RWFrames { &mut self.write } } +impl Borrow for &mut RWFrames { + fn borrow(&self) -> &ReadFrame { + &self.read + } +} +impl BorrowMut for &mut RWFrames { + fn borrow_mut(&mut self) -> &mut ReadFrame { + &mut self.read + } +} +impl Borrow for &mut RWFrames { + fn borrow(&self) -> &WriteFrame { + &self.write + } +} +impl BorrowMut for &mut RWFrames { + fn borrow_mut(&mut self) -> &mut WriteFrame { + &mut self.write + } +} impl Stream for FramedImpl where T: AsyncRead, diff --git a/tokio-util/src/codec/framed_read.rs b/tokio-util/src/codec/framed_read.rs index 184c567b498..b95c2a33a35 100644 --- a/tokio-util/src/codec/framed_read.rs +++ b/tokio-util/src/codec/framed_read.rs @@ -22,6 +22,17 @@ pin_project! { } } +pin_project! { + /// A [`Stream`] of messages decoded from an [`AsyncRead`]. + /// + /// [`Stream`]: futures_core::Stream + /// [`AsyncRead`]: tokio::io::AsyncRead + pub struct BorrowFramedRead<'borrow, T, D> { + #[pin] + inner: FramedImpl<&'borrow mut T, D, &'borrow mut ReadFrame>, + } +} + // ===== impl FramedRead ===== impl FramedRead @@ -129,6 +140,27 @@ impl FramedRead { } } + /// Maps the decoder `D` to `C` temporarily using &mut self, + /// preserving the read buffer wrapped by `Framed`. + pub fn with_decoder(&mut self, map: F) -> BorrowFramedRead<'_, T, C> + where + F: FnOnce(&mut D) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let FramedImpl { + inner, + state, + codec, + } = &mut self.inner; + BorrowFramedRead { + inner: FramedImpl { + inner, + state, + codec: map(codec), + }, + } + } + /// Returns a mutable reference to the underlying decoder. pub fn decoder_pin_mut(self: Pin<&mut Self>) -> &mut D { self.project().inner.project().codec @@ -197,3 +229,40 @@ where .finish() } } + +// This impl just defers to the underlying FramedImpl +impl<'borrow, T, D> Stream for BorrowFramedRead<'borrow, T, D> +where + T: AsyncRead + Unpin, + D: Decoder, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_next(cx) + } +} + +// This impl just defers to the underlying T: Sink +impl<'borrow, T, I, D> Sink for BorrowFramedRead<'borrow, T, D> +where + T: Sink + Unpin, +{ + type Error = T::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.project().inner.poll_close(cx) + } +} diff --git a/tokio-util/src/codec/framed_write.rs b/tokio-util/src/codec/framed_write.rs index 3f0a3408157..5b2096a74f7 100644 --- a/tokio-util/src/codec/framed_write.rs +++ b/tokio-util/src/codec/framed_write.rs @@ -22,6 +22,16 @@ pin_project! { } } +pin_project! { + /// A [`Sink`] of frames encoded to an `AsyncWrite`. + /// + /// [`Sink`]: futures_sink::Sink + pub struct BorrowFramedWrite<'borrow, T, E> { + #[pin] + inner: FramedImpl<&'borrow mut T, E, &'borrow mut WriteFrame>, + } +} + impl FramedWrite where T: AsyncWrite, @@ -109,6 +119,27 @@ impl FramedWrite { } } + /// Maps the encoder `E` to `C` temporarily using &mut self, + /// preserving the write buffer wrapped by `Framed`. + pub fn with_encoder(&mut self, map: F) -> BorrowFramedWrite<'_, T, C> + where + F: FnOnce(&mut E) -> C, + { + // This could be potentially simplified once rust-lang/rust#86555 hits stable + let FramedImpl { + inner, + state, + codec, + } = &mut self.inner; + BorrowFramedWrite { + inner: FramedImpl { + inner, + state, + codec: map(codec), + }, + } + } + /// Returns a mutable reference to the underlying encoder. pub fn encoder_pin_mut(self: Pin<&mut Self>) -> &mut E { self.project().inner.project().codec @@ -186,3 +217,41 @@ where .finish() } } + +// This impl just defers to the underlying FramedImpl +impl<'borrow, T, I, E> Sink for BorrowFramedWrite<'borrow, T, E> +where + T: AsyncWrite + Unpin, + E: Encoder, + E::Error: From, +{ + type Error = E::Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_ready(cx) + } + + fn start_send(self: Pin<&mut Self>, item: I) -> Result<(), Self::Error> { + self.project().inner.start_send(item) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_close(cx) + } +} + +// This impl just defers to the underlying T: Stream +impl<'borrow, T, D> Stream for BorrowFramedWrite<'borrow, T, D> +where + T: Stream + Unpin, +{ + type Item = T::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.project().inner.poll_next(cx) + } +} diff --git a/tokio-util/tests/framed.rs b/tokio-util/tests/framed.rs index ec8cdf00d09..94ec8a58414 100644 --- a/tokio-util/tests/framed.rs +++ b/tokio-util/tests/framed.rs @@ -150,3 +150,30 @@ fn external_buf_does_not_shrink() { assert_eq!(read_buf.capacity(), INITIAL_CAPACITY * 2); } + +#[tokio::test] +async fn borrow_framed() { + let mut parts = FramedParts::new(DontReadIntoThis, U32Codec::default()); + parts.read_buf = BytesMut::from( + &[ + 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, 84, 0, 0, 0, 0, 0, 0, 0, 84, 0, 0, 0, 42, + ][..], + ); + + let mut framed = Framed::from_parts(parts); + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 42); + assert_eq!(framed.codec().read_bytes, 4); + + let mut borrow_framed = framed.with_codec(|codec| U64Codec { + read_bytes: codec.read_bytes, + }); + assert_eq!(assert_ok!(borrow_framed.next().await.unwrap()), 84); + assert_eq!(assert_ok!(borrow_framed.next().await.unwrap()), 84); + + let num = assert_ok!(framed.next().await.unwrap()); + + assert_eq!(num, 42); + assert_eq!(framed.codec().read_bytes, 8); +} diff --git a/tokio-util/tests/framed_read.rs b/tokio-util/tests/framed_read.rs index 2a9e27e22f5..ac9ea81d218 100644 --- a/tokio-util/tests/framed_read.rs +++ b/tokio-util/tests/framed_read.rs @@ -118,6 +118,25 @@ fn read_multi_frame_in_packet_after_codec_changed() { }); } +#[test] +fn borrow_framed_read() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08\x00\x00\x00\x04".to_vec()), + }; + let mut framed = FramedRead::new(mock, U32Decoder); + + task.enter(|cx, _| { + assert_read!(pin!(framed).poll_next(cx), 0x04); + + let mut borrow_framed = framed.with_decoder(|_| U64Decoder); + assert_read!(pin!(borrow_framed).poll_next(cx), 0x08); + + assert_read!(pin!(framed).poll_next(cx), 0x04); + assert!(assert_ready!(pin!(framed).poll_next(cx)).is_none()); + }); +} + #[test] fn read_not_ready() { let mut task = task::spawn(()); diff --git a/tokio-util/tests/framed_write.rs b/tokio-util/tests/framed_write.rs index 39091c0b1b5..32b2486791f 100644 --- a/tokio-util/tests/framed_write.rs +++ b/tokio-util/tests/framed_write.rs @@ -104,6 +104,32 @@ fn write_multi_frame_after_codec_changed() { }); } +#[test] +fn borrow_framed_write() { + let mut task = task::spawn(()); + let mock = mock! { + Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()), + }; + let mut framed = FramedWrite::new(mock, U32Encoder); + + task.enter(|cx, _| { + assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok()); + assert!(pin!(framed).start_send(0x04).is_ok()); + + let mut borrow_framed = framed.with_encoder(|_| U64Encoder); + assert!(assert_ready!(pin!(borrow_framed).poll_ready(cx)).is_ok()); + assert!(pin!(borrow_framed).start_send(0x08).is_ok()); + + // Nothing written yet + assert_eq!(1, framed.get_ref().calls.len()); + + // Flush the writes + assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok()); + + assert_eq!(0, framed.get_ref().calls.len()); + }); +} + #[test] fn write_hits_backpressure() { const ITER: usize = 2 * 1024;