Skip to content

Commit

Permalink
Abortable streams (#2410)
Browse files Browse the repository at this point in the history
  • Loading branch information
ibraheemdev authored and taiki-e committed May 10, 2021
1 parent 0924ecb commit 90db30b
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 155 deletions.
185 changes: 185 additions & 0 deletions futures-util/src/abortable.rs
@@ -0,0 +1,185 @@
use crate::task::AtomicWaker;
use alloc::sync::Arc;
use core::fmt;
use core::pin::Pin;
use core::sync::atomic::{AtomicBool, Ordering};
use futures_core::future::Future;
use futures_core::task::{Context, Poll};
use futures_core::Stream;
use pin_project_lite::pin_project;

pin_project! {
/// A future/stream which can be remotely short-circuited using an `AbortHandle`.
#[derive(Debug, Clone)]
#[must_use = "futures/streams do nothing unless you poll them"]
pub struct Abortable<T> {
#[pin]
task: T,
inner: Arc<AbortInner>,
}
}

impl<T> Abortable<T> {
/// Creates a new `Abortable` future/stream using an existing `AbortRegistration`.
/// `AbortRegistration`s can be acquired through `AbortHandle::new`.
///
/// When `abort` is called on the handle tied to `reg` or if `abort` has
/// already been called, the future/stream will complete immediately without making
/// any further progress.
///
/// # Examples:
///
/// Usage with futures:
///
/// ```
/// # futures::executor::block_on(async {
/// use futures::future::{Abortable, AbortHandle, Aborted};
///
/// let (abort_handle, abort_registration) = AbortHandle::new_pair();
/// let future = Abortable::new(async { 2 }, abort_registration);
/// abort_handle.abort();
/// assert_eq!(future.await, Err(Aborted));
/// # });
/// ```
///
/// Usage with streams:
///
/// ```
/// # futures::executor::block_on(async {
/// # use futures::future::{Abortable, AbortHandle};
/// # use futures::stream::{self, StreamExt};
///
/// let (abort_handle, abort_registration) = AbortHandle::new_pair();
/// let mut stream = Abortable::new(stream::iter(vec![1, 2, 3]), abort_registration);
/// abort_handle.abort();
/// assert_eq!(stream.next().await, None);
/// # });
/// ```
pub fn new(task: T, reg: AbortRegistration) -> Self {
Self { task, inner: reg.inner }
}

/// Checks whether the task has been aborted. Note that all this
/// method indicates is whether [`AbortHandle::abort`] was *called*.
/// This means that it will return `true` even if:
/// * `abort` was called after the task had completed.
/// * `abort` was called while the task was being polled - the task may still be running and
/// will not be stopped until `poll` returns.
pub fn is_aborted(&self) -> bool {
self.inner.aborted.load(Ordering::Relaxed)
}
}

/// A registration handle for an `Abortable` task.
/// Values of this type can be acquired from `AbortHandle::new` and are used
/// in calls to `Abortable::new`.
#[derive(Debug)]
pub struct AbortRegistration {
inner: Arc<AbortInner>,
}

/// A handle to an `Abortable` task.
#[derive(Debug, Clone)]
pub struct AbortHandle {
inner: Arc<AbortInner>,
}

impl AbortHandle {
/// Creates an (`AbortHandle`, `AbortRegistration`) pair which can be used
/// to abort a running future or stream.
///
/// This function is usually paired with a call to [`Abortable::new`].
pub fn new_pair() -> (Self, AbortRegistration) {
let inner =
Arc::new(AbortInner { waker: AtomicWaker::new(), aborted: AtomicBool::new(false) });

(Self { inner: inner.clone() }, AbortRegistration { inner })
}
}

// Inner type storing the waker to awaken and a bool indicating that it
// should be aborted.
#[derive(Debug)]
struct AbortInner {
waker: AtomicWaker,
aborted: AtomicBool,
}

/// Indicator that the `Abortable` task was aborted.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct Aborted;

impl fmt::Display for Aborted {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "`Abortable` future has been aborted")
}
}

#[cfg(feature = "std")]
impl std::error::Error for Aborted {}

impl<T> Abortable<T> {
fn try_poll<I>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
poll: impl Fn(Pin<&mut T>, &mut Context<'_>) -> Poll<I>,
) -> Poll<Result<I, Aborted>> {
// Check if the task has been aborted
if self.is_aborted() {
return Poll::Ready(Err(Aborted));
}

// attempt to complete the task
if let Poll::Ready(x) = poll(self.as_mut().project().task, cx) {
return Poll::Ready(Ok(x));
}

// Register to receive a wakeup if the task is aborted in the future
self.inner.waker.register(cx.waker());

// Check to see if the task was aborted between the first check and
// registration.
// Checking with `is_aborted` which uses `Relaxed` is sufficient because
// `register` introduces an `AcqRel` barrier.
if self.is_aborted() {
return Poll::Ready(Err(Aborted));
}

Poll::Pending
}
}

impl<Fut> Future for Abortable<Fut>
where
Fut: Future,
{
type Output = Result<Fut::Output, Aborted>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.try_poll(cx, |fut, cx| fut.poll(cx))
}
}

impl<St> Stream for Abortable<St>
where
St: Stream,
{
type Item = St::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.try_poll(cx, |stream, cx| stream.poll_next(cx)).map(Result::ok).map(Option::flatten)
}
}

impl AbortHandle {
/// Abort the `Abortable` stream/future associated with this handle.
///
/// Notifies the Abortable task associated with this handle that it
/// should abort. Note that if the task is currently being polled on
/// another thread, it will not immediately stop running. Instead, it will
/// continue to run until its poll method returns.
pub fn abort(&self) {
self.inner.aborted.store(true, Ordering::Relaxed);
self.inner.waker.wake();
}
}
158 changes: 4 additions & 154 deletions futures-util/src/future/abortable.rs
@@ -1,101 +1,8 @@
use super::assert_future;
use crate::task::AtomicWaker;
use alloc::sync::Arc;
use core::fmt;
use core::pin::Pin;
use core::sync::atomic::{AtomicBool, Ordering};
use crate::future::{AbortHandle, Abortable, Aborted};
use futures_core::future::Future;
use futures_core::task::{Context, Poll};
use pin_project_lite::pin_project;

pin_project! {
/// A future which can be remotely short-circuited using an `AbortHandle`.
#[derive(Debug, Clone)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Abortable<Fut> {
#[pin]
future: Fut,
inner: Arc<AbortInner>,
}
}

impl<Fut> Abortable<Fut>
where
Fut: Future,
{
/// Creates a new `Abortable` future using an existing `AbortRegistration`.
/// `AbortRegistration`s can be acquired through `AbortHandle::new`.
///
/// When `abort` is called on the handle tied to `reg` or if `abort` has
/// already been called, the future will complete immediately without making
/// any further progress.
///
/// Example:
///
/// ```
/// # futures::executor::block_on(async {
/// use futures::future::{Abortable, AbortHandle, Aborted};
///
/// let (abort_handle, abort_registration) = AbortHandle::new_pair();
/// let future = Abortable::new(async { 2 }, abort_registration);
/// abort_handle.abort();
/// assert_eq!(future.await, Err(Aborted));
/// # });
/// ```
pub fn new(future: Fut, reg: AbortRegistration) -> Self {
assert_future::<Result<Fut::Output, Aborted>, _>(Self { future, inner: reg.inner })
}
}

/// A registration handle for a `Abortable` future.
/// Values of this type can be acquired from `AbortHandle::new` and are used
/// in calls to `Abortable::new`.
#[derive(Debug)]
pub struct AbortRegistration {
inner: Arc<AbortInner>,
}

/// A handle to a `Abortable` future.
#[derive(Debug, Clone)]
pub struct AbortHandle {
inner: Arc<AbortInner>,
}

impl AbortHandle {
/// Creates an (`AbortHandle`, `AbortRegistration`) pair which can be used
/// to abort a running future.
///
/// This function is usually paired with a call to `Abortable::new`.
///
/// Example:
///
/// ```
/// # futures::executor::block_on(async {
/// use futures::future::{Abortable, AbortHandle, Aborted};
///
/// let (abort_handle, abort_registration) = AbortHandle::new_pair();
/// let future = Abortable::new(async { 2 }, abort_registration);
/// abort_handle.abort();
/// assert_eq!(future.await, Err(Aborted));
/// # });
/// ```
pub fn new_pair() -> (Self, AbortRegistration) {
let inner =
Arc::new(AbortInner { waker: AtomicWaker::new(), cancel: AtomicBool::new(false) });

(Self { inner: inner.clone() }, AbortRegistration { inner })
}
}

// Inner type storing the waker to awaken and a bool indicating that it
// should be cancelled.
#[derive(Debug)]
struct AbortInner {
waker: AtomicWaker,
cancel: AtomicBool,
}

/// Creates a new `Abortable` future and a `AbortHandle` which can be used to stop it.
/// Creates a new `Abortable` future and an `AbortHandle` which can be used to stop it.
///
/// This function is a convenient (but less flexible) alternative to calling
/// `AbortHandle::new` and `Abortable::new` manually.
Expand All @@ -107,63 +14,6 @@ where
Fut: Future,
{
let (handle, reg) = AbortHandle::new_pair();
(Abortable::new(future, reg), handle)
}

/// Indicator that the `Abortable` future was aborted.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct Aborted;

impl fmt::Display for Aborted {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "`Abortable` future has been aborted")
}
}

#[cfg(feature = "std")]
impl std::error::Error for Aborted {}

impl<Fut> Future for Abortable<Fut>
where
Fut: Future,
{
type Output = Result<Fut::Output, Aborted>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Check if the future has been aborted
if self.inner.cancel.load(Ordering::Relaxed) {
return Poll::Ready(Err(Aborted));
}

// attempt to complete the future
if let Poll::Ready(x) = self.as_mut().project().future.poll(cx) {
return Poll::Ready(Ok(x));
}

// Register to receive a wakeup if the future is aborted in the... future
self.inner.waker.register(cx.waker());

// Check to see if the future was aborted between the first check and
// registration.
// Checking with `Relaxed` is sufficient because `register` introduces an
// `AcqRel` barrier.
if self.inner.cancel.load(Ordering::Relaxed) {
return Poll::Ready(Err(Aborted));
}

Poll::Pending
}
}

impl AbortHandle {
/// Abort the `Abortable` future associated with this handle.
///
/// Notifies the Abortable future associated with this handle that it
/// should abort. Note that if the future is currently being polled on
/// another thread, it will not immediately stop running. Instead, it will
/// continue to run until its poll method returns.
pub fn abort(&self) {
self.inner.cancel.store(true, Ordering::Relaxed);
self.inner.waker.wake();
}
let abortable = assert_future::<Result<Fut::Output, Aborted>, _>(Abortable::new(future, reg));
(abortable, handle)
}
4 changes: 3 additions & 1 deletion futures-util/src/future/mod.rs
Expand Up @@ -112,7 +112,9 @@ cfg_target_has_atomic! {
#[cfg(feature = "alloc")]
mod abortable;
#[cfg(feature = "alloc")]
pub use self::abortable::{abortable, Abortable, AbortHandle, AbortRegistration, Aborted};
pub use crate::abortable::{Abortable, AbortHandle, AbortRegistration, Aborted};
#[cfg(feature = "alloc")]
pub use abortable::abortable;
}

// Just a helper function to ensure the futures we're returning all have the
Expand Down
5 changes: 5 additions & 0 deletions futures-util/src/lib.rs
Expand Up @@ -336,5 +336,10 @@ pub use crate::io::{
#[cfg(feature = "alloc")]
pub mod lock;

cfg_target_has_atomic! {
#[cfg(feature = "alloc")]
mod abortable;
}

mod fns;
mod unfold_state;
19 changes: 19 additions & 0 deletions futures-util/src/stream/abortable.rs
@@ -0,0 +1,19 @@
use super::assert_stream;
use crate::stream::{AbortHandle, Abortable};
use crate::Stream;

/// Creates a new `Abortable` stream and an `AbortHandle` which can be used to stop it.
///
/// This function is a convenient (but less flexible) alternative to calling
/// `AbortHandle::new` and `Abortable::new` manually.
///
/// This function is only available when the `std` or `alloc` feature of this
/// library is activated, and it is activated by default.
pub fn abortable<St>(stream: St) -> (Abortable<St>, AbortHandle)
where
St: Stream,
{
let (handle, reg) = AbortHandle::new_pair();
let abortable = assert_stream::<St::Item, _>(Abortable::new(stream, reg));
(abortable, handle)
}

0 comments on commit 90db30b

Please sign in to comment.