diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index d0f8e2b0..b4658079 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -1,27 +1,29 @@ import NIOCore import NIOPosix import NIOEmbedded -import XCTest +import Testing import Logging @testable import PostgresNIO -class PostgresConnectionTests: XCTestCase { +@Suite struct PostgresConnectionTests { let logger = Logger(label: "PostgresConnectionTests") - func testConnectionFailure() { + @Test func testConnectionFailure() { // We start a local server and close it immediately to ensure that the port // number we try to connect to is not used by any other process. - let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } - + let eventLoopGroup = NIOSingletons.posixEventLoopGroup + var tempChannel: Channel? - XCTAssertNoThrow(tempChannel = try ServerBootstrap(group: eventLoopGroup) - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)).wait()) + #expect(throws: Never.self) { + tempChannel = try ServerBootstrap(group: eventLoopGroup) + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)).wait() + } let maybePort = tempChannel?.localAddress?.port - XCTAssertNoThrow(try tempChannel?.close().wait()) + #expect(throws: Never.self) { try tempChannel?.close().wait() } guard let port = maybePort else { - return XCTFail("Could not get port number from temp started server") + Issue.record("Could not get port number from temp started server") + return } let config = PostgresConnection.Configuration( @@ -33,12 +35,14 @@ class PostgresConnectionTests: XCTestCase { var logger = Logger.psqlTest logger.logLevel = .trace - XCTAssertThrowsError(try PostgresConnection.connect(on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger).wait()) { - XCTAssertTrue($0 is PSQLError) + #expect(throws: PSQLError.self) { + try PostgresConnection + .connect(on: eventLoopGroup.next(), configuration: config, id: 1, logger: logger) + .wait() } } - func testOptionsAreSentOnTheWire() async throws { + @Test func testOptionsAreSentOnTheWire() async throws { let eventLoop = NIOAsyncTestingEventLoop() let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) @@ -71,7 +75,7 @@ class PostgresConnectionTests: XCTestCase { async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: .psqlTest) let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false)))) + #expect(message == .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false)))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) @@ -80,269 +84,275 @@ class PostgresConnectionTests: XCTestCase { try await connection.close() } - func testSimpleListen() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + @Test func testSimpleListen() async throws { + try await self.withAsyncTestingChannel { connection, channel in + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + let events = try await connection.listen("foo") + for try await event in events { + #expect(event.payload == "wooohooo") + break + } + } - try await withThrowingTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - for try await event in events { - XCTAssertEqual(event.payload, "wooohooo") + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == #"LISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + + let unlistenMessage = try await channel.waitForUnpreparedRequest() + #expect(unlistenMessage.parse.query == #"UNLISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + switch await taskGroup.nextResult()! { + case .success: break + case .failure(let failure): + Issue.record("Unexpected error: \(failure)") } } - - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - - let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") - } } } - func testSimpleListenDoesNotUnlistenIfThereIsAnotherSubscriber() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + @Test func testSimpleListenDoesNotUnlistenIfThereIsAnotherSubscriber() async throws { + try await self.withAsyncTestingChannel { connection, channel in - try await withThrowingTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - for try await event in events { - XCTAssertEqual(event.payload, "wooohooo") - break + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + let events = try await connection.listen("foo") + for try await event in events { + #expect(event.payload == "wooohooo") + break + } } - } - taskGroup.addTask { - let events = try await connection.listen("foo") - var counter = 0 - loop: for try await event in events { - defer { counter += 1 } - switch counter { - case 0: - XCTAssertEqual(event.payload, "wooohooo") - case 1: - XCTAssertEqual(event.payload, "wooohooo2") - break loop - default: - XCTFail("Unexpected message: \(event)") + taskGroup.addTask { + let events = try await connection.listen("foo") + var counter = 0 + loop: for try await event in events { + defer { counter += 1 } + switch counter { + case 0: + #expect(event.payload == "wooohooo") + case 1: + #expect(event.payload == "wooohooo2") + break loop + default: + Issue.record("Unexpected message: \(event)") + } } } - } - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo2"))) - - let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") - } - } - } + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == #"LISTEN "foo";"#) - func testSimpleListenConnectionDrops() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup in - taskGroup.addTask { - let events = try await connection.listen("foo") - var iterator = events.makeAsyncIterator() - let first = try await iterator.next() - XCTAssertEqual(first?.payload, "wooohooo") - do { - _ = try await iterator.next() - XCTFail("Did not expect to not throw") - } catch { - logger.error("error", metadata: ["error": "\(error)"]) - } - } + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) - - try await channel.writeInbound(PostgresBackendMessage.parseComplete) - try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - try await channel.writeInbound(PostgresBackendMessage.noData) - try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - - try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) - struct MyWeirdError: Error {} - channel.pipeline.fireErrorCaught(MyWeirdError()) - - switch await taskGroup.nextResult()! { - case .success: - break - case .failure(let failure): - XCTFail("Unexpected error: \(failure)") + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo2"))) + + let unlistenMessage = try await channel.waitForUnpreparedRequest() + #expect(unlistenMessage.parse.query == #"UNLISTEN "foo";"#) + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + try await channel.writeInbound(PostgresBackendMessage.noData) + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.writeInbound(PostgresBackendMessage.commandComplete("UNLISTEN")) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + switch await taskGroup.nextResult()! { + case .success: + break + case .failure(let failure): + Issue.record("Unexpected error: \(failure)") + } } } } - func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in - for _ in 1...2 { + @Test func testSimpleListenConnectionDrops() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup in taskGroup.addTask { - let rows = try await connection.query("SELECT 1;", logger: logger) - var iterator = rows.decode(Int.self).makeAsyncIterator() + let events = try await connection.listen("foo") + var iterator = events.makeAsyncIterator() let first = try await iterator.next() - XCTAssertEqual(first, 1) - let second = try await iterator.next() - XCTAssertNil(second) + #expect(first?.payload == "wooohooo") + do { + _ = try await iterator.next() + Issue.record("Did not expect to not throw") + } catch { + logger.error("error", metadata: ["error": "\(error)"]) + } } - } - for i in 0...1 { let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") - - if i == 0 { - taskGroup.addTask { - try await connection.closeGracefully() - } - } + #expect(listenMessage.parse.query == #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) - let intDescription = RowDescription.Column( - name: "", - tableOID: 0, - columnAttributeNumber: 0, - dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary - ) - try await channel.writeInbound(PostgresBackendMessage.rowDescription(.init(columns: [intDescription]))) - try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.noData) try await channel.writeInbound(PostgresBackendMessage.bindComplete) - try await channel.testingEventLoop.executeInContext { channel.read() } - try await channel.writeInbound(PostgresBackendMessage.dataRow([Int(1)])) - try await channel.testingEventLoop.executeInContext { channel.read() } - try await channel.writeInbound(PostgresBackendMessage.commandComplete("SELECT 1 1")) - try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.commandComplete("LISTEN")) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - } - let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(terminate, .terminate) - try await channel.closeFuture.get() - XCTAssertEqual(channel.isActive, false) + try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) + struct MyWeirdError: Error {} + channel.pipeline.fireErrorCaught(MyWeirdError()) - while let taskResult = await taskGroup.nextResult() { - switch taskResult { + switch await taskGroup.nextResult()! { case .success: break case .failure(let failure): - XCTFail("Unexpected error: \(failure)") + Issue.record("Unexpected error: \(failure)") } } } } - func testCloseClosesImmediatly() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + @Test func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { + try await self.withAsyncTestingChannel { connection, channel in + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + let rows = try await connection.query("SELECT 1;", logger: logger) + var iterator = rows.decode(Int.self).makeAsyncIterator() + let first = try await iterator.next() + #expect(first == 1) + let second = try await iterator.next() + #expect(second == nil) + } + } + + for i in 0...1 { + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == "SELECT 1;") - try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in - for _ in 1...2 { - taskGroup.addTask { - try await connection.query("SELECT 1;", logger: logger) + if i == 0 { + taskGroup.addTask { + try await connection.closeGracefully() + } + } + + try await channel.writeInbound(PostgresBackendMessage.parseComplete) + try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) + let intDescription = RowDescription.Column( + name: "", + tableOID: 0, + columnAttributeNumber: 0, + dataType: .int8, dataTypeSize: 8, dataTypeModifier: 0, format: .binary + ) + try await channel.writeInbound(PostgresBackendMessage.rowDescription(.init(columns: [intDescription]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.bindComplete) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.dataRow([Int(1)])) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.commandComplete("SELECT 1 1")) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + } + + let terminate = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + #expect(terminate == .terminate) + try await channel.closeFuture.get() + #expect(!channel.isActive) + + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + break + case .failure(let failure): + Issue.record("Unexpected error: \(failure)") + } } } + } + } - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + @Test func testCloseClosesImmediatly() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in + for _ in 1...2 { + taskGroup.addTask { + try await connection.query("SELECT 1;", logger: logger) + } + } + + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == "SELECT 1;") - async let close: () = connection.close() + async let close: () = connection.close() - try await channel.closeFuture.get() - XCTAssertEqual(channel.isActive, false) + try await channel.closeFuture.get() + #expect(!channel.isActive) - try await close + try await close - while let taskResult = await taskGroup.nextResult() { - switch taskResult { - case .success: - XCTFail("Expected queries to fail") - case .failure(let failure): - guard let error = failure as? PSQLError else { - return XCTFail("Unexpected error type: \(failure)") + while let taskResult = await taskGroup.nextResult() { + switch taskResult { + case .success: + Issue.record("Expected queries to fail") + case .failure(let failure): + guard let error = failure as? PSQLError else { + Issue.record("Unexpected error type: \(failure)") + return + } + #expect(error.code == .clientClosedConnection) } - XCTAssertEqual(error.code, .clientClosedConnection) } } } } - func testIfServerJustClosesTheErrorReflectsThat() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - let logger = self.logger + @Test func testIfServerJustClosesTheErrorReflectsThat() async throws { + try await self.withAsyncTestingChannel { connection, channel in + let logger = self.logger - async let response = try await connection.query("SELECT 1;", logger: logger) + async let response = try await connection.query("SELECT 1;", logger: logger) - let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") + let listenMessage = try await channel.waitForUnpreparedRequest() + #expect(listenMessage.parse.query == "SELECT 1;") - try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() } - try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() } + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelInactive() } + try await channel.testingEventLoop.executeInContext { channel.pipeline.fireChannelUnregistered() } - do { - _ = try await response - XCTFail("Expected to throw") - } catch { - XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) - } + do { + _ = try await response + Issue.record("Expected to throw") + } catch { + #expect((error as? PSQLError)?.code == .serverClosedConnection) + } - // retry on same connection + // retry on same connection - do { - _ = try await connection.query("SELECT 1;", logger: self.logger) - XCTFail("Expected to throw") - } catch { - XCTAssertEqual((error as? PSQLError)?.code, .serverClosedConnection) + do { + _ = try await connection.query("SELECT 1;", logger: self.logger) + Issue.record("Expected to throw") + } catch { + #expect((error as? PSQLError)?.code == .serverClosedConnection) + } } } @@ -363,282 +373,287 @@ class PostgresConnectionTests: XCTestCase { } } - func testPreparedStatement() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "active") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database", database) + @Test func testPreparedStatement() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database" == database) + } + #expect(rows == 1) } - XCTAssertEqual(rows, 1) - } - let prepareRequest = try await channel.waitForPrepareRequest() - XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.first, .text) - guard case .preparedStatement(let name) = prepareRequest.describe else { - fatalError("Describe should contain a prepared statement") - } - XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - - try await channel.sendPrepareResponse( - parameterDescription: .init(dataTypes: [ - PostgresDataType.text - ]), - rowDescription: .init(columns: [ - .init( - name: "datname", - tableOID: 12222, - columnAttributeNumber: 2, - dataType: .name, - dataTypeSize: 64, - dataTypeModifier: -1, - format: .text - ) - ]) - ) + let prepareRequest = try await channel.waitForPrepareRequest() + #expect(prepareRequest.parse.query == "SELECT datname FROM pg_stat_activity WHERE state = $1") + #expect(prepareRequest.parse.parameters.first == .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + #expect(name == String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) - let preparedRequest = try await channel.waitForPreparedRequest() - XCTAssertEqual(preparedRequest.bind.preparedStatementName, String(reflecting: TestPrepareStatement.self)) - XCTAssertEqual(preparedRequest.bind.parameters.count, 1) - XCTAssertEqual(preparedRequest.bind.resultColumnFormats, [.binary]) + let preparedRequest = try await channel.waitForPreparedRequest() + #expect(preparedRequest.bind.preparedStatementName == String(reflecting: TestPrepareStatement.self)) + #expect(preparedRequest.bind.parameters.count == 1) + #expect(preparedRequest.bind.resultColumnFormats == [.binary]) - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database"] - ], - commandTag: TestPrepareStatement.sql - ) + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + } } } - func testWeDontCrashOnUnexpectedChannelEvents() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + @Test func testWeDontCrashOnUnexpectedChannelEvents() async throws { + try await self.withAsyncTestingChannel { connection, channel in - enum MyEvent { - case pleaseDontCrash + enum MyEvent { + case pleaseDontCrash + } + channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash) + try await connection.close() } - channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash) - try await connection.close() } - func testSerialExecutionOfSamePreparedStatement() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in - // Send the same prepared statement twice, but with different parameters. - // Send one first and wait to send the other request until preparation is complete - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "active") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database", database) + @Test func testSerialExecutionOfSamePreparedStatement() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database" == database) + } + #expect(rows == 1) } - XCTAssertEqual(rows, 1) - } - let prepareRequest = try await channel.waitForPrepareRequest() - XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.first, .text) - guard case .preparedStatement(let name) = prepareRequest.describe else { - fatalError("Describe should contain a prepared statement") - } - XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - - try await channel.sendPrepareResponse( - parameterDescription: .init(dataTypes: [ - PostgresDataType.text - ]), - rowDescription: .init(columns: [ - .init( - name: "datname", - tableOID: 12222, - columnAttributeNumber: 2, - dataType: .name, - dataTypeSize: 64, - dataTypeModifier: -1, - format: .text - ) - ]) - ) + let prepareRequest = try await channel.waitForPrepareRequest() + #expect(prepareRequest.parse.query == "SELECT datname FROM pg_stat_activity WHERE state = $1") + #expect(prepareRequest.parse.parameters.first == .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + #expect(name == String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) - let preparedRequest1 = try await channel.waitForPreparedRequest() - var buffer = preparedRequest1.bind.parameters[0]! - let parameter1 = buffer.readString(length: buffer.readableBytes)! - XCTAssertEqual(parameter1, "active") - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database"] - ], - commandTag: TestPrepareStatement.sql - ) + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + #expect(parameter1 == "active") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) - // Now that the statement has been prepared and executed, send another request that will only get executed - // without preparation - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "idle") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database", database) + // Now that the statement has been prepared and executed, send another request that will only get executed + // without preparation + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database" == database) + } + #expect(rows == 1) } - XCTAssertEqual(rows, 1) - } - let preparedRequest2 = try await channel.waitForPreparedRequest() - buffer = preparedRequest2.bind.parameters[0]! - let parameter2 = buffer.readString(length: buffer.readableBytes)! - XCTAssertEqual(parameter2, "idle") - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database"] - ], - commandTag: TestPrepareStatement.sql - ) - // Ensure we received and responded to both the requests - let parameters = [parameter1, parameter2] - XCTAssert(parameters.contains("active")) - XCTAssert(parameters.contains("idle")) + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + #expect(parameter2 == "idle") + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + #expect(parameters.contains("active")) + #expect(parameters.contains("idle")) + } } } - func testStatementPreparationOnlyHappensOnceWithConcurrentRequests() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in - // Send the same prepared statement twice, but with different parameters. - // Let them race to tests that requests and responses aren't mixed up - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "active") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database_active", database) + @Test func testStatementPreparationOnlyHappensOnceWithConcurrentRequests() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Let them race to tests that requests and responses aren't mixed up + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database_active" == database) + } + #expect(rows == 1) } - XCTAssertEqual(rows, 1) - } - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "idle") - let result = try await connection.execute(preparedStatement, logger: .psqlTest) - var rows = 0 - for try await database in result { - rows += 1 - XCTAssertEqual("test_database_idle", database) + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + let result = try await connection.execute(preparedStatement, logger: .psqlTest) + var rows = 0 + for try await database in result { + rows += 1 + #expect("test_database_idle" == database) + } + #expect(rows == 1) } - XCTAssertEqual(rows, 1) - } - // The channel deduplicates prepare requests, we're going to see only one of them - let prepareRequest = try await channel.waitForPrepareRequest() - XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.first, .text) - guard case .preparedStatement(let name) = prepareRequest.describe else { - fatalError("Describe should contain a prepared statement") - } - XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - - try await channel.sendPrepareResponse( - parameterDescription: .init(dataTypes: [ - PostgresDataType.text - ]), - rowDescription: .init(columns: [ - .init( - name: "datname", - tableOID: 12222, - columnAttributeNumber: 2, - dataType: .name, - dataTypeSize: 64, - dataTypeModifier: -1, - format: .text - ) - ]) - ) + // The channel deduplicates prepare requests, we're going to see only one of them + let prepareRequest = try await channel.waitForPrepareRequest() + #expect(prepareRequest.parse.query == "SELECT datname FROM pg_stat_activity WHERE state = $1") + #expect(prepareRequest.parse.parameters.first == .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + #expect(name == String(reflecting: TestPrepareStatement.self)) + + try await channel.sendPrepareResponse( + parameterDescription: .init(dataTypes: [ + PostgresDataType.text + ]), + rowDescription: .init(columns: [ + .init( + name: "datname", + tableOID: 12222, + columnAttributeNumber: 2, + dataType: .name, + dataTypeSize: 64, + dataTypeModifier: -1, + format: .text + ) + ]) + ) - // Now both the tasks have their statements prepared. - // We should see both of their execute requests coming in, the order is nondeterministic - let preparedRequest1 = try await channel.waitForPreparedRequest() - var buffer = preparedRequest1.bind.parameters[0]! - let parameter1 = buffer.readString(length: buffer.readableBytes)! - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database_\(parameter1)"] - ], - commandTag: TestPrepareStatement.sql - ) - let preparedRequest2 = try await channel.waitForPreparedRequest() - buffer = preparedRequest2.bind.parameters[0]! - let parameter2 = buffer.readString(length: buffer.readableBytes)! - try await channel.sendPreparedResponse( - dataRows: [ - ["test_database_\(parameter2)"] - ], - commandTag: TestPrepareStatement.sql - ) - // Ensure we received and responded to both the requests - let parameters = [parameter1, parameter2] - XCTAssert(parameters.contains("active")) - XCTAssert(parameters.contains("idle")) + // Now both the tasks have their statements prepared. + // We should see both of their execute requests coming in, the order is nondeterministic + let preparedRequest1 = try await channel.waitForPreparedRequest() + var buffer = preparedRequest1.bind.parameters[0]! + let parameter1 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter1)"] + ], + commandTag: TestPrepareStatement.sql + ) + let preparedRequest2 = try await channel.waitForPreparedRequest() + buffer = preparedRequest2.bind.parameters[0]! + let parameter2 = buffer.readString(length: buffer.readableBytes)! + try await channel.sendPreparedResponse( + dataRows: [ + ["test_database_\(parameter2)"] + ], + commandTag: TestPrepareStatement.sql + ) + // Ensure we received and responded to both the requests + let parameters = [parameter1, parameter2] + #expect(parameters.contains("active")) + #expect(parameters.contains("idle")) + } } } - func testStatementPreparationFailure() async throws { - let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in - // Send the same prepared statement twice, but with different parameters. - // Send one first and wait to send the other request until preparation is complete - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "active") - do { - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Was supposed to fail") - } catch { - XCTAssert(error is PSQLError) + @Test func testStatementPreparationFailure() async throws { + try await self.withAsyncTestingChannel { connection, channel in + + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + // Send the same prepared statement twice, but with different parameters. + // Send one first and wait to send the other request until preparation is complete + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "active") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + Issue.record("Was supposed to fail") + } catch { + #expect(error is PSQLError) + } } - } - let prepareRequest = try await channel.waitForPrepareRequest() - XCTAssertEqual(prepareRequest.parse.query, "SELECT datname FROM pg_stat_activity WHERE state = $1") - XCTAssertEqual(prepareRequest.parse.parameters.first, .text) - guard case .preparedStatement(let name) = prepareRequest.describe else { - fatalError("Describe should contain a prepared statement") - } - XCTAssertEqual(name, String(reflecting: TestPrepareStatement.self)) - - // Respond with an error taking care to return a SQLSTATE that isn't - // going to get the connection closed. - try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ - .sqlState : "26000" // invalid_sql_statement_name - ]))) - try await channel.testingEventLoop.executeInContext { channel.read() } - try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) - try await channel.testingEventLoop.executeInContext { channel.read() } - - - // Send another requests with the same prepared statement, which should fail straight - // away without any interaction with the server - taskGroup.addTask { - let preparedStatement = TestPrepareStatement(state: "idle") - do { - _ = try await connection.execute(preparedStatement, logger: .psqlTest) - XCTFail("Was supposed to fail") - } catch { - XCTAssert(error is PSQLError) + let prepareRequest = try await channel.waitForPrepareRequest() + #expect(prepareRequest.parse.query == "SELECT datname FROM pg_stat_activity WHERE state = $1") + #expect(prepareRequest.parse.parameters.first == .text) + guard case .preparedStatement(let name) = prepareRequest.describe else { + fatalError("Describe should contain a prepared statement") + } + #expect(name == String(reflecting: TestPrepareStatement.self)) + + // Respond with an error taking care to return a SQLSTATE that isn't + // going to get the connection closed. + try await channel.writeInbound(PostgresBackendMessage.error(.init(fields: [ + .sqlState : "26000" // invalid_sql_statement_name + ]))) + try await channel.testingEventLoop.executeInContext { channel.read() } + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + try await channel.testingEventLoop.executeInContext { channel.read() } + + + // Send another requests with the same prepared statement, which should fail straight + // away without any interaction with the server + taskGroup.addTask { + let preparedStatement = TestPrepareStatement(state: "idle") + do { + _ = try await connection.execute(preparedStatement, logger: .psqlTest) + Issue.record("Was supposed to fail") + } catch { + #expect(error is PSQLError) + } } } } } - func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { + func withAsyncTestingChannel(_ body: (PostgresConnection, NIOAsyncTestingChannel) async throws -> ()) async throws { let eventLoop = NIOAsyncTestingEventLoop() let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) @@ -656,18 +671,20 @@ class PostgresConnectionTests: XCTestCase { let logger = self.logger async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: logger) let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) - XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) + #expect(message == .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) let connection = try await connectionPromise - self.addTeardownBlock { - try await connection.close() + do { + try await body(connection, channel) + } catch { + } - return (connection, channel) + try await connection.close() } }