Skip to content

Commit

Permalink
WebSocket: Migrate to async/await from Promises lib (livekit#245)
Browse files Browse the repository at this point in the history
* progress

* format

* progress

* clean up

* ref

* optimize

* optimize
  • Loading branch information
hiroshihorie committed Oct 26, 2023
1 parent ec36cc3 commit 980146b
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 146 deletions.
87 changes: 53 additions & 34 deletions Sources/LiveKit/Core/SignalClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -108,42 +108,59 @@ internal class SignalClient: MulticastDelegate<SignalClientDelegate> {
$0.connectionState = .connecting
}

return WebSocket.connect(url: url,
onMessage: self.onWebSocketMessage,
onDisconnect: { reason in
self.webSocket = nil
self.cleanUp(reason: reason)
})
.then(on: queue) { (webSocket: WebSocket) -> Void in
self.webSocket = webSocket
self._state.mutate { $0.connectionState = .connected }
}.recover(on: queue) { error -> Promise<Void> in
// Skip validation if reconnect mode
if reconnectMode != nil { throw error }
// Catch first, then throw again after getting validation response
// Re-build url with validate mode
guard let validateUrl = Utils.buildUrl(urlString,
token,
connectOptions: connectOptions,
adaptiveStream: adaptiveStream,
validate: true) else {

return Promise(InternalError.parse(message: "Failed to parse validation url"))
let socket = WebSocket(url: url)

return Promise<Void> { resolve, reject in
Task {
do {
try await socket.connect()
self.webSocket = socket
self._state.mutate { $0.connectionState = .connected }
resolve(())

Task.detached {
self.log("Did enter WebSocket message loop...")
do {
for try await message in socket {
self.onWebSocketMessage(message: message)
}
} catch {
//
self.cleanUp(reason: .networkError(error))
}
self.log("Did exit WebSocket message loop...")
}
} catch {
reject(error)
}
}
}.recover(on: queue) { error -> Promise<Void> in
// Skip validation if reconnect mode
if reconnectMode != nil { throw error }
// Catch first, then throw again after getting validation response
// Re-build url with validate mode
guard let validateUrl = Utils.buildUrl(urlString,
token,
connectOptions: connectOptions,
adaptiveStream: adaptiveStream,
validate: true) else {

return Promise(InternalError.parse(message: "Failed to parse validation url"))
}

self.log("Validating with url: \(validateUrl)")
self.log("Validating with url: \(validateUrl)")

return HTTP().get(on: self.queue, url: validateUrl).then(on: self.queue) { data in
guard let string = String(data: data, encoding: .utf8) else {
throw SignalClientError.connect(message: "Failed to decode string")
}
self.log("validate response: \(string)")
// re-throw with validation response
throw SignalClientError.connect(message: "Validation response: \"\(string)\"")
return HTTP().get(on: self.queue, url: validateUrl).then(on: self.queue) { data in
guard let string = String(data: data, encoding: .utf8) else {
throw SignalClientError.connect(message: "Failed to decode string")
}
}.catch(on: queue) { error in
self.cleanUp(reason: .networkError(error))
self.log("validate response: \(string)")
// re-throw with validation response
throw SignalClientError.connect(message: "Validation response: \"\(string)\"")
}
}.catch(on: queue) { error in
self.cleanUp(reason: .networkError(error))
}
}

func cleanUp(reason: DisconnectReason? = nil) {
Expand All @@ -156,9 +173,11 @@ internal class SignalClient: MulticastDelegate<SignalClientDelegate> {
pingTimeoutTimer = nil

if let socket = webSocket {
socket.cleanUp(reason: reason, notify: false)
socket.onMessage = nil
socket.onDisconnect = nil
// socket.cleanUp(reason: reason, notify: false)
// socket.onMessage = nil
// socket.onDisconnect = nil
// self.webSocket?.cancel()
socket.reset()
self.webSocket = nil
}

Expand Down
191 changes: 79 additions & 112 deletions Sources/LiveKit/Support/WebSocket.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2022 LiveKit
* Copyright 2022-2023 LiveKit
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,163 +17,130 @@
import Foundation
import Promises

internal class WebSocket: NSObject, URLSessionWebSocketDelegate, Loggable {
internal typealias WebSocketStream = AsyncThrowingStream<URLSessionWebSocketTask.Message, Error>

private let queue = DispatchQueue(label: "LiveKitSDK.webSocket", qos: .default)
internal class WebSocket: NSObject, Loggable, AsyncSequence, URLSessionWebSocketDelegate {

typealias OnMessage = (URLSessionWebSocketTask.Message) -> Void
typealias OnDisconnect = (_ reason: DisconnectReason?) -> Void
typealias AsyncIterator = WebSocketStream.Iterator
typealias Element = URLSessionWebSocketTask.Message

public var onMessage: OnMessage?
public var onDisconnect: OnDisconnect?
private var streamContinuation: WebSocketStream.Continuation?
private var connectContinuation: CheckedContinuation<Void, Error>?

private let operationQueue = OperationQueue()
private let request: URLRequest

private var disconnected = false
private var connectPromise: Promise<WebSocket>?

private lazy var session: URLSession = {
private lazy var urlSession: URLSession = {
let config = URLSessionConfiguration.default
// explicitly set timeout intervals
config.timeoutIntervalForRequest = TimeInterval(60)
config.timeoutIntervalForResource = TimeInterval(604_800)
log("URLSessionConfiguration.timeoutIntervalForRequest: \(config.timeoutIntervalForRequest)")
log("URLSessionConfiguration.timeoutIntervalForResource: \(config.timeoutIntervalForResource)")
return URLSession(configuration: config,
delegate: self,
delegateQueue: operationQueue)
return URLSession(configuration: config, delegate: self, delegateQueue: nil)
}()

private lazy var task: URLSessionWebSocketTask = {
session.webSocketTask(with: request)
urlSession.webSocketTask(with: request)
}()

static func connect(url: URL,
onMessage: OnMessage? = nil,
onDisconnect: OnDisconnect? = nil) -> Promise<WebSocket> {

return WebSocket(url: url,
onMessage: onMessage,
onDisconnect: onDisconnect).connect()
}
private lazy var stream: WebSocketStream = {
return WebSocketStream { continuation in
streamContinuation = continuation
waitForNextValue()
}
}()

private init(url: URL,
onMessage: OnMessage? = nil,
onDisconnect: OnDisconnect? = nil) {
init(url: URL) {

request = URLRequest(url: url,
cachePolicy: .useProtocolCachePolicy,
timeoutInterval: .defaultSocketConnect)

self.onMessage = onMessage
self.onDisconnect = onDisconnect
super.init()
task.resume()
}

deinit {
log()
reset()
}

private func connect() -> Promise<WebSocket> {
connectPromise = Promise<WebSocket>.pending()
return connectPromise!
}

internal func cleanUp(reason: DisconnectReason?, notify: Bool = true) {

log("reason: \(String(describing: reason))")
public func connect() async throws {

guard !disconnected else {
log("dispose can be called only once", .warning)
return
}

// mark as disconnected, this instance cannot be re-used
disconnected = true

task.cancel()
session.invalidateAndCancel()

if let promise = connectPromise {
let sdkError = NetworkError.disconnected(message: "WebSocket disconnected")
promise.reject(sdkError)
connectPromise = nil
}

if notify {
onDisconnect?(reason)
try await withCheckedThrowingContinuation { continuation in
connectContinuation = continuation
task.resume()
}
}

public func send(data: Data) -> Promise<Void> {
let message = URLSessionWebSocketTask.Message.data(data)
return Promise(on: queue) { resolve, fail in
self.task.send(message) { error in
if let error = error {
fail(error)
return
}
resolve(())
}
}
func reset() {
task.cancel(with: .goingAway, reason: nil)
connectContinuation?.resume(throwing: SignalClientError.socketError(rawError: nil))
connectContinuation = nil
streamContinuation?.finish()
streamContinuation = nil
}

private func receive(task: URLSessionWebSocketTask,
result: Result<URLSessionWebSocketTask.Message, Error>) {
switch result {
case .failure(let error):
log("Failed to receive \(error)", .error)
// MARK: - AsyncSequence

case .success(let message):
onMessage?(message)
queue.async { task.receive { self.receive(task: task, result: $0) } }
}
func makeAsyncIterator() -> AsyncIterator {
return stream.makeAsyncIterator()
}

// MARK: - URLSessionWebSocketDelegate

internal func urlSession(_ session: URLSession,
webSocketTask: URLSessionWebSocketTask,
didOpenWithProtocol protocol: String?) {

guard !disconnected else {
private func waitForNextValue() {
guard task.closeCode == .invalid else {
streamContinuation?.finish()
streamContinuation = nil
return
}

if let promise = connectPromise {
promise.fulfill(self)
connectPromise = nil
}
task.receive(completionHandler: { [weak self] result in
guard let continuation = self?.streamContinuation else {
return
}

queue.async { webSocketTask.receive { self.receive(task: webSocketTask, result: $0) } }
do {
let message = try result.get()
continuation.yield(message)
self?.waitForNextValue()
} catch {
continuation.finish(throwing: error)
self?.streamContinuation = nil
}
})
}

internal func urlSession(_ session: URLSession,
webSocketTask: URLSessionWebSocketTask,
didCloseWith closeCode: URLSessionWebSocketTask.CloseCode,
reason: Data?) {
// MARK: - Send

guard !disconnected else {
return
}
public func send(data: Data) async throws {
let message = URLSessionWebSocketTask.Message.data(data)
try await task.send(message)
}

let sdkError = NetworkError.disconnected(message: "WebSocket did close with code: \(closeCode) reason: \(String(describing: reason))")
// MARK: - URLSessionWebSocketDelegate

cleanUp(reason: .networkError(sdkError))
func urlSession(_ session: URLSession, webSocketTask: URLSessionWebSocketTask, didOpenWithProtocol protocol: String?) {
connectContinuation?.resume()
connectContinuation = nil
}

internal func urlSession(_ session: URLSession,
task: URLSessionTask,
didCompleteWithError error: Error?) {

guard !disconnected else {
return
}
func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
log("didCompleteWithError: \(String(describing: error))", .error)
let error = error ?? NetworkError.disconnected(message: "WebSocket didCompleteWithError")
connectContinuation?.resume(throwing: error)
connectContinuation = nil
streamContinuation?.finish()
streamContinuation = nil
}
}

let sdkError = NetworkError.disconnected(message: "WebSocket disconnected", rawError: error)
internal extension WebSocket {

cleanUp(reason: .networkError(sdkError))
// Deprecate
func send(data: Data) -> Promise<Void> {
Promise { [self] resolve, fail in
Task {
do {
try await self.send(data: data)
resolve(())
} catch {
fail(error)
}
}
}
}
}
Loading

0 comments on commit 980146b

Please sign in to comment.