Skip to content

Commit

Permalink
Merge branch 'main' into optional-async
Browse files Browse the repository at this point in the history
  • Loading branch information
palfrey committed Aug 6, 2022
2 parents ad95b26 + d108a62 commit c2b0024
Show file tree
Hide file tree
Showing 12 changed files with 321 additions and 164 deletions.
22 changes: 22 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,19 @@ extern crate serial_test;
for earlier versions.

You can then either add `#[serial]` or `#[serial(some_text)]` to tests as required.

For each test, a timeout can be specified with the `timeout_ms` parameter to the [serial](macro@serial) attribute. Note that
the timeout is counted from the first invocation of the test, not from the time the previous test was completed. This can
lead to some unpredictable behavior based on the number of parallel tests run on the system.
```rust
#[test]
#[serial(timeout_ms = 1000)]
fn test_serial_one() {
// Do things
}
#[test]
#[serial(test_name, timeout_ms = 1000)]
fn test_serial_another() {
// Do things
}
```
7 changes: 2 additions & 5 deletions serial_test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ log = { version = "0.4", optional = true }
futures = { version = "^0.3", default_features = false, features = [
"executor",
], optional = true}
dashmap = { version = "5"}

[dev-dependencies]
itertools = "0.10"
tokio = { version = "^1.17", features = ["macros", "rt"] }

[features]
default = ["logging", "timeout", "async"]
default = ["logging", "async"]

## Switches on debug logging (and requires the `log` package)
logging = ["log"]
Expand All @@ -37,10 +38,6 @@ async = ["futures"]
## The file_locks feature unlocks the `file_serial`/`file_parallel` macros
file_locks = ["fslock"]

## The `timeout` feature lets tests time out after a certain amount of time
## if not enabled tests will wait indefinitely to be started
timeout = []

docsrs = ["document-features"]

# docs.rs-specific configuration
Expand Down
70 changes: 20 additions & 50 deletions serial_test/src/code_lock.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
use crate::rwlock::{Locks, MutexGuardWrapper};
use dashmap::{try_result::TryResult, DashMap};
use lazy_static::lazy_static;
#[cfg(all(feature = "logging", feature = "timeout"))]
#[cfg(all(feature = "logging"))]
use log::debug;
use parking_lot::RwLock;
#[cfg(feature = "timeout")]
use std::time::Instant;
use std::{
collections::HashMap,
ops::{Deref, DerefMut},
sync::{atomic::AtomicU32, Arc},
time::Duration,
time::{Duration, Instant},
};

pub(crate) struct UniqueReentrantMutex {
Expand Down Expand Up @@ -45,16 +41,11 @@ impl UniqueReentrantMutex {
}

lazy_static! {
pub(crate) static ref LOCKS: Arc<RwLock<HashMap<String, UniqueReentrantMutex>>> =
Arc::new(RwLock::new(HashMap::new()));
pub(crate) static ref LOCKS: Arc<DashMap<String, UniqueReentrantMutex>> =
Arc::new(DashMap::new());
static ref MUTEX_ID: Arc<AtomicU32> = Arc::new(AtomicU32::new(1));
}

#[cfg(feature = "timeout")]
lazy_static! {
static ref MAX_WAIT: Arc<RwLock<Duration>> = Arc::new(RwLock::new(Duration::from_secs(60)));
}

impl Default for UniqueReentrantMutex {
fn default() -> Self {
Self {
Expand All @@ -64,60 +55,39 @@ impl Default for UniqueReentrantMutex {
}
}

/// Sets the maximum amount of time the serial locks will wait to unlock.
/// By default, this is set to 60 seconds, which is almost always much longer than is needed.
/// This is deliberately set high to try and avoid situations where we accidentally hit the limits
/// but is set at all so we can timeout rather than hanging forever.
///
/// However, sometimes if you've got a *lot* of serial tests it might theoretically not be enough,
/// hence this method.
///
/// This function is only available when the `timeout` feature is enabled.
#[cfg(feature = "timeout")]
pub fn set_max_wait(max_wait: Duration) {
*MAX_WAIT.write() = max_wait;
}

#[cfg(feature = "timeout")]
pub(crate) fn wait_duration() -> Duration {
*MAX_WAIT.read()
}

pub(crate) fn check_new_key(name: &str) {
#[cfg(feature = "timeout")]
pub(crate) fn check_new_key(name: &str, max_wait: Option<Duration>) {
let start = Instant::now();
loop {
#[cfg(all(feature = "logging", feature = "timeout"))]
#[cfg(all(feature = "logging"))]
{
let duration = start.elapsed();
debug!("Waiting for '{}' {:?}", name, duration);
}
// Check if a new key is needed. Just need a read lock, which can be done in sync with everyone else
let try_unlock = LOCKS.try_read_recursive_for(Duration::from_secs(1));
if let Some(unlock) = try_unlock {
if unlock.deref().contains_key(name) {
match LOCKS.try_get(name) {
TryResult::Present(_) => {
return;
}
drop(unlock); // so that we don't hold the read lock and so the writer can maybe succeed
} else {
continue; // wasn't able to get read lock
}
TryResult::Locked => {
continue; // wasn't able to get read lock
}
TryResult::Absent => {} // do the write path below
};

// This is the rare path, which avoids the multi-writer situation mostly
let try_lock = LOCKS.try_write_for(Duration::from_secs(1));
let try_entry = LOCKS.try_entry(name.to_string());

if let Some(mut lock) = try_lock {
lock.deref_mut().entry(name.to_string()).or_default();
if let Some(entry) = try_entry {
entry.or_default();
return;
}

// If the try_lock fails, then go around the loop again
// If the try_entry fails, then go around the loop again
// Odds are another test was also locking on the write and has now written the key

#[cfg(feature = "timeout")]
{
if let Some(max_wait) = max_wait {
let duration = start.elapsed();
if duration > wait_duration() {
if duration > max_wait {
panic!("Timeout waiting for '{}' {:?}", name, duration);
}
}
Expand Down
3 changes: 0 additions & 3 deletions serial_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ mod parallel_file_lock;
#[cfg(feature = "file_locks")]
mod serial_file_lock;

#[cfg(feature = "timeout")]
pub use code_lock::set_max_wait;

#[cfg(feature = "async")]
pub use parallel_code_lock::{local_async_parallel_core, local_async_parallel_core_with_return};

Expand Down
72 changes: 43 additions & 29 deletions serial_test/src/parallel_code_lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
use crate::code_lock::{check_new_key, LOCKS};
#[cfg(feature = "async")]
use futures::FutureExt;
use std::{ops::Deref, panic};
use std::{panic, time::Duration};

#[doc(hidden)]
pub fn local_parallel_core_with_return<E>(
name: &str,
max_wait: Option<Duration>,
function: fn() -> Result<(), E>,
) -> Result<(), E> {
check_new_key(name);
check_new_key(name, max_wait);

let unlock = LOCKS.read_recursive();
unlock.deref()[name].start_parallel();
let lock = LOCKS.get(name).unwrap();
lock.start_parallel();
let res = panic::catch_unwind(function);
unlock.deref()[name].end_parallel();
lock.end_parallel();
match res {
Ok(ret) => ret,
Err(err) => {
Expand All @@ -25,15 +26,15 @@ pub fn local_parallel_core_with_return<E>(
}

#[doc(hidden)]
pub fn local_parallel_core(name: &str, function: fn()) {
check_new_key(name);
pub fn local_parallel_core(name: &str, max_wait: Option<Duration>, function: fn()) {
check_new_key(name, max_wait);

let unlock = LOCKS.read_recursive();
unlock.deref()[name].start_parallel();
let lock = LOCKS.get(name).unwrap();
lock.start_parallel();
let res = panic::catch_unwind(|| {
function();
});
unlock.deref()[name].end_parallel();
lock.end_parallel();
if let Err(err) = res {
panic::resume_unwind(err);
}
Expand All @@ -43,14 +44,15 @@ pub fn local_parallel_core(name: &str, function: fn()) {
#[cfg(feature = "async")]
pub async fn local_async_parallel_core_with_return<E>(
name: &str,
max_wait: Option<Duration>,
fut: impl std::future::Future<Output = Result<(), E>> + panic::UnwindSafe,
) -> Result<(), E> {
check_new_key(name);
check_new_key(name, max_wait);

let unlock = LOCKS.read_recursive();
unlock.deref()[name].start_parallel();
let lock = LOCKS.get(name).unwrap();
lock.start_parallel();
let res = fut.catch_unwind().await;
unlock.deref()[name].end_parallel();
lock.end_parallel();
match res {
Ok(ret) => ret,
Err(err) => {
Expand All @@ -63,14 +65,15 @@ pub async fn local_async_parallel_core_with_return<E>(
#[cfg(feature = "async")]
pub async fn local_async_parallel_core(
name: &str,
max_wait: Option<Duration>,
fut: impl std::future::Future<Output = ()> + panic::UnwindSafe,
) {
check_new_key(name);
check_new_key(name, max_wait);

let unlock = LOCKS.read_recursive();
unlock.deref()[name].start_parallel();
let lock = LOCKS.get(name).unwrap();
lock.start_parallel();
let res = fut.catch_unwind().await;
unlock.deref()[name].end_parallel();
lock.end_parallel();
if let Err(err) = res {
panic::resume_unwind(err);
}
Expand All @@ -82,18 +85,20 @@ mod tests {
use crate::{local_async_parallel_core, local_async_parallel_core_with_return};

use crate::{code_lock::LOCKS, local_parallel_core, local_parallel_core_with_return};
use std::{io::Error, ops::Deref, panic};
use std::{io::Error, panic};

#[test]
fn unlock_on_assert_sync_without_return() {
let _ = panic::catch_unwind(|| {
local_parallel_core("unlock_on_assert_sync_without_return", || {
local_parallel_core("unlock_on_assert_sync_without_return", None, || {
assert!(false);
})
});
let unlock = LOCKS.read_recursive();
assert_eq!(
unlock.deref()["unlock_on_assert_sync_without_return"].parallel_count(),
LOCKS
.get("unlock_on_assert_sync_without_return")
.unwrap()
.parallel_count(),
0
);
}
Expand All @@ -103,15 +108,18 @@ mod tests {
let _ = panic::catch_unwind(|| {
local_parallel_core_with_return(
"unlock_on_assert_sync_with_return",
None,
|| -> Result<(), Error> {
assert!(false);
Ok(())
},
)
});
let unlock = LOCKS.read_recursive();
assert_eq!(
unlock.deref()["unlock_on_assert_sync_with_return"].parallel_count(),
LOCKS
.get("unlock_on_assert_sync_with_return")
.unwrap()
.parallel_count(),
0
);
}
Expand All @@ -123,17 +131,20 @@ mod tests {
assert!(false);
}
async fn call_serial_test_fn() {
local_async_parallel_core("unlock_on_assert_async_without_return", demo_assert()).await
local_async_parallel_core("unlock_on_assert_async_without_return", None, demo_assert())
.await
}
// as per https://stackoverflow.com/a/66529014/320546
let _ = panic::catch_unwind(|| {
let handle = tokio::runtime::Handle::current();
let _enter_guard = handle.enter();
futures::executor::block_on(call_serial_test_fn());
});
let unlock = LOCKS.read_recursive();
assert_eq!(
unlock.deref()["unlock_on_assert_async_without_return"].parallel_count(),
LOCKS
.get("unlock_on_assert_async_without_return")
.unwrap()
.parallel_count(),
0
);
}
Expand All @@ -150,6 +161,7 @@ mod tests {
async fn call_serial_test_fn() {
local_async_parallel_core_with_return(
"unlock_on_assert_async_with_return",
None,
demo_assert(),
)
.await;
Expand All @@ -161,9 +173,11 @@ mod tests {
let _enter_guard = handle.enter();
futures::executor::block_on(call_serial_test_fn());
});
let unlock = LOCKS.read_recursive();
assert_eq!(
unlock.deref()["unlock_on_assert_async_with_return"].parallel_count(),
LOCKS
.get("unlock_on_assert_async_with_return")
.unwrap()
.parallel_count(),
0
);
}
Expand Down
Loading

0 comments on commit c2b0024

Please sign in to comment.