Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NIOLoopBound issues #3081

Merged
merged 8 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions Sources/Vapor/HTTP/BodyStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ public protocol BodyStreamWriter: Sendable {

extension BodyStreamWriter {
public func write(_ result: BodyStreamResult) -> EventLoopFuture<Void> {
// We need to ensure we're on the event loop here for write as there's
// no guarantee that users will be on the event loop
if self.eventLoop.inEventLoop {
return write0(result)
} else {
return self.eventLoop.flatSubmit {
self.write0(result)
}
}
}

private func write0(_ result: BodyStreamResult) -> EventLoopFuture<Void> {
let promise = self.eventLoop.makePromise(of: Void.self)
self.write(result, promise: promise)
return promise.futureResult
Expand Down
3 changes: 2 additions & 1 deletion Sources/Vapor/HTTP/Server/HTTPServerHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ final class HTTPServerHandler: ChannelInboundHandler, RemovableChannelHandler {
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let box = NIOLoopBound((context, self), eventLoop: context.eventLoop)
let request = self.unwrapInboundIn(data)
self.responder.respond(to: request).whenComplete { response in
// hop(to:) is required here to ensure we're on the correct event loop
self.responder.respond(to: request).hop(to: context.eventLoop).whenComplete { response in
let (context, handler) = box.value
handler.serialize(response, for: request, context: context)
}
Expand Down
129 changes: 117 additions & 12 deletions Tests/VaporTests/PipelineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@ import NIOEmbedded
import NIOCore

final class PipelineTests: XCTestCase {
var app: Application!

override func setUp() async throws {
app = Application(.testing)
}

override func tearDown() async throws {
app.shutdown()
}


func testEchoHandlers() throws {
let app = Application(.testing)
defer { app.shutdown() }

app.on(.POST, "echo", body: .stream) { request -> Response in
Response(body: .init(stream: { writer in
request.body.drain { body in
Expand Down Expand Up @@ -59,9 +67,6 @@ final class PipelineTests: XCTestCase {
}

func testEOFFraming() throws {
let app = Application(.testing)
defer { app.shutdown() }

app.on(.POST, "echo", body: .stream) { request -> Response in
Response(body: .init(stream: { writer in
request.body.drain { body in
Expand Down Expand Up @@ -89,9 +94,6 @@ final class PipelineTests: XCTestCase {
}

func testBadStreamLength() throws {
let app = Application(.testing)
defer { app.shutdown() }

app.on(.POST, "echo", body: .stream) { request -> Response in
Response(body: .init(stream: { writer in
writer.write(.buffer(.init(string: "a")), promise: nil)
Expand All @@ -117,9 +119,6 @@ final class PipelineTests: XCTestCase {
}

func testInvalidHttp() throws {
let app = Application(.testing)
defer { app.shutdown() }

let channel = EmbeddedChannel()
try channel.connect(to: .init(unixDomainSocketPath: "/foo")).wait()
try channel.pipeline.addVaporHTTP1Handlers(
Expand All @@ -140,6 +139,112 @@ final class PipelineTests: XCTestCase {
XCTAssertEqual(channel.isActive, false)
try XCTAssertNil(channel.readOutbound(as: ByteBuffer.self)?.string)
}

func testReturningResponseOnDifferentEventLoopDosentCrashLoopBoundBox() async throws {
struct ResponseThing: ResponseEncodable {
let eventLoop: EventLoop

func encodeResponse(for request: Vapor.Request) -> NIOCore.EventLoopFuture<Vapor.Response> {
let response = Response(status: .ok)
return eventLoop.future(response)
}
}

let eventLoop = app!.eventLoopGroup.next()
app.get("dont-crash") { req in
return ResponseThing(eventLoop: eventLoop)
}

try app.test(.GET, "dont-crash") { res in
XCTAssertEqual(res.status, .ok)
}

app.environment.arguments = ["serve"]
app.http.server.configuration.port = 0
try app.start()

XCTAssertNotNil(app.http.server.shared.localAddress)
guard let localAddress = app.http.server.shared.localAddress,
let port = localAddress.port else {
XCTFail("couldn't get ip/port from \(app.http.server.shared.localAddress.debugDescription)")
return
}

let res = try await app.client.get("http://localhost:\(port)/dont-crash")
XCTAssertEqual(res.status, .ok)
}

func testReturningResponseFromMiddlewareOnDifferentEventLoopDosentCrashLoopBoundBox() async throws {
struct WrongEventLoopMiddleware: Middleware {
func respond(to request: Request, chainingTo next: Responder) -> EventLoopFuture<Response> {
next.respond(to: request).hop(to: request.application.eventLoopGroup.next())
}
}

app.grouped(WrongEventLoopMiddleware()).get("dont-crash") { req in
return "OK"
}

try app.test(.GET, "dont-crash") { res in
XCTAssertEqual(res.status, .ok)
}

app.environment.arguments = ["serve"]
app.http.server.configuration.port = 0
try app.start()

XCTAssertNotNil(app.http.server.shared.localAddress)
guard let localAddress = app.http.server.shared.localAddress,
let port = localAddress.port else {
XCTFail("couldn't get ip/port from \(app.http.server.shared.localAddress.debugDescription)")
return
}

let res = try await app.client.get("http://localhost:\(port)/dont-crash")
XCTAssertEqual(res.status, .ok)
}

func testStreamingOffEventLoop() async throws {
let eventLoop = app.eventLoopGroup.next()
app.on(.POST, "stream", body: .stream) { request -> Response in
Response(body: .init(stream: { writer in
request.body.drain { body in
switch body {
case .buffer(let buffer):
return writer.write(.buffer(buffer)).hop(to: eventLoop)
case .error(let error):
return writer.write(.error(error)).hop(to: eventLoop)
case .end:
return writer.write(.end).hop(to: eventLoop)
}
}
}))
}

app.environment.arguments = ["serve"]
app.http.server.configuration.port = 0
try app.start()

XCTAssertNotNil(app.http.server.shared.localAddress)
guard let localAddress = app.http.server.shared.localAddress,
let port = localAddress.port else {
XCTFail("couldn't get ip/port from \(app.http.server.shared.localAddress.debugDescription)")
return
}

struct ABody: Content {
let hello: String

init() {
self.hello = "hello"
}
}

let res = try await app.client.post("http://localhost:\(port)/stream", beforeSend: {
try $0.content.encode(ABody())
})
XCTAssertEqual(res.status, .ok)
}

override class func setUp() {
XCTAssert(isLoggingConfigured)
Expand Down