Skip to content

Commit

Permalink
Add async_channel without cancel safety
Browse files Browse the repository at this point in the history
Signed-off-by: Klimenty Tsoutsman <klim@tsoutsman.com>
  • Loading branch information
tsoutsman committed Dec 16, 2023
1 parent e06c84a commit a430baf
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 0 deletions.
14 changes: 14 additions & 0 deletions kernel/async_channel/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[package]
name = "async_channel"
version = "0.1.0"
authors = ["Klim Tsoutsman <klim@tsoutsman.com>"]
description = "A bounded, multi-producer, multi-consumer asynchronous channel"
edition = "2021"

[dependencies]
async_wait_queue = { path = "../async_wait_queue" }
dreadnought = { path = "../dreadnought" }
futures = { version = "0.3.28", default-features = false }
mpmc = "0.1.6"
sync = { path = "../../libs/sync" }
sync_spin = { path = "../../libs/sync_spin" }
164 changes: 164 additions & 0 deletions kernel/async_channel/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
//! A bounded, multi-producer, multi-consumer asynchronous channel.
//!
//! See [`Channel`] for more details.

#![no_std]

use core::{
pin::Pin,
task::{Context, Poll},
};

use async_wait_queue::WaitQueue;
use futures::stream::{FusedStream, Stream};
use mpmc::Queue;
use sync::DeadlockPrevention;
use sync_spin::Spin;

/// A bounded, multi-producer, multi-consumer asynchronous channel.
///
/// The channel can also be used outside of an asynchronous runtime with the
/// [`blocking_send`], and [`blocking_recv`] methods.
///
/// [`blocking_send`]: Self::blocking_send
/// [`blocking_recv`]: Self::blocking_recv
#[derive(Clone)]
pub struct Channel<T, P = Spin>
where
T: Send,
P: DeadlockPrevention,
{
inner: Queue<T>,
senders: WaitQueue<P>,
receivers: WaitQueue<P>,
}

impl<T, P> Channel<T, P>
where
T: Send,
P: DeadlockPrevention,
{
/// Creates a new channel.
///
/// The provided capacity dictates how many messages can be stored in the
/// queue before the sender blocks.
///
/// # Examples
///
/// ```
/// use async_channel::Channel;
///
/// let channel = Channel::new(2);
///
/// assert!(channel.try_send(1).is_ok());
/// assert!(channel.try_send(2).is_ok());
/// // The channel is full.
/// assert!(channel.try_send(3).is_err());
///
/// assert_eq!(channel.try_recv(), Some(1));
/// assert_eq!(channel.try_recv(), Some(2));
/// assert!(channel.try_recv().is_none());
/// ```
pub fn new(capacity: usize) -> Self {
Self {
inner: Queue::with_capacity(capacity),
senders: WaitQueue::new(),
receivers: WaitQueue::new(),
}
}

/// Sends `value`.
///
/// # Cancel safety
///
/// This method is cancel safe, in that if it is dropped prior to
/// completion, `value` is guaranteed to have not been set. However, in that
/// case `value` will be dropped.
pub async fn send(&self, value: T) {
let mut temp = Some(value);

self.senders
.wait_until(|| match self.inner.push(temp.take().unwrap()) {
Ok(()) => {
self.receivers.notify_one();
Some(())
}
Err(value) => {
temp = Some(value);
None
}
})
.await
}

/// Tries to send `value`.
///
/// # Errors
///
/// Returns an error containing `value` if the channel was full.
pub fn try_send(&self, value: T) -> Result<(), T> {
self.inner.push(value)?;
self.receivers.notify_one();
Ok(())
}

/// Blocks the current thread until `value` is sent.
pub fn blocking_send(&self, value: T) {
dreadnought::block_on(self.send(value))
}

/// Receives the next value.
///
/// # Cancel safety
///
/// This method is cancel safe.
pub async fn recv(&self) -> T {
let value = self.receivers.wait_until(|| self.inner.pop()).await;
self.senders.notify_one();
value
}

/// Tries to receive the next value.
pub fn try_recv(&self) -> Option<T> {
let value = self.inner.pop()?;
self.senders.notify_one();
Some(value)
}

/// Blocks the current thread until a value is received.
pub fn blocking_recv(&self) -> T {
dreadnought::block_on(self.recv())
}
}

impl<T, P> Stream for Channel<T, P>
where
T: Send,
P: DeadlockPrevention,
{
type Item = T;

fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self
.receivers
.poll_wait_until(ctx, &mut || self.inner.pop())
{
Poll::Ready(value) => {
self.senders.notify_one();
Poll::Ready(Some(value))
}
Poll::Pending => Poll::Pending,
}
}
}

impl<T, P> FusedStream for Channel<T, P>
where
T: Send,
P: DeadlockPrevention,
{
fn is_terminated(&self) -> bool {
// NOTE: If we ever implement disconnections, this will need to be modified.
false
}
}
12 changes: 12 additions & 0 deletions kernel/async_wait_queue/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "async_wait_queue"
version = "0.1.0"
authors = ["Klim Tsoutsman <klim@tsoutsman.com>"]
description = "An asynchronous wait queue"
edition = "2021"

[dependencies]
dreadnought = { path = "../dreadnought" }
mpmc_queue = { path = "../../libs/mpmc_queue" }
sync = { path = "../../libs/sync" }
sync_spin = { path = "../../libs/sync_spin" }
105 changes: 105 additions & 0 deletions kernel/async_wait_queue/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
//! An asynchronous wait queue.
//!
//! See [`WaitQueue`] for more details.

#![no_std]

extern crate alloc;

use alloc::sync::Arc;
use core::{
future::poll_fn,
task::{Context, Poll, Waker},
};

use mpmc_queue::Queue;
use sync::DeadlockPrevention;
use sync_spin::Spin;

/// An asynchronous queue of tasks waiting to be notified.
#[derive(Clone)]
pub struct WaitQueue<P = Spin>
where
P: DeadlockPrevention,
{
inner: Arc<Queue<Waker, P>>,
}

impl<P> Default for WaitQueue<P>
where
P: DeadlockPrevention,
{
fn default() -> Self {
Self::new()
}
}

impl<P> WaitQueue<P>
where
P: DeadlockPrevention,
{
/// Creates a new empty wait queue.
pub fn new() -> Self {
Self {
inner: Arc::new(Queue::new()),
}
}

pub async fn wait_until<F, T>(&self, mut condition: F) -> T
where
F: FnMut() -> Option<T>,
{
poll_fn(move |context| self.poll_wait_until(context, &mut condition)).await
}

pub fn poll_wait_until<F, T>(&self, ctx: &mut Context, condition: &mut F) -> Poll<T>
where
F: FnMut() -> Option<T>,
{
let wrapped_condition = || {
if let Some(value) = condition() {
Ok(value)
} else {
Err(())
}
};

match self
.inner
.push_if_fail(ctx.waker().clone(), wrapped_condition)
{
Ok(value) => Poll::Ready(value),
Err(()) => Poll::Pending,
}
}

pub fn blocking_wait_until<F, T>(&self, condition: F) -> T
where
F: FnMut() -> Option<T>,
{
dreadnought::block_on(self.wait_until(condition))
}

/// Notifies the first task in the wait queue.
///
/// Returns whether or not a task was awoken.
pub fn notify_one(&self) -> bool {
match self.inner.pop() {
Some(waker) => {
waker.wake();
// From the `Waker` documentation:
// > As long as the executor keeps running and the task is not
// finished, it is guaranteed that each invocation of `wake()`
// will be followed by at least one `poll()` of the task to
// which this `Waker` belongs.
true
}
None => false,
}
}

/// Notifies all the tasks in the wait queue.
pub fn notify_all(&self) {
while self.notify_one() {}
}
}

0 comments on commit a430baf

Please sign in to comment.