From d682e05fdb64c9f7da01af096a73cd11bb7ab755 Mon Sep 17 00:00:00 2001 From: Tim Condon <0xTim@users.noreply.github.com> Date: Tue, 7 Nov 2023 13:14:41 +0000 Subject: [PATCH] Make Async Request Body actually work (#3096) * Add an actual failing test for async request body * Fix event loop crash in async body * Fix logic bug with backpressure in async request body * Add backpressure test * Update Tests/AsyncTests/AsyncRequestTests.swift Co-authored-by: Gwynne Raskind * Try working out failing test * Tell the delegate we stopped * Add test to ensure we clean up correctly * Disable dodgy test --------- Co-authored-by: Gwynne Raskind --- NOTICES.txt | 9 + ...cy.swift => RequestBody+Concurrency.swift} | 14 +- .../Vapor/Request/Request+BodyStream.swift | 21 +- Tests/AsyncTests/AsyncRequestTests.swift | 230 +++++++++++++++++- 4 files changed, 253 insertions(+), 21 deletions(-) rename Sources/Vapor/Concurrency/{Request+Concurrency.swift => RequestBody+Concurrency.swift} (93%) diff --git a/NOTICES.txt b/NOTICES.txt index 684c812166..c1fcff76c8 100644 --- a/NOTICES.txt +++ b/NOTICES.txt @@ -19,3 +19,12 @@ from Swift Metrics. * https://www.apache.org/licenses/LICENSE-2.0 * HOMEPAGE: * https://github.com/apple/swift-metrics + +This product contains an implementation of AsyncLazySequence taken from Async +HTTP Client + + * LICENSE (Apache License 2.0): + * https://www.apache.org/licenses/LICENSE-2.0 + * HOMEPAGE: + * https://github.com/swift-server/async-http-client + diff --git a/Sources/Vapor/Concurrency/Request+Concurrency.swift b/Sources/Vapor/Concurrency/RequestBody+Concurrency.swift similarity index 93% rename from Sources/Vapor/Concurrency/Request+Concurrency.swift rename to Sources/Vapor/Concurrency/RequestBody+Concurrency.swift index 892a187893..f7b392a3c4 100644 --- a/Sources/Vapor/Concurrency/Request+Concurrency.swift +++ b/Sources/Vapor/Concurrency/RequestBody+Concurrency.swift @@ -1,4 +1,3 @@ -#if compiler(>=5.7) import NIOCore import NIOConcurrencyHelpers @@ -13,11 +12,12 @@ extension Request.Body { /// in `Request.Body/makeAsyncIterator()` method. fileprivate final class AsyncSequenceDelegate: @unchecked Sendable, NIOAsyncSequenceProducerDelegate { private enum State { + case notCalledYet case noSignalReceived case waitingForSignalFromConsumer(EventLoopPromise) } - private var _state: State = .noSignalReceived + private var _state: State = .notCalledYet private let eventLoop: any EventLoop init(eventLoop: any EventLoop) { @@ -27,6 +27,9 @@ extension Request.Body { private func produceMore0() { self.eventLoop.preconditionInEventLoop() switch self._state { + case .notCalledYet: + // We can just return here to sign to the producer that we want more data + break case .noSignalReceived: preconditionFailure() case .waitingForSignalFromConsumer(let promise): @@ -38,6 +41,9 @@ extension Request.Body { private func didTerminate0() { self.eventLoop.preconditionInEventLoop() switch self._state { + case .notCalledYet: + // Means didn't hit the backpressure limits, so just return + break case .noSignalReceived: // we will inform the producer, since the next write will fail. break @@ -50,7 +56,7 @@ extension Request.Body { func registerBackpressurePromise(_ promise: EventLoopPromise) { self.eventLoop.preconditionInEventLoop() switch self._state { - case .noSignalReceived: + case .noSignalReceived, .notCalledYet: self._state = .waitingForSignalFromConsumer(promise) case .waitingForSignalFromConsumer: preconditionFailure() @@ -140,6 +146,7 @@ extension Request.Body: AsyncSequence { // The consumer dropped the sequence. // Inform the producer that we don't want more data // by returning an error in the future. + delegate.didTerminate() return request.eventLoop.makeFailedFuture(CancellationError()) case .stopProducing: // The consumer is too slow. @@ -166,4 +173,3 @@ extension Request.Body: AsyncSequence { return AsyncIterator(underlying: producer.sequence.makeAsyncIterator()) } } -#endif diff --git a/Sources/Vapor/Request/Request+BodyStream.swift b/Sources/Vapor/Request/Request+BodyStream.swift index c7e844c58d..204dab202a 100644 --- a/Sources/Vapor/Request/Request+BodyStream.swift +++ b/Sources/Vapor/Request/Request+BodyStream.swift @@ -27,8 +27,17 @@ extension Request { self.allocator = byteBufferAllocator } - /// `read(_:)` **must** be called when on an `EventLoop` - func read(_ handler: @escaping (BodyStreamResult, EventLoopPromise?) -> ()) { + func read(_ handler: @escaping @Sendable (BodyStreamResult, EventLoopPromise?) -> ()) { + if self.eventLoop.inEventLoop { + read0(handler) + } else { + self.eventLoop.execute { + self.read0(handler) + } + } + } + + func read0(_ handler: @escaping @Sendable (BodyStreamResult, EventLoopPromise?) -> ()) { self.eventLoop.preconditionInEventLoop() self.handlerBuffer.value.handler = handler for (result, promise) in self.handlerBuffer.value.buffer { @@ -72,17 +81,17 @@ extension Request { // See https://github.com/vapor/vapor/issues/2906 return eventLoop.flatSubmit { let promise = eventLoop.makePromise(of: ByteBuffer.self) - var data = self.allocator.buffer(capacity: 0) + let data = NIOLoopBoundBox(self.allocator.buffer(capacity: 0), eventLoop: eventLoop) self.read { chunk, next in switch chunk { case .buffer(var buffer): - if let max = max, data.readableBytes + buffer.readableBytes >= max { + if let max = max, data.value.readableBytes + buffer.readableBytes >= max { promise.fail(Abort(.payloadTooLarge)) } else { - data.writeBuffer(&buffer) + data.value.writeBuffer(&buffer) } case .error(let error): promise.fail(error) - case .end: promise.succeed(data) + case .end: promise.succeed(data.value) } next?.succeed(()) } diff --git a/Tests/AsyncTests/AsyncRequestTests.swift b/Tests/AsyncTests/AsyncRequestTests.swift index a372818e57..1fbb978658 100644 --- a/Tests/AsyncTests/AsyncRequestTests.swift +++ b/Tests/AsyncTests/AsyncRequestTests.swift @@ -1,8 +1,10 @@ -#if compiler(>=5.7) && canImport(_Concurrency) import XCTVapor import XCTest import Vapor import NIOCore +import AsyncHTTPClient +import Atomics +import NIOConcurrencyHelpers fileprivate extension String { static func randomDigits(length: Int = 999) -> String { @@ -16,9 +18,22 @@ fileprivate extension String { final class AsyncRequestTests: XCTestCase { - func testStreamingRequest() throws { - let app = Application(.testing) - defer { app.shutdown() } + var app: Application! + var eventLoopGroup: EventLoopGroup! + + override func setUp() async throws { + eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 4) + app = Application(.testing, .shared(eventLoopGroup)) + } + + override func tearDown() async throws { + app.shutdown() + try await eventLoopGroup.shutdownGracefully() + } + + func testStreamingRequest() async throws { + app.http.server.configuration.hostname = "127.0.0.1" + app.http.server.configuration.port = 0 let testValue = String.randomDigits() @@ -33,13 +48,206 @@ final class AsyncRequestTests: XCTestCase { return string } - try app.testable().test(.POST, "/stream", beforeRequest: { req in - req.body = ByteBuffer(string: testValue) - }) { res in - XCTAssertEqual(res.status, .ok) - let returnedString = try XCTUnwrap(try res.content.decode(String.self)) - XCTAssertEqual(testValue, returnedString) + app.environment.arguments = ["serve"] + XCTAssertNoThrow(try app.start()) + + XCTAssertNotNil(app.http.server.shared.localAddress) + guard let localAddress = app.http.server.shared.localAddress, + let ip = localAddress.ipAddress, + let port = localAddress.port else { + return XCTFail("couldn't get ip/port from \(app.http.server.shared.localAddress.debugDescription)") + } + + var request = HTTPClientRequest(url: "http://\(ip):\(port)/stream") + request.method = .POST + request.body = .stream(testValue.utf8.async, length: .unknown) + + let response: HTTPClientResponse = try await app.http.client.shared.execute(request, timeout: .seconds(5)) + XCTAssertEqual(response.status, .ok) + let body = try await response.body.collect(upTo: 1024 * 1024) + XCTAssertEqual(body.string, testValue) + } + + func testStreamingRequestBodyCleansUp() async throws { + app.http.server.configuration.hostname = "127.0.0.1" + app.http.server.configuration.port = 0 + + let bytesTheServerRead = ManagedAtomic(0) + + app.on(.POST, "hello", body: .stream) { req async throws -> Response in + var bodyIterator = req.body.makeAsyncIterator() + let firstChunk = try await bodyIterator.next() + bytesTheServerRead.wrappingIncrement(by: firstChunk?.readableBytes ?? 0, ordering: .relaxed) + throw Abort(.internalServerError) + } + + app.environment.arguments = ["serve"] + XCTAssertNoThrow(try app.start()) + + XCTAssertNotNil(app.http.server.shared.localAddress) + guard let localAddress = app.http.server.shared.localAddress, + let ip = localAddress.ipAddress, + let port = localAddress.port else { + XCTFail("couldn't get ip/port from \(app.http.server.shared.localAddress.debugDescription)") + return + } + + var oneMBBB = ByteBuffer(repeating: 0x41, count: 1024 * 1024) + let oneMB = try XCTUnwrap(oneMBBB.readData(length: oneMBBB.readableBytes)) + var request = HTTPClientRequest(url: "http://\(ip):\(port)/hello") + request.method = .POST + request.body = .stream(oneMB.async, length: .known(oneMB.count)) + let response = try await app.http.client.shared.execute(request, timeout: .seconds(5)) + + XCTAssertGreaterThan(bytesTheServerRead.load(ordering: .relaxed), 0) + XCTAssertEqual(response.status, .internalServerError) + } + + // TODO: Re-enable once it reliably works and doesn't cause issues with trying to shut the application down + // This may require some work in Vapor + func _testRequestBodyBackpressureWorksWithAsyncStreaming() async throws { + app.http.server.configuration.hostname = "127.0.0.1" + app.http.server.configuration.port = 0 + + let numberOfTimesTheServerGotOfferedBytes = ManagedAtomic(0) + let bytesTheServerSaw = ManagedAtomic(0) + let bytesTheClientSent = ManagedAtomic(0) + let serverSawEnd = ManagedAtomic(false) + let serverSawRequest = ManagedAtomic(false) + + let requestHandlerTask: NIOLockedValueBox?> = .init(nil) + + app.on(.POST, "hello", body: .stream) { req async throws -> Response in + requestHandlerTask.withLockedValue { + $0 = Task { + XCTAssertTrue(serverSawRequest.compareExchange(expected: false, desired: true, ordering: .relaxed).exchanged) + var bodyIterator = req.body.makeAsyncIterator() + let firstChunk = try await bodyIterator.next() // read only first chunk + numberOfTimesTheServerGotOfferedBytes.wrappingIncrement(ordering: .relaxed) + bytesTheServerSaw.wrappingIncrement(by: firstChunk?.readableBytes ?? 0, ordering: .relaxed) + defer { + _ = bodyIterator // make sure to not prematurely cancelling the sequence + } + try await Task.sleep(nanoseconds: 10_000_000_000) // wait "forever" + serverSawEnd.store(true, ordering: .relaxed) + return Response(status: .ok) + } + } + + do { + let task = requestHandlerTask.withLockedValue { $0 } + return try await task!.value + } catch { + throw Abort(.internalServerError) + } + } + + app.environment.arguments = ["serve"] + XCTAssertNoThrow(try app.start()) + + XCTAssertNotNil(app.http.server.shared.localAddress) + guard let localAddress = app.http.server.shared.localAddress, + let ip = localAddress.ipAddress, + let port = localAddress.port else { + XCTFail("couldn't get ip/port from \(app.http.server.shared.localAddress.debugDescription)") + return + } + + final class ResponseDelegate: HTTPClientResponseDelegate { + typealias Response = Void + + private let bytesTheClientSent: ManagedAtomic + + init(bytesTheClientSent: ManagedAtomic) { + self.bytesTheClientSent = bytesTheClientSent + } + + func didFinishRequest(task: HTTPClient.Task) throws -> Response { + return () + } + + func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) { + self.bytesTheClientSent.wrappingIncrement(by: part.readableBytes, ordering: .relaxed) + } + } + + let tenMB = ByteBuffer(repeating: 0x41, count: 10 * 1024 * 1024) + let request = try! HTTPClient.Request(url: "http://\(ip):\(port)/hello", + method: .POST, + headers: [:], + body: .byteBuffer(tenMB)) + let delegate = ResponseDelegate(bytesTheClientSent: bytesTheClientSent) + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup)) + XCTAssertThrowsError(try httpClient.execute(request: request, + delegate: delegate, + deadline: .now() + .milliseconds(500)).wait()) { error in + if let error = error as? HTTPClientError { + XCTAssert(error == .readTimeout || error == .deadlineExceeded) + } else { + XCTFail("unexpected error: \(error)") + } + } + + XCTAssertEqual(1, numberOfTimesTheServerGotOfferedBytes.load(ordering: .relaxed)) + XCTAssertGreaterThan(tenMB.readableBytes, bytesTheServerSaw.load(ordering: .relaxed)) + XCTAssertGreaterThan(tenMB.readableBytes, bytesTheClientSent.load(ordering: .relaxed)) + XCTAssertEqual(0, bytesTheClientSent.load(ordering: .relaxed)) // We'd only see this if we sent the full 10 MB. + XCTAssertFalse(serverSawEnd.load(ordering: .relaxed)) + XCTAssertTrue(serverSawRequest.load(ordering: .relaxed)) + + requestHandlerTask.withLockedValue { $0?.cancel() } + try await httpClient.shutdown() + } +} + +// This was taken from AsyncHTTPClients's AsyncRequestTests.swift code. +// The license for the original work is reproduced below. See NOTICES.txt for +// more. + +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +struct AsyncLazySequence: AsyncSequence { + typealias Element = Base.Element + struct AsyncIterator: AsyncIteratorProtocol { + var iterator: Base.Iterator + init(iterator: Base.Iterator) { + self.iterator = iterator } + + mutating func next() async throws -> Base.Element? { + self.iterator.next() + } + } + + var base: Base + + init(base: Base) { + self.base = base + } + + func makeAsyncIterator() -> AsyncIterator { + .init(iterator: self.base.makeIterator()) + } +} + +extension AsyncLazySequence: Sendable where Base: Sendable {} +extension AsyncLazySequence.AsyncIterator: Sendable where Base.Iterator: Sendable {} + +extension Sequence { + /// Turns `self` into an `AsyncSequence` by vending each element of `self` asynchronously. + var async: AsyncLazySequence { + .init(base: self) } } -#endif