From 37d9d1759405bf237334ee8731ea2af6e134c8ff Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Sat, 9 Apr 2022 19:17:06 +0200 Subject: [PATCH 01/11] Start converting WebSocket to Swift Concurrency --- Package.resolved | 13 +- Package.swift | 11 +- ...cketTaskCloseCode+WebSocketCloseCode.swift | 28 +- ...essionWebSocketTaskMessage+WebSocket.swift | 25 + Sources/WebSocket/WebSocket.swift | 917 ++++++++++++------ Sources/WebSocket/WebSocketClient.swift | 69 ++ Sources/WebSocket/WebSocketCloseCode.swift | 46 + Sources/WebSocket/WebSocketError.swift | 11 +- Sources/WebSocket/WebSocketEvent.swift | 31 + ...ocketMessage+URLSessionWebSocketTask.swift | 27 - Sources/WebSocket/WebSocketMessage.swift | 19 + Sources/WebSocket/WebSocketOptions.swift | 17 + ...RLSessionWebSocketTaskCloseCodeTests.swift | 1 - Tests/WebSocketTests/WebSocketTests.swift | 453 +++++---- 14 files changed, 1066 insertions(+), 602 deletions(-) create mode 100644 Sources/WebSocket/WebSocketClient.swift create mode 100644 Sources/WebSocket/WebSocketCloseCode.swift create mode 100644 Sources/WebSocket/WebSocketEvent.swift delete mode 100644 Sources/WebSocket/WebSocketMessage+URLSessionWebSocketTask.swift create mode 100644 Sources/WebSocket/WebSocketMessage.swift create mode 100644 Sources/WebSocket/WebSocketOptions.swift diff --git a/Package.resolved b/Package.resolved index bd5d45f..4ae6b60 100644 --- a/Package.resolved +++ b/Package.resolved @@ -6,8 +6,8 @@ "repositoryURL": "https://github.com/apple/swift-nio.git", "state": { "branch": null, - "revision": "43931b7a7daf8120a487601530c8bc03ce711992", - "version": "2.25.1" + "revision": "d6e3762e0a5f7ede652559f53623baf11006e17c", + "version": "2.39.0" } }, { @@ -18,15 +18,6 @@ "revision": "f01e4a1ee5fbf586d612a8dc0bc068603f6b9450", "version": "3.0.0" } - }, - { - "package": "WebSocketProtocol", - "repositoryURL": "https://github.com/shareup/websocket-protocol.git", - "state": { - "branch": null, - "revision": "bd6257e3c4b23484dfc73c550b025f96c8e151f6", - "version": "2.3.1" - } } ] }, diff --git a/Package.swift b/Package.swift index 4c80549..19d2558 100644 --- a/Package.swift +++ b/Package.swift @@ -4,7 +4,7 @@ import PackageDescription let package = Package( name: "WebSocket", platforms: [ - .macOS(.v10_15), .iOS(.v13), .tvOS(.v13), .watchOS(.v6), + .macOS(.v11), .iOS(.v14), .tvOS(.v14), .watchOS(.v7), ], products: [ .library( @@ -17,16 +17,11 @@ let package = Package( url: "https://github.com/shareup/synchronized.git", from: "3.0.0" ), - .package( - name: "WebSocketProtocol", - url: "https://github.com/shareup/websocket-protocol.git", - from: "2.3.2" - ), - .package(name: "swift-nio", url: "https://github.com/apple/swift-nio.git", from: "2.0.0")], + .package(name: "swift-nio", url: "https://github.com/apple/swift-nio.git", from: "2.39.0")], targets: [ .target( name: "WebSocket", - dependencies: ["Synchronized", "WebSocketProtocol"]), + dependencies: ["Synchronized"]), .testTarget( name: "WebSocketTests", dependencies: [ diff --git a/Sources/WebSocket/URLSessionWebSocketTaskCloseCode+WebSocketCloseCode.swift b/Sources/WebSocket/URLSessionWebSocketTaskCloseCode+WebSocketCloseCode.swift index 67f1f29..875c542 100644 --- a/Sources/WebSocket/URLSessionWebSocketTaskCloseCode+WebSocketCloseCode.swift +++ b/Sources/WebSocket/URLSessionWebSocketTaskCloseCode+WebSocketCloseCode.swift @@ -1,8 +1,32 @@ import Foundation -import WebSocketProtocol -public extension URLSessionWebSocketTask.CloseCode { +extension URLSessionWebSocketTask.CloseCode { init?(_ closeCode: WebSocketCloseCode) { self.init(rawValue: closeCode.rawValue) } } + +extension WebSocketCloseCode { + init?(_ closeCode: URLSessionWebSocketTask.CloseCode?) { + guard let closeCode = closeCode else { return nil } + self.init(rawValue: closeCode.rawValue) + } + + var urlSessionCloseCode: URLSessionWebSocketTask.CloseCode { + switch self { + case .invalid: return .invalid + case .normalClosure: return .normalClosure + case .goingAway: return .goingAway + case .protocolError: return .protocolError + case .unsupportedData: return .unsupportedData + case .noStatusReceived: return .noStatusReceived + case .abnormalClosure: return .abnormalClosure + case .invalidFramePayloadData: return .invalidFramePayloadData + case .policyViolation: return .policyViolation + case .messageTooBig: return .messageTooBig + case .mandatoryExtensionMissing: return .mandatoryExtensionMissing + case .internalServerError: return .internalServerError + case .tlsHandshakeFailure: return .tlsHandshakeFailure + } + } +} diff --git a/Sources/WebSocket/URLSessionWebSocketTaskMessage+WebSocket.swift b/Sources/WebSocket/URLSessionWebSocketTaskMessage+WebSocket.swift index 73946a3..de1fa69 100644 --- a/Sources/WebSocket/URLSessionWebSocketTaskMessage+WebSocket.swift +++ b/Sources/WebSocket/URLSessionWebSocketTaskMessage+WebSocket.swift @@ -13,3 +13,28 @@ extension URLSessionWebSocketTask.Message: CustomDebugStringConvertible { } } } + +extension WebSocketMessage { + init(_ message: URLSessionWebSocketTask.Message) { + switch message { + case let .data(data): + self = .data(data) + case let .string(string): + self = .text(string) + @unknown default: + assertionFailure("Unknown WebSocket Message type") + self = .text("") + } + } +} + +extension Result: CustomDebugStringConvertible where Success == WebSocketMessage { + public var debugDescription: String { + switch self { + case let .success(message): + return message.debugDescription + case let .failure(error): + return error.localizedDescription + } + } +} diff --git a/Sources/WebSocket/WebSocket.swift b/Sources/WebSocket/WebSocket.swift index e3b87c3..593db8b 100644 --- a/Sources/WebSocket/WebSocket.swift +++ b/Sources/WebSocket/WebSocket.swift @@ -2,381 +2,666 @@ import Combine import Foundation import os.log import Synchronized -import WebSocketProtocol - -public final class WebSocket: WebSocketProtocol { - public typealias Output = Result - public typealias Failure = Swift.Error - - private enum State: CustomDebugStringConvertible { - case unopened - case connecting(URLSession, URLSessionWebSocketTask, WebSocketDelegate) - case open(URLSession, URLSessionWebSocketTask, WebSocketDelegate) - case closing - case closed(WebSocketError) - - var webSocketSessionAndTask: (URLSession, URLSessionWebSocketTask)? { - switch self { - case let .connecting(session, task, _), let .open(session, task, _): - return (session, task) - case .unopened, .closing, .closed: - return nil - } - } - var debugDescription: String { - switch self { - case .unopened: return "unopened" - case .connecting: return "connecting" - case .open: return "open" - case .closing: return "closing" - case .closed: return "closed" - } - } - } +final actor WebSocket { + let url: URL + let options: WebSocketOptions - /// The maximum number of bytes to buffer before the receive call fails with an error. - /// Default: 1 MiB - public var maximumMessageSize: Int = 1024 * 1024 { - didSet { sync { - guard let (_, task) = state.webSocketSessionAndTask else { return } - task.maximumMessageSize = maximumMessageSize - } } + var isOpen: Bool { + get async { + guard case .open = state else { return false } + return true + } } - public var isOpen: Bool { sync { - guard case .open = state else { return false } - return true - } } - - public var isClosed: Bool { sync { - guard case .closed = state else { return false } - return true - } } - - private let lock = RecursiveLock() - private func sync(_ block: () throws -> T) rethrows -> T { try lock.locked(block) } - - private let url: URL + var isClosed: Bool { get async { await !isOpen } } - private let timeoutIntervalForRequest: TimeInterval - private let timeoutIntervalForResource: TimeInterval + private var onStateChange: (WebSocketEvent) -> Void + private var state: WebSocketState = .unopened - private var state: State = .unopened - private let subject = PassthroughSubject() - - private let subjectQueue: DispatchQueue - - public convenience init(url: URL) { - self.init(url: url, publisherQueue: DispatchQueue.global()) - } - - public init( + init( url: URL, - timeoutIntervalForRequest: TimeInterval = 60, // 60 seconds - timeoutIntervalForResource: TimeInterval = 604_800, // 7 days - publisherQueue: DispatchQueue = DispatchQueue.global() - ) { + options: WebSocketOptions = .init(), + onStateChange: @escaping (WebSocketEvent) -> Void + ) async { self.url = url - self.timeoutIntervalForRequest = timeoutIntervalForRequest - self.timeoutIntervalForResource = timeoutIntervalForResource - subjectQueue = DispatchQueue( - label: "app.shareup.websocket.subjectqueue", - qos: .default, - autoreleaseFrequency: .workItem, - target: publisherQueue - ) + self.options = options + self.onStateChange = onStateChange + connect() } - deinit { - close() + func setOnStateChange(_ block: @escaping (WebSocketEvent) -> Void) async { + onStateChange = block } - public func connect() { - sync { - os_log( - "connect: oldstate=%{public}@", - log: .webSocket, - type: .debug, - state.debugDescription - ) + func close(_ code: WebSocketCloseCode) async { + switch state { + case let .connecting(session, task, _), let .open(session, task, _): + state = .closed(code.urlSessionCloseCode, nil) + onStateChange(.close(code, nil)) + task.cancel(with: code.urlSessionCloseCode, reason: nil) + session.finishTasksAndInvalidate() - switch state { - case .closed, .unopened: - let delegate = WebSocketDelegate( - onOpen: onOpen, - onClose: onClose, - onCompletion: onCompletion - ) - - let config = URLSessionConfiguration.default - config.timeoutIntervalForRequest = timeoutIntervalForRequest - config.timeoutIntervalForResource = timeoutIntervalForResource - - let session = URLSession( - configuration: config, - delegate: delegate, - delegateQueue: nil - ) - - let task = session.webSocketTask(with: url) - task.maximumMessageSize = maximumMessageSize - state = .connecting(session, task, delegate) - task.resume() - receiveFromWebSocket() - - default: - break - } + case .unopened, .closing, .closed: + break } } - public func receive(subscriber: S) - where S.Input == Result, S.Failure == Swift.Error - { - subject.receive(subscriber: subscriber) - } - - private func receiveFromWebSocket() { - let task: URLSessionWebSocketTask? = sync { - let webSocketTask = self.state.webSocketSessionAndTask?.1 - guard let task = webSocketTask, case .running = task.state else { return nil } - return task - } + func send(_ message: URLSessionWebSocketTask.Message) async throws { + // Mirrors the document behavior of JavaScript's `WebSocket` + // http://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send + switch state { + case let .open(session, task, _): + try await task.send(message) - task?.receive - { [weak self, weak task] (result: Result) in - guard let self = self else { return } + case .unopened, .connecting: + throw WebSocketError.sendMessageWhileConnecting - let _result = result.map { WebSocketMessage($0) } + case .closing, .closed: + break + } + } - guard task?.state == .running - else { - os_log( - "receive message in incorrect task state: message=%s taskstate=%{public}@", - log: .webSocket, - type: .debug, - _result.debugDescription, - "\(task?.state.rawValue ?? -1)" - ) - return - } + func receive() async throws -> URLSessionWebSocketTask.Message { + switch state { + case let .open(_, task, _): + return try await task.receive() - os_log("receive: %s", log: .webSocket, type: .debug, _result.debugDescription) - self.subjectQueue.async { [weak self] in self?.subject.send(_result) } - self.receiveFromWebSocket() - } + case .unopened, .connecting, .closing, .closed: + throw WebSocketError.receiveMessageWhenNotOpen + } } +} - public func send( - _ string: String, - completionHandler: @escaping (Error?) -> Void = { _ in } - ) { - os_log("send: %s", log: .webSocket, type: .debug, string) - send(.string(string), completionHandler: completionHandler) +private extension WebSocket { + func setState(_ state: WebSocketState) async { + self.state = state } - public func send(_ data: Data, completionHandler: @escaping (Error?) -> Void = { _ in }) { - os_log("send: %lld bytes", log: .webSocket, type: .debug, data.count) - send(.data(data), completionHandler: completionHandler) - } + func connect() { + os_log( + "connect: oldstate=%{public}@", + log: .webSocket, + type: .debug, + state.debugDescription + ) - private func send( - _ message: URLSessionWebSocketTask.Message, - completionHandler: @escaping (Error?) -> Void - ) { - let task: URLSessionWebSocketTask? = sync { - guard case let .open(_, task, _) = state, task.state == .running - else { - os_log( - "send message in incorrect task state: message=%s taskstate=%{public}@", - log: .webSocket, - type: .debug, - message.debugDescription, - "\(self.state.webSocketSessionAndTask?.1.state.rawValue ?? -1)" - ) - completionHandler(WebSocketError.notOpen) - return nil - } - return task - } + switch state { + case .closed, .unopened: + let delegate = WebSocketDelegate(onStateChange: onDelegateEvent) - task?.send(message, completionHandler: completionHandler) - } + let config = URLSessionConfiguration.default + config.timeoutIntervalForRequest = options.timeoutIntervalForRequest + config.timeoutIntervalForResource = options.timeoutIntervalForResource - public func close(_ closeCode: WebSocketCloseCode) { - let task: URLSessionWebSocketTask? = sync { - os_log( - "close: oldstate=%{public}@ code=%lld", - log: .webSocket, - type: .debug, - state.debugDescription, - closeCode.rawValue + let session = URLSession( + configuration: config, + delegate: delegate, + delegateQueue: nil ) - guard let (_, task) = state.webSocketSessionAndTask, task.state == .running - else { return nil } - state = .closing - return task - } + let task = session.webSocketTask(with: url) + task.maximumMessageSize = options.maximumMessageSize + state = .connecting(session, task, delegate) - let code = URLSessionWebSocketTask.CloseCode(closeCode) ?? .invalid - task?.cancel(with: code, reason: nil) - } -} + task.resume() -private typealias OnOpenHandler = (URLSession, URLSessionWebSocketTask, String?) -> Void -private typealias OnCloseHandler = ( - URLSession, - URLSessionWebSocketTask, - URLSessionWebSocketTask.CloseCode, - Data? -) -> Void -private typealias OnCompletionHandler = (URLSession, URLSessionTask, Error?) -> Void - -private let normalCloseCodes: [URLSessionWebSocketTask.CloseCode] = [.goingAway, .normalClosure] - -// MARK: onOpen and onClose + default: + break + } + } -private extension WebSocket { - var onOpen: OnOpenHandler { - { [weak self] webSocketSession, webSocketTask, _ in + var onDelegateEvent: (WebSocketDelegateEvent) async -> Void { + { [weak self] (event: WebSocketDelegateEvent) in guard let self = self else { return } - self.sync { - os_log( - "onOpen: oldstate=%{public}@", - log: .webSocket, - type: .debug, - self.state.debugDescription - ) - - guard case let .connecting(session, task, delegate) = self.state else { - os_log( - "receive onOpen callback in incorrect state: oldstate=%{public}@", - log: .webSocket, - type: .error, - self.state.debugDescription - ) - self.state = .open( - webSocketSession, - webSocketTask, - webSocketSession.delegate as! WebSocketDelegate + switch (await self.state, event) { + case let (.connecting(s1, t1, delegate), .open(s2, t2, _)): + guard s1 === s2 && t1 === t2 else { return } + await self.setState(.open(s2, t2, delegate)) + await self.onStateChange(.open) + + case let (.connecting(s1, t1, _), .close(s2, t2, closeCode, reason)), + let (.open(s1, t1, _), .close(s2, t2, closeCode, reason)): + guard s1 === s2, t1 === t2 else { return } + if let closeCode = closeCode { + await self.setState(.closed(closeCode, reason)) + } else { + await self.setState(.closed(.abnormalClosure, nil)) + } + await self.onStateChange(.close(.init(closeCode), reason)) + s2.invalidateAndCancel() + + case let (.connecting(s1, t1, _), .complete(s2, t2, error)), + let (.open(s1, t1, _), .complete(s2, t2, error)): + guard s1 === s2, t1 === t2 else { return } + if let error = error { + await self.setState( + .closed( + .internalServerError, + Data(error.localizedDescription.utf8)) ) - return + await self.onStateChange(.error(error as NSError)) + } else { + await self.setState(.closed(.internalServerError, nil)) + await self.onStateChange(.close(nil, nil)) } + s2.invalidateAndCancel() - assert(session === webSocketSession) - assert(task === webSocketTask) + case let (.closing, .close(session, _, closeCode, reason)): + if let closeCode = closeCode { + await self.setState(.closed(closeCode, reason)) + } else { + await self.setState(.closed(.abnormalClosure, nil)) + } + await self.onStateChange(.close(.init(closeCode), reason)) + session.invalidateAndCancel() - self.state = .open(webSocketSession, webSocketTask, delegate) - } + case (.unopened, _): + return - self.subjectQueue.async { [weak self] in self?.subject.send(.success(.open)) } - } - } + case (.closed, _): + return - var onClose: OnCloseHandler { - { [weak self] _, _, closeCode, reason in - guard let self = self else { return } + case (.closing, .open), (.closing, .complete): + return - self.sync { - os_log( - "onClose: oldstate=%{public}@ code=%lld", - log: .webSocket, - type: .debug, - self.state.debugDescription, - closeCode.rawValue - ) - - if case .closed = self.state { return } - self.state = .closed(WebSocketError.closed(closeCode, reason)) - - self.subjectQueue.async { [weak self] in - if normalCloseCodes.contains(closeCode) { - self?.subject.send(completion: .finished) - } else { - self?.subject.send( - completion: .failure(WebSocketError.closed(closeCode, reason)) - ) - } - } + case (.open, .open): + return } } } +} - var onCompletion: OnCompletionHandler { - { [weak self] webSocketSession, _, error in - defer { webSocketSession.invalidateAndCancel() } - guard let self = self else { return } +private enum WebSocketState: CustomDebugStringConvertible { + case unopened + case connecting(URLSession, URLSessionWebSocketTask, WebSocketDelegate) + case open(URLSession, URLSessionWebSocketTask, WebSocketDelegate) + case closing + case closed(URLSessionWebSocketTask.CloseCode, Data?) + + var webSocketSessionAndTask: (URLSession, URLSessionWebSocketTask)? { + switch self { + case let .connecting(session, task, _), let .open(session, task, _): + return (session, task) + case .unopened, .closing, .closed: + return nil + } + } - os_log("onCompletion", log: .webSocket, type: .debug) - - // "The only errors your delegate receives through the error parameter - // are client-side errors, such as being unable to resolve the hostname - // or connect to the host." - // - // https://developer.apple.com/documentation/foundation/urlsessiontaskdelegate/1411610-urlsession - // - // When receiving these errors, `onClose` is not called because the connection - // was never actually opened. - guard let error = error else { return } - self.sync { - os_log( - "onCompletion: oldstate=%{public}@ error=%@", - log: .webSocket, - type: .debug, - self.state.debugDescription, - error.localizedDescription - ) - - if case .closed = self.state { return } - self.state = .closed(.notOpen) - - self.subjectQueue.async { [weak self] in - self?.subject.send(completion: .failure(error)) - } - } + var debugDescription: String { + switch self { + case .unopened: return "unopened" + case .connecting: return "connecting" + case .open: return "open" + case .closing: return "closing" + case .closed: return "closed" } } } // MARK: URLSessionWebSocketDelegate +private enum WebSocketDelegateEvent { + case open(URLSession, URLSessionWebSocketTask, String?) + case close(URLSession, URLSessionWebSocketTask, URLSessionWebSocketTask.CloseCode?, Data?) + case complete(URLSession, URLSessionTask, Error?) +} + private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate { - private let onOpen: OnOpenHandler - private let onClose: OnCloseHandler - private let onCompletion: OnCompletionHandler - - init(onOpen: @escaping OnOpenHandler, - onClose: @escaping OnCloseHandler, - onCompletion: @escaping OnCompletionHandler) - { - self.onOpen = onOpen - self.onClose = onClose - self.onCompletion = onCompletion + private var onStateChange: (WebSocketDelegateEvent) async -> Void + + init(onStateChange: @escaping (WebSocketDelegateEvent) async -> Void) { + self.onStateChange = onStateChange super.init() } - func urlSession(_ webSocketSession: URLSession, - webSocketTask: URLSessionWebSocketTask, - didOpenWithProtocol protocol: String?) - { - onOpen(webSocketSession, webSocketTask, `protocol`) + func urlSession( + _ webSocketSession: URLSession, + webSocketTask: URLSessionWebSocketTask, + didOpenWithProtocol protocol: String? + ) { + Swift.print("$$$ \(#function)") + Task { await onStateChange(.open(webSocketSession, webSocketTask, `protocol`)) } + } + + func urlSession( + _ session: URLSession, + webSocketTask: URLSessionWebSocketTask, + didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, + reason: Data? + ) { + Swift.print("$$$ \(#function)") + Task { await onStateChange(.close(session, webSocketTask, closeCode, reason)) } } - func urlSession(_ session: URLSession, - webSocketTask: URLSessionWebSocketTask, - didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, - reason: Data?) - { - onClose(session, webSocketTask, closeCode, reason) + func urlSession( + _ session: URLSession, + task: URLSessionTask, + didCompleteWithError error: Error? + ) { + Swift.print("$$$ \(#function)") + Task { await onStateChange(.complete(session, task, error)) } } - func urlSession(_ session: URLSession, - task: URLSessionTask, - didCompleteWithError error: Error?) - { - onCompletion(session, task, error) + func urlSession(_ session: URLSession, didBecomeInvalidWithError error: Error?) { + Swift.print("$$$ \(#function): \(String(describing: error))") } } + +//private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate { +// private let onOpen: OnOpenHandler +// private let onClose: OnCloseHandler +// private let onCompletion: OnCompletionHandler +// +// init( +// onOpen: @escaping OnOpenHandler, +// onClose: @escaping OnCloseHandler, +// onCompletion: @escaping OnCompletionHandler +// ) { +// self.onOpen = onOpen +// self.onClose = onClose +// self.onCompletion = onCompletion +// super.init() +// } +// +// func urlSession( +// _ webSocketSession: URLSession, +// webSocketTask: URLSessionWebSocketTask, +// didOpenWithProtocol protocol: String? +// ) { +// onOpen(webSocketSession, webSocketTask, `protocol`) +// } +// +// func urlSession(_ session: URLSession, +// webSocketTask: URLSessionWebSocketTask, +// didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, +// reason: Data?) +// { +// onClose(session, webSocketTask, closeCode, reason) +// } +// +// func urlSession(_ session: URLSession, +// task: URLSessionTask, +// didCompleteWithError error: Error?) +// { +// onCompletion(session, task, error) +// } +//} + +//public final class WebSocket: WebSocketProtocol { +// public typealias Output = Result +// public typealias Failure = Swift.Error +// +// private enum State: CustomDebugStringConvertible { +// case unopened +// case connecting(URLSession, URLSessionWebSocketTask, WebSocketDelegate) +// case open(URLSession, URLSessionWebSocketTask, WebSocketDelegate) +// case closing +// case closed(WebSocketError) +// +// var webSocketSessionAndTask: (URLSession, URLSessionWebSocketTask)? { +// switch self { +// case let .connecting(session, task, _), let .open(session, task, _): +// return (session, task) +// case .unopened, .closing, .closed: +// return nil +// } +// } +// +// var debugDescription: String { +// switch self { +// case .unopened: return "unopened" +// case .connecting: return "connecting" +// case .open: return "open" +// case .closing: return "closing" +// case .closed: return "closed" +// } +// } +// } +// +// /// The maximum number of bytes to buffer before the receive call fails with an error. +// /// Default: 1 MiB +// public var maximumMessageSize: Int = 1024 * 1024 { +// didSet { sync { +// guard let (_, task) = state.webSocketSessionAndTask else { return } +// task.maximumMessageSize = maximumMessageSize +// } } +// } +// +// public var isOpen: Bool { sync { +// guard case .open = state else { return false } +// return true +// } } +// +// public var isClosed: Bool { sync { +// guard case .closed = state else { return false } +// return true +// } } +// +// private let lock = RecursiveLock() +// private func sync(_ block: () throws -> T) rethrows -> T { try lock.locked(block) } +// +// private let url: URL +// +// private let timeoutIntervalForRequest: TimeInterval +// private let timeoutIntervalForResource: TimeInterval +// +// private var state: State = .unopened +// private let subject = PassthroughSubject() +// +// private let subjectQueue: DispatchQueue +// +// public convenience init(url: URL) { +// self.init(url: url, publisherQueue: DispatchQueue.global()) +// } +// +// public init( +// url: URL, +// timeoutIntervalForRequest: TimeInterval = 60, // 60 seconds +// timeoutIntervalForResource: TimeInterval = 604_800, // 7 days +// publisherQueue: DispatchQueue = DispatchQueue.global() +// ) { +// self.url = url +// self.timeoutIntervalForRequest = timeoutIntervalForRequest +// self.timeoutIntervalForResource = timeoutIntervalForResource +// subjectQueue = DispatchQueue( +// label: "app.shareup.websocket.subjectqueue", +// qos: .default, +// autoreleaseFrequency: .workItem, +// target: publisherQueue +// ) +// } +// +// deinit { +// close() +// } +// +// public func connect() { +// sync { +// os_log( +// "connect: oldstate=%{public}@", +// log: .webSocket, +// type: .debug, +// state.debugDescription +// ) +// +// switch state { +// case .closed, .unopened: +// let delegate = WebSocketDelegate( +// onOpen: onOpen, +// onClose: onClose, +// onCompletion: onCompletion +// ) +// +// let config = URLSessionConfiguration.default +// config.timeoutIntervalForRequest = timeoutIntervalForRequest +// config.timeoutIntervalForResource = timeoutIntervalForResource +// +// let session = URLSession( +// configuration: config, +// delegate: delegate, +// delegateQueue: nil +// ) +// +// let task = session.webSocketTask(with: url) +// task.maximumMessageSize = maximumMessageSize +// state = .connecting(session, task, delegate) +// task.resume() +// receiveFromWebSocket() +// +// default: +// break +// } +// } +// } +// +// public func receive(subscriber: S) +// where S.Input == Result, S.Failure == Swift.Error +// { +// subject.receive(subscriber: subscriber) +// } +// +// private func receiveFromWebSocket() { +// let task: URLSessionWebSocketTask? = sync { +// let webSocketTask = self.state.webSocketSessionAndTask?.1 +// guard let task = webSocketTask, case .running = task.state else { return nil } +// return task +// } +// +// task?.receive +// { [weak self, weak task] (result: Result) in +// guard let self = self else { return } +// +// let _result = result.map { WebSocketMessage($0) } +// +// guard task?.state == .running +// else { +// os_log( +// "receive message in incorrect task state: message=%s taskstate=%{public}@", +// log: .webSocket, +// type: .debug, +// _result.debugDescription, +// "\(task?.state.rawValue ?? -1)" +// ) +// return +// } +// +// os_log("receive: %s", log: .webSocket, type: .debug, _result.debugDescription) +// self.subjectQueue.async { [weak self] in self?.subject.send(_result) } +// self.receiveFromWebSocket() +// } +// } +// +// public func send( +// _ string: String, +// completionHandler: @escaping (Error?) -> Void = { _ in } +// ) { +// os_log("send: %s", log: .webSocket, type: .debug, string) +// send(.string(string), completionHandler: completionHandler) +// } +// +// public func send(_ data: Data, completionHandler: @escaping (Error?) -> Void = { _ in }) { +// os_log("send: %lld bytes", log: .webSocket, type: .debug, data.count) +// send(.data(data), completionHandler: completionHandler) +// } +// +// private func send( +// _ message: URLSessionWebSocketTask.Message, +// completionHandler: @escaping (Error?) -> Void +// ) { +// let task: URLSessionWebSocketTask? = sync { +// guard case let .open(_, task, _) = state, task.state == .running +// else { +// os_log( +// "send message in incorrect task state: message=%s taskstate=%{public}@", +// log: .webSocket, +// type: .debug, +// message.debugDescription, +// "\(self.state.webSocketSessionAndTask?.1.state.rawValue ?? -1)" +// ) +// completionHandler(WebSocketError.notOpen) +// return nil +// } +// return task +// } +// +// task?.send(message, completionHandler: completionHandler) +// } +// +// public func close(_ closeCode: WebSocketCloseCode) { +// let task: URLSessionWebSocketTask? = sync { +// os_log( +// "close: oldstate=%{public}@ code=%lld", +// log: .webSocket, +// type: .debug, +// state.debugDescription, +// closeCode.rawValue +// ) +// +// guard let (_, task) = state.webSocketSessionAndTask, task.state == .running +// else { return nil } +// state = .closing +// return task +// } +// +// let code = URLSessionWebSocketTask.CloseCode(closeCode) ?? .invalid +// task?.cancel(with: code, reason: nil) +// } +//} +// +//private typealias OnOpenHandler = (URLSession, URLSessionWebSocketTask, String?) -> Void +//private typealias OnCloseHandler = ( +// URLSession, +// URLSessionWebSocketTask, +// URLSessionWebSocketTask.CloseCode, +// Data? +//) -> Void +//private typealias OnCompletionHandler = (URLSession, URLSessionTask, Error?) -> Void +// +//private let normalCloseCodes: [URLSessionWebSocketTask.CloseCode] = [.goingAway, .normalClosure] +// +//// MARK: onOpen and onClose +// +//private extension WebSocket { +// var onOpen: OnOpenHandler { +// { [weak self] webSocketSession, webSocketTask, _ in +// guard let self = self else { return } +// +// self.sync { +// os_log( +// "onOpen: oldstate=%{public}@", +// log: .webSocket, +// type: .debug, +// self.state.debugDescription +// ) +// +// guard case let .connecting(session, task, delegate) = self.state else { +// os_log( +// "receive onOpen callback in incorrect state: oldstate=%{public}@", +// log: .webSocket, +// type: .error, +// self.state.debugDescription +// ) +// self.state = .open( +// webSocketSession, +// webSocketTask, +// webSocketSession.delegate as! WebSocketDelegate +// ) +// return +// } +// +// assert(session === webSocketSession) +// assert(task === webSocketTask) +// +// self.state = .open(webSocketSession, webSocketTask, delegate) +// } +// +// self.subjectQueue.async { [weak self] in self?.subject.send(.success(.open)) } +// } +// } +// +// var onClose: OnCloseHandler { +// { [weak self] _, _, closeCode, reason in +// guard let self = self else { return } +// +// self.sync { +// os_log( +// "onClose: oldstate=%{public}@ code=%lld", +// log: .webSocket, +// type: .debug, +// self.state.debugDescription, +// closeCode.rawValue +// ) +// +// if case .closed = self.state { return } +// self.state = .closed(WebSocketError.closed(closeCode, reason)) +// +// self.subjectQueue.async { [weak self] in +// if normalCloseCodes.contains(closeCode) { +// self?.subject.send(completion: .finished) +// } else { +// self?.subject.send( +// completion: .failure(WebSocketError.closed(closeCode, reason)) +// ) +// } +// } +// } +// } +// } +// +// var onCompletion: OnCompletionHandler { +// { [weak self] webSocketSession, _, error in +// defer { webSocketSession.invalidateAndCancel() } +// guard let self = self else { return } +// +// os_log("onCompletion", log: .webSocket, type: .debug) +// +// // "The only errors your delegate receives through the error parameter +// // are client-side errors, such as being unable to resolve the hostname +// // or connect to the host." +// // +// // https://developer.apple.com/documentation/foundation/urlsessiontaskdelegate/1411610-urlsession +// // +// // When receiving these errors, `onClose` is not called because the connection +// // was never actually opened. +// guard let error = error else { return } +// self.sync { +// os_log( +// "onCompletion: oldstate=%{public}@ error=%@", +// log: .webSocket, +// type: .debug, +// self.state.debugDescription, +// error.localizedDescription +// ) +// +// if case .closed = self.state { return } +// self.state = .closed(.notOpen) +// +// self.subjectQueue.async { [weak self] in +// self?.subject.send(completion: .failure(error)) +// } +// } +// } +// } +//} +// +//// MARK: URLSessionWebSocketDelegate +// +//private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate { +// private let onOpen: OnOpenHandler +// private let onClose: OnCloseHandler +// private let onCompletion: OnCompletionHandler +// +// init(onOpen: @escaping OnOpenHandler, +// onClose: @escaping OnCloseHandler, +// onCompletion: @escaping OnCompletionHandler) +// { +// self.onOpen = onOpen +// self.onClose = onClose +// self.onCompletion = onCompletion +// super.init() +// } +// +// func urlSession(_ webSocketSession: URLSession, +// webSocketTask: URLSessionWebSocketTask, +// didOpenWithProtocol protocol: String?) +// { +// onOpen(webSocketSession, webSocketTask, `protocol`) +// } +// +// func urlSession(_ session: URLSession, +// webSocketTask: URLSessionWebSocketTask, +// didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, +// reason: Data?) +// { +// onClose(session, webSocketTask, closeCode, reason) +// } +// +// func urlSession(_ session: URLSession, +// task: URLSessionTask, +// didCompleteWithError error: Error?) +// { +// onCompletion(session, task, error) +// } +//} diff --git a/Sources/WebSocket/WebSocketClient.swift b/Sources/WebSocket/WebSocketClient.swift new file mode 100644 index 0000000..c54268f --- /dev/null +++ b/Sources/WebSocket/WebSocketClient.swift @@ -0,0 +1,69 @@ +import Foundation + +public struct WebSocketClient { + public var onStateChange: (@escaping (WebSocketEvent) -> Void) async -> Void + + /// Sends a close frame to the server with the given close code. + public var close: (WebSocketCloseCode) async throws -> Void + + /// Sends the WebSocket binary message. + public var sendBinary: (Data) async throws -> Void + + /// Sends the WebSocket text message. + public var sendText: (String) async throws -> Void + + /// Receives a message from the WebSocket. + public var receiveMessage: () async throws -> WebSocketMessage +} + +public extension WebSocketClient { + /// Calls `WebSocketProtocol.close(closeCode: .goingAway)`. + func close() async throws { + try await self.close(.goingAway) + } + + func receiveText() async throws -> String { + guard case let .text(text) = try await self.receiveMessage() + else { throw WebSocketError.expectedTextReceivedData } + return text + } + + func receiveData() async throws -> Data { + guard case let .data(data) = try await self.receiveMessage() + else { throw WebSocketError.expectedDataReceivedText } + return data + } +} + +public extension WebSocketClient { + static func system( + url: URL, + options: WebSocketOptions = .init(), + onStateChange: @escaping (WebSocketEvent) -> Void + ) async -> Self { + let ws = await WebSocket( + url: url, + options: options, + onStateChange: onStateChange + ) + + return Self( + onStateChange: { await ws.setOnStateChange($0) }, + close: { await ws.close($0) }, + sendBinary: { try await ws.send(.data($0)) }, + sendText: { try await ws.send(.string($0)) }, + receiveMessage: { + switch try await ws.receive() { + case let .data(data): + return .data(data) + + case let .string(text): + return .text(text) + + @unknown default: + throw WebSocketError.receiveUnknownMessageType + } + } + ) + } +} diff --git a/Sources/WebSocket/WebSocketCloseCode.swift b/Sources/WebSocket/WebSocketCloseCode.swift new file mode 100644 index 0000000..5ad47c6 --- /dev/null +++ b/Sources/WebSocket/WebSocketCloseCode.swift @@ -0,0 +1,46 @@ +import Foundation + +/// A code indicating why a WebSocket connection closed. +/// +/// Mirrors [URLSessionWebSocketTask](https://developer.apple.com/documentation/foundation/urlsessionwebsockettask/closecode). +public enum WebSocketCloseCode: Int, CaseIterable { + + /// A code that indicates the connection is still open. + case invalid = 0 + + /// A code that indicates normal connection closure. + case normalClosure = 1000 + + /// A code that indicates an endpoint is going away. + case goingAway = 1001 + + /// A code that indicates an endpoint terminated the connection due to a protocol error. + case protocolError = 1002 + + /// A code that indicates an endpoint terminated the connection after receiving a type of data it can’t accept. + case unsupportedData = 1003 + + /// A reserved code that indicates an endpoint expected a status code and didn’t receive one. + case noStatusReceived = 1005 + + /// A reserved code that indicates the connection closed without a close control frame. + case abnormalClosure = 1006 + + /// A code that indicates the server terminated the connection because it received data inconsistent with the message’s type. + case invalidFramePayloadData = 1007 + + /// A code that indicates an endpoint terminated the connection because it received a message that violates its policy. + case policyViolation = 1008 + + /// A code that indicates an endpoint is terminating the connection because it received a message too big for it to process. + case messageTooBig = 1009 + + /// A code that indicates the client terminated the connection because the server didn’t negotiate a required extension. + case mandatoryExtensionMissing = 1010 + + /// A code that indicates the server terminated the connection because it encountered an unexpected condition. + case internalServerError = 1011 + + /// A reserved code that indicates the connection closed due to the failure to perform a TLS handshake. + case tlsHandshakeFailure = 1015 +} diff --git a/Sources/WebSocket/WebSocketError.swift b/Sources/WebSocket/WebSocketError.swift index 7000884..55fe10a 100644 --- a/Sources/WebSocket/WebSocketError.swift +++ b/Sources/WebSocket/WebSocketError.swift @@ -1,8 +1,9 @@ import Foundation -public enum WebSocketError: Error { - case invalidURL(URL) - case invalidURLComponents(URLComponents) - case notOpen - case closed(URLSessionWebSocketTask.CloseCode, Data?) +public enum WebSocketError: Error, Hashable { + case sendMessageWhileConnecting + case receiveMessageWhenNotOpen + case receiveUnknownMessageType + case expectedTextReceivedData + case expectedDataReceivedText } diff --git a/Sources/WebSocket/WebSocketEvent.swift b/Sources/WebSocket/WebSocketEvent.swift new file mode 100644 index 0000000..e2e218b --- /dev/null +++ b/Sources/WebSocket/WebSocketEvent.swift @@ -0,0 +1,31 @@ +import Foundation + +/// Lifecycle events related to the opening or closing of the WebSocket. +public enum WebSocketEvent: Hashable, CustomStringConvertible { + /// Fired when a connection with a WebSocket is opened. + case open + + /// Fired when a connection with a WebSocket is closed. + case close(WebSocketCloseCode?, Data?) + + /// Fired when a connection with a WebSocket has been closed because of an error, + /// such as when some data couldn't be sent. + case error(NSError?) + + public var description: String { + switch self { + case .open: + return "open" + + case let .close(code, reason): + if let reason = reason { + return "close(\(code?.rawValue ?? -1), \(String(data: reason, encoding: .utf8) ?? ""))" + } else { + return "close(\(code?.rawValue ?? -1))" + } + + case let .error(error): + return "error(\(error?.localizedDescription ?? ""))" + } + } +} diff --git a/Sources/WebSocket/WebSocketMessage+URLSessionWebSocketTask.swift b/Sources/WebSocket/WebSocketMessage+URLSessionWebSocketTask.swift deleted file mode 100644 index 7a96fad..0000000 --- a/Sources/WebSocket/WebSocketMessage+URLSessionWebSocketTask.swift +++ /dev/null @@ -1,27 +0,0 @@ -import Foundation -import WebSocketProtocol - -extension WebSocketMessage { - init(_ message: URLSessionWebSocketTask.Message) { - switch message { - case let .data(data): - self = .binary(data) - case let .string(string): - self = .text(string) - @unknown default: - assertionFailure("Unknown WebSocket Message type") - self = .text("") - } - } -} - -extension Result: CustomDebugStringConvertible where Success == WebSocketMessage { - public var debugDescription: String { - switch self { - case let .success(message): - return message.debugDescription - case let .failure(error): - return error.localizedDescription - } - } -} diff --git a/Sources/WebSocket/WebSocketMessage.swift b/Sources/WebSocket/WebSocketMessage.swift new file mode 100644 index 0000000..e963377 --- /dev/null +++ b/Sources/WebSocket/WebSocketMessage.swift @@ -0,0 +1,19 @@ +import Foundation + +/// An enumeration of the types of messages that can be received. +public enum WebSocketMessage: CustomStringConvertible, CustomDebugStringConvertible, Hashable { + /// A WebSocket message that contains a block of data. + case data(Data) + + /// A WebSocket message that contains a UTF-8 formatted string. + case text(String) + + public var description: String { + switch self { + case let .data(data): return "\(data.count) bytes" + case let .text(text): return text + } + } + + public var debugDescription: String { self.description } +} diff --git a/Sources/WebSocket/WebSocketOptions.swift b/Sources/WebSocket/WebSocketOptions.swift new file mode 100644 index 0000000..d88d0e2 --- /dev/null +++ b/Sources/WebSocket/WebSocketOptions.swift @@ -0,0 +1,17 @@ +import Foundation + +public struct WebSocketOptions: Hashable { + public var maximumMessageSize: Int + public var timeoutIntervalForRequest: TimeInterval + public var timeoutIntervalForResource: TimeInterval + + public init( + maximumMessageSize: Int = 1024 * 1024, // 1 MiB + timeoutIntervalForRequest: TimeInterval = 60, // 60 seconds + timeoutIntervalForResource: TimeInterval = 604_800 // 7 days + ) { + self.maximumMessageSize = maximumMessageSize + self.timeoutIntervalForRequest = timeoutIntervalForRequest + self.timeoutIntervalForResource = timeoutIntervalForResource + } +} diff --git a/Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift b/Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift index 66df139..c638c0e 100644 --- a/Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift +++ b/Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift @@ -1,5 +1,4 @@ @testable import WebSocket -import WebSocketProtocol import XCTest class URLSessionWebSocketTaskCloseCodeTests: XCTestCase { diff --git a/Tests/WebSocketTests/WebSocketTests.swift b/Tests/WebSocketTests/WebSocketTests.swift index 3853d3d..05b94da 100644 --- a/Tests/WebSocketTests/WebSocketTests.swift +++ b/Tests/WebSocketTests/WebSocketTests.swift @@ -1,285 +1,274 @@ import Combine @testable import WebSocket -import WebSocketProtocol import XCTest private var ports = (50000 ... 52000).map { UInt16($0) } +// NOTE: If `WebSocketTests` is not marked as `@MainActor`, calls to +// `wait(for:timeout:)` prevent other asyncronous events from running. +// Using `await waitForExpectations(timeout:handler:)` works properly +// because it's already marked as `@MainActor`. + +@MainActor class WebSocketTests: XCTestCase { func url(_ port: UInt16) -> URL { URL(string: "ws://0.0.0.0:\(port)/socket")! } - func testCanConnectToAndDisconnectFromServer() throws { - try withServer { _, client in - let sub = client.sink( - receiveCompletion: expectFinished(), - receiveValue: expectValueAndThen(WebSocketMessage.open, client.close()) - ) - defer { sub.cancel() } + func testCanConnectToAndDisconnectFromServer() async throws { + let openEx = expectation(description: "Should have opened") + let closeEx = expectation(description: "Should have closed") + let (server, client) = await makeServerAndClient { event in + switch event { + case .open: + openEx.fulfill() - client.connect() - waitForExpectations(timeout: 2) + case let .close(closeCode, _): + XCTAssertEqual(.normalClosure, closeCode) + closeEx.fulfill() - XCTAssertFalse(client.isOpen) - XCTAssertTrue(client.isClosed) + case let .error(error): + XCTFail("Should not have received error: \(String(describing: error))") + } } - } + defer { server.close() } - func testCompleteWhenServerIsUnreachable() throws { - try withServer { server, client in - server.close() + wait(for: [openEx], timeout: 0.5) - let sub = client.sink( - receiveCompletion: expectFailure(), - receiveValue: { result in - switch result { - case .failure: - // It's possible to receive or not receive an error. - // Clients need to be resilient in the face of this reality. - break - case let .success(message): - XCTFail("Should not have received message: \(message)") - } - } - ) - defer { sub.cancel() } + let isOpen = await client.isOpen + XCTAssertTrue(isOpen) - client.connect() - waitForExpectations(timeout: 0.2) - - XCTAssertTrue(client.isClosed) - } + await client.close(.normalClosure) + wait(for: [closeEx], timeout: 0.5) } - func testCompleteWhenRemoteCloses() throws { - try withServer { _, client in - var invalidUTF8Bytes = [0x192, 0x193] as [UInt16] - let bytes = withUnsafeBytes(of: &invalidUTF8Bytes) { Array($0) } - let data = Data(bytes: bytes, count: bytes.count) - - let openEx = self.expectation(description: "Should have opened") - let errorEx = self.expectation(description: "Should have erred") - - let sub = client.sink( - receiveCompletion: expectFailure(), - receiveValue: { result in - switch result { - case .success(.open): - XCTAssertTrue(client.isOpen) - XCTAssertFalse(client.isClosed) - client.send(data) - openEx.fulfill() - case let .failure(error as NSError): - XCTAssertEqual("NSPOSIXErrorDomain", error.domain) - XCTAssertEqual(57, error.code) - errorEx.fulfill() - default: - break - } - } - ) - defer { sub.cancel() } - - client.connect() - waitForExpectations(timeout: 2) - - XCTAssertFalse(client.isOpen) - XCTAssertTrue(client.isClosed) +// func testCustom() async throws { +// let (server, client) = await makeServerAndClient() +// +// try await Task.sleep(nanoseconds: NSEC_PER_SEC * 10000) +// server.close() +// } + + func testErrorWhenServerIsUnreachable() async throws { + let ex = expectation(description: "Should have errored") + let (server, client) = await makeOfflineServerAndClient { event in + guard case let .error(error) = event else { + return XCTFail("Should not have received \(event)") + } + XCTAssertEqual(-1004, error?.code) + ex.fulfill() } - } - - func testEchoPush() throws { - try withEchoServer { _, client in - let message = "hello" - let completion = self.expectNoError() + defer { server.close() } - let sub = client.sink( - receiveCompletion: expectFinished(), - receiveValue: expectValuesAndThen([ - .open: { client.send(message, completionHandler: completion) }, - .text(message): { client.close() }, - ]) - ) - defer { sub.cancel() } + waitForExpectations(timeout: 0.5) - client.connect() - waitForExpectations(timeout: 2) - } + let isClosed = await client.isClosed + XCTAssertTrue(isClosed) } - func testEchoBinaryPush() throws { - try withEchoServer { _, client in - let message = "hello" - let binary = message.data(using: .utf8)! - let completion = self.expectNoError() - - let sub = client.sink( - receiveCompletion: expectFinished(), - receiveValue: expectValuesAndThen([ - .open: { client.send(binary, completionHandler: completion) }, - .text(message): { client.close() }, - ]) - ) - defer { sub.cancel() } + func testErrorWhenRemoteCloses() async throws { + var invalidUTF8Bytes = [0x192, 0x193] as [UInt16] + let bytes = withUnsafeBytes(of: &invalidUTF8Bytes) { Array($0) } + let data = Data(bytes: bytes, count: bytes.count) - client.connect() - waitForExpectations(timeout: 2) - } - } + let openEx = expectation(description: "Should have opened") + let errorEx = expectation(description: "Should have errored") - func testJoinLobbyAndEcho() throws { - let joinPush = "[1,1,\"room:lobby\",\"phx_join\",{}]" - let echoPush1 = "[1,2,\"room:lobby\",\"echo\",{\"echo\":\"one\"}]" - let echoPush2 = "[1,3,\"room:lobby\",\"echo\",{\"echo\":\"two\"}]" + let (server, client) = await makeServerAndClient { event in + switch event { + case .open: + openEx.fulfill() - let joinReply = "[1,1,\"room:lobby\",\"phx_reply\",{\"response\":{},\"status\":\"ok\"}]" - let echoReply1 = - "[1,2,\"room:lobby\",\"phx_reply\",{\"response\":{\"echo\":\"one\"},\"status\":\"ok\"}]" - let echoReply2 = - "[1,3,\"room:lobby\",\"phx_reply\",{\"response\":{\"echo\":\"two\"},\"status\":\"ok\"}]" + case .close: + Swift.print("$$$ CLOSED") + XCTFail("Should not have closed") - let joinCompletion = expectNoError() - let echo1Completion = expectNoError() - let echo2Completion = expectNoError() + case let .error(error): + Swift.print("$$$ ERROR: \(String(describing: error))") + errorEx.fulfill() + } + } + defer { server.close() } - try withReplyServer([joinReply, echoReply1, echoReply2]) { _, client in - let sub = client.sink( - receiveCompletion: expectFinished(), - receiveValue: expectValuesAndThen([ - .open: { client.send(joinPush, completionHandler: joinCompletion) }, - .text(joinReply): { client.send(echoPush1, completionHandler: echo1Completion) - }, - .text(echoReply1): { client.send(echoPush2, completionHandler: echo2Completion) - }, - .text(echoReply2): { client.close() }, - ]) - ) - defer { sub.cancel() } + wait(for: [openEx], timeout: 0.5) + let isOpen = await client.isOpen + XCTAssertTrue(isOpen) - client.connect() - waitForExpectations(timeout: 2) - } + try await client.send(.data(data)) + wait(for: [errorEx], timeout: 0.5) +// let isClosed = await client.isClosed +// XCTAssertTrue(isClosed) } - func testCanSendFromTwoThreadsSimultaneously() throws { - let queueCount = 8 - let queues = (0 ..< queueCount).map { DispatchQueue(label: "\($0)") } - - let messageCount = 100 - let sendMessages: (WebSocket) -> Void = { client in - (0 ..< messageCount).forEach { messageIndex in - (0 ..< queueCount).forEach { queueIndex in - queues[queueIndex].async { client.send("\(queueIndex)-\(messageIndex)") } - } - } + func testEchoPush() async throws { + let openEx = expectation(description: "Should have opened") + let (server, client) = await makeEchoServerAndClient { event in + guard case .open = event else { return } + openEx.fulfill() } + defer { server.close() } - let receiveMessageEx = expectation( - description: "Should have received \(queueCount * messageCount) messages" - ) - receiveMessageEx.expectedFulfillmentCount = queueCount * messageCount + wait(for: [openEx], timeout: 0.5) - try withEchoServer { _, client in - let sub = client.sink( - receiveCompletion: { _ in }, - receiveValue: { message in - switch message { - case .success(.open): - sendMessages(client) - case .success(.text): - receiveMessageEx.fulfill() - default: - XCTFail() - } - } - ) - defer { sub.cancel() } + try await client.send(.string("hello")) + guard case let .string(text) = try await client.receive() + else { return XCTFail("Should have received text") } - client.connect() - waitForExpectations(timeout: 10) - client.close() - } + XCTAssertEqual("hello", text) } + +// func testEchoPush() throws { +// try withEchoServer { _, client in +// let message = "hello" +// let completion = self.expectNoError() +// +// let sub = client.sink( +// receiveCompletion: expectFinished(), +// receiveValue: expectValuesAndThen([ +// .open: { client.send(message, completionHandler: completion) }, +// .text(message): { client.close() }, +// ]) +// ) +// defer { sub.cancel() } +// +// client.connect() +// waitForExpectations(timeout: 2) +// } +// } +// +// func testEchoBinaryPush() throws { +// try withEchoServer { _, client in +// let message = "hello" +// let data = message.data(using: .utf8)! +// let completion = self.expectNoError() +// +// let sub = client.sink( +// receiveCompletion: expectFinished(), +// receiveValue: expectValuesAndThen([ +// .open: { client.send(data, completionHandler: completion) }, +// .text(message): { client.close() }, +// ]) +// ) +// defer { sub.cancel() } +// +// client.connect() +// waitForExpectations(timeout: 2) +// } +// } +// +// func testJoinLobbyAndEcho() throws { +// let joinPush = "[1,1,\"room:lobby\",\"phx_join\",{}]" +// let echoPush1 = "[1,2,\"room:lobby\",\"echo\",{\"echo\":\"one\"}]" +// let echoPush2 = "[1,3,\"room:lobby\",\"echo\",{\"echo\":\"two\"}]" +// +// let joinReply = "[1,1,\"room:lobby\",\"phx_reply\",{\"response\":{},\"status\":\"ok\"}]" +// let echoReply1 = +// "[1,2,\"room:lobby\",\"phx_reply\",{\"response\":{\"echo\":\"one\"},\"status\":\"ok\"}]" +// let echoReply2 = +// "[1,3,\"room:lobby\",\"phx_reply\",{\"response\":{\"echo\":\"two\"},\"status\":\"ok\"}]" +// +// let joinCompletion = expectNoError() +// let echo1Completion = expectNoError() +// let echo2Completion = expectNoError() +// +// try withReplyServer([joinReply, echoReply1, echoReply2]) { _, client in +// let sub = client.sink( +// receiveCompletion: expectFinished(), +// receiveValue: expectValuesAndThen([ +// .open: { client.send(joinPush, completionHandler: joinCompletion) }, +// .text(joinReply): { client.send(echoPush1, completionHandler: echo1Completion) +// }, +// .text(echoReply1): { client.send(echoPush2, completionHandler: echo2Completion) +// }, +// .text(echoReply2): { client.close() }, +// ]) +// ) +// defer { sub.cancel() } +// +// client.connect() +// waitForExpectations(timeout: 2) +// } +// } +// +// func testCanSendFromTwoThreadsSimultaneously() throws { +// let queueCount = 8 +// let queues = (0 ..< queueCount).map { DispatchQueue(label: "\($0)") } +// +// let messageCount = 100 +// let sendMessages: (WebSocket) -> Void = { client in +// (0 ..< messageCount).forEach { messageIndex in +// (0 ..< queueCount).forEach { queueIndex in +// queues[queueIndex].async { client.send("\(queueIndex)-\(messageIndex)") } +// } +// } +// } +// +// let receiveMessageEx = expectation( +// description: "Should have received \(queueCount * messageCount) messages" +// ) +// receiveMessageEx.expectedFulfillmentCount = queueCount * messageCount +// +// try withEchoServer { _, client in +// let sub = client.sink( +// receiveCompletion: { _ in }, +// receiveValue: { message in +// switch message { +// case .success(.open): +// sendMessages(client) +// case .success(.text): +// receiveMessageEx.fulfill() +// default: +// XCTFail() +// } +// } +// ) +// defer { sub.cancel() } +// +// client.connect() +// waitForExpectations(timeout: 10) +// client.close() +// } +// } } private extension WebSocketTests { - func withServer(_ block: (WebSocketServer, WebSocket) throws -> Void) throws { + func makeServerAndClient( + _ onStateChange: @escaping (WebSocketEvent) -> Void = { _ in } + ) async -> (WebSocketServer, WebSocket) { let port = ports.removeFirst() let server = WebSocketServer(port: port, replyProvider: .reply { nil }) - let client = WebSocket(url: url(port)) - try withExtendedLifetime((server, client)) { server.listen(); try block(server, client) } + let client = await WebSocket(url: url(port), onStateChange: onStateChange) + server.listen() + return (server, client) } - func withEchoServer(_ block: (WebSocketServer, WebSocket) throws -> Void) throws { + func makeOfflineServerAndClient( + _ onStateChange: @escaping (WebSocketEvent) -> Void = { _ in } + ) async -> (WebSocketServer, WebSocket) { + let port = ports.removeFirst() + let server = WebSocketServer(port: port, replyProvider: .reply { nil }) + let client = await WebSocket(url: url(port), onStateChange: onStateChange) + return (server, client) + } + + func makeEchoServerAndClient( + _ onStateChange: @escaping (WebSocketEvent) -> Void = { _ in } + ) async -> (WebSocketServer, WebSocket) { let port = ports.removeFirst() let server = WebSocketServer(port: port, replyProvider: .echo) - let client = WebSocket(url: url(port)) - try withExtendedLifetime((server, client)) { server.listen(); try block(server, client) } + let client = await WebSocket(url: url(port), onStateChange: onStateChange) + server.listen() + return (server, client) } - func withReplyServer( + func makeReplyServerAndClient( _ replies: [String?], - _ block: (WebSocketServer, WebSocket) throws -> Void - ) throws { + _ onStateChange: @escaping (WebSocketEvent) -> Void = { _ in } + ) async -> (WebSocketServer, WebSocket) { let port = ports.removeFirst() var replies = replies let provider: () -> String? = { replies.removeFirst() } let server = WebSocketServer(port: port, replyProvider: .reply(provider)) - let client = WebSocket(url: url(port)) - try withExtendedLifetime((server, client)) { server.listen(); try block(server, client) } - } -} - -private extension WebSocketTests { - func expectValueAndThen( - _ value: T, - _ block: @escaping @autoclosure () -> Void - ) -> (Result) -> Void { - expectValuesAndThen([value: block]) - } - - func expectValuesAndThen< - T: Hashable, - E: Error - >(_ values: [T: () -> Void]) -> (Result) -> Void { - var values = values - let expectation = self - .expectation(description: "Should have received \(String(describing: values))") - return { (result: Result) in - guard case let .success(value) = result else { - return XCTFail("Unexpected result: \(String(describing: result))") - } - - let block = values.removeValue(forKey: value) - XCTAssertNotNil(block) - block?() - - if values.isEmpty { - expectation.fulfill() - } - } - } - - func expectFinished() -> (Subscribers.Completion) -> Void { - let expectation = self.expectation(description: "Should have finished successfully") - return { completion in - guard case Subscribers.Completion.finished = completion else { return } - expectation.fulfill() - } - } - - func expectFailure() -> (Subscribers.Completion) -> Void where E: Error { - let expectation = self.expectation(description: "Should have failed") - return { completion in - guard case Subscribers.Completion.failure = completion else { return } - expectation.fulfill() - } - } - - func expectNoError() -> (Error?) -> Void { - let expectation = self.expectation(description: "Should not have had an error") - return { error in - XCTAssertNil(error) - expectation.fulfill() - } + let client = await WebSocket(url: url(port), onStateChange: onStateChange) + server.listen() + return (server, client) } } From d4505bb26559076a26e49ba20cd625a790a02227 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Sun, 10 Apr 2022 00:53:49 +0200 Subject: [PATCH 02/11] Add logs --- Sources/WebSocket/WebSocket.swift | 47 ++++++++++++++----- .../Server/WebSocketServer.swift | 4 +- Tests/WebSocketTests/WebSocketTests.swift | 4 +- 3 files changed, 39 insertions(+), 16 deletions(-) diff --git a/Sources/WebSocket/WebSocket.swift b/Sources/WebSocket/WebSocket.swift index 593db8b..efcd20d 100644 --- a/Sources/WebSocket/WebSocket.swift +++ b/Sources/WebSocket/WebSocket.swift @@ -35,6 +35,14 @@ final actor WebSocket { } func close(_ code: WebSocketCloseCode) async { + os_log( + "close: oldstate=%{public}@ code=%lld", + log: .webSocket, + type: .debug, + state.description, + code.rawValue + ) + switch state { case let .connecting(session, task, _), let .open(session, task, _): state = .closed(code.urlSessionCloseCode, nil) @@ -51,13 +59,26 @@ final actor WebSocket { // Mirrors the document behavior of JavaScript's `WebSocket` // http://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send switch state { - case let .open(session, task, _): + case let .open(_, task, _): + os_log("send: %s", log: .webSocket, type: .debug, message.debugDescription) try await task.send(message) case .unopened, .connecting: + os_log( + "send message while connecting: %s", + log: .webSocket, + type: .error, + message.debugDescription + ) throw WebSocketError.sendMessageWhileConnecting case .closing, .closed: + os_log( + "send message while closed: %s", + log: .webSocket, + type: .debug, + message.debugDescription + ) break } } @@ -65,9 +86,18 @@ final actor WebSocket { func receive() async throws -> URLSessionWebSocketTask.Message { switch state { case let .open(_, task, _): - return try await task.receive() + let message = try await task.receive() + os_log("receive: %s", log: .webSocket, type: .debug, message.debugDescription) + return message + case .unopened, .connecting, .closing, .closed: + os_log( + "receive in incorrect state: %s", + log: .webSocket, + type: .error, + state.description + ) throw WebSocketError.receiveMessageWhenNotOpen } } @@ -83,7 +113,7 @@ private extension WebSocket { "connect: oldstate=%{public}@", log: .webSocket, type: .debug, - state.debugDescription + state.description ) switch state { @@ -173,7 +203,7 @@ private extension WebSocket { } } -private enum WebSocketState: CustomDebugStringConvertible { +private enum WebSocketState: CustomStringConvertible { case unopened case connecting(URLSession, URLSessionWebSocketTask, WebSocketDelegate) case open(URLSession, URLSessionWebSocketTask, WebSocketDelegate) @@ -189,7 +219,7 @@ private enum WebSocketState: CustomDebugStringConvertible { } } - var debugDescription: String { + var description: String { switch self { case .unopened: return "unopened" case .connecting: return "connecting" @@ -221,7 +251,6 @@ private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate { webSocketTask: URLSessionWebSocketTask, didOpenWithProtocol protocol: String? ) { - Swift.print("$$$ \(#function)") Task { await onStateChange(.open(webSocketSession, webSocketTask, `protocol`)) } } @@ -231,7 +260,6 @@ private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate { didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, reason: Data? ) { - Swift.print("$$$ \(#function)") Task { await onStateChange(.close(session, webSocketTask, closeCode, reason)) } } @@ -240,13 +268,8 @@ private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate { task: URLSessionTask, didCompleteWithError error: Error? ) { - Swift.print("$$$ \(#function)") Task { await onStateChange(.complete(session, task, error)) } } - - func urlSession(_ session: URLSession, didBecomeInvalidWithError error: Error?) { - Swift.print("$$$ \(#function): \(String(describing: error))") - } } //private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate { diff --git a/Tests/WebSocketTests/Server/WebSocketServer.swift b/Tests/WebSocketTests/Server/WebSocketServer.swift index 5c4c44c..5bcb859 100644 --- a/Tests/WebSocketTests/Server/WebSocketServer.swift +++ b/Tests/WebSocketTests/Server/WebSocketServer.swift @@ -59,8 +59,8 @@ final class WebSocketServer { .EventLoopFuture { head.uri.starts(with: "/socket") ? - channel.pipeline.addHandler(WebSocketHandler(replyProvider: replyProvider)) : channel - .closeFuture + channel.pipeline.addHandler(WebSocketHandler(replyProvider: replyProvider)) : + channel.closeFuture } private var replyProvider: (String) -> String? { diff --git a/Tests/WebSocketTests/WebSocketTests.swift b/Tests/WebSocketTests/WebSocketTests.swift index 05b94da..06d4f7e 100644 --- a/Tests/WebSocketTests/WebSocketTests.swift +++ b/Tests/WebSocketTests/WebSocketTests.swift @@ -94,8 +94,8 @@ class WebSocketTests: XCTestCase { try await client.send(.data(data)) wait(for: [errorEx], timeout: 0.5) -// let isClosed = await client.isClosed -// XCTAssertTrue(isClosed) + let isClosed = await client.isClosed + XCTAssertTrue(isClosed) } func testEchoPush() async throws { From f60953f5c8f70a8d8d515aec2c732578170a60fb Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Sun, 10 Apr 2022 23:37:52 +0200 Subject: [PATCH 03/11] Add WebSocketServer2 --- Package.resolved | 9 + Package.swift | 6 +- .../Server/WebSocketServer.swift | 447 +++++++-------- .../Server/WebSocketServer2.swift | 515 ++++++++++++++++++ Tests/WebSocketTests/WebSocketTests.swift | 182 ++----- 5 files changed, 809 insertions(+), 350 deletions(-) create mode 100644 Tests/WebSocketTests/Server/WebSocketServer2.swift diff --git a/Package.resolved b/Package.resolved index 4ae6b60..6ae1c7e 100644 --- a/Package.resolved +++ b/Package.resolved @@ -10,6 +10,15 @@ "version": "2.39.0" } }, + { + "package": "swift-nio-ssl", + "repositoryURL": "https://github.com/apple/swift-nio-ssl.git", + "state": { + "branch": null, + "revision": "b5260a31c2a72a89fa684f5efb3054d8725a2316", + "version": "2.18.0" + } + }, { "package": "Synchronized", "repositoryURL": "https://github.com/shareup/synchronized.git", diff --git a/Package.swift b/Package.swift index 19d2558..4250a45 100644 --- a/Package.swift +++ b/Package.swift @@ -17,7 +17,9 @@ let package = Package( url: "https://github.com/shareup/synchronized.git", from: "3.0.0" ), - .package(name: "swift-nio", url: "https://github.com/apple/swift-nio.git", from: "2.39.0")], + .package(name: "swift-nio", url: "https://github.com/apple/swift-nio.git", from: "2.39.0"), + .package(name: "swift-nio-ssl", url: "https://github.com/apple/swift-nio-ssl.git", from: "2.18.0"), + ], targets: [ .target( name: "WebSocket", @@ -28,6 +30,8 @@ let package = Package( .product(name: "NIO", package: "swift-nio"), .product(name: "NIOHTTP1", package: "swift-nio"), .product(name: "NIOWebSocket", package: "swift-nio"), + .product(name: "NIOFoundationCompat", package: "swift-nio"), + .product(name: "NIOSSL", package: "swift-nio-ssl"), "WebSocket", ]) ] diff --git a/Tests/WebSocketTests/Server/WebSocketServer.swift b/Tests/WebSocketTests/Server/WebSocketServer.swift index 5bcb859..8ecd2d2 100644 --- a/Tests/WebSocketTests/Server/WebSocketServer.swift +++ b/Tests/WebSocketTests/Server/WebSocketServer.swift @@ -1,223 +1,224 @@ -import Foundation -import NIO -import NIOHTTP1 -import NIOWebSocket - -enum ReplyType { - case echo - case reply(() -> String?) - case matchReply((String) -> String?) -} - -final class WebSocketServer { - let port: UInt16 - - private let replyType: ReplyType - private let eventLoopGroup: EventLoopGroup - - private var serverChannel: Channel? - - init(port: UInt16, replyProvider: ReplyType) { - self.port = port - replyType = replyProvider - eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - } - - func listen() { - do { - var addr = sockaddr_in() - addr.sin_port = in_port_t(port).bigEndian - let address = SocketAddress(addr, host: "0.0.0.0") - - let bootstrap = makeBootstrap() - serverChannel = try bootstrap.bind(to: address).wait() - - guard let localAddress = serverChannel?.localAddress else { - throw NIO.ChannelError.unknownLocalAddress - } - print("WebSocketServer running on \(localAddress)") - } catch let error as NIO.IOError { - print("Failed to start server: \(error.errnoCode) '\(error.localizedDescription)'") - } catch { - print("Failed to start server: \(String(describing: error))") - } - } - - func close() { - do { try serverChannel?.close().wait() } - catch { print("Failed to wait on server: \(error)") } - } - - private func shouldUpgrade(channel _: Channel, - head: HTTPRequestHead) -> EventLoopFuture - { - let headers = head.uri.starts(with: "/socket") ? HTTPHeaders() : nil - return eventLoopGroup.next().makeSucceededFuture(headers) - } - - private func upgradePipelineHandler(channel: Channel, head: HTTPRequestHead) -> NIO - .EventLoopFuture - { - head.uri.starts(with: "/socket") ? - channel.pipeline.addHandler(WebSocketHandler(replyProvider: replyProvider)) : - channel.closeFuture - } - - private var replyProvider: (String) -> String? { - { [weak self] (input: String) -> String? in - guard let self = self else { return nil } - switch self.replyType { - case .echo: - return input - case let .reply(iterator): - return iterator() - case let .matchReply(matcher): - return matcher(input) - } - } - } - - private func makeBootstrap() -> ServerBootstrap { - let reuseAddrOpt = ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR) - return ServerBootstrap(group: eventLoopGroup) - .serverChannelOption(ChannelOptions.backlog, value: 256) - .serverChannelOption(reuseAddrOpt, value: 1) - .childChannelInitializer { channel in - let connectionUpgrader = NIOWebSocketServerUpgrader( - shouldUpgrade: self.shouldUpgrade, - upgradePipelineHandler: self.upgradePipelineHandler - ) - - let config: NIOHTTPServerUpgradeConfiguration = ( - upgraders: [connectionUpgrader], - completionHandler: { _ in } - ) - - return channel.pipeline.configureHTTPServerPipeline( - position: .first, - withPipeliningAssistance: true, - withServerUpgrade: config, - withErrorHandling: true - ) - } - .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) - .childChannelOption(reuseAddrOpt, value: 1) - .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 1) - } -} - -private class WebSocketHandler: ChannelInboundHandler { - typealias InboundIn = WebSocketFrame - typealias OutboundOut = WebSocketFrame - - private let replyProvider: (String) -> String? - private var awaitingClose = false - - init(replyProvider: @escaping (String) -> String?) { - self.replyProvider = replyProvider - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let frame = unwrapInboundIn(data) - - switch frame.opcode { - case .connectionClose: - onClose(context: context, frame: frame) - case .ping: - onPing(context: context, frame: frame) - case .text: - var data = frame.unmaskedData - let text = data.readString(length: data.readableBytes) ?? "" - onText(context: context, text: text) - case .binary: - let buffer = frame.unmaskedData - var data = Data(capacity: buffer.readableBytes) - buffer.withUnsafeReadableBytes { data.append(contentsOf: $0) } - onBinary(context: context, binary: data) - default: - onError(context: context) - } - } - - private func onBinary(context: ChannelHandlerContext, binary: Data) { - do { - // Obviously, this would need to be changed to actually handle data input - if let text = String(data: binary, encoding: .utf8) { - onText(context: context, text: text) - } else { - throw NIO.IOError(errnoCode: EBADMSG, reason: "Invalid message") - } - } catch { - onError(context: context) - } - } - - private func onText(context: ChannelHandlerContext, text: String) { - guard let reply = replyProvider(text) else { return } - - var replyBuffer = context.channel.allocator.buffer(capacity: reply.utf8.count) - replyBuffer.writeString(reply) - - let frame = WebSocketFrame(fin: true, opcode: .text, data: replyBuffer) - - _ = context.channel.writeAndFlush(frame) - } - - private func onPing(context: ChannelHandlerContext, frame: WebSocketFrame) { - var frameData = frame.data - - if let maskingKey = frame.maskKey { - frameData.webSocketUnmask(maskingKey) - } - - let pong = WebSocketFrame(fin: true, opcode: .pong, data: frameData) - context.write(wrapOutboundOut(pong), promise: nil) - } - - private func onClose(context: ChannelHandlerContext, frame: WebSocketFrame) { - if awaitingClose { - // We sent the initial close and were waiting for the client's response - context.close(promise: nil) - } else { - // The close came from the client. - var data = frame.unmaskedData - let closeDataCode = data.readSlice(length: 2) ?? context.channel.allocator - .buffer(capacity: 0) - let closeFrame = WebSocketFrame( - fin: true, - opcode: .connectionClose, - data: closeDataCode - ) - _ = context.write(wrapOutboundOut(closeFrame)).map { () in - context.close(promise: nil) - } - } - } - - private func onError(context: ChannelHandlerContext) { - var data = context.channel.allocator.buffer(capacity: 2) - data.write(webSocketErrorCode: .protocolError) - let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data) - context.write(wrapOutboundOut(frame)).whenComplete { (_: Result) in - context.close(mode: .output, promise: nil) - } - awaitingClose = true - } - - func channelReadComplete(context: ChannelHandlerContext) { - context.flush() - } - - func channelActive(context: ChannelHandlerContext) { - print("Channel active: \(String(describing: context.channel.remoteAddress))") - } - - func channelInactive(context: ChannelHandlerContext) { - print("Channel closed: \(String(describing: context.localAddress))") - } - - func errorCaught(context: ChannelHandlerContext, error: Error) { - print("Error: \(error)") - context.close(promise: nil) - } -} +//import Foundation +//import NIO +//import NIOHTTP1 +//import NIOWebSocket +// +//enum ReplyType { +// case echo +// case reply(() -> String?) +// case matchReply((String) -> String?) +//} +// +//final class WebSocketServer { +// let port: UInt16 +// +// private let replyType: ReplyType +// private let eventLoopGroup: EventLoopGroup +// +// private var serverChannel: Channel? +// +// init(port: UInt16, replyProvider: ReplyType) { +// self.port = port +// replyType = replyProvider +// eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) +// } +// +// func listen() { +// do { +// var addr = sockaddr_in() +// addr.sin_port = in_port_t(port).bigEndian +// let address = SocketAddress(addr, host: "0.0.0.0") +// +// let bootstrap = makeBootstrap() +// serverChannel = try bootstrap.bind(to: address).wait() +// +// guard let localAddress = serverChannel?.localAddress else { +// throw NIO.ChannelError.unknownLocalAddress +// } +// print("WebSocketServer running on \(localAddress)") +// } catch let error as NIO.IOError { +// print("Failed to start server: \(error.errnoCode) '\(error.localizedDescription)'") +// } catch { +// print("Failed to start server: \(String(describing: error))") +// } +// } +// +// func close() { +// do { try serverChannel?.close().wait() } +// catch { print("Failed to wait on server: \(error)") } +// } +// +// private func shouldUpgrade(channel _: Channel, +// head: HTTPRequestHead) -> EventLoopFuture +// { +// let headers = head.uri.starts(with: "/socket") ? HTTPHeaders() : nil +// return eventLoopGroup.next().makeSucceededFuture(headers) +// } +// +// private func upgradePipelineHandler( +// channel: Channel, +// head: HTTPRequestHead +// ) -> NIO.EventLoopFuture { +// head.uri.starts(with: "/socket") ? +// channel.pipeline.addHandler(WebSocketHandler(replyProvider: replyProvider)) : +// channel.closeFuture +// } +// +// private var replyProvider: (String) -> String? { +// { [weak self] (input: String) -> String? in +// guard let self = self else { return nil } +// switch self.replyType { +// case .echo: +// return input +// case let .reply(iterator): +// return iterator() +// case let .matchReply(matcher): +// return matcher(input) +// } +// } +// } +// +// private func makeBootstrap() -> ServerBootstrap { +// let reuseAddrOpt = ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR) +// return ServerBootstrap(group: eventLoopGroup) +// .serverChannelOption(ChannelOptions.backlog, value: 256) +// .serverChannelOption(reuseAddrOpt, value: 1) +// .childChannelInitializer { channel in +// let connectionUpgrader = NIOWebSocketServerUpgrader( +// shouldUpgrade: self.shouldUpgrade, +// upgradePipelineHandler: self.upgradePipelineHandler +// ) +// +// let config: NIOHTTPServerUpgradeConfiguration = ( +// upgraders: [connectionUpgrader], +// completionHandler: { _ in } +// ) +// +// return channel.pipeline.configureHTTPServerPipeline( +// position: .first, +// withPipeliningAssistance: true, +// withServerUpgrade: config, +// withErrorHandling: true +// ) +// } +// .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) +// .childChannelOption(reuseAddrOpt, value: 1) +// .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 1) +// } +//} +// +//private class WebSocketHandler: ChannelInboundHandler { +// typealias InboundIn = WebSocketFrame +// typealias OutboundOut = WebSocketFrame +// +// private let replyProvider: (String) -> String? +// private var awaitingClose = false +// +// init(replyProvider: @escaping (String) -> String?) { +// self.replyProvider = replyProvider +// } +// +// func channelRead(context: ChannelHandlerContext, data: NIOAny) { +// let frame = unwrapInboundIn(data) +// +// switch frame.opcode { +// case .connectionClose: +// onClose(context: context, frame: frame) +// case .ping: +// onPing(context: context, frame: frame) +// case .text: +// var data = frame.unmaskedData +// let text = data.readString(length: data.readableBytes) ?? "" +// onText(context: context, text: text) +// case .binary: +// let buffer = frame.unmaskedData +// var data = Data(capacity: buffer.readableBytes) +// buffer.withUnsafeReadableBytes { data.append(contentsOf: $0) } +// onBinary(context: context, binary: data) +// default: +// onError(context: context) +// } +// } +// +// private func onBinary(context: ChannelHandlerContext, binary: Data) { +// do { +// // Obviously, this would need to be changed to actually handle data input +// if let text = String(data: binary, encoding: .utf8) { +// onText(context: context, text: text) +// } else { +// throw NIO.IOError(errnoCode: EBADMSG, reason: "Invalid message") +// } +// } catch { +// onError(context: context) +// } +// } +// +// private func onText(context: ChannelHandlerContext, text: String) { +// guard let reply = replyProvider(text) else { return } +// +// var replyBuffer = context.channel.allocator.buffer(capacity: reply.utf8.count) +// replyBuffer.writeString(reply) +// +// let frame = WebSocketFrame(fin: true, opcode: .text, data: replyBuffer) +// +// _ = context.channel.writeAndFlush(frame) +// } +// +// private func onPing(context: ChannelHandlerContext, frame: WebSocketFrame) { +// var frameData = frame.data +// +// if let maskingKey = frame.maskKey { +// frameData.webSocketUnmask(maskingKey) +// } +// +// let pong = WebSocketFrame(fin: true, opcode: .pong, data: frameData) +// context.write(wrapOutboundOut(pong), promise: nil) +// } +// +// private func onClose(context: ChannelHandlerContext, frame: WebSocketFrame) { +// if awaitingClose { +// // We sent the initial close and were waiting for the client's response +// context.close(promise: nil) +// } else { +// // The close came from the client. +// var data = frame.unmaskedData +// let closeDataCode = data.readSlice(length: 2) ?? context.channel.allocator +// .buffer(capacity: 0) +// let closeFrame = WebSocketFrame( +// fin: true, +// opcode: .connectionClose, +// data: closeDataCode +// ) +// _ = context.write(wrapOutboundOut(closeFrame)).map { () in +// context.close(promise: nil) +// } +// } +// } +// +// private func onError(context: ChannelHandlerContext) { +// var data = context.channel.allocator.buffer(capacity: 2) +// data.write(webSocketErrorCode: .protocolError) +// let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data) +// context.write(wrapOutboundOut(frame)).whenComplete { (_: Result) in +// context.close(mode: .output, promise: nil) +// } +// awaitingClose = true +// } +// +// func channelReadComplete(context: ChannelHandlerContext) { +// context.flush() +// } +// +// func channelActive(context: ChannelHandlerContext) { +// print("Channel active: \(String(describing: context.channel.remoteAddress))") +// } +// +// func channelInactive(context: ChannelHandlerContext) { +// print("Channel closed: \(String(describing: context.localAddress))") +// } +// +// func errorCaught(context: ChannelHandlerContext, error: Error) { +// print("Error: \(error)") +// context.close(promise: nil) +// } +//} diff --git a/Tests/WebSocketTests/Server/WebSocketServer2.swift b/Tests/WebSocketTests/Server/WebSocketServer2.swift new file mode 100644 index 0000000..b6edd51 --- /dev/null +++ b/Tests/WebSocketTests/Server/WebSocketServer2.swift @@ -0,0 +1,515 @@ +import NIO +import NIOWebSocket +import NIOHTTP1 +import NIOSSL +import Foundation +import NIOFoundationCompat + +enum ReplyType { + case echo + case reply(() -> String?) + case matchReply((String) -> String?) +} + +final class WebSocketServer { + enum PeerType { + case server + case client + } + + var eventLoop: EventLoop { + return channel.eventLoop + } + + var isClosed: Bool { !self.channel.isActive } + private(set) var closeCode: WebSocketErrorCode? + + var onClose: EventLoopFuture { + self.channel.closeFuture + } + + let port: UInt16 + + private let replyType: ReplyType + private let eventLoopGroup: EventLoopGroup + + private var channel: Channel! + private var onTextCallback: (WebSocketServer, String) -> () + private var onBinaryCallback: (WebSocketServer, ByteBuffer) -> () + private var onPongCallback: (WebSocketServer) -> () + private var onPingCallback: (WebSocketServer) -> () + private var frameSequence: WebSocketFrameSequence? + private let type: PeerType + private var waitingForPong: Bool + private var waitingForClose: Bool + private var scheduledTimeoutTask: Scheduled? + + init(port: UInt16, replyProvider: ReplyType) throws { + self.port = port + replyType = replyProvider + eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + + self.type = .server + self.onTextCallback = { _, _ in } + self.onBinaryCallback = { _, _ in } + self.onPongCallback = { _ in } + self.onPingCallback = { _ in } + self.waitingForPong = false + self.waitingForClose = false + self.scheduledTimeoutTask = nil + + var addr = sockaddr_in() + addr.sin_port = in_port_t(port).bigEndian + let address = SocketAddress(addr, host: "0.0.0.0") + + let bootstrap = makeBootstrap() + let channel = try bootstrap.bind(to: address).wait() + + guard let localAddress = channel.localAddress else { + throw NIO.ChannelError.unknownLocalAddress + } + + self.channel = channel + + print("WebSocketServer running on \(localAddress)") + } + + deinit { + assert(self.isClosed, "WebSocketServer was not closed before deinit.") + } + + func onText(_ callback: @escaping (WebSocketServer, String) -> ()) { + self.onTextCallback = callback + } + + func onBinary(_ callback: @escaping (WebSocketServer, ByteBuffer) -> ()) { + self.onBinaryCallback = callback + } + + func onPong(_ callback: @escaping (WebSocketServer) -> ()) { + self.onPongCallback = callback + } + + func onPing(_ callback: @escaping (WebSocketServer) -> ()) { + self.onPingCallback = callback + } + + /// If set, this will trigger automatic pings on the connection. If ping is not answered before + /// the next ping is sent, then the WebSocketServer will be presumed innactive and will be closed + /// automatically. + /// These pings can also be used to keep the WebSocketServer alive if there is some other timeout + /// mechanism shutting down innactive connections, such as a Load Balancer deployed in + /// front of the server. + var pingInterval: TimeAmount? { + didSet { + if pingInterval != nil { + if scheduledTimeoutTask == nil { + waitingForPong = false + self.pingAndScheduleNextTimeoutTask() + } + } else { + scheduledTimeoutTask?.cancel() + } + } + } + + func send(_ text: S, promise: EventLoopPromise? = nil) + where S: Collection, S.Element == Character + { + let string = String(text) + var buffer = channel.allocator.buffer(capacity: text.count) + buffer.writeString(string) + self.send(raw: buffer.readableBytesView, opcode: .text, fin: true, promise: promise) + } + + func send(_ binary: [UInt8], promise: EventLoopPromise? = nil) { + self.send(raw: binary, opcode: .binary, fin: true, promise: promise) + } + + func sendPing(promise: EventLoopPromise? = nil) { + self.send( + raw: Data(), + opcode: .ping, + fin: true, + promise: promise + ) + } + + func send( + raw data: Data, + opcode: WebSocketOpcode, + fin: Bool = true, + promise: EventLoopPromise? = nil + ) + where Data: DataProtocol + { + var buffer = channel.allocator.buffer(capacity: data.count) + buffer.writeBytes(data) + let frame = WebSocketFrame( + fin: fin, + opcode: opcode, + maskKey: self.makeMaskKey(), + data: buffer + ) + self.channel.writeAndFlush(frame, promise: promise) + } + + func close(code: WebSocketErrorCode = .goingAway) -> EventLoopFuture { + let promise = self.eventLoop.makePromise(of: Void.self) + self.close(code: code, promise: promise) + return promise.futureResult + } + + func close( + code: WebSocketErrorCode = .goingAway, + promise: EventLoopPromise? + ) { + guard !self.isClosed else { + promise?.succeed(()) + return + } + guard !self.waitingForClose else { + promise?.succeed(()) + return + } + self.waitingForClose = true + self.closeCode = code + + let codeAsInt = UInt16(webSocketErrorCode: code) + let codeToSend: WebSocketErrorCode + if codeAsInt == 1005 || codeAsInt == 1006 { + /// Code 1005 and 1006 are used to report errors to the application, but must never be sent over + /// the wire (per https://tools.ietf.org/html/rfc6455#section-7.4) + codeToSend = .normalClosure + } else { + codeToSend = code + } + + var buffer = channel.allocator.buffer(capacity: 2) + buffer.write(webSocketErrorCode: codeToSend) + + self.send(raw: buffer.readableBytesView, opcode: .connectionClose, fin: true, promise: promise) + } + + func makeMaskKey() -> WebSocketMaskingKey? { + switch type { + case .client: + var bytes: [UInt8] = [] + for _ in 0..<4 { + bytes.append(.random(in: .min ..< .max)) + } + return WebSocketMaskingKey(bytes) + case .server: + return nil + } + } + + func handle(incoming frame: WebSocketFrame) { + switch frame.opcode { + case .connectionClose: + if self.waitingForClose { + // peer confirmed close, time to close channel + self.channel.close(mode: .all, promise: nil) + } else { + // peer asking for close, confirm and close output side channel + let promise = self.eventLoop.makePromise(of: Void.self) + var data = frame.data + let maskingKey = frame.maskKey + if let maskingKey = maskingKey { + data.webSocketUnmask(maskingKey) + } + self.close( + code: data.readWebSocketErrorCode() ?? .unknown(1005), + promise: promise + ) + promise.futureResult.whenComplete { _ in + self.channel.close(mode: .all, promise: nil) + } + } + case .ping: + if frame.fin { + var frameData = frame.data + let maskingKey = frame.maskKey + if let maskingKey = maskingKey { + frameData.webSocketUnmask(maskingKey) + } + self.send( + raw: frameData.readableBytesView, + opcode: .pong, + fin: true, + promise: nil + ) + } else { + self.close(code: .protocolError, promise: nil) + } + case .text, .binary, .pong: + // create a new frame sequence or use existing + var frameSequence: WebSocketFrameSequence + if let existing = self.frameSequence { + frameSequence = existing + } else { + frameSequence = WebSocketFrameSequence(type: frame.opcode) + } + // append this frame and update the sequence + frameSequence.append(frame) + self.frameSequence = frameSequence + case .continuation: + // we must have an existing sequence + if var frameSequence = self.frameSequence { + // append this frame and update + frameSequence.append(frame) + self.frameSequence = frameSequence + } else { + self.close(code: .protocolError, promise: nil) + } + default: + // We ignore all other frames. + break + } + + // if this frame was final and we have a non-nil frame sequence, + // output it to the websocket and clear storage + if let frameSequence = self.frameSequence, frame.fin { + switch frameSequence.type { + case .binary: + self.onBinaryCallback(self, frameSequence.binaryBuffer) + case .text: + self.onTextCallback(self, frameSequence.textBuffer) + case .pong: + self.waitingForPong = false + self.onPongCallback(self) + case .ping: + self.onPingCallback(self) + default: break + } + self.frameSequence = nil + } + } + + private func pingAndScheduleNextTimeoutTask() { + guard channel.isActive, let pingInterval = pingInterval else { + return + } + + if waitingForPong { + // We never received a pong from our last ping, so the connection has timed out + let promise = self.eventLoop.makePromise(of: Void.self) + self.close(code: .unknown(1006), promise: promise) + promise.futureResult.whenComplete { _ in + // Usually, closing a WebSocketServer is done by sending the close frame and waiting + // for the peer to respond with their close frame. We are in a timeout situation, + // so the other side likely will never send the close frame. We just close the + // channel ourselves. + self.channel.close(mode: .all, promise: nil) + } + } else { + self.sendPing() + self.waitingForPong = true + self.scheduledTimeoutTask = self.eventLoop.scheduleTask( + deadline: .now() + pingInterval, + self.pingAndScheduleNextTimeoutTask + ) + } + } + + private func makeBootstrap() -> ServerBootstrap { + let reuseAddrOpt = ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR) + return ServerBootstrap(group: eventLoopGroup) + .serverChannelOption(ChannelOptions.backlog, value: 256) + .serverChannelOption(reuseAddrOpt, value: 1) + .childChannelInitializer { channel in + let connectionUpgrader = NIOWebSocketServerUpgrader( + shouldUpgrade: self.shouldUpgrade, + upgradePipelineHandler: self.upgradePipelineHandler + ) + + let config: NIOHTTPServerUpgradeConfiguration = ( + upgraders: [connectionUpgrader], + completionHandler: { _ in } + ) + + return channel.pipeline.configureHTTPServerPipeline( + position: .first, + withPipeliningAssistance: true, + withServerUpgrade: config, + withErrorHandling: true + ) + } + .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) + .childChannelOption(reuseAddrOpt, value: 1) + .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 1) + } + + private func shouldUpgrade(channel _: Channel, + head: HTTPRequestHead) -> EventLoopFuture + { + let headers = head.uri.starts(with: "/socket") ? HTTPHeaders() : nil + return eventLoopGroup.next().makeSucceededFuture(headers) + } + + private func upgradePipelineHandler( + channel: Channel, + head: HTTPRequestHead + ) -> NIO.EventLoopFuture { + head.uri.starts(with: "/socket") ? + channel.pipeline.addHandler(WebSocketHandler(replyType: replyType)) : + channel.closeFuture + } +} + +private final class WebSocketHandler: ChannelInboundHandler { + typealias InboundIn = WebSocketFrame + typealias OutboundOut = WebSocketFrame + + private var awaitingClose: Bool = false + private let replyType: ReplyType + + init(replyType: ReplyType) { + self.replyType = replyType + } + + private func replyProvider(input: String) -> String? { + switch replyType { + case .echo: + return input + case let .reply(iterator): + return iterator() + case let .matchReply(matcher): + return matcher(input) + } + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let frame = self.unwrapInboundIn(data) + + func handleText(_ text: String) { + guard let reply = replyProvider(input: text) else { return } + let buffer = context.channel.allocator.buffer(data: Data(reply.utf8)) + let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) + context.writeAndFlush(self.wrapOutboundOut(frame)).whenFailure { _ in + context.close(promise: nil) + } + } + + switch frame.opcode { + case .connectionClose: + self.receivedClose(context: context, frame: frame) + case .ping: + self.pong(context: context, frame: frame) + case .text: + var data = frame.unmaskedData + let text = data.readString(length: data.readableBytes) ?? "" + handleText(text) + + case .binary: + let buffer = frame.unmaskedData + var data = Data(capacity: buffer.readableBytes) + buffer.withUnsafeReadableBytes { data.append(contentsOf: $0) } + + if let text = String(data: data, encoding: .utf8) { + handleText(text) + } else { + closeOnError(context: context) + } + + case .continuation, .pong: + // We ignore these frames. + break + default: + // Unknown frames are errors. + self.closeOnError(context: context) + } + } + + func channelReadComplete(context: ChannelHandlerContext) { + context.flush() + } + + private func sendTime(context: ChannelHandlerContext) { + guard context.channel.isActive else { return } + + // We can't send if we sent a close message. + guard !self.awaitingClose else { return } + + // We can't really check for error here, but it's also not the purpose of the + // example so let's not worry about it. + let theTime = NIODeadline.now().uptimeNanoseconds + var buffer = context.channel.allocator.buffer(capacity: 12) + buffer.writeString("\(theTime)") + + let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) + context.writeAndFlush(self.wrapOutboundOut(frame)).map { + context.eventLoop.scheduleTask(in: .seconds(1), { self.sendTime(context: context) }) + }.whenFailure { (_: Error) in + context.close(promise: nil) + } + } + + private func receivedClose(context: ChannelHandlerContext, frame: WebSocketFrame) { + // Handle a received close frame. In websockets, we're just going to send the close + // frame and then close, unless we already sent our own close frame. + if awaitingClose { + // Cool, we started the close and were waiting for the user. We're done. + context.close(promise: nil) + } else { + // This is an unsolicited close. We're going to send a response frame and + // then, when we've sent it, close up shop. We should send back the close code the remote + // peer sent us, unless they didn't send one at all. + var data = frame.unmaskedData + let closeDataCode = data.readSlice(length: 2) ?? ByteBuffer() + let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode) + _ = context.write(self.wrapOutboundOut(closeFrame)).map { () in + context.close(promise: nil) + } + } + } + + private func pong(context: ChannelHandlerContext, frame: WebSocketFrame) { + var frameData = frame.data + let maskingKey = frame.maskKey + + if let maskingKey = maskingKey { + frameData.webSocketUnmask(maskingKey) + } + + let responseFrame = WebSocketFrame(fin: true, opcode: .pong, data: frameData) + context.write(self.wrapOutboundOut(responseFrame), promise: nil) + } + + private func closeOnError(context: ChannelHandlerContext) { + // We have hit an error, we want to close. We do that by sending a close frame and then + // shutting down the write side of the connection. + var data = context.channel.allocator.buffer(capacity: 2) + data.write(webSocketErrorCode: .protocolError) + let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data) + context.write(self.wrapOutboundOut(frame)).whenComplete { (_: Result) in + context.close(mode: .output, promise: nil) + } + awaitingClose = true + } +} + +private struct WebSocketFrameSequence { + var binaryBuffer: ByteBuffer + var textBuffer: String + var type: WebSocketOpcode + + init(type: WebSocketOpcode) { + self.binaryBuffer = ByteBufferAllocator().buffer(capacity: 0) + self.textBuffer = .init() + self.type = type + } + + mutating func append(_ frame: WebSocketFrame) { + var data = frame.unmaskedData + switch type { + case .binary: + self.binaryBuffer.writeBuffer(&data) + case .text: + if let string = data.readString(length: data.readableBytes) { + self.textBuffer += string + } + default: break + } + } +} diff --git a/Tests/WebSocketTests/WebSocketTests.swift b/Tests/WebSocketTests/WebSocketTests.swift index 06d4f7e..7e9bde6 100644 --- a/Tests/WebSocketTests/WebSocketTests.swift +++ b/Tests/WebSocketTests/WebSocketTests.swift @@ -40,13 +40,6 @@ class WebSocketTests: XCTestCase { wait(for: [closeEx], timeout: 0.5) } -// func testCustom() async throws { -// let (server, client) = await makeServerAndClient() -// -// try await Task.sleep(nanoseconds: NSEC_PER_SEC * 10000) -// server.close() -// } - func testErrorWhenServerIsUnreachable() async throws { let ex = expectation(description: "Should have errored") let (server, client) = await makeOfflineServerAndClient { event in @@ -115,118 +108,55 @@ class WebSocketTests: XCTestCase { XCTAssertEqual("hello", text) } -// func testEchoPush() throws { -// try withEchoServer { _, client in -// let message = "hello" -// let completion = self.expectNoError() -// -// let sub = client.sink( -// receiveCompletion: expectFinished(), -// receiveValue: expectValuesAndThen([ -// .open: { client.send(message, completionHandler: completion) }, -// .text(message): { client.close() }, -// ]) -// ) -// defer { sub.cancel() } -// -// client.connect() -// waitForExpectations(timeout: 2) -// } -// } -// -// func testEchoBinaryPush() throws { -// try withEchoServer { _, client in -// let message = "hello" -// let data = message.data(using: .utf8)! -// let completion = self.expectNoError() -// -// let sub = client.sink( -// receiveCompletion: expectFinished(), -// receiveValue: expectValuesAndThen([ -// .open: { client.send(data, completionHandler: completion) }, -// .text(message): { client.close() }, -// ]) -// ) -// defer { sub.cancel() } -// -// client.connect() -// waitForExpectations(timeout: 2) -// } -// } -// -// func testJoinLobbyAndEcho() throws { -// let joinPush = "[1,1,\"room:lobby\",\"phx_join\",{}]" -// let echoPush1 = "[1,2,\"room:lobby\",\"echo\",{\"echo\":\"one\"}]" -// let echoPush2 = "[1,3,\"room:lobby\",\"echo\",{\"echo\":\"two\"}]" -// -// let joinReply = "[1,1,\"room:lobby\",\"phx_reply\",{\"response\":{},\"status\":\"ok\"}]" -// let echoReply1 = -// "[1,2,\"room:lobby\",\"phx_reply\",{\"response\":{\"echo\":\"one\"},\"status\":\"ok\"}]" -// let echoReply2 = -// "[1,3,\"room:lobby\",\"phx_reply\",{\"response\":{\"echo\":\"two\"},\"status\":\"ok\"}]" -// -// let joinCompletion = expectNoError() -// let echo1Completion = expectNoError() -// let echo2Completion = expectNoError() -// -// try withReplyServer([joinReply, echoReply1, echoReply2]) { _, client in -// let sub = client.sink( -// receiveCompletion: expectFinished(), -// receiveValue: expectValuesAndThen([ -// .open: { client.send(joinPush, completionHandler: joinCompletion) }, -// .text(joinReply): { client.send(echoPush1, completionHandler: echo1Completion) -// }, -// .text(echoReply1): { client.send(echoPush2, completionHandler: echo2Completion) -// }, -// .text(echoReply2): { client.close() }, -// ]) -// ) -// defer { sub.cancel() } -// -// client.connect() -// waitForExpectations(timeout: 2) -// } -// } -// -// func testCanSendFromTwoThreadsSimultaneously() throws { -// let queueCount = 8 -// let queues = (0 ..< queueCount).map { DispatchQueue(label: "\($0)") } -// -// let messageCount = 100 -// let sendMessages: (WebSocket) -> Void = { client in -// (0 ..< messageCount).forEach { messageIndex in -// (0 ..< queueCount).forEach { queueIndex in -// queues[queueIndex].async { client.send("\(queueIndex)-\(messageIndex)") } -// } -// } -// } -// -// let receiveMessageEx = expectation( -// description: "Should have received \(queueCount * messageCount) messages" -// ) -// receiveMessageEx.expectedFulfillmentCount = queueCount * messageCount -// -// try withEchoServer { _, client in -// let sub = client.sink( -// receiveCompletion: { _ in }, -// receiveValue: { message in -// switch message { -// case .success(.open): -// sendMessages(client) -// case .success(.text): -// receiveMessageEx.fulfill() -// default: -// XCTFail() -// } -// } -// ) -// defer { sub.cancel() } -// -// client.connect() -// waitForExpectations(timeout: 10) -// client.close() -// } -// } + func testEchoBinaryPush() async throws { + let openEx = expectation(description: "Should have opened") + let (server, client) = await makeEchoServerAndClient { event in + guard case .open = event else { return } + openEx.fulfill() + } + defer { server.close() } + + wait(for: [openEx], timeout: 0.5) + + try await client.send(.data(Data("hello".utf8))) + guard case let .string(text) = try await client.receive() + else { return XCTFail("Should have received text") } + + XCTAssertEqual("hello", text) + } + + func testJoinLobbyAndEcho() async throws { + var pushes = [ + "[1,1,\"room:lobby\",\"phx_join\",{}]", + "[1,2,\"room:lobby\",\"echo\",{\"echo\":\"one\"}]", + "[1,3,\"room:lobby\",\"echo\",{\"echo\":\"two\"}]", + ] + + let replies = [ + "[1,1,\"room:lobby\",\"phx_reply\",{\"response\":{},\"status\":\"ok\"}]", + "[1,2,\"room:lobby\",\"phx_reply\",{\"response\":{\"echo\":\"one\"},\"status\":\"ok\"}]", + "[1,3,\"room:lobby\",\"phx_reply\",{\"response\":{\"echo\":\"two\"},\"status\":\"ok\"}]", + ] + + let openEx = expectation(description: "Should have opened") + + let (server, client) = await makeReplyServerAndClient(replies) { event in + guard case .open = event else { return } + openEx.fulfill() + } + defer { server.close() } + + wait(for: [openEx], timeout: 0.5) + + try await client.send(.string(pushes.removeFirst())) + try await client.send(.string(pushes.removeFirst())) + try await client.send(.string(pushes.removeFirst())) + + for expected in replies { + guard case let .string(reply) = try await client.receive() else { return XCTFail() } + XCTAssertEqual(expected, reply) + } + } } private extension WebSocketTests { @@ -234,9 +164,9 @@ private extension WebSocketTests { _ onStateChange: @escaping (WebSocketEvent) -> Void = { _ in } ) async -> (WebSocketServer, WebSocket) { let port = ports.removeFirst() - let server = WebSocketServer(port: port, replyProvider: .reply { nil }) + let server = try! WebSocketServer(port: port, replyProvider: .reply { nil }) let client = await WebSocket(url: url(port), onStateChange: onStateChange) - server.listen() +// server.listen() return (server, client) } @@ -244,7 +174,7 @@ private extension WebSocketTests { _ onStateChange: @escaping (WebSocketEvent) -> Void = { _ in } ) async -> (WebSocketServer, WebSocket) { let port = ports.removeFirst() - let server = WebSocketServer(port: port, replyProvider: .reply { nil }) + let server = try! WebSocketServer(port: 1, replyProvider: .reply { nil }) let client = await WebSocket(url: url(port), onStateChange: onStateChange) return (server, client) } @@ -253,9 +183,9 @@ private extension WebSocketTests { _ onStateChange: @escaping (WebSocketEvent) -> Void = { _ in } ) async -> (WebSocketServer, WebSocket) { let port = ports.removeFirst() - let server = WebSocketServer(port: port, replyProvider: .echo) + let server = try! WebSocketServer(port: port, replyProvider: .echo) let client = await WebSocket(url: url(port), onStateChange: onStateChange) - server.listen() +// server.listen() return (server, client) } @@ -266,9 +196,9 @@ private extension WebSocketTests { let port = ports.removeFirst() var replies = replies let provider: () -> String? = { replies.removeFirst() } - let server = WebSocketServer(port: port, replyProvider: .reply(provider)) + let server = try! WebSocketServer(port: port, replyProvider: .reply(provider)) let client = await WebSocket(url: url(port), onStateChange: onStateChange) - server.listen() +// server.listen() return (server, client) } } From 717f256b13bd5cb2132fff5775cf0b3f037a4306 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Sun, 17 Apr 2022 13:02:33 +0200 Subject: [PATCH 04/11] Clean up WebSocket --- Sources/WebSocket/WebSocket.swift | 432 +----------------------------- 1 file changed, 7 insertions(+), 425 deletions(-) diff --git a/Sources/WebSocket/WebSocket.swift b/Sources/WebSocket/WebSocket.swift index efcd20d..7eb7d57 100644 --- a/Sources/WebSocket/WebSocket.swift +++ b/Sources/WebSocket/WebSocket.swift @@ -79,7 +79,6 @@ final actor WebSocket { type: .debug, message.debugDescription ) - break } } @@ -90,7 +89,6 @@ final actor WebSocket { os_log("receive: %s", log: .webSocket, type: .debug, message.debugDescription) return message - case .unopened, .connecting, .closing, .closed: os_log( "receive in incorrect state: %s", @@ -141,18 +139,18 @@ private extension WebSocket { } } - var onDelegateEvent: (WebSocketDelegateEvent) async -> Void { + var onDelegateEvent: (WebSocketDelegateEvent) async -> Void { { [weak self] (event: WebSocketDelegateEvent) in guard let self = self else { return } switch (await self.state, event) { case let (.connecting(s1, t1, delegate), .open(s2, t2, _)): - guard s1 === s2 && t1 === t2 else { return } + guard s1 === s2, t1 === t2 else { return } await self.setState(.open(s2, t2, delegate)) await self.onStateChange(.open) case let (.connecting(s1, t1, _), .close(s2, t2, closeCode, reason)), - let (.open(s1, t1, _), .close(s2, t2, closeCode, reason)): + let (.open(s1, t1, _), .close(s2, t2, closeCode, reason)): guard s1 === s2, t1 === t2 else { return } if let closeCode = closeCode { await self.setState(.closed(closeCode, reason)) @@ -163,13 +161,14 @@ private extension WebSocket { s2.invalidateAndCancel() case let (.connecting(s1, t1, _), .complete(s2, t2, error)), - let (.open(s1, t1, _), .complete(s2, t2, error)): + let (.open(s1, t1, _), .complete(s2, t2, error)): guard s1 === s2, t1 === t2 else { return } if let error = error { await self.setState( .closed( .internalServerError, - Data(error.localizedDescription.utf8)) + Data(error.localizedDescription.utf8) + ) ) await self.onStateChange(.error(error as NSError)) } else { @@ -232,7 +231,7 @@ private enum WebSocketState: CustomStringConvertible { // MARK: URLSessionWebSocketDelegate -private enum WebSocketDelegateEvent { +private enum WebSocketDelegateEvent { case open(URLSession, URLSessionWebSocketTask, String?) case close(URLSession, URLSessionWebSocketTask, URLSessionWebSocketTask.CloseCode?, Data?) case complete(URLSession, URLSessionTask, Error?) @@ -271,420 +270,3 @@ private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate { Task { await onStateChange(.complete(session, task, error)) } } } - -//private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate { -// private let onOpen: OnOpenHandler -// private let onClose: OnCloseHandler -// private let onCompletion: OnCompletionHandler -// -// init( -// onOpen: @escaping OnOpenHandler, -// onClose: @escaping OnCloseHandler, -// onCompletion: @escaping OnCompletionHandler -// ) { -// self.onOpen = onOpen -// self.onClose = onClose -// self.onCompletion = onCompletion -// super.init() -// } -// -// func urlSession( -// _ webSocketSession: URLSession, -// webSocketTask: URLSessionWebSocketTask, -// didOpenWithProtocol protocol: String? -// ) { -// onOpen(webSocketSession, webSocketTask, `protocol`) -// } -// -// func urlSession(_ session: URLSession, -// webSocketTask: URLSessionWebSocketTask, -// didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, -// reason: Data?) -// { -// onClose(session, webSocketTask, closeCode, reason) -// } -// -// func urlSession(_ session: URLSession, -// task: URLSessionTask, -// didCompleteWithError error: Error?) -// { -// onCompletion(session, task, error) -// } -//} - -//public final class WebSocket: WebSocketProtocol { -// public typealias Output = Result -// public typealias Failure = Swift.Error -// -// private enum State: CustomDebugStringConvertible { -// case unopened -// case connecting(URLSession, URLSessionWebSocketTask, WebSocketDelegate) -// case open(URLSession, URLSessionWebSocketTask, WebSocketDelegate) -// case closing -// case closed(WebSocketError) -// -// var webSocketSessionAndTask: (URLSession, URLSessionWebSocketTask)? { -// switch self { -// case let .connecting(session, task, _), let .open(session, task, _): -// return (session, task) -// case .unopened, .closing, .closed: -// return nil -// } -// } -// -// var debugDescription: String { -// switch self { -// case .unopened: return "unopened" -// case .connecting: return "connecting" -// case .open: return "open" -// case .closing: return "closing" -// case .closed: return "closed" -// } -// } -// } -// -// /// The maximum number of bytes to buffer before the receive call fails with an error. -// /// Default: 1 MiB -// public var maximumMessageSize: Int = 1024 * 1024 { -// didSet { sync { -// guard let (_, task) = state.webSocketSessionAndTask else { return } -// task.maximumMessageSize = maximumMessageSize -// } } -// } -// -// public var isOpen: Bool { sync { -// guard case .open = state else { return false } -// return true -// } } -// -// public var isClosed: Bool { sync { -// guard case .closed = state else { return false } -// return true -// } } -// -// private let lock = RecursiveLock() -// private func sync(_ block: () throws -> T) rethrows -> T { try lock.locked(block) } -// -// private let url: URL -// -// private let timeoutIntervalForRequest: TimeInterval -// private let timeoutIntervalForResource: TimeInterval -// -// private var state: State = .unopened -// private let subject = PassthroughSubject() -// -// private let subjectQueue: DispatchQueue -// -// public convenience init(url: URL) { -// self.init(url: url, publisherQueue: DispatchQueue.global()) -// } -// -// public init( -// url: URL, -// timeoutIntervalForRequest: TimeInterval = 60, // 60 seconds -// timeoutIntervalForResource: TimeInterval = 604_800, // 7 days -// publisherQueue: DispatchQueue = DispatchQueue.global() -// ) { -// self.url = url -// self.timeoutIntervalForRequest = timeoutIntervalForRequest -// self.timeoutIntervalForResource = timeoutIntervalForResource -// subjectQueue = DispatchQueue( -// label: "app.shareup.websocket.subjectqueue", -// qos: .default, -// autoreleaseFrequency: .workItem, -// target: publisherQueue -// ) -// } -// -// deinit { -// close() -// } -// -// public func connect() { -// sync { -// os_log( -// "connect: oldstate=%{public}@", -// log: .webSocket, -// type: .debug, -// state.debugDescription -// ) -// -// switch state { -// case .closed, .unopened: -// let delegate = WebSocketDelegate( -// onOpen: onOpen, -// onClose: onClose, -// onCompletion: onCompletion -// ) -// -// let config = URLSessionConfiguration.default -// config.timeoutIntervalForRequest = timeoutIntervalForRequest -// config.timeoutIntervalForResource = timeoutIntervalForResource -// -// let session = URLSession( -// configuration: config, -// delegate: delegate, -// delegateQueue: nil -// ) -// -// let task = session.webSocketTask(with: url) -// task.maximumMessageSize = maximumMessageSize -// state = .connecting(session, task, delegate) -// task.resume() -// receiveFromWebSocket() -// -// default: -// break -// } -// } -// } -// -// public func receive(subscriber: S) -// where S.Input == Result, S.Failure == Swift.Error -// { -// subject.receive(subscriber: subscriber) -// } -// -// private func receiveFromWebSocket() { -// let task: URLSessionWebSocketTask? = sync { -// let webSocketTask = self.state.webSocketSessionAndTask?.1 -// guard let task = webSocketTask, case .running = task.state else { return nil } -// return task -// } -// -// task?.receive -// { [weak self, weak task] (result: Result) in -// guard let self = self else { return } -// -// let _result = result.map { WebSocketMessage($0) } -// -// guard task?.state == .running -// else { -// os_log( -// "receive message in incorrect task state: message=%s taskstate=%{public}@", -// log: .webSocket, -// type: .debug, -// _result.debugDescription, -// "\(task?.state.rawValue ?? -1)" -// ) -// return -// } -// -// os_log("receive: %s", log: .webSocket, type: .debug, _result.debugDescription) -// self.subjectQueue.async { [weak self] in self?.subject.send(_result) } -// self.receiveFromWebSocket() -// } -// } -// -// public func send( -// _ string: String, -// completionHandler: @escaping (Error?) -> Void = { _ in } -// ) { -// os_log("send: %s", log: .webSocket, type: .debug, string) -// send(.string(string), completionHandler: completionHandler) -// } -// -// public func send(_ data: Data, completionHandler: @escaping (Error?) -> Void = { _ in }) { -// os_log("send: %lld bytes", log: .webSocket, type: .debug, data.count) -// send(.data(data), completionHandler: completionHandler) -// } -// -// private func send( -// _ message: URLSessionWebSocketTask.Message, -// completionHandler: @escaping (Error?) -> Void -// ) { -// let task: URLSessionWebSocketTask? = sync { -// guard case let .open(_, task, _) = state, task.state == .running -// else { -// os_log( -// "send message in incorrect task state: message=%s taskstate=%{public}@", -// log: .webSocket, -// type: .debug, -// message.debugDescription, -// "\(self.state.webSocketSessionAndTask?.1.state.rawValue ?? -1)" -// ) -// completionHandler(WebSocketError.notOpen) -// return nil -// } -// return task -// } -// -// task?.send(message, completionHandler: completionHandler) -// } -// -// public func close(_ closeCode: WebSocketCloseCode) { -// let task: URLSessionWebSocketTask? = sync { -// os_log( -// "close: oldstate=%{public}@ code=%lld", -// log: .webSocket, -// type: .debug, -// state.debugDescription, -// closeCode.rawValue -// ) -// -// guard let (_, task) = state.webSocketSessionAndTask, task.state == .running -// else { return nil } -// state = .closing -// return task -// } -// -// let code = URLSessionWebSocketTask.CloseCode(closeCode) ?? .invalid -// task?.cancel(with: code, reason: nil) -// } -//} -// -//private typealias OnOpenHandler = (URLSession, URLSessionWebSocketTask, String?) -> Void -//private typealias OnCloseHandler = ( -// URLSession, -// URLSessionWebSocketTask, -// URLSessionWebSocketTask.CloseCode, -// Data? -//) -> Void -//private typealias OnCompletionHandler = (URLSession, URLSessionTask, Error?) -> Void -// -//private let normalCloseCodes: [URLSessionWebSocketTask.CloseCode] = [.goingAway, .normalClosure] -// -//// MARK: onOpen and onClose -// -//private extension WebSocket { -// var onOpen: OnOpenHandler { -// { [weak self] webSocketSession, webSocketTask, _ in -// guard let self = self else { return } -// -// self.sync { -// os_log( -// "onOpen: oldstate=%{public}@", -// log: .webSocket, -// type: .debug, -// self.state.debugDescription -// ) -// -// guard case let .connecting(session, task, delegate) = self.state else { -// os_log( -// "receive onOpen callback in incorrect state: oldstate=%{public}@", -// log: .webSocket, -// type: .error, -// self.state.debugDescription -// ) -// self.state = .open( -// webSocketSession, -// webSocketTask, -// webSocketSession.delegate as! WebSocketDelegate -// ) -// return -// } -// -// assert(session === webSocketSession) -// assert(task === webSocketTask) -// -// self.state = .open(webSocketSession, webSocketTask, delegate) -// } -// -// self.subjectQueue.async { [weak self] in self?.subject.send(.success(.open)) } -// } -// } -// -// var onClose: OnCloseHandler { -// { [weak self] _, _, closeCode, reason in -// guard let self = self else { return } -// -// self.sync { -// os_log( -// "onClose: oldstate=%{public}@ code=%lld", -// log: .webSocket, -// type: .debug, -// self.state.debugDescription, -// closeCode.rawValue -// ) -// -// if case .closed = self.state { return } -// self.state = .closed(WebSocketError.closed(closeCode, reason)) -// -// self.subjectQueue.async { [weak self] in -// if normalCloseCodes.contains(closeCode) { -// self?.subject.send(completion: .finished) -// } else { -// self?.subject.send( -// completion: .failure(WebSocketError.closed(closeCode, reason)) -// ) -// } -// } -// } -// } -// } -// -// var onCompletion: OnCompletionHandler { -// { [weak self] webSocketSession, _, error in -// defer { webSocketSession.invalidateAndCancel() } -// guard let self = self else { return } -// -// os_log("onCompletion", log: .webSocket, type: .debug) -// -// // "The only errors your delegate receives through the error parameter -// // are client-side errors, such as being unable to resolve the hostname -// // or connect to the host." -// // -// // https://developer.apple.com/documentation/foundation/urlsessiontaskdelegate/1411610-urlsession -// // -// // When receiving these errors, `onClose` is not called because the connection -// // was never actually opened. -// guard let error = error else { return } -// self.sync { -// os_log( -// "onCompletion: oldstate=%{public}@ error=%@", -// log: .webSocket, -// type: .debug, -// self.state.debugDescription, -// error.localizedDescription -// ) -// -// if case .closed = self.state { return } -// self.state = .closed(.notOpen) -// -// self.subjectQueue.async { [weak self] in -// self?.subject.send(completion: .failure(error)) -// } -// } -// } -// } -//} -// -//// MARK: URLSessionWebSocketDelegate -// -//private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate { -// private let onOpen: OnOpenHandler -// private let onClose: OnCloseHandler -// private let onCompletion: OnCompletionHandler -// -// init(onOpen: @escaping OnOpenHandler, -// onClose: @escaping OnCloseHandler, -// onCompletion: @escaping OnCompletionHandler) -// { -// self.onOpen = onOpen -// self.onClose = onClose -// self.onCompletion = onCompletion -// super.init() -// } -// -// func urlSession(_ webSocketSession: URLSession, -// webSocketTask: URLSessionWebSocketTask, -// didOpenWithProtocol protocol: String?) -// { -// onOpen(webSocketSession, webSocketTask, `protocol`) -// } -// -// func urlSession(_ session: URLSession, -// webSocketTask: URLSessionWebSocketTask, -// didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, -// reason: Data?) -// { -// onClose(session, webSocketTask, closeCode, reason) -// } -// -// func urlSession(_ session: URLSession, -// task: URLSessionTask, -// didCompleteWithError error: Error?) -// { -// onCompletion(session, task, error) -// } -//} From 4a708a468d277ac324a7f9316019bb7fb8ffc746 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Sun, 17 Apr 2022 23:59:34 +0200 Subject: [PATCH 05/11] Replace URLSessionWebSocketTask with NWConnection --- Package.resolved | 34 - Package.swift | 22 +- README.md | 43 +- Sources/WebSocket/SystemWebSocket.swift | 640 ++++++++++++++++++ ...essionWebSocketTaskMessage+WebSocket.swift | 2 +- Sources/WebSocket/WebSocket.swift | 323 +++------ Sources/WebSocket/WebSocketClient.swift | 69 -- Sources/WebSocket/WebSocketCloseCode.swift | 35 +- Sources/WebSocket/WebSocketCloseResult.swift | 8 + Sources/WebSocket/WebSocketError.swift | 16 +- Sources/WebSocket/WebSocketEvent.swift | 31 - Sources/WebSocket/WebSocketMessage.swift | 5 +- .../Server/WebSocketServer.swift | 451 ++++++------ .../Server/WebSocketServer2.swift | 515 -------------- .../WebSocketTests/SystemWebSocketTests.swift | 348 ++++++++++ Tests/WebSocketTests/WebSocketTests.swift | 204 ------ 16 files changed, 1382 insertions(+), 1364 deletions(-) delete mode 100644 Package.resolved create mode 100644 Sources/WebSocket/SystemWebSocket.swift delete mode 100644 Sources/WebSocket/WebSocketClient.swift create mode 100644 Sources/WebSocket/WebSocketCloseResult.swift delete mode 100644 Sources/WebSocket/WebSocketEvent.swift delete mode 100644 Tests/WebSocketTests/Server/WebSocketServer2.swift create mode 100644 Tests/WebSocketTests/SystemWebSocketTests.swift delete mode 100644 Tests/WebSocketTests/WebSocketTests.swift diff --git a/Package.resolved b/Package.resolved deleted file mode 100644 index 6ae1c7e..0000000 --- a/Package.resolved +++ /dev/null @@ -1,34 +0,0 @@ -{ - "object": { - "pins": [ - { - "package": "swift-nio", - "repositoryURL": "https://github.com/apple/swift-nio.git", - "state": { - "branch": null, - "revision": "d6e3762e0a5f7ede652559f53623baf11006e17c", - "version": "2.39.0" - } - }, - { - "package": "swift-nio-ssl", - "repositoryURL": "https://github.com/apple/swift-nio-ssl.git", - "state": { - "branch": null, - "revision": "b5260a31c2a72a89fa684f5efb3054d8725a2316", - "version": "2.18.0" - } - }, - { - "package": "Synchronized", - "repositoryURL": "https://github.com/shareup/synchronized.git", - "state": { - "branch": null, - "revision": "f01e4a1ee5fbf586d612a8dc0bc068603f6b9450", - "version": "3.0.0" - } - } - ] - }, - "version": 1 -} diff --git a/Package.swift b/Package.swift index 4250a45..823ab7a 100644 --- a/Package.swift +++ b/Package.swift @@ -10,29 +10,17 @@ let package = Package( .library( name: "WebSocket", targets: ["WebSocket"] - )], - dependencies: [ - .package( - name: "Synchronized", - url: "https://github.com/shareup/synchronized.git", - from: "3.0.0" ), - .package(name: "swift-nio", url: "https://github.com/apple/swift-nio.git", from: "2.39.0"), - .package(name: "swift-nio-ssl", url: "https://github.com/apple/swift-nio-ssl.git", from: "2.18.0"), ], + dependencies: [], targets: [ .target( name: "WebSocket", - dependencies: ["Synchronized"]), + dependencies: [] + ), .testTarget( name: "WebSocketTests", - dependencies: [ - .product(name: "NIO", package: "swift-nio"), - .product(name: "NIOHTTP1", package: "swift-nio"), - .product(name: "NIOWebSocket", package: "swift-nio"), - .product(name: "NIOFoundationCompat", package: "swift-nio"), - .product(name: "NIOSSL", package: "swift-nio-ssl"), - "WebSocket", - ]) + dependencies: ["WebSocket"] + ), ] ) diff --git a/README.md b/README.md index 24c97dd..78bf3ae 100644 --- a/README.md +++ b/README.md @@ -2,31 +2,32 @@ ## _(macOS, iOS, iPadOS, tvOS, and watchOS)_ -A concrete implementation of a WebSocket client implemented by wrapping Apple's `URLSessionWebSocketTask` and conforming to [`WebSocketProtocol`](https://github.com/shareup/websocket-protocol). `WebSocket` exposes a simple API and conforms to Apple's Combine [`Publisher`](https://developer.apple.com/documentation/combine/publisher). +A concrete implementation of a WebSocket client implemented by wrapping Apple's [`NWConnection`](https://developer.apple.com/documentation/network/nwconnection). + +The public "interface" of `WebSocket` is a simple struct whose public "methods" are exposed as closures. The reason for this design is to make it easy to inject fake `WebSocket`s into your code for testing purposes. + +The actual implementation is `SystemWebSocket`, but this type is not publicly accessible. Instead, you can access it via `WebSocket.system(url:)`. `SystemWebSocket` tries its best to mirror the documented behavior of web browsers' [`WebSocket`](http://developer.mozilla.org/en-US/docs/Web/API/WebSocket). Please report any deviations as bugs. + +`WebSocket` exposes a simple API, makes heavy use of [Swift Concurrency](https://developer.apple.com/documentation/swift/swift_standard_library/concurrency), and conforms to Apple's Combine [`Publisher`](https://developer.apple.com/documentation/combine/publisher). ## Usage ```swift -let socket = WebSocket(url: url(49999)) - -let sub = socket.sink( - receiveCompletion: { print("Socket closed: \(String(describing: $0))") }, - receiveValue: { (result) in - switch result { - case .success(.open): - socket.send("First message") - case .success(.string(let incoming)): - print("Received \(incoming)") - case .failure: - socket.close() - default: - break - } - } -) -defer { sub.cancel() } - -socket.connect() +// `WebSocket` starts connecting to the specified `URL` immediately. +let socket = WebSocket.system(url: url(49999)) + +// Wait for `WebSocket` to be ready to send and receive messages. +try await socket.open() + +// Send a message to the server +try await socket.send(.text("hello")) + +// Receive messages from the server +for await message in socket.messages { + print(message) +} + +try await socket.close() ``` ## Tests diff --git a/Sources/WebSocket/SystemWebSocket.swift b/Sources/WebSocket/SystemWebSocket.swift new file mode 100644 index 0000000..91fc3a3 --- /dev/null +++ b/Sources/WebSocket/SystemWebSocket.swift @@ -0,0 +1,640 @@ +import Combine +import Foundation +import Network +import os.log + +final actor SystemWebSocket: Publisher { + typealias Output = WebSocketMessage + typealias Failure = Never + + var isOpen: Bool { get async { + guard case .open = state else { return false } + return true + } } + + var isClosed: Bool { get async { + guard case .closed = state else { return false } + return true + } } + + private let url: URL + private let options: WebSocketOptions + private var _onOpen: WebSocketOnOpen + private var _onClose: WebSocketOnClose + private var state: WebSocketState = .unopened + + private var messageIndex = 0 // Used to identify sent messages + + private let subject = PassthroughSubject() + + private let webSocketQueue: DispatchQueue = .init( + label: "app.shareup.websocket.websocketqueue", + attributes: [], + autoreleaseFrequency: .workItem, + target: .global(qos: .default) + ) + + // Deliver messages to the subscribers on a separate queue because it's a bad idea + // to let the subscribers, who could potentially be doing long-running tasks with the + // data we send them, block our network queue. + private let subscriberQueue = DispatchQueue( + label: "app.shareup.websocket.subjectqueue", + attributes: [], + target: DispatchQueue.global(qos: .default) + ) + + init( + url: URL, + options: WebSocketOptions = .init(), + onOpen: @escaping WebSocketOnOpen = {}, + onClose: @escaping WebSocketOnClose = { _ in } + ) async throws { + self.url = url + self.options = options + _onOpen = onOpen + _onClose = onClose + try connect() + } + + deinit { + switch state { + case let .connecting(connection), let .open(connection): + connection.forceCancel() + default: + break + } + } + + nonisolated func receive(subscriber: S) + where S.Input == WebSocketMessage, S.Failure == Never + { + subject + .receive(on: subscriberQueue) + .receive(subscriber: subscriber) + } + + func open(timeout: TimeInterval? = nil) async throws { + switch state { + case .open: + return + + case .closing, .closed: + throw WebSocketError.openAfterConnectionClosed + + case .unopened, .connecting: + do { + try await withThrowingTaskGroup(of: Void.self) { (group: inout ThrowingTaskGroup) in + _ = group.addTaskUnlessCancelled { [weak self] in + guard let self = self else { return } + let _timeout = UInt64(timeout ?? self.options.timeoutIntervalForRequest) + try await Task.sleep(nanoseconds: _timeout * NSEC_PER_SEC) + throw CancellationError() + } + + _ = group.addTaskUnlessCancelled { [weak self] in + guard let self = self else { return } + while await !self.isOpen { + try await Task.sleep(nanoseconds: 10 * NSEC_PER_MSEC) + } + } + + let _ = try await group.next() + group.cancelAll() + } + } catch { + doClose() + throw error + } + } + } + + func send(_ message: WebSocketMessage) async throws { + // Mirrors the document behavior of JavaScript's `WebSocket` + // http://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send + switch state { + case let .open(connection): + messageIndex += 1 + + os_log( + "send: index=%d message=%s", + log: .webSocket, + type: .debug, + messageIndex, + message.debugDescription + ) + + let context = NWConnection.ContentContext( + identifier: String(messageIndex), + metadata: [message.metadata] + ) + + try await withCheckedThrowingContinuation { (cont: CheckedContinuation) in + connection.send( + content: message.contentAsData, + contentContext: context, + isComplete: true, + completion: .contentProcessed { (error: NWError?) in + if let error = error { + cont.resume(throwing: error) + } else { + cont.resume() + } + } + ) + } + + case .unopened, .connecting: + os_log( + "send message while connecting: %s", + log: .webSocket, + type: .error, + message.debugDescription + ) + throw WebSocketError.sendMessageWhileConnecting + + case .closing, .closed: + os_log( + "send message while closed: %s", + log: .webSocket, + type: .debug, + message.debugDescription + ) + } + } + + func close(_ closeCode: WebSocketCloseCode = .normalClosure) async throws { + switch state { + case let .connecting(conn), let .open(conn): + os_log( + "close connection: code=%d state=%{public}s", + log: .webSocket, + type: .debug, + closeCode.rawValue, + state.description + ) + + try await withCheckedThrowingContinuation { (cont: CheckedContinuation) in + conn.send( + content: nil, + contentContext: .finalMessage, + isComplete: true, + completion: .contentProcessed { (error: Error?) in + if let error = error { + cont.resume(throwing: error) + } else { + cont.resume() + } + } + ) + } + startClosing(connection: conn, error: closeCode.error) + + case .unopened, .closing, .closed: + doClose() + } + } + + func forceClose(_ closeCode: WebSocketCloseCode) { + os_log( + "force close connection: code=%d state=%{public}s", + log: .webSocket, + type: .debug, + closeCode.rawValue, + state.description + ) + + doClose() + } + + func onOpen(_ block: @escaping WebSocketOnOpen) { + _onOpen = block + } + + func onClose(_ block: @escaping WebSocketOnClose) { + _onClose = block + } +} + +private extension SystemWebSocket { + var isUnopened: Bool { + switch state { + case .unopened: return true + default: return false + } + } + + func setState(_ state: WebSocketState) async { + self.state = state + } + + func connect() throws { + precondition(isUnopened) + + guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { + throw WebSocketError.invalidURL(url) + } + + let parameters = try self.parameters(with: components) + let connection = NWConnection(to: .url(url), using: parameters) + state = .connecting(connection) + connection.stateUpdateHandler = connectionStateUpdateHandler + connection.start(queue: webSocketQueue) + } + + func openReadyConnection(_ connection: NWConnection) { + os_log( + "open connection: connection_state=%{public}s", + log: .webSocket, + type: .debug, + connection.state.debugDescription + ) + + state = .open(connection) + _onOpen() + connection.receiveMessage(completion: onReceiveMessage) + } + + func startClosing(connection: NWConnection, error: NWError? = nil) { + state = .closing(error) + subject.send(completion: .finished) + connection.cancel() + } + + func doClose() { + // TODO: Switch to using `state.description` + os_log( + "do close connection: state=%{public}s", + log: .webSocket, + type: .debug, + state.debugDescription + ) + + switch state { + case .closing(nil): + state = .closed(nil) + subject.send(completion: .finished) + _onClose(normalClosure) + + case let .closing(.some(err)): + state = .closed(.connectionError(err)) + subject.send(completion: .finished) + _onClose(closureWithError(err)) + + case .unopened: + state = .closed(nil) + subject.send(completion: .finished) + _onClose(abnormalClosure) + + case let .connecting(conn), let .open(conn): + state = .closed(nil) + subject.send(completion: .finished) + _onClose(abnormalClosure) + conn.forceCancel() + + case .closed: + // `PassthroughSubject` only sends completions once. + subject.send(completion: .finished) + } + } + + func doCloseWithError(_ error: WebSocketError) { + // TODO: Switch to using `state.description` + os_log( + "do close connection: state=%{public}s error=%{public}s", + log: .webSocket, + type: .debug, + state.debugDescription, + String(describing: error) + ) + + switch state { + case let .closing(.some(err)): + state = .closed(.connectionError(err)) + subject.send(completion: .finished) + _onClose(closureWithError(err)) + + case .closing(nil): + state = .closed(error) + subject.send(completion: .finished) + _onClose(closureWithError(error)) + + case .unopened: + state = .closed(error) + subject.send(completion: .finished) + _onClose(closureWithError(error)) + + case let .connecting(conn), let .open(conn): + state = .closed(nil) + subject.send(completion: .finished) + _onClose(closureWithError(error)) + conn.forceCancel() + + case .closed: + // `PassthroughSubject` only sends completions once. + subject.send(completion: .finished) + } + } +} + +private extension SystemWebSocket { + func host(with urlComponents: URLComponents) throws -> NWEndpoint.Host { + guard let host = urlComponents.host else { + throw WebSocketError.invalidURLComponents(urlComponents) + } + return NWEndpoint.Host(host) + } + + func port(with urlComponents: URLComponents) throws -> NWEndpoint.Port { + if let raw = urlComponents.port, let port = NWEndpoint.Port(rawValue: UInt16(raw)) { + return port + } else if urlComponents.scheme == "ws" { + return NWEndpoint.Port.http + } else if urlComponents.scheme == "wss" { + return NWEndpoint.Port.https + } else { + throw WebSocketError.invalidURLComponents(urlComponents) + } + } + + func parameters(with urlComponents: URLComponents) throws -> NWParameters { + let parameters: NWParameters + switch urlComponents.scheme { + case "ws": + parameters = .tcp + case "wss": + parameters = .tls + default: + throw WebSocketError.invalidURLComponents(urlComponents) + } + + let webSocketOptions = NWProtocolWebSocket.Options() + webSocketOptions.maximumMessageSize = options.maximumMessageSize + webSocketOptions.autoReplyPing = true + + parameters.defaultProtocolStack.applicationProtocols.insert(webSocketOptions, at: 0) + + return parameters + } +} + +private extension SystemWebSocket { + var connectionStateUpdateHandler: (NWConnection.State) -> Void { + { [weak self] (connectionState: NWConnection.State) in + Task { [weak self] in + guard let self = self else { return } + + let state = await self.state + + // TODO: Switch to using `state.description` + os_log( + "connection state update: connection_state=%{public}s state=%{public}s", + log: .webSocket, + type: .debug, + connectionState.debugDescription, + state.debugDescription + ) + + switch connectionState { + case .setup: + break + + case let .waiting(error): + await self.doCloseWithError(.connectionError(error)) + + case .preparing: + break + + case .ready: + switch state { + case let .connecting(conn): + await self.openReadyConnection(conn) + + case .open: + // TODO: Handle betterPathUpdate here? + break + + case .unopened, .closing, .closed: + // TODO: Switch to using `state.description` + os_log( + "unexpected connection ready: state=%{public}s", + log: .webSocket, + type: .error, + state.debugDescription + ) + } + + case let .failed(error): + switch state { + case let .connecting(conn), let .open(conn): + await self.startClosing(connection: conn, error: error) + + case .unopened, .closing, .closed: + break + } + + case .cancelled: + switch state { + case let .connecting(conn), let .open(conn): + await self.startClosing(connection: conn) + + case .unopened, .closing: + await self.doClose() + + case .closed: + break + } + + @unknown default: + assertionFailure("Unknown state '\(state)'") + } + } + } + } + + var onReceiveMessage: (Data?, NWConnection.ContentContext?, Bool, NWError?) -> Void { + { [weak self] data, context, isMessageComplete, error in + guard let self = self else { return } + guard isMessageComplete else { return } + + Task { + switch (data, context, error) { + case let (.some(data), .some(context), .none): + await self.handleSuccessfulMessage(data: data, context: context) + case let (.none, _, .some(error)): + await self.handleMessageWithError(error) + default: + await self.handleUnknownMessage(data: data, context: context, error: error) + } + } + } + } + + func handleSuccessfulMessage(data: Data, context: NWConnection.ContentContext) { + guard case let .open(connection) = state else { return } + + switch context.websocketMessageType { + case .binary: + os_log( + "receive binary: size=%d", + log: .webSocket, + type: .debug, + data.count + ) + subject.send(.data(data)) + + case .text: + guard let text = String(data: data, encoding: .utf8) else { + startClosing(connection: connection, error: .posix(.EBADMSG)) + return + } + os_log( + "receive text: content=%s", + log: .webSocket, + type: .debug, + text + ) + subject.send(.text(text)) + + case .close: + doClose() + + case .pong: + // TODO: Handle pongs at some point + break + + default: + let messageType = String(describing: context.websocketMessageType) + assertionFailure("Unexpected message type: \(messageType)") + } + + connection.receiveMessage(completion: onReceiveMessage) + } + + func handleMessageWithError(_ error: NWError) { + switch state { + case let .connecting(conn), let .open(conn): + + startClosing(connection: conn, error: error) + + case .unopened, .closing, .closed: + // TODO: Should we call `doClose()` here, instead? + break + } + } + + func handleUnknownMessage(data: Data?, context: NWConnection.ContentContext?, error: NWError?) { + func describeInputs() -> String { + String(describing: String(data: data ?? Data(), encoding: .utf8)) + " " + + String(describing: context) + " " + String(describing: error) + } + + // TODO: Switch to using `state.description` + os_log( + "unknown message: state=%{public}s message=%s", + log: .webSocket, + type: .error, + state.debugDescription, + describeInputs() + ) + + doCloseWithError(WebSocketError.receiveUnknownMessageType) + } +} + +private extension WebSocketMessage { + var metadata: NWProtocolWebSocket.Metadata { + switch self { + case .data: return .init(opcode: .binary) + case .text: return .init(opcode: .text) + } + } + + var contentAsData: Data { + switch self { + case let .data(data): return data + case let .text(text): return Data(text.utf8) + } + } +} + +private enum WebSocketState: CustomStringConvertible, CustomDebugStringConvertible { + case unopened + case connecting(NWConnection) + case open(NWConnection) + case closing(NWError?) + case closed(WebSocketError?) + + var description: String { + switch self { + case .unopened: return "unopened" + case .connecting: return "connecting" + case .open: return "open" + case .closing: return "closing" + case .closed: return "closed" + } + } + + var debugDescription: String { + switch self { + case .unopened: return "unopened" + case let .connecting(conn): return "connecting(\(String(reflecting: conn)))" + case let .open(conn): return "open(\(String(reflecting: conn)))" + case let .closing(error): return "closing(\(error.debugDescription))" + case let .closed(error): return "closed(\(error.debugDescription))" + } + } +} + +private extension NWConnection.ContentContext { + var webSocketMetadata: NWProtocolWebSocket.Metadata? { + let definition = NWProtocolWebSocket.definition + return protocolMetadata(definition: definition) as? NWProtocolWebSocket.Metadata + } + + var websocketMessageType: NWProtocolWebSocket.Opcode? { + webSocketMetadata?.opcode + } +} + +private extension NWError { + var shouldCloseConnectionWhileConnectingOrOpen: Bool { + switch self { + case .posix(.ECANCELED), .posix(.ENOTCONN): + return false + default: + print("Unhandled error in '\(#function)': \(debugDescription)") + return true + } + } + + var closeCode: WebSocketCloseCode { + switch self { + case .posix(.ECANCELED): + return .normalClosure + default: + print("Unhandled error in '\(#function)': \(debugDescription)") + return .normalClosure + } + } +} + +private extension NWConnection.State { + var debugDescription: String { + switch self { + case .setup: return "setup" + case let .waiting(error): return "waiting(\(String(reflecting: error)))" + case .preparing: return "preparing" + case .ready: return "ready" + case let .failed(error): return "failed(\(String(reflecting: error)))" + case .cancelled: return "cancelled" + @unknown default: return "unknown" + } + } +} + +private extension Optional where Wrapped == NWError { + var debugDescription: String { + guard case let .some(error) = self else { return "" } + return String(reflecting: error) + } +} diff --git a/Sources/WebSocket/URLSessionWebSocketTaskMessage+WebSocket.swift b/Sources/WebSocket/URLSessionWebSocketTaskMessage+WebSocket.swift index de1fa69..ccb58b8 100644 --- a/Sources/WebSocket/URLSessionWebSocketTaskMessage+WebSocket.swift +++ b/Sources/WebSocket/URLSessionWebSocketTaskMessage+WebSocket.swift @@ -6,7 +6,7 @@ extension URLSessionWebSocketTask.Message: CustomDebugStringConvertible { case let .string(text): return text case let .data(data): - return "\(data.count) bytes" + return "<\(data.count) bytes>" @unknown default: assertionFailure("Unsupported message: \(self)") return "" diff --git a/Sources/WebSocket/WebSocket.swift b/Sources/WebSocket/WebSocket.swift index 7eb7d57..6d74aa3 100644 --- a/Sources/WebSocket/WebSocket.swift +++ b/Sources/WebSocket/WebSocket.swift @@ -1,272 +1,111 @@ -import Combine import Foundation -import os.log -import Synchronized - -final actor WebSocket { - let url: URL - let options: WebSocketOptions +import Combine - var isOpen: Bool { - get async { - guard case .open = state else { return false } - return true - } - } +public typealias WebSocketOnOpen = () -> Void +public typealias WebSocketOnClose = (WebSocketCloseResult) -> Void - var isClosed: Bool { get async { await !isOpen } } +public struct WebSocket { + /// Sets a closure to be called when the WebSocket connects successfully. + public var onOpen: (@escaping WebSocketOnOpen) async -> Void - private var onStateChange: (WebSocketEvent) -> Void - private var state: WebSocketState = .unopened + /// Sets a closure to be called when the WebSocket closes. + public var onClose: (@escaping WebSocketOnClose) async -> Void - init( - url: URL, - options: WebSocketOptions = .init(), - onStateChange: @escaping (WebSocketEvent) -> Void - ) async { - self.url = url - self.options = options - self.onStateChange = onStateChange - connect() - } + /// Opens the WebSocket connect with an optional timeout. After this function + /// is awaited, the WebSocket connection is open ready to be used. If the + /// connection fails or times out, an error is thrown. + public var open: (TimeInterval?) async throws -> Void - func setOnStateChange(_ block: @escaping (WebSocketEvent) -> Void) async { - onStateChange = block - } - - func close(_ code: WebSocketCloseCode) async { - os_log( - "close: oldstate=%{public}@ code=%lld", - log: .webSocket, - type: .debug, - state.description, - code.rawValue - ) - - switch state { - case let .connecting(session, task, _), let .open(session, task, _): - state = .closed(code.urlSessionCloseCode, nil) - onStateChange(.close(code, nil)) - task.cancel(with: code.urlSessionCloseCode, reason: nil) - session.finishTasksAndInvalidate() - - case .unopened, .closing, .closed: - break - } - } + /// Sends a close frame to the server with the given close code. + public var close: (WebSocketCloseCode) async throws -> Void - func send(_ message: URLSessionWebSocketTask.Message) async throws { - // Mirrors the document behavior of JavaScript's `WebSocket` - // http://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send - switch state { - case let .open(_, task, _): - os_log("send: %s", log: .webSocket, type: .debug, message.debugDescription) - try await task.send(message) + /// Sends a text or binary message. + public var send: (WebSocketMessage) async throws -> Void - case .unopened, .connecting: - os_log( - "send message while connecting: %s", - log: .webSocket, - type: .error, - message.debugDescription - ) - throw WebSocketError.sendMessageWhileConnecting + /// Publishes messages received from WebSocket. Finishes when the + /// WebSocket connection closes. + public var messagesPublisher: () -> AnyPublisher - case .closing, .closed: - os_log( - "send message while closed: %s", - log: .webSocket, - type: .debug, - message.debugDescription - ) - } - } - - func receive() async throws -> URLSessionWebSocketTask.Message { - switch state { - case let .open(_, task, _): - let message = try await task.receive() - os_log("receive: %s", log: .webSocket, type: .debug, message.debugDescription) - return message - - case .unopened, .connecting, .closing, .closed: - os_log( - "receive in incorrect state: %s", - log: .webSocket, - type: .error, - state.description - ) - throw WebSocketError.receiveMessageWhenNotOpen + public init( + onOpen: @escaping ((@escaping WebSocketOnOpen)) async -> Void = { _ in }, + onClose: @escaping ((@escaping WebSocketOnClose)) async -> Void = { _ in }, + open: @escaping (TimeInterval?) async throws -> Void = { _ in }, + close: @escaping (WebSocketCloseCode) async throws -> Void = { _ in }, + send: @escaping (WebSocketMessage) async throws -> Void = { _ in }, + messagesPublisher: @escaping () -> AnyPublisher = { + Empty(completeImmediately: false).eraseToAnyPublisher() } + ) { + self.onOpen = onOpen + self.onClose = onClose + self.open = open + self.close = close + self.send = send + self.messagesPublisher = messagesPublisher } } -private extension WebSocket { - func setState(_ state: WebSocketState) async { - self.state = state +public extension WebSocket { + /// Calls `WebSocket.open(nil)`. + func open() async throws { + try await open(nil) } - func connect() { - os_log( - "connect: oldstate=%{public}@", - log: .webSocket, - type: .debug, - state.description - ) - - switch state { - case .closed, .unopened: - let delegate = WebSocketDelegate(onStateChange: onDelegateEvent) - - let config = URLSessionConfiguration.default - config.timeoutIntervalForRequest = options.timeoutIntervalForRequest - config.timeoutIntervalForResource = options.timeoutIntervalForResource - - let session = URLSession( - configuration: config, - delegate: delegate, - delegateQueue: nil - ) - - let task = session.webSocketTask(with: url) - task.maximumMessageSize = options.maximumMessageSize - state = .connecting(session, task, delegate) - - task.resume() - - default: - break - } + /// Calls `WebSocket.close(closeCode: .goingAway)`. + func close() async throws { + try await close(.goingAway) } - var onDelegateEvent: (WebSocketDelegateEvent) async -> Void { - { [weak self] (event: WebSocketDelegateEvent) in - guard let self = self else { return } - - switch (await self.state, event) { - case let (.connecting(s1, t1, delegate), .open(s2, t2, _)): - guard s1 === s2, t1 === t2 else { return } - await self.setState(.open(s2, t2, delegate)) - await self.onStateChange(.open) + /// The WebSocket's received messages as an asynchronous stream. + var messages: AsyncStream { + var cancellable: AnyCancellable? - case let (.connecting(s1, t1, _), .close(s2, t2, closeCode, reason)), - let (.open(s1, t1, _), .close(s2, t2, closeCode, reason)): - guard s1 === s2, t1 === t2 else { return } - if let closeCode = closeCode { - await self.setState(.closed(closeCode, reason)) - } else { - await self.setState(.closed(.abnormalClosure, nil)) + return AsyncStream { cont in + func finish() { + if cancellable != nil { + cont.finish() + cancellable = nil } - await self.onStateChange(.close(.init(closeCode), reason)) - s2.invalidateAndCancel() - - case let (.connecting(s1, t1, _), .complete(s2, t2, error)), - let (.open(s1, t1, _), .complete(s2, t2, error)): - guard s1 === s2, t1 === t2 else { return } - if let error = error { - await self.setState( - .closed( - .internalServerError, - Data(error.localizedDescription.utf8) - ) - ) - await self.onStateChange(.error(error as NSError)) - } else { - await self.setState(.closed(.internalServerError, nil)) - await self.onStateChange(.close(nil, nil)) - } - s2.invalidateAndCancel() - - case let (.closing, .close(session, _, closeCode, reason)): - if let closeCode = closeCode { - await self.setState(.closed(closeCode, reason)) - } else { - await self.setState(.closed(.abnormalClosure, nil)) - } - await self.onStateChange(.close(.init(closeCode), reason)) - session.invalidateAndCancel() - - case (.unopened, _): - return - - case (.closed, _): - return - - case (.closing, .open), (.closing, .complete): - return - - case (.open, .open): - return } - } - } -} -private enum WebSocketState: CustomStringConvertible { - case unopened - case connecting(URLSession, URLSessionWebSocketTask, WebSocketDelegate) - case open(URLSession, URLSessionWebSocketTask, WebSocketDelegate) - case closing - case closed(URLSessionWebSocketTask.CloseCode, Data?) + let _cancellable = self.messagesPublisher() + .handleEvents(receiveCancel: { finish() }) + .sink( + receiveCompletion: { _ in finish() }, + receiveValue: { cont.yield($0) } + ) - var webSocketSessionAndTask: (URLSession, URLSessionWebSocketTask)? { - switch self { - case let .connecting(session, task, _), let .open(session, task, _): - return (session, task) - case .unopened, .closing, .closed: - return nil - } - } - - var description: String { - switch self { - case .unopened: return "unopened" - case .connecting: return "connecting" - case .open: return "open" - case .closing: return "closing" - case .closed: return "closed" + cancellable = _cancellable } } } -// MARK: URLSessionWebSocketDelegate - -private enum WebSocketDelegateEvent { - case open(URLSession, URLSessionWebSocketTask, String?) - case close(URLSession, URLSessionWebSocketTask, URLSessionWebSocketTask.CloseCode?, Data?) - case complete(URLSession, URLSessionTask, Error?) -} - -private class WebSocketDelegate: NSObject, URLSessionWebSocketDelegate { - private var onStateChange: (WebSocketDelegateEvent) async -> Void - - init(onStateChange: @escaping (WebSocketDelegateEvent) async -> Void) { - self.onStateChange = onStateChange - super.init() - } - - func urlSession( - _ webSocketSession: URLSession, - webSocketTask: URLSessionWebSocketTask, - didOpenWithProtocol protocol: String? - ) { - Task { await onStateChange(.open(webSocketSession, webSocketTask, `protocol`)) } - } - - func urlSession( - _ session: URLSession, - webSocketTask: URLSessionWebSocketTask, - didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, - reason: Data? - ) { - Task { await onStateChange(.close(session, webSocketTask, closeCode, reason)) } +public extension WebSocket { + /// System WebSocket implementation powered by the Network Framework. + static func system( + url: URL, + options: WebSocketOptions = .init(), + onOpen: @escaping WebSocketOnOpen = {}, + onClose: @escaping WebSocketOnClose = { _ in } + ) async throws -> Self { + let ws = try await SystemWebSocket( + url: url, + options: options, + onOpen: onOpen, + onClose: onClose + ) + return try await .system(ws) } - func urlSession( - _ session: URLSession, - task: URLSessionTask, - didCompleteWithError error: Error? - ) { - Task { await onStateChange(.complete(session, task, error)) } + // This is only intended for use in tests. + internal static func system(_ ws: SystemWebSocket) async throws -> Self { + Self( + onOpen: { onOpen in await ws.onOpen(onOpen) }, + onClose: { onClose in await ws.onClose(onClose) }, + open: { timeout in try await ws.open(timeout: timeout) }, + close: { code in try await ws.close(code) }, + send: { message in try await ws.send(message) }, + messagesPublisher: { ws.eraseToAnyPublisher() } + ) } } diff --git a/Sources/WebSocket/WebSocketClient.swift b/Sources/WebSocket/WebSocketClient.swift deleted file mode 100644 index c54268f..0000000 --- a/Sources/WebSocket/WebSocketClient.swift +++ /dev/null @@ -1,69 +0,0 @@ -import Foundation - -public struct WebSocketClient { - public var onStateChange: (@escaping (WebSocketEvent) -> Void) async -> Void - - /// Sends a close frame to the server with the given close code. - public var close: (WebSocketCloseCode) async throws -> Void - - /// Sends the WebSocket binary message. - public var sendBinary: (Data) async throws -> Void - - /// Sends the WebSocket text message. - public var sendText: (String) async throws -> Void - - /// Receives a message from the WebSocket. - public var receiveMessage: () async throws -> WebSocketMessage -} - -public extension WebSocketClient { - /// Calls `WebSocketProtocol.close(closeCode: .goingAway)`. - func close() async throws { - try await self.close(.goingAway) - } - - func receiveText() async throws -> String { - guard case let .text(text) = try await self.receiveMessage() - else { throw WebSocketError.expectedTextReceivedData } - return text - } - - func receiveData() async throws -> Data { - guard case let .data(data) = try await self.receiveMessage() - else { throw WebSocketError.expectedDataReceivedText } - return data - } -} - -public extension WebSocketClient { - static func system( - url: URL, - options: WebSocketOptions = .init(), - onStateChange: @escaping (WebSocketEvent) -> Void - ) async -> Self { - let ws = await WebSocket( - url: url, - options: options, - onStateChange: onStateChange - ) - - return Self( - onStateChange: { await ws.setOnStateChange($0) }, - close: { await ws.close($0) }, - sendBinary: { try await ws.send(.data($0)) }, - sendText: { try await ws.send(.string($0)) }, - receiveMessage: { - switch try await ws.receive() { - case let .data(data): - return .data(data) - - case let .string(text): - return .text(text) - - @unknown default: - throw WebSocketError.receiveUnknownMessageType - } - } - ) - } -} diff --git a/Sources/WebSocket/WebSocketCloseCode.swift b/Sources/WebSocket/WebSocketCloseCode.swift index 5ad47c6..c33552a 100644 --- a/Sources/WebSocket/WebSocketCloseCode.swift +++ b/Sources/WebSocket/WebSocketCloseCode.swift @@ -1,10 +1,10 @@ import Foundation +import Network /// A code indicating why a WebSocket connection closed. /// /// Mirrors [URLSessionWebSocketTask](https://developer.apple.com/documentation/foundation/urlsessionwebsockettask/closecode). public enum WebSocketCloseCode: Int, CaseIterable { - /// A code that indicates the connection is still open. case invalid = 0 @@ -44,3 +44,36 @@ public enum WebSocketCloseCode: Int, CaseIterable { /// A reserved code that indicates the connection closed due to the failure to perform a TLS handshake. case tlsHandshakeFailure = 1015 } + +extension WebSocketCloseCode { + var error: NWError? { + switch self { + case .invalid: + return nil + case .normalClosure: + return nil + case .goingAway: + return nil + case .protocolError: + return .posix(.EPROTO) + case .unsupportedData: + return .posix(.EBADMSG) + case .noStatusReceived: + return nil + case .abnormalClosure: + return nil + case .invalidFramePayloadData: + return nil + case .policyViolation: + return nil + case .messageTooBig: + return .posix(.EMSGSIZE) + case .mandatoryExtensionMissing: + return nil + case .internalServerError: + return nil + case .tlsHandshakeFailure: + return .tls(errSSLHandshakeFail) + } + } +} diff --git a/Sources/WebSocket/WebSocketCloseResult.swift b/Sources/WebSocket/WebSocketCloseResult.swift new file mode 100644 index 0000000..458c6f4 --- /dev/null +++ b/Sources/WebSocket/WebSocketCloseResult.swift @@ -0,0 +1,8 @@ +import Foundation + +public typealias WebSocketCloseResult = Result<(code: WebSocketCloseCode, reason: Data?), Error> + +internal let normalClosure: WebSocketCloseResult = .success((.normalClosure, nil)) +internal let abnormalClosure: WebSocketCloseResult = .success((.abnormalClosure, nil)) +internal let closureWithError: (Error) -> WebSocketCloseResult = { e in .failure(e) } + diff --git a/Sources/WebSocket/WebSocketError.swift b/Sources/WebSocket/WebSocketError.swift index 55fe10a..a7a7cc0 100644 --- a/Sources/WebSocket/WebSocketError.swift +++ b/Sources/WebSocket/WebSocketError.swift @@ -1,9 +1,19 @@ import Foundation +import Network -public enum WebSocketError: Error, Hashable { +public enum WebSocketError: Error, Equatable { + case invalidURL(URL) + case invalidURLComponents(URLComponents) + case openAfterConnectionClosed case sendMessageWhileConnecting case receiveMessageWhenNotOpen case receiveUnknownMessageType - case expectedTextReceivedData - case expectedDataReceivedText + case connectionError(NWError) +} + +extension Optional where Wrapped == WebSocketError { + var debugDescription: String { + guard case let .some(error) = self else { return "" } + return String(reflecting: error) + } } diff --git a/Sources/WebSocket/WebSocketEvent.swift b/Sources/WebSocket/WebSocketEvent.swift deleted file mode 100644 index e2e218b..0000000 --- a/Sources/WebSocket/WebSocketEvent.swift +++ /dev/null @@ -1,31 +0,0 @@ -import Foundation - -/// Lifecycle events related to the opening or closing of the WebSocket. -public enum WebSocketEvent: Hashable, CustomStringConvertible { - /// Fired when a connection with a WebSocket is opened. - case open - - /// Fired when a connection with a WebSocket is closed. - case close(WebSocketCloseCode?, Data?) - - /// Fired when a connection with a WebSocket has been closed because of an error, - /// such as when some data couldn't be sent. - case error(NSError?) - - public var description: String { - switch self { - case .open: - return "open" - - case let .close(code, reason): - if let reason = reason { - return "close(\(code?.rawValue ?? -1), \(String(data: reason, encoding: .utf8) ?? ""))" - } else { - return "close(\(code?.rawValue ?? -1))" - } - - case let .error(error): - return "error(\(error?.localizedDescription ?? ""))" - } - } -} diff --git a/Sources/WebSocket/WebSocketMessage.swift b/Sources/WebSocket/WebSocketMessage.swift index e963377..0cca3e7 100644 --- a/Sources/WebSocket/WebSocketMessage.swift +++ b/Sources/WebSocket/WebSocketMessage.swift @@ -1,6 +1,7 @@ import Foundation +import Network -/// An enumeration of the types of messages that can be received. +/// An enumeration of the types of messages that can be sent or received. public enum WebSocketMessage: CustomStringConvertible, CustomDebugStringConvertible, Hashable { /// A WebSocket message that contains a block of data. case data(Data) @@ -15,5 +16,5 @@ public enum WebSocketMessage: CustomStringConvertible, CustomDebugStringConverti } } - public var debugDescription: String { self.description } + public var debugDescription: String { description } } diff --git a/Tests/WebSocketTests/Server/WebSocketServer.swift b/Tests/WebSocketTests/Server/WebSocketServer.swift index 8ecd2d2..0d34995 100644 --- a/Tests/WebSocketTests/Server/WebSocketServer.swift +++ b/Tests/WebSocketTests/Server/WebSocketServer.swift @@ -1,224 +1,227 @@ -//import Foundation -//import NIO -//import NIOHTTP1 -//import NIOWebSocket -// -//enum ReplyType { -// case echo -// case reply(() -> String?) -// case matchReply((String) -> String?) -//} -// -//final class WebSocketServer { -// let port: UInt16 -// -// private let replyType: ReplyType -// private let eventLoopGroup: EventLoopGroup -// -// private var serverChannel: Channel? -// -// init(port: UInt16, replyProvider: ReplyType) { -// self.port = port -// replyType = replyProvider -// eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) -// } -// -// func listen() { -// do { -// var addr = sockaddr_in() -// addr.sin_port = in_port_t(port).bigEndian -// let address = SocketAddress(addr, host: "0.0.0.0") -// -// let bootstrap = makeBootstrap() -// serverChannel = try bootstrap.bind(to: address).wait() -// -// guard let localAddress = serverChannel?.localAddress else { -// throw NIO.ChannelError.unknownLocalAddress -// } -// print("WebSocketServer running on \(localAddress)") -// } catch let error as NIO.IOError { -// print("Failed to start server: \(error.errnoCode) '\(error.localizedDescription)'") -// } catch { -// print("Failed to start server: \(String(describing: error))") -// } -// } -// -// func close() { -// do { try serverChannel?.close().wait() } -// catch { print("Failed to wait on server: \(error)") } -// } -// -// private func shouldUpgrade(channel _: Channel, -// head: HTTPRequestHead) -> EventLoopFuture -// { -// let headers = head.uri.starts(with: "/socket") ? HTTPHeaders() : nil -// return eventLoopGroup.next().makeSucceededFuture(headers) -// } -// -// private func upgradePipelineHandler( -// channel: Channel, -// head: HTTPRequestHead -// ) -> NIO.EventLoopFuture { -// head.uri.starts(with: "/socket") ? -// channel.pipeline.addHandler(WebSocketHandler(replyProvider: replyProvider)) : -// channel.closeFuture -// } -// -// private var replyProvider: (String) -> String? { -// { [weak self] (input: String) -> String? in -// guard let self = self else { return nil } -// switch self.replyType { -// case .echo: -// return input -// case let .reply(iterator): -// return iterator() -// case let .matchReply(matcher): -// return matcher(input) -// } -// } -// } -// -// private func makeBootstrap() -> ServerBootstrap { -// let reuseAddrOpt = ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR) -// return ServerBootstrap(group: eventLoopGroup) -// .serverChannelOption(ChannelOptions.backlog, value: 256) -// .serverChannelOption(reuseAddrOpt, value: 1) -// .childChannelInitializer { channel in -// let connectionUpgrader = NIOWebSocketServerUpgrader( -// shouldUpgrade: self.shouldUpgrade, -// upgradePipelineHandler: self.upgradePipelineHandler -// ) -// -// let config: NIOHTTPServerUpgradeConfiguration = ( -// upgraders: [connectionUpgrader], -// completionHandler: { _ in } -// ) -// -// return channel.pipeline.configureHTTPServerPipeline( -// position: .first, -// withPipeliningAssistance: true, -// withServerUpgrade: config, -// withErrorHandling: true -// ) -// } -// .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) -// .childChannelOption(reuseAddrOpt, value: 1) -// .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 1) -// } -//} -// -//private class WebSocketHandler: ChannelInboundHandler { -// typealias InboundIn = WebSocketFrame -// typealias OutboundOut = WebSocketFrame -// -// private let replyProvider: (String) -> String? -// private var awaitingClose = false -// -// init(replyProvider: @escaping (String) -> String?) { -// self.replyProvider = replyProvider -// } -// -// func channelRead(context: ChannelHandlerContext, data: NIOAny) { -// let frame = unwrapInboundIn(data) -// -// switch frame.opcode { -// case .connectionClose: -// onClose(context: context, frame: frame) -// case .ping: -// onPing(context: context, frame: frame) -// case .text: -// var data = frame.unmaskedData -// let text = data.readString(length: data.readableBytes) ?? "" -// onText(context: context, text: text) -// case .binary: -// let buffer = frame.unmaskedData -// var data = Data(capacity: buffer.readableBytes) -// buffer.withUnsafeReadableBytes { data.append(contentsOf: $0) } -// onBinary(context: context, binary: data) -// default: -// onError(context: context) -// } -// } -// -// private func onBinary(context: ChannelHandlerContext, binary: Data) { -// do { -// // Obviously, this would need to be changed to actually handle data input -// if let text = String(data: binary, encoding: .utf8) { -// onText(context: context, text: text) -// } else { -// throw NIO.IOError(errnoCode: EBADMSG, reason: "Invalid message") -// } -// } catch { -// onError(context: context) -// } -// } -// -// private func onText(context: ChannelHandlerContext, text: String) { -// guard let reply = replyProvider(text) else { return } -// -// var replyBuffer = context.channel.allocator.buffer(capacity: reply.utf8.count) -// replyBuffer.writeString(reply) -// -// let frame = WebSocketFrame(fin: true, opcode: .text, data: replyBuffer) -// -// _ = context.channel.writeAndFlush(frame) -// } -// -// private func onPing(context: ChannelHandlerContext, frame: WebSocketFrame) { -// var frameData = frame.data -// -// if let maskingKey = frame.maskKey { -// frameData.webSocketUnmask(maskingKey) -// } -// -// let pong = WebSocketFrame(fin: true, opcode: .pong, data: frameData) -// context.write(wrapOutboundOut(pong), promise: nil) -// } -// -// private func onClose(context: ChannelHandlerContext, frame: WebSocketFrame) { -// if awaitingClose { -// // We sent the initial close and were waiting for the client's response -// context.close(promise: nil) -// } else { -// // The close came from the client. -// var data = frame.unmaskedData -// let closeDataCode = data.readSlice(length: 2) ?? context.channel.allocator -// .buffer(capacity: 0) -// let closeFrame = WebSocketFrame( -// fin: true, -// opcode: .connectionClose, -// data: closeDataCode -// ) -// _ = context.write(wrapOutboundOut(closeFrame)).map { () in -// context.close(promise: nil) -// } -// } -// } -// -// private func onError(context: ChannelHandlerContext) { -// var data = context.channel.allocator.buffer(capacity: 2) -// data.write(webSocketErrorCode: .protocolError) -// let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data) -// context.write(wrapOutboundOut(frame)).whenComplete { (_: Result) in -// context.close(mode: .output, promise: nil) -// } -// awaitingClose = true -// } -// -// func channelReadComplete(context: ChannelHandlerContext) { -// context.flush() -// } -// -// func channelActive(context: ChannelHandlerContext) { -// print("Channel active: \(String(describing: context.channel.remoteAddress))") -// } -// -// func channelInactive(context: ChannelHandlerContext) { -// print("Channel closed: \(String(describing: context.localAddress))") -// } -// -// func errorCaught(context: ChannelHandlerContext, error: Error) { -// print("Error: \(error)") -// context.close(promise: nil) -// } -//} +import Foundation +import Network +import Combine +import WebSocket + +enum WebSocketServerError: Error { + case couldNotCreatePort(UInt16) +} + +enum WebSocketServerOutput: Hashable { + case die + case message(WebSocketMessage) +} + +private typealias E = WebSocketServerError + +final class WebSocketServer { + let port: UInt16 + let maximumMessageSize: Int + + // Publisher provided by consumers of `WebSocketServer` to provide the output + // `WebSocketServer` should send to its clients. + private let outputPublisher: AnyPublisher + private var outputPublisherSubscription: AnyCancellable? + + // Publisher the repeats everything sent to it by clients. + private let inputSubject = PassthroughSubject() + + private var listener: NWListener + private var connections: [NWConnection] = [] + + private let queue = DispatchQueue( + label: "app.shareup.websocketserverqueue", + qos: .default, + autoreleaseFrequency: .workItem, + target: .global() + ) + + init( + port: UInt16, + outputPublisher: P, + usesTLS: Bool = false, + maximumMessageSize: Int = 1024 * 1024 + ) throws where P.Output == WebSocketServerOutput, P.Failure == Error { + self.port = port + self.outputPublisher = outputPublisher.eraseToAnyPublisher() + self.maximumMessageSize = maximumMessageSize + + let parameters = NWParameters(tls: usesTLS ? .init() : nil) + parameters.allowLocalEndpointReuse = true + parameters.includePeerToPeer = true + parameters.acceptLocalOnly = true + + let options = NWProtocolWebSocket.Options() + options.autoReplyPing = true + options.maximumMessageSize = maximumMessageSize + + parameters.defaultProtocolStack.applicationProtocols.insert(options, at: 0) + + guard let port = NWEndpoint.Port(rawValue: port) + else { throw E.couldNotCreatePort(port) } + + listener = try NWListener(using: parameters, on: port) + + start() + } + + func forceClose() { + queue.sync { + connections.forEach { connection in + connection.forceCancel() + } + connections.removeAll() + listener.cancel() + } + } + + var inputPublisher: AnyPublisher { + inputSubject.eraseToAnyPublisher() + } +} + +private extension WebSocketServer { + func start() { + listener.newConnectionHandler = onNewConnection + + listener.stateUpdateHandler = { [weak self] state in + guard let self = self else { return } + switch state { + case .failed: + self.close() + + default: + break + } + } + + listener.start(queue: queue) + } + + func broadcastMessage(_ message: WebSocketMessage) { + let context: NWConnection.ContentContext + let content: Data + + switch message { + case let .data(data): + let metadata: NWProtocolWebSocket.Metadata = .init(opcode: .binary) + context = .init(identifier: String(message.hashValue), metadata: [metadata]) + content = data + + case let .text(string): + let metadata: NWProtocolWebSocket.Metadata = .init(opcode: .text) + context = .init(identifier: String(message.hashValue), metadata: [metadata]) + content = Data(string.utf8) + } + + connections.forEach { connection in + connection.send( + content: content, + contentContext: context, + isComplete: true, + completion: .contentProcessed({ [weak self] error in + guard let _ = error else { return } + self?.closeConnection(connection) + }) + ) + } + } + + func close() { + connections.forEach { closeConnection($0) } + connections.removeAll() + listener.cancel() + } + + func closeConnection(_ connection: NWConnection) { + connection.send( + content: nil, + contentContext: .finalMessage, + isComplete: true, + completion: .contentProcessed({ _ in + connection.cancel() + }) + ) + } + + func cancelConnection(_ connection: NWConnection) { + connection.forceCancel() + connections.removeAll(where: { $0 === connection }) + } + + var onNewConnection: (NWConnection) -> Void { + { [weak self] (newConnection: NWConnection) in + guard let self = self else { return } + + self.connections.append(newConnection) + + func receive() { + newConnection.receiveMessage { [weak self] (data, context, _, error) in + guard let self = self else { return } + guard error == nil else { return self.closeConnection(newConnection) } + + guard let data = data, + let context = context, + let _metadata = context.protocolMetadata.first, + let metadata = _metadata as? NWProtocolWebSocket.Metadata + else { return } + + switch metadata.opcode { + case .binary: + self.inputSubject.send(.data(data)) + + case .text: + if let text = String(data: data, encoding: .utf8) { + self.inputSubject.send(.text(text)) + } + + default: + break + } + + receive() + } + } + receive() + + newConnection.stateUpdateHandler = { [weak self] state in + guard let self = self else { return } + + switch state { + case .ready: + guard self.outputPublisherSubscription == nil else { break } + self.outputPublisherSubscription = self.outputPublisher + .receive(on: self.queue) + .sink( + receiveCompletion: { [weak self] completion in + guard let self = self else { return } + guard case .failure = completion else { + self.cancelConnection(newConnection) + return + } + self.close() + }, + receiveValue: { [weak self] (output: WebSocketServerOutput) in + guard let self = self else { return } + switch output { + case .die: + self.cancelConnection(newConnection) + + case let .message(message): + self.broadcastMessage(message) + } + } + ) + + case .failed: + self.cancelConnection(newConnection) + + default: + break + } + } + + newConnection.start(queue: self.queue) + } + } +} diff --git a/Tests/WebSocketTests/Server/WebSocketServer2.swift b/Tests/WebSocketTests/Server/WebSocketServer2.swift deleted file mode 100644 index b6edd51..0000000 --- a/Tests/WebSocketTests/Server/WebSocketServer2.swift +++ /dev/null @@ -1,515 +0,0 @@ -import NIO -import NIOWebSocket -import NIOHTTP1 -import NIOSSL -import Foundation -import NIOFoundationCompat - -enum ReplyType { - case echo - case reply(() -> String?) - case matchReply((String) -> String?) -} - -final class WebSocketServer { - enum PeerType { - case server - case client - } - - var eventLoop: EventLoop { - return channel.eventLoop - } - - var isClosed: Bool { !self.channel.isActive } - private(set) var closeCode: WebSocketErrorCode? - - var onClose: EventLoopFuture { - self.channel.closeFuture - } - - let port: UInt16 - - private let replyType: ReplyType - private let eventLoopGroup: EventLoopGroup - - private var channel: Channel! - private var onTextCallback: (WebSocketServer, String) -> () - private var onBinaryCallback: (WebSocketServer, ByteBuffer) -> () - private var onPongCallback: (WebSocketServer) -> () - private var onPingCallback: (WebSocketServer) -> () - private var frameSequence: WebSocketFrameSequence? - private let type: PeerType - private var waitingForPong: Bool - private var waitingForClose: Bool - private var scheduledTimeoutTask: Scheduled? - - init(port: UInt16, replyProvider: ReplyType) throws { - self.port = port - replyType = replyProvider - eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - - self.type = .server - self.onTextCallback = { _, _ in } - self.onBinaryCallback = { _, _ in } - self.onPongCallback = { _ in } - self.onPingCallback = { _ in } - self.waitingForPong = false - self.waitingForClose = false - self.scheduledTimeoutTask = nil - - var addr = sockaddr_in() - addr.sin_port = in_port_t(port).bigEndian - let address = SocketAddress(addr, host: "0.0.0.0") - - let bootstrap = makeBootstrap() - let channel = try bootstrap.bind(to: address).wait() - - guard let localAddress = channel.localAddress else { - throw NIO.ChannelError.unknownLocalAddress - } - - self.channel = channel - - print("WebSocketServer running on \(localAddress)") - } - - deinit { - assert(self.isClosed, "WebSocketServer was not closed before deinit.") - } - - func onText(_ callback: @escaping (WebSocketServer, String) -> ()) { - self.onTextCallback = callback - } - - func onBinary(_ callback: @escaping (WebSocketServer, ByteBuffer) -> ()) { - self.onBinaryCallback = callback - } - - func onPong(_ callback: @escaping (WebSocketServer) -> ()) { - self.onPongCallback = callback - } - - func onPing(_ callback: @escaping (WebSocketServer) -> ()) { - self.onPingCallback = callback - } - - /// If set, this will trigger automatic pings on the connection. If ping is not answered before - /// the next ping is sent, then the WebSocketServer will be presumed innactive and will be closed - /// automatically. - /// These pings can also be used to keep the WebSocketServer alive if there is some other timeout - /// mechanism shutting down innactive connections, such as a Load Balancer deployed in - /// front of the server. - var pingInterval: TimeAmount? { - didSet { - if pingInterval != nil { - if scheduledTimeoutTask == nil { - waitingForPong = false - self.pingAndScheduleNextTimeoutTask() - } - } else { - scheduledTimeoutTask?.cancel() - } - } - } - - func send(_ text: S, promise: EventLoopPromise? = nil) - where S: Collection, S.Element == Character - { - let string = String(text) - var buffer = channel.allocator.buffer(capacity: text.count) - buffer.writeString(string) - self.send(raw: buffer.readableBytesView, opcode: .text, fin: true, promise: promise) - } - - func send(_ binary: [UInt8], promise: EventLoopPromise? = nil) { - self.send(raw: binary, opcode: .binary, fin: true, promise: promise) - } - - func sendPing(promise: EventLoopPromise? = nil) { - self.send( - raw: Data(), - opcode: .ping, - fin: true, - promise: promise - ) - } - - func send( - raw data: Data, - opcode: WebSocketOpcode, - fin: Bool = true, - promise: EventLoopPromise? = nil - ) - where Data: DataProtocol - { - var buffer = channel.allocator.buffer(capacity: data.count) - buffer.writeBytes(data) - let frame = WebSocketFrame( - fin: fin, - opcode: opcode, - maskKey: self.makeMaskKey(), - data: buffer - ) - self.channel.writeAndFlush(frame, promise: promise) - } - - func close(code: WebSocketErrorCode = .goingAway) -> EventLoopFuture { - let promise = self.eventLoop.makePromise(of: Void.self) - self.close(code: code, promise: promise) - return promise.futureResult - } - - func close( - code: WebSocketErrorCode = .goingAway, - promise: EventLoopPromise? - ) { - guard !self.isClosed else { - promise?.succeed(()) - return - } - guard !self.waitingForClose else { - promise?.succeed(()) - return - } - self.waitingForClose = true - self.closeCode = code - - let codeAsInt = UInt16(webSocketErrorCode: code) - let codeToSend: WebSocketErrorCode - if codeAsInt == 1005 || codeAsInt == 1006 { - /// Code 1005 and 1006 are used to report errors to the application, but must never be sent over - /// the wire (per https://tools.ietf.org/html/rfc6455#section-7.4) - codeToSend = .normalClosure - } else { - codeToSend = code - } - - var buffer = channel.allocator.buffer(capacity: 2) - buffer.write(webSocketErrorCode: codeToSend) - - self.send(raw: buffer.readableBytesView, opcode: .connectionClose, fin: true, promise: promise) - } - - func makeMaskKey() -> WebSocketMaskingKey? { - switch type { - case .client: - var bytes: [UInt8] = [] - for _ in 0..<4 { - bytes.append(.random(in: .min ..< .max)) - } - return WebSocketMaskingKey(bytes) - case .server: - return nil - } - } - - func handle(incoming frame: WebSocketFrame) { - switch frame.opcode { - case .connectionClose: - if self.waitingForClose { - // peer confirmed close, time to close channel - self.channel.close(mode: .all, promise: nil) - } else { - // peer asking for close, confirm and close output side channel - let promise = self.eventLoop.makePromise(of: Void.self) - var data = frame.data - let maskingKey = frame.maskKey - if let maskingKey = maskingKey { - data.webSocketUnmask(maskingKey) - } - self.close( - code: data.readWebSocketErrorCode() ?? .unknown(1005), - promise: promise - ) - promise.futureResult.whenComplete { _ in - self.channel.close(mode: .all, promise: nil) - } - } - case .ping: - if frame.fin { - var frameData = frame.data - let maskingKey = frame.maskKey - if let maskingKey = maskingKey { - frameData.webSocketUnmask(maskingKey) - } - self.send( - raw: frameData.readableBytesView, - opcode: .pong, - fin: true, - promise: nil - ) - } else { - self.close(code: .protocolError, promise: nil) - } - case .text, .binary, .pong: - // create a new frame sequence or use existing - var frameSequence: WebSocketFrameSequence - if let existing = self.frameSequence { - frameSequence = existing - } else { - frameSequence = WebSocketFrameSequence(type: frame.opcode) - } - // append this frame and update the sequence - frameSequence.append(frame) - self.frameSequence = frameSequence - case .continuation: - // we must have an existing sequence - if var frameSequence = self.frameSequence { - // append this frame and update - frameSequence.append(frame) - self.frameSequence = frameSequence - } else { - self.close(code: .protocolError, promise: nil) - } - default: - // We ignore all other frames. - break - } - - // if this frame was final and we have a non-nil frame sequence, - // output it to the websocket and clear storage - if let frameSequence = self.frameSequence, frame.fin { - switch frameSequence.type { - case .binary: - self.onBinaryCallback(self, frameSequence.binaryBuffer) - case .text: - self.onTextCallback(self, frameSequence.textBuffer) - case .pong: - self.waitingForPong = false - self.onPongCallback(self) - case .ping: - self.onPingCallback(self) - default: break - } - self.frameSequence = nil - } - } - - private func pingAndScheduleNextTimeoutTask() { - guard channel.isActive, let pingInterval = pingInterval else { - return - } - - if waitingForPong { - // We never received a pong from our last ping, so the connection has timed out - let promise = self.eventLoop.makePromise(of: Void.self) - self.close(code: .unknown(1006), promise: promise) - promise.futureResult.whenComplete { _ in - // Usually, closing a WebSocketServer is done by sending the close frame and waiting - // for the peer to respond with their close frame. We are in a timeout situation, - // so the other side likely will never send the close frame. We just close the - // channel ourselves. - self.channel.close(mode: .all, promise: nil) - } - } else { - self.sendPing() - self.waitingForPong = true - self.scheduledTimeoutTask = self.eventLoop.scheduleTask( - deadline: .now() + pingInterval, - self.pingAndScheduleNextTimeoutTask - ) - } - } - - private func makeBootstrap() -> ServerBootstrap { - let reuseAddrOpt = ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR) - return ServerBootstrap(group: eventLoopGroup) - .serverChannelOption(ChannelOptions.backlog, value: 256) - .serverChannelOption(reuseAddrOpt, value: 1) - .childChannelInitializer { channel in - let connectionUpgrader = NIOWebSocketServerUpgrader( - shouldUpgrade: self.shouldUpgrade, - upgradePipelineHandler: self.upgradePipelineHandler - ) - - let config: NIOHTTPServerUpgradeConfiguration = ( - upgraders: [connectionUpgrader], - completionHandler: { _ in } - ) - - return channel.pipeline.configureHTTPServerPipeline( - position: .first, - withPipeliningAssistance: true, - withServerUpgrade: config, - withErrorHandling: true - ) - } - .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) - .childChannelOption(reuseAddrOpt, value: 1) - .childChannelOption(ChannelOptions.maxMessagesPerRead, value: 1) - } - - private func shouldUpgrade(channel _: Channel, - head: HTTPRequestHead) -> EventLoopFuture - { - let headers = head.uri.starts(with: "/socket") ? HTTPHeaders() : nil - return eventLoopGroup.next().makeSucceededFuture(headers) - } - - private func upgradePipelineHandler( - channel: Channel, - head: HTTPRequestHead - ) -> NIO.EventLoopFuture { - head.uri.starts(with: "/socket") ? - channel.pipeline.addHandler(WebSocketHandler(replyType: replyType)) : - channel.closeFuture - } -} - -private final class WebSocketHandler: ChannelInboundHandler { - typealias InboundIn = WebSocketFrame - typealias OutboundOut = WebSocketFrame - - private var awaitingClose: Bool = false - private let replyType: ReplyType - - init(replyType: ReplyType) { - self.replyType = replyType - } - - private func replyProvider(input: String) -> String? { - switch replyType { - case .echo: - return input - case let .reply(iterator): - return iterator() - case let .matchReply(matcher): - return matcher(input) - } - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let frame = self.unwrapInboundIn(data) - - func handleText(_ text: String) { - guard let reply = replyProvider(input: text) else { return } - let buffer = context.channel.allocator.buffer(data: Data(reply.utf8)) - let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) - context.writeAndFlush(self.wrapOutboundOut(frame)).whenFailure { _ in - context.close(promise: nil) - } - } - - switch frame.opcode { - case .connectionClose: - self.receivedClose(context: context, frame: frame) - case .ping: - self.pong(context: context, frame: frame) - case .text: - var data = frame.unmaskedData - let text = data.readString(length: data.readableBytes) ?? "" - handleText(text) - - case .binary: - let buffer = frame.unmaskedData - var data = Data(capacity: buffer.readableBytes) - buffer.withUnsafeReadableBytes { data.append(contentsOf: $0) } - - if let text = String(data: data, encoding: .utf8) { - handleText(text) - } else { - closeOnError(context: context) - } - - case .continuation, .pong: - // We ignore these frames. - break - default: - // Unknown frames are errors. - self.closeOnError(context: context) - } - } - - func channelReadComplete(context: ChannelHandlerContext) { - context.flush() - } - - private func sendTime(context: ChannelHandlerContext) { - guard context.channel.isActive else { return } - - // We can't send if we sent a close message. - guard !self.awaitingClose else { return } - - // We can't really check for error here, but it's also not the purpose of the - // example so let's not worry about it. - let theTime = NIODeadline.now().uptimeNanoseconds - var buffer = context.channel.allocator.buffer(capacity: 12) - buffer.writeString("\(theTime)") - - let frame = WebSocketFrame(fin: true, opcode: .text, data: buffer) - context.writeAndFlush(self.wrapOutboundOut(frame)).map { - context.eventLoop.scheduleTask(in: .seconds(1), { self.sendTime(context: context) }) - }.whenFailure { (_: Error) in - context.close(promise: nil) - } - } - - private func receivedClose(context: ChannelHandlerContext, frame: WebSocketFrame) { - // Handle a received close frame. In websockets, we're just going to send the close - // frame and then close, unless we already sent our own close frame. - if awaitingClose { - // Cool, we started the close and were waiting for the user. We're done. - context.close(promise: nil) - } else { - // This is an unsolicited close. We're going to send a response frame and - // then, when we've sent it, close up shop. We should send back the close code the remote - // peer sent us, unless they didn't send one at all. - var data = frame.unmaskedData - let closeDataCode = data.readSlice(length: 2) ?? ByteBuffer() - let closeFrame = WebSocketFrame(fin: true, opcode: .connectionClose, data: closeDataCode) - _ = context.write(self.wrapOutboundOut(closeFrame)).map { () in - context.close(promise: nil) - } - } - } - - private func pong(context: ChannelHandlerContext, frame: WebSocketFrame) { - var frameData = frame.data - let maskingKey = frame.maskKey - - if let maskingKey = maskingKey { - frameData.webSocketUnmask(maskingKey) - } - - let responseFrame = WebSocketFrame(fin: true, opcode: .pong, data: frameData) - context.write(self.wrapOutboundOut(responseFrame), promise: nil) - } - - private func closeOnError(context: ChannelHandlerContext) { - // We have hit an error, we want to close. We do that by sending a close frame and then - // shutting down the write side of the connection. - var data = context.channel.allocator.buffer(capacity: 2) - data.write(webSocketErrorCode: .protocolError) - let frame = WebSocketFrame(fin: true, opcode: .connectionClose, data: data) - context.write(self.wrapOutboundOut(frame)).whenComplete { (_: Result) in - context.close(mode: .output, promise: nil) - } - awaitingClose = true - } -} - -private struct WebSocketFrameSequence { - var binaryBuffer: ByteBuffer - var textBuffer: String - var type: WebSocketOpcode - - init(type: WebSocketOpcode) { - self.binaryBuffer = ByteBufferAllocator().buffer(capacity: 0) - self.textBuffer = .init() - self.type = type - } - - mutating func append(_ frame: WebSocketFrame) { - var data = frame.unmaskedData - switch type { - case .binary: - self.binaryBuffer.writeBuffer(&data) - case .text: - if let string = data.readString(length: data.readableBytes) { - self.textBuffer += string - } - default: break - } - } -} diff --git a/Tests/WebSocketTests/SystemWebSocketTests.swift b/Tests/WebSocketTests/SystemWebSocketTests.swift new file mode 100644 index 0000000..76d8960 --- /dev/null +++ b/Tests/WebSocketTests/SystemWebSocketTests.swift @@ -0,0 +1,348 @@ +import Combine +@testable import WebSocket +import XCTest + +private var ports = (50000 ... 52000).map { UInt16($0) } + +// NOTE: If `WebSocketTests` is not marked as `@MainActor`, calls to +// `wait(for:timeout:)` prevent other asyncronous events from running. +// Using `await waitForExpectations(timeout:handler:)` works properly +// because it's already marked as `@MainActor`. + +@MainActor +class SystemWebSocketTests: XCTestCase { + var subject: PassthroughSubject! + + @MainActor + override func setUp() async throws { + try await super.setUp() + subject = .init() + } + + func testCanConnectToAndDisconnectFromServer() async throws { + let openEx = expectation(description: "Should have opened") + let closeEx = expectation(description: "Should have closed") + let (server, client) = await makeServerAndClient( + onOpen: { openEx.fulfill() }, + onClose: { result in + switch result { + case let .success(close): + XCTAssertEqual(.normalClosure, close.code) + XCTAssertNil(close.reason) + closeEx.fulfill() + + case let .failure(error): + XCTFail("Should not have received error: \(error)") + } + } + ) + defer { server.forceClose() } + + wait(for: [openEx], timeout: 2) + + let isOpen = await client.isOpen + XCTAssertTrue(isOpen) + + try await client.close() + wait(for: [closeEx], timeout: 2) + } + + func testErrorWhenServerIsUnreachable() async throws { + let ex = expectation(description: "Should have errored") + let (server, client) = await makeOfflineServerAndClient( + onOpen: { XCTFail("Should not have opened") }, + onClose: { result in + switch result { + case let .success(close): + XCTFail("Should not have closed successfully: \(String(reflecting: close))") + + case let .failure(error): + guard let webSocketError = error as? WebSocketError, + case let .connectionError(nwerror) = webSocketError, + case let .posix(posix) = nwerror + else { return XCTFail("Closed with incorrect error: \(error)") } + XCTAssertEqual(.ECONNREFUSED, posix) + ex.fulfill() + } + } + ) + defer { server.forceClose() } + + waitForExpectations(timeout: 2) + + let isClosed = await client.isClosed + XCTAssertTrue(isClosed) + } + + func testErrorWhenRemoteCloses() async throws { + let errorEx = expectation(description: "Should have closed") + let (server, client) = await makeServerAndClient( + onClose: { result in + switch result { + case let .success(close): + XCTFail("Should not have closed successfully: \(String(reflecting: close))") + + case let .failure(error): + guard let err = error as? WebSocketError, + case .receiveUnknownMessageType = err + else { return XCTFail("Should have received unknown message error") } + errorEx.fulfill() + } + } + ) + defer { server.forceClose() } + + try await client.open() + + subject.send(.die) + wait(for: [errorEx], timeout: 2) + } + + func testWebSocketCannotBeOpenedTwice() async throws { + var closeCount = 0 + + let firstCloseEx = expectation(description: "Should have closed once") + let secondCloseEx = expectation(description: "Should not have closed more than once") + secondCloseEx.isInverted = true + + let (server, client) = await makeServerAndClient( + onClose: { result in + closeCount += 1 + if closeCount == 1 { + firstCloseEx.fulfill() + } else { + secondCloseEx.fulfill() + } + } + ) + defer { server.forceClose() } + + try await client.open() + + try await client.close() + wait(for: [firstCloseEx], timeout: 2) + + do { + try await client.open() + XCTFail("Should not have successfully reopened") + } catch { + guard let wserror = error as? WebSocketError, + case .openAfterConnectionClosed = wserror + else { return XCTFail("Received wrong error: \(error)") } + } + + wait(for: [secondCloseEx], timeout: 0.1) + } + + func testPushAndReceiveText() async throws { + let (server, client) = await makeServerAndClient() + defer { server.forceClose() } + + let sentEx = expectation(description: "Server should have received message") + let sentSub = server.inputPublisher + .sink(receiveValue: { message in + guard case let .text(text) = message + else { return XCTFail("Should have received text") } + XCTAssertEqual("hello", text) + sentEx.fulfill() + }) + defer { sentSub.cancel() } + + try await client.open() + + let receivedEx = expectation(description: "Should have received message") + let receivedSub = client.sink { message in + defer { receivedEx.fulfill() } + guard case let .text(text) = message + else { return XCTFail("Should have received text") } + XCTAssertEqual("hi, to you too!", text) + } + defer { receivedSub.cancel() } + + try await client.send(.text("hello")) + wait(for: [sentEx], timeout: 2) + subject.send(.message(.text("hi, to you too!"))) + wait(for: [receivedEx], timeout: 2) + } + + @available(iOS 15.0, macOS 12.0, *) + func testPushAndReceiveTextWithAsyncPublisher() async throws { + let (server, client) = await makeServerAndClient() + defer { server.forceClose() } + + try await client.open() + + try await client.send(.text("hello")) + subject.send(.message(.text("hi, to you too!"))) + + for await message in client.values { + guard case let .text(text) = message else { + XCTFail("Should have received text") + break + } + XCTAssertEqual("hi, to you too!", text) + break + } + } + + func testPushAndReceiveData() async throws { + let (server, client) = await makeServerAndClient() + defer { server.forceClose() } + + let sentEx = expectation(description: "Server should have received message") + let sentSub = server.inputPublisher + .sink(receiveValue: { message in + guard case let .data(data) = message + else { return XCTFail("Should have received data") } + XCTAssertEqual(Data("hello".utf8), data) + sentEx.fulfill() + }) + defer { sentSub.cancel() } + + try await client.open() + + let receivedEx = expectation(description: "Should have received message") + let receivedSub = client.sink { message in + defer { receivedEx.fulfill() } + guard case let .data(data) = message + else { return XCTFail("Should have received data") } + XCTAssertEqual(Data("hi, to you too!".utf8), data) + } + defer { receivedSub.cancel() } + + try await client.send(.data(Data("hello".utf8))) + wait(for: [sentEx], timeout: 2) + subject.send(.message(.data(Data("hi, to you too!".utf8)))) + wait(for: [receivedEx], timeout: 2) + } + + @available(iOS 15.0, macOS 12.0, *) + func testPushAndReceiveDataWithAsyncPublisher() async throws { + let (server, client) = await makeServerAndClient() + defer { server.forceClose() } + + try await client.open() + + try await client.send(.data(Data("hello bytes".utf8))) + subject.send(.message(.data(Data("howdy".utf8)))) + + for await message in client.values { + guard case let .data(data) = message else { + XCTFail("Should have received data") + break + } + XCTAssertEqual("howdy", String(data: data, encoding: .utf8)) + break + } + } + + func testWrappedSystemWebSocket() async throws { + let openEx = expectation(description: "Should have opened") + let closeEx = expectation(description: "Should have closed") + let (server, client) = await makeServerAndWrappedClient( + onOpen: { openEx.fulfill() }, + onClose: { result in + switch result { + case let .success((code, reason)): + XCTAssertEqual(.normalClosure, code) + XCTAssertNil(reason) + closeEx.fulfill() + case let .failure(error): + XCTFail("Should not have failed: \(error)") + } + } + ) + defer { server.forceClose() } + + var messagesToSend: [WebSocketMessage] = [ + .text("one"), + .data(Data("two".utf8)), + .text("three"), + ] + + var messagesToReceive: [WebSocketMessage] = [ + .text("one"), + .data(Data("two".utf8)), + .text("three"), + ] + + let sentSub = server.inputPublisher + .sink(receiveValue: { message in + let expected = messagesToSend.removeFirst() + XCTAssertEqual(expected, message) + }) + defer { sentSub.cancel() } + + // These two lines are redundant, but the goal + // is to test everything in `WebSocket`. + try await client.open() + wait(for: [openEx], timeout: 2) + + // These messages have to be sent after the `AsyncStream` is + // subscribed to below. So, we send them asynchronously. + let firstMessageToReceive = try XCTUnwrap(messagesToReceive.first) + let firstMessageToSend = try XCTUnwrap(messagesToSend.first) + Task.detached { + await self.subject.send(.message(firstMessageToReceive)) + try await client.send(firstMessageToSend) + } + + for await message in client.messages { + let expected = messagesToReceive.removeFirst() + XCTAssertEqual(expected, message) + + if let messageToSend = messagesToSend.first, + let messageToReceive = messagesToReceive.first { + try await client.send(messageToSend) + subject.send(.message(messageToReceive)) + } else { + try await client.close() + } + } + + XCTAssertTrue(messagesToSend.isEmpty) + XCTAssertTrue(messagesToReceive.isEmpty) + + wait(for: [closeEx], timeout: 2) + } +} + +private let empty: Empty = Empty( + completeImmediately: false, + outputType: WebSocketServerOutput.self, + failureType: Error.self +) + +private extension SystemWebSocketTests { + func url(_ port: UInt16) -> URL { URL(string: "ws://0.0.0.0:\(port)/socket")! } + + func makeServerAndClient( + onOpen: @escaping () -> Void = {}, + onClose: @escaping (WebSocketCloseResult) -> Void = { _ in } + ) async -> (WebSocketServer, SystemWebSocket) { + let port = ports.removeFirst() + let server = try! WebSocketServer(port: port, outputPublisher: subject) + let client = try! await SystemWebSocket(url: url(port), onOpen: onOpen, onClose: onClose) + return (server, client) + } + + func makeOfflineServerAndClient( + onOpen: @escaping () -> Void = {}, + onClose: @escaping (WebSocketCloseResult) -> Void = { _ in } + ) async -> (WebSocketServer, SystemWebSocket) { + let port = ports.removeFirst() + let server = try! WebSocketServer(port: 1, outputPublisher: empty) + let client = try! await SystemWebSocket(url: url(port), onOpen: onOpen, onClose: onClose) + return (server, client) + } + + func makeServerAndWrappedClient( + onOpen: @escaping () -> Void = {}, + onClose: @escaping (WebSocketCloseResult) -> Void = { _ in } + ) async -> (WebSocketServer, WebSocket) { + let port = ports.removeFirst() + let server = try! WebSocketServer(port: port, outputPublisher: subject) + let client = try! await SystemWebSocket(url: url(port), onOpen: onOpen, onClose: onClose) + return (server, try! await .system(client)) + } +} diff --git a/Tests/WebSocketTests/WebSocketTests.swift b/Tests/WebSocketTests/WebSocketTests.swift deleted file mode 100644 index 7e9bde6..0000000 --- a/Tests/WebSocketTests/WebSocketTests.swift +++ /dev/null @@ -1,204 +0,0 @@ -import Combine -@testable import WebSocket -import XCTest - -private var ports = (50000 ... 52000).map { UInt16($0) } - -// NOTE: If `WebSocketTests` is not marked as `@MainActor`, calls to -// `wait(for:timeout:)` prevent other asyncronous events from running. -// Using `await waitForExpectations(timeout:handler:)` works properly -// because it's already marked as `@MainActor`. - -@MainActor -class WebSocketTests: XCTestCase { - func url(_ port: UInt16) -> URL { URL(string: "ws://0.0.0.0:\(port)/socket")! } - - func testCanConnectToAndDisconnectFromServer() async throws { - let openEx = expectation(description: "Should have opened") - let closeEx = expectation(description: "Should have closed") - let (server, client) = await makeServerAndClient { event in - switch event { - case .open: - openEx.fulfill() - - case let .close(closeCode, _): - XCTAssertEqual(.normalClosure, closeCode) - closeEx.fulfill() - - case let .error(error): - XCTFail("Should not have received error: \(String(describing: error))") - } - } - defer { server.close() } - - wait(for: [openEx], timeout: 0.5) - - let isOpen = await client.isOpen - XCTAssertTrue(isOpen) - - await client.close(.normalClosure) - wait(for: [closeEx], timeout: 0.5) - } - - func testErrorWhenServerIsUnreachable() async throws { - let ex = expectation(description: "Should have errored") - let (server, client) = await makeOfflineServerAndClient { event in - guard case let .error(error) = event else { - return XCTFail("Should not have received \(event)") - } - XCTAssertEqual(-1004, error?.code) - ex.fulfill() - } - defer { server.close() } - - waitForExpectations(timeout: 0.5) - - let isClosed = await client.isClosed - XCTAssertTrue(isClosed) - } - - func testErrorWhenRemoteCloses() async throws { - var invalidUTF8Bytes = [0x192, 0x193] as [UInt16] - let bytes = withUnsafeBytes(of: &invalidUTF8Bytes) { Array($0) } - let data = Data(bytes: bytes, count: bytes.count) - - let openEx = expectation(description: "Should have opened") - let errorEx = expectation(description: "Should have errored") - - let (server, client) = await makeServerAndClient { event in - switch event { - case .open: - openEx.fulfill() - - case .close: - Swift.print("$$$ CLOSED") - XCTFail("Should not have closed") - - case let .error(error): - Swift.print("$$$ ERROR: \(String(describing: error))") - errorEx.fulfill() - } - } - defer { server.close() } - - wait(for: [openEx], timeout: 0.5) - let isOpen = await client.isOpen - XCTAssertTrue(isOpen) - - try await client.send(.data(data)) - wait(for: [errorEx], timeout: 0.5) - let isClosed = await client.isClosed - XCTAssertTrue(isClosed) - } - - func testEchoPush() async throws { - let openEx = expectation(description: "Should have opened") - let (server, client) = await makeEchoServerAndClient { event in - guard case .open = event else { return } - openEx.fulfill() - } - defer { server.close() } - - wait(for: [openEx], timeout: 0.5) - - try await client.send(.string("hello")) - guard case let .string(text) = try await client.receive() - else { return XCTFail("Should have received text") } - - XCTAssertEqual("hello", text) - } - - func testEchoBinaryPush() async throws { - let openEx = expectation(description: "Should have opened") - let (server, client) = await makeEchoServerAndClient { event in - guard case .open = event else { return } - openEx.fulfill() - } - defer { server.close() } - - wait(for: [openEx], timeout: 0.5) - - try await client.send(.data(Data("hello".utf8))) - guard case let .string(text) = try await client.receive() - else { return XCTFail("Should have received text") } - - XCTAssertEqual("hello", text) - } - - func testJoinLobbyAndEcho() async throws { - var pushes = [ - "[1,1,\"room:lobby\",\"phx_join\",{}]", - "[1,2,\"room:lobby\",\"echo\",{\"echo\":\"one\"}]", - "[1,3,\"room:lobby\",\"echo\",{\"echo\":\"two\"}]", - ] - - let replies = [ - "[1,1,\"room:lobby\",\"phx_reply\",{\"response\":{},\"status\":\"ok\"}]", - "[1,2,\"room:lobby\",\"phx_reply\",{\"response\":{\"echo\":\"one\"},\"status\":\"ok\"}]", - "[1,3,\"room:lobby\",\"phx_reply\",{\"response\":{\"echo\":\"two\"},\"status\":\"ok\"}]", - ] - - let openEx = expectation(description: "Should have opened") - - let (server, client) = await makeReplyServerAndClient(replies) { event in - guard case .open = event else { return } - openEx.fulfill() - } - defer { server.close() } - - wait(for: [openEx], timeout: 0.5) - - try await client.send(.string(pushes.removeFirst())) - try await client.send(.string(pushes.removeFirst())) - try await client.send(.string(pushes.removeFirst())) - - for expected in replies { - guard case let .string(reply) = try await client.receive() else { return XCTFail() } - XCTAssertEqual(expected, reply) - } - } -} - -private extension WebSocketTests { - func makeServerAndClient( - _ onStateChange: @escaping (WebSocketEvent) -> Void = { _ in } - ) async -> (WebSocketServer, WebSocket) { - let port = ports.removeFirst() - let server = try! WebSocketServer(port: port, replyProvider: .reply { nil }) - let client = await WebSocket(url: url(port), onStateChange: onStateChange) -// server.listen() - return (server, client) - } - - func makeOfflineServerAndClient( - _ onStateChange: @escaping (WebSocketEvent) -> Void = { _ in } - ) async -> (WebSocketServer, WebSocket) { - let port = ports.removeFirst() - let server = try! WebSocketServer(port: 1, replyProvider: .reply { nil }) - let client = await WebSocket(url: url(port), onStateChange: onStateChange) - return (server, client) - } - - func makeEchoServerAndClient( - _ onStateChange: @escaping (WebSocketEvent) -> Void = { _ in } - ) async -> (WebSocketServer, WebSocket) { - let port = ports.removeFirst() - let server = try! WebSocketServer(port: port, replyProvider: .echo) - let client = await WebSocket(url: url(port), onStateChange: onStateChange) -// server.listen() - return (server, client) - } - - func makeReplyServerAndClient( - _ replies: [String?], - _ onStateChange: @escaping (WebSocketEvent) -> Void = { _ in } - ) async -> (WebSocketServer, WebSocket) { - let port = ports.removeFirst() - var replies = replies - let provider: () -> String? = { replies.removeFirst() } - let server = try! WebSocketServer(port: port, replyProvider: .reply(provider)) - let client = await WebSocket(url: url(port), onStateChange: onStateChange) -// server.listen() - return (server, client) - } -} From 7a6e7baac380e37fa023e9a115d15fa9ee2037e0 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Mon, 18 Apr 2022 00:03:20 +0200 Subject: [PATCH 06/11] Format --- .swiftformat | 4 +++- Sources/WebSocket/SystemWebSocket.swift | 15 ++++++++---- Sources/WebSocket/WebSocket.swift | 6 ++--- Sources/WebSocket/WebSocketCloseResult.swift | 1 - .../Server/WebSocketServer.swift | 14 +++++------ .../WebSocketTests/SystemWebSocketTests.swift | 23 +++++++++++++++---- ...RLSessionWebSocketTaskCloseCodeTests.swift | 6 +++-- 7 files changed, 46 insertions(+), 23 deletions(-) diff --git a/.swiftformat b/.swiftformat index 7a40186..35d6950 100644 --- a/.swiftformat +++ b/.swiftformat @@ -1,6 +1,8 @@ --funcattributes prev-line --minversion 0.47.2 ---maxwidth 100 +--maxwidth 96 --typeattributes prev-line --wraparguments before-first +--wrapparameters before-first --wrapcollections before-first +--xcodeindentation enabled diff --git a/Sources/WebSocket/SystemWebSocket.swift b/Sources/WebSocket/SystemWebSocket.swift index 91fc3a3..305a53a 100644 --- a/Sources/WebSocket/SystemWebSocket.swift +++ b/Sources/WebSocket/SystemWebSocket.swift @@ -72,7 +72,7 @@ final actor SystemWebSocket: Publisher { .receive(on: subscriberQueue) .receive(subscriber: subscriber) } - + func open(timeout: TimeInterval? = nil) async throws { switch state { case .open: @@ -83,7 +83,10 @@ final actor SystemWebSocket: Publisher { case .unopened, .connecting: do { - try await withThrowingTaskGroup(of: Void.self) { (group: inout ThrowingTaskGroup) in + try await withThrowingTaskGroup( + of: Void + .self + ) { (group: inout ThrowingTaskGroup) in _ = group.addTaskUnlessCancelled { [weak self] in guard let self = self else { return } let _timeout = UInt64(timeout ?? self.options.timeoutIntervalForRequest) @@ -98,7 +101,7 @@ final actor SystemWebSocket: Publisher { } } - let _ = try await group.next() + _ = try await group.next() group.cancelAll() } } catch { @@ -522,7 +525,11 @@ private extension SystemWebSocket { } } - func handleUnknownMessage(data: Data?, context: NWConnection.ContentContext?, error: NWError?) { + func handleUnknownMessage( + data: Data?, + context: NWConnection.ContentContext?, + error: NWError? + ) { func describeInputs() -> String { String(describing: String(data: data ?? Data(), encoding: .utf8)) + " " + String(describing: context) + " " + String(describing: error) diff --git a/Sources/WebSocket/WebSocket.swift b/Sources/WebSocket/WebSocket.swift index 6d74aa3..e0547eb 100644 --- a/Sources/WebSocket/WebSocket.swift +++ b/Sources/WebSocket/WebSocket.swift @@ -1,5 +1,5 @@ -import Foundation import Combine +import Foundation public typealias WebSocketOnOpen = () -> Void public typealias WebSocketOnClose = (WebSocketCloseResult) -> Void @@ -27,8 +27,8 @@ public struct WebSocket { public var messagesPublisher: () -> AnyPublisher public init( - onOpen: @escaping ((@escaping WebSocketOnOpen)) async -> Void = { _ in }, - onClose: @escaping ((@escaping WebSocketOnClose)) async -> Void = { _ in }, + onOpen: @escaping (@escaping WebSocketOnOpen) async -> Void = { _ in }, + onClose: @escaping (@escaping WebSocketOnClose) async -> Void = { _ in }, open: @escaping (TimeInterval?) async throws -> Void = { _ in }, close: @escaping (WebSocketCloseCode) async throws -> Void = { _ in }, send: @escaping (WebSocketMessage) async throws -> Void = { _ in }, diff --git a/Sources/WebSocket/WebSocketCloseResult.swift b/Sources/WebSocket/WebSocketCloseResult.swift index 458c6f4..048421d 100644 --- a/Sources/WebSocket/WebSocketCloseResult.swift +++ b/Sources/WebSocket/WebSocketCloseResult.swift @@ -5,4 +5,3 @@ public typealias WebSocketCloseResult = Result<(code: WebSocketCloseCode, reason internal let normalClosure: WebSocketCloseResult = .success((.normalClosure, nil)) internal let abnormalClosure: WebSocketCloseResult = .success((.abnormalClosure, nil)) internal let closureWithError: (Error) -> WebSocketCloseResult = { e in .failure(e) } - diff --git a/Tests/WebSocketTests/Server/WebSocketServer.swift b/Tests/WebSocketTests/Server/WebSocketServer.swift index 0d34995..735eb1e 100644 --- a/Tests/WebSocketTests/Server/WebSocketServer.swift +++ b/Tests/WebSocketTests/Server/WebSocketServer.swift @@ -1,6 +1,6 @@ +import Combine import Foundation import Network -import Combine import WebSocket enum WebSocketServerError: Error { @@ -119,10 +119,10 @@ private extension WebSocketServer { content: content, contentContext: context, isComplete: true, - completion: .contentProcessed({ [weak self] error in + completion: .contentProcessed { [weak self] error in guard let _ = error else { return } self?.closeConnection(connection) - }) + } ) } } @@ -138,9 +138,9 @@ private extension WebSocketServer { content: nil, contentContext: .finalMessage, isComplete: true, - completion: .contentProcessed({ _ in + completion: .contentProcessed { _ in connection.cancel() - }) + } ) } @@ -156,14 +156,14 @@ private extension WebSocketServer { self.connections.append(newConnection) func receive() { - newConnection.receiveMessage { [weak self] (data, context, _, error) in + newConnection.receiveMessage { [weak self] data, context, _, error in guard let self = self else { return } guard error == nil else { return self.closeConnection(newConnection) } guard let data = data, let context = context, let _metadata = context.protocolMetadata.first, - let metadata = _metadata as? NWProtocolWebSocket.Metadata + let metadata = _metadata as? NWProtocolWebSocket.Metadata else { return } switch metadata.opcode { diff --git a/Tests/WebSocketTests/SystemWebSocketTests.swift b/Tests/WebSocketTests/SystemWebSocketTests.swift index 76d8960..44ff10e 100644 --- a/Tests/WebSocketTests/SystemWebSocketTests.swift +++ b/Tests/WebSocketTests/SystemWebSocketTests.swift @@ -106,7 +106,7 @@ class SystemWebSocketTests: XCTestCase { secondCloseEx.isInverted = true let (server, client) = await makeServerAndClient( - onClose: { result in + onClose: { _ in closeCount += 1 if closeCount == 1 { firstCloseEx.fulfill() @@ -292,7 +292,8 @@ class SystemWebSocketTests: XCTestCase { XCTAssertEqual(expected, message) if let messageToSend = messagesToSend.first, - let messageToReceive = messagesToReceive.first { + let messageToReceive = messagesToReceive.first + { try await client.send(messageToSend) subject.send(.message(messageToReceive)) } else { @@ -322,7 +323,11 @@ private extension SystemWebSocketTests { ) async -> (WebSocketServer, SystemWebSocket) { let port = ports.removeFirst() let server = try! WebSocketServer(port: port, outputPublisher: subject) - let client = try! await SystemWebSocket(url: url(port), onOpen: onOpen, onClose: onClose) + let client = try! await SystemWebSocket( + url: url(port), + onOpen: onOpen, + onClose: onClose + ) return (server, client) } @@ -332,7 +337,11 @@ private extension SystemWebSocketTests { ) async -> (WebSocketServer, SystemWebSocket) { let port = ports.removeFirst() let server = try! WebSocketServer(port: 1, outputPublisher: empty) - let client = try! await SystemWebSocket(url: url(port), onOpen: onOpen, onClose: onClose) + let client = try! await SystemWebSocket( + url: url(port), + onOpen: onOpen, + onClose: onClose + ) return (server, client) } @@ -342,7 +351,11 @@ private extension SystemWebSocketTests { ) async -> (WebSocketServer, WebSocket) { let port = ports.removeFirst() let server = try! WebSocketServer(port: port, outputPublisher: subject) - let client = try! await SystemWebSocket(url: url(port), onOpen: onOpen, onClose: onClose) + let client = try! await SystemWebSocket( + url: url(port), + onOpen: onOpen, + onClose: onClose + ) return (server, try! await .system(client)) } } diff --git a/Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift b/Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift index c638c0e..7ea92f7 100644 --- a/Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift +++ b/Tests/WebSocketTests/URLSessionWebSocketTaskCloseCodeTests.swift @@ -6,13 +6,15 @@ class URLSessionWebSocketTaskCloseCodeTests: XCTestCase { let urlSessionCloseCodes: [URLSessionWebSocketTask.CloseCode] = [ .invalid, .normalClosure, .goingAway, .protocolError, .unsupportedData, .noStatusReceived, .abnormalClosure, .invalidFramePayloadData, .policyViolation, - .messageTooBig, .mandatoryExtensionMissing, .internalServerError, .tlsHandshakeFailure, + .messageTooBig, .mandatoryExtensionMissing, .internalServerError, + .tlsHandshakeFailure, ] let closeCodes: [WebSocketCloseCode] = [ .invalid, .normalClosure, .goingAway, .protocolError, .unsupportedData, .noStatusReceived, .abnormalClosure, .invalidFramePayloadData, .policyViolation, - .messageTooBig, .mandatoryExtensionMissing, .internalServerError, .tlsHandshakeFailure, + .messageTooBig, .mandatoryExtensionMissing, .internalServerError, + .tlsHandshakeFailure, ] zip(urlSessionCloseCodes, closeCodes).forEach { urlSessionCloseCode, closeCode in From b6fac24ace78d2843015a6a1ea4bffc8ff48551a Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Mon, 18 Apr 2022 00:15:26 +0200 Subject: [PATCH 07/11] Add CI --- .github/workflows/{swift.yml => ci.yml} | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) rename .github/workflows/{swift.yml => ci.yml} (77%) diff --git a/.github/workflows/swift.yml b/.github/workflows/ci.yml similarity index 77% rename from .github/workflows/swift.yml rename to .github/workflows/ci.yml index 05ffd87..2fbe620 100644 --- a/.github/workflows/swift.yml +++ b/.github/workflows/ci.yml @@ -1,10 +1,10 @@ -name: Swift +name: Build and Test on: push: - branches: [ master ] + branches: [ main ] pull_request: - branches: [ master ] + branches: [ main ] jobs: build: @@ -17,4 +17,3 @@ jobs: run: swift build -v - name: Run tests run: swift test -v - From 4135c03fcfec8bdbf0e49ee582743d9dac0474f1 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Mon, 18 Apr 2022 00:29:44 +0200 Subject: [PATCH 08/11] Update ci.yml --- .github/workflows/ci.yml | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2fbe620..519a14f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,19 +1,15 @@ -name: Build and Test - -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] - +name: Test +on: push jobs: - build: - name: Build and Test + test: + name: Test runs-on: macos-latest - steps: - - uses: actions/checkout@v2 - - name: Build - run: swift build -v - - name: Run tests - run: swift test -v + uses: actions/checkout@v2 + # Available environments: https://github.com/actions/virtual-environments/blob/main/images/macos/macos-12-Readme.md#xcode + - name: Switch Xcode to 13.3 + run: xcversion select 13.3 + - name: Resolve package dependencies + run: swift package resolve + - name: Test + run: swift test --skip-update From 5ad3e3598f7a4fcef1d8db12165d29f1b3489dd4 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Mon, 18 Apr 2022 00:32:16 +0200 Subject: [PATCH 09/11] Update ci.yml --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 519a14f..ecaf70f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,5 +1,5 @@ name: Test -on: push +on: [push, pull_request] jobs: test: name: Test From 178110cd13888e7b0212842ba89fdf300ee51a69 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Mon, 18 Apr 2022 00:36:55 +0200 Subject: [PATCH 10/11] Fix ci.yml --- .github/workflows/ci.yml | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ecaf70f..20bf00b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,15 +1,12 @@ -name: Test -on: [push, pull_request] +on: push + jobs: test: - name: Test runs-on: macos-latest + steps: - uses: actions/checkout@v2 + - uses: actions/checkout@v2 # Available environments: https://github.com/actions/virtual-environments/blob/main/images/macos/macos-12-Readme.md#xcode - - name: Switch Xcode to 13.3 - run: xcversion select 13.3 - - name: Resolve package dependencies - run: swift package resolve - - name: Test - run: swift test --skip-update + - run: xcversion select 13.3 + - run: swift package resolve + - run: swift test --skip-update From 3965f2b15f493977079e721e4267762aa51997a1 Mon Sep 17 00:00:00 2001 From: Anthony Drendel Date: Mon, 18 Apr 2022 00:39:50 +0200 Subject: [PATCH 11/11] Specify macos-12 in ci.yml --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 20bf00b..bb81615 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ on: push jobs: test: - runs-on: macos-latest + runs-on: macos-12 steps: - uses: actions/checkout@v2