Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions Sources/PostgresNIO/Connection/PostgresConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,34 @@ extension PostgresConnection {
}
}

#if swift(>=5.5) && canImport(_Concurrency)
extension PostgresConnection {
func close() async throws {
try await self.close().get()
}

func query(_ query: PostgresQuery, logger: Logger, file: String = #file, line: UInt = #line) async throws -> PostgresRowSequence {
var logger = logger
logger[postgresMetadataKey: .connectionID] = "\(self.underlying.connectionID)"

do {
guard query.binds.count <= Int(Int16.max) else {
throw PSQLError.tooManyParameters
}
let promise = self.underlying.channel.eventLoop.makePromise(of: PSQLRowStream.self)
let context = ExtendedQueryContext(
query: query,
logger: logger,
promise: promise)

self.underlying.channel.write(PSQLTask.extendedQuery(context), promise: nil)

return try await promise.futureResult.map({ $0.asyncSequence() }).get()
}
}
}
#endif

// MARK: PostgresDatabase

extension PostgresConnection: PostgresDatabase {
Expand Down
67 changes: 67 additions & 0 deletions Tests/IntegrationTests/AsyncTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import Logging
import XCTest
@testable import PostgresNIO

#if swift(>=5.5.2)
final class AsyncPostgresConnectionTests: XCTestCase {

func test1kRoundTrips() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

try await withTestConnection(on: eventLoop) { connection in
for _ in 0..<1_000 {
let rows = try await connection.query("SELECT version()", logger: .psqlTest)
var iterator = rows.makeAsyncIterator()
let firstRow = try await iterator.next()
XCTAssertEqual(try firstRow?.decode(String.self, context: .default).contains("PostgreSQL"), true)
let done = try await iterator.next()
XCTAssertNil(done)
}
}
}

func testSelect10kRows() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

let start = 1
let end = 10000

try await withTestConnection(on: eventLoop) { connection in
let rows = try await connection.query("SELECT generate_series(\(start), \(end));", logger: .psqlTest)
var counter = 1
for try await row in rows {
XCTAssertEqual(try row.decode(Int.self, context: .default), counter)
counter += 1
}

XCTAssertEqual(counter, end + 1)
}
}
}

extension XCTestCase {

func withTestConnection<Result>(
on eventLoop: EventLoop,
file: StaticString = #file,
line: UInt = #line,
_ closure: (PostgresConnection) async throws -> Result
) async throws -> Result {
let connection = try await PostgresConnection.test(on: eventLoop).get()

do {
let result = try await closure(connection)
try await connection.close()
return result
} catch {
XCTFail("Unexpected error: \(error)", file: file, line: line)
try await connection.close()
throw error
}
}
}
#endif