Skip to content

Commit fd350e4

Browse files
authored
Use trylock to eliminate the remaining race condition in Test.cancel(). (#1415)
This PR fixes the race condition in `Test.cancel()` that could occur if an unstructured task, created from within a test's task, called `Test.cancel()` at just the right moment. The order of events for the race is: - Unstructured task is created and inherits task-locals including the reference to the test's unsafe current task; - Test's task starts tearing down; - Unstructured task calls `takeUnsafeCurrentTask()` and gets a reference to the unsafe current task; - Test's task finishes tearing down; - Unstructured task calls `UnsafeCurrentTask.cancel()`. The fix is to use `trylock` semantics when cancelling the unsafe current task. If the test's task is still alive, the task is cancelled while the lock is held, which will block the test's task from being torn down as it has a lock-guarded call to clear the unsafe current task reference. If the test's task is no longer alive, the reference is already `nil` by the time the unstructured task acquires the lock and it bails early. If we recursively call `cancel()` (which can happen via the concurrency-level cancellation handler), the `trylock` means we won't acquire the lock a second time, so we won't end up deadlocking or aborting (which is what prevents calling `cancel()` while holding the lock in the current implementation). It is possible for `cancel()` to trigger user code, especially if the user has set up a cancellation handler, but there is no code path that can then lead to a deadlock because the only user-accessible calls that might touch this lock use `trylock`. I hope some part of that made sense. ### Checklist: - [x] Code and documentation should follow the style of the [Style Guide](https://github.com/apple/swift-testing/blob/main/Documentation/StyleGuide.md). - [x] If public symbols are renamed or modified, DocC references should be updated.
1 parent b0aef48 commit fd350e4

File tree

2 files changed

+120
-38
lines changed

2 files changed

+120
-38
lines changed

Sources/Testing/Support/Locked.swift

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ struct Locked<T> {
2626
/// A type providing storage for the underlying lock and wrapped value.
2727
#if SWT_TARGET_OS_APPLE && canImport(os)
2828
private typealias _Storage = ManagedBuffer<T, os_unfair_lock_s>
29+
#elseif !SWT_FIXED_85448 && (os(Linux) || os(Android))
30+
private final class _Storage: ManagedBuffer<T, pthread_mutex_t> {
31+
deinit {
32+
withUnsafeMutablePointerToElements { lock in
33+
_ = pthread_mutex_destroy(lock)
34+
}
35+
}
36+
}
2937
#else
3038
private final class _Storage {
3139
let mutex: Mutex<T>
@@ -49,6 +57,11 @@ extension Locked: RawRepresentable {
4957
_storage.withUnsafeMutablePointerToElements { lock in
5058
lock.initialize(to: .init())
5159
}
60+
#elseif !SWT_FIXED_85448 && (os(Linux) || os(Android))
61+
_storage = _Storage.create(minimumCapacity: 1, makingHeaderWith: { _ in rawValue }) as! _Storage
62+
_storage.withUnsafeMutablePointerToElements { lock in
63+
_ = pthread_mutex_init(lock, nil)
64+
}
5265
#else
5366
nonisolated(unsafe) let rawValue = rawValue
5467
_storage = _Storage(rawValue)
@@ -77,20 +90,72 @@ extension Locked {
7790
/// synchronous caller. Wherever possible, use actor isolation or other Swift
7891
/// concurrency tools.
7992
func withLock<R>(_ body: (inout T) throws -> sending R) rethrows -> sending R where R: ~Copyable {
93+
nonisolated(unsafe) let result: R
8094
#if SWT_TARGET_OS_APPLE && canImport(os)
81-
nonisolated(unsafe) let result = try _storage.withUnsafeMutablePointers { rawValue, lock in
95+
result = try _storage.withUnsafeMutablePointers { rawValue, lock in
8296
os_unfair_lock_lock(lock)
8397
defer {
8498
os_unfair_lock_unlock(lock)
8599
}
86100
return try body(&rawValue.pointee)
87101
}
88-
return result
102+
#elseif !SWT_FIXED_85448 && (os(Linux) || os(Android))
103+
result = try _storage.withUnsafeMutablePointers { rawValue, lock in
104+
pthread_mutex_lock(lock)
105+
defer {
106+
pthread_mutex_unlock(lock)
107+
}
108+
return try body(&rawValue.pointee)
109+
}
89110
#else
90-
try _storage.mutex.withLock { rawValue in
111+
result = try _storage.mutex.withLock { rawValue in
91112
try body(&rawValue)
92113
}
93114
#endif
115+
return result
116+
}
117+
118+
/// Try to acquire the lock and invoke a function while it is held.
119+
///
120+
/// - Parameters:
121+
/// - body: A closure to invoke while the lock is held.
122+
///
123+
/// - Returns: Whatever is returned by `body`, or `nil` if the lock could not
124+
/// be acquired.
125+
///
126+
/// - Throws: Whatever is thrown by `body`.
127+
///
128+
/// This function can be used to synchronize access to shared data from a
129+
/// synchronous caller. Wherever possible, use actor isolation or other Swift
130+
/// concurrency tools.
131+
func withLockIfAvailable<R>(_ body: (inout T) throws -> sending R) rethrows -> sending R? where R: ~Copyable {
132+
nonisolated(unsafe) let result: R?
133+
#if SWT_TARGET_OS_APPLE && canImport(os)
134+
result = try _storage.withUnsafeMutablePointers { rawValue, lock in
135+
guard os_unfair_lock_trylock(lock) else {
136+
return nil
137+
}
138+
defer {
139+
os_unfair_lock_unlock(lock)
140+
}
141+
return try body(&rawValue.pointee)
142+
}
143+
#elseif !SWT_FIXED_85448 && (os(Linux) || os(Android))
144+
result = try _storage.withUnsafeMutablePointers { rawValue, lock in
145+
guard 0 == pthread_mutex_trylock(lock) else {
146+
return nil
147+
}
148+
defer {
149+
pthread_mutex_unlock(lock)
150+
}
151+
return try body(&rawValue.pointee)
152+
}
153+
#else
154+
result = try _storage.mutex.withLockIfAvailable { rawValue in
155+
return try body(&rawValue)
156+
}
157+
#endif
158+
return result
94159
}
95160
}
96161

Sources/Testing/Test+Cancellation.swift

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,8 @@ protocol TestCancellable: Sendable {
2525

2626
// MARK: - Tracking the current task
2727

28-
/// A structure describing a reference to a task that is associated with some
29-
/// ``TestCancellable`` value.
30-
private struct _TaskReference: Sendable {
28+
/// A structure that is able to cancel a task.
29+
private struct _TaskCanceller: Sendable {
3130
/// The unsafe underlying reference to the associated task.
3231
private nonisolated(unsafe) var _unsafeCurrentTask = Locked<UnsafeCurrentTask?>()
3332

@@ -45,25 +44,46 @@ private struct _TaskReference: Sendable {
4544
_unsafeCurrentTask = withUnsafeCurrentTask { Locked(rawValue: $0) }
4645
}
4746

48-
/// Take this instance's reference to its associated task.
49-
///
50-
/// - Returns: An `UnsafeCurrentTask` instance, or `nil` if it was already
51-
/// taken or if it was never available.
52-
///
53-
/// This function consumes the reference to the task. After the first call,
54-
/// subsequent calls on the same instance return `nil`.
55-
func takeUnsafeCurrentTask() -> UnsafeCurrentTask? {
47+
/// Clear this instance's reference to its associated task without first
48+
/// cancelling it.
49+
func clear() {
5650
_unsafeCurrentTask.withLock { unsafeCurrentTask in
57-
let result = unsafeCurrentTask
5851
unsafeCurrentTask = nil
59-
return result
6052
}
6153
}
54+
55+
/// Cancel this instance's associated task and clear the reference to it.
56+
///
57+
/// - Returns: Whether or not this instance's task was cancelled.
58+
///
59+
/// After the first call to this function _starts_, subsequent calls on the
60+
/// same instance return `false`. In other words, if another thread calls this
61+
/// function before it has returned (or the same thread calls it recursively),
62+
/// it returns `false` without cancelling the task a second time.
63+
func cancel(with skipInfo: SkipInfo) -> Bool {
64+
// trylock means a recursive call to this function won't ruin our day, nor
65+
// should interleaving locks.
66+
_unsafeCurrentTask.withLockIfAvailable { unsafeCurrentTask in
67+
defer {
68+
unsafeCurrentTask = nil
69+
}
70+
if let unsafeCurrentTask {
71+
// The task is still valid, so we'll cancel it.
72+
$_currentSkipInfo.withValue(skipInfo) {
73+
unsafeCurrentTask.cancel()
74+
}
75+
return true
76+
}
77+
78+
// The task has already been cancelled and/or cleared.
79+
return false
80+
} ?? false
81+
}
6282
}
6383

64-
/// A dictionary of tracked tasks, keyed by types that conform to
84+
/// A dictionary of cancellable tasks keyed by types that conform to
6585
/// ``TestCancellable``.
66-
@TaskLocal private var _currentTaskReferences = [ObjectIdentifier: _TaskReference]()
86+
@TaskLocal private var _currentTaskCancellers = [ObjectIdentifier: _TaskCanceller]()
6787

6888
/// The instance of ``SkipInfo`` to propagate to children of the current task.
6989
///
@@ -87,16 +107,15 @@ extension TestCancellable {
87107
/// the current task, test, or test case is cancelled, it records a
88108
/// corresponding cancellation event.
89109
func withCancellationHandling<R>(_ body: () async throws -> R) async rethrows -> R {
90-
let taskReference = _TaskReference()
91-
var currentTaskReferences = _currentTaskReferences
92-
currentTaskReferences[ObjectIdentifier(Self.self)] = taskReference
93-
return try await $_currentTaskReferences.withValue(currentTaskReferences) {
94-
// Before returning, explicitly clear the stored task. This minimizes
95-
// the potential race condition that can occur if test code creates an
96-
// unstructured task and calls `Test.cancel()` in it after the test body
97-
// has finished.
110+
let taskCanceller = _TaskCanceller()
111+
var currentTaskCancellers = _currentTaskCancellers
112+
currentTaskCancellers[ObjectIdentifier(Self.self)] = taskCanceller
113+
return try await $_currentTaskCancellers.withValue(currentTaskCancellers) {
114+
// Before returning, explicitly clear the stored task so that an
115+
// unstructured task that inherits the task local isn't able to
116+
// accidentally cancel the task after it has been deallocated.
98117
defer {
99-
_ = taskReference.takeUnsafeCurrentTask()
118+
taskCanceller.clear()
100119
}
101120

102121
return try await withTaskCancellationHandler {
@@ -121,18 +140,16 @@ extension TestCancellable {
121140
/// - testAndTestCase: The test and test case to use when posting an event.
122141
/// - skipInfo: Information about the cancellation event.
123142
private func _cancel<T>(_ cancellableValue: T?, for testAndTestCase: (Test?, Test.Case?), skipInfo: SkipInfo) where T: TestCancellable {
124-
if cancellableValue != nil {
125-
// If the current test case is still running, take its task property (which
126-
// signals to subsequent callers that it has been cancelled.)
127-
let task = _currentTaskReferences[ObjectIdentifier(T.self)]?.takeUnsafeCurrentTask()
128-
129-
// If we just cancelled the current test case's task, post a corresponding
130-
// event with the relevant skip info.
131-
if let task {
132-
$_currentSkipInfo.withValue(skipInfo) {
133-
task.cancel()
134-
}
143+
if cancellableValue != nil, let taskCanceller = _currentTaskCancellers[ObjectIdentifier(T.self)] {
144+
// Try to cancel the task associated with `T`, if any. If we succeed, post a
145+
// corresponding event with the relevant skip info. If we fail, we still
146+
// attempt to cancel the current *task* in order to honor our API contract.
147+
if taskCanceller.cancel(with: skipInfo) {
135148
Event.post(T.makeCancelledEventKind(with: skipInfo), for: testAndTestCase)
149+
} else {
150+
withUnsafeCurrentTask { task in
151+
task?.cancel()
152+
}
136153
}
137154
} else {
138155
// The current task isn't associated with a test/case, so just cancel the

0 commit comments

Comments
 (0)