From c88ad9b9dc628617105370b4ba143c63e15a6b2c Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Sat, 21 Jan 2023 12:38:12 +1100 Subject: [PATCH] Add `Either::as_pin_mut` and `Either::as_pin_ref` (#2691) --- futures-util/src/future/either.rs | 58 +++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/futures-util/src/future/either.rs b/futures-util/src/future/either.rs index 9602de7a42..27e5064dfb 100644 --- a/futures-util/src/future/either.rs +++ b/futures-util/src/future/either.rs @@ -33,11 +33,31 @@ pub enum Either { } impl Either { - fn project(self: Pin<&mut Self>) -> Either, Pin<&mut B>> { + /// Convert `Pin<&Either>` to `Either, Pin<&B>>`, + /// pinned projections of the inner variants. + pub fn as_pin_ref(self: Pin<&Self>) -> Either, Pin<&B>> { + // SAFETY: We can use `new_unchecked` because the `inner` parts are + // guaranteed to be pinned, as they come from `self` which is pinned. unsafe { - match self.get_unchecked_mut() { - Either::Left(a) => Either::Left(Pin::new_unchecked(a)), - Either::Right(b) => Either::Right(Pin::new_unchecked(b)), + match *Pin::get_ref(self) { + Either::Left(ref inner) => Either::Left(Pin::new_unchecked(inner)), + Either::Right(ref inner) => Either::Right(Pin::new_unchecked(inner)), + } + } + } + + /// Convert `Pin<&mut Either>` to `Either, Pin<&mut B>>`, + /// pinned projections of the inner variants. + pub fn as_pin_mut(self: Pin<&mut Self>) -> Either, Pin<&mut B>> { + // SAFETY: `get_unchecked_mut` is fine because we don't move anything. + // We can use `new_unchecked` because the `inner` parts are guaranteed + // to be pinned, as they come from `self` which is pinned, and we never + // offer an unpinned `&mut A` or `&mut B` through `Pin<&mut Self>`. We + // also don't have an implementation of `Drop`, nor manual `Unpin`. + unsafe { + match *Pin::get_unchecked_mut(self) { + Either::Left(ref mut inner) => Either::Left(Pin::new_unchecked(inner)), + Either::Right(ref mut inner) => Either::Right(Pin::new_unchecked(inner)), } } } @@ -85,7 +105,7 @@ where type Output = A::Output; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll(cx), Either::Right(x) => x.poll(cx), } @@ -113,7 +133,7 @@ where type Item = A::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll_next(cx), Either::Right(x) => x.poll_next(cx), } @@ -149,28 +169,28 @@ where type Error = A::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll_ready(cx), Either::Right(x) => x.poll_ready(cx), } } fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.start_send(item), Either::Right(x) => x.start_send(item), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll_flush(cx), Either::Right(x) => x.poll_flush(cx), } } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll_close(cx), Either::Right(x) => x.poll_close(cx), } @@ -198,7 +218,7 @@ mod if_std { cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll_read(cx, buf), Either::Right(x) => x.poll_read(cx, buf), } @@ -209,7 +229,7 @@ mod if_std { cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>], ) -> Poll> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll_read_vectored(cx, bufs), Either::Right(x) => x.poll_read_vectored(cx, bufs), } @@ -226,7 +246,7 @@ mod if_std { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll_write(cx, buf), Either::Right(x) => x.poll_write(cx, buf), } @@ -237,21 +257,21 @@ mod if_std { cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll_write_vectored(cx, bufs), Either::Right(x) => x.poll_write_vectored(cx, bufs), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll_flush(cx), Either::Right(x) => x.poll_flush(cx), } } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll_close(cx), Either::Right(x) => x.poll_close(cx), } @@ -268,7 +288,7 @@ mod if_std { cx: &mut Context<'_>, pos: SeekFrom, ) -> Poll> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll_seek(cx, pos), Either::Right(x) => x.poll_seek(cx, pos), } @@ -281,14 +301,14 @@ mod if_std { B: AsyncBufRead, { fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.poll_fill_buf(cx), Either::Right(x) => x.poll_fill_buf(cx), } } fn consume(self: Pin<&mut Self>, amt: usize) { - match self.project() { + match self.as_pin_mut() { Either::Left(x) => x.consume(amt), Either::Right(x) => x.consume(amt), }