diff --git a/Sources/Vapor/Response/Response.swift b/Sources/Vapor/Response/Response.swift index 727b308931..259bfa90c5 100644 --- a/Sources/Vapor/Response/Response.swift +++ b/Sources/Vapor/Response/Response.swift @@ -33,7 +33,7 @@ public final class Response: CustomStringConvertible { } internal enum Upgrader { - case webSocket(onUpgrade: (WebSocket) -> ()) + case webSocket(maxFrameSize: WebSocketMaxFrameSize, onUpgrade: (WebSocket) -> ()) } internal var upgrader: Upgrader? diff --git a/Sources/Vapor/Routing/RoutesBuilder+WebSocket.swift b/Sources/Vapor/Routing/RoutesBuilder+WebSocket.swift index d6eb8d2c4a..f8ae9ffdcd 100644 --- a/Sources/Vapor/Routing/RoutesBuilder+WebSocket.swift +++ b/Sources/Vapor/Routing/RoutesBuilder+WebSocket.swift @@ -1,12 +1,18 @@ +public enum WebSocketMaxFrameSize { + case `default` + case override(Int) +} + extension RoutesBuilder { @discardableResult public func webSocket( _ path: PathComponent..., + maxFrameSize: WebSocketMaxFrameSize = .`default`, onUpgrade: @escaping (Request, WebSocket) -> () ) -> Route { return self.on(.GET, path) { request -> Response in let res = Response(status: .switchingProtocols) - res.upgrader = .webSocket(onUpgrade: { ws in + res.upgrader = .webSocket(maxFrameSize: maxFrameSize, onUpgrade: { ws in onUpgrade(request, ws) }) return res diff --git a/Sources/Vapor/Server/HTTPServerUpgradeHandler.swift b/Sources/Vapor/Server/HTTPServerUpgradeHandler.swift index 2ea7713933..0e0298bccb 100644 --- a/Sources/Vapor/Server/HTTPServerUpgradeHandler.swift +++ b/Sources/Vapor/Server/HTTPServerUpgradeHandler.swift @@ -51,8 +51,15 @@ final class HTTPServerUpgradeHandler: ChannelDuplexHandler, RemovableChannelHand self.upgradeState = .upgraded if res.status == .switchingProtocols, let upgrader = res.upgrader { switch upgrader { - case .webSocket(let onUpgrade): - let webSocketUpgrader = NIOWebSocketServerUpgrader(automaticErrorHandling: false, shouldUpgrade: { channel, _ in + case .webSocket(let maxFrameSize, let onUpgrade): + let maxFrameSizeBytes: Int + switch maxFrameSize { + case .`default`: + maxFrameSizeBytes = 1 << 14 + case .override(let bytes): + maxFrameSizeBytes = bytes + } + let webSocketUpgrader = NIOWebSocketServerUpgrader(maxFrameSize: maxFrameSizeBytes, automaticErrorHandling: false, shouldUpgrade: { channel, _ in return channel.eventLoop.makeSucceededFuture([:]) }, upgradePipelineHandler: { channel, req in return WebSocket.server(on: channel, onUpgrade: onUpgrade)