Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ where
}

consuming func consumeAndConclude<Return>(
body: (consuming Underlying) async throws -> Return
body: (consuming sending Underlying) async throws -> Return
) async throws -> (Return, FinalElement) {
let (result, trailers) = try await self.base.consumeAndConclude { [logger] reader in
let wrappedReader = RequestBodyAsyncReader(
Expand Down Expand Up @@ -166,7 +166,7 @@ where
}

consuming func produceAndConclude<Return>(
body: (consuming ResponseBodyAsyncWriter) async throws -> (Return, HTTPFields?)
body: (consuming sending ResponseBodyAsyncWriter) async throws -> (Return, HTTPFields?)
) async throws -> Return {
let logger = self.logger
return try await self.base.produceAndConclude { writer in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public protocol ConcludingAsyncReader<Underlying, FinalElement>: ~Copyable {
/// }
/// ```
consuming func consumeAndConclude<Return>(
body: (consuming Underlying) async throws -> Return
body: (consuming sending Underlying) async throws -> Return
) async throws -> (Return, FinalElement)
}

Expand Down Expand Up @@ -62,7 +62,7 @@ extension ConcludingAsyncReader where Self: ~Copyable {
/// }
/// ```
public consuming func consumeAndConclude(
body: (consuming Underlying) async throws -> Void
body: (consuming sending Underlying) async throws -> Void
) async throws -> FinalElement {
let (_, finalElement) = try await self.consumeAndConclude { reader in
try await body(reader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public protocol ConcludingAsyncWriter<Underlying, FinalElement>: ~Copyable {
/// }
/// ```
consuming func produceAndConclude<Return>(
body: (consuming Underlying) async throws -> (Return, FinalElement)
body: (consuming sending Underlying) async throws -> (Return, FinalElement)
) async throws -> Return
}

Expand All @@ -56,7 +56,7 @@ extension ConcludingAsyncWriter where Self: ~Copyable {
/// }
/// ```
public consuming func produceAndConclude(
body: (consuming Underlying) async throws -> FinalElement
body: (consuming sending Underlying) async throws -> FinalElement
) async throws {
try await self.produceAndConclude { writer in
((), try await body(writer))
Expand Down
58 changes: 42 additions & 16 deletions Sources/HTTPServer/HTTPRequestConcludingAsyncReader.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
public import HTTPTypes
import NIOCore
import NIOHTTPTypes
import Synchronization

/// A specialized reader for HTTP request bodies and trailers that manages the reading process
/// and captures the final trailer fields.
Expand All @@ -23,7 +24,7 @@ public struct HTTPRequestConcludingAsyncReader: ConcludingAsyncReader, ~Copyable
public typealias ReadFailure = any Error

/// The HTTP trailer fields captured at the end of the request.
fileprivate var state: ReaderState?
fileprivate var state: ReaderState

/// The iterator that provides HTTP request parts from the underlying channel.
private var iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator
Expand All @@ -32,9 +33,11 @@ public struct HTTPRequestConcludingAsyncReader: ConcludingAsyncReader, ~Copyable
///
/// - Parameter iterator: The NIO async channel inbound stream iterator to use for reading request parts.
fileprivate init(
iterator: consuming sending NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator
iterator: consuming sending NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator,
readerState: ReaderState
) {
self.iterator = iterator
self.state = readerState
}

/// Reads a chunk of request body data.
Expand All @@ -53,18 +56,28 @@ public struct HTTPRequestConcludingAsyncReader: ConcludingAsyncReader, ~Copyable
// TODO: Add ByteBuffer span interfaces
return try await body(Array(buffer: element).span)
case .end(let trailers):
self.state?.trailers = trailers
self.state?.finishedReading = true
self.state.wrapped.withLock { state in
state.trailers = trailers
state.finishedReading = true
}
return try await body(nil)
case .none:
return try await body(nil)
}
}
}

final class ReaderState {
var trailers: HTTPFields? = nil
var finishedReading: Bool = false
final class ReaderState: Sendable {
struct Wrapped {
var trailers: HTTPFields? = nil
var finishedReading: Bool = false
}

let wrapped: Mutex<Wrapped>

init() {
self.wrapped = .init(.init())
}
}

/// The underlying reader type for the HTTP request body.
Expand All @@ -76,10 +89,9 @@ public struct HTTPRequestConcludingAsyncReader: ConcludingAsyncReader, ~Copyable
/// The type of errors that can occur during reading operations.
public typealias Failure = any Error

/// The internal reader that provides HTTP request parts from the underlying channel.
private var partsReader: RequestBodyAsyncReader
private var iterator: NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator?

fileprivate let readerState: ReaderState
internal var state: ReaderState

/// Initializes a new HTTP request body and trailers reader with the given NIO async channel iterator.
///
Expand All @@ -88,8 +100,8 @@ public struct HTTPRequestConcludingAsyncReader: ConcludingAsyncReader, ~Copyable
iterator: consuming sending NIOAsyncChannelInboundStream<HTTPRequestPart>.AsyncIterator,
readerState: ReaderState
) {
self.partsReader = RequestBodyAsyncReader(iterator: iterator)
self.readerState = readerState
self.iterator = iterator
self.state = readerState
}

/// Processes the request body reading operation and captures the final trailer fields.
Expand Down Expand Up @@ -118,11 +130,16 @@ public struct HTTPRequestConcludingAsyncReader: ConcludingAsyncReader, ~Copyable
/// }
/// ```
public consuming func consumeAndConclude<Return>(
body: (consuming RequestBodyAsyncReader) async throws -> Return
body: (consuming sending RequestBodyAsyncReader) async throws -> Return
) async throws -> (Return, HTTPFields?) {
self.partsReader.state = self.readerState
let result = try await body(self.partsReader)
return (result, self.readerState.trailers)
if let iterator = self.iterator.sendingTake() {
let partsReader = RequestBodyAsyncReader(iterator: iterator, readerState: self.state)
let result = try await body(partsReader)
let trailers = self.state.wrapped.withLock { $0.trailers }
return (result, trailers)
} else {
fatalError("consumeAndConclude called more than once")
}
}
}

Expand All @@ -131,3 +148,12 @@ extension HTTPRequestConcludingAsyncReader: Sendable {}

@available(*, unavailable)
extension HTTPRequestConcludingAsyncReader.RequestBodyAsyncReader: Sendable {}


extension Optional {
mutating func sendingTake() -> sending Self {
let result = consume self
self = nil
return result
}
}
19 changes: 15 additions & 4 deletions Sources/HTTPServer/HTTPResponseConcludingAsyncWriter.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
public import HTTPTypes
import NIOCore
import NIOHTTPTypes
import Synchronization

/// A specialized writer for HTTP response bodies and trailers that manages the writing process
/// and the final trailer fields.
Expand Down Expand Up @@ -48,8 +49,16 @@ public struct HTTPResponseConcludingAsyncWriter: ConcludingAsyncWriter, ~Copyabl
}
}

final class WriterState {
var finishedWriting: Bool = false
final class WriterState: Sendable {
struct Wrapped {
var finishedWriting: Bool = false
}

let wrapped: Mutex<Wrapped>

init() {
self.wrapped = .init(.init())
}
}

/// The underlying writer type for the HTTP response body.
Expand Down Expand Up @@ -102,12 +111,14 @@ public struct HTTPResponseConcludingAsyncWriter: ConcludingAsyncWriter, ~Copyabl
/// }
/// ```
public consuming func produceAndConclude<Return>(
body: (consuming ResponseBodyAsyncWriter) async throws -> (Return, FinalElement)
body: (consuming sending ResponseBodyAsyncWriter) async throws -> (Return, FinalElement)
) async throws -> Return {
let responseBodyAsyncWriter = ResponseBodyAsyncWriter(writer: self.writer)
let (result, finalElement) = try await body(responseBodyAsyncWriter)
try await self.writer.write(.end(finalElement))
self.writerState.finishedWriting = true
self.writerState.wrapped.withLock { state in
state.finishedWriting = true
}
return result
}
}
Expand Down
2 changes: 1 addition & 1 deletion Sources/HTTPServer/HTTPResponseSender.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public import HTTPTypes
public struct HTTPResponseSender<ResponseWriter: ConcludingAsyncWriter & ~Copyable>: ~Copyable {
private let _sendResponse: (HTTPResponse) async throws -> ResponseWriter

package init(
public init(
_ sendResponse: @escaping (HTTPResponse) async throws -> ResponseWriter
) {
self._sendResponse = sendResponse
Expand Down
5 changes: 3 additions & 2 deletions Sources/HTTPServer/HTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import NIOPosix
import NIOSSL
import X509
import SwiftASN1
import Synchronization

/// A generic HTTP server that can handle incoming HTTP requests.
///
Expand Down Expand Up @@ -363,12 +364,12 @@ public final class Server<RequestHandler: HTTPServerRequestHandler> {
}
)
} catch {
if !readerState.finishedReading {
if !readerState.wrapped.withLock({ $0.finishedReading }) {
// TODO: do something - we didn't finish reading but we threw
// if h2 reset stream; if h1 try draining request?
fatalError("Didn't finish reading but threw.")
}
if !writerState.finishedWriting {
if !writerState.wrapped.withLock({ $0.finishedWriting }) {
// TODO: this means we didn't write a response end and we threw
// we need to do something, possibly just close the connection or
// reset the stream with the appropriate error.
Expand Down
Loading