Skip to content

Commit

Permalink
add get_or_try_init
Browse files Browse the repository at this point in the history
  • Loading branch information
b-naber committed Mar 31, 2021
1 parent cdc3edc commit 9b6d749
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 2 deletions.
63 changes: 62 additions & 1 deletion tokio/src/sync/once_cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ impl<T> OnceCell<T> {
unsafe { self.get_unchecked() }
} else {
// After acquire().await we have either acquired a permit while self.value
// is still uninitialized, or current thread is awoken after another thread
// is still uninitialized, or the current thread is awoken after another thread
// has intialized the value and closed the semaphore, in which case self.initialized
// is true and we don't set the value here
match self.semaphore.acquire().await {
Expand Down Expand Up @@ -258,6 +258,67 @@ impl<T> OnceCell<T> {
}
}

/// Tries to initialize the value of the OnceCell using the async function `f`.
/// If the value of the OnceCell was already initialized prior to this call,
/// a reference to that initialized value is returned. If some other thread
/// initiated the initialization prior to this call and the initialization
/// hasn't completed, this call waits until the initialization is finished.
/// If the function argument `f` returns an error, `get_or_try_init`
/// returns that error, otherwise the result of `f` will be stored in the cell.
///
/// This will deadlock if `f` tries to initialize the cell itself.
pub async fn get_or_try_init<E, F, Fut>(&self, f: F) -> Result<&T, E>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, E>>,
{
if self.initialized() {
// SAFETY: once the value is initialized, no mutable references are given out, so
// we can give out arbitrarily many immutable references
unsafe { Ok(self.get_unchecked()) }
} else {
// After acquire().await we have either acquired a permit while self.value
// is still uninitialized, or the current thread is awoken after another thread
// has intialized the value and closed the semaphore, in which case self.initialized
// is true and we don't set the value here
match self.semaphore.acquire().await {
Ok(_permit) => {
if !self.initialized() {
// If `f()` panics or `select!` is called, this `get_or_try_init` call
// is aborted and the semaphore permit is dropped.
let value = f().await;

match value {
Ok(value) => {
// SAFETY: There is only one permit on the semaphore, hence only one
// mutable reference is created
unsafe { self.set_value(value) };

// SAFETY: once the value is initialized, no mutable references are given out, so
// we can give out arbitrarily many immutable references
unsafe { Ok(self.get_unchecked()) }
},
Err(e) => Err(e),
}
} else {
unreachable!("acquired semaphore after value was already initialized.");
}
}
Err(_) => {
if self.initialized() {
// SAFETY: once the value is initialized, no mutable references are given out, so
// we can give out arbitrarily many immutable references
unsafe { Ok(self.get_unchecked()) }
} else {
unreachable!(
"Semaphore closed, but the OnceCell has not been initialized."
);
}
}
}
}
}

/// Moves the value out of the cell, destroying the cell in the process.
///
/// Returns `None` if the cell is uninitialized.
Expand Down
36 changes: 35 additions & 1 deletion tokio/tests/sync_once_cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ async fn func2() -> u32 {
10
}

async fn func_err() -> Result<u32, ()> {
Err(())
}

async fn func_ok() -> Result<u32, ()> {
Ok(10)
}

async fn func_panic() -> u32 {
time::sleep(Duration::from_millis(1)).await;
panic!();
Expand Down Expand Up @@ -144,6 +152,31 @@ fn set_while_initializing() {
});
}

#[test]
fn get_or_try_init() {
let rt = runtime::Builder::new_current_thread()
.enable_time()
.start_paused(true)
.build()
.unwrap();

static ONCE: OnceCell<u32> = OnceCell::const_new();

rt.block_on(async {
let handle1 = rt.spawn(async { ONCE.get_or_try_init(func_err).await });
let handle2 = rt.spawn(async { ONCE.get_or_try_init(func_ok).await });

time::advance(Duration::from_millis(1)).await;
time::resume();

let result1 = handle1.await;
assert!(result1.is_err());

let result2 = handle2.await.unwrap();
assert_eq!(*result2.unwrap(), 10);
});
}

#[test]
fn drop_cell() {
static NUM_DROPS: AtomicU32 = AtomicU32::new(0);
Expand All @@ -161,7 +194,7 @@ fn drop_cell() {
{
let once_cell = OnceCell::new();
let prev = once_cell.set(fooer);
assert!(prev == ())
assert!(prev.is_ok())
}
assert!(NUM_DROPS.load(Ordering::Acquire) == 1);
}
Expand All @@ -182,6 +215,7 @@ fn drop_cell_new_with() {

{
let once_cell = OnceCell::new_with(Some(fooer));
assert!(once_cell.initialized());
}
assert!(NUM_DROPS.load(Ordering::Acquire) == 1);
}
Expand Down

0 comments on commit 9b6d749

Please sign in to comment.