Skip to content

Commit

Permalink
Replace the linked list with a safer and less allocation-heavy altern…
Browse files Browse the repository at this point in the history
…ative (#38)

* Use slab to avoid unsafe code

* Move send+sync impls down to Mutex

* Code review

* Unwrap the key earlier

* Reduce the scope of one of the unsafe blocks.
  • Loading branch information
notgull committed Nov 21, 2022
1 parent 6496571 commit 0235e55
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 312 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ __test = []
[dependencies]
crossbeam-utils = { version = "0.8.12", default-features = false }
parking = { version = "2.0.0", optional = true }
slab = { version = "0.4.7", default-features = false }

[dev-dependencies]
waker-fn = "1"
Expand Down
27 changes: 7 additions & 20 deletions src/inner.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! The inner mechanism powering the `Event` type.

use crate::list::{Entry, List};
use crate::list::List;
use crate::node::Node;
use crate::queue::Queue;
use crate::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
Expand All @@ -11,7 +11,6 @@ use alloc::vec;
use alloc::vec::Vec;

use core::ops;
use core::ptr::NonNull;

/// Inner state of [`Event`].
pub(crate) struct Inner {
Expand All @@ -25,14 +24,6 @@ pub(crate) struct Inner {

/// Queue of nodes waiting to be processed.
queue: Queue,

/// A single cached list entry to avoid allocations on the fast path of the insertion.
///
/// This field can only be written to when the `cache_used` field in the `list` structure
/// is false, or the user has a pointer to the `Entry` identical to this one and that user
/// has exclusive access to that `Entry`. An immutable pointer to this field is kept in
/// the `list` structure when it is in use.
cache: UnsafeCell<Entry>,
}

impl Inner {
Expand All @@ -42,7 +33,6 @@ impl Inner {
notified: AtomicUsize::new(core::usize::MAX),
list: Mutex::new(List::new()),
queue: Queue::new(),
cache: UnsafeCell::new(Entry::new()),
}
}

Expand All @@ -62,12 +52,6 @@ impl Inner {
// Acquire and drop the lock to make sure that the queue is flushed.
let _guard = self.lock();
}

/// Returns the pointer to the single cached list entry.
#[inline(always)]
pub(crate) fn cache_ptr(&self) -> NonNull<Entry> {
unsafe { NonNull::new_unchecked(self.cache.get()) }
}
}

/// The guard returned by [`Inner::lock`].
Expand All @@ -88,11 +72,11 @@ impl ListGuard<'_> {
guard: &mut MutexGuard<'_, List>,
) {
// Process the start node.
tasks.extend(start_node.apply(guard, self.inner));
tasks.extend(start_node.apply(guard));

// Process all remaining nodes.
while let Some(node) = self.inner.queue.pop() {
tasks.extend(node.apply(guard, self.inner));
tasks.extend(node.apply(guard));
}
}
}
Expand Down Expand Up @@ -125,7 +109,7 @@ impl Drop for ListGuard<'_> {
}

// Update the atomic `notified` counter.
let notified = if list.notified < list.len {
let notified = if list.notified < list.len() {
list.notified
} else {
core::usize::MAX
Expand Down Expand Up @@ -224,3 +208,6 @@ impl<'a, T> ops::DerefMut for MutexGuard<'a, T> {
unsafe { &mut *self.mutex.value.get() }
}
}

unsafe impl<T: Send> Send for Mutex<T> {}
unsafe impl<T: Send> Sync for Mutex<T> {}
158 changes: 92 additions & 66 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,10 @@ use alloc::sync::Arc;

use core::fmt;
use core::future::Future;
use core::mem::ManuallyDrop;
use core::mem::{self, ManuallyDrop};
use core::num::NonZeroUsize;
use core::pin::Pin;
use core::ptr::{self, NonNull};
use core::ptr;
use core::sync::atomic::{self, AtomicPtr, AtomicUsize, Ordering};
use core::task::{Context, Poll, Waker};
use core::usize;
Expand All @@ -92,7 +93,7 @@ use std::time::{Duration, Instant};

use inner::Inner;
use list::{Entry, State};
use node::Node;
use node::{Node, TaskWaiting};

#[cfg(feature = "std")]
use parking::Unparker;
Expand Down Expand Up @@ -168,9 +169,6 @@ pub struct Event {
inner: AtomicPtr<Inner>,
}

unsafe impl Send for Event {}
unsafe impl Sync for Event {}

#[cfg(feature = "std")]
impl UnwindSafe for Event {}
#[cfg(feature = "std")]
Expand Down Expand Up @@ -210,31 +208,31 @@ impl Event {
let inner = self.inner();

// Try to acquire a lock in the inner list.
let entry = unsafe {
if let Some(mut lock) = (*inner).lock() {
let entry = lock.alloc((*inner).cache_ptr());
lock.insert(entry);
let state = {
let inner = unsafe { &*inner };
if let Some(mut lock) = inner.lock() {
let entry = lock.insert(Entry::new());

entry
ListenerState::HasNode(entry)
} else {
// Push entries into the queue indicating that we want to push a listener.
let (node, entry) = Node::listener();
(*inner).push(node);
inner.push(node);

// Indicate that there are nodes waiting to be notified.
(*inner)
inner
.notified
.compare_exchange(usize::MAX, 0, Ordering::AcqRel, Ordering::Relaxed)
.ok();

entry
ListenerState::Queued(entry)
}
};

// Register the listener.
let listener = EventListener {
inner: unsafe { Arc::clone(&ManuallyDrop::new(Arc::from_raw(inner))) },
entry: Some(entry),
state,
};

// Make sure the listener is registered before whatever happens next.
Expand Down Expand Up @@ -529,12 +527,20 @@ pub struct EventListener {
/// A reference to [`Event`]'s inner state.
inner: Arc<Inner>,

/// A pointer to this listener's entry in the linked list.
entry: Option<NonNull<Entry>>,
/// The current state of the listener.
state: ListenerState,
}

unsafe impl Send for EventListener {}
unsafe impl Sync for EventListener {}
enum ListenerState {
/// The listener has a node inside of the linked list.
HasNode(NonZeroUsize),

/// The listener has already been notified and has discarded its entry.
Discarded,

/// The listener has an entry in the queue that may or may not have a task waiting.
Queued(Arc<TaskWaiting>),
}

#[cfg(feature = "std")]
impl UnwindSafe for EventListener {}
Expand Down Expand Up @@ -605,11 +611,26 @@ impl EventListener {

fn wait_internal(mut self, deadline: Option<Instant>) -> bool {
// Take out the entry pointer and set it to `None`.
let entry = match self.entry.take() {
None => unreachable!("cannot wait twice on an `EventListener`"),
Some(entry) => entry,
};
let (parker, unparker) = parking::pair();
let entry = match self.state.take() {
ListenerState::HasNode(entry) => entry,
ListenerState::Queued(task_waiting) => {
// This listener is stuck in the backup queue.
// Wait for the task to be notified.
loop {
match task_waiting.status() {
Some(entry_id) => break entry_id,
None => {
// Register a task and park until it is notified.
task_waiting.register(Task::Thread(unparker.clone()));

parker.park();
}
}
}
}
ListenerState::Discarded => panic!("Cannot wait on a discarded listener"),
};

// Wait for the lock to be available.
let lock = || {
Expand All @@ -628,22 +649,15 @@ impl EventListener {

// Set this listener's state to `Waiting`.
{
let e = unsafe { entry.as_ref() };

if e.is_queued() {
// Write a task to be woken once the lock is acquired.
e.write_task(Task::Thread(unparker));
} else {
let mut list = lock();
let mut list = lock();

// If the listener was notified, we're done.
match e.state().replace(State::Notified(false)) {
State::Notified(_) => {
list.remove(entry, self.inner.cache_ptr());
return true;
}
_ => e.state().set(State::Task(Task::Thread(unparker))),
// If the listener was notified, we're done.
match list.state(entry).replace(State::Notified(false)) {
State::Notified(_) => {
list.remove(entry);
return true;
}
_ => list.state(entry).set(State::Task(Task::Thread(unparker))),
}
}

Expand All @@ -658,7 +672,7 @@ impl EventListener {
if now >= deadline {
// Remove the entry and check if notified.
let mut list = lock();
let state = list.remove(entry, self.inner.cache_ptr());
let state = list.remove(entry);
return state.is_notified();
}

Expand All @@ -668,17 +682,16 @@ impl EventListener {
}

let mut list = lock();
let e = unsafe { entry.as_ref() };

// Do a dummy replace operation in order to take out the state.
match e.state().replace(State::Notified(false)) {
match list.state(entry).replace(State::Notified(false)) {
State::Notified(_) => {
// If this listener has been notified, remove it from the list and return.
list.remove(entry, self.inner.cache_ptr());
list.remove(entry);
return true;
}
// Otherwise, set the state back to `Waiting`.
state => e.state().set(state),
state => list.state(entry).set(state),
}
}
}
Expand Down Expand Up @@ -706,10 +719,10 @@ impl EventListener {
/// ```
pub fn discard(mut self) -> bool {
// If this listener has never picked up a notification...
if let Some(entry) = self.entry.take() {
if let ListenerState::HasNode(entry) = self.state.take() {
// Remove the listener from the list and return `true` if it was notified.
if let Some(mut lock) = self.inner.lock() {
let state = lock.remove(entry, self.inner.cache_ptr());
let state = lock.remove(entry);

if let State::Notified(_) = state {
return true;
Expand Down Expand Up @@ -772,6 +785,30 @@ impl Future for EventListener {

#[allow(unreachable_patterns)]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let entry = match self.state {
ListenerState::Discarded => {
unreachable!("cannot poll a completed `EventListener` future")
}
ListenerState::HasNode(ref entry) => *entry,
ListenerState::Queued(ref task_waiting) => {
loop {
// See if the task waiting has been completed.
match task_waiting.status() {
Some(entry_id) => {
self.state = ListenerState::HasNode(entry_id);
break entry_id;
}
None => {
// If not, wait for it to complete.
task_waiting.register(Task::Waker(cx.waker().clone()));
if task_waiting.status().is_none() {
return Poll::Pending;
}
}
}
}
}
};
let mut list = match self.inner.lock() {
Some(list) => list,
None => {
Expand All @@ -787,20 +824,15 @@ impl Future for EventListener {
}
}
};

let entry = match self.entry {
None => unreachable!("cannot poll a completed `EventListener` future"),
Some(entry) => entry,
};
let state = unsafe { entry.as_ref().state() };
let state = list.state(entry);

// Do a dummy replace operation in order to take out the state.
match state.replace(State::Notified(false)) {
State::Notified(_) => {
// If this listener has been notified, remove it from the list and return.
list.remove(entry, self.inner.cache_ptr());
list.remove(entry);
drop(list);
self.entry = None;
self.state = ListenerState::Discarded;
return Poll::Ready(());
}
State::Created => {
Expand All @@ -827,12 +859,11 @@ impl Future for EventListener {
impl Drop for EventListener {
fn drop(&mut self) {
// If this listener has never picked up a notification...
if let Some(entry) = self.entry.take() {
if let ListenerState::HasNode(entry) = self.state.take() {
match self.inner.lock() {
Some(mut list) => {
// But if a notification was delivered to it...
if let State::Notified(additional) = list.remove(entry, self.inner.cache_ptr())
{
if let State::Notified(additional) = list.remove(entry) {
// Then pass it on to another active listener.
list.notify(1, additional);
}
Expand All @@ -849,6 +880,12 @@ impl Drop for EventListener {
}
}

impl ListenerState {
fn take(&mut self) -> Self {
mem::replace(self, ListenerState::Discarded)
}
}

/// Equivalent to `atomic::fence(Ordering::SeqCst)`, but in some cases faster.
#[inline]
fn full_fence() {
Expand Down Expand Up @@ -877,17 +914,6 @@ fn full_fence() {
}
}

/// Indicate that we're using spin-based contention and that we should yield the CPU.
#[inline]
fn yield_now() {
#[cfg(feature = "std")]
std::thread::yield_now();

#[cfg(not(feature = "std"))]
#[allow(deprecated)]
sync::atomic::spin_loop_hint();
}

#[cfg(any(feature = "__test", test))]
impl Event {
/// Locks the event.
Expand Down

0 comments on commit 0235e55

Please sign in to comment.