From 02bb7ca3942b46968e5d4fd1db5b0708f87a2c01 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Tue, 19 Oct 2021 16:51:35 +0200 Subject: [PATCH 1/2] Make async queries possible better caching performance improvements. More stuff Compiling tests Compile errors for swift<5.5 compiles! compile tests Decoding optionals works better! Clean up --- Package.swift | 1 + .../PostgresConnection+Database.swift | 4 +- .../PostgresNIO/Data/PostgresData+Array.swift | 2 +- .../Message/PostgresMessage+Error.swift | 2 +- ...PostgresMessage+ParameterDescription.swift | 2 +- .../PostgresMessage+RowDescription.swift | 4 +- .../ConnectionStateMachine.swift | 12 +- .../ExtendedQueryStateMachine.swift | 18 +- .../PrepareStatementStateMachine.swift | 4 +- .../RowStreamStateMachine.swift | 18 +- .../New/Data/Array+PSQLCodable.swift | 60 +- .../New/Data/Bool+PSQLCodable.swift | 19 +- .../New/Data/Bytes+PSQLCodable.swift | 16 +- .../New/Data/Date+PSQLCodable.swift | 8 +- .../New/Data/Float+PSQLCodable.swift | 20 +- .../New/Data/Int+PSQLCodable.swift | 50 +- .../New/Data/Optional+PSQLCodable.swift | 27 +- .../New/Data/String+PSQLCodable.swift | 8 +- .../New/Data/UUID+PSQLCodable.swift | 8 +- .../New/Extensions/ByteBuffer+PSQL.swift | 14 +- .../New/Extensions/Logging+PSQL.swift | 14 +- .../New/Messages/Authentication.swift | 1 - .../PostgresNIO/New/Messages/DataRow.swift | 162 ++++- .../New/Messages/RowDescription.swift | 141 ++-- Sources/PostgresNIO/New/PSQL+JSON.swift | 4 +- .../PostgresNIO/New/PSQLBackendMessage.swift | 1 + .../PostgresNIO/New/PSQLChannelHandler.swift | 2 +- Sources/PostgresNIO/New/PSQLCodable.swift | 45 +- Sources/PostgresNIO/New/PSQLConnection.swift | 95 ++- Sources/PostgresNIO/New/PSQLData.swift | 134 ++-- Sources/PostgresNIO/New/PSQLError.swift | 31 +- .../New/PSQLPreparedStatement.swift | 2 +- Sources/PostgresNIO/New/PSQLRow.swift | 155 +++- .../New/PSQLRowSequence+Decoding.swift | 62 ++ Sources/PostgresNIO/New/PSQLRowSequence.swift | 663 ++++++++++++++++++ Sources/PostgresNIO/New/PSQLRowStream.swift | 287 +++++--- Sources/PostgresNIO/New/PSQLTask.swift | 6 +- Sources/PostgresNIO/Postgres+PSQLCompat.swift | 10 +- Sources/PostgresNIO/Utilities/NIOUtils.swift | 2 +- .../PSQLConnection+AsyncTests.swift | 383 ++++++++++ .../PSQLIntegrationTests.swift | 4 +- .../ExtendedQueryStateMachineTests.swift | 16 +- .../PrepareStatementStateMachineTests.swift | 6 +- .../New/Data/String+PSQLCodableTests.swift | 6 +- .../New/Data/UUID+PSQLCodableTests.swift | 6 +- .../PSQLBackendMessage+Equatable.swift | 16 +- .../PSQLBackendMessageEncoder.swift | 17 +- .../New/Messages/DataRowTests.swift | 12 +- .../New/Messages/RowDescriptionTests.swift | 10 +- .../PostgresNIOTests/New/PSQLDataTests.swift | 3 - .../New/PSQLRowSequenceTests.swift | 177 +++++ 51 files changed, 2230 insertions(+), 540 deletions(-) create mode 100644 Sources/PostgresNIO/New/PSQLRowSequence+Decoding.swift create mode 100644 Sources/PostgresNIO/New/PSQLRowSequence.swift create mode 100644 Tests/IntegrationTests/PSQLConnection+AsyncTests.swift create mode 100644 Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift diff --git a/Package.swift b/Package.swift index 64c261b3..05088dab 100644 --- a/Package.swift +++ b/Package.swift @@ -26,6 +26,7 @@ let package = Package( .product(name: "Metrics", package: "swift-metrics"), .product(name: "NIO", package: "swift-nio"), .product(name: "NIOCore", package: "swift-nio"), + .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), .product(name: "NIOPosix", package: "swift-nio"), .product(name: "NIOTLS", package: "swift-nio"), .product(name: "NIOSSL", package: "swift-nio-ssl"), diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift index 725f17d8..68e6c96c 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Database.swift @@ -50,7 +50,7 @@ extension PostgresConnection: PostgresDatabase { let lookupTable = PostgresRow.LookupTable(rowDescription: .init(fields: fields), resultFormat: [.binary]) return rows.all().map { allrows in let r = allrows.map { psqlRow -> PostgresRow in - let columns = psqlRow.data.columns.map { + let columns = psqlRow.data.map { PostgresMessage.DataRow.Column(value: $0) } return PostgresRow(dataRow: .init(columns: columns), lookupTable: lookupTable) @@ -112,7 +112,7 @@ extension PSQLRowStream { func iterateRowsWithoutBackpressureOption(lookupTable: PostgresRow.LookupTable, onRow: @escaping (PostgresRow) throws -> ()) -> EventLoopFuture { self.onRow { psqlRow in - let columns = psqlRow.data.columns.map { + let columns = psqlRow.data.map { PostgresMessage.DataRow.Column(value: $0) } diff --git a/Sources/PostgresNIO/Data/PostgresData+Array.swift b/Sources/PostgresNIO/Data/PostgresData+Array.swift index bbb420bc..0eda8dd2 100644 --- a/Sources/PostgresNIO/Data/PostgresData+Array.swift +++ b/Sources/PostgresNIO/Data/PostgresData+Array.swift @@ -81,7 +81,7 @@ extension PostgresData { return nil } assert(b == 0, "Array b field did not equal zero") - guard let type = value.readInteger(as: PostgresDataType.self) else { + guard let type = value.readRawRepresentableInteger(as: PostgresDataType.self) else { return nil } guard isNotEmpty == 1 else { diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift index 51b9be7e..e9828efb 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Error.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Error.swift @@ -10,7 +10,7 @@ extension PostgresMessage { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> Error { var fields: [Field: String] = [:] - while let field = buffer.readInteger(as: Field.self) { + while let field = buffer.readRawRepresentableInteger(as: Field.self) { guard let string = buffer.readNullTerminatedString() else { throw PostgresError.protocol("Could not read error response string.") } diff --git a/Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift b/Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift index 3dfdb8e1..6337b649 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+ParameterDescription.swift @@ -6,7 +6,7 @@ extension PostgresMessage { /// Parses an instance of this message type from a byte buffer. public static func parse(from buffer: inout ByteBuffer) throws -> ParameterDescription { guard let dataTypes = try buffer.read(array: PostgresDataType.self, { buffer in - guard let dataType = buffer.readInteger(as: PostgresDataType.self) else { + guard let dataType = buffer.readRawRepresentableInteger(as: PostgresDataType.self) else { throw PostgresError.protocol("Could not parse data type integer in parameter description message.") } return dataType diff --git a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift index 48a90c18..271efcb4 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+RowDescription.swift @@ -20,7 +20,7 @@ extension PostgresMessage { guard let columnAttributeNumber = buffer.readInteger(as: Int16.self) else { throw PostgresError.protocol("Could not read row description field column attribute number") } - guard let dataType = buffer.readInteger(as: PostgresDataType.self) else { + guard let dataType = buffer.readRawRepresentableInteger(as: PostgresDataType.self) else { throw PostgresError.protocol("Could not read row description field data type") } guard let dataTypeSize = buffer.readInteger(as: Int16.self) else { @@ -29,7 +29,7 @@ extension PostgresMessage { guard let dataTypeModifier = buffer.readInteger(as: Int32.self) else { throw PostgresError.protocol("Could not read row description field data type modifier") } - guard let formatCode = buffer.readInteger(as: PostgresFormatCode.self) else { + guard let formatCode = buffer.readRawRepresentableInteger(as: PostgresFormatCode.self) else { throw PostgresError.protocol("Could not read row description field format code") } return .init( diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 1af28a3b..27dd40dc 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -87,18 +87,18 @@ struct ConnectionStateMachine { case sendParseDescribeBindExecuteSync(query: String, binds: [PSQLEncodable]) case sendBindExecuteSync(statementName: String, binds: [PSQLEncodable]) case failQuery(ExtendedQueryContext, with: PSQLError, cleanupContext: CleanUpContext?) - case succeedQuery(ExtendedQueryContext, columns: [PSQLBackendMessage.RowDescription.Column]) + case succeedQuery(ExtendedQueryContext, columns: [RowDescription.Column]) case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String) // --- streaming actions // actions if query has requested next row but we are waiting for backend - case forwardRows(CircularBuffer) - case forwardStreamComplete(CircularBuffer, commandTag: String) + case forwardRows([DataRow]) + case forwardStreamComplete([DataRow], commandTag: String) case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?) // Prepare statement actions case sendParseDescribeSync(name: String, query: String) - case succeedPreparedStatementCreation(PrepareStatementContext, with: PSQLBackendMessage.RowDescription?) + case succeedPreparedStatementCreation(PrepareStatementContext, with: RowDescription?) case failPreparedStatementCreation(PrepareStatementContext, with: PSQLError, cleanupContext: CleanUpContext?) // Close actions @@ -713,7 +713,7 @@ struct ConnectionStateMachine { } } - mutating func rowDescriptionReceived(_ description: PSQLBackendMessage.RowDescription) -> ConnectionAction { + mutating func rowDescriptionReceived(_ description: RowDescription) -> ConnectionAction { switch self.state { case .extendedQuery(var queryState, let connectionContext) where !queryState.isComplete: return self.avoidingStateMachineCoW { machine -> ConnectionAction in @@ -791,7 +791,7 @@ struct ConnectionStateMachine { } } - mutating func dataRowReceived(_ dataRow: PSQLBackendMessage.DataRow) -> ConnectionAction { + mutating func dataRowReceived(_ dataRow: DataRow) -> ConnectionAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.dataRow(dataRow))) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 4818ca19..67fe219f 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -8,13 +8,13 @@ struct ExtendedQueryStateMachine { case parseCompleteReceived(ExtendedQueryContext) case parameterDescriptionReceived(ExtendedQueryContext) - case rowDescriptionReceived(ExtendedQueryContext, [PSQLBackendMessage.RowDescription.Column]) + case rowDescriptionReceived(ExtendedQueryContext, [RowDescription.Column]) case noDataMessageReceived(ExtendedQueryContext) /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) - case streaming([PSQLBackendMessage.RowDescription.Column], RowStreamStateMachine) + case streaming([RowDescription.Column], RowStreamStateMachine) case commandComplete(commandTag: String) case error(PSQLError) @@ -28,13 +28,13 @@ struct ExtendedQueryStateMachine { // --- general actions case failQuery(ExtendedQueryContext, with: PSQLError) - case succeedQuery(ExtendedQueryContext, columns: [PSQLBackendMessage.RowDescription.Column]) + case succeedQuery(ExtendedQueryContext, columns: [RowDescription.Column]) case succeedQueryNoRowsComming(ExtendedQueryContext, commandTag: String) // --- streaming actions // actions if query has requested next row but we are waiting for backend - case forwardRows(CircularBuffer) - case forwardStreamComplete(CircularBuffer, commandTag: String) + case forwardRows([DataRow]) + case forwardStreamComplete([DataRow], commandTag: String) case forwardStreamError(PSQLError, read: Bool) case read @@ -105,7 +105,7 @@ struct ExtendedQueryStateMachine { } } - mutating func rowDescriptionReceived(_ rowDescription: PSQLBackendMessage.RowDescription) -> Action { + mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action { guard case .parameterDescriptionReceived(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription))) } @@ -119,7 +119,7 @@ struct ExtendedQueryStateMachine { // In Postgres extended queries we always request the response rows to be returned in // `.binary` format. - let columns = rowDescription.columns.map { column -> PSQLBackendMessage.RowDescription.Column in + let columns = rowDescription.columns.map { column -> RowDescription.Column in var column = column column.format = .binary return column @@ -155,12 +155,12 @@ struct ExtendedQueryStateMachine { } } - mutating func dataRowReceived(_ dataRow: PSQLBackendMessage.DataRow) -> Action { + mutating func dataRowReceived(_ dataRow: DataRow) -> Action { switch self.state { case .streaming(let columns, var demandStateMachine): // When receiving a data row, we must ensure that the data row column count // matches the previously received row description column count. - guard dataRow.columns.count == columns.count else { + guard dataRow.columnCount == columns.count else { return self.setAndFireError(.unexpectedBackendMessage(.dataRow(dataRow))) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift index 98e18dbc..947c8f97 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/PrepareStatementStateMachine.swift @@ -15,7 +15,7 @@ struct PrepareStatementStateMachine { enum Action { case sendParseDescribeSync(name: String, query: String) - case succeedPreparedStatementCreation(PrepareStatementContext, with: PSQLBackendMessage.RowDescription?) + case succeedPreparedStatementCreation(PrepareStatementContext, with: RowDescription?) case failPreparedStatementCreation(PrepareStatementContext, with: PSQLError) case read @@ -72,7 +72,7 @@ struct PrepareStatementStateMachine { return .succeedPreparedStatementCreation(queryContext, with: nil) } - mutating func rowDescriptionReceived(_ rowDescription: PSQLBackendMessage.RowDescription) -> Action { + mutating func rowDescriptionReceived(_ rowDescription: RowDescription) -> Action { guard case .parameterDescriptionReceived(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.rowDescription(rowDescription))) } diff --git a/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift index 165ba4f3..08953fb2 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/RowStreamStateMachine.swift @@ -13,15 +13,15 @@ struct RowStreamStateMachine { private enum State { /// The state machines expects further writes to `channelRead`. The writes are appended to the buffer. - case waitingForRows(CircularBuffer) + case waitingForRows([DataRow]) /// The state machines expects a call to `demandMoreResponseBodyParts` or `read`. The buffer is /// empty. It is preserved for performance reasons. - case waitingForReadOrDemand(CircularBuffer) + case waitingForReadOrDemand([DataRow]) /// The state machines expects a call to `read`. The buffer is empty. It is preserved for performance reasons. - case waitingForRead(CircularBuffer) + case waitingForRead([DataRow]) /// The state machines expects a call to `demandMoreResponseBodyParts`. The buffer is empty. It is /// preserved for performance reasons. - case waitingForDemand(CircularBuffer) + case waitingForDemand([DataRow]) case modifying } @@ -29,10 +29,12 @@ struct RowStreamStateMachine { private var state: State init() { - self.state = .waitingForRows(CircularBuffer(initialCapacity: 32)) + var buffer = [DataRow]() + buffer.reserveCapacity(32) + self.state = .waitingForRows(buffer) } - mutating func receivedRow(_ newRow: PSQLBackendMessage.DataRow) { + mutating func receivedRow(_ newRow: DataRow) { switch self.state { case .waitingForRows(var buffer): self.state = .modifying @@ -66,7 +68,7 @@ struct RowStreamStateMachine { } } - mutating func channelReadComplete() -> CircularBuffer? { + mutating func channelReadComplete() -> [DataRow]? { switch self.state { case .waitingForRows(let buffer): if buffer.isEmpty { @@ -139,7 +141,7 @@ struct RowStreamStateMachine { } } - mutating func end() -> CircularBuffer { + mutating func end() -> [DataRow] { switch self.state { case .waitingForRows(let buffer): return buffer diff --git a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift index d2211885..265701f4 100644 --- a/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Array+PSQLCodable.swift @@ -2,81 +2,81 @@ import NIOCore import struct Foundation.UUID /// A type, of which arrays can be encoded into and decoded from a postgres binary format -protocol PSQLArrayElement: PSQLCodable { +public protocol PSQLArrayElement: PSQLCodable { static var psqlArrayType: PSQLDataType { get } static var psqlArrayElementType: PSQLDataType { get } } extension Bool: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .boolArray } - static var psqlArrayElementType: PSQLDataType { .bool } + public static var psqlArrayType: PSQLDataType { .boolArray } + public static var psqlArrayElementType: PSQLDataType { .bool } } extension ByteBuffer: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .byteaArray } - static var psqlArrayElementType: PSQLDataType { .bytea } + public static var psqlArrayType: PSQLDataType { .byteaArray } + public static var psqlArrayElementType: PSQLDataType { .bytea } } extension UInt8: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .charArray } - static var psqlArrayElementType: PSQLDataType { .char } + public static var psqlArrayType: PSQLDataType { .charArray } + public static var psqlArrayElementType: PSQLDataType { .char } } extension Int16: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .int2Array } - static var psqlArrayElementType: PSQLDataType { .int2 } + public static var psqlArrayType: PSQLDataType { .int2Array } + public static var psqlArrayElementType: PSQLDataType { .int2 } } extension Int32: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .int4Array } - static var psqlArrayElementType: PSQLDataType { .int4 } + public static var psqlArrayType: PSQLDataType { .int4Array } + public static var psqlArrayElementType: PSQLDataType { .int4 } } extension Int64: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .int8Array } - static var psqlArrayElementType: PSQLDataType { .int8 } + public static var psqlArrayType: PSQLDataType { .int8Array } + public static var psqlArrayElementType: PSQLDataType { .int8 } } extension Int: PSQLArrayElement { #if (arch(i386) || arch(arm)) - static var psqlArrayType: PSQLDataType { .int4Array } - static var psqlArrayElementType: PSQLDataType { .int4 } + public static var psqlArrayType: PSQLDataType { .int4Array } + public static var psqlArrayElementType: PSQLDataType { .int4 } #else - static var psqlArrayType: PSQLDataType { .int8Array } - static var psqlArrayElementType: PSQLDataType { .int8 } + public static var psqlArrayType: PSQLDataType { .int8Array } + public static var psqlArrayElementType: PSQLDataType { .int8 } #endif } extension Float: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .float4Array } - static var psqlArrayElementType: PSQLDataType { .float4 } + public static var psqlArrayType: PSQLDataType { .float4Array } + public static var psqlArrayElementType: PSQLDataType { .float4 } } extension Double: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .float8Array } - static var psqlArrayElementType: PSQLDataType { .float8 } + public static var psqlArrayType: PSQLDataType { .float8Array } + public static var psqlArrayElementType: PSQLDataType { .float8 } } extension String: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .textArray } - static var psqlArrayElementType: PSQLDataType { .text } + public static var psqlArrayType: PSQLDataType { .textArray } + public static var psqlArrayElementType: PSQLDataType { .text } } extension UUID: PSQLArrayElement { - static var psqlArrayType: PSQLDataType { .uuidArray } - static var psqlArrayElementType: PSQLDataType { .uuid } + public static var psqlArrayType: PSQLDataType { .uuidArray } + public static var psqlArrayElementType: PSQLDataType { .uuid } } extension Array: PSQLEncodable where Element: PSQLArrayElement { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { Element.psqlArrayType } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } - func encode(into buffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + public func encode(into buffer: inout ByteBuffer, context: PSQLEncodingContext) throws { // 0 if empty, 1 if not buffer.writeInteger(self.isEmpty ? 0 : 1, as: UInt32.self) // b @@ -102,7 +102,7 @@ extension Array: PSQLEncodable where Element: PSQLArrayElement { extension Array: PSQLDecodable where Element: PSQLArrayElement { - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Array { + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Array { guard case .binary = format else { // currently we only support decoding arrays in binary format. throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) @@ -116,7 +116,7 @@ extension Array: PSQLDecodable where Element: PSQLArrayElement { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } - guard let elementType = buffer.readInteger(as: PSQLDataType.self) else { + guard let elementType = buffer.readRawRepresentableInteger(as: PSQLDataType.self) else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } diff --git a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift index 9ab2cc0f..c1f22807 100644 --- a/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bool+PSQLCodable.swift @@ -1,15 +1,22 @@ import NIOCore -extension Bool: PSQLCodable { - var psqlType: PSQLDataType { +extension Bool: PSQLEncodable { + public var psqlType: PSQLDataType { .bool } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Bool { + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + byteBuffer.writeInteger(self ? 1 : 0, as: UInt8.self) + } +} + +extension Bool: PSQLDecodable { + + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Bool { guard type == .bool else { throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: buffer, context: context) } @@ -43,8 +50,4 @@ extension Bool: PSQLCodable { } } } - - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { - byteBuffer.writeInteger(self ? 1 : 0, as: UInt8.self) - } } diff --git a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift index be8b2dd8..500b2e28 100644 --- a/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Bytes+PSQLCodable.swift @@ -17,38 +17,38 @@ extension PSQLEncodable where Self: Sequence, Self.Element == UInt8 { } extension ByteBuffer: PSQLCodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { .bytea } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { var copyOfSelf = self // dirty hack byteBuffer.writeBuffer(©OfSelf) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { return buffer } } extension Data: PSQLCodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { .bytea } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeBytes(self) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { return buffer.readData(length: buffer.readableBytes, byteTransferStrategy: .automatic)! } } diff --git a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift index f78a915b..bb028d1e 100644 --- a/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Date+PSQLCodable.swift @@ -2,15 +2,15 @@ import NIOCore import struct Foundation.Date extension Date: PSQLCodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { .timestamptz } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { switch type { case .timestamp, .timestamptz: guard buffer.readableBytes == 8, let microseconds = buffer.readInteger(as: Int64.self) else { @@ -29,7 +29,7 @@ extension Date: PSQLCodable { } } - func encode(into buffer: inout ByteBuffer, context: PSQLEncodingContext) { + public func encode(into buffer: inout ByteBuffer, context: PSQLEncodingContext) { let seconds = self.timeIntervalSince(Self._psqlDateStart) * Double(Self._microsecondsPerSecond) buffer.writeInteger(Int64(seconds)) } diff --git a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift index e86894a2..bc78fa4e 100644 --- a/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Float+PSQLCodable.swift @@ -1,15 +1,16 @@ import NIOCore extension Float: PSQLCodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { .float4 } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Float { + @inlinable + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Float { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.readFloat() else { @@ -31,21 +32,23 @@ extension Float: PSQLCodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + @inlinable + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeFloat(self) } } extension Double: PSQLCodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { .float8 } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Double { + @inlinable + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Double { switch (format, type) { case (.binary, .float4): guard buffer.readableBytes == 4, let float = buffer.readFloat() else { @@ -67,7 +70,8 @@ extension Double: PSQLCodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + @inlinable + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeDouble(self) } } diff --git a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift index 2c421e92..2495ec57 100644 --- a/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Int+PSQLCodable.swift @@ -1,16 +1,17 @@ import NIOCore extension UInt8: PSQLCodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { .char } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + @inlinable + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { switch type { case .bpchar, .char: guard buffer.readableBytes == 1, let value = buffer.readInteger(as: UInt8.self) else { @@ -24,23 +25,25 @@ extension UInt8: PSQLCodable { } // encoding - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + @inlinable + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeInteger(self, as: UInt8.self) } } extension Int16: PSQLCodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { .int2 } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + @inlinable + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -58,22 +61,24 @@ extension Int16: PSQLCodable { } // encoding - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + @inlinable + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeInteger(self, as: Int16.self) } } extension Int32: PSQLCodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { .int4 } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + @inlinable + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -96,22 +101,24 @@ extension Int32: PSQLCodable { } // encoding - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + @inlinable + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeInteger(self, as: Int32.self) } } extension Int64: PSQLCodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { .int8 } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + @inlinable + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -139,13 +146,14 @@ extension Int64: PSQLCodable { } // encoding - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + @inlinable + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeInteger(self, as: Int64.self) } } extension Int: PSQLCodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { switch self.bitWidth { case Int32.bitWidth: return .int4 @@ -156,12 +164,13 @@ extension Int: PSQLCodable { } } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } // decoding - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + @inlinable + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { switch (format, type) { case (.binary, .int2): guard buffer.readableBytes == 2, let value = buffer.readInteger(as: Int16.self) else { @@ -189,7 +198,8 @@ extension Int: PSQLCodable { } // encoding - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + @inlinable + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeInteger(self, as: Int.self) } } diff --git a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift index 99332221..41b5655e 100644 --- a/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/Optional+PSQLCodable.swift @@ -1,17 +1,24 @@ import NIOCore extension Optional: PSQLDecodable where Wrapped: PSQLDecodable { - static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Optional { - preconditionFailure("This code path should never be hit.") - // The code path for decoding an optional should be: - // -> PSQLData.decode(as: String?.self) - // -> PSQLData.decodeIfPresent(String.self) - // -> String.decode(from: type:) + typealias ActualType = Wrapped + + public static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Optional { + preconditionFailure("This should not be called") + } + + public static func decodeRaw(from byteBuffer: inout ByteBuffer?, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + switch byteBuffer { + case .some(var buffer): + return try ActualType.decode(from: &buffer, type: type, format: format, context: context) + case .none: + return nil + } } } extension Optional: PSQLEncodable where Wrapped: PSQLEncodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { switch self { case .some(let value): return value.psqlType @@ -20,7 +27,7 @@ extension Optional: PSQLEncodable where Wrapped: PSQLEncodable { } } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { switch self { case .some(let value): return value.psqlFormat @@ -29,11 +36,11 @@ extension Optional: PSQLEncodable where Wrapped: PSQLEncodable { } } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { preconditionFailure("Should never be hit, since `encodeRaw` is implemented.") } - func encodeRaw(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + public func encodeRaw(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { switch self { case .none: byteBuffer.writeInteger(-1, as: Int32.self) diff --git a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift index cff48330..d52b8be7 100644 --- a/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PSQLCodable.swift @@ -2,19 +2,19 @@ import NIOCore import struct Foundation.UUID extension String: PSQLCodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { .text } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { byteBuffer.writeString(self) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> String { + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> String { switch (format, type) { case (_, .varchar), (_, .text), diff --git a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift index 5e259c4b..15362de9 100644 --- a/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift +++ b/Sources/PostgresNIO/New/Data/UUID+PSQLCodable.swift @@ -4,15 +4,15 @@ import typealias Foundation.uuid_t extension UUID: PSQLCodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { .uuid } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) { let uuid = self.uuid byteBuffer.writeBytes([ uuid.0, uuid.1, uuid.2, uuid.3, @@ -22,7 +22,7 @@ extension UUID: PSQLCodable { ]) } - static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> UUID { + public static func decode(from buffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> UUID { switch (format, type) { case (.binary, .uuid): guard let uuid = buffer.readUUID() else { diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift index 45197cc0..f364f007 100644 --- a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -23,18 +23,28 @@ internal extension ByteBuffer { self.writeInteger(messageID.rawValue) } + @inlinable mutating func readFloat() -> Float? { - return self.readInteger(as: UInt32.self).map { Float(bitPattern: $0) } + guard let uint32 = self.readInteger(as: UInt32.self) else { + return nil + } + return Float(bitPattern: uint32) } + @inlinable mutating func readDouble() -> Double? { - return self.readInteger(as: UInt64.self).map { Double(bitPattern: $0) } + guard let uint64 = self.readInteger(as: UInt64.self) else { + return nil + } + return Double(bitPattern: uint64) } + @inlinable mutating func writeFloat(_ float: Float) { self.writeInteger(float.bitPattern) } + @inlinable mutating func writeDouble(_ double: Double) { self.writeInteger(double.bitPattern) } diff --git a/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift b/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift index 90e91177..d0143ea2 100644 --- a/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/Logging+PSQL.swift @@ -79,7 +79,7 @@ extension Logger { extension Logger { /// See `Logger.trace(_:metadata:source:file:function:line:)` - @usableFromInline + @inlinable func trace(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, @@ -88,7 +88,7 @@ extension Logger { } /// See `Logger.debug(_:metadata:source:file:function:line:)` - @usableFromInline + @inlinable func debug(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, @@ -97,7 +97,7 @@ extension Logger { } /// See `Logger.info(_:metadata:source:file:function:line:)` - @usableFromInline + @inlinable func info(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, @@ -106,7 +106,7 @@ extension Logger { } /// See `Logger.notice(_:metadata:source:file:function:line:)` - @usableFromInline + @inlinable func notice(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, @@ -115,7 +115,7 @@ extension Logger { } /// See `Logger.warning(_:metadata:source:file:function:line:)` - @usableFromInline + @inlinable func warning(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, @@ -124,7 +124,7 @@ extension Logger { } /// See `Logger.error(_:metadata:source:file:function:line:)` - @usableFromInline + @inlinable func error(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, @@ -133,7 +133,7 @@ extension Logger { } /// See `Logger.critical(_:metadata:source:file:function:line:)` - @usableFromInline + @inlinable func critical(_ message: @autoclosure () -> Logger.Message, metadata: @autoclosure () -> PSQLLoggingMetadata, source: @autoclosure () -> String? = nil, diff --git a/Sources/PostgresNIO/New/Messages/Authentication.swift b/Sources/PostgresNIO/New/Messages/Authentication.swift index 5ce5b857..d02ca205 100644 --- a/Sources/PostgresNIO/New/Messages/Authentication.swift +++ b/Sources/PostgresNIO/New/Messages/Authentication.swift @@ -1,7 +1,6 @@ import NIOCore extension PSQLBackendMessage { - enum Authentication: PayloadDecodable { case ok case kerberosV5 diff --git a/Sources/PostgresNIO/New/Messages/DataRow.swift b/Sources/PostgresNIO/New/Messages/DataRow.swift index 3047ccc2..0b8adf4f 100644 --- a/Sources/PostgresNIO/New/Messages/DataRow.swift +++ b/Sources/PostgresNIO/New/Messages/DataRow.swift @@ -1,34 +1,150 @@ import NIOCore -extension PSQLBackendMessage { +/// A backend data row message. +/// +/// - NOTE: This struct is not part of the ``PSQLBackendMessage`` namespace even +/// though this is where it actually belongs. The reason for this is, that we want +/// this type to be @usableFromInline. If a type is made @usableFromInline in an +/// enclosing type, the enclosing type must be @usableFromInline as well. +/// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick +/// the Swift compiler. +@usableFromInline +struct DataRow: PSQLBackendMessage.PayloadDecodable, Equatable { - struct DataRow: PayloadDecodable, Equatable { - - var columns: [ByteBuffer?] + @usableFromInline + var columnCount: Int16 + + @usableFromInline + var bytes: ByteBuffer + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + try buffer.ensureAtLeastNBytesRemaining(2) + let columnCount = buffer.readInteger(as: Int16.self)! + let firstColumnIndex = buffer.readerIndex - static func decode(from buffer: inout ByteBuffer) throws -> Self { + for _ in 0..= 0 else { - result.append(nil) - continue - } - - try buffer.ensureAtLeastNBytesRemaining(bufferLength) - let columnBuffer = buffer.readSlice(length: Int(bufferLength))! - - result.append(columnBuffer) + guard bufferLength >= 0 else { + // if buffer length is negative, this means that the value is null + continue } - return DataRow(columns: result) + try buffer.ensureAtLeastNBytesRemaining(bufferLength) + buffer.moveReaderIndex(forwardBy: bufferLength) + } + + try buffer.ensureExactNBytesRemaining(0) + + buffer.moveReaderIndex(to: firstColumnIndex) + let columnSlice = buffer.readSlice(length: buffer.readableBytes)! + return DataRow(columnCount: columnCount, bytes: columnSlice) + } +} + +extension DataRow: Sequence { + @usableFromInline + typealias Element = ByteBuffer? + + // There is no contiguous storage available... Sadly + @inlinable + func withContiguousStorageIfAvailable(_ body: (UnsafeBufferPointer) throws -> R) rethrows -> R? { + nil + } +} + +extension DataRow: Collection { + + @usableFromInline + struct ByteIndex: Comparable { + var offset: Int + + init(_ index: Int) { + self.offset = index + } + + @usableFromInline + static func == (lhs: Self, rhs: Self) -> Bool { + lhs.offset == rhs.offset + } + + @usableFromInline + static func < (lhs: Self, rhs: Self) -> Bool { + lhs.offset < rhs.offset + } + + @usableFromInline + static func <= (lhs: Self, rhs: Self) -> Bool { + lhs.offset <= rhs.offset + } + + @usableFromInline + static func >= (lhs: Self, rhs: Self) -> Bool { + lhs.offset >= rhs.offset } + + @usableFromInline + static func > (lhs: Self, rhs: Self) -> Bool { + lhs.offset > rhs.offset + } + } + + @usableFromInline + typealias Index = DataRow.ByteIndex + + @usableFromInline + var startIndex: ByteIndex { + ByteIndex(self.bytes.readerIndex) + } + + @usableFromInline + var endIndex: ByteIndex { + ByteIndex(self.bytes.readerIndex + self.bytes.readableBytes) + } + + @usableFromInline + var count: Int { + Int(self.columnCount) + } + + @usableFromInline + func index(after index: ByteIndex) -> ByteIndex { + guard index < self.endIndex else { + preconditionFailure("index out of bounds") + } + var elementLength = Int(self.bytes.getInteger(at: index.offset, as: Int32.self)!) + if elementLength < 0 { + elementLength = 0 + } + return ByteIndex(index.offset + MemoryLayout.size + elementLength) + } + + @usableFromInline + subscript(index: ByteIndex) -> Element { + guard index < self.endIndex else { + preconditionFailure("index out of bounds") + } + let elementLength = Int(self.bytes.getInteger(at: index.offset, as: Int32.self)!) + if elementLength < 0 { + return nil + } + return self.bytes.getSlice(at: index.offset + MemoryLayout.size, length: elementLength)! + } +} + +extension DataRow { + @usableFromInline + subscript(column index: Int) -> Element { + guard index < self.columnCount else { + preconditionFailure("index out of bounds") + } + + var byteIndex = self.startIndex + for _ in 0.. Self { + try buffer.ensureAtLeastNBytesRemaining(2) + let columnCount = buffer.readInteger(as: Int16.self)! + + guard columnCount >= 0 else { + throw PSQLPartialDecodingError.integerMustBePositiveOrNull(columnCount) } - static func decode(from buffer: inout ByteBuffer) throws -> Self { - try buffer.ensureAtLeastNBytesRemaining(2) - let columnCount = buffer.readInteger(as: Int16.self)! - - guard columnCount >= 0 else { - throw PSQLPartialDecodingError.integerMustBePositiveOrNull(columnCount) + var result = [Column]() + result.reserveCapacity(Int(columnCount)) + + for _ in 0..(_ value: T, into buffer: inout ByteBuffer) throws } -protocol PSQLJSONDecoder { +public protocol PSQLJSONDecoder { func decode(_ type: T.Type, from buffer: ByteBuffer) throws -> T } diff --git a/Sources/PostgresNIO/New/PSQLBackendMessage.swift b/Sources/PostgresNIO/New/PSQLBackendMessage.swift index d65f4623..8b1cf3d2 100644 --- a/Sources/PostgresNIO/New/PSQLBackendMessage.swift +++ b/Sources/PostgresNIO/New/PSQLBackendMessage.swift @@ -20,6 +20,7 @@ protocol PSQLMessagePayloadDecodable { /// /// All messages are defined in the official Postgres Documentation in the section /// [Frontend/Backend Protocol – Message Formats](https://www.postgresql.org/docs/13/protocol-message-formats.html) + enum PSQLBackendMessage { typealias PayloadDecodable = PSQLMessagePayloadDecodable diff --git a/Sources/PostgresNIO/New/PSQLChannelHandler.swift b/Sources/PostgresNIO/New/PSQLChannelHandler.swift index 20f3c065..c1f3c016 100644 --- a/Sources/PostgresNIO/New/PSQLChannelHandler.swift +++ b/Sources/PostgresNIO/New/PSQLChannelHandler.swift @@ -465,7 +465,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler { private func succeedQueryWithRowStream( _ queryContext: ExtendedQueryContext, - columns: [PSQLBackendMessage.RowDescription.Column], + columns: [RowDescription.Column], context: ChannelHandlerContext) { let rows = PSQLRowStream( diff --git a/Sources/PostgresNIO/New/PSQLCodable.swift b/Sources/PostgresNIO/New/PSQLCodable.swift index b5434edd..60c78197 100644 --- a/Sources/PostgresNIO/New/PSQLCodable.swift +++ b/Sources/PostgresNIO/New/PSQLCodable.swift @@ -1,7 +1,7 @@ import NIOCore /// A type that can encode itself to a postgres wire binary representation. -protocol PSQLEncodable { +public protocol PSQLEncodable { /// identifies the data type that we will encode into `byteBuffer` in `encode` var psqlType: PSQLDataType { get } @@ -19,7 +19,8 @@ protocol PSQLEncodable { } /// A type that can decode itself from a postgres wire binary representation. -protocol PSQLDecodable { +public protocol PSQLDecodable { + typealias ActualType = Self /// Decode an entity from the `byteBuffer` in postgres wire format /// @@ -33,13 +34,32 @@ protocol PSQLDecodable { /// to use when decoding json and metadata to create better errors. /// - Returns: A decoded object static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self + + /// Decode an entity from the `byteBuffer` in postgres wire format. + /// This method has a default implementation and may be overriden + /// only for special cases, like `Optional`s. + static func decodeRaw(from byteBuffer: inout ByteBuffer?, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self +} + +extension PSQLDecodable { + + @inlinable + public static func decodeRaw(from byteBuffer: inout ByteBuffer?, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Self { + switch byteBuffer { + case .some(var buffer): + return try self.decode(from: &buffer, type: type, format: format, context: context) + case .none: + throw PSQLCastingError.missingData(targetType: Self.self, type: type, context: context) + } + } + } /// A type that can be encoded into and decoded from a postgres binary format -protocol PSQLCodable: PSQLEncodable, PSQLDecodable {} +public protocol PSQLCodable: PSQLEncodable, PSQLDecodable {} extension PSQLEncodable { - func encodeRaw(into buffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + public func encodeRaw(into buffer: inout ByteBuffer, context: PSQLEncodingContext) throws { // The length of the parameter value, in bytes (this count does not include // itself). Can be zero. let lengthIndex = buffer.writerIndex @@ -54,20 +74,21 @@ extension PSQLEncodable { } } -struct PSQLEncodingContext { - let jsonEncoder: PSQLJSONEncoder +public struct PSQLEncodingContext { + public let jsonEncoder: PSQLJSONEncoder } -struct PSQLDecodingContext { +public struct PSQLDecodingContext { - let jsonDecoder: PSQLJSONDecoder + public let jsonDecoder: PSQLJSONDecoder - let columnIndex: Int - let columnName: String + public let columnIndex: Int + public let columnName: String - let file: String - let line: Int + public let file: String + public let line: Int + @inlinable init(jsonDecoder: PSQLJSONDecoder, columnName: String, columnIndex: Int, file: String, line: Int) { self.jsonDecoder = jsonDecoder self.columnName = columnName diff --git a/Sources/PostgresNIO/New/PSQLConnection.swift b/Sources/PostgresNIO/New/PSQLConnection.swift index d6c31542..189e6e90 100644 --- a/Sources/PostgresNIO/New/PSQLConnection.swift +++ b/Sources/PostgresNIO/New/PSQLConnection.swift @@ -6,58 +6,58 @@ import class Foundation.JSONEncoder import class Foundation.JSONDecoder import struct Foundation.UUID import Logging +import Foundation -@usableFromInline -final class PSQLConnection { +public final class PSQLConnection { - struct Configuration { + public struct Configuration { - struct Coders { - var jsonEncoder: PSQLJSONEncoder - var jsonDecoder: PSQLJSONDecoder + public struct Coders { + public var jsonEncoder: PSQLJSONEncoder + public var jsonDecoder: PSQLJSONDecoder - init(jsonEncoder: PSQLJSONEncoder, jsonDecoder: PSQLJSONDecoder) { + public init(jsonEncoder: PSQLJSONEncoder, jsonDecoder: PSQLJSONDecoder) { self.jsonEncoder = jsonEncoder self.jsonDecoder = jsonDecoder } - static var foundation: Coders { + public static var foundation: Coders { Coders(jsonEncoder: JSONEncoder(), jsonDecoder: JSONDecoder()) } } - struct Authentication { - var username: String - var database: String? = nil - var password: String? = nil + public struct Authentication { + public var username: String + public var database: String? = nil + public var password: String? = nil - init(username: String, password: String?, database: String?) { + public init(username: String, password: String?, database: String?) { self.username = username self.database = database self.password = password } } - enum Connection { + public enum Connection { case unresolved(host: String, port: Int) case resolved(address: SocketAddress, serverName: String?) } - var connection: Connection + public var connection: Connection /// The authentication properties to send to the Postgres server during startup auth handshake - var authentication: Authentication? + public var authentication: Authentication? - var tlsConfiguration: TLSConfiguration? - var coders: Coders + public var tlsConfiguration: TLSConfiguration? + public var coders: Coders - init(host: String, - port: Int = 5432, - username: String, - database: String? = nil, - password: String? = nil, - tlsConfiguration: TLSConfiguration? = nil, - coders: Coders = .foundation) + public init(host: String, + port: Int = 5432, + username: String, + database: String? = nil, + password: String? = nil, + tlsConfiguration: TLSConfiguration? = nil, + coders: Coders = .foundation) { self.connection = .unresolved(host: host, port: port) self.authentication = Authentication(username: username, password: password, database: database) @@ -107,7 +107,7 @@ final class PSQLConnection { self.jsonDecoder = jsonDecoder } deinit { - assert(self.isClosed, "PostgresConnection deinitialized before being closed.") + precondition(self.isClosed, "PSQLConnection deinitialized before being closed.") } func close() -> EventLoopFuture { @@ -147,7 +147,7 @@ final class PSQLConnection { // MARK: Prepared statements func prepareStatement(_ query: String, with name: String, logger: Logger) -> EventLoopFuture { - let promise = self.channel.eventLoop.makePromise(of: PSQLBackendMessage.RowDescription?.self) + let promise = self.channel.eventLoop.makePromise(of: RowDescription?.self) let context = PrepareStatementContext( name: name, query: query, @@ -290,6 +290,47 @@ extension PSQLConnection.Configuration { } } +#if swift(>=5.5) && canImport(_Concurrency) +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +extension PSQLConnection { + + public static func connect( + configuration: PSQLConnection.Configuration, + logger: Logger, + on eventLoop: EventLoop + ) async throws -> PSQLConnection { + try await Self.connect(configuration: configuration, logger: logger, on: eventLoop).get() + } + + public func close() async throws { + try await self.close().get() + } + + public func query(_ query: String, logger: Logger) async throws -> PSQLRowSequence { + try await self.query(query, [], logger: logger) + } + + public func query(_ query: String, _ bind: [PSQLEncodable], logger: Logger) async throws -> PSQLRowSequence { + var logger = logger + logger[postgresMetadataKey: .connectionID] = "\(self.connectionID)" + guard bind.count <= Int(Int16.max) else { + throw PSQLError.tooManyParameters + } + let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self) + let context = ExtendedQueryContext( + query: query, + bind: bind, + logger: logger, + jsonDecoder: self.jsonDecoder, + promise: promise) + + self.channel.write(PSQLTask.extendedQuery(context), promise: nil) + + return try await promise.futureResult.map({ $0.asyncSequence() }).get() + } +} +#endif + // copy and pasted from NIOSSL: private extension String { func isIPAddress() -> Bool { diff --git a/Sources/PostgresNIO/New/PSQLData.swift b/Sources/PostgresNIO/New/PSQLData.swift index 840d798a..30515731 100644 --- a/Sources/PostgresNIO/New/PSQLData.swift +++ b/Sources/PostgresNIO/New/PSQLData.swift @@ -5,144 +5,130 @@ import NIOCore /// Currently there a two wire formats supported: /// - text /// - binary -enum PSQLFormat: Int16 { +public enum PSQLFormat: Int16 { case text = 0 case binary = 1 } +@usableFromInline struct PSQLData: Equatable { - @usableFromInline var bytes: ByteBuffer? - @usableFromInline var dataType: PSQLDataType - @usableFromInline var format: PSQLFormat + @usableFromInline + var bytes: ByteBuffer? + @usableFromInline + var dataType: PSQLDataType + @usableFromInline + var format: PSQLFormat /// use this only for testing + @inlinable init(bytes: ByteBuffer?, dataType: PSQLDataType, format: PSQLFormat) { self.bytes = bytes self.dataType = dataType self.format = format } - @inlinable - func decode(as: Optional.Type, context: PSQLDecodingContext) throws -> T? { - try self.decodeIfPresent(as: T.self, context: context) - } - @inlinable func decode(as type: T.Type, context: PSQLDecodingContext) throws -> T { - switch self.bytes { - case .none: - throw PSQLCastingError.missingData(targetType: type, type: self.dataType, context: context) - case .some(var buffer): - return try T.decode(from: &buffer, type: self.dataType, format: self.format, context: context) - } - } - - @inlinable - func decodeIfPresent(as: T.Type, context: PSQLDecodingContext) throws -> T? { - switch self.bytes { - case .none: - return nil - case .some(var buffer): - return try T.decode(from: &buffer, type: self.dataType, format: self.format, context: context) - } + var buffer = self.bytes + return try T.decodeRaw(from: &buffer, type: self.dataType, format: self.format, context: context) } } -struct PSQLDataType: RawRepresentable, Equatable, CustomStringConvertible { - typealias RawValue = Int32 +public struct PSQLDataType: RawRepresentable, Equatable, CustomStringConvertible { + public typealias RawValue = Int32 /// The raw data type code recognized by PostgreSQL. - var rawValue: Int32 + public var rawValue: Int32 /// `0` - static let null = PSQLDataType(0) + public static let null = PSQLDataType(0) /// `16` - static let bool = PSQLDataType(16) + public static let bool = PSQLDataType(16) /// `17` - static let bytea = PSQLDataType(17) + public static let bytea = PSQLDataType(17) /// `18` - static let char = PSQLDataType(18) + public static let char = PSQLDataType(18) /// `19` - static let name = PSQLDataType(19) + public static let name = PSQLDataType(19) /// `20` - static let int8 = PSQLDataType(20) + public static let int8 = PSQLDataType(20) /// `21` - static let int2 = PSQLDataType(21) + public static let int2 = PSQLDataType(21) /// `23` - static let int4 = PSQLDataType(23) + public static let int4 = PSQLDataType(23) /// `24` - static let regproc = PSQLDataType(24) + public static let regproc = PSQLDataType(24) /// `25` - static let text = PSQLDataType(25) + public static let text = PSQLDataType(25) /// `26` - static let oid = PSQLDataType(26) + public static let oid = PSQLDataType(26) /// `114` - static let json = PSQLDataType(114) + public static let json = PSQLDataType(114) /// `194` pg_node_tree - static let pgNodeTree = PSQLDataType(194) + public static let pgNodeTree = PSQLDataType(194) /// `600` - static let point = PSQLDataType(600) + public static let point = PSQLDataType(600) /// `700` - static let float4 = PSQLDataType(700) + public static let float4 = PSQLDataType(700) /// `701` - static let float8 = PSQLDataType(701) + public static let float8 = PSQLDataType(701) /// `790` - static let money = PSQLDataType(790) + public static let money = PSQLDataType(790) /// `1000` _bool - static let boolArray = PSQLDataType(1000) + public static let boolArray = PSQLDataType(1000) /// `1001` _bytea - static let byteaArray = PSQLDataType(1001) + public static let byteaArray = PSQLDataType(1001) /// `1002` _char - static let charArray = PSQLDataType(1002) + public static let charArray = PSQLDataType(1002) /// `1003` _name - static let nameArray = PSQLDataType(1003) + public static let nameArray = PSQLDataType(1003) /// `1005` _int2 - static let int2Array = PSQLDataType(1005) + public static let int2Array = PSQLDataType(1005) /// `1007` _int4 - static let int4Array = PSQLDataType(1007) + public static let int4Array = PSQLDataType(1007) /// `1009` _text - static let textArray = PSQLDataType(1009) + public static let textArray = PSQLDataType(1009) /// `1015` _varchar - static let varcharArray = PSQLDataType(1015) + public static let varcharArray = PSQLDataType(1015) /// `1016` _int8 - static let int8Array = PSQLDataType(1016) + public static let int8Array = PSQLDataType(1016) /// `1017` _point - static let pointArray = PSQLDataType(1017) + public static let pointArray = PSQLDataType(1017) /// `1021` _float4 - static let float4Array = PSQLDataType(1021) + public static let float4Array = PSQLDataType(1021) /// `1022` _float8 - static let float8Array = PSQLDataType(1022) + public static let float8Array = PSQLDataType(1022) /// `1034` _aclitem - static let aclitemArray = PSQLDataType(1034) + public static let aclitemArray = PSQLDataType(1034) /// `1042` - static let bpchar = PSQLDataType(1042) + public static let bpchar = PSQLDataType(1042) /// `1043` - static let varchar = PSQLDataType(1043) + public static let varchar = PSQLDataType(1043) /// `1082` - static let date = PSQLDataType(1082) + public static let date = PSQLDataType(1082) /// `1083` - static let time = PSQLDataType(1083) + public static let time = PSQLDataType(1083) /// `1114` - static let timestamp = PSQLDataType(1114) + public static let timestamp = PSQLDataType(1114) /// `1115` _timestamp - static let timestampArray = PSQLDataType(1115) + public static let timestampArray = PSQLDataType(1115) /// `1184` - static let timestamptz = PSQLDataType(1184) + public static let timestamptz = PSQLDataType(1184) /// `1266` - static let timetz = PSQLDataType(1266) + public static let timetz = PSQLDataType(1266) /// `1700` - static let numeric = PSQLDataType(1700) + public static let numeric = PSQLDataType(1700) /// `2278` - static let void = PSQLDataType(2278) + public static let void = PSQLDataType(2278) /// `2950` - static let uuid = PSQLDataType(2950) + public static let uuid = PSQLDataType(2950) /// `2951` _uuid - static let uuidArray = PSQLDataType(2951) + public static let uuidArray = PSQLDataType(2951) /// `3802` - static let jsonb = PSQLDataType(3802) + public static let jsonb = PSQLDataType(3802) /// `3807` _jsonb - static let jsonbArray = PSQLDataType(3807) + public static let jsonbArray = PSQLDataType(3807) /// Returns `true` if the type's raw value is greater than `2^14`. /// This _appears_ to be true for all user-defined types, but I don't @@ -155,7 +141,7 @@ struct PSQLDataType: RawRepresentable, Equatable, CustomStringConvertible { self.rawValue = rawValue } - init(rawValue: Int32) { + public init(rawValue: Int32) { self.init(rawValue) } @@ -210,7 +196,7 @@ struct PSQLDataType: RawRepresentable, Equatable, CustomStringConvertible { } /// See `CustomStringConvertible`. - var description: String { + public var description: String { return self.knownSQLName ?? "UNKNOWN \(self.rawValue)" } } diff --git a/Sources/PostgresNIO/New/PSQLError.swift b/Sources/PostgresNIO/New/PSQLError.swift index 0cadc9ee..24da0c17 100644 --- a/Sources/PostgresNIO/New/PSQLError.swift +++ b/Sources/PostgresNIO/New/PSQLError.swift @@ -80,6 +80,7 @@ struct PSQLError: Error { } } +@usableFromInline struct PSQLCastingError: Error { let columnName: String @@ -90,11 +91,34 @@ struct PSQLCastingError: Error { let targetType: PSQLDecodable.Type let postgresType: PSQLDataType - let postgresData: ByteBuffer? + let cellData: ByteBuffer? let description: String let underlying: Error? + init( + columnName: String, + columnIndex: Int, + file: String, + line: Int, + targetType: PSQLDecodable.Type, + postgresType: PSQLDataType, + cellData: ByteBuffer?, + description: String, + underlying: Error? + ) { + self.columnName = columnName + self.columnIndex = columnIndex + self.file = file + self.line = line + self.targetType = targetType + self.postgresType = postgresType + self.cellData = cellData + self.description = description + self.underlying = underlying + } + + @usableFromInline static func missingData(targetType: PSQLDecodable.Type, type: PSQLDataType, context: PSQLDecodingContext) -> Self { PSQLCastingError( columnName: context.columnName, @@ -103,7 +127,7 @@ struct PSQLCastingError: Error { line: context.line, targetType: targetType, postgresType: type, - postgresData: nil, + cellData: nil, description: """ Failed to cast Postgres data type \(type.description) to Swift type \(targetType) \ because of missing data in \(context.file) line \(context.line). @@ -112,6 +136,7 @@ struct PSQLCastingError: Error { ) } + @usableFromInline static func failure(targetType: PSQLDecodable.Type, type: PSQLDataType, postgresData: ByteBuffer, @@ -126,7 +151,7 @@ struct PSQLCastingError: Error { line: context.line, targetType: targetType, postgresType: type, - postgresData: postgresData, + cellData: postgresData, description: description ?? """ Failed to cast Postgres data type \(type.description) to Swift type \(targetType) \ in \(context.file) line \(context.line)." diff --git a/Sources/PostgresNIO/New/PSQLPreparedStatement.swift b/Sources/PostgresNIO/New/PSQLPreparedStatement.swift index c5a08be9..fbdfd868 100644 --- a/Sources/PostgresNIO/New/PSQLPreparedStatement.swift +++ b/Sources/PostgresNIO/New/PSQLPreparedStatement.swift @@ -10,5 +10,5 @@ struct PSQLPreparedStatement { let connection: PSQLConnection /// The `RowDescription` to apply to all `DataRow`s when executing this `PSQLPreparedStatement` - let rowDescription: PSQLBackendMessage.RowDescription? + let rowDescription: RowDescription? } diff --git a/Sources/PostgresNIO/New/PSQLRow.swift b/Sources/PostgresNIO/New/PSQLRow.swift index c5efb53a..47b10af5 100644 --- a/Sources/PostgresNIO/New/PSQLRow.swift +++ b/Sources/PostgresNIO/New/PSQLRow.swift @@ -1,34 +1,26 @@ +import NIOCore /// `PSQLRow` represents a single row that was received from the Postgres Server. -struct PSQLRow { +public struct PSQLRow { + @usableFromInline internal let lookupTable: [String: Int] - internal let data: PSQLBackendMessage.DataRow + @usableFromInline + internal let data: DataRow - internal let columns: [PSQLBackendMessage.RowDescription.Column] + @usableFromInline + internal let columns: [RowDescription.Column] + @usableFromInline internal let jsonDecoder: PSQLJSONDecoder - internal init(data: PSQLBackendMessage.DataRow, lookupTable: [String: Int], columns: [PSQLBackendMessage.RowDescription.Column], jsonDecoder: PSQLJSONDecoder) { + internal init(data: DataRow, lookupTable: [String: Int], columns: [RowDescription.Column], jsonDecoder: PSQLJSONDecoder) { self.data = data self.lookupTable = lookupTable self.columns = columns self.jsonDecoder = jsonDecoder } - - /// Access the raw Postgres data in the n-th column - subscript(index: Int) -> PSQLData { - PSQLData(bytes: self.data.columns[index], dataType: self.columns[index].dataType, format: self.columns[index].format) - } - - // TBD: Should this be optional? - /// Access the raw Postgres data in the column indentified by name - subscript(column columnName: String) -> PSQLData? { - guard let index = self.lookupTable[columnName] else { - return nil - } - - return self[index] - } - +} + +extension PSQLRow { /// Access the data in the provided column and decode it into the target type. /// /// - Parameters: @@ -36,7 +28,8 @@ struct PSQLRow { /// - type: The type to decode the data into /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. /// - Returns: The decoded value of Type T. - func decode(column: String, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { + @inlinable + public func decode(column: String, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { guard let index = self.lookupTable[column] else { preconditionFailure("A column '\(column)' does not exist.") } @@ -51,16 +44,128 @@ struct PSQLRow { /// - type: The type to decode the data into /// - Throws: The error of the decoding implementation. See also `PSQLDecodable` protocol for this. /// - Returns: The decoded value of Type T. - func decode(column index: Int, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { + @inlinable + public func decode(column index: Int, as type: T.Type, file: String = #file, line: Int = #line) throws -> T { + precondition(index < self.data.columnCount) + let column = self.columns[index] + let context = PSQLDecodingContext( + jsonDecoder: self.jsonDecoder, + columnName: column.name, + columnIndex: index, + file: file, + line: line) - let decodingContext = PSQLDecodingContext( - jsonDecoder: jsonDecoder, + guard var cellSlice = self.data[column: index] else { + throw PSQLCastingError.missingData(targetType: T.self, type: column.dataType, context: context) + } + + return try T.decode(from: &cellSlice, type: column.dataType, format: column.format, context: context) + } + + @inlinable + public func decode(column index: Int, as type: Optional.Type, file: String = #file, line: Int = #line) throws -> Optional { + precondition(index < self.data.columnCount) + + guard var cellSlice = self.data[column: index] else { + return nil + } + + let column = self.columns[index] + let context = PSQLDecodingContext( + jsonDecoder: self.jsonDecoder, columnName: column.name, columnIndex: index, file: file, line: line) + + return try T.decode(from: &cellSlice, type: column.dataType, format: column.format, context: context) + } +} + +extension PSQLRow { + + @inlinable + public func decode(_ t0: T0.Type, file: String = #file, line: Int = #line) throws -> T0 + where T0: PSQLDecodable + { + var buffer = self.data.bytes + + return try ( + self.decodeNextColumn(t0, from: &buffer, index: 0, file: file, line: line) + ) + } + + @inlinable + public func decode(_ t0: T0.Type, _ t1: T1.Type, file: String = #file, line: Int = #line) throws -> (T0, T1) + where T0: PSQLDecodable, T1: PSQLDecodable + { + assert(self.columns.count >= 2) + var buffer = self.data.bytes + + return try ( + self.decodeNextColumn(t0, from: &buffer, index: 0, file: file, line: line), + self.decodeNextColumn(t1, from: &buffer, index: 1, file: file, line: line) + ) + } + + @inlinable + public func decode(_ t0: T0.Type, _ t1: T1.Type, _ t2: T2.Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2) + where T0: PSQLDecodable, T1: PSQLDecodable, T2: PSQLDecodable + { + assert(self.columns.count >= 3) + var buffer = self.data.bytes - return try self[index].decode(as: T.self, context: decodingContext) + return try ( + self.decodeNextColumn(t0, from: &buffer, index: 0, file: file, line: line), + self.decodeNextColumn(t1, from: &buffer, index: 1, file: file, line: line), + self.decodeNextColumn(t2, from: &buffer, index: 2, file: file, line: line) + ) + } + + @inlinable + public func decode(_ t0: T0.Type, _ t1: T1.Type, _ t2: T2.Type, _ t3: T3.Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3) + where T0: PSQLDecodable, T1: PSQLDecodable, T2: PSQLDecodable, T3: PSQLDecodable + { + assert(self.columns.count >= 4) + var buffer = self.data.bytes + + return try ( + self.decodeNextColumn(t0, from: &buffer, index: 0, file: file, line: line), + self.decodeNextColumn(t1, from: &buffer, index: 1, file: file, line: line), + self.decodeNextColumn(t2, from: &buffer, index: 2, file: file, line: line), + self.decodeNextColumn(t3, from: &buffer, index: 3, file: file, line: line) + ) + } + + @inlinable + public func decode(_ t0: T0.Type, _ t1: T1.Type, _ t2: T2.Type, _ t3: T3.Type, _ t4: T4.Type, file: String = #file, line: Int = #line) throws -> (T0, T1, T2, T3, T4) + where T0: PSQLDecodable, T1: PSQLDecodable, T2: PSQLDecodable, T3: PSQLDecodable, T4: PSQLDecodable + { + assert(self.columns.count >= 5) + var buffer = self.data.bytes + + return try ( + self.decodeNextColumn(t0, from: &buffer, index: 0, file: file, line: line), + self.decodeNextColumn(t1, from: &buffer, index: 1, file: file, line: line), + self.decodeNextColumn(t2, from: &buffer, index: 2, file: file, line: line), + self.decodeNextColumn(t3, from: &buffer, index: 3, file: file, line: line), + self.decodeNextColumn(t4, from: &buffer, index: 4, file: file, line: line) + ) + } + + @inlinable + func decodeNextColumn(_ t: T.Type, from buffer: inout ByteBuffer, index: Int, file: String, line: Int) throws -> T { + var slice = buffer.readLengthPrefixedSlice(as: Int32.self) + + let dc0 = PSQLDecodingContext( + jsonDecoder: jsonDecoder, + columnName: self.columns[index].name, + columnIndex: index, + file: file, + line: line + ) + let r = try T.decodeRaw(from: &slice, type: self.columns[index].dataType, format: self.columns[index].format, context: dc0) + return r } } diff --git a/Sources/PostgresNIO/New/PSQLRowSequence+Decoding.swift b/Sources/PostgresNIO/New/PSQLRowSequence+Decoding.swift new file mode 100644 index 00000000..cbd8ab14 --- /dev/null +++ b/Sources/PostgresNIO/New/PSQLRowSequence+Decoding.swift @@ -0,0 +1,62 @@ +#if swift(>=5.5) + +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +extension PSQLRowSequence { + + @inlinable + public func decode(_ t0: T0.Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence + where T0: PSQLDecodable + { + self.map { try $0.decode(t0, file: file, line: line) } + } + + @inlinable + public func decode(_ t0: T0.Type, _ t1: T1.Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence + where T0: PSQLDecodable, T1: PSQLDecodable + { + self.map { try $0.decode(t0, t1, file: file, line: line) } + } + + @inlinable + public func decode(_ t0: T0.Type, _ t1: T1.Type, _ t2: T2.Type, file: String = #file, line: Int = #line) -> AsyncThrowingMapSequence + where T0: PSQLDecodable, T1: PSQLDecodable, T2: PSQLDecodable + { + self.map { try $0.decode(t0, t1, t2, file: file, line: line) } + } + + @inlinable + public func decode( + _ t0: T0.Type, + _ t1: T1.Type, + _ t2: T2.Type, + _ t3: T3.Type, + file: String = #file, line: Int = #line + ) -> AsyncThrowingMapSequence + where T0: PSQLDecodable, + T1: PSQLDecodable, + T2: PSQLDecodable, + T3: PSQLDecodable + { + self.map { try $0.decode(t0, t1, t2, t3, file: file, line: line) } + } + + @inlinable + public func decode( + _ t0: T0.Type, + _ t1: T1.Type, + _ t2: T2.Type, + _ t3: T3.Type, + _ t4: T4.Type, + file: String = #file, line: Int = #line + ) -> AsyncThrowingMapSequence + where T0: PSQLDecodable, + T1: PSQLDecodable, + T2: PSQLDecodable, + T3: PSQLDecodable, + T4: PSQLDecodable + { + self.map { try $0.decode(t0, t1, t2, t3, t4, file: file, line: line) } + } +} + +#endif diff --git a/Sources/PostgresNIO/New/PSQLRowSequence.swift b/Sources/PostgresNIO/New/PSQLRowSequence.swift new file mode 100644 index 00000000..ea1c05e6 --- /dev/null +++ b/Sources/PostgresNIO/New/PSQLRowSequence.swift @@ -0,0 +1,663 @@ +import NIOCore +import NIOConcurrencyHelpers + +#if swift(>=5.5) && canImport(_Concurrency) + +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +public struct PSQLRowSequence: AsyncSequence { + public typealias Element = PSQLRow + public typealias AsyncIterator = Iterator + + final class _Internal { + + let consumer: AsyncStreamConsumer + + init(consumer: AsyncStreamConsumer) { + self.consumer = consumer + } + + deinit { + // if no iterator was created, we need to cancel the stream + self.consumer.sequenceDeinitialized() + } + + func makeAsyncIterator() -> Iterator { + self.consumer.makeAsyncIterator() + } + } + + let _internal: _Internal + + init(_ consumer: AsyncStreamConsumer) { + self._internal = .init(consumer: consumer) + } + + public func makeAsyncIterator() -> Iterator { + self._internal.makeAsyncIterator() + } +} + +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +extension PSQLRowSequence { + public struct Iterator: AsyncIteratorProtocol { + public typealias Element = PSQLRow + + let _internal: _Internal + + init(consumer: AsyncStreamConsumer) { + self._internal = _Internal(consumer: consumer) + } + + public mutating func next() async throws -> PSQLRow? { + try await self._internal.next() + } + + final class _Internal { + struct ID: Hashable { + let objectID: ObjectIdentifier + + init(_ object: _Internal) { + self.objectID = ObjectIdentifier(object) + } + } + + var id: ID { ID(self) } + + let consumer: AsyncStreamConsumer + + init(consumer: AsyncStreamConsumer) { + self.consumer = consumer + } + + func next() async throws -> PSQLRow? { + try await self.consumer.next(for: self.id) + } + } + } +} + +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +final class AsyncStreamConsumer { + let lock = Lock() + + let lookupTable: [String: Int] + let columns: [RowDescription.Column] + let jsonDecoder: PSQLJSONDecoder + private var state: StateMachine + + init( + lookupTable: [String: Int], + columns: [RowDescription.Column], + jsonDecoder: PSQLJSONDecoder + ) { + self.state = StateMachine() + + self.lookupTable = lookupTable + self.columns = columns + self.jsonDecoder = jsonDecoder + } + + func startCompleted(_ buffer: CircularBuffer, commandTag: String) { + self.lock.withLock { + self.state.finished(buffer, commandTag: commandTag) + } + } + + func startStreaming(_ buffer: CircularBuffer, upstream: PSQLRowStream) { + self.lock.withLock { + self.state.buffered(buffer, upstream: upstream) + } + } + + func startFailed(_ error: Error) { + self.lock.withLock { + self.state.failed(error) + } + } + + func receive(_ newRows: [DataRow]) { + let receiveAction = self.lock.withLock { + self.state.receive(newRows) + } + + switch receiveAction { + case .succeed(let continuation, let data, signalDemandTo: let source): + let row = PSQLRow( + data: data, + lookupTable: self.lookupTable, + columns: self.columns, + jsonDecoder: self.jsonDecoder + ) + continuation.resume(returning: row) + source?.demand() + + case .none: + break + } + } + + func receive(completion result: Result) { + let completionAction = self.lock.withLock { + self.state.receive(completion: result) + } + + switch completionAction { + case .succeed(let continuation): + continuation.resume(returning: nil) + + case .fail(let continuation, let error): + continuation.resume(throwing: error) + + case .none: + break + } + } + + func sequenceDeinitialized() { + let action = self.lock.withLock { + self.state.sequenceDeinitialized() + } + + switch action { + case .cancelStream(let source): + source.cancel() + case .none: + break + } + } + + func makeAsyncIterator() -> PSQLRowSequence.Iterator { + let iterator = PSQLRowSequence.Iterator(consumer: self) + self.lock.withLock { + self.state.registerAsyncIteratorID(ObjectIdentifier(iterator._internal)) + } + return iterator + } + + func next(for id: PSQLRowSequence.Iterator._Internal.ID) async throws -> PSQLRow? { + self.lock.lock() + switch self.state.next() { + case .returnNil: + self.lock.unlock() + return nil + + case .returnRow(let data, signalDemandTo: let source): + self.lock.unlock() + source?.demand() + return PSQLRow( + data: data, + lookupTable: self.lookupTable, + columns: self.columns, + jsonDecoder: self.jsonDecoder + ) + + case .throwError(let error): + self.lock.unlock() + throw error + + case .hitSlowPath: + return try await withCheckedThrowingContinuation { continuation in + let slowPathAction = self.state.next(for: continuation) + self.lock.unlock() + switch slowPathAction { + case .signalDemand(let source): + source.demand() + case .none: + break + } + } + } + } + +} + +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +extension AsyncStreamConsumer { + struct StateMachine { + enum UpstreamState { + enum DemandState { + case canAskForMore + case waitingForMore(CheckedContinuation?) + } + + case initialized + case streaming(AdaptiveRowBuffer, PSQLRowStream, DemandState) + case finished(AdaptiveRowBuffer, String) + case failed(Error) + case done + + case modifying + } + + enum DownstreamState { + case sequenceCreated + case iteratorCreated(ObjectIdentifier) + } + + var upstreamState: UpstreamState + var downstreamState: DownstreamState + + init() { + self.upstreamState = .initialized + self.downstreamState = .sequenceCreated + } + + mutating func buffered(_ buffer: CircularBuffer, upstream: PSQLRowStream) { + guard case .initialized = self.upstreamState else { + preconditionFailure("Invalid upstream state: \(self.upstreamState)") + } + let adaptive = AdaptiveRowBuffer(buffer) + self.upstreamState = .streaming(adaptive, upstream, buffer.isEmpty ? .waitingForMore(nil) : .canAskForMore) + } + + mutating func finished(_ buffer: CircularBuffer, commandTag: String) { + guard case .initialized = self.upstreamState else { + preconditionFailure("Invalid upstream state: \(self.upstreamState)") + } + let adaptive = AdaptiveRowBuffer(buffer) + self.upstreamState = .finished(adaptive, commandTag) + } + + mutating func failed(_ error: Error) { + guard case .initialized = self.upstreamState else { + preconditionFailure("Invalid upstream state: \(self.upstreamState)") + } + self.upstreamState = .failed(error) + } + + mutating func registerAsyncIteratorID(_ id: ObjectIdentifier) { + switch self.downstreamState { + case .sequenceCreated: + self.downstreamState = .iteratorCreated(id) + case .iteratorCreated: + preconditionFailure("An iterator already exists") + } + } + + enum SequenceDeinitializedAction { + case cancelStream(PSQLRowStream) + case none + } + + mutating func sequenceDeinitialized() -> SequenceDeinitializedAction { + switch (self.downstreamState, self.upstreamState) { + case (.sequenceCreated, .initialized): + preconditionFailure() + + case (.sequenceCreated, .streaming(_, let source, _)): + return .cancelStream(source) + + case (.sequenceCreated, .finished), + (.sequenceCreated, .done), + (.sequenceCreated, .failed): + return .none + + case (.iteratorCreated, _): + return .none + + case (_, .modifying): + preconditionFailure() + } + } + + enum NextFastPathAction { + case hitSlowPath + case throwError(Error) + case returnRow(DataRow, signalDemandTo: PSQLRowStream?) + case returnNil + } + + mutating func next() -> NextFastPathAction { + switch self.upstreamState { + case .initialized: + preconditionFailure() + + case .streaming(var buffer, let source, .canAskForMore): + self.upstreamState = .modifying + guard let (data, demand) = buffer.popFirst() else { + self.upstreamState = .streaming(buffer, source, .canAskForMore) + return .hitSlowPath + } + if demand { + self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) + return .returnRow(data, signalDemandTo: source) + } + self.upstreamState = .streaming(buffer, source, .canAskForMore) + return .returnRow(data, signalDemandTo: nil) + + case .streaming(var buffer, let source, .waitingForMore(.none)): + self.upstreamState = .modifying + guard let (data, _) = buffer.popFirst() else { + self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) + return .hitSlowPath + } + + self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) + return .returnRow(data, signalDemandTo: nil) + + case .streaming(_, _, .waitingForMore(.some)): + preconditionFailure() + + case .finished(var buffer, let commandTag): + self.upstreamState = .modifying + guard let (data, _) = buffer.popFirst() else { + self.upstreamState = .done + return .returnNil + } + + self.upstreamState = .finished(buffer, commandTag) + return .returnRow(data, signalDemandTo: nil) + + case .failed(let error): + self.upstreamState = .done + return .throwError(error) + + case .done: + return .returnNil + + case .modifying: + preconditionFailure() + } + } + + enum NextSlowPathAction { + case signalDemand(PSQLRowStream) + case none + } + + mutating func next(for continuation: CheckedContinuation) -> NextSlowPathAction { + switch self.upstreamState { + case .initialized: + preconditionFailure() + + case .streaming(let buffer, let source, .canAskForMore): + precondition(buffer.isEmpty) + self.upstreamState = .streaming(buffer, source, .waitingForMore(continuation)) + return .signalDemand(source) + + case .streaming(let buffer, let source, .waitingForMore(.none)): + precondition(buffer.isEmpty) + self.upstreamState = .streaming(buffer, source, .waitingForMore(continuation)) + return .none + + case .streaming(_, _, .waitingForMore(.some)): + preconditionFailure() + + case .finished: + preconditionFailure() + + case .failed: + preconditionFailure() + + case .done: + preconditionFailure() + + case .modifying: + preconditionFailure() + } + } + + enum ReceiveAction { + case succeed(CheckedContinuation, DataRow, signalDemandTo: PSQLRowStream?) + case none + } + + mutating func receive(_ newRows: [DataRow]) -> ReceiveAction { + precondition(!newRows.isEmpty) + + switch self.upstreamState { + case .streaming(var buffer, let source, .waitingForMore(.some(let continuation))): + buffer.append(contentsOf: newRows) + let (first, demand) = buffer.removeFirst() + if demand { + self.upstreamState = .streaming(buffer, source, .waitingForMore(.none)) + return .succeed(continuation, first, signalDemandTo: source) + } + self.upstreamState = .streaming(buffer, source, .canAskForMore) + return .succeed(continuation, first, signalDemandTo: nil) + + case .streaming(var buffer, let source, .waitingForMore(.none)): + buffer.append(contentsOf: newRows) + self.upstreamState = .streaming(buffer, source, .canAskForMore) + return .none + + case .streaming(var buffer, let source, .canAskForMore): + buffer.append(contentsOf: newRows) + self.upstreamState = .streaming(buffer, source, .canAskForMore) + return .none + + case .initialized, .finished, .done: + preconditionFailure() + + case .failed: + return .none + + case .modifying: + preconditionFailure() + } + } + + enum CompletionResult { + case succeed(CheckedContinuation) + case fail(CheckedContinuation, Error) + case none + } + + mutating func receive(completion result: Result) -> CompletionResult { + switch result { + case .success(let commandTag): + return self.receiveEnd(commandTag: commandTag) + case .failure(let error): + return self.receiveError(error) + } + } + + mutating func receiveEnd(commandTag: String) -> CompletionResult { + switch self.upstreamState { + case .streaming(let buffer, _, .waitingForMore(.some(let continuation))): + precondition(buffer.isEmpty) + self.upstreamState = .done + return .succeed(continuation) + + case .streaming(let buffer, _, .waitingForMore(.none)): + self.upstreamState = .finished(buffer, commandTag) + return .none + + case .streaming(let buffer, _, .canAskForMore): + self.upstreamState = .finished(buffer, commandTag) + return .none + + case .initialized, .finished, .done: + preconditionFailure() + + case .failed: + return .none + + case .modifying: + preconditionFailure() + } + } + + mutating func receiveError(_ error: Error) -> CompletionResult { + switch self.upstreamState { + case .streaming(let buffer, _, .waitingForMore(.some(let continuation))): + precondition(buffer.isEmpty) + self.upstreamState = .done + return .fail(continuation, error) + + case .streaming(let buffer, _, .waitingForMore(.none)): + precondition(buffer.isEmpty) + self.upstreamState = .failed(error) + return .none + + case .streaming(_, _, .canAskForMore): + self.upstreamState = .failed(error) + return .none + + case .initialized, .finished, .done: + preconditionFailure() + + case .failed: + return .none + + case .modifying: + preconditionFailure() + } + } + } +} + +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +extension PSQLRowSequence { + func collect() async throws -> [PSQLRow] { + var result = [PSQLRow]() + for try await row in self { + result.append(row) + } + return result + } +} + +struct AdaptiveRowBuffer { + public let minimum: Int + public let maximum: Int + + private var circularBuffer: CircularBuffer + private var target: Int + private var canShrink: Bool = false + + private var hasDemand: Bool { + self.circularBuffer.count < self.maximum + } + + var isEmpty: Bool { + self.circularBuffer.isEmpty + } + + init() { + self.minimum = 1 + self.maximum = 16384 + self.target = 256 + self.circularBuffer = CircularBuffer() + } + + init(_ circularBuffer: CircularBuffer) { + self.minimum = 1 + self.maximum = 16384 + self.target = 64 + self.circularBuffer = circularBuffer + } + + mutating func append(contentsOf newRows: Rows) where Rows.Element == DataRow { + self.circularBuffer.append(contentsOf: newRows) +// print("buffer size: \(self.circularBuffer.count)") + if self.circularBuffer.count >= self.target, self.canShrink, self.target > self.minimum { + self.target &>>= 1 +// print("shrink: \(self.target)") + } + self.canShrink = true + } + + mutating func removeFirst() -> (DataRow, Bool) { + let element = self.circularBuffer.removeFirst() + + // If the buffer is drained now, we should double our target size. + if self.circularBuffer.count == 0, self.target < self.maximum { + self.target = self.target * 2 + self.canShrink = false +// print("grow: \(self.target)") + } + + return (element, self.circularBuffer.count < self.target) + } + + mutating func popFirst() -> (DataRow, Bool)? { + guard !self.circularBuffer.isEmpty else { + return nil + } + return self.removeFirst() + } +} + +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +public struct PSQLDecodedRowSequence: AsyncSequence { + public typealias AsyncIterator = Iterator + public typealias Element = (T0, T1, T2) + + @usableFromInline + let upstream: PSQLRowSequence + @usableFromInline + let file: String + @usableFromInline + let line: Int + + @inlinable + init(_ upstream: PSQLRowSequence, _ t0: T0.Type, _ t1: T1.Type, _ t2: T2.Type, file: String, line: Int) { + self.upstream = upstream + self.file = file + self.line = line + } + + @inlinable + public func makeAsyncIterator() -> Iterator { + Iterator(self.upstream.makeAsyncIterator(), file: self.file, line: self.line) + } + + public struct Iterator: AsyncIteratorProtocol { + public typealias Element = (T0, T1, T2) + + @usableFromInline + var upstream: PSQLRowSequence.Iterator + + @usableFromInline + let file: String + + @usableFromInline + let line: Int + + @inlinable + init(_ upstream: PSQLRowSequence.Iterator, file: String, line: Int) { + self.upstream = upstream + self.file = file + self.line = line + } + + @inlinable + public mutating func next() async throws -> Element? { + try await self.upstream.next()?.decode(T0.self, T1.self, T2.self, file: self.file, line: self.line) + } + } +} + +#endif + +struct AbstractStreamConsumer { + var consumer: Any? + + init() { + self.consumer = nil + } + + #if swift(>=5.5) && canImport(_Concurrency) + @available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) + init(_ consumer: AsyncStreamConsumer) { + self.consumer = consumer + } + #endif + + func receive(_ newRows: [DataRow]) { + #if swift(>=5.5) && canImport(_Concurrency) + if #available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) { + (self.consumer as! AsyncStreamConsumer).receive(newRows) + } + #endif + } + + func receive(completion result: Result) { + #if swift(>=5.5) && canImport(_Concurrency) + if #available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) { + (self.consumer as! AsyncStreamConsumer).receive(completion: result) + } + #endif + } +} diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index 768255fb..df5e33a7 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -2,7 +2,6 @@ import NIOCore import Logging final class PSQLRowStream { - enum RowSource { case stream(PSQLRowsDataSource) case noRows(Result) @@ -11,48 +10,49 @@ final class PSQLRowStream { let eventLoop: EventLoop let logger: Logger - private enum UpstreamState { - case streaming(buffer: CircularBuffer, dataSource: PSQLRowsDataSource) - case finished(buffer: CircularBuffer, commandTag: String) + private enum BufferState { + case streaming(buffer: CircularBuffer, dataSource: PSQLRowsDataSource) + case finished(buffer: CircularBuffer, commandTag: String) case failure(Error) - case consumed(Result) - case modifying } private enum DownstreamState { - case iteratingRows(onRow: (PSQLRow) throws -> (), EventLoopPromise) - case waitingForAll(EventLoopPromise<[PSQLRow]>) - case consuming + case waitingForConsumer(BufferState) + case iteratingRows(onRow: (PSQLRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) + case waitingForAll([PSQLRow], EventLoopPromise<[PSQLRow]>, PSQLRowsDataSource) + case consumed(Result) + + case asyncSequence(AbstractStreamConsumer, PSQLRowsDataSource) } - internal let rowDescription: [PSQLBackendMessage.RowDescription.Column] + internal let rowDescription: [RowDescription.Column] private let lookupTable: [String: Int] - private var upstreamState: UpstreamState private var downstreamState: DownstreamState private let jsonDecoder: PSQLJSONDecoder - init(rowDescription: [PSQLBackendMessage.RowDescription.Column], + init(rowDescription: [RowDescription.Column], queryContext: ExtendedQueryContext, eventLoop: EventLoop, rowSource: RowSource) { - let buffer = CircularBuffer() - - self.downstreamState = .consuming + let bufferState: BufferState switch rowSource { case .stream(let dataSource): - self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource) + bufferState = .streaming(buffer: .init(), dataSource: dataSource) case .noRows(.success(let commandTag)): - self.upstreamState = .finished(buffer: .init(), commandTag: commandTag) + bufferState = .finished(buffer: .init(), commandTag: commandTag) case .noRows(.failure(let error)): - self.upstreamState = .failure(error) + bufferState = .failure(error) } + self.downstreamState = .waitingForConsumer(bufferState) + self.eventLoop = eventLoop self.logger = queryContext.logger self.jsonDecoder = queryContext.jsonDecoder self.rowDescription = rowDescription + var lookup = [String: Int]() lookup.reserveCapacity(rowDescription.count) rowDescription.enumerated().forEach { (index, column) in @@ -60,6 +60,70 @@ final class PSQLRowStream { } self.lookupTable = lookup } + + // MARK: Async Sequence + + #if swift(>=5.5) && canImport(_Concurrency) + @available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) + func asyncSequence() -> PSQLRowSequence { + self.eventLoop.preconditionInEventLoop() + + guard case .waitingForConsumer(let bufferState) = self.downstreamState else { + preconditionFailure("Invalid state: \(self.downstreamState)") + } + + let consumer = AsyncStreamConsumer( + lookupTable: self.lookupTable, + columns: self.rowDescription, + jsonDecoder: self.jsonDecoder + ) + + switch bufferState { + case .streaming(let bufferedRows, let dataSource): + consumer.startStreaming(bufferedRows, upstream: self) + self.downstreamState = .asyncSequence(.init(consumer), dataSource) + + case .finished(let buffer, let commandTag): + consumer.startCompleted(buffer, commandTag: commandTag) + self.downstreamState = .consumed(.success(commandTag)) + + case .failure(let error): + consumer.startFailed(error) + self.downstreamState = .consumed(.failure(error)) + } + + return PSQLRowSequence(consumer) + } + #endif + + func demand() { + if self.eventLoop.inEventLoop { + self.demand0() + } else { + self.eventLoop.execute { + self.demand0() + } + } + } + + private func demand0() { + switch self.downstreamState { + case .waitingForConsumer, .iteratingRows, .waitingForAll: + preconditionFailure("Invalid state: \(self.downstreamState)") + + case .consumed: + break + + case .asyncSequence(_, let dataSource): + dataSource.request(for: self) + } + } + + func cancel() { + preconditionFailure("Unimplemented") + } + + // MARK: Consume in array func all() -> EventLoopFuture<[PSQLRow]> { if self.eventLoop.inEventLoop { @@ -74,40 +138,35 @@ final class PSQLRowStream { private func all0() -> EventLoopFuture<[PSQLRow]> { self.eventLoop.preconditionInEventLoop() - guard case .consuming = self.downstreamState else { - preconditionFailure("Invalid state") + guard case .waitingForConsumer(let bufferState) = self.downstreamState else { + preconditionFailure("Invalid state: \(self.downstreamState)") } - switch self.upstreamState { - case .streaming(_, let dataSource): - dataSource.request(for: self) + switch bufferState { + case .streaming(let bufferedRows, let dataSource): let promise = self.eventLoop.makePromise(of: [PSQLRow].self) - self.downstreamState = .waitingForAll(promise) + let rows = bufferedRows.map { data in + PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) + } + self.downstreamState = .waitingForAll(rows, promise, dataSource) return promise.futureResult case .finished(let buffer, let commandTag): - self.upstreamState = .modifying - let rows = buffer.map { PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) } - self.downstreamState = .consuming - self.upstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(commandTag)) return self.eventLoop.makeSucceededFuture(rows) - case .consumed: - preconditionFailure("We already signaled, that the stream has completed, why are we asked again?") - - case .modifying: - preconditionFailure("Invalid state") - case .failure(let error): - self.upstreamState = .consumed(.failure(error)) + self.downstreamState = .consumed(.failure(error)) return self.eventLoop.makeFailedFuture(error) } } + // MARK: Consume on EventLoop + func onRow(_ onRow: @escaping (PSQLRow) throws -> ()) -> EventLoopFuture { if self.eventLoop.inEventLoop { return self.onRow0(onRow) @@ -121,7 +180,11 @@ final class PSQLRowStream { private func onRow0(_ onRow: @escaping (PSQLRow) throws -> ()) -> EventLoopFuture { self.eventLoop.preconditionInEventLoop() - switch self.upstreamState { + guard case .waitingForConsumer(let bufferState) = self.downstreamState else { + preconditionFailure("Invalid state: \(self.downstreamState)") + } + + switch bufferState { case .streaming(var buffer, let dataSource): let promise = self.eventLoop.makePromise(of: Void.self) do { @@ -136,12 +199,11 @@ final class PSQLRowStream { } buffer.removeAll() - self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource) - self.downstreamState = .iteratingRows(onRow: onRow, promise) + self.downstreamState = .iteratingRows(onRow: onRow, promise, dataSource) // immediately request more dataSource.request(for: self) } catch { - self.upstreamState = .failure(error) + self.downstreamState = .consumed(.failure(error)) dataSource.cancel(for: self) promise.fail(error) } @@ -160,22 +222,15 @@ final class PSQLRowStream { try onRow(row) } - self.upstreamState = .consumed(.success(commandTag)) - self.downstreamState = .consuming + self.downstreamState = .consumed(.success(commandTag)) return self.eventLoop.makeSucceededVoidFuture() } catch { - self.upstreamState = .consumed(.failure(error)) + self.downstreamState = .consumed(.failure(error)) return self.eventLoop.makeFailedFuture(error) } - case .consumed: - preconditionFailure("We already signaled, that the stream has completed, why are we asked again?") - - case .modifying: - preconditionFailure("Invalid state") - case .failure(let error): - self.upstreamState = .consumed(.failure(error)) + self.downstreamState = .consumed(.failure(error)) return self.eventLoop.makeFailedFuture(error) } } @@ -186,20 +241,22 @@ final class PSQLRowStream { ]) } - internal func receive(_ newRows: CircularBuffer) { + internal func receive(_ newRows: [DataRow]) { precondition(!newRows.isEmpty, "Expected to get rows!") self.eventLoop.preconditionInEventLoop() self.logger.trace("Row stream received rows", metadata: [ "row_count": "\(newRows.count)" ]) - guard case .streaming(var buffer, let dataSource) = self.upstreamState else { - preconditionFailure("Invalid state") - } - switch self.downstreamState { - case .iteratingRows(let onRow, let promise): - precondition(buffer.isEmpty) + case .waitingForConsumer(.streaming(buffer: var buffer, dataSource: let dataSource)): + buffer.append(contentsOf: newRows) + self.downstreamState = .waitingForConsumer(.streaming(buffer: buffer, dataSource: dataSource)) + + case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + preconditionFailure("How can new rows be received, if an end was already signalled?") + + case .iteratingRows(let onRow, let promise, let dataSource): do { for data in newRows { let row = PSQLRow( @@ -214,82 +271,92 @@ final class PSQLRowStream { dataSource.request(for: self) } catch { dataSource.cancel(for: self) - self.upstreamState = .failure(error) + self.downstreamState = .consumed(.failure(error)) promise.fail(error) return } - case .waitingForAll: - self.upstreamState = .modifying - buffer.append(contentsOf: newRows) - self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource) + + case .waitingForAll(var rows, let promise, let dataSource): + newRows.forEach { data in + let row = PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) + rows.append(row) + } + self.downstreamState = .waitingForAll(rows, promise, dataSource) - // immediately request more - dataSource.request(for: self) + case .asyncSequence(let consumer, _): + consumer.receive(newRows) - case .consuming: - // this might happen, if the query has finished while the user is consuming data - // we don't need to ask for more since the user is consuming anyway - self.upstreamState = .modifying - buffer.append(contentsOf: newRows) - self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource) + case .consumed(.success): + preconditionFailure("How can we receive further rows, if we are supposed to be done") + + case .consumed(.failure): + break } } internal func receive(completion result: Result) { self.eventLoop.preconditionInEventLoop() - guard case .streaming(let oldBuffer, _) = self.upstreamState else { - preconditionFailure("Invalid state") + switch result { + case .success(let commandTag): + self.receiveEnd(commandTag) + case .failure(let error): + self.receiveError(error) } + } + private func receiveEnd(_ commandTag: String) { switch self.downstreamState { - case .iteratingRows(_, let promise): - precondition(oldBuffer.isEmpty) - self.downstreamState = .consuming - self.upstreamState = .consumed(result) - switch result { - case .success: - promise.succeed(()) - case .failure(let error): - promise.fail(error) - } + case .waitingForConsumer(.streaming(buffer: let buffer, _)): + self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, commandTag: commandTag)) + case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + preconditionFailure("How can we get another end, if an end was already signalled?") - case .consuming: - switch result { - case .success(let commandTag): - self.upstreamState = .finished(buffer: oldBuffer, commandTag: commandTag) - case .failure(let error): - self.upstreamState = .failure(error) - } - - case .waitingForAll(let promise): - switch result { - case .failure(let error): - self.upstreamState = .consumed(.failure(error)) - promise.fail(error) - case .success(let commandTag): - let rows = oldBuffer.map { - PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder) - } - self.upstreamState = .consumed(.success(commandTag)) - promise.succeed(rows) - } + case .iteratingRows(_, let promise, _): + self.downstreamState = .consumed(.success(commandTag)) + promise.succeed(()) + + case .waitingForAll(let rows, let promise, _): + self.downstreamState = .consumed(.success(commandTag)) + promise.succeed(rows) + + case .asyncSequence(let consumer, _): + consumer.receive(completion: .success(commandTag)) + self.downstreamState = .consumed(.success(commandTag)) + + case .consumed: + break } } - - func cancel() { - guard case .streaming(_, let dataSource) = self.upstreamState else { - // We don't need to cancel any upstream resource. All needed data is already - // included in this - return - } - dataSource.cancel(for: self) + private func receiveError(_ error: Error) { + switch self.downstreamState { + case .waitingForConsumer(.streaming): + self.downstreamState = .waitingForConsumer(.failure(error)) + + case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + preconditionFailure("How can we get another end, if an end was already signalled?") + + case .iteratingRows(_, let promise, _): + self.downstreamState = .consumed(.failure(error)) + promise.fail(error) + + case .waitingForAll(_, let promise, _): + self.downstreamState = .consumed(.failure(error)) + promise.fail(error) + + case .asyncSequence(let consumer, _): + consumer.receive(completion: .failure(error)) + self.downstreamState = .consumed(.failure(error)) + + case .consumed: + break + } } var commandTag: String { - guard case .consumed(.success(let commandTag)) = self.upstreamState else { + guard case .consumed(.success(let commandTag)) = self.downstreamState else { preconditionFailure("commandTag may only be called if all rows have been consumed") } return commandTag diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index af3e8ee4..1f7a06d6 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -21,7 +21,7 @@ enum PSQLTask { final class ExtendedQueryContext { enum Query { case unnamed(String) - case preparedStatement(name: String, rowDescription: PSQLBackendMessage.RowDescription?) + case preparedStatement(name: String, rowDescription: RowDescription?) } let query: Query @@ -65,12 +65,12 @@ final class PrepareStatementContext { let name: String let query: String let logger: Logger - let promise: EventLoopPromise + let promise: EventLoopPromise init(name: String, query: String, logger: Logger, - promise: EventLoopPromise) + promise: EventLoopPromise) { self.name = name self.query = query diff --git a/Sources/PostgresNIO/Postgres+PSQLCompat.swift b/Sources/PostgresNIO/Postgres+PSQLCompat.swift index 7af85fd3..45993800 100644 --- a/Sources/PostgresNIO/Postgres+PSQLCompat.swift +++ b/Sources/PostgresNIO/Postgres+PSQLCompat.swift @@ -28,20 +28,20 @@ struct PostgresJSONEncoderWrapper: PSQLJSONEncoder { } extension PostgresData: PSQLEncodable { - var psqlType: PSQLDataType { + public var psqlType: PSQLDataType { PSQLDataType(Int32(self.type.rawValue)) } - var psqlFormat: PSQLFormat { + public var psqlFormat: PSQLFormat { .binary } - func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + public func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { preconditionFailure("Should never be hit, since `encodeRaw` is implemented.") } // encoding - func encodeRaw(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { + public func encodeRaw(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) throws { switch self.value { case .none: byteBuffer.writeInteger(-1, as: Int32.self) @@ -53,7 +53,7 @@ extension PostgresData: PSQLEncodable { } extension PostgresData: PSQLDecodable { - static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> PostgresData { + public static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> PostgresData { let myBuffer = byteBuffer.readSlice(length: byteBuffer.readableBytes)! return PostgresData(type: PostgresDataType(UInt32(type.rawValue)), typeModifier: nil, formatCode: .binary, value: myBuffer) diff --git a/Sources/PostgresNIO/Utilities/NIOUtils.swift b/Sources/PostgresNIO/Utilities/NIOUtils.swift index 75ab8c20..544cd80a 100644 --- a/Sources/PostgresNIO/Utilities/NIOUtils.swift +++ b/Sources/PostgresNIO/Utilities/NIOUtils.swift @@ -2,7 +2,7 @@ import Foundation import NIOCore internal extension ByteBuffer { - mutating func readInteger(endianness: Endianness = .big, as rawRepresentable: E.Type) -> E? where E: RawRepresentable, E.RawValue: FixedWidthInteger { + mutating func readRawRepresentableInteger(endianness: Endianness = .big, as rawRepresentable: E.Type) -> E? where E: RawRepresentable, E.RawValue: FixedWidthInteger { guard let rawValue = readInteger(endianness: endianness, as: E.RawValue.self) else { return nil } diff --git a/Tests/IntegrationTests/PSQLConnection+AsyncTests.swift b/Tests/IntegrationTests/PSQLConnection+AsyncTests.swift new file mode 100644 index 00000000..89f80f0c --- /dev/null +++ b/Tests/IntegrationTests/PSQLConnection+AsyncTests.swift @@ -0,0 +1,383 @@ +// +// File.swift +// File +// +// Created by Fabian Fett on 19.10.21. +// + +import XCTest +import NIOCore +import NIOPosix +import Logging +@testable import PostgresNIO + +#if swift(>=5.5) && canImport(_Concurrency) +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +final class PSQLConnection_AsyncIntegrationTests: XCTestCase { + + func testConnectAndClose() { XCTAsyncTest { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let conn = try await PSQLConnection.test(on: eventLoop) + try await conn.close() + } } + + func testAuthenticationFailure() { XCTAsyncTest { + // If the postgres server trusts every connection, it is really hard to create an + // authentication failure. + try XCTSkipIf(env("POSTGRES_HOST_AUTH_METHOD") == "trust") + + let config = PSQLConnection.Configuration( + host: env("POSTGRES_HOSTNAME") ?? "localhost", + port: 5432, + username: env("POSTGRES_USER") ?? "postgres", + database: env("POSTGRES_DB"), + password: "wrong_password", + tlsConfiguration: nil) + + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + var logger = Logger.psqlTest + logger.logLevel = .info + + do { + let conn = try await PSQLConnection.connect(configuration: config, logger: logger, on: eventLoopGroup.next()) + XCTFail("Did not expect to create a connection here") + try await conn.close() + } catch { + XCTAssertTrue(error is PSQLError) + } + } } + + func testQueryVersion() { + PSQLConnection.withTestConnection { connection in + let rowSequence = try await connection.query("SELECT version()", logger: .psqlTest) + let rows = try await rowSequence.collect() + XCTAssertEqual(rows.count, 1) + XCTAssert(try rows.first?.decode(column: 0, as: String.self).contains("PostgreSQL") ?? false) + } + } + + func testQueryVersionDecodingError() { + PSQLConnection.withTestConnection { connection in + let rowSequence = try await connection.query("SELECT version()", logger: .psqlTest) + let rows = try await rowSequence.collect() + XCTAssertEqual(rows.count, 1) + XCTAssertThrowsError(try rows.first?.decode(Int?.self)) { + XCTAssertEqual(($0 as? PSQLCastingError)?.line, #line - 1) + XCTAssertEqual(($0 as? PSQLCastingError)?.file, #file) +// XCTAssert(($0 as? PSQLCastingError)?.targetType is Optional) + XCTAssertEqual(($0 as? PSQLCastingError)?.postgresType, .text) + XCTAssertEqual(($0 as? PSQLCastingError)?.columnIndex, 0) + } + } + } + + func testQuery10kItems() { + PSQLConnection.withTestConnection { connection in + let rowSequence = try await connection.query("SELECT generate_series(1, 100000);", logger: .psqlTest) + var received: Int64 = 0 + for try await row in rowSequence { + var number: Int64? + XCTAssertNoThrow(number = try row.decode(column: 0, as: Int64.self)) + received += 1 + XCTAssertEqual(number, received) + } + XCTAssertEqual(received, 100000) + } + } + + func test1kRoundTrips() { + PSQLConnection.withTestConnection { connection in + for _ in 0..<1_000 { + let rows = try await connection.query("SELECT version()", logger: .psqlTest).collect() + var version: String? + XCTAssertNoThrow(version = try rows.first?.decode(String.self)) + XCTAssertEqual(version?.contains("PostgreSQL"), true) + } + } + + } + +// func testQuerySelectParameter() { +// let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) +// defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } +// let eventLoop = eventLoopGroup.next() +// +// var conn: PSQLConnection? +// XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) +// defer { XCTAssertNoThrow(try conn?.close().wait()) } +// +// var stream: PSQLRowStream? +// XCTAssertNoThrow(stream = try conn?.query("SELECT $1::TEXT as foo", ["hello"], logger: .psqlTest).wait()) +// var rows: [PSQLRow]? +// XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) +// var foo: String? +// XCTAssertNoThrow(foo = try rows?.first?.decode(column: 0, as: String.self)) +// XCTAssertEqual(foo, "hello") +// } +// +// func testDecodeIntegers() { +// let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) +// defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } +// let eventLoop = eventLoopGroup.next() +// +// var conn: PSQLConnection? +// XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) +// defer { XCTAssertNoThrow(try conn?.close().wait()) } +// +// var stream: PSQLRowStream? +// XCTAssertNoThrow(stream = try conn?.query(""" +// SELECT +// 1::SMALLINT as smallint, +// -32767::SMALLINT as smallint_min, +// 32767::SMALLINT as smallint_max, +// 1::INT as int, +// -2147483647::INT as int_min, +// 2147483647::INT as int_max, +// 1::BIGINT as bigint, +// -9223372036854775807::BIGINT as bigint_min, +// 9223372036854775807::BIGINT as bigint_max +// """, logger: .psqlTest).wait()) +// +// var rows: [PSQLRow]? +// XCTAssertNoThrow(rows = try stream?.all().wait()) +// XCTAssertEqual(rows?.count, 1) +// let row = rows?.first +// +// XCTAssertEqual(try row?.decode(column: "smallint", as: Int16.self), 1) +// XCTAssertEqual(try row?.decode(column: "smallint_min", as: Int16.self), -32_767) +// XCTAssertEqual(try row?.decode(column: "smallint_max", as: Int16.self), 32_767) +// XCTAssertEqual(try row?.decode(column: "int", as: Int32.self), 1) +// XCTAssertEqual(try row?.decode(column: "int_min", as: Int32.self), -2_147_483_647) +// XCTAssertEqual(try row?.decode(column: "int_max", as: Int32.self), 2_147_483_647) +// XCTAssertEqual(try row?.decode(column: "bigint", as: Int64.self), 1) +// XCTAssertEqual(try row?.decode(column: "bigint_min", as: Int64.self), -9_223_372_036_854_775_807) +// XCTAssertEqual(try row?.decode(column: "bigint_max", as: Int64.self), 9_223_372_036_854_775_807) +// } +// +// func testEncodeAndDecodeIntArray() { +// let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) +// defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } +// let eventLoop = eventLoopGroup.next() +// +// var conn: PSQLConnection? +// XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) +// defer { XCTAssertNoThrow(try conn?.close().wait()) } +// +// var stream: PSQLRowStream? +// let array: [Int64] = [1, 2, 3] +// XCTAssertNoThrow(stream = try conn?.query("SELECT $1::int8[] as array", [array], logger: .psqlTest).wait()) +// +// var rows: [PSQLRow]? +// XCTAssertNoThrow(rows = try stream?.all().wait()) +// XCTAssertEqual(rows?.count, 1) +// XCTAssertEqual(try rows?.first?.decode(column: "array", as: [Int64].self), array) +// } +// +// func testDecodeEmptyIntegerArray() { +// let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) +// defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } +// let eventLoop = eventLoopGroup.next() +// +// var conn: PSQLConnection? +// XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) +// defer { XCTAssertNoThrow(try conn?.close().wait()) } +// +// var stream: PSQLRowStream? +// XCTAssertNoThrow(stream = try conn?.query("SELECT '{}'::int[] as array", logger: .psqlTest).wait()) +// +// var rows: [PSQLRow]? +// XCTAssertNoThrow(rows = try stream?.all().wait()) +// XCTAssertEqual(rows?.count, 1) +// XCTAssertEqual(try rows?.first?.decode(column: "array", as: [Int64].self), []) +// } +// +// func testDoubleArraySerialization() { +// let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) +// defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } +// let eventLoop = eventLoopGroup.next() +// +// var conn: PSQLConnection? +// XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) +// defer { XCTAssertNoThrow(try conn?.close().wait()) } +// +// var stream: PSQLRowStream? +// let doubles: [Double] = [3.14, 42] +// XCTAssertNoThrow(stream = try conn?.query("SELECT $1::double precision[] as doubles", [doubles], logger: .psqlTest).wait()) +// +// var rows: [PSQLRow]? +// XCTAssertNoThrow(rows = try stream?.all().wait()) +// XCTAssertEqual(rows?.count, 1) +// XCTAssertEqual(try rows?.first?.decode(column: "doubles", as: [Double].self), doubles) +// } +// +// func testDecodeDates() { +// let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) +// defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } +// let eventLoop = eventLoopGroup.next() +// +// var conn: PSQLConnection? +// XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) +// defer { XCTAssertNoThrow(try conn?.close().wait()) } +// +// var stream: PSQLRowStream? +// XCTAssertNoThrow(stream = try conn?.query(""" +// SELECT +// '2016-01-18 01:02:03 +0042'::DATE as date, +// '2016-01-18 01:02:03 +0042'::TIMESTAMP as timestamp, +// '2016-01-18 01:02:03 +0042'::TIMESTAMPTZ as timestamptz +// """, logger: .psqlTest).wait()) +// +// var rows: [PSQLRow]? +// XCTAssertNoThrow(rows = try stream?.all().wait()) +// XCTAssertEqual(rows?.count, 1) +// let row = rows?.first +// +// XCTAssertEqual(try row?.decode(column: "date", as: Date.self).description, "2016-01-18 00:00:00 +0000") +// XCTAssertEqual(try row?.decode(column: "timestamp", as: Date.self).description, "2016-01-18 01:02:03 +0000") +// XCTAssertEqual(try row?.decode(column: "timestamptz", as: Date.self).description, "2016-01-18 00:20:03 +0000") +// } +// +// func testDecodeUUID() { +// let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) +// defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } +// let eventLoop = eventLoopGroup.next() +// +// var conn: PSQLConnection? +// XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) +// defer { XCTAssertNoThrow(try conn?.close().wait()) } +// +// var stream: PSQLRowStream? +// XCTAssertNoThrow(stream = try conn?.query(""" +// SELECT '2c68f645-9ca6-468b-b193-ee97f241c2f8'::UUID as uuid +// """, logger: .psqlTest).wait()) +// +// var rows: [PSQLRow]? +// XCTAssertNoThrow(rows = try stream?.all().wait()) +// XCTAssertEqual(rows?.count, 1) +// +// XCTAssertEqual(try rows?.first?.decode(column: "uuid", as: UUID.self), UUID(uuidString: "2c68f645-9ca6-468b-b193-ee97f241c2f8")) +// } +// +// func testRoundTripJSONB() { +// struct Object: Codable, PSQLCodable { +// let foo: Int +// let bar: Int +// } +// +// let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) +// defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } +// let eventLoop = eventLoopGroup.next() +// +// var conn: PSQLConnection? +// XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait()) +// defer { XCTAssertNoThrow(try conn?.close().wait()) } +// +// do { +// var stream: PSQLRowStream? +// XCTAssertNoThrow(stream = try conn?.query(""" +// select $1::jsonb as jsonb +// """, [Object(foo: 1, bar: 2)], logger: .psqlTest).wait()) +// +// var rows: [PSQLRow]? +// XCTAssertNoThrow(rows = try stream?.all().wait()) +// XCTAssertEqual(rows?.count, 1) +// var result: Object? +// XCTAssertNoThrow(result = try rows?.first?.decode(column: "jsonb", as: Object.self)) +// XCTAssertEqual(result?.foo, 1) +// XCTAssertEqual(result?.bar, 2) +// } +// +// do { +// var stream: PSQLRowStream? +// XCTAssertNoThrow(stream = try conn?.query(""" +// select $1::json as json +// """, [Object(foo: 1, bar: 2)], logger: .psqlTest).wait()) +// +// var rows: [PSQLRow]? +// XCTAssertNoThrow(rows = try stream?.all().wait()) +// XCTAssertEqual(rows?.count, 1) +// var result: Object? +// XCTAssertNoThrow(result = try rows?.first?.decode(column: "json", as: Object.self)) +// XCTAssertEqual(result?.foo, 1) +// XCTAssertEqual(result?.bar, 2) +// } +// } +} + + +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +extension PSQLConnection { + + static func withTestConnection(_ closure: @escaping (PSQLConnection) async throws -> ()) { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { try! eventLoopGroup.syncShutdownGracefully() } + let eventLoop = eventLoopGroup.next() + + let dispatchGroup = DispatchGroup() + dispatchGroup.enter() + Task { + defer { + dispatchGroup.leave() + } + + let conn = try await PSQLConnection.test(on: eventLoop) + defer { Task { try await conn.close() } } + + do { + try await closure(conn) + } catch { + XCTFail("\(error)") + } + } + + dispatchGroup.wait() + } + + static func test(on eventLoop: EventLoop, logLevel: Logger.Level = .info) async throws -> PSQLConnection { + var logger = Logger(label: "psql.connection.test") + logger.logLevel = logLevel + let config = PSQLConnection.Configuration( + host: env("POSTGRES_HOSTNAME") ?? "localhost", + port: 5432, + username: env("POSTGRES_USER") ?? "postgres", + database: env("POSTGRES_DB"), + password: env("POSTGRES_PASSWORD"), + tlsConfiguration: nil) + + return try await PSQLConnection.connect(configuration: config, logger: logger, on: eventLoop) + } + +} +#endif + +#if swift(>=5.5) && canImport(_Concurrency) +// NOTE: workaround until we have async test support on linux +// https://github.com/apple/swift-corelibs-xctest/pull/326 +extension XCTestCase { + @available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) + func XCTAsyncTest( + expectationDescription: String = "Async operation", + timeout: TimeInterval = 3, + file: StaticString = #file, + line: Int = #line, + operation: @escaping () async throws -> Void + ) { + let expectation = self.expectation(description: expectationDescription) + Task { + do { try await operation() } + catch { + XCTFail("Error thrown while executing async function @ \(file):\(line): \(error)") + Thread.callStackSymbols.forEach { print($0) } + } + expectation.fulfill() + } + self.wait(for: [expectation], timeout: timeout) + } +} +#endif + diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index dabe9f1c..5cb1d686 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -106,7 +106,7 @@ final class IntegrationTests: XCTestCase { var rows: [PSQLRow]? XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) var version: String? - XCTAssertNoThrow(version = try rows?.first?.decode(column: 0, as: String.self)) + XCTAssertNoThrow(version = try rows?.first?.decode(String.self)) XCTAssertEqual(version?.contains("PostgreSQL"), true) } } @@ -125,7 +125,7 @@ final class IntegrationTests: XCTestCase { var rows: [PSQLRow]? XCTAssertNoThrow(rows = try XCTUnwrap(stream).all().wait()) var foo: String? - XCTAssertNoThrow(foo = try rows?.first?.decode(column: 0, as: String.self)) + XCTAssertNoThrow(foo = try rows?.first?.decode(String.self)) XCTAssertEqual(foo, "hello") } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index e1076a6e..39360645 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -40,25 +40,25 @@ class ExtendedQueryStateMachineTests: XCTestCase { // We need to ensure that even though the row description from the wire says that we // will receive data in `.text` format, we will actually receive it in binary format, // since we requested it in binary with our bind message. - let input: [PSQLBackendMessage.RowDescription.Column] = [ + let input: [RowDescription.Column] = [ .init(name: "version", tableOID: 0, columnAttributeNumber: 0, dataType: .text, dataTypeSize: -1, dataTypeModifier: -1, format: .text) ] - let expected: [PSQLBackendMessage.RowDescription.Column] = input.map { + let expected: [RowDescription.Column] = input.map { .init(name: $0.name, tableOID: $0.tableOID, columnAttributeNumber: $0.columnAttributeNumber, dataType: $0.dataType, dataTypeSize: $0.dataTypeSize, dataTypeModifier: $0.dataTypeModifier, format: .binary) } XCTAssertEqual(state.rowDescriptionReceived(.init(columns: input)), .wait) XCTAssertEqual(state.bindCompleteReceived(), .succeedQuery(queryContext, columns: expected)) - let row1: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test1")] + let row1: DataRow = [ByteBuffer(string: "test1")] XCTAssertEqual(state.dataRowReceived(row1), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardRows([row1])) XCTAssertEqual(state.readEventCaught(), .wait) XCTAssertEqual(state.requestQueryRows(), .read) - let row2: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test2")] - let row3: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test3")] - let row4: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test4")] + let row2: DataRow = [ByteBuffer(string: "test2")] + let row3: DataRow = [ByteBuffer(string: "test3")] + let row4: DataRow = [ByteBuffer(string: "test4")] XCTAssertEqual(state.dataRowReceived(row2), .wait) XCTAssertEqual(state.dataRowReceived(row3), .wait) XCTAssertEqual(state.dataRowReceived(row4), .wait) @@ -69,8 +69,8 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.readEventCaught(), .read) - let row5: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test5")] - let row6: PSQLBackendMessage.DataRow = [ByteBuffer(string: "test6")] + let row5: DataRow = [ByteBuffer(string: "test5")] + let row6: DataRow = [ByteBuffer(string: "test6")] XCTAssertEqual(state.dataRowReceived(row5), .wait) XCTAssertEqual(state.dataRowReceived(row6), .wait) diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift index 9b88af9a..6cff280e 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PrepareStatementStateMachineTests.swift @@ -7,7 +7,7 @@ class PrepareStatementStateMachineTests: XCTestCase { func testCreatePreparedStatementReturningRowDescription() { var state = ConnectionStateMachine.readyForQuery() - let promise = EmbeddedEventLoop().makePromise(of: PSQLBackendMessage.RowDescription?.self) + let promise = EmbeddedEventLoop().makePromise(of: RowDescription?.self) promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let name = "haha" @@ -20,7 +20,7 @@ class PrepareStatementStateMachineTests: XCTestCase { XCTAssertEqual(state.parseCompleteReceived(), .wait) XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) - let columns: [PSQLBackendMessage.RowDescription.Column] = [ + let columns: [RowDescription.Column] = [ .init(name: "id", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: -1, format: .binary) ] @@ -32,7 +32,7 @@ class PrepareStatementStateMachineTests: XCTestCase { func testCreatePreparedStatementReturningNoData() { var state = ConnectionStateMachine.readyForQuery() - let promise = EmbeddedEventLoop().makePromise(of: PSQLBackendMessage.RowDescription?.self) + let promise = EmbeddedEventLoop().makePromise(of: RowDescription?.self) promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. let name = "haha" diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index 304bb7d6..10cef9cb 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -42,7 +42,7 @@ class String_PSQLCodableTests: XCTestCase { XCTAssertEqual((error as? PSQLCastingError)?.file, #file) XCTAssertEqual((error as? PSQLCastingError)?.columnIndex, 0) - XCTAssertEqual((error as? PSQLCastingError)?.postgresData, loopBuffer) + XCTAssertEqual((error as? PSQLCastingError)?.cellData, loopBuffer) } } } @@ -57,7 +57,7 @@ class String_PSQLCodableTests: XCTestCase { XCTAssertEqual((error as? PSQLCastingError)?.file, #file) XCTAssertEqual((error as? PSQLCastingError)?.columnIndex, 0) - XCTAssertEqual((error as? PSQLCastingError)?.postgresData, nil) + XCTAssertEqual((error as? PSQLCastingError)?.cellData, nil) } } } @@ -84,7 +84,7 @@ class String_PSQLCodableTests: XCTestCase { XCTAssertEqual((error as? PSQLCastingError)?.file, #file) XCTAssertEqual((error as? PSQLCastingError)?.columnIndex, 0) - XCTAssertEqual((error as? PSQLCastingError)?.postgresData, buffer) + XCTAssertEqual((error as? PSQLCastingError)?.cellData, buffer) } } } diff --git a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift index 8b1be81e..adf8a516 100644 --- a/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/UUID+PSQLCodableTests.swift @@ -87,7 +87,7 @@ class UUID_PSQLCodableTests: XCTestCase { XCTAssertEqual((error as? PSQLCastingError)?.file, #file) XCTAssertEqual((error as? PSQLCastingError)?.columnIndex, 0) - XCTAssertEqual((error as? PSQLCastingError)?.postgresData, buffer) + XCTAssertEqual((error as? PSQLCastingError)?.cellData, buffer) } } @@ -107,7 +107,7 @@ class UUID_PSQLCodableTests: XCTestCase { XCTAssertEqual((error as? PSQLCastingError)?.file, #file) XCTAssertEqual((error as? PSQLCastingError)?.columnIndex, 0) - XCTAssertEqual((error as? PSQLCastingError)?.postgresData, loopBuffer) + XCTAssertEqual((error as? PSQLCastingError)?.cellData, loopBuffer) } } } @@ -127,7 +127,7 @@ class UUID_PSQLCodableTests: XCTestCase { XCTAssertEqual((error as? PSQLCastingError)?.file, #file) XCTAssertEqual((error as? PSQLCastingError)?.columnIndex, 0) - XCTAssertEqual((error as? PSQLCastingError)?.postgresData, data.bytes) + XCTAssertEqual((error as? PSQLCastingError)?.cellData, data.bytes) } } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift index 8434e761..3613c40b 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessage+Equatable.swift @@ -1,4 +1,5 @@ @testable import PostgresNIO +import class Foundation.JSONEncoder extension PSQLBackendMessage: Equatable { @@ -48,10 +49,17 @@ extension PSQLBackendMessage: Equatable { } } -extension PSQLBackendMessage.DataRow: ExpressibleByArrayLiteral { - public typealias ArrayLiteralElement = ByteBuffer +extension DataRow: ExpressibleByArrayLiteral { + public typealias ArrayLiteralElement = PSQLEncodable - public init(arrayLiteral elements: ByteBuffer...) { - self.init(columns: elements) + public init(arrayLiteral elements: PSQLEncodable...) { + + var buffer = ByteBuffer() + let encodingContext = PSQLEncodingContext(jsonEncoder: JSONEncoder()) + elements.forEach { element in + try! element.encodeRaw(into: &buffer, context: encodingContext) + } + + self.init(columnCount: Int16(elements.count), bytes: buffer) } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index ea5323ec..3d9ede05 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -188,19 +188,10 @@ extension PSQLBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable { } } -extension PSQLBackendMessage.DataRow: PSQLMessagePayloadEncodable { +extension DataRow: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { - buffer.writeInteger(Int16(self.columns.count)) - - for column in self.columns { - switch column { - case .none: - buffer.writeInteger(-1, as: Int32.self) - case .some(var writable): - buffer.writeInteger(Int32(writable.readableBytes)) - buffer.writeBuffer(&writable) - } - } + buffer.writeInteger(Int16(self.columnCount)) + buffer.writeBytes(self.bytes.readableBytesView) } } @@ -255,7 +246,7 @@ extension PSQLBackendMessage.TransactionState: PSQLMessagePayloadEncodable { } } -extension PSQLBackendMessage.RowDescription: PSQLMessagePayloadEncodable { +extension RowDescription: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(Int16(self.columns.count)) diff --git a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift index af9ee3f2..a5c33030 100644 --- a/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/DataRowTests.swift @@ -20,16 +20,12 @@ class DataRowTests: XCTestCase { buffer.writeBytes([UInt8](repeating: 5, count: 10)) } - let expectedColumns: [ByteBuffer?] = [ - nil, - ByteBuffer(), - ByteBuffer(bytes: [UInt8](repeating: 5, count: 10)) - ] - + let rowSlice = buffer.getSlice(at: 7, length: buffer.readableBytes - 7)! + let expectedInOuts = [ - (buffer, [PSQLBackendMessage.dataRow(.init(columns: expectedColumns))]), + (buffer, [PSQLBackendMessage.dataRow(.init(columnCount: 3, bytes: rowSlice))]), ] - + XCTAssertNoThrow(try ByteToMessageDecoderVerifier.verifyDecoder( inputOutputPairs: expectedInOuts, decoderFactory: { PSQLBackendMessageDecoder(hasAlreadyReceivedBytes: false) })) diff --git a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift index 4452ebce..8eba059d 100644 --- a/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift +++ b/Tests/PostgresNIOTests/New/Messages/RowDescriptionTests.swift @@ -6,7 +6,7 @@ import NIOTestUtils class RowDescriptionTests: XCTestCase { func testDecode() { - let columns: [PSQLBackendMessage.RowDescription.Column] = [ + let columns: [RowDescription.Column] = [ .init(name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary), .init(name: "Second", tableOID: 123, columnAttributeNumber: 456, dataType: .uuidArray, dataTypeSize: 567, dataTypeModifier: 123, format: .text), ] @@ -42,7 +42,7 @@ class RowDescriptionTests: XCTestCase { } func testDecodeFailureBecauseOfMissingNullTerminationInColumnName() { - let column = PSQLBackendMessage.RowDescription.Column( + let column = RowDescription.Column( name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) var buffer = ByteBuffer() @@ -65,7 +65,7 @@ class RowDescriptionTests: XCTestCase { } func testDecodeFailureBecauseOfMissingColumnCount() { - let column = PSQLBackendMessage.RowDescription.Column( + let column = RowDescription.Column( name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) var buffer = ByteBuffer() @@ -87,7 +87,7 @@ class RowDescriptionTests: XCTestCase { } func testDecodeFailureBecauseInvalidFormatCode() { - let column = PSQLBackendMessage.RowDescription.Column( + let column = RowDescription.Column( name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) var buffer = ByteBuffer() @@ -110,7 +110,7 @@ class RowDescriptionTests: XCTestCase { } func testDecodeFailureBecauseNegativeColumnCount() { - let column = PSQLBackendMessage.RowDescription.Column( + let column = RowDescription.Column( name: "First", tableOID: 123, columnAttributeNumber: 123, dataType: .bool, dataTypeSize: 2, dataTypeModifier: 8, format: .binary) var buffer = ByteBuffer() diff --git a/Tests/PostgresNIOTests/New/PSQLDataTests.swift b/Tests/PostgresNIOTests/New/PSQLDataTests.swift index c76b8d07..b4781cd7 100644 --- a/Tests/PostgresNIOTests/New/PSQLDataTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLDataTests.swift @@ -9,9 +9,6 @@ class PSQLDataTests: XCTestCase { let data = PSQLData(bytes: emptyBuffer, dataType: .text, format: .binary) var emptyResult: String? - XCTAssertNoThrow(emptyResult = try data.decodeIfPresent(as: String.self, context: .forTests())) - XCTAssertNil(emptyResult) - XCTAssertNoThrow(emptyResult = try data.decode(as: String?.self, context: .forTests())) XCTAssertNil(emptyResult) } diff --git a/Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift new file mode 100644 index 00000000..e3e182cc --- /dev/null +++ b/Tests/PostgresNIOTests/New/PSQLRowSequenceTests.swift @@ -0,0 +1,177 @@ +import NIOCore +import NIOConcurrencyHelpers +import XCTest +import Logging +@testable import PostgresNIO + +#if swift(>=5.5) && canImport(_Concurrency) +@available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) +class PSQLRowSequenceTests: XCTestCase { + func testSimpleSelect() { XCTAsyncTest { + let embedded = EmbeddedEventLoop() + let rowDescription: [RowDescription.Column] = [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ] + let logger = Logger(label: "test") + let queryContext = ExtendedQueryContext( + query: "SELECT * FROM foo", + bind: [], + logger: logger, + jsonDecoder: JSONDecoder(), + promise: embedded.makePromise(of: PSQLRowStream.self) + ) + let dataSource = CountDataSource() + let stream = PSQLRowStream( + rowDescription: rowDescription, + queryContext: queryContext, + eventLoop: embedded, + rowSource: .stream(dataSource) + ) + queryContext.promise.succeed(stream) + + let row1: DataRow = [ByteBuffer(integer: 0)] + stream.receive([row1]) + stream.receive(completion: .success("SELECT 1")) + let sequence = stream.asyncSequence() + + for try await row in sequence { + print("\(row)") + } + } } + + func testBackpressure() { XCTAsyncTest { + let embedded = EmbeddedEventLoop() + let rowDescription: [RowDescription.Column] = [ + .init(name: "test", tableOID: 0, columnAttributeNumber: 0, dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary) + ] + let logger = Logger(label: "test") + let queryContext = ExtendedQueryContext( + query: "SELECT * FROM foo", + bind: [], + logger: logger, + jsonDecoder: JSONDecoder(), + promise: embedded.makePromise(of: PSQLRowStream.self) + ) + let dataSource = BlockingDataSource() + let stream = PSQLRowStream( + rowDescription: rowDescription, + queryContext: queryContext, + eventLoop: embedded, + rowSource: .stream(dataSource) + ) + queryContext.promise.succeed(stream) + + @Sendable func workaround() { + // for the first rows the consumer doesn't signal demand + let row1Data: DataRow = [Int(0)] + stream.receive([row1Data]) + + for i in 1..<1000 { + XCTAssertNoThrow(try dataSource.waitForDemand(deadline: .now() + .seconds(10))) + + let rowData: DataRow = [Int(i)] + stream.receive([rowData]) + } + + // After 1000 rows, send end! + XCTAssertNoThrow(try dataSource.waitForDemand(deadline: .now() + .seconds(10))) + stream.receive(completion: .success("SELECT 1")) + } + + DispatchQueue(label: "source").async { workaround() } + + var consumed = 0 + for try await int in stream.asyncSequence().decode(Int.self) { + XCTAssertEqual(int, consumed) + consumed += 1 + XCTAssertEqual(dataSource.demandCounter, consumed) + } + } } +} + +final class CountDataSource: PSQLRowsDataSource { + + var hitRequestCounter: Int { + self._hitRequestCounter.load() + } + + var hitCancelCounter: Int { + self._hitCancelCounter.load() + } + + private let _hitRequestCounter = NIOAtomic.makeAtomic(value: 0) + private let _hitCancelCounter = NIOAtomic.makeAtomic(value: 0) + + init() {} + + func request(for stream: PSQLRowStream) { + self._hitRequestCounter.add(1) + } + + func cancel(for stream: PSQLRowStream) { + self._hitCancelCounter.add(1) + } +} + +final class BlockingDataSource: PSQLRowsDataSource { + + struct TimeoutError: Error {} + + private let demandLock = ConditionLock(value: false) + private var _demandCounter = 0 + + var demandCounter: Int { + self.demandLock.lock() + defer { self.demandLock.unlock() } + return self._demandCounter + } + + init() {} + + func request(for stream: PSQLRowStream) { + self.demandLock.lock() + self._demandCounter += 1 + self.demandLock.unlock(withValue: true) + } + + func waitForDemand(deadline: NIODeadline) throws { + let secondsUntilDeath = deadline - NIODeadline.now() + + guard self.demandLock.lock(whenValue: true, timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000)) else { + throw TimeoutError() + } + self.demandLock.unlock(withValue: false) + } + + func cancel(for stream: PSQLRowStream) { + preconditionFailure() + } +} +#endif + +#if swift(>=5.5) && canImport(_Concurrency) +// NOTE: workaround until we have async test support on linux +// https://github.com/apple/swift-corelibs-xctest/pull/326 +extension XCTestCase { + @available(macOS 12, iOS 15, tvOS 15, watchOS 8, *) + func XCTAsyncTest( + expectationDescription: String = "Async operation", + timeout: TimeInterval = 3, + file: StaticString = #file, + line: Int = #line, + operation: @escaping () async throws -> Void + ) { + let expectation = self.expectation(description: expectationDescription) + Task { + do { try await operation() } + catch { + XCTFail("Error thrown while executing async function @ \(file):\(line): \(error)") + Thread.callStackSymbols.forEach { print($0) } + } + expectation.fulfill() + } + self.wait(for: [expectation], timeout: timeout) + } +} +#endif + From 856fa9e2956c1d91243961817e96f15254edaab4 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 15 Dec 2021 09:59:13 +0100 Subject: [PATCH 2/2] Remove `writeNullTerminatedString`, `readNullTerminatedString` --- .../New/Extensions/ByteBuffer+PSQL.swift | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift index f364f007..35664180 100644 --- a/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift +++ b/Sources/PostgresNIO/New/Extensions/ByteBuffer+PSQL.swift @@ -1,20 +1,6 @@ import NIOCore internal extension ByteBuffer { - mutating func writeNullTerminatedString(_ string: String) { - self.writeString(string) - self.writeInteger(0, as: UInt8.self) - } - - mutating func readNullTerminatedString() -> String? { - guard let nullIndex = readableBytesView.firstIndex(of: 0) else { - return nil - } - - defer { moveReaderIndex(forwardBy: 1) } - return readString(length: nullIndex - readerIndex) - } - mutating func writeBackendMessageID(_ messageID: PSQLBackendMessage.ID) { self.writeInteger(messageID.rawValue) }