diff --git a/FlyingFox/Sources/WebSocket/WSCloseCode.swift b/FlyingFox/Sources/WebSocket/WSCloseCode.swift index 6ad3076..162365b 100644 --- a/FlyingFox/Sources/WebSocket/WSCloseCode.swift +++ b/FlyingFox/Sources/WebSocket/WSCloseCode.swift @@ -31,15 +31,17 @@ import Foundation -public struct WSCloseCode: RawRepresentable, Sendable, Hashable { - public var rawValue: UInt16 - - public init(rawValue: UInt16) { - self.rawValue = rawValue - } +public struct WSCloseCode: Sendable, Hashable { + public var code: UInt16 + public var reason: String public init(_ code: UInt16) { - self.rawValue = code + self.code = code + self.reason = "" + } + public init(_ code: UInt16, reason: String) { + self.code = code + self.reason = reason } } @@ -47,19 +49,19 @@ public extension WSCloseCode { // The following codes are based on: // https://developer.mozilla.org/en-US/docs/Web/API/CloseEvent/code - static let normalClosure = WSCloseCode(1000) - static let goingAway = WSCloseCode(1001) - static let protocolError = WSCloseCode(1002) - static let unsupportedData = WSCloseCode(1003) - static let noStatusReceived = WSCloseCode(1005) - static let abnormalClosure = WSCloseCode(1006) - static let invalidFramePayloadData = WSCloseCode(1007) - static let policyViolation = WSCloseCode(1008) - static let messageTooBig = WSCloseCode(1009) - static let mandatoryExtensionMissing = WSCloseCode(1010) - static let internalServerError = WSCloseCode(1011) - static let serviceRestart = WSCloseCode(1012) - static let tryAgainLater = WSCloseCode(1013) - static let badGateway = WSCloseCode(1014) - static let tlsHandshakeFailure = WSCloseCode(1015) + static let normalClosure = WSCloseCode(1000) + static let goingAway = WSCloseCode(1001, reason: "Going Away") + static let protocolError = WSCloseCode(1002, reason: "Protocol Error") + static let unsupportedData = WSCloseCode(1003, reason: "Unsupported Data") + static let noStatusReceived = WSCloseCode(1005, reason: "No Status Received") + static let abnormalClosure = WSCloseCode(1006, reason: "Abnormal Closure") + static let invalidFramePayload = WSCloseCode(1007, reason: "Invalid Frame Payload") + static let policyViolation = WSCloseCode(1008, reason: "Policy Violation") + static let messageTooBig = WSCloseCode(1009, reason: "Message Too Big") + static let mandatoryExtensionMissing = WSCloseCode(1010, reason: "Mandatory Extension Missing") + static let internalServerError = WSCloseCode(1011, reason: "Internal Server Error") + static let serviceRestart = WSCloseCode(1012, reason: "Service Restart") + static let tryAgainLater = WSCloseCode(1013, reason: "Try Again Later") + static let badGateway = WSCloseCode(1014, reason: "Bad Gateway") + static let tlsHandshakeFailure = WSCloseCode(1015, reason: "TLS Handshake Failure") } diff --git a/FlyingFox/Sources/WebSocket/WSFrame.swift b/FlyingFox/Sources/WebSocket/WSFrame.swift index f9e43ca..652a6f3 100644 --- a/FlyingFox/Sources/WebSocket/WSFrame.swift +++ b/FlyingFox/Sources/WebSocket/WSFrame.swift @@ -93,15 +93,14 @@ public struct WSFrame: Sendable, Hashable { public extension WSFrame { static func close(message: String = "", mask: Mask? = nil) -> Self { close( - code: message.isEmpty ? .normalClosure : .protocolError, - message: message, + code: message.isEmpty ? .normalClosure : WSCloseCode(WSCloseCode.protocolError.code, reason: message), mask: mask ) } - static func close(code: WSCloseCode, message: String, mask: Mask? = nil) -> Self { - var payload = Data([UInt8(code.rawValue >> 8), UInt8(code.rawValue & 0xFF)]) - if let data = message.data(using: .utf8) { + static func close(code: WSCloseCode, mask: Mask? = nil) -> Self { + var payload = Data([UInt8(code.code >> 8), UInt8(code.code & 0xFF)]) + if let data = code.reason.data(using: .utf8) { payload.append(contentsOf: data) } return WSFrame( diff --git a/FlyingFox/Sources/WebSocket/WSHandler.swift b/FlyingFox/Sources/WebSocket/WSHandler.swift index fd9ee6d..29e3928 100644 --- a/FlyingFox/Sources/WebSocket/WSHandler.swift +++ b/FlyingFox/Sources/WebSocket/WSHandler.swift @@ -119,10 +119,18 @@ public struct MessageFrameWSHandler: WSHandler { } } group.addTask { - for await message in messagesOut { - for frame in makeFrames(for: message) { - framesOut.yield(frame) + do { + for await message in messagesOut { + for frame in makeFrames(for: message) { + framesOut.yield(frame) + if frame.opcode == .close { + throw FrameError.closed(frame) + } + } } + framesOut.finish(throwing: nil) + } catch { + framesOut.finish(throwing: nil) } } await group.next()! @@ -140,23 +148,22 @@ public struct MessageFrameWSHandler: WSHandler { case .binary: return .data(frame.payload) case .close: - let (code, reason) = try makeCloseCode(from: frame.payload) - return .close(code: code, reason: reason) + return try .close(makeCloseCode(from: frame.payload)) default: return nil } } - func makeCloseCode(from payload: Data) throws -> (WSCloseCode, String) { + func makeCloseCode(from payload: Data) throws -> WSCloseCode { guard payload.count >= 2 else { - return (.noStatusReceived, "") + return .noStatusReceived } let statusCode = payload.withUnsafeBytes { $0.load(as: UInt16.self).bigEndian } guard let reason = String(data: payload.dropFirst(2), encoding: .utf8) else { throw FrameError.invalid("Invalid UTF8 Sequence") } - return (WSCloseCode(statusCode), reason) + return WSCloseCode(statusCode, reason: reason) } func makeResponseFrames(for frame: WSFrame) throws -> WSFrame? { @@ -178,8 +185,8 @@ public struct MessageFrameWSHandler: WSHandler { return Self.makeFrames(opcode: .text, payload: string.data(using: .utf8)!, size: frameSize) case let .data(data): return Self.makeFrames(opcode: .binary, payload: data, size: frameSize) - case let .close(code: code, reason: message): - return [WSFrame.close(code: code, message: message)] + case let .close(code): + return [WSFrame.close(code: code)] } } diff --git a/FlyingFox/Sources/WebSocket/WSMessage.swift b/FlyingFox/Sources/WebSocket/WSMessage.swift index 3c93baa..832a46f 100644 --- a/FlyingFox/Sources/WebSocket/WSMessage.swift +++ b/FlyingFox/Sources/WebSocket/WSMessage.swift @@ -34,7 +34,7 @@ import Foundation public enum WSMessage: @unchecked Sendable, Hashable { case text(String) case data(Data) - case close(code: WSCloseCode = .normalClosure, reason: String = "") + case close(WSCloseCode = .normalClosure) } public protocol WSMessageHandler: Sendable { diff --git a/FlyingFox/Tests/WebSocket/WSFrameTests.swift b/FlyingFox/Tests/WebSocket/WSFrameTests.swift index 9057780..dfd2a9b 100644 --- a/FlyingFox/Tests/WebSocket/WSFrameTests.swift +++ b/FlyingFox/Tests/WebSocket/WSFrameTests.swift @@ -70,7 +70,7 @@ struct WSFrameTests { ) ) #expect( - WSFrame.close(code: WSCloseCode(4999), message: "Err") == .make( + WSFrame.close(code: WSCloseCode(4999, reason: "Err")) == .make( fin: true, opcode: .close, mask: nil, @@ -78,7 +78,7 @@ struct WSFrameTests { ) ) #expect( - WSFrame.close(code: WSCloseCode(4999), message: "Err", mask: .mock) == .make( + WSFrame.close(code: WSCloseCode(4999, reason: "Err"), mask: .mock) == .make( fin: true, opcode: .close, mask: .mock, diff --git a/FlyingFox/Tests/WebSocket/WSHandlerTests.swift b/FlyingFox/Tests/WebSocket/WSHandlerTests.swift index 949bb5b..b40c737 100644 --- a/FlyingFox/Tests/WebSocket/WSHandlerTests.swift +++ b/FlyingFox/Tests/WebSocket/WSHandlerTests.swift @@ -63,11 +63,11 @@ struct WSHandlerTests { #expect( try handler.makeMessage(for: .make(fin: true, opcode: .close, payload: payload)) == - .close(code: WSCloseCode(4999), reason: "fish") + .close(WSCloseCode(4999, reason: "fish")) ) #expect( try handler.makeMessage(for: .make(fin: true, opcode: .close)) == - .close(code: .noStatusReceived, reason: "") + .close(.noStatusReceived) ) } diff --git a/FlyingFox/XCTests/WebSocket/WSFrameTests.swift b/FlyingFox/XCTests/WebSocket/WSFrameTests.swift index 42475f1..074681f 100644 --- a/FlyingFox/XCTests/WebSocket/WSFrameTests.swift +++ b/FlyingFox/XCTests/WebSocket/WSFrameTests.swift @@ -65,7 +65,7 @@ final class WSFrameTests: XCTestCase { payload: Data([0x03, 0xEA, .ascii("E"), .ascii("r"), .ascii("r")])) ) XCTAssertEqual( - WSFrame.close(code: WSCloseCode(4999), message: "Err"), + WSFrame.close(code: WSCloseCode(4999, reason: "Err")), .make( fin: true, opcode: .close, @@ -74,7 +74,7 @@ final class WSFrameTests: XCTestCase { ) ) XCTAssertEqual( - WSFrame.close(code: WSCloseCode(4999), message: "Err", mask: .mock), + WSFrame.close(code: WSCloseCode(4999, reason: "Err"), mask: .mock), .make( fin: true, opcode: .close, diff --git a/FlyingFox/XCTests/WebSocket/WSHandlerTests.swift b/FlyingFox/XCTests/WebSocket/WSHandlerTests.swift index a91674f..c754edc 100644 --- a/FlyingFox/XCTests/WebSocket/WSHandlerTests.swift +++ b/FlyingFox/XCTests/WebSocket/WSHandlerTests.swift @@ -65,11 +65,11 @@ final class WSHandlerTests: XCTestCase { XCTAssertEqual( try handler.makeMessage(for: .make(fin: true, opcode: .close, payload: payload)), - .close(code: WSCloseCode(4999), reason: "fish") + .close(WSCloseCode(4999, reason: "fish")) ) XCTAssertEqual( try handler.makeMessage(for: .make(fin: true, opcode: .close)), - .close(code: .noStatusReceived, reason: "") + .close(.noStatusReceived) ) }