Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server shutdown improvements #2472

Merged
merged 1 commit into from Aug 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 29 additions & 0 deletions Sources/Vapor/HTTP/Headers/HTTPHeaders+Connection.swift
@@ -0,0 +1,29 @@
extension HTTPHeaders {
public struct Connection: ExpressibleByStringLiteral, Equatable {
public static let close: Self = "close"
public static let keepAlive: Self = "keep-alive"

public let value: String

public init(value: String) {
self.value = value
}

public init(stringLiteral value: String) {
self.init(value: value)
}
}

public var connection: Connection? {
get {
self.first(name: .connection).flatMap(Connection.init(value:))
}
set {
if let value = newValue {
self.replaceOrAdd(name: .connection, value: value.value)
} else {
self.remove(name: .connection)
}
}
}
}
4 changes: 2 additions & 2 deletions Sources/Vapor/HTTP/Server/HTTPServer.swift
Expand Up @@ -361,7 +361,7 @@ extension ChannelPipeline {
handlers.append(serverResEncoder)

// add server request -> response delegate
let handler = HTTPServerHandler(responder: responder)
let handler = HTTPServerHandler(responder: responder, logger: application.logger)
handlers.append(handler)

return self.addHandlers(handlers).flatMap {
Expand Down Expand Up @@ -426,7 +426,7 @@ extension ChannelPipeline {
)
handlers.append(serverReqDecoder)
// add server request -> response delegate
let handler = HTTPServerHandler(responder: responder)
let handler = HTTPServerHandler(responder: responder, logger: application.logger)

// add HTTP upgrade handler
let upgrader = HTTPServerUpgradeHandler(
Expand Down
24 changes: 21 additions & 3 deletions Sources/Vapor/HTTP/Server/HTTPServerHandler.swift
Expand Up @@ -5,9 +5,13 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias OutboundOut = Response

let responder: Responder
let logger: Logger
var isShuttingDown: Bool

init(responder: Responder) {
init(responder: Responder, logger: Logger) {
self.responder = responder
self.logger = logger
self.isShuttingDown = false
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
Expand All @@ -34,12 +38,16 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {
case 2:
context.write(self.wrapOutboundOut(response), promise: nil)
default:
response.headers.add(name: .connection, value: request.isKeepAlive ? "keep-alive" : "close")
let keepAlive = !self.isShuttingDown && request.isKeepAlive
if self.isShuttingDown {
self.logger.debug("In-flight request has completed")
}
response.headers.add(name: .connection, value: keepAlive ? "keep-alive" : "close")
let done = context.write(self.wrapOutboundOut(response))
done.whenComplete { result in
switch result {
case .success:
if !request.isKeepAlive {
if !keepAlive {
context.close(mode: .output, promise: nil)
}
case .failure(let error):
Expand All @@ -48,4 +56,14 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {
}
}
}

func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
switch event {
case is ChannelShouldQuiesceEvent:
self.logger.trace("HTTP handler will no longer respect keep-alive")
self.isShuttingDown = true
default:
self.logger.trace("Unhandled user event: \(event)")
}
}
}
19 changes: 16 additions & 3 deletions Sources/Vapor/HTTP/Server/HTTPServerRequestDecoder.swift
Expand Up @@ -17,13 +17,14 @@ final class HTTPServerRequestDecoder: ChannelInboundHandler, RemovableChannelHan
var requestState: RequestState
var bodyStreamState: HTTPBodyStreamState

private let logger: Logger
var logger: Logger {
self.application.logger
}
var application: Application

init(application: Application) {
self.application = application
self.requestState = .ready
self.logger = Logger(label: "codes.vapor.server")
self.bodyStreamState = .init()
}

Expand Down Expand Up @@ -179,7 +180,8 @@ final class HTTPServerRequestDecoder: ChannelInboundHandler, RemovableChannelHan
}

func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
if event is HTTPServerResponseEncoder.ResponseEndSentEvent {
switch event {
case is HTTPServerResponseEncoder.ResponseEndSentEvent:
switch self.requestState {
case .streamingBody(let bodyStream):
// Response ended during request stream.
Expand All @@ -197,6 +199,17 @@ final class HTTPServerRequestDecoder: ChannelInboundHandler, RemovableChannelHan
// Response ended after request had been read.
break
}
case is ChannelShouldQuiesceEvent:
switch self.requestState {
case .ready:
self.logger.trace("Closing keep-alive HTTP connection since server is going away")
context.channel.close(mode: .all, promise: nil)
default:
self.logger.debug("A request is currently in-flight")
context.fireUserInboundEventTriggered(event)
}
default:
self.logger.trace("Unhandled user event: \(event)")
}
}
}
Expand Down
23 changes: 22 additions & 1 deletion Tests/VaporTests/ServerTests.swift
Expand Up @@ -277,7 +277,7 @@ final class ServerTests: XCTestCase {
request: request,
delegate: response
).wait()

XCTAssertEqual(context.server, ["foo", "bar", "baz"])
XCTAssertEqual(context.client, ["foo", "bar", "baz"])
}
Expand Down Expand Up @@ -315,6 +315,27 @@ final class ServerTests: XCTestCase {
XCTAssertEqual(b.status, .ok)
}

func testQuiesceKeepAliveConnections() throws {
let app = Application(.testing)
defer { app.shutdown() }

app.get("hello") { req in
"world"
}

let port = 1337
app.http.server.configuration.port = port
try app.start()

let request = try HTTPClient.Request(
url: "http://localhost:\(port)/hello",
method: .GET,
headers: ["connection": "keep-alive"]
)
let a = try app.http.client.shared.execute(request: request).wait()
XCTAssertEqual(a.headers.connection, .keepAlive)
}

override class func setUp() {
XCTAssertTrue(isLoggingConfigured)
}
Expand Down
80 changes: 80 additions & 0 deletions Tests/VaporTests/WebSocketTests.swift
Expand Up @@ -76,4 +76,84 @@ final class WebSocketTests: XCTestCase {

try XCTAssertEqual(promise.futureResult.wait(), "foo")
}

func testLifecycleShutdown() throws {
let app = Application(.testing)
app.http.server.configuration.port = 1337

final class WebSocketManager: LifecycleHandler {
private let lock: Lock
private var connections: Set<WebSocket>

init() {
self.lock = .init()
self.connections = .init()
}

func track(_ ws: WebSocket) {
self.lock.lock()
defer { self.lock.unlock() }
self.connections.insert(ws)
ws.onClose.whenComplete { _ in
self.lock.lock()
defer { self.lock.unlock() }
self.connections.remove(ws)
}
}

func broadcast(_ message: String) {
self.lock.lock()
defer { self.lock.unlock() }
for ws in self.connections {
ws.send(message)
}
}

/// Closes all active WebSocket connections
func shutdown(_ app: Application) {
self.lock.lock()
defer { self.lock.unlock() }
app.logger.debug("Shutting down \(self.connections.count) WebSocket(s)")
try! EventLoopFuture<Void>.andAllSucceed(
self.connections.map { $0.close() } ,
on: app.eventLoopGroup.next()
).wait()
}
}

let webSockets = WebSocketManager()
app.lifecycle.use(webSockets)

app.webSocket("watcher") { req, ws in
webSockets.track(ws)
ws.send("hello")
}

try app.start()

let clientGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { try! clientGroup.syncShutdownGracefully() }
let connectPromise = app.eventLoopGroup.next().makePromise(of: WebSocket.self)
WebSocket.connect(to: "ws://localhost:1337/watcher", on: clientGroup) { ws in
connectPromise.succeed(ws)
}.cascadeFailure(to: connectPromise)

let ws = try connectPromise.futureResult.wait()
app.shutdown()
try ws.onClose.wait()
}

override class func setUp() {
XCTAssertTrue(isLoggingConfigured)
}
}

extension WebSocket: Hashable {
public static func == (lhs: WebSocket, rhs: WebSocket) -> Bool {
lhs === rhs
}

public func hash(into hasher: inout Hasher) {
ObjectIdentifier(self).hash(into: &hasher)
}
}