From 4b2ac57067246a35d07037964ad509de90336f7b Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Tue, 8 Jul 2025 00:48:04 +0200 Subject: [PATCH 1/9] =?UTF-8?q?Implement=20`COPY=20=E2=80=A6=20FROM=20STDI?= =?UTF-8?q?N`=20queries?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This implements support for COPY operations using `COPY … FROM STDIN` queries for fast data transfer from the client to the backend. --- .../PostgresConnection+CopyFrom.swift | 211 +++++++++ .../ConnectionStateMachine.swift | 117 ++++- .../ExtendedQueryStateMachine.swift | 206 +++++++- .../New/Extensions/AnyErrorContinuation.swift | 13 + Sources/PostgresNIO/New/PSQLTask.swift | 17 + .../New/PostgresChannelHandler.swift | 64 ++- .../PSQLIntegrationTests.swift | 108 +++++ .../Extensions/PostgresFrontendMessage.swift | 122 +++++ .../New/PostgresConnectionTests.swift | 446 +++++++++++++++--- 9 files changed, 1218 insertions(+), 86 deletions(-) create mode 100644 Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift create mode 100644 Sources/PostgresNIO/New/Extensions/AnyErrorContinuation.swift diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift new file mode 100644 index 00000000..b654a899 --- /dev/null +++ b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift @@ -0,0 +1,211 @@ +/// Handle to send data for a `COPY ... FROM STDIN` query to the backend. +public struct PostgresCopyFromWriter: Sendable { + /// The backend failed the copy data transfer, which means that no more data sent by the frontend would be processed. + /// + /// The `PostgresCopyFromWriter` should cancel the data transfer. + public struct CopyCancellationError: Error { + /// The error that the backend sent us which cancelled the data transfer. + /// + /// Note that this error is related to previous `write` calls since a `CopyCancellationError` is thrown before + /// new data is written by `write`. + public let underlyingError: PSQLError + } + + private let channelHandler: NIOLoopBound + private let eventLoop: any EventLoop + + init(handler: PostgresChannelHandler, eventLoop: any EventLoop) { + self.channelHandler = NIOLoopBound(handler, eventLoop: eventLoop) + self.eventLoop = eventLoop + } + + private func writeAssumingInEventLoop(_ byteBuffer: ByteBuffer, _ continuation: CheckedContinuation) { + precondition(eventLoop.inEventLoop) + let promise = eventLoop.makePromise(of: Void.self) + self.channelHandler.value.checkBackendCanReceiveCopyData(promise: promise) + promise.futureResult.map { + if eventLoop.inEventLoop { + self.channelHandler.value.sendCopyData(byteBuffer) + } else { + eventLoop.execute { + self.channelHandler.value.sendCopyData(byteBuffer) + } + } + }.whenComplete { result in + continuation.resume(with: result) + } + } + + /// Send data for a `COPY ... FROM STDIN` operation to the backend. + /// + /// If the backend encountered an error during the data transfer and thus cannot process any more data, this throws + /// a `CopyCancellationError`. + public func write(_ byteBuffer: ByteBuffer) async throws { + // Check for cancellation. This is cheap and makes sure that we regularly check for cancellation in the + // `writeData` closure. It is likely that the user would forget to do so. + try Task.checkCancellation() + + // TODO: Listen for task cancellation while we are waiting for backpressure to clear. + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + if eventLoop.inEventLoop { + writeAssumingInEventLoop(byteBuffer, continuation) + } else { + eventLoop.execute { + writeAssumingInEventLoop(byteBuffer, continuation) + } + } + } + } + + /// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to + /// the backend. + func done() async throws { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + if eventLoop.inEventLoop { + self.channelHandler.value.sendCopyDone(continuation: continuation) + } else { + eventLoop.execute { + self.channelHandler.value.sendCopyDone(continuation: continuation) + } + } + } + } + + /// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to + /// the backend. + func failed(error: any Error) async throws { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + // TODO: Is it OK to use string interpolation to construct an error description to be sent to the backend + // here? We could also use a generic description, it doesn't really matter since we throw the user's error + // in `copyFrom`. + if eventLoop.inEventLoop { + self.channelHandler.value.sendCopyFail(message: "\(error)", continuation: continuation) + } else { + eventLoop.execute { + self.channelHandler.value.sendCopyFail(message: "\(error)", continuation: continuation) + } + } + } + } +} + +/// Specifies the format in which data is transferred to the backend in a COPY operation. +public enum PostgresCopyFromFormat: Sendable { + /// Options that can be used to modify the `text` format of a COPY operation. + public struct TextOptions: Sendable { + /// The delimiter that separates columns in the data. + /// + /// See the `DELIMITER` option in Postgres's `COPY` command. + /// + /// Uses the default delimiter of the format + public var delimiter: UnicodeScalar? = nil + + public init() {} + } + + case text(TextOptions) +} + +/// Create a `COPY ... FROM STDIN` query based on the given parameters. +/// +/// An empty `columns` array signifies that no columns should be specified in the query and that all columns will be +/// copied by the caller. +private func buildCopyFromQuery( + table: StaticString, + columns: [StaticString] = [], + format: PostgresCopyFromFormat +) -> PostgresQuery { + // TODO: Should we put the table and column names in quotes to make them case-sensitive? + var query = "COPY \(table)" + if !columns.isEmpty { + query += "(" + columns.map(\.description).joined(separator: ",") + ")" + } + query += " FROM STDIN" + var queryOptions: [String] = [] + switch format { + case .text(let options): + queryOptions.append("FORMAT text") + if let delimiter = options.delimiter { + // Set the delimiter as a Unicode code point. This avoids the possibility of SQL injection. + queryOptions.append("DELIMITER U&'\\\(String(format: "%04x", delimiter.value))'") + } + } + precondition(!queryOptions.isEmpty) + query += " WITH (" + query += queryOptions.map { "\($0)" }.joined(separator: ",") + query += ")" + return "\(unescaped: query)" +} + +extension PostgresConnection { + /// Copy data into a table using a `COPY FROM STDIN` query. + /// + /// - Parameters: + /// - table: The name of the table into which to copy the data. + /// - columns: The name of the columns to copy. If an empty array is passed, all columns are assumed to be copied. + /// - format: Options that specify the format of the data that is produced by `writeData`. + /// - writeData: Closure that produces the data for the table, to be streamed to the backend. Call `write` on the + /// writer provided by the closure to send data to the backend and return from the closure once all data is sent. + /// Throw an error from the closure to fail the data transfer. The error thrown by the closure will be rethrown + /// by the `copyFrom` function. + /// + /// - Note: The table and column names are inserted into the SQL query verbatim. They are forced to be compile-time + /// specified to avoid runtime SQL injection attacks. + public func copyFrom( + table: StaticString, + columns: [StaticString] = [], + format: PostgresCopyFromFormat = .text(.init()), + logger: Logger, + file: String = #fileID, + line: Int = #line, + writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void + ) async throws { + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(self.id)" + let writer: PostgresCopyFromWriter = try await withCheckedThrowingContinuation { continuation in + let context = ExtendedQueryContext( + copyFromQuery: buildCopyFromQuery(table: table, columns: columns, format: format), + triggerCopy: continuation, + logger: logger + ) + self.channel.write(HandlerTask.extendedQuery(context), promise: nil) + } + + do { + try await writeData(writer) + } catch { + // We need to send a `CopyFail` to the backend to put it out of copy mode. This will most likely throw, most + // notably for the following two reasons. In both of them, it's better to ignore the error thrown by + // `writer.failed` and instead throw the error from `writeData`: + // - We send `CopyFail` and the backend replies with an `ErrorResponse` that relays the `CopyFail` message. + // This took the backend out of copy mode but it's more informative to the user to see the error they + // threw instead of the one that got relayed back, so it's better to ignore the error here. + // - The backend sent us an `ErrorResponse` during the copy, eg. because of an invalid format. This puts + // the `ExtendedQueryStateMachine` in the error state. Trying to send a `CopyFail` will throw but trigger + // a `Sync` that takes the backend out of copy mode. If `writeData` threw the `CopyCancellationError` + // from the `PostgresCopyFromWriter.write` call, `writer.failed` will throw with the same error, so it + // doesn't matter that we ignore the error here. If the user threw some other error, it's better to honor + // the user's error. + try? await writer.failed(error: error) + + if let error = error as? PostgresCopyFromWriter.CopyCancellationError { + // If we receive a `CopyCancellationError` that is with almost certain likelihood because + // `PostgresCopyFromWriter.write` threw it - otherwise the user must have saved a previous + // `PostgresCopyFromWriter` error, which is very unlikely. + // Throw the underlying error because that contains the error message that was sent by the backend and + // is most actionable by the user. + throw error.underlyingError + } else { + throw error + } + } + + // `writer.done` may fail, eg. because the backend sends an error response after receiving `CopyDone` or during + // the transfer of the last bit of data so that the user didn't call `PostgresCopyFromWriter.write` again, which + // would have checked the error state. In either of these cases, calling `writer.done` puts the backend out of + // copy mode, so we don't need to send another `CopyFail`. Thus, this must not be handled in the `do` block + // above. + try await writer.done() + } + +} diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 8560b948..decd0c1a 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -88,7 +88,27 @@ struct ConnectionStateMachine { case sendParseDescribeBindExecuteSync(PostgresQuery) case sendBindExecuteSync(PSQLExecuteStatement) case failQuery(EventLoopPromise, with: PSQLError, cleanupContext: CleanUpContext?) + /// Fail a query's execution by resuming the continuation with the given error. When `sync` is `true`, send a + /// `Sync` message to the backend. + case failQueryContinuation(AnyErrorContinuation, with: PSQLError, sync: Bool, cleanupContext: CleanUpContext?) + /// Fail a query's execution by resuming the continuation with the given error and send a `Sync` message to the + /// backend. case succeedQuery(EventLoopPromise, with: QueryResult) + /// Succeed the continuation with a void result. When `sync` is `true`, send a `Sync` message to the backend. + case succeedQueryContinuation(CheckedContinuation, sync: Bool) + + /// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation. + /// + /// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state + /// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling + /// `PostgresChannelHandler.sendCopyDone` or `PostgresChannelHandler.sendCopyFail`. + case triggerCopyData(CheckedContinuation) + + /// Send a `CopyDone` and `Sync` message to the backend. + case sendCopyDoneAndSync + + /// Send a `CopyFail` message to the backend with the given error message. + case sendCopyFail(message: String) // --- streaming actions // actions if query has requested next row but we are waiting for backend @@ -107,6 +127,14 @@ struct ConnectionStateMachine { case failClose(CloseCommandContext, with: PSQLError, cleanupContext: CleanUpContext?) } + enum ChannelWritabilityChangedAction { + /// No action needs to be taken based on the writability change. + case none + + /// Resume the given continuation successfully. + case succeedPromise(EventLoopPromise) + } + private var state: State private let requireBackendKeyData: Bool private var taskQueue = CircularBuffer() @@ -587,6 +615,8 @@ struct ConnectionStateMachine { switch queryContext.query { case .executeStatement(_, let promise), .unnamed(_, let promise): return .failQuery(promise, with: psqlErrror, cleanupContext: nil) + case .copyFrom(_, let triggerCopy): + return .failQueryContinuation(.copyFromWriter(triggerCopy), with: psqlErrror, sync: false, cleanupContext: nil) case .prepareStatement(_, _, _, let promise): return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil) } @@ -660,6 +690,16 @@ struct ConnectionStateMachine { preconditionFailure("Invalid state: \(self.state)") } } + + mutating func channelWritabilityChanged(isWritable: Bool) -> ChannelWritabilityChangedAction { + guard case .extendedQuery(var queryState, let connectionContext) = state else { + return .none + } + self.state = .modifying // avoid CoW + let action = queryState.channelWritabilityChanged(isWritable: isWritable) + self.state = .extendedQuery(queryState, connectionContext) + return action + } // MARK: - Running Queries - @@ -752,10 +792,55 @@ struct ConnectionStateMachine { return self.modify(with: action) } - mutating func copyInResponseReceived( - _ copyInResponse: PostgresBackendMessage.CopyInResponse - ) -> ConnectionAction { - return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + mutating func copyInResponseReceived(_ copyInResponse: PostgresBackendMessage.CopyInResponse) -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + + self.state = .modifying // avoid CoW + let action = queryState.copyInResponseReceived(copyInResponse) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } + + + /// Succeed the promise when the channel to the backend is writable and the backend is ready to receive more data. + /// + /// The promise may be failed if the backend indicated that it can't handle any more data by sending an + /// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer + /// should be aborted to avoid unnecessary work. + mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise) { + guard case .extendedQuery(var queryState, let connectionContext) = self.state else { + preconditionFailure("Copy mode is only supported for extended queries") + } + + self.state = .modifying // avoid CoW + queryState.checkBackendCanReceiveCopyData(channelIsWritable: channelIsWritable, promise: promise) + self.state = .extendedQuery(queryState, connectionContext) + } + + /// Put the state machine out of the copying mode and send a `CopyDone` message to the backend. + mutating func sendCopyDone(continuation: CheckedContinuation) -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state else { + preconditionFailure("Copy mode is only supported for extended queries") + } + + self.state = .modifying // avoid CoW + let action = queryState.sendCopyDone(continuation: continuation) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) + } + + /// Put the state machine out of the copying mode and send a `CopyFail` message to the backend. + mutating func sendCopyFail(message: String, continuation: CheckedContinuation) -> ConnectionAction { + guard case .extendedQuery(var queryState, let connectionContext) = self.state else { + preconditionFailure("Copy mode is only supported for extended queries") + } + + self.state = .modifying // avoid CoW + let action = queryState.sendCopyFail(message: message, continuation: continuation) + self.state = .extendedQuery(queryState, connectionContext) + return self.modify(with: action) } mutating func emptyQueryResponseReceived() -> ConnectionAction { @@ -866,14 +951,21 @@ struct ConnectionStateMachine { .forwardRows, .forwardStreamComplete, .wait, - .read: + .read, + .triggerCopyData, + .sendCopyDoneAndSync, + .sendCopyFail, + .succeedQueryContinuation: preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)") case .evaluateErrorAtConnectionLevel: return .closeConnectionAndCleanup(cleanupContext) - case .failQuery(let queryContext, with: let error): - return .failQuery(queryContext, with: error, cleanupContext: cleanupContext) + case .failQuery(let promise, with: let error): + return .failQuery(promise, with: error, cleanupContext: cleanupContext) + + case .failQueryContinuation(let continuation, with: let error, let sync): + return .failQueryContinuation(continuation, with: error, sync: sync, cleanupContext: cleanupContext) case .forwardStreamError(let error, let read): return .forwardStreamError(error, read: read, cleanupContext: cleanupContext) @@ -1044,8 +1136,19 @@ extension ConnectionStateMachine { case .failQuery(let requestContext, with: let error): let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) return .failQuery(requestContext, with: error, cleanupContext: cleanupContext) + case .failQueryContinuation(let continuation, with: let error, let sync): + let cleanupContext = self.setErrorAndCreateCleanupContextIfNeeded(error) + return .failQueryContinuation(continuation, with: error, sync: sync, cleanupContext: cleanupContext) case .succeedQuery(let requestContext, with: let result): return .succeedQuery(requestContext, with: result) + case .succeedQueryContinuation(let continuation, let sync): + return .succeedQueryContinuation(continuation, sync: sync) + case .triggerCopyData(let triggerCopy): + return .triggerCopyData(triggerCopy) + case .sendCopyDoneAndSync: + return .sendCopyDoneAndSync + case .sendCopyFail(message: let message): + return .sendCopyFail(message: message) case .forwardRows(let buffer): return .forwardRows(buffer) case .forwardStreamComplete(let buffer, let commandTag): diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 5708b6b9..8ecda6ab 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -2,6 +2,15 @@ import NIOCore struct ExtendedQueryStateMachine { + private enum CopyingDataState { + /// The write channel is ready to handle more data. + case readyToSend + + /// The write channel has backpressure. Once that is relieved, we should resume the attached continuation to + /// allow more data to be sent by the client. + case pendingBackpressureRelieve(EventLoopPromise) + } + private enum State { case initialized(ExtendedQueryContext) case messagesSent(ExtendedQueryContext) @@ -12,6 +21,19 @@ struct ExtendedQueryStateMachine { case noDataMessageReceived(ExtendedQueryContext) case emptyQueryResponseReceived + /// We are currently copying data to the backend using `CopyData` messages. + case copyingData(CopyingDataState) + + /// We copied data to the backend and are done with that, either by sending a `CopyDone` or `CopyFail` message. + /// We are now expecting a `CommandComplete` or `ErrorResponse`. + /// + /// Once that is received the continuation is resumed. + /// + /// `successful` identifies whether copying was finished with a `CopyDone` or a `CopyFail` message. This is + /// necessary because we send a `Sync` after `CopyDone` but only send the `Sync` for `CopyFail` once we receive + /// the `ErrorResponse` from the backend. + case copyingFinished(CheckedContinuation, successful: Bool) + /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) @@ -32,13 +54,31 @@ struct ExtendedQueryStateMachine { // --- general actions case failQuery(EventLoopPromise, with: PSQLError) + /// Fail a query's execution by resuming the continuation with the given error. When `sync` is `true`, send a + /// `Sync` message to the backend. + case failQueryContinuation(AnyErrorContinuation, with: PSQLError, sync: Bool) case succeedQuery(EventLoopPromise, with: QueryResult) + /// Succeed the continuation with a void result. When `sync` is `true`, send a `Sync` message to the backend. + case succeedQueryContinuation(CheckedContinuation, sync: Bool) case evaluateErrorAtConnectionLevel(PSQLError) case succeedPreparedStatementCreation(EventLoopPromise, with: RowDescription?) case failPreparedStatementCreation(EventLoopPromise, with: PSQLError) + /// Trigger a data transfer returning a `PostgresCopyFromWriter` to the given continuation. + /// + /// Once the data transfer is triggered, it will send `CopyData` messages to the backend. After that the state + /// machine needs to be prodded again to send a `CopyDone` or `CopyFail` by calling + /// `PostgresChannelHandler.sendCopyDone` or `PostgresChannelHandler.sendCopyFail`. + case triggerCopyData(CheckedContinuation) + + /// Send a `CopyDone` and `Sync` message to the backend. + case sendCopyDoneAndSync + + /// Send a `CopyFail` message to the backend with the given error message. + case sendCopyFail(message: String) + // --- streaming actions // actions if query has requested next row but we are waiting for backend case forwardRows([DataRow]) @@ -63,7 +103,7 @@ struct ExtendedQueryStateMachine { } switch queryContext.query { - case .unnamed(let query, _): + case .unnamed(let query, _), .copyFrom(let query, _): return self.avoidingStateMachineCoW { state -> Action in state = .messagesSent(queryContext) return .sendParseDescribeBindExecuteSync(query) @@ -108,10 +148,21 @@ struct ExtendedQueryStateMachine { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: .queryCancelled) + case .copyFrom(_, let triggerCopy): + return .failQueryContinuation(.copyFromWriter(triggerCopy), with: .queryCancelled, sync: false) + case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled) } + case .copyingData: + return .sendCopyFail(message: "Copy cancelled") + + case .copyingFinished: + // We already finished the copy and are awaiting the `CommandComplete` or `ErrorResponse` from it. There's + // nothing we can do to cancel that. + return .wait + case .streaming(let columns, var streamStateMachine): precondition(!self.isCancelled) self.isCancelled = true @@ -160,7 +211,7 @@ struct ExtendedQueryStateMachine { } switch queryContext.query { - case .unnamed, .executeStatement: + case .unnamed, .copyFrom, .executeStatement: return self.avoidingStateMachineCoW { state -> Action in state = .noDataMessageReceived(queryContext) return .wait @@ -198,7 +249,7 @@ struct ExtendedQueryStateMachine { } switch queryContext.query { - case .unnamed, .executeStatement: + case .unnamed, .copyFrom, .executeStatement: return .wait case .prepareStatement(_, _, _, let eventLoopPromise): @@ -219,6 +270,10 @@ struct ExtendedQueryStateMachine { case .prepareStatement: return .evaluateErrorAtConnectionLevel(.unexpectedBackendMessage(.bindComplete)) + case .copyFrom: + // The COPY commands don't return row descriptions, so we should never be in the + // `rowDescriptionReceived` state. + return .evaluateErrorAtConnectionLevel(.unexpectedBackendMessage(.bindComplete)) } case .noDataMessageReceived(let queryContext): @@ -235,7 +290,9 @@ struct ExtendedQueryStateMachine { .streaming, .drain, .commandComplete, - .error: + .error, + .copyingData, + .copyingFinished: return self.setAndFireError(.unexpectedBackendMessage(.bindComplete)) case .modifying: @@ -274,7 +331,9 @@ struct ExtendedQueryStateMachine { .rowDescriptionReceived, .bindCompleteReceived, .commandComplete, - .error: + .error, + .copyingData, + .copyingFinished: return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) case .modifying: preconditionFailure("Invalid state") @@ -291,10 +350,19 @@ struct ExtendedQueryStateMachine { let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) return .succeedQuery(eventLoopPromise, with: result) } - + case .copyFrom: + // We expect to transition through `copyingData` to `copyingFinished` before receiving a + // `CommandCompleted` message for copy queries. + return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag))) case .prepareStatement: preconditionFailure("Invalid state: \(self.state)") } + + case .copyingFinished(let continuation, let successful): + return self.avoidingStateMachineCoW { state -> Action in + state = .commandComplete(commandTag: commandTag) + return .succeedQueryContinuation(continuation, sync: !successful) + } case .streaming(_, var demandStateMachine): return self.avoidingStateMachineCoW { state -> Action in @@ -315,17 +383,82 @@ struct ExtendedQueryStateMachine { .emptyQueryResponseReceived, .rowDescriptionReceived, .commandComplete, - .error: + .error, + .copyingData: return self.setAndFireError(.unexpectedBackendMessage(.commandComplete(commandTag))) case .modifying: preconditionFailure("Invalid state") } } - mutating func copyInResponseReceived( - _ copyInResponse: PostgresBackendMessage.CopyInResponse - ) -> Action { - return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + mutating func copyInResponseReceived(_ copyInResponse: PostgresBackendMessage.CopyInResponse) -> Action { + guard case .bindCompleteReceived(let queryContext) = self.state, + case .copyFrom(_, let triggerCopy) = queryContext.query else { + return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + return avoidingStateMachineCoW { state in + // We can assume that we have no backpressure here. Before sending data, `checkBackendCanReceiveCopyData` + // will be called, which checks if the channel to the backend is indeed writable. + state = .copyingData(.readyToSend) + return .triggerCopyData(triggerCopy) + } + } + + /// Succeed the promise when the channel to the backend is writable and the backend is ready to receive more data. + /// + /// The promise may be failed if the backend indicated that it can't handle any more data by sending an + /// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer + /// should be aborted to avoid unnecessary work. + mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise) { + if case .error(let error) = self.state { + // The backend sent us an ErrorResponse during the copy operation. Indicate to the client that it should + // abort the data transfer. + promise.fail(PostgresCopyFromWriter.CopyCancellationError(underlyingError: error)) + return + } + guard case .copyingData(.readyToSend) = self.state else { + preconditionFailure("Not ready to send data") + } + if channelIsWritable { + promise.succeed() + return + } + return avoidingStateMachineCoW { state in + state = .copyingData(.pendingBackpressureRelieve(promise)) + } + } + + /// Put the state machine out of the copying mode and send a `CopyDone` message to the backend. + mutating func sendCopyDone(continuation: CheckedContinuation) -> Action { + if case .error(let error) = self.state { + // The backend sent us an ErrorResponse during the copy operation. We need to send a `Sync` to get out of + // copy mode and communicate the error to the user. There's no need for `CopyDone` anymore. + return .failQueryContinuation(.void(continuation), with: error, sync: true) + } + guard case .copyingData = self.state else { + preconditionFailure("Must be in copy mode to send CopyDone") + } + return avoidingStateMachineCoW { state in + state = .copyingFinished(continuation, successful: true) + return .sendCopyDoneAndSync + } + } + + /// Put the state machine out of the copying mode and send a `CopyFail` message to the backend. + mutating func sendCopyFail(message: String, continuation: CheckedContinuation) -> Action { + if case .error(let error) = self.state { + // The backend sent us an ErrorResponse during the copy operation. We need to send a `Sync` to get out of + // copy mode and communicate the error to the user. There's no need for `CopyFail` anymore. + return .failQueryContinuation(.void(continuation), with: error, sync: true) + } + guard case .copyingData = self.state else { + preconditionFailure("Must be in copy mode to send CopyFail") + } + return avoidingStateMachineCoW { state in + state = .copyingFinished(continuation, successful: false) + return .sendCopyFail(message: message) + } + } mutating func emptyQueryResponseReceived() -> Action { @@ -342,7 +475,7 @@ struct ExtendedQueryStateMachine { return .succeedQuery(eventLoopPromise, with: result) } - case .prepareStatement(_, _, _, _): + case .prepareStatement, .copyFrom: return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) } } @@ -359,6 +492,8 @@ struct ExtendedQueryStateMachine { return self.setAndFireError(error) case .rowDescriptionReceived, .noDataMessageReceived: return self.setAndFireError(error) + case .copyingData, .copyingFinished: + return self.setAndFireError(error) case .streaming, .drain: return self.setAndFireError(error) case .commandComplete, .emptyQueryResponseReceived: @@ -409,7 +544,9 @@ struct ExtendedQueryStateMachine { .noDataMessageReceived, .emptyQueryResponseReceived, .rowDescriptionReceived, - .bindCompleteReceived: + .bindCompleteReceived, + .copyingData, + .copyingFinished: preconditionFailure("Requested to consume next row without anything going on.") case .commandComplete, .error: @@ -433,7 +570,9 @@ struct ExtendedQueryStateMachine { .noDataMessageReceived, .emptyQueryResponseReceived, .rowDescriptionReceived, - .bindCompleteReceived: + .bindCompleteReceived, + .copyingData, + .copyingFinished: return .wait case .streaming(let columns, var demandStateMachine): @@ -460,7 +599,9 @@ struct ExtendedQueryStateMachine { .parameterDescriptionReceived, .noDataMessageReceived, .rowDescriptionReceived, - .bindCompleteReceived: + .bindCompleteReceived, + .copyingData, + .copyingFinished: return .read case .streaming(let columns, var demandStateMachine): precondition(!self.isCancelled) @@ -486,6 +627,16 @@ struct ExtendedQueryStateMachine { preconditionFailure("Invalid state") } } + + mutating func channelWritabilityChanged(isWritable: Bool) -> ConnectionStateMachine.ChannelWritabilityChangedAction { + guard case .copyingData(.pendingBackpressureRelieve(let promise)) = state else { + return .none + } + return self.avoidingStateMachineCoW { state in + state = .copyingData(.readyToSend) + return .succeedPromise(promise) + } + } // MARK: Private Methods @@ -505,11 +656,22 @@ struct ExtendedQueryStateMachine { switch context.query { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return .failQuery(eventLoopPromise, with: error) + case .copyFrom(_, let triggerCopy): + return .failQueryContinuation(.copyFromWriter(triggerCopy), with: error, sync: false) case .prepareStatement(_, _, _, let eventLoopPromise): return .failPreparedStatementCreation(eventLoopPromise, with: error) } } - + case .copyingData: + self.state = .error(error) + // Store the error. We expect the next chunk of data to be written almost immediately, which will call + // `checkBackendCanReceiveCopyData`, which handles the error. If the user is done writing data, we expect a + // `CopyDone` or `CopyFail` message soon, which also checks for the error case, so there's nothing that we + // need to actively do here. + return .wait + case .copyingFinished(let continuation, let successful): + self.state = .error(error) + return .failQueryContinuation(.void(continuation), with: error, sync: !successful) case .drain: self.state = .error(error) return .evaluateErrorAtConnectionLevel(error) @@ -542,11 +704,19 @@ struct ExtendedQueryStateMachine { switch context.query { case .prepareStatement: return true - case .unnamed, .executeStatement: + case .unnamed, .copyFrom, .executeStatement: return false } - case .initialized, .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, .bindCompleteReceived, .streaming, .drain: + case .initialized, + .messagesSent, + .parseCompleteReceived, + .parameterDescriptionReceived, + .bindCompleteReceived, + .streaming, + .drain, + .copyingData, + .copyingFinished: return false case .modifying: diff --git a/Sources/PostgresNIO/New/Extensions/AnyErrorContinuation.swift b/Sources/PostgresNIO/New/Extensions/AnyErrorContinuation.swift new file mode 100644 index 00000000..b141c928 --- /dev/null +++ b/Sources/PostgresNIO/New/Extensions/AnyErrorContinuation.swift @@ -0,0 +1,13 @@ +/// Enum that abstracts over continuations that have `any Error` as the failure type. Cases are expected to get added +/// for the success types that we care about. +enum AnyErrorContinuation { + case void(CheckedContinuation) + case copyFromWriter(CheckedContinuation) + + func resume(throwing error: any Error) { + switch self { + case .void(let continuation): continuation.resume(throwing: error) + case .copyFromWriter(let continuation): continuation.resume(throwing: error) + } + } +} diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 6106fd21..820e8a20 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -19,6 +19,8 @@ enum PSQLTask { switch extendedQueryContext.query { case .unnamed(_, let eventLoopPromise): eventLoopPromise.fail(error) + case .copyFrom(_, let triggerCopy): + triggerCopy.resume(throwing: error) case .executeStatement(_, let eventLoopPromise): eventLoopPromise.fail(error) case .prepareStatement(_, _, _, let eventLoopPromise): @@ -34,6 +36,12 @@ enum PSQLTask { final class ExtendedQueryContext: Sendable { enum Query { case unnamed(PostgresQuery, EventLoopPromise) + /// A `COPY ... FROM STDIN` query that copies data from the frontend into a table. + /// + /// When `triggerCopy` is resumed, the `PostgresConnection` that created this query should send data to the + /// backend via `CopyData` messages and finalize the data transfer by calling `sendCopyDone` or `sendCopyFail` + /// on the `PostgresChannelHandler`. + case copyFrom(PostgresQuery, triggerCopy: CheckedContinuation) case executeStatement(PSQLExecuteStatement, EventLoopPromise) case prepareStatement(name: String, query: String, bindingDataTypes: [PostgresDataType], EventLoopPromise) } @@ -50,6 +58,15 @@ final class ExtendedQueryContext: Sendable { self.logger = logger } + init( + copyFromQuery query: PostgresQuery, + triggerCopy: CheckedContinuation, + logger: Logger + ) { + self.query = .copyFrom(query, triggerCopy: triggerCopy) + self.logger = logger + } + init( executeStatement: PSQLExecuteStatement, logger: Logger, diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index baf801e5..cafe0f21 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -171,10 +171,48 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.run(action, with: context) } + /// Succeed the promise when the channel to the backend is writable and the backend is ready to receive more data. + /// + /// The promise may be failed if the backend indicated that it can't handle any more data by sending an + /// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer + /// should be aborted to avoid unnecessary work. + func checkBackendCanReceiveCopyData(promise: EventLoopPromise) { + self.state.checkBackendCanReceiveCopyData(channelIsWritable: handlerContext!.channel.isWritable, promise: promise) + } + + /// Send a `CopyData` message to the backend using the given data. + func sendCopyData(_ data: ByteBuffer) { + self.encoder.copyDataHeader(dataLength: UInt32(data.readableBytes)) + self.handlerContext!.write(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + self.handlerContext!.writeAndFlush(self.wrapOutboundOut(data), promise: nil) + } + + /// Put the state machine out of the copying mode and send a `CopyDone` message to the backend. + func sendCopyDone(continuation: CheckedContinuation) { + let action = self.state.sendCopyDone(continuation: continuation) + self.run(action, with: self.handlerContext!) + } + + /// Put the state machine out of the copying mode and send a `CopyFail` message to the backend. + func sendCopyFail(message: String, continuation: CheckedContinuation) { + let action = self.state.sendCopyFail(message: message, continuation: continuation) + self.run(action, with: self.handlerContext!) + } + func channelReadComplete(context: ChannelHandlerContext) { let action = self.state.channelReadComplete() self.run(action, with: context) } + + func channelWritabilityChanged(context: ChannelHandlerContext) { + let action = self.state.channelWritabilityChanged(isWritable: context.channel.isWritable) + switch action { + case .none: + break + case .succeedPromise(let promise): + promise.succeed() + } + } func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { self.logger.trace("User inbound event received", metadata: [ @@ -355,12 +393,36 @@ final class PostgresChannelHandler: ChannelDuplexHandler { self.sendParseDescribeBindExecuteAndSyncMessage(query: query, context: context) case .succeedQuery(let promise, with: let result): self.succeedQuery(promise, result: result, context: context) + case .succeedQueryContinuation(let continuation, let sync): + if sync { + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + continuation.resume() case .failQuery(let promise, with: let error, let cleanupContext): promise.fail(error) if let cleanupContext = cleanupContext { self.closeConnectionAndCleanup(cleanupContext, context: context) } - + case .failQueryContinuation(let continuation, with: let error, let sync, let cleanupContext): + if sync { + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + } + if let cleanupContext = cleanupContext { + self.closeConnectionAndCleanup(cleanupContext, context: context) + } + continuation.resume(throwing: error) + case .triggerCopyData(let triggerCopy): + let writer = PostgresCopyFromWriter(handler: self, eventLoop: eventLoop) + triggerCopy.resume(returning: writer) + case .sendCopyDoneAndSync: + self.encoder.copyDone() + self.encoder.sync() + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + case .sendCopyFail(message: let message): + self.encoder.copyFail(message: message) + context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) case .forwardRows(let rows): self.rowStream!.receive(rows) diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index d541899b..35581edb 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -379,4 +379,112 @@ final class IntegrationTests: XCTestCase { } } + func testCopyIntoFrom() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let conn = try await PostgresConnection.test(on: eventLoop).get() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + _ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get() + _ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get() + + var options = PostgresCopyFromFormat.TextOptions() + options.delimiter = "," + try await conn.copyFrom(table: "copy_table", columns: ["id", "name"], format: .text(options), logger: .psqlTest) { writer in + let records: [(id: Int, name: String)] = [ + (1, "Alice"), + (42, "Bob") + ] + for record in records { + var buffer = ByteBuffer() + buffer.writeString("\(record.id),\(record.name)\n") + try await writer.write(buffer) + } + } + let rows = try await conn.query("SELECT id, name FROM copy_table").get().rows.map { try $0.decode((Int, String).self) } + guard rows.count == 2 else { + XCTFail("Expected 2 columns, received \(rows.count)") + return + } + XCTAssertEqual(rows[0].0, 1) + XCTAssertEqual(rows[0].1, "Alice") + XCTAssertEqual(rows[1].0, 42) + XCTAssertEqual(rows[1].1, "Bob") + } + + func testCopyIntoFromIsTerminatedByThrowingErrorFromClosure() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let conn = try await PostgresConnection.test(on: eventLoop).get() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + _ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get() + _ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get() + + do { + try await conn.copyFrom(table: "copy_table", columns: ["id", "name"], logger: .psqlTest) { writer in + throw MyError() + } + XCTFail("Expected error to be thrown") + } catch { + XCTAssert(error is MyError, "Expected error of type MyError, got \(String(reflecting: error))") + } + } + + + func testCopyIntoFromHasBadFormat() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let conn = try await PostgresConnection.test(on: eventLoop).get() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + _ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get() + _ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get() + + do { + try await conn.copyFrom(table: "copy_table", columns: ["id", "name"], logger: .psqlTest) { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + } + XCTFail("Expected error to be thrown") + } catch { + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") // invalid_text_representation + } + } + + func testSyntaxErrorInGeneratedQuery() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let conn = try await PostgresConnection.test(on: eventLoop).get() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + do { + // Use some form of input that generates an invalid query, the exact manner of its invalidness doesn't matter + try await conn.copyFrom(table: "", logger: .psqlTest) { writer in + XCTFail("Did not expect to call writeData") + } + XCTFail("Expected error to be thrown") + } catch { + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "42601") // scanner_yyerror + } + } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index 5fc8144b..d2dd5681 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -323,3 +323,125 @@ extension PostgresFrontendMessage { } } } + +/// Convenience accessors to get a specific case or `nil` if the enum is of a different case. +extension PostgresFrontendMessage { + var bind: Bind? { + guard case .bind(let bind) = self else { + return nil + } + return bind + } + + var cancel: Cancel? { + guard case .cancel(let cancel) = self else { + return nil + } + return cancel + } + + var copyData: CopyData? { + guard case .copyData(let copyData) = self else { + return nil + } + return copyData + } + + var copyDone: Void? { + guard case .copyDone = self else { + return nil + } + return () + } + + var copyFail: CopyFail? { + guard case .copyFail(let copyFail) = self else { + return nil + } + return copyFail + } + + var close: Close? { + guard case .close(let close) = self else { + return nil + } + return close + } + + var describe: Describe? { + guard case .describe(let describe) = self else { + return nil + } + return describe + } + + var execute: Execute? { + guard case .execute(let execute) = self else { + return nil + } + return execute + } + + var flush: Void? { + guard case .flush = self else { + return nil + } + return () + } + + var parse: Parse? { + guard case .parse(let parse) = self else { + return nil + } + return parse + } + + var password: Password? { + guard case .password(let password) = self else { + return nil + } + return password + } + + var saslInitialResponse: SASLInitialResponse? { + guard case .saslInitialResponse(let saslInitialResponse) = self else { + return nil + } + return saslInitialResponse + } + + var saslResponse: SASLResponse? { + guard case .saslResponse(let saslResponse) = self else { + return nil + } + return saslResponse + } + + var sslRequest: Void? { + guard case .sslRequest = self else { + return nil + } + return () + } + + var sync: Void? { + guard case .sync = self else { + return nil + } + return () + } + + var startup: Startup? { + guard case .startup(let startup) = self else { + return nil + } + return startup + } + + var terminate: Void? { + guard case .terminate = self else { + return nil + } + return () + } +} diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index d0f8e2b0..9c3dabfc 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -4,6 +4,9 @@ import NIOEmbedded import XCTest import Logging @testable import PostgresNIO +#if canImport(Synchronization) +import Synchronization +#endif class PostgresConnectionTests: XCTestCase { @@ -70,8 +73,8 @@ class PostgresConnectionTests: XCTestCase { }() async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: .psqlTest) - let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false)))) + let message = try await channel.waitForPostgresFrontendMessage(\.startup) + XCTAssertEqual(message, .versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -95,10 +98,7 @@ class PostgresConnectionTests: XCTestCase { let listenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -107,10 +107,7 @@ class PostgresConnectionTests: XCTestCase { let unlistenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -155,10 +152,7 @@ class PostgresConnectionTests: XCTestCase { let listenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -168,10 +162,7 @@ class PostgresConnectionTests: XCTestCase { let unlistenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -204,10 +195,7 @@ class PostgresConnectionTests: XCTestCase { let listenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -267,8 +255,7 @@ class PostgresConnectionTests: XCTestCase { try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) } - let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(terminate, .terminate) + try await channel.waitForPostgresFrontendMessage(\.terminate) try await channel.closeFuture.get() XCTAssertEqual(channel.isActive, false) @@ -283,7 +270,7 @@ class PostgresConnectionTests: XCTestCase { } } - func testCloseClosesImmediatly() async throws { + func testCloseClosesImmediately() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in @@ -638,6 +625,244 @@ class PostgresConnectionTests: XCTestCase { } } + func testCopyFromSucceeds() async throws { + try await assertCopyFrom { writer in + try await writer.write(ByteBuffer(staticString: "1\tAlice\n")) + } validateCopyRequest: { copyRequest in + XCTAssertEqual(copyRequest.parse.query, "COPY copy_table(id,name) FROM STDIN WITH (FORMAT text)") + XCTAssertEqual(copyRequest.bind.parameters, []) + } mockBackend: { channel in + let data = try await channel.waitForCopyData() + XCTAssertEqual(String(buffer: data.data), "1\tAlice\n") + XCTAssertEqual(data.result, .done) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1")) + } + } + + + func testCopyFromWithOptions() async throws { + var options = PostgresCopyFromFormat.TextOptions() + options.delimiter = "," + try await assertCopyFrom(format: .text(options)) { writer in + try await writer.write(ByteBuffer(staticString: "1,Alice\n")) + } validateCopyRequest: { copyRequest in + XCTAssertEqual(copyRequest.parse.query, #"COPY copy_table(id,name) FROM STDIN WITH (FORMAT text,DELIMITER U&'\002c')"#) + XCTAssertEqual(copyRequest.bind.parameters, []) + } mockBackend: { channel in + let data = try await channel.waitForCopyData() + XCTAssertEqual(String(buffer: data.data), "1,Alice\n") + XCTAssertEqual(data.result, .done) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1")) + } + } + + func testCopyFromWriterFails() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + try await assertCopyFrom { writer in + throw MyError() + } validateCopyFromError: { error in + XCTAssert(error is MyError, "Expected error of type MyError, got \(error)") + } mockBackend: { channel in + let data = try await channel.waitForCopyData() + XCTAssertEqual(data.result, .failed(message: "My error")) + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: "COPY from stdin failed: My error", + .sqlState : "57014" // query_canceled + ]))) + } + } + + func testCopyFromBackendSendsErrorBeforeCopyDone() async throws { + // Stream to indicate that the backend did send an error + let (signalStream, signalContinuation) = AsyncStream.makeStream(of: Void.self) + + try await assertCopyFrom { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + var iterator = signalStream.makeAsyncIterator() + await iterator.next() + } validateCopyFromError: { error in + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") + } mockBackend: { channel in + let copyDataMessage = try await channel.waitForPostgresFrontendMessage(\.copyData) + XCTAssertEqual(copyDataMessage, PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n"))) + + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + signalContinuation.yield() + } + } + + func testCopyFromBackendSendsErrorAfterCopyDone() async throws { + try await assertCopyFrom { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + } validateCopyFromError: { error in + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") + } mockBackend: { channel in + _ = try await channel.waitForCopyData() + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + } + } + + func testCopyFromBackendSendsErrorBeforeUserThrowsUnrelatedErrorFromClosure() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + // Stream to indicate that the backend did send an error + let (signalStream, signalContinuation) = AsyncStream.makeStream(of: Void.self) + + // If the user throws an error and we receive an error from the server, we should prefer throwing the user error + // from `copyFrom` since it's likely the more actionable for the user. + try await assertCopyFrom { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + var iterator = signalStream.makeAsyncIterator() + await iterator.next() + throw MyError() + } validateCopyFromError: { error in + XCTAssert(error is MyError, "Expected MyError, got \(error)") + } mockBackend: { channel in + let copyDataMessage = try await channel.waitForPostgresFrontendMessage(\.copyData) + XCTAssertEqual(copyDataMessage, PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n"))) + + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + signalContinuation.yield() + } + } + + func testCopyFromWriterThrowsErrorAfterBackendSentError() async throws { + // Stream to indicate that the backend did send an error + let (signalStream, signalContinuation) = AsyncStream.makeStream(of: Void.self) + + try await assertCopyFrom { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + var iterator = signalStream.makeAsyncIterator() + await iterator.next() + do { + try await writer.write(ByteBuffer(staticString: "2\tBob\n")) + XCTFail("Expected error to be thrown") + } catch { + XCTAssert(error is PostgresCopyFromWriter.CopyCancellationError, "Received unexpected error: \(error)") + throw error + } + } validateCopyFromError: { error in + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") + } mockBackend: { channel in + let dataMessage = try await channel.waitForPostgresFrontendMessage(\.copyData) + XCTAssertEqual(dataMessage, PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n"))) + + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + signalContinuation.yield() + } + } + + func testCopyFromCallerDoesNotRethrowCopyCancellationError() async throws { + struct MyError: Error, CustomStringConvertible { + var description: String { "My error" } + } + + // Stream to indicate that the backend did send an error + let (signalStream, signalContinuation) = AsyncStream.makeStream(of: Void.self) + + try await assertCopyFrom { writer in + try await writer.write(ByteBuffer(staticString: "1Alice\n")) + var iterator = signalStream.makeAsyncIterator() + await iterator.next() + do { + try await writer.write(ByteBuffer(staticString: "2\tBob\n")) + XCTFail("Expected error to be thrown") + } catch { + XCTAssert(error is PostgresCopyFromWriter.CopyCancellationError, "Received unexpected error: \(error)") + throw MyError() + } + } validateCopyFromError: { error in + XCTAssert(error is MyError, "Expected MyError, got \(error)") + } mockBackend: { channel in + let dataMessage = try await channel.waitForPostgresFrontendMessage(\.copyData) + XCTAssertEqual(dataMessage, PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n"))) + + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"invalid input syntax for type integer: "1Alice""#, + .sqlState : "22P02" // invalid_text_representation + ]))) + signalContinuation.yield() + } + } + + func testCopyFromQueryHasSyntaxError() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + do { + try await connection.copyFrom(table: "", logger: .psqlTest) { _ in + XCTFail("Did not expect to call writeData") + } + + } catch { + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "42601") + } + // Send another query to ensure that the state machine is back in the idle state afterwards and can + // handle new queries. We don't wait for this to finish, just to receive the initiation on the other + // side of the + _ = connection.simpleQuery("DUMMY") + } + + _ = try await channel.waitForUnpreparedRequest() + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: #"syntax error"#, + .sqlState : "42601" // scanner_yyerror + ]))) + + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + _ = try await channel.waitForUnpreparedRequest() // Await the dummy query messages + } + } + + func testCopyFromHasWriteBackpressure() async throws { + #if !canImport(Synchronization) + throw XCTSkip("Test uses Synchronization which is not available") + #else + guard #available(macOS 15, *) else { + throw XCTSkip("Test uses Atomic which is not available") + } + // `true` while the `writeData` closure is executing the `PostgresCopyFromWriter.write` function, ie. while it + // is blocked for backpressure to be relieved. + let isWriting = Atomic(false) + + try await assertCopyFrom { writer in + isWriting.store(true, ordering: .sequentiallyConsistent) + try await writer.write(ByteBuffer(staticString: "1\tAlice\n")) + isWriting.store(false, ordering: .sequentiallyConsistent) + } preCopyInResponse: { channel in + channel.isWritable = false + } mockBackend: { channel in + XCTAssert(isWriting.load(ordering: .sequentiallyConsistent)) + + channel.isWritable = true + channel.pipeline.fireChannelWritabilityChanged() + + let data = try await channel.waitForCopyData() + XCTAssertEqual(data.data, ByteBuffer(staticString: "1\tAlice\n")) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1")) + } + #endif + } + func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in @@ -655,8 +880,8 @@ class PostgresConnectionTests: XCTestCase { let logger = self.logger async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: logger) - let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) + let message = try await channel.waitForPostgresFrontendMessage(\.startup) + XCTAssertEqual(message, .versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -669,40 +894,140 @@ class PostgresConnectionTests: XCTestCase { return (connection, channel) } + + /// Validate the behavior of a `COPY FROM` query. + /// + /// Also checks that the connection returns to an idle state after performing the copy and is capable + /// of handling another query. + /// + /// - Parameters: + /// - table: Forwarded to the `copyFrom` call in `PostgresConnection`. + /// - columns: Forwarded to the `copyFrom` call in `PostgresConnection`. + /// - format: Forwarded to the `copyFrom` call in `PostgresConnection`. + /// - writeData: Forwarded to the `copyFrom` call in `PostgresConnection`. + /// - validateCopyFromError: When not `nil`, we expect the `copyFrom` call to throw. This closure can be used to + /// inspect the thrown error and assert that it has the correct shape. + /// - preCopyInResponse: Called before the `CopyInResponse` is sent to the frontend. + /// - validateCopyRequest: Can be used to verify the shape of the `COPY` query that is received by the backend. + /// - mockBackend: determines how the backend behaves, starting after the point where the backend has sent the + /// `CopyInResponse` and ending in the state where the backend has sent a `CommandComplete` or `ErrorResponse` + /// and is now expecting a `Sync` to return back to the idle state. + private func assertCopyFrom( + table: StaticString = "copy_table", + columns: [StaticString] = ["id", "name"], + format: PostgresCopyFromFormat = .text(.init()), + writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void, + validateCopyFromError: (@Sendable (any Error) -> Void)? = nil, + preCopyInResponse: (_ channel: NIOAsyncTestingChannel) -> Void = { _ in }, + validateCopyRequest: (UnpreparedRequest) -> Void = { _ in }, + mockBackend: (_ channel: NIOAsyncTestingChannel) async throws -> Void, + file: StaticString = #file, + line: UInt = #line + ) async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + do { + try await connection.copyFrom(table: table, columns: columns, format: format, logger: .psqlTest, writeData: writeData) + if validateCopyFromError != nil { + XCTFail("Expected `copyFrom` to throw but it did not") + } + } catch { + if let validateCopyFromError { + validateCopyFromError(error) + } else { + throw error + } + } + // Send another query to ensure that the state machine is back in the idle state afterwards and can + // handle new queries. We don't wait for this to finish, just to receive the initiation on the other + // side of the + _ = connection.simpleQuery("DUMMY") + } + + let copyRequest = try await channel.waitForUnpreparedRequest() + validateCopyRequest(copyRequest) + + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + preCopyInResponse(channel) + try await channel.writeInbound(PostgresBackendMessage.copyInResponse(.init(format: .textual, columnFormats: Array(repeating: .textual, count: columns.count)))) + + try await mockBackend(channel) + + try await channel.waitForPostgresFrontendMessage(\.sync) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + _ = try await channel.waitForUnpreparedRequest() // Await the dummy query messages + } + } } extension NIOAsyncTestingChannel { + /// Wait for a `PostgresFrontendMessage` such that `transform` returns a non-nil value. + /// + /// The intention of this is to be used with the convenience accessors on `PostgresFrontendMessage` for the + /// different cases, eg. to wait for a `parse` message + /// + /// ```swift + /// try await self.waitForPostgresFrontendMessage(\.parse) + /// ``` + func waitForPostgresFrontendMessage( + _ transform: (PostgresFrontendMessage) -> T?, + file: StaticString = #file, + line: UInt = #line + ) async throws -> T { + let message = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + let payload = try XCTUnwrap(transform(message), "Received unexpected payload: \(message)", file: file, line: line) + return payload + } func waitForUnpreparedRequest() async throws -> UnpreparedRequest { - let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let bind = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let execute = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - - guard case .parse(let parse) = parse, - case .describe(let describe) = describe, - case .bind(let bind) = bind, - case .execute(let execute) = execute, - case .sync = sync - else { - fatalError() - } + let parse = try await self.waitForPostgresFrontendMessage(\.parse) + let describe = try await self.waitForPostgresFrontendMessage(\.describe) + let bind = try await self.waitForPostgresFrontendMessage(\.bind) + let execute = try await self.waitForPostgresFrontendMessage(\.execute) + try await self.waitForPostgresFrontendMessage(\.sync) return UnpreparedRequest(parse: parse, describe: describe, bind: bind, execute: execute) } - func waitForPrepareRequest() async throws -> PrepareRequest { - let parse = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let describe = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + struct CopyDataRequest { + enum Result: Equatable { + /// The data copy finished successfully with a `CopyDone` message. + case done + /// The data copy finished with a `CopyFail` message containing the following error message. + case failed(message: String) + } - guard case .parse(let parse) = parse, - case .describe(let describe) = describe, - case .sync = sync - else { - fatalError("Unexpected message") + /// The data that was transferred. + var data: ByteBuffer + + /// The `CopyDone` or `CopyFail` message that finalized the data transfer. + var result: Result + } + + func waitForCopyData() async throws -> CopyDataRequest { + var copiedData = ByteBuffer() + while true { + let message = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) + switch message { + case .copyData(let data): + copiedData.writeImmutableBuffer(data.data) + case .copyDone: + return CopyDataRequest(data: copiedData, result: .done) + case .copyFail(let message): + return CopyDataRequest(data: copiedData, result: .failed(message: message.message)) + default: + fatalError("Unexpected message") + } } + } + + func waitForPrepareRequest() async throws -> PrepareRequest { + let parse = try await self.waitForPostgresFrontendMessage(\.parse) + let describe = try await self.waitForPostgresFrontendMessage(\.describe) + try await self.waitForPostgresFrontendMessage(\.sync) return PrepareRequest(parse: parse, describe: describe) } @@ -722,16 +1047,9 @@ extension NIOAsyncTestingChannel { } func waitForPreparedRequest() async throws -> PreparedRequest { - let bind = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let execute = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - let sync = try await self.waitForOutboundWrite(as: PostgresFrontendMessage.self) - - guard case .bind(let bind) = bind, - case .execute(let execute) = execute, - case .sync = sync - else { - fatalError() - } + let bind = try await self.waitForPostgresFrontendMessage(\.bind) + let execute = try await self.waitForPostgresFrontendMessage(\.execute) + try await self.waitForPostgresFrontendMessage(\.sync) return PreparedRequest(bind: bind, execute: execute) } @@ -751,6 +1069,14 @@ extension NIOAsyncTestingChannel { try await self.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) try await self.testingEventLoop.executeInContext { self.read() } } + + /// Send the messages up to `BindComplete` for an unnamed query that does not bind any parameters. + func sendUnpreparedRequestWithNoParametersBindResponse() async throws { + try await writeInbound(PostgresBackendMessage.parseComplete) + try await writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await writeInbound(PostgresBackendMessage.noData) + try await writeInbound(PostgresBackendMessage.bindComplete) + } } struct UnpreparedRequest { From c5f2928a6c543587a29b77b51215ed163ccdef24 Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Mon, 7 Jul 2025 23:32:42 +0200 Subject: [PATCH 2/9] Implement cancellation during copy operations --- .../PostgresConnection+CopyFrom.swift | 37 +++++++-- .../ConnectionStateMachine.swift | 17 ++++- .../ExtendedQueryStateMachine.swift | 18 ++++- .../New/PostgresChannelHandler.swift | 44 ++++++++--- .../ExtendedQueryStateMachineTests.swift | 6 +- .../New/PostgresConnectionTests.swift | 76 ++++++++++++++++--- 6 files changed, 162 insertions(+), 36 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift index b654a899..51900b51 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift @@ -23,13 +23,17 @@ public struct PostgresCopyFromWriter: Sendable { precondition(eventLoop.inEventLoop) let promise = eventLoop.makePromise(of: Void.self) self.channelHandler.value.checkBackendCanReceiveCopyData(promise: promise) - promise.futureResult.map { + promise.futureResult.flatMap { if eventLoop.inEventLoop { - self.channelHandler.value.sendCopyData(byteBuffer) + return eventLoop.makeCompletedFuture(withResultOf: { + try self.channelHandler.value.sendCopyData(byteBuffer) + }) } else { + let promise = eventLoop.makePromise(of: Void.self) eventLoop.execute { - self.channelHandler.value.sendCopyData(byteBuffer) + promise.completeWith(Result(catching: { try self.channelHandler.value.sendCopyData(byteBuffer) })) } + return promise.futureResult } }.whenComplete { result in continuation.resume(with: result) @@ -45,13 +49,32 @@ public struct PostgresCopyFromWriter: Sendable { // `writeData` closure. It is likely that the user would forget to do so. try Task.checkCancellation() - // TODO: Listen for task cancellation while we are waiting for backpressure to clear. - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + try await withTaskCancellationHandler { + do { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + if eventLoop.inEventLoop { + writeAssumingInEventLoop(byteBuffer, continuation) + } else { + eventLoop.execute { + writeAssumingInEventLoop(byteBuffer, continuation) + } + } + } + } catch { + if Task.isCancelled { + // If the task was cancelled, we might receive a postgres error which is an artifact about how we + // communicate the cancellation to the state machine. Throw a `CancellationError` to the user + // instead, which looks more like native Swift Concurrency code. + throw CancellationError() + } + throw error + } + } onCancel: { if eventLoop.inEventLoop { - writeAssumingInEventLoop(byteBuffer, continuation) + self.channelHandler.value.cancel() } else { eventLoop.execute { - writeAssumingInEventLoop(byteBuffer, continuation) + self.channelHandler.value.cancel() } } } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index decd0c1a..a25b6d41 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -110,6 +110,12 @@ struct ConnectionStateMachine { /// Send a `CopyFail` message to the backend with the given error message. case sendCopyFail(message: String) + /// Fail the promise with the given error and close the connection. + /// + /// This is used when we want to cancel a COPY operation while waiting for backpressure relieve. In that case we + /// can't recover the connection because we can't send any messages to the backend, so we need to close it. + case failPromiseAndCloseConnection(EventLoopPromise, error: PSQLError, cleanupContext: CleanUpContext) + // --- streaming actions // actions if query has requested next row but we are waiting for backend case forwardRows([DataRow]) @@ -867,9 +873,10 @@ struct ConnectionStateMachine { // MARK: Consumer - mutating func cancelQueryStream() -> ConnectionAction { + mutating func cancel() -> ConnectionAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state else { - preconditionFailure("Tried to cancel stream without active query") + // We are not in a state in which we can cancel. Do nothing. + return .wait } self.state = .modifying // avoid CoW @@ -955,7 +962,8 @@ struct ConnectionStateMachine { .triggerCopyData, .sendCopyDoneAndSync, .sendCopyFail, - .succeedQueryContinuation: + .succeedQueryContinuation, + .failPromiseAndCloseConnection: preconditionFailure("Invalid query state machine action in state: \(self.state), action: \(action)") case .evaluateErrorAtConnectionLevel: @@ -1149,6 +1157,9 @@ extension ConnectionStateMachine { return .sendCopyDoneAndSync case .sendCopyFail(message: let message): return .sendCopyFail(message: message) + case .failPromiseAndCloseConnection(let promise, error: let error): + let cleanupContext = self.setErrorAndCreateCleanupContext(error) + return .failPromiseAndCloseConnection(promise, error: error, cleanupContext: cleanupContext) case .forwardRows(let buffer): return .forwardRows(buffer) case .forwardStreamComplete(let buffer, let commandTag): diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 8ecda6ab..84da37c6 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -79,6 +79,12 @@ struct ExtendedQueryStateMachine { /// Send a `CopyFail` message to the backend with the given error message. case sendCopyFail(message: String) + /// Fail the promise with the given error and close the connection. + /// + /// This is used when we want to cancel a COPY operation while waiting for backpressure relieve. In that case we + /// can't recover the connection because we can't send any messages to the backend, so we need to close it. + case failPromiseAndCloseConnection(EventLoopPromise, error: PSQLError) + // --- streaming actions // actions if query has requested next row but we are waiting for backend case forwardRows([DataRow]) @@ -155,8 +161,16 @@ struct ExtendedQueryStateMachine { return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled) } - case .copyingData: - return .sendCopyFail(message: "Copy cancelled") + case .copyingData(.readyToSend): + // We can't initiate an exit from the copy state here because `copyingFinished`, which is the state that is + // reached after sending a `CopyFail` requires a continuation that waits for the `CommandComplete` or + // `ErrorResponse`. Instead, we assume that the next call to `CopyFromWriter.write` checks cancellation and + // initiates the `CopyFail` with the cancellation. + return .wait + + case .copyingData(.pendingBackpressureRelieve(let promise)): + self.state = .error(.queryCancelled) + return .failPromiseAndCloseConnection(promise, error: .queryCancelled) case .copyingFinished: // We already finished the copy and are awaiting the `CommandComplete` or `ErrorResponse` from it. There's diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index cafe0f21..6f896327 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -177,26 +177,50 @@ final class PostgresChannelHandler: ChannelDuplexHandler { /// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer /// should be aborted to avoid unnecessary work. func checkBackendCanReceiveCopyData(promise: EventLoopPromise) { - self.state.checkBackendCanReceiveCopyData(channelIsWritable: handlerContext!.channel.isWritable, promise: promise) + guard let handlerContext else { + promise.fail(PostgresError.connectionClosed) + return + } + self.state.checkBackendCanReceiveCopyData(channelIsWritable: handlerContext.channel.isWritable, promise: promise) + } + + /// Cancel the currently executing operation, if it is cancellable. + func cancel() { + guard let handlerContext else { + return + } + let action = self.state.cancel() + self.run(action, with: handlerContext) } /// Send a `CopyData` message to the backend using the given data. - func sendCopyData(_ data: ByteBuffer) { + func sendCopyData(_ data: ByteBuffer) throws { + guard let handlerContext else { + throw PostgresError.connectionClosed + } self.encoder.copyDataHeader(dataLength: UInt32(data.readableBytes)) - self.handlerContext!.write(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) - self.handlerContext!.writeAndFlush(self.wrapOutboundOut(data), promise: nil) + handlerContext.write(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil) + handlerContext.writeAndFlush(self.wrapOutboundOut(data), promise: nil) } /// Put the state machine out of the copying mode and send a `CopyDone` message to the backend. func sendCopyDone(continuation: CheckedContinuation) { + guard let handlerContext else { + continuation.resume(throwing: PostgresError.connectionClosed) + return + } let action = self.state.sendCopyDone(continuation: continuation) - self.run(action, with: self.handlerContext!) + self.run(action, with: handlerContext) } /// Put the state machine out of the copying mode and send a `CopyFail` message to the backend. func sendCopyFail(message: String, continuation: CheckedContinuation) { + guard let handlerContext else { + continuation.resume(throwing: PostgresError.connectionClosed) + return + } let action = self.state.sendCopyFail(message: message, continuation: continuation) - self.run(action, with: self.handlerContext!) + self.run(action, with: handlerContext) } func channelReadComplete(context: ChannelHandlerContext) { @@ -489,6 +513,9 @@ final class PostgresChannelHandler: ChannelDuplexHandler { } case .forwardNotificationToListeners(let notification): self.forwardNotificationToListeners(notification, context: context) + case .failPromiseAndCloseConnection(let promise, let error, let cleanupContext): + promise.fail(error) + self.closeConnectionAndCleanup(cleanupContext, context: context) } } @@ -860,11 +887,10 @@ extension PostgresChannelHandler: PSQLRowsDataSource { } func cancel(for stream: PSQLRowStream) { - guard self.rowStream === stream, let handlerContext = self.handlerContext else { + guard self.rowStream === stream else { return } - let action = self.state.cancelQueryStream() - self.run(action, with: handlerContext) + self.cancel() } } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 872664af..758f83e2 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -140,7 +140,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(promise, with: .init(value: .rowDescription(expected), logger: logger))) - XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: false, cleanupContext: nil)) + XCTAssertEqual(state.cancel(), .forwardStreamError(.queryCancelled, read: false, cleanupContext: nil)) XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test1")]), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.readEventCaught(), .read) @@ -188,7 +188,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.dataRowReceived(row1), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) XCTAssertEqual(state.readEventCaught(), .wait) - XCTAssertEqual(state.cancelQueryStream(), .forwardStreamError(.queryCancelled, read: true, cleanupContext: nil)) + XCTAssertEqual(state.cancel(), .forwardStreamError(.queryCancelled, read: true, cleanupContext: nil)) XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test2")]), .wait) XCTAssertEqual(state.dataRowReceived([ByteBuffer(string: "test3")]), .wait) @@ -287,7 +287,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) - XCTAssertEqual(state.cancelQueryStream(), .failQuery(promise, with: .queryCancelled, cleanupContext: .none)) + XCTAssertEqual(state.cancel(), .failQuery(promise, with: .queryCancelled, cleanupContext: .none)) let serverError = PostgresBackendMessage.ErrorResponse(fields: [.severity: "Error", .sqlState: "123"]) XCTAssertEqual(state.errorReceived(serverError), .wait) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 9c3dabfc..68429423 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -631,7 +631,7 @@ class PostgresConnectionTests: XCTestCase { } validateCopyRequest: { copyRequest in XCTAssertEqual(copyRequest.parse.query, "COPY copy_table(id,name) FROM STDIN WITH (FORMAT text)") XCTAssertEqual(copyRequest.bind.parameters, []) - } mockBackend: { channel in + } mockBackend: { channel, _ in let data = try await channel.waitForCopyData() XCTAssertEqual(String(buffer: data.data), "1\tAlice\n") XCTAssertEqual(data.result, .done) @@ -648,7 +648,7 @@ class PostgresConnectionTests: XCTestCase { } validateCopyRequest: { copyRequest in XCTAssertEqual(copyRequest.parse.query, #"COPY copy_table(id,name) FROM STDIN WITH (FORMAT text,DELIMITER U&'\002c')"#) XCTAssertEqual(copyRequest.bind.parameters, []) - } mockBackend: { channel in + } mockBackend: { channel, _ in let data = try await channel.waitForCopyData() XCTAssertEqual(String(buffer: data.data), "1,Alice\n") XCTAssertEqual(data.result, .done) @@ -665,7 +665,7 @@ class PostgresConnectionTests: XCTestCase { throw MyError() } validateCopyFromError: { error in XCTAssert(error is MyError, "Expected error of type MyError, got \(error)") - } mockBackend: { channel in + } mockBackend: { channel, _ in let data = try await channel.waitForCopyData() XCTAssertEqual(data.result, .failed(message: "My error")) try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ @@ -685,7 +685,7 @@ class PostgresConnectionTests: XCTestCase { await iterator.next() } validateCopyFromError: { error in XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") - } mockBackend: { channel in + } mockBackend: { channel, _ in let copyDataMessage = try await channel.waitForPostgresFrontendMessage(\.copyData) XCTAssertEqual(copyDataMessage, PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n"))) @@ -702,7 +702,7 @@ class PostgresConnectionTests: XCTestCase { try await writer.write(ByteBuffer(staticString: "1Alice\n")) } validateCopyFromError: { error in XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") - } mockBackend: { channel in + } mockBackend: { channel, _ in _ = try await channel.waitForCopyData() try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ .message: #"invalid input syntax for type integer: "1Alice""#, @@ -728,7 +728,7 @@ class PostgresConnectionTests: XCTestCase { throw MyError() } validateCopyFromError: { error in XCTAssert(error is MyError, "Expected MyError, got \(error)") - } mockBackend: { channel in + } mockBackend: { channel, _ in let copyDataMessage = try await channel.waitForPostgresFrontendMessage(\.copyData) XCTAssertEqual(copyDataMessage, PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n"))) @@ -757,7 +757,7 @@ class PostgresConnectionTests: XCTestCase { } } validateCopyFromError: { error in XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") - } mockBackend: { channel in + } mockBackend: { channel, _ in let dataMessage = try await channel.waitForPostgresFrontendMessage(\.copyData) XCTAssertEqual(dataMessage, PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n"))) @@ -790,7 +790,7 @@ class PostgresConnectionTests: XCTestCase { } } validateCopyFromError: { error in XCTAssert(error is MyError, "Expected MyError, got \(error)") - } mockBackend: { channel in + } mockBackend: { channel, _ in let dataMessage = try await channel.waitForPostgresFrontendMessage(\.copyData) XCTAssertEqual(dataMessage, PostgresFrontendMessage.CopyData(data: ByteBuffer(staticString: "1Alice\n"))) @@ -850,7 +850,7 @@ class PostgresConnectionTests: XCTestCase { isWriting.store(false, ordering: .sequentiallyConsistent) } preCopyInResponse: { channel in channel.isWritable = false - } mockBackend: { channel in + } mockBackend: { channel, _ in XCTAssert(isWriting.load(ordering: .sequentiallyConsistent)) channel.isWritable = true @@ -863,6 +863,57 @@ class PostgresConnectionTests: XCTestCase { #endif } + func testCopyFromCancelled() async throws { + try await assertCopyFrom { writer in + while true { + try await writer.write(ByteBuffer(staticString: "1\tAlice\n")) + try await Task.sleep(for: .milliseconds(10)) + } + } validateCopyFromError: { error in + XCTAssert(error is CancellationError, "Expected CancellationError, got \(error)") + } mockBackend: { channel, cancelCopy in + cancelCopy() + + let data = try await channel.waitForCopyData() + XCTAssertEqual(data.result, .failed(message: "CancellationError()")) + + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .message: "COPY from stdin failed: CancellationError()", + .sqlState : "57014" // query_canceled + ]))) + } + } + + func testCopyFromCancelledWhileWaitingForBackpressureRelieve() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + do { + try await connection.copyFrom(table: "test", logger: .psqlTest) { writer in + try await writer.write(ByteBuffer(staticString: "1\tAlice\n")) + } + XCTFail("Expected `copyFrom` to throw but it did not") + } catch { + XCTAssert(error is CancellationError, "Expected CancellationError, got \(error)") + } + } + + _ = try await channel.waitForUnpreparedRequest() + + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + channel.isWritable = false + try await channel.writeInbound(PostgresBackendMessage.copyInResponse(.init(format: .textual, columnFormats: Array(repeating: .textual, count: 2)))) + + // Wait for the `PostgresCopyFromWriter.write` call to execute and hit the write backpressure before we cancel the task. + try await Task.sleep(for: .milliseconds(200)) + taskGroup.cancelAll() + + // Check that the connection got closed because of the cancellation. + try await connection.closeFuture.get() + } + } + func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in @@ -911,7 +962,8 @@ class PostgresConnectionTests: XCTestCase { /// - validateCopyRequest: Can be used to verify the shape of the `COPY` query that is received by the backend. /// - mockBackend: determines how the backend behaves, starting after the point where the backend has sent the /// `CopyInResponse` and ending in the state where the backend has sent a `CommandComplete` or `ErrorResponse` - /// and is now expecting a `Sync` to return back to the idle state. + /// and is now expecting a `Sync` to return back to the idle state. The closure may call the `cancelCopyFrom` + /// closure that is passed to it to cancel the COPY operation. private func assertCopyFrom( table: StaticString = "copy_table", columns: [StaticString] = ["id", "name"], @@ -920,7 +972,7 @@ class PostgresConnectionTests: XCTestCase { validateCopyFromError: (@Sendable (any Error) -> Void)? = nil, preCopyInResponse: (_ channel: NIOAsyncTestingChannel) -> Void = { _ in }, validateCopyRequest: (UnpreparedRequest) -> Void = { _ in }, - mockBackend: (_ channel: NIOAsyncTestingChannel) async throws -> Void, + mockBackend: (_ channel: NIOAsyncTestingChannel, _ cancelCopy: () -> Void) async throws -> Void, file: StaticString = #file, line: UInt = #line ) async throws { @@ -953,7 +1005,7 @@ class PostgresConnectionTests: XCTestCase { preCopyInResponse(channel) try await channel.writeInbound(PostgresBackendMessage.copyInResponse(.init(format: .textual, columnFormats: Array(repeating: .textual, count: columns.count)))) - try await mockBackend(channel) + try await mockBackend(channel, { taskGroup.cancelAll() }) try await channel.waitForPostgresFrontendMessage(\.sync) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) From cca91b87a043020e31ebd383053e61ca76958eef Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Tue, 8 Jul 2025 11:27:12 +0200 Subject: [PATCH 3/9] Address review comments --- .../PostgresConnection+CopyFrom.swift | 77 ++++++++----------- .../ExtendedQueryStateMachine.swift | 2 +- .../New/PostgresConnectionTests.swift | 22 +++--- 3 files changed, 44 insertions(+), 57 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift index 51900b51..0428a616 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift @@ -1,16 +1,5 @@ /// Handle to send data for a `COPY ... FROM STDIN` query to the backend. public struct PostgresCopyFromWriter: Sendable { - /// The backend failed the copy data transfer, which means that no more data sent by the frontend would be processed. - /// - /// The `PostgresCopyFromWriter` should cancel the data transfer. - public struct CopyCancellationError: Error { - /// The error that the backend sent us which cancelled the data transfer. - /// - /// Note that this error is related to previous `write` calls since a `CopyCancellationError` is thrown before - /// new data is written by `write`. - public let underlyingError: PSQLError - } - private let channelHandler: NIOLoopBound private let eventLoop: any EventLoop @@ -42,9 +31,9 @@ public struct PostgresCopyFromWriter: Sendable { /// Send data for a `COPY ... FROM STDIN` operation to the backend. /// - /// If the backend encountered an error during the data transfer and thus cannot process any more data, this throws - /// a `CopyCancellationError`. - public func write(_ byteBuffer: ByteBuffer) async throws { + /// - Throws: If an error occurs during the write of if the backend sent an `ErrorResponse` during the copy + /// operation, eg. to indicate that a **previous** `write` call had an invalid format. + public func write(_ byteBuffer: ByteBuffer, isolation: isolated (any Actor)? = #isolation) async throws { // Check for cancellation. This is cheap and makes sure that we regularly check for cancellation in the // `writeData` closure. It is likely that the user would forget to do so. try Task.checkCancellation() @@ -82,7 +71,7 @@ public struct PostgresCopyFromWriter: Sendable { /// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyDone` message to /// the backend. - func done() async throws { + func done(isolation: isolated (any Actor)? = #isolation) async throws { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in if eventLoop.inEventLoop { self.channelHandler.value.sendCopyDone(continuation: continuation) @@ -96,16 +85,13 @@ public struct PostgresCopyFromWriter: Sendable { /// Finalize the data transfer, putting the state machine out of the copy mode and sending a `CopyFail` message to /// the backend. - func failed(error: any Error) async throws { + func failed(error: any Error, isolation: isolated (any Actor)? = #isolation) async throws { try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - // TODO: Is it OK to use string interpolation to construct an error description to be sent to the backend - // here? We could also use a generic description, it doesn't really matter since we throw the user's error - // in `copyFrom`. if eventLoop.inEventLoop { - self.channelHandler.value.sendCopyFail(message: "\(error)", continuation: continuation) + self.channelHandler.value.sendCopyFail(message: "Client failed copy", continuation: continuation) } else { eventLoop.execute { - self.channelHandler.value.sendCopyFail(message: "\(error)", continuation: continuation) + self.channelHandler.value.sendCopyFail(message: "Client failed copy", continuation: continuation) } } } @@ -113,20 +99,29 @@ public struct PostgresCopyFromWriter: Sendable { } /// Specifies the format in which data is transferred to the backend in a COPY operation. -public enum PostgresCopyFromFormat: Sendable { +/// +/// See the Postgres documentation at https://www.postgresql.org/docs/current/sql-copy.html for the option's meanings +/// and their default values. +public struct PostgresCopyFromFormat: Sendable { /// Options that can be used to modify the `text` format of a COPY operation. public struct TextOptions: Sendable { /// The delimiter that separates columns in the data. /// /// See the `DELIMITER` option in Postgres's `COPY` command. - /// - /// Uses the default delimiter of the format public var delimiter: UnicodeScalar? = nil public init() {} } - case text(TextOptions) + enum Format { + case text(TextOptions) + } + + var format: Format + + public static func text(_ options: TextOptions) -> PostgresCopyFromFormat { + return PostgresCopyFromFormat(format: .text(options)) + } } /// Create a `COPY ... FROM STDIN` query based on the given parameters. @@ -138,14 +133,17 @@ private func buildCopyFromQuery( columns: [StaticString] = [], format: PostgresCopyFromFormat ) -> PostgresQuery { - // TODO: Should we put the table and column names in quotes to make them case-sensitive? - var query = "COPY \(table)" + var query = """ + COPY "\(table)" + """ if !columns.isEmpty { - query += "(" + columns.map(\.description).joined(separator: ",") + ")" + query += "(" + query += columns.map { #"""# + $0.description + #"""# }.joined(separator: ",") + query += ")" } query += " FROM STDIN" var queryOptions: [String] = [] - switch format { + switch format.format { case .text(let options): queryOptions.append("FORMAT text") if let delimiter = options.delimiter { @@ -179,6 +177,7 @@ extension PostgresConnection { columns: [StaticString] = [], format: PostgresCopyFromFormat = .text(.init()), logger: Logger, + isolation: isolated (any Actor)? = #isolation, file: String = #fileID, line: Int = #line, writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void @@ -205,22 +204,13 @@ extension PostgresConnection { // threw instead of the one that got relayed back, so it's better to ignore the error here. // - The backend sent us an `ErrorResponse` during the copy, eg. because of an invalid format. This puts // the `ExtendedQueryStateMachine` in the error state. Trying to send a `CopyFail` will throw but trigger - // a `Sync` that takes the backend out of copy mode. If `writeData` threw the `CopyCancellationError` - // from the `PostgresCopyFromWriter.write` call, `writer.failed` will throw with the same error, so it - // doesn't matter that we ignore the error here. If the user threw some other error, it's better to honor - // the user's error. + // a `Sync` that takes the backend out of copy mode. If `writeData` threw the error from from the + // `PostgresCopyFromWriter.write` call, `writer.failed` will throw with the same error, so it doesn't + // matter that we ignore the error here. If the user threw some other error, it's better to honor the + // user's error. try? await writer.failed(error: error) - if let error = error as? PostgresCopyFromWriter.CopyCancellationError { - // If we receive a `CopyCancellationError` that is with almost certain likelihood because - // `PostgresCopyFromWriter.write` threw it - otherwise the user must have saved a previous - // `PostgresCopyFromWriter` error, which is very unlikely. - // Throw the underlying error because that contains the error message that was sent by the backend and - // is most actionable by the user. - throw error.underlyingError - } else { - throw error - } + throw error } // `writer.done` may fail, eg. because the backend sends an error response after receiving `CopyDone` or during @@ -230,5 +220,4 @@ extension PostgresConnection { // above. try await writer.done() } - } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 84da37c6..215c3d7e 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -427,7 +427,7 @@ struct ExtendedQueryStateMachine { if case .error(let error) = self.state { // The backend sent us an ErrorResponse during the copy operation. Indicate to the client that it should // abort the data transfer. - promise.fail(PostgresCopyFromWriter.CopyCancellationError(underlyingError: error)) + promise.fail(error) return } guard case .copyingData(.readyToSend) = self.state else { diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 68429423..031ada76 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -629,7 +629,7 @@ class PostgresConnectionTests: XCTestCase { try await assertCopyFrom { writer in try await writer.write(ByteBuffer(staticString: "1\tAlice\n")) } validateCopyRequest: { copyRequest in - XCTAssertEqual(copyRequest.parse.query, "COPY copy_table(id,name) FROM STDIN WITH (FORMAT text)") + XCTAssertEqual(copyRequest.parse.query, #"COPY "copy_table"("id","name") FROM STDIN WITH (FORMAT text)"#) XCTAssertEqual(copyRequest.bind.parameters, []) } mockBackend: { channel, _ in let data = try await channel.waitForCopyData() @@ -646,7 +646,7 @@ class PostgresConnectionTests: XCTestCase { try await assertCopyFrom(format: .text(options)) { writer in try await writer.write(ByteBuffer(staticString: "1,Alice\n")) } validateCopyRequest: { copyRequest in - XCTAssertEqual(copyRequest.parse.query, #"COPY copy_table(id,name) FROM STDIN WITH (FORMAT text,DELIMITER U&'\002c')"#) + XCTAssertEqual(copyRequest.parse.query, #"COPY "copy_table"("id","name") FROM STDIN WITH (FORMAT text,DELIMITER U&'\002c')"#) XCTAssertEqual(copyRequest.bind.parameters, []) } mockBackend: { channel, _ in let data = try await channel.waitForCopyData() @@ -657,9 +657,7 @@ class PostgresConnectionTests: XCTestCase { } func testCopyFromWriterFails() async throws { - struct MyError: Error, CustomStringConvertible { - var description: String { "My error" } - } + struct MyError: Error {} try await assertCopyFrom { writer in throw MyError() @@ -667,9 +665,9 @@ class PostgresConnectionTests: XCTestCase { XCTAssert(error is MyError, "Expected error of type MyError, got \(error)") } mockBackend: { channel, _ in let data = try await channel.waitForCopyData() - XCTAssertEqual(data.result, .failed(message: "My error")) + XCTAssertEqual(data.result, .failed(message: "Client failed copy")) try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ - .message: "COPY from stdin failed: My error", + .message: "COPY from stdin failed: Client failed copy", .sqlState : "57014" // query_canceled ]))) } @@ -752,7 +750,7 @@ class PostgresConnectionTests: XCTestCase { try await writer.write(ByteBuffer(staticString: "2\tBob\n")) XCTFail("Expected error to be thrown") } catch { - XCTAssert(error is PostgresCopyFromWriter.CopyCancellationError, "Received unexpected error: \(error)") + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") throw error } } validateCopyFromError: { error in @@ -769,7 +767,7 @@ class PostgresConnectionTests: XCTestCase { } } - func testCopyFromCallerDoesNotRethrowCopyCancellationError() async throws { + func testCopyFromCallerDoesNotRethrowFromWriteCall() async throws { struct MyError: Error, CustomStringConvertible { var description: String { "My error" } } @@ -785,7 +783,7 @@ class PostgresConnectionTests: XCTestCase { try await writer.write(ByteBuffer(staticString: "2\tBob\n")) XCTFail("Expected error to be thrown") } catch { - XCTAssert(error is PostgresCopyFromWriter.CopyCancellationError, "Received unexpected error: \(error)") + XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "22P02") throw MyError() } } validateCopyFromError: { error in @@ -875,10 +873,10 @@ class PostgresConnectionTests: XCTestCase { cancelCopy() let data = try await channel.waitForCopyData() - XCTAssertEqual(data.result, .failed(message: "CancellationError()")) + XCTAssertEqual(data.result, .failed(message: "Client failed copy")) try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ - .message: "COPY from stdin failed: CancellationError()", + .message: "COPY from stdin failed: Client failed copy", .sqlState : "57014" // query_canceled ]))) } From cb05b24aec452593eb234fa09504cdc8bd9a6a33 Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Tue, 8 Jul 2025 12:06:00 +0200 Subject: [PATCH 4/9] Communicate side effects of `checkBackendCanReceiveCopyData` via an action --- .../ConnectionStateMachine.swift | 16 ++++++++++++++-- .../ExtendedQueryStateMachine.swift | 8 ++++---- .../PostgresNIO/New/PostgresChannelHandler.swift | 10 +++++++++- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index a25b6d41..6b0a8059 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -141,6 +141,17 @@ struct ConnectionStateMachine { case succeedPromise(EventLoopPromise) } + enum CheckBackendCanReceiveCopyDataAction { + /// Don't perform any action. + case none + + /// Succeed the promise with a Void result. + case succeedPromise(EventLoopPromise) + + /// Fail the promise with the given error. + case failPromise(EventLoopPromise, error: any Error) + } + private var state: State private let requireBackendKeyData: Bool private var taskQueue = CircularBuffer() @@ -815,14 +826,15 @@ struct ConnectionStateMachine { /// The promise may be failed if the backend indicated that it can't handle any more data by sending an /// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer /// should be aborted to avoid unnecessary work. - mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise) { + mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise) -> CheckBackendCanReceiveCopyDataAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state else { preconditionFailure("Copy mode is only supported for extended queries") } self.state = .modifying // avoid CoW - queryState.checkBackendCanReceiveCopyData(channelIsWritable: channelIsWritable, promise: promise) + let action = queryState.checkBackendCanReceiveCopyData(channelIsWritable: channelIsWritable, promise: promise) self.state = .extendedQuery(queryState, connectionContext) + return action } /// Put the state machine out of the copying mode and send a `CopyDone` message to the backend. diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 215c3d7e..bba44f0f 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -423,22 +423,22 @@ struct ExtendedQueryStateMachine { /// The promise may be failed if the backend indicated that it can't handle any more data by sending an /// `ErrorResponse`. This is mostly the case when malformed data is sent to it. In that case, the data transfer /// should be aborted to avoid unnecessary work. - mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise) { + mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise) -> ConnectionStateMachine.CheckBackendCanReceiveCopyDataAction { if case .error(let error) = self.state { // The backend sent us an ErrorResponse during the copy operation. Indicate to the client that it should // abort the data transfer. promise.fail(error) - return + return . failPromise(promise, error: error) } guard case .copyingData(.readyToSend) = self.state else { preconditionFailure("Not ready to send data") } if channelIsWritable { - promise.succeed() - return + return .succeedPromise(promise) } return avoidingStateMachineCoW { state in state = .copyingData(.pendingBackpressureRelieve(promise)) + return .none } } diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 6f896327..d1470dee 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -181,7 +181,15 @@ final class PostgresChannelHandler: ChannelDuplexHandler { promise.fail(PostgresError.connectionClosed) return } - self.state.checkBackendCanReceiveCopyData(channelIsWritable: handlerContext.channel.isWritable, promise: promise) + let action = self.state.checkBackendCanReceiveCopyData(channelIsWritable: handlerContext.channel.isWritable, promise: promise) + switch action { + case .none: + break + case .succeedPromise(let promise): + promise.succeed() + case .failPromise(let promise, error: let error): + promise.fail(error) +} } /// Cancel the currently executing operation, if it is cancellable. From 568d258bc9456159daf4d8a3dc41506404cb373b Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Tue, 8 Jul 2025 12:39:35 +0200 Subject: [PATCH 5/9] Disable `copyFrom` when using a Swift <6 compiler Swift 5.10 does not support `#isolation`, which we use. --- .../PostgresNIO/Connection/PostgresConnection+CopyFrom.swift | 4 ++++ Tests/IntegrationTests/PSQLIntegrationTests.swift | 2 ++ Tests/PostgresNIOTests/New/PostgresConnectionTests.swift | 4 ++++ 3 files changed, 10 insertions(+) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift index 0428a616..51319f41 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift @@ -29,6 +29,7 @@ public struct PostgresCopyFromWriter: Sendable { } } + #if compiler(>=6.0) /// Send data for a `COPY ... FROM STDIN` operation to the backend. /// /// - Throws: If an error occurs during the write of if the backend sent an `ErrorResponse` during the copy @@ -96,6 +97,7 @@ public struct PostgresCopyFromWriter: Sendable { } } } + #endif } /// Specifies the format in which data is transferred to the backend in a COPY operation. @@ -124,6 +126,7 @@ public struct PostgresCopyFromFormat: Sendable { } } +#if compiler(>=6.0) /// Create a `COPY ... FROM STDIN` query based on the given parameters. /// /// An empty `columns` array signifies that no columns should be specified in the query and that all columns will be @@ -221,3 +224,4 @@ extension PostgresConnection { try await writer.done() } } +#endif diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 35581edb..267e70e9 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -379,6 +379,7 @@ final class IntegrationTests: XCTestCase { } } + #if compiler(>=6.0) func testCopyIntoFrom() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } @@ -487,4 +488,5 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "42601") // scanner_yyerror } } + #endif } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 031ada76..4e43ed38 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -625,6 +625,7 @@ class PostgresConnectionTests: XCTestCase { } } + #if compiler(>=6.0) func testCopyFromSucceeds() async throws { try await assertCopyFrom { writer in try await writer.write(ByteBuffer(staticString: "1\tAlice\n")) @@ -911,6 +912,7 @@ class PostgresConnectionTests: XCTestCase { try await connection.closeFuture.get() } } + #endif func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() @@ -944,6 +946,7 @@ class PostgresConnectionTests: XCTestCase { return (connection, channel) } + #if compiler(>=6.0) /// Validate the behavior of a `COPY FROM` query. /// /// Also checks that the connection returns to an idle state after performing the copy and is capable @@ -1011,6 +1014,7 @@ class PostgresConnectionTests: XCTestCase { _ = try await channel.waitForUnpreparedRequest() // Await the dummy query messages } } + #endif } extension NIOAsyncTestingChannel { From 1602d85f28df7671c976fda0341a8002fa22b680 Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Thu, 24 Jul 2025 13:26:32 +0200 Subject: [PATCH 6/9] Allow `String` to be used for table and column names in `COPY FROM` --- .../Connection/PostgresConnection+CopyFrom.swift | 15 +++++++++------ .../New/PostgresConnectionTests.swift | 4 ++-- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift index 51319f41..b7f4935a 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift @@ -131,9 +131,12 @@ public struct PostgresCopyFromFormat: Sendable { /// /// An empty `columns` array signifies that no columns should be specified in the query and that all columns will be /// copied by the caller. +/// +/// - Important: The table and column names are inserted into the `COPY FROM` query as passed and might thus be +/// susceptible to SQL injection. Ensure no untrusted data is contained in these strings. private func buildCopyFromQuery( - table: StaticString, - columns: [StaticString] = [], + table: String, + columns: [String] = [], format: PostgresCopyFromFormat ) -> PostgresQuery { var query = """ @@ -173,11 +176,11 @@ extension PostgresConnection { /// Throw an error from the closure to fail the data transfer. The error thrown by the closure will be rethrown /// by the `copyFrom` function. /// - /// - Note: The table and column names are inserted into the SQL query verbatim. They are forced to be compile-time - /// specified to avoid runtime SQL injection attacks. + /// - Important: The table and column names are inserted into the `COPY FROM` query as passed and might thus be + /// susceptible to SQL injection. Ensure no untrusted data is contained in these strings. public func copyFrom( - table: StaticString, - columns: [StaticString] = [], + table: String, + columns: [String] = [], format: PostgresCopyFromFormat = .text(.init()), logger: Logger, isolation: isolated (any Actor)? = #isolation, diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 4e43ed38..82c493a3 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -966,8 +966,8 @@ class PostgresConnectionTests: XCTestCase { /// and is now expecting a `Sync` to return back to the idle state. The closure may call the `cancelCopyFrom` /// closure that is passed to it to cancel the COPY operation. private func assertCopyFrom( - table: StaticString = "copy_table", - columns: [StaticString] = ["id", "name"], + table: String = "copy_table", + columns: [String] = ["id", "name"], format: PostgresCopyFromFormat = .text(.init()), writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void, validateCopyFromError: (@Sendable (any Error) -> Void)? = nil, From d52ec9373f21c4153128ca7d19f0c87d1cc8f5b2 Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Thu, 21 Aug 2025 09:47:55 +0200 Subject: [PATCH 7/9] =?UTF-8?q?Don=E2=80=99t=20require=20`writeData`=20to?= =?UTF-8?q?=20be=20`Sendable`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../PostgresNIO/Connection/PostgresConnection+CopyFrom.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift index b7f4935a..5f60e352 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift @@ -186,7 +186,7 @@ extension PostgresConnection { isolation: isolated (any Actor)? = #isolation, file: String = #fileID, line: Int = #line, - writeData: @escaping @Sendable (PostgresCopyFromWriter) async throws -> Void + writeData: (PostgresCopyFromWriter) async throws -> Void ) async throws { var logger = logger logger[postgresMetadataKey: .connectionID] = "\(self.id)" From ca3f4861b4a6e66b98fb89dc380d8ede3e426c0e Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Thu, 21 Aug 2025 09:49:03 +0200 Subject: [PATCH 8/9] Support binary data transfer in `COPY FROM` My benchmark of transferring the integers from 0 to 1,000,000 both as an integer and as a string was about the same speed as the old text-based transfer. I believe that the binary transfer will start to show significant benefits when transferring binary data, other fields that don't need to be represented as fields and also means that the user doesn't need to worry about escapping their data. --- .../PostgresConnection+CopyFrom.swift | 164 ++++++++++++++++++ .../PSQLIntegrationTests.swift | 35 ++++ .../New/PostgresConnectionTests.swift | 62 +++++++ 3 files changed, 261 insertions(+) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift index 5f60e352..b647aedc 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift @@ -1,3 +1,111 @@ +import NIO + +#if compiler(>=6.0) +/// Handle to send binary data for a `COPY ... FROM STDIN` query to the backend. +/// +/// It takes care of serializing `PostgresEncodable` column types into the binary format that Postgres expects. +public struct PostgresBinaryCopyFromWriter: ~Copyable { + /// Handle to serialize columns into a row that is being written by `PostgresBinaryCopyFromWriter`. + public struct ColumnWriter: ~Copyable { + /// The `PostgresBinaryCopyFromWriter` that is gathering the serialized data. + /// + /// We need to model this as `UnsafeMutablePointer` because we can't express in the Swift type system that + /// `ColumnWriter` never exceeds the lifetime of `PostgresBinaryCopyFromWriter`. + @usableFromInline + let underlying: UnsafeMutablePointer + + /// The number of columns that have been written by this `ColumnWriter`. + @usableFromInline + var columns: UInt16 = 0 + + @usableFromInline + init(underlying: UnsafeMutablePointer) { + self.underlying = underlying + } + + /// Serialize a single column to a row. + /// + /// - Important: It is critical that that data type encoded here exactly matches the data type in the + /// databasse. For example, if the database stores an a 4-bit integer the corresponding `writeColumn` must + /// be called with an `Int32`. Serializing an integer of a different width will cause a deserialization + /// failure in the backend. + @inlinable + public mutating func writeColumn(_ column: (some PostgresEncodable)?) throws { + columns += 1 + try underlying.pointee.writeColumn(column) + } + } + + /// The underlying `PostgresCopyFromWriter` that sends the serialized data to the backend. + @usableFromInline let underlying: PostgresCopyFromWriter + + /// The buffer in which we accumulate binary data. Once this buffer exceeds `bufferSize`, we flush it to + /// the backend. + @usableFromInline var buffer = ByteBuffer() + + /// Once `buffer` exceeds this size, it gets flushed to the backend. + @usableFromInline let bufferSize: Int + + init(underlying: PostgresCopyFromWriter, bufferSize: Int) { + self.underlying = underlying + // Allocate 10% more than the buffer size because we only flush the buffer once it has exceeded `bufferSize` + buffer.reserveCapacity(bufferSize + bufferSize / 10) + self.bufferSize = bufferSize + } + + /// Serialize a single row to the backend. Call `writeColumn` on `columnWriter` for every column that should be + /// included in the row. + @inlinable + public mutating func writeRow(_ body: (_ columnWriter: inout ColumnWriter) throws -> Void) async throws { + // Write a placeholder for the number of columns + let columnIndex = buffer.writerIndex + buffer.writeInteger(UInt16(0)) + + let columns = try withUnsafeMutablePointer(to: &self) { pointerToSelf in + // Important: We need to ensure that `pointerToSel` (and thus `ColumnWriter`) does not exceed the lifetime + // of `self` because it is holding an unsafe reference to it. + // + // We achieve this because `ColumnWriter` is non-Copyable and thus the client can't store a copy to it. + // Futhermore `columnWriter` is destroyed before the end of `withUnsafeMutablePointer`, which holds `self` + // alive. + var columnWriter = ColumnWriter(underlying: pointerToSelf) + + try body(&columnWriter) + + return columnWriter.columns + } + + // Fill in the number of columns + buffer.setInteger(columns, at: columnIndex) + + if buffer.readableBytes > bufferSize { + try await flush() + } + } + + /// Serialize a single column to the buffer. Should only be called by `ColumnWriter`. + @inlinable + mutating func writeColumn(_ column: (some PostgresEncodable)?) throws { + if let column { + try buffer.writeLengthPrefixed(as: Int32.self) { buffer in + let startIndex = buffer.writerIndex + try column.encode(into: &buffer, context: .default) + return buffer.writerIndex - startIndex + } + } else { + buffer.writeInteger(Int32(-1)) + } + } + + /// Flush any pending data in the buffer to the backend. + @usableFromInline + mutating func flush(isolation: (any Actor)? = #isolation) async throws { + try await underlying.write(buffer) + buffer.clear() + } +} +#endif + /// Handle to send data for a `COPY ... FROM STDIN` query to the backend. public struct PostgresCopyFromWriter: Sendable { private let channelHandler: NIOLoopBound @@ -115,8 +223,14 @@ public struct PostgresCopyFromFormat: Sendable { public init() {} } + /// Options that can be used to modify the `binary` format of a COPY operation. + public struct BinaryOptions: Sendable { + public init() {} + } + enum Format { case text(TextOptions) + case binary(BinaryOptions) } var format: Format @@ -124,6 +238,10 @@ public struct PostgresCopyFromFormat: Sendable { public static func text(_ options: TextOptions) -> PostgresCopyFromFormat { return PostgresCopyFromFormat(format: .text(options)) } + + public static func binary(_ options: BinaryOptions) -> PostgresCopyFromFormat { + return PostgresCopyFromFormat(format: .binary(options)) + } } #if compiler(>=6.0) @@ -156,6 +274,8 @@ private func buildCopyFromQuery( // Set the delimiter as a Unicode code point. This avoids the possibility of SQL injection. queryOptions.append("DELIMITER U&'\\\(String(format: "%04x", delimiter.value))'") } + case .binary: + queryOptions.append("FORMAT binary") } precondition(!queryOptions.isEmpty) query += " WITH (" @@ -165,6 +285,50 @@ private func buildCopyFromQuery( } extension PostgresConnection { + /// Copy data into a table using a `COPY
FROM STDIN` query, transferring data in a binary format. + /// + /// - Parameters: + /// - table: The name of the table into which to copy the data. + /// - columns: The name of the columns to copy. If an empty array is passed, all columns are assumed to be copied. + /// - bufferSize: How many bytes to accumulate a local buffer before flushing it to the database. Can affect + /// performance characteristics of the copy operation. + /// - writeData: Closure that produces the data for the table, to be streamed to the backend. Call `write` on the + /// writer provided by the closure to send data to the backend and return from the closure once all data is sent. + /// Throw an error from the closure to fail the data transfer. The error thrown by the closure will be rethrown + /// by the `copyFrom` function. + /// + /// - Important: The table and column names are inserted into the `COPY FROM` query as passed and might thus be + /// susceptible to SQL injection. Ensure no untrusted data is contained in these strings. + public func copyFromBinary( + table: String, + columns: [String] = [], + options: PostgresCopyFromFormat.BinaryOptions = .init(), + bufferSize: Int = 100_000, + logger: Logger, + isolation: isolated (any Actor)? = #isolation, + file: String = #fileID, + line: Int = #line, + writeData: (inout PostgresBinaryCopyFromWriter) async throws -> Void + ) async throws { + try await copyFrom(table: table, columns: columns, format: .binary(PostgresCopyFromFormat.BinaryOptions()), logger: logger) { writer in + var header = ByteBuffer() + header.writeString("PGCOPY\n") + header.writeInteger(UInt8(0xff)) + header.writeString("\r\n\0") + + // Flag fields + header.writeInteger(UInt32(0)) + + // Header extension area length + header.writeInteger(UInt32(0)) + try await writer.write(header) + + var binaryWriter = PostgresBinaryCopyFromWriter(underlying: writer, bufferSize: bufferSize) + try await writeData(&binaryWriter) + try await binaryWriter.flush() + } + } + /// Copy data into a table using a `COPY
FROM STDIN` query. /// /// - Parameters: diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 267e70e9..76b1af6a 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -488,5 +488,40 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "42601") // scanner_yyerror } } + + func testCopyFromBinary() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let conn = try await PostgresConnection.test(on: eventLoop).get() + defer { XCTAssertNoThrow(try conn.close().wait()) } + + _ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get() + _ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get() + + try await conn.copyFromBinary(table: "copy_table", columns: ["id", "name"], logger: .psqlTest) { writer in + let records: [(id: Int, name: String)] = [ + (1, "Alice"), + (42, "Bob") + ] + for record in records { + try await writer.writeRow { columnWriter in + try columnWriter.writeColumn(Int32(record.id)) + try columnWriter.writeColumn(record.name) + } + } + } + let rows = try await conn.query("SELECT id, name FROM copy_table").get().rows.map { try $0.decode((Int, String).self) } + guard rows.count == 2 else { + XCTFail("Expected 2 columns, received \(rows.count)") + return + } + XCTAssertEqual(rows[0].0, 1) + XCTAssertEqual(rows[0].1, "Alice") + XCTAssertEqual(rows[1].0, 42) + XCTAssertEqual(rows[1].1, "Bob") + } + #endif } diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 82c493a3..6fe78985 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -912,6 +912,68 @@ class PostgresConnectionTests: XCTestCase { try await connection.closeFuture.get() } } + + func testCopyFromBinary() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + try await connection.copyFromBinary(table: "copy_table", logger: .psqlTest) { writer in + try await writer.writeRow { columnWriter in + try columnWriter.writeColumn(Int32(1)) + try columnWriter.writeColumn("Alice") + } + try await writer.writeRow { columnWriter in + try columnWriter.writeColumn(Int32(2)) + try columnWriter.writeColumn("Bob") + } + } + } + + let copyRequest = try await channel.waitForUnpreparedRequest() + XCTAssertEqual(copyRequest.parse.query, #"COPY "copy_table" FROM STDIN WITH (FORMAT binary)"#) + + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + try await channel.writeInbound(PostgresBackendMessage.copyInResponse(.init(format: .binary, columnFormats: [.binary, .binary]))) + + let copyData = try await channel.waitForCopyData() + XCTAssertEqual(copyData.result, .done) + var data = copyData.data + // Signature + XCTAssertEqual(data.readString(length: 7), "PGCOPY\n") + XCTAssertEqual(data.readInteger(as: UInt8.self), 0xff) + XCTAssertEqual(data.readString(length: 3), "\r\n\0") + // Flags + XCTAssertEqual(data.readInteger(as: UInt32.self), 0) + // Header extension area length + XCTAssertEqual(data.readInteger(as: UInt32.self), 0) + + struct Row: Equatable { + let id: Int32 + let name: String + } + var rows: [Row] = [] + while data.readableBytes > 0 { + // Number of columns + XCTAssertEqual(data.readInteger(as: UInt16.self), 2) + // 'id' column + XCTAssertEqual(data.readInteger(as: UInt32.self), 4) + let id = try XCTUnwrap(data.readInteger(as: Int32.self)) + // 'name' column length + let nameLength = try XCTUnwrap(data.readInteger(as: UInt32.self)) + let name = try XCTUnwrap(data.readString(length: Int(nameLength))) + rows.append(Row(id: id, name: name)) + } + XCTAssertEqual(rows, [ + Row(id: 1, name: "Alice"), + Row(id: 2, name: "Bob") + ]) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1")) + + try await channel.waitForPostgresFrontendMessage(\.sync) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + } #endif func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { From 6dd723c1788ec1c38f8a4b66dd22223ff99a3a88 Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Thu, 24 Jul 2025 21:50:52 +0200 Subject: [PATCH 9/9] Work around compiler assertion failure Work around https://github.com/swiftlang/swift/issues/83309 --- .../Connection/PostgresConnection+CopyFrom.swift | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift index b647aedc..2c77871e 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift @@ -32,7 +32,17 @@ public struct PostgresBinaryCopyFromWriter: ~Copyable { @inlinable public mutating func writeColumn(_ column: (some PostgresEncodable)?) throws { columns += 1 - try underlying.pointee.writeColumn(column) + try invokeWriteColumn(on: underlying, column) + } + + // Needed to work around https://github.com/swiftlang/swift/issues/83309, copying the implementation into + // `writeColumn` causes an assertion failure when thread sanitizer is enabled. + @inlinable + func invokeWriteColumn( + on writer: UnsafeMutablePointer, + _ column: (some PostgresEncodable)? + ) throws { + try writer.pointee.writeColumn(column) } }