Skip to content

Commit

Permalink
Fix prepared statements
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianfett committed Feb 21, 2024
1 parent 85d189c commit 738d0b7
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 34 deletions.
9 changes: 6 additions & 3 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ public final class PostgresConnection: @unchecked Sendable {
let context = ExtendedQueryContext(
name: name,
query: query,
bindingDataTypes: [],
logger: logger,
promise: promise
)
Expand Down Expand Up @@ -472,9 +473,10 @@ extension PostgresConnection {
let bindings = try preparedStatement.makeBindings()
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let task = HandlerTask.executePreparedStatement(.init(
name: String(reflecting: Statement.self),
name: Statement.name,
sql: Statement.sql,
bindings: bindings,
bindingDataTypes: Statement.bindingDataTypes,
logger: logger,
promise: promise
))
Expand All @@ -493,10 +495,10 @@ extension PostgresConnection {
)
throw error // rethrow with more metadata
}

}

/// Execute a prepared statement, taking care of the preparation when necessary
@_disfavoredOverload
public func execute<Statement: PostgresPreparedStatement>(
_ preparedStatement: Statement,
logger: Logger,
Expand All @@ -506,9 +508,10 @@ extension PostgresConnection {
let bindings = try preparedStatement.makeBindings()
let promise = self.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let task = HandlerTask.executePreparedStatement(.init(
name: String(reflecting: Statement.self),
name: Statement.name,
sql: Statement.sql,
bindings: bindings,
bindingDataTypes: Statement.bindingDataTypes,
logger: logger,
promise: promise
))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct ConnectionStateMachine {
case forwardStreamError(PSQLError, read: Bool, cleanupContext: CleanUpContext?)

// Prepare statement actions
case sendParseDescribeSync(name: String, query: String)
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
case succeedPreparedStatementCreation(EventLoopPromise<RowDescription?>, with: RowDescription?)
case failPreparedStatementCreation(EventLoopPromise<RowDescription?>, with: PSQLError, cleanupContext: CleanUpContext?)

Expand Down Expand Up @@ -587,7 +587,7 @@ struct ConnectionStateMachine {
switch queryContext.query {
case .executeStatement(_, let promise), .unnamed(_, let promise):
return .failQuery(promise, with: psqlErrror, cleanupContext: nil)
case .prepareStatement(_, _, let promise):
case .prepareStatement(_, _, _, let promise):
return .failPreparedStatementCreation(promise, with: psqlErrror, cleanupContext: nil)
}
case .closeCommand(let closeContext):
Expand Down Expand Up @@ -1057,8 +1057,8 @@ extension ConnectionStateMachine {
return .read
case .wait:
return .wait
case .sendParseDescribeSync(name: let name, query: let query):
return .sendParseDescribeSync(name: name, query: query)
case .sendParseDescribeSync(name: let name, query: let query, bindingDataTypes: let bindingDataTypes):
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
case .succeedPreparedStatementCreation(let promise, with: let rowDescription):
return .succeedPreparedStatementCreation(promise, with: rowDescription)
case .failPreparedStatementCreation(let promise, with: let error):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ struct ExtendedQueryStateMachine {

enum Action {
case sendParseDescribeBindExecuteSync(PostgresQuery)
case sendParseDescribeSync(name: String, query: String)
case sendParseDescribeSync(name: String, query: String, bindingDataTypes: [PostgresDataType])
case sendBindExecuteSync(PSQLExecuteStatement)

// --- general actions
Expand Down Expand Up @@ -79,10 +79,10 @@ struct ExtendedQueryStateMachine {
return .sendBindExecuteSync(prepared)
}

case .prepareStatement(let name, let query, _):
case .prepareStatement(let name, let query, let bindingDataTypes, _):
return self.avoidingStateMachineCoW { state -> Action in
state = .messagesSent(queryContext)
return .sendParseDescribeSync(name: name, query: query)
return .sendParseDescribeSync(name: name, query: query, bindingDataTypes: bindingDataTypes)
}
}
}
Expand All @@ -107,7 +107,7 @@ struct ExtendedQueryStateMachine {
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
return .failQuery(eventLoopPromise, with: .queryCancelled)

case .prepareStatement(_, _, let eventLoopPromise):
case .prepareStatement(_, _, _, let eventLoopPromise):
return .failPreparedStatementCreation(eventLoopPromise, with: .queryCancelled)
}

Expand Down Expand Up @@ -165,7 +165,7 @@ struct ExtendedQueryStateMachine {
return .wait
}

case .prepareStatement(_, _, let promise):
case .prepareStatement(_, _, _, let promise):
return self.avoidingStateMachineCoW { state -> Action in
state = .noDataMessageReceived(queryContext)
return .succeedPreparedStatementCreation(promise, with: nil)
Expand Down Expand Up @@ -200,7 +200,7 @@ struct ExtendedQueryStateMachine {
case .unnamed, .executeStatement:
return .wait

case .prepareStatement(_, _, let eventLoopPromise):
case .prepareStatement(_, _, _, let eventLoopPromise):
return .succeedPreparedStatementCreation(eventLoopPromise, with: rowDescription)
}
}
Expand Down Expand Up @@ -477,7 +477,7 @@ struct ExtendedQueryStateMachine {
switch context.query {
case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise):
return .failQuery(eventLoopPromise, with: error)
case .prepareStatement(_, _, let eventLoopPromise):
case .prepareStatement(_, _, _, let eventLoopPromise):
return .failPreparedStatementCreation(eventLoopPromise, with: error)
}
}
Expand Down
14 changes: 11 additions & 3 deletions Sources/PostgresNIO/New/PSQLTask.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ enum PSQLTask {
eventLoopPromise.fail(error)
case .executeStatement(_, let eventLoopPromise):
eventLoopPromise.fail(error)
case .prepareStatement(_, _, let eventLoopPromise):
case .prepareStatement(_, _, _, let eventLoopPromise):
eventLoopPromise.fail(error)
}

Expand All @@ -35,7 +35,7 @@ final class ExtendedQueryContext {
enum Query {
case unnamed(PostgresQuery, EventLoopPromise<PSQLRowStream>)
case executeStatement(PSQLExecuteStatement, EventLoopPromise<PSQLRowStream>)
case prepareStatement(name: String, query: String, EventLoopPromise<RowDescription?>)
case prepareStatement(name: String, query: String, bindingDataTypes: [PostgresDataType], EventLoopPromise<RowDescription?>)
}

let query: Query
Expand All @@ -62,17 +62,19 @@ final class ExtendedQueryContext {
init(
name: String,
query: String,
bindingDataTypes: [PostgresDataType],
logger: Logger,
promise: EventLoopPromise<RowDescription?>
) {
self.query = .prepareStatement(name: name, query: query, promise)
self.query = .prepareStatement(name: name, query: query, bindingDataTypes: bindingDataTypes, promise)
self.logger = logger
}
}

final class PreparedStatementContext: Sendable {
let name: String
let sql: String
let bindingDataTypes: [PostgresDataType]
let bindings: PostgresBindings
let logger: Logger
let promise: EventLoopPromise<PSQLRowStream>
Expand All @@ -81,12 +83,18 @@ final class PreparedStatementContext: Sendable {
name: String,
sql: String,
bindings: PostgresBindings,
bindingDataTypes: [PostgresDataType],
logger: Logger,
promise: EventLoopPromise<PSQLRowStream>
) {
self.name = name
self.sql = sql
self.bindings = bindings
if bindingDataTypes.isEmpty {
self.bindingDataTypes = bindings.metadata.map(\.dataType)
} else {
self.bindingDataTypes = bindingDataTypes
}
self.logger = logger
self.promise = promise
}
Expand Down
10 changes: 6 additions & 4 deletions Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
self.closeConnectionAndCleanup(cleanupContext, context: context)
case .fireChannelInactive:
context.fireChannelInactive()
case .sendParseDescribeSync(let name, let query):
self.sendParseDecribeAndSyncMessage(statementName: name, query: query, context: context)
case .sendParseDescribeSync(let name, let query, let bindingDataTypes):
self.sendParseDescribeAndSyncMessage(statementName: name, query: query, bindingDataTypes: bindingDataTypes, context: context)
case .sendBindExecuteSync(let executeStatement):
self.sendBindExecuteAndSyncMessage(executeStatement: executeStatement, context: context)
case .sendParseDescribeBindExecuteSync(let query):
Expand Down Expand Up @@ -489,13 +489,14 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
}
}

private func sendParseDecribeAndSyncMessage(
private func sendParseDescribeAndSyncMessage(
statementName: String,
query: String,
bindingDataTypes: [PostgresDataType],
context: ChannelHandlerContext
) {
precondition(self.rowStream == nil, "Expected to not have an open stream at this point")
self.encoder.parse(preparedStatementName: statementName, query: query, parameters: [])
self.encoder.parse(preparedStatementName: statementName, query: query, parameters: bindingDataTypes)
self.encoder.describePreparedStatement(statementName)
self.encoder.sync()
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
Expand Down Expand Up @@ -724,6 +725,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
return .extendedQuery(.init(
name: preparedStatement.name,
query: preparedStatement.sql,
bindingDataTypes: preparedStatement.bindingDataTypes,
logger: preparedStatement.logger,
promise: promise
))
Expand Down
23 changes: 22 additions & 1 deletion Sources/PostgresNIO/New/PreparedStatement.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,36 @@
/// Structs conforming to this protocol can then be used with `PostgresConnection.execute(_ preparedStatement:, logger:)`,
/// which will take care of preparing the statement on the server side and executing it.
public protocol PostgresPreparedStatement: Sendable {
/// The prepared statements name.
///
/// > Note: There is a default implementation that returns the implementor's name.
static var name: String { get }

/// The type rows returned by the statement will be decoded into
associatedtype Row

/// The SQL statement to prepare on the database server.
static var sql: String { get }

/// Make the bindings to provided concrete values to use when executing the prepared SQL statement
/// The postgres data types of the values that are bind when this statement is executed.
///
/// If an empty array is returned the datatypes are inferred from the ``PostgresBindings`` returned
/// from ``PostgresPreparedStatement/makeBindings()``.
///
/// > Note: There is a default implementation that returns an empty array, which will lead to
/// automatic inference.
static var bindingDataTypes: [PostgresDataType] { get }

/// Make the bindings to provided concrete values to use when executing the prepared SQL statement.
/// The order must match ``PostgresPreparedStatement/bindingDataTypes-4b6tx``.
func makeBindings() throws -> PostgresBindings

/// Decode a row returned by the database into an instance of `Row`
func decodeRow(_ row: PostgresRow) throws -> Row
}

extension PostgresPreparedStatement {
public static var name: String { String(reflecting: self) }

public static var bindingDataTypes: [PostgresDataType] { [] }
}
81 changes: 81 additions & 0 deletions Tests/IntegrationTests/AsyncTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,87 @@ final class AsyncPostgresConnectionTests: XCTestCase {
}
}
}

static let preparedStatementTestTable = "AsyncTestPreparedStatementTestTable"
func testPreparedStatementWithIntegerBinding() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

struct InsertPreparedStatement: PostgresPreparedStatement {
static let name = "INSERT-AsyncTestPreparedStatementTestTable"

static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementTestTable)" (uuid) VALUES ($1);"#
typealias Row = ()

var uuid: UUID

func makeBindings() -> PostgresBindings {
var bindings = PostgresBindings()
bindings.append(self.uuid)
return bindings
}

func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
()
}
}

struct SelectPreparedStatement: PostgresPreparedStatement {
static let name = "SELECT-AsyncTestPreparedStatementTestTable"

static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementTestTable)" WHERE id <= $1;"#
typealias Row = (Int, UUID)

var id: Int

func makeBindings() -> PostgresBindings {
var bindings = PostgresBindings()
bindings.append(self.id)
return bindings
}

func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row {
try row.decode((Int, UUID).self)
}
}

do {
try await withTestConnection(on: eventLoop) { connection in
try await connection.query("""
CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementTestTable)" (
id SERIAL PRIMARY KEY,
uuid UUID NOT NULL
)
""",
logger: .psqlTest
)

_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)
_ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest)

let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest)
var counter = 0
for try await (id, uuid) in rows {
Logger.psqlTest.info("Received row", metadata: [
"id": "\(id)", "uuid": "\(uuid)"
])
counter += 1
}

try await connection.query("""
DROP TABLE "\(unescaped: Self.preparedStatementTestTable)";
""",
logger: .psqlTest
)
}
} catch {
XCTFail("Unexpected error: \(String(describing: error))")
}
}
}

extension XCTestCase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ class PrepareStatementStateMachineTests: XCTestCase {
let name = "haha"
let query = #"SELECT id FROM users WHERE id = $1 "#
let prepareStatementContext = ExtendedQueryContext(
name: name, query: query, logger: .psqlTest, promise: promise
name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise
)

XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)),
.sendParseDescribeSync(name: name, query: query))
.sendParseDescribeSync(name: name, query: query, bindingDataTypes: []))
XCTAssertEqual(state.parseCompleteReceived(), .wait)
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)

Expand All @@ -38,11 +38,11 @@ class PrepareStatementStateMachineTests: XCTestCase {
let name = "haha"
let query = #"DELETE FROM users WHERE id = $1 "#
let prepareStatementContext = ExtendedQueryContext(
name: name, query: query, logger: .psqlTest, promise: promise
name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise
)

XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)),
.sendParseDescribeSync(name: name, query: query))
.sendParseDescribeSync(name: name, query: query, bindingDataTypes: []))
XCTAssertEqual(state.parseCompleteReceived(), .wait)
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)

Expand All @@ -60,11 +60,11 @@ class PrepareStatementStateMachineTests: XCTestCase {
let name = "haha"
let query = #"DELETE FROM users WHERE id = $1 "#
let prepareStatementContext = ExtendedQueryContext(
name: name, query: query, logger: .psqlTest, promise: promise
name: name, query: query, bindingDataTypes: [], logger: .psqlTest, promise: promise
)

XCTAssertEqual(state.enqueue(task: .extendedQuery(prepareStatementContext)),
.sendParseDescribeSync(name: name, query: query))
.sendParseDescribeSync(name: name, query: query, bindingDataTypes: []))
XCTAssertEqual(state.parseCompleteReceived(), .wait)
XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait)

Expand Down
Loading

0 comments on commit 738d0b7

Please sign in to comment.