Skip to content

Commit

Permalink
Make Room.connect cancellable (livekit#273)
Browse files Browse the repository at this point in the history
* engine connect

* connect flow

* cancellable completer

* cancellable WebSocket

* completer cancel test

* comment

* check cancel for queue actor
  • Loading branch information
hiroshihorie authored Nov 9, 2023
1 parent 3a07312 commit 29d61f1
Show file tree
Hide file tree
Showing 12 changed files with 121 additions and 61 deletions.
1 change: 0 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ let package = Package(
.macOS(.v10_15),
],
products: [
// Products define the executables and libraries a package produces, and make them visible to other packages.
.library(
name: "LiveKit",
targets: ["LiveKit"]
Expand Down
12 changes: 6 additions & 6 deletions Sources/LiveKit/Broadcast/BroadcastScreenCapturer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
* limitations under the License.
*/

import Foundation
#if os(iOS)

#if canImport(UIKit)
import UIKit
#endif
import Foundation

@_implementationOnly import WebRTC
#if canImport(UIKit)
import UIKit
#endif

#if os(iOS)
@_implementationOnly import WebRTC

class BroadcastScreenCapturer: BufferCapturer {
static let kRTCScreensharingSocketFD = "rtc_SSFD"
Expand Down
15 changes: 15 additions & 0 deletions Sources/LiveKit/Core/Engine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class Engine: MulticastDelegate<EngineDelegate> {
}

try await cleanUp()
try Task.checkCancellation()

_state.mutate { $0.connectionState = .connecting }

Expand All @@ -154,13 +155,19 @@ class Engine: MulticastDelegate<EngineDelegate> {
// Connect sequence successful
log("Connect sequence completed")

// Final check if cancelled, don't fire connected events
try Task.checkCancellation()

// update internal vars (only if connect succeeded)
_state.mutate {
$0.url = url
$0.token = token
$0.connectionState = .connected
}

} catch is CancellationError {
// Cancelled by .user
try await cleanUp(reason: .user)
} catch {
try await cleanUp(reason: .networkError(error))
}
Expand Down Expand Up @@ -344,10 +351,18 @@ extension Engine {
connectOptions: _state.connectOptions,
reconnectMode: _state.reconnectMode,
adaptiveStream: room._state.options.adaptiveStream)
// Check cancellation after WebSocket connected
try Task.checkCancellation()

let jr = try await signalClient.joinResponseCompleter.wait()
// Check cancellation after received join response
try Task.checkCancellation()

_state.mutate { $0.connectStopwatch.split(label: "signal") }
try await configureTransports(joinResponse: jr)
// Check cancellation after configuring transports
try Task.checkCancellation()

try await signalClient.resumeResponseQueue()
try await primaryTransportConnectedCompleter.wait()
_state.mutate { $0.connectStopwatch.split(label: "engine") }
Expand Down
14 changes: 5 additions & 9 deletions Sources/LiveKit/Core/SignalClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,8 @@ class SignalClient: MulticastDelegate<SignalClientDelegate> {
$0.connectionState = .connecting
}

let socket = WebSocket(url: url)

do {
try await socket.connect()
let socket = try await WebSocket(url: url)
_webSocket = socket
_state.mutate { $0.connectionState = .connected }

Expand Down Expand Up @@ -156,10 +154,8 @@ class SignalClient: MulticastDelegate<SignalClientDelegate> {
pingIntervalTimer = nil
pingTimeoutTimer = nil

if let socket = _webSocket {
socket.reset()
_webSocket = nil
}
_webSocket?.close()
_webSocket = nil

latestJoinResponse = nil

Expand Down Expand Up @@ -311,7 +307,7 @@ private extension SignalClient {

extension SignalClient {
func resumeResponseQueue() async throws {
await _responseQueue.resume { response in
try await _responseQueue.resume { response in
await processSignalResponse(response)
}
}
Expand All @@ -321,7 +317,7 @@ extension SignalClient {

extension SignalClient {
func sendQueuedRequests() async throws {
await _requestQueue.resume { element in
try await _requestQueue.resume { element in
do {
try await sendRequest(element, enqueueIfReconnecting: false)
} catch {
Expand Down
2 changes: 1 addition & 1 deletion Sources/LiveKit/Core/Transport.swift
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class Transport: MulticastDelegate<TransportDelegate> {
func set(remoteDescription sd: LKRTCSessionDescription) async throws {
try await _pc.setRemoteDescription(sd)

await _pendingCandidatesQueue.resume { candidate in
try await _pendingCandidatesQueue.resume { candidate in
do {
try await add(iceCandidate: candidate)
} catch {
Expand Down
2 changes: 2 additions & 0 deletions Sources/LiveKit/Errors.swift
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ public enum TrackError: LiveKitError {
}

public enum SignalClientError: LiveKitError {
case cancelled
case state(message: String? = nil)
case socketError(rawError: Error?)
case close(message: String? = nil)
Expand All @@ -105,6 +106,7 @@ public enum SignalClientError: LiveKitError {

public var description: String {
switch self {
case .cancelled: return buildDescription("cancelled")
case let .state(message): return buildDescription("state", message)
case let .socketError(rawError): return buildDescription("socketError", rawError: rawError)
case let .close(message): return buildDescription("close", message)
Expand Down
42 changes: 25 additions & 17 deletions Sources/LiveKit/Support/AsyncCompleter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class AsyncCompleter<T>: Loggable {

public func cancel() {
_cancelTimer()
if _continuation != nil {
log("\(label) cancelled")
}
_continuation?.resume(throwing: AsyncCompleterError.cancelled)
_continuation = nil
_returningValue = nil
Expand Down Expand Up @@ -140,24 +143,29 @@ class AsyncCompleter<T>: Loggable {
// Cancel any previous waits
cancel()

// Create a timed continuation
return try await withCheckedThrowingContinuation { continuation in
// Store reference to continuation
_continuation = continuation

// Create time-out block
let timeOutBlock = DispatchWorkItem { [weak self] in
guard let self else { return }
self.log("\(self.label) timedOut")
self._continuation?.resume(throwing: AsyncCompleterError.timedOut)
self._continuation = nil
self.cancel()
// Create a cancel-aware timed continuation
return try await withTaskCancellationHandler {
try await withCheckedThrowingContinuation { continuation in
// Store reference to continuation
_continuation = continuation

// Create time-out block
let timeOutBlock = DispatchWorkItem { [weak self] in
guard let self else { return }
self.log("\(self.label) timedOut")
self._continuation?.resume(throwing: AsyncCompleterError.timedOut)
self._continuation = nil
self.cancel()
}

// Schedule time-out block
_queue.asyncAfter(deadline: .now() + _timeOut, execute: timeOutBlock)
// Store reference to time-out block
_timeOutBlock = timeOutBlock
}

// Schedule time-out block
_queue.asyncAfter(deadline: .now() + _timeOut, execute: timeOutBlock)
// Store reference to time-out block
_timeOutBlock = timeOutBlock
} onCancel: {
// Cancel completer when Task gets cancelled
cancel()
}
}
}
6 changes: 4 additions & 2 deletions Sources/LiveKit/Support/AsyncQueueActor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ actor AsyncQueueActor<T> {
}

/// Mark as `.resumed` and process each element with an async `block`.
func resume(_ block: (T) async -> Void) async {
func resume(_ block: (T) async throws -> Void) async throws {
state = .resumed
if queue.isEmpty { return }
for element in queue {
await block(element)
// Check cancellation before processing next block...
try Task.checkCancellation()
try await block(element)
}
queue.removeAll()
}
Expand Down
23 changes: 13 additions & 10 deletions Sources/LiveKit/Support/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,27 @@ class WebSocket: NSObject, Loggable, AsyncSequence, URLSessionWebSocketDelegate
waitForNextValue()
}

init(url: URL) {
init(url: URL) async throws {
request = URLRequest(url: url,
cachePolicy: .useProtocolCachePolicy,
timeoutInterval: .defaultSocketConnect)
super.init()
try await withTaskCancellationHandler {
try await withCheckedThrowingContinuation { continuation in
connectContinuation = continuation
task.resume()
}
} onCancel: {
// Cancel(reset) when Task gets cancelled
close()
}
}

deinit {
reset()
}

public func connect() async throws {
try await withCheckedThrowingContinuation { continuation in
connectContinuation = continuation
task.resume()
}
close()
}

func reset() {
func close() {
task.cancel(with: .goingAway, reason: nil)
connectContinuation?.resume(throwing: SignalClientError.socketError(rawError: nil))
connectContinuation = nil
Expand Down
45 changes: 44 additions & 1 deletion Tests/LiveKitTests/CompleterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,48 @@ class CompleterTests: XCTestCase {

override func tearDown() async throws {}

func testCompleter() async throws {}
func testCompleterReuse() async throws {
let completer = AsyncCompleter<Void>(label: "Test01", timeOut: .seconds(1))
do {
try await completer.wait()
} catch AsyncCompleterError.timedOut {
print("Timed out 1")
}
// Re-use
do {
try await completer.wait()
} catch AsyncCompleterError.timedOut {
print("Timed out 2")
}
}

func testCompleterCancel() async throws {
let completer = AsyncCompleter<Void>(label: "cancel-test", timeOut: .never)
do {
// Run Tasks in parallel
try await withThrowingTaskGroup(of: Void.self) { group in

group.addTask {
print("Task 1: Waiting...")
try await completer.wait()
}

group.addTask {
print("Task 2: Started...")
// Cancel after 1 second
try await Task.sleep(until: .now + .seconds(1), clock: .continuous)
print("Task 2: Cancelling completer...")
completer.cancel()
}

try await group.waitForAll()
}
} catch let error as AsyncCompleterError where error == .timedOut {
print("Completer timed out")
} catch let error as AsyncCompleterError where error == .cancelled {
print("Completer cancelled")
} catch {
print("Unknown error: \(error)")
}
}
}
9 changes: 4 additions & 5 deletions Tests/LiveKitTests/TimerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
*/

@testable import LiveKit
import Promises
import XCTest

class TimerTests: XCTestCase {
Expand All @@ -35,10 +34,10 @@ class TimerTests: XCTestCase {
if self.counter == 3 {
print("suspending timer for 3s...")
self.timer.suspend()
Promise(()).delay(3).then {
print("restarting timer...")
self.timer.restart()
}
// Promise(()).delay(3).then {
// print("restarting timer...")
// self.timer.restart()
// }
}

if self.counter == 5 {
Expand Down
11 changes: 2 additions & 9 deletions Tests/LiveKitTests/WebSocketTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,13 @@
import XCTest

class WebSocketTests: XCTestCase {
lazy var socket: WebSocket = {
let url = URL(string: "wss://socketsbay.com/wss/v2/1/demo/")!
return WebSocket(url: url)
}()

override func setUpWithError() throws {}

override func tearDown() async throws {}

func testCompleter1() async throws {
// Read messages

func testWebSocket01() async throws {
print("Connecting...")
try await socket.connect()
let socket = try await WebSocket(url: URL(string: "wss://socketsbay.com/wss/v2/1/demo/")!)

print("Connected. Waiting for messages...")
do {
Expand Down

0 comments on commit 29d61f1

Please sign in to comment.