Skip to content

Commit

Permalink
Provide option to implement shouldUpgrade on webSockets (#2487)
Browse files Browse the repository at this point in the history
* Provide option to implement shouldUpgrade on ws

* Apply feedback for better readability

* Add API docs

* Add tests

* Fix code style

Co-authored-by: Siemen Sikkema <siemensikkema@users.noreply.github.com>

Co-authored-by: Siemen Sikkema <siemensikkema@users.noreply.github.com>
  • Loading branch information
code28 and siemensikkema committed Oct 27, 2020
1 parent fb59cb5 commit 1d1fcf3
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 6 deletions.
6 changes: 3 additions & 3 deletions Sources/Vapor/HTTP/Server/HTTPServerUpgradeHandler.swift
Expand Up @@ -51,9 +51,9 @@ final class HTTPServerUpgradeHandler: ChannelDuplexHandler, RemovableChannelHand
self.upgradeState = .upgraded
if res.status == .switchingProtocols, let upgrader = res.upgrader {
switch upgrader {
case .webSocket(let maxFrameSize, let onUpgrade):
let webSocketUpgrader = NIOWebSocketServerUpgrader(maxFrameSize: maxFrameSize.value, automaticErrorHandling: false, shouldUpgrade: { channel, _ in
return channel.eventLoop.makeSucceededFuture([:])
case .webSocket(let maxFrameSize, let shouldUpgrade, let onUpgrade):
let webSocketUpgrader = NIOWebSocketServerUpgrader(maxFrameSize: maxFrameSize.value, automaticErrorHandling: false, shouldUpgrade: { _, _ in
return shouldUpgrade()
}, upgradePipelineHandler: { channel, req in
return WebSocket.server(on: channel, onUpgrade: onUpgrade)
})
Expand Down
2 changes: 1 addition & 1 deletion Sources/Vapor/Response/Response.swift
Expand Up @@ -36,7 +36,7 @@ public final class Response: CustomStringConvertible {
var forHeadRequest: Bool

internal enum Upgrader {
case webSocket(maxFrameSize: WebSocketMaxFrameSize, onUpgrade: (WebSocket) -> ())
case webSocket(maxFrameSize: WebSocketMaxFrameSize, shouldUpgrade: (() -> EventLoopFuture<HTTPHeaders?>), onUpgrade: (WebSocket) -> ())
}

internal var upgrader: Upgrader?
Expand Down
30 changes: 28 additions & 2 deletions Sources/Vapor/Routing/RoutesBuilder+WebSocket.swift
Expand Up @@ -11,24 +11,50 @@ public struct WebSocketMaxFrameSize: ExpressibleByIntegerLiteral {
}

extension RoutesBuilder {
/// Adds a route for opening a web socket connection
/// - parameters:
/// - path: Path components separated by commas.
/// - maxFrameSize: The maximum allowed frame size. See `NIOWebSocketServerUpgrader`.
/// - shouldUpgrade: Closure to apply before upgrade to web socket happens.
/// Returns additional `HTTPHeaders` for response, `nil` to deny upgrading.
/// See `NIOWebSocketServerUpgrader`.
/// - onUpgrade: Closure to apply after web socket is upgraded successfully.
/// - returns: `Route` instance for newly created web socket endpoint
@discardableResult
public func webSocket(
_ path: PathComponent...,
maxFrameSize: WebSocketMaxFrameSize = .`default`,
shouldUpgrade: @escaping ((Request) -> EventLoopFuture<HTTPHeaders?>) = {
$0.eventLoop.makeSucceededFuture([:])
},
onUpgrade: @escaping (Request, WebSocket) -> ()
) -> Route {
return self.webSocket(path, maxFrameSize: maxFrameSize, onUpgrade: onUpgrade)
return self.webSocket(path, maxFrameSize: maxFrameSize, shouldUpgrade: shouldUpgrade, onUpgrade: onUpgrade)
}

/// Adds a route for opening a web socket connection
/// - parameters:
/// - path: Array of path components.
/// - maxFrameSize: The maximum allowed frame size. See `NIOWebSocketServerUpgrader`.
/// - shouldUpgrade: Closure to apply before upgrade to web socket happens.
/// Returns additional `HTTPHeaders` for response, `nil` to deny upgrading.
/// See `NIOWebSocketServerUpgrader`.
/// - onUpgrade: Closure to apply after web socket is upgraded successfully.
/// - returns: `Route` instance for newly created web socket endpoint
@discardableResult
public func webSocket(
_ path: [PathComponent],
maxFrameSize: WebSocketMaxFrameSize = .`default`,
shouldUpgrade: @escaping ((Request) -> EventLoopFuture<HTTPHeaders?>) = {
$0.eventLoop.makeSucceededFuture([:])
},
onUpgrade: @escaping (Request, WebSocket) -> ()
) -> Route {
return self.on(.GET, path) { request -> Response in
let res = Response(status: .switchingProtocols)
res.upgrader = .webSocket(maxFrameSize: maxFrameSize, onUpgrade: { ws in
res.upgrader = .webSocket(maxFrameSize: maxFrameSize, shouldUpgrade: {
shouldUpgrade(request)
}, onUpgrade: { ws in
onUpgrade(request, ws)
})
return res
Expand Down
21 changes: 21 additions & 0 deletions Tests/VaporTests/RouteTests.swift
Expand Up @@ -351,4 +351,25 @@ final class RouteTests: XCTestCase {
XCTAssertEqual(res.status, .ok)
}
}

func testWebsocketUpgrade() throws {
let app = Application(.testing)
defer { app.shutdown() }

let testMarkerHeaderKey = "TestMarker"
let testMarkerHeaderValue = "addedInShouldUpgrade"

app.routes.webSocket("customshouldupgrade", shouldUpgrade: { req in
return req.eventLoop.future([testMarkerHeaderKey: testMarkerHeaderValue])
}, onUpgrade: { _, _ in })

try app.testable(method: .running).test(.GET, "customshouldupgrade", beforeRequest: { req in
req.headers.replaceOrAdd(name: HTTPHeaders.Name.secWebSocketVersion, value: "13")
req.headers.replaceOrAdd(name: HTTPHeaders.Name.secWebSocketKey, value: "zyFJtLIpI2ASsmMHJ4Cf0A==")
req.headers.replaceOrAdd(name: .connection, value: "Upgrade")
req.headers.replaceOrAdd(name: .upgrade, value: "websocket")
}) { res in
XCTAssertEqual(res.headers.first(name: testMarkerHeaderKey), testMarkerHeaderValue)
}
}
}

0 comments on commit 1d1fcf3

Please sign in to comment.