diff --git a/Sources/PostgreSQL/Connection/PostgreSQLConnection+NotifyAndListen.swift b/Sources/PostgreSQL/Connection/PostgreSQLConnection+NotifyAndListen.swift index c1c650d7..06ab84f0 100644 --- a/Sources/PostgreSQL/Connection/PostgreSQLConnection+NotifyAndListen.swift +++ b/Sources/PostgreSQL/Connection/PostgreSQLConnection+NotifyAndListen.swift @@ -1,19 +1,26 @@ import Async + extension PostgreSQLConnection { + /// Note: after calling `listen'` on a connection, it can no longer handle other database operations. Do not try to send other SQL commands through this connection afterwards. + /// IAlso, notifications will only be sent for as long as this connection remains open; you are responsible for opening a new connection to listen on when this one closes. public func listen( _ channelName: String, handler: @escaping (String) throws -> () ) throws -> Future { - beforeClose = { conn in + closeHandlers.append({ conn in let query = PostgreSQLQuery(query: "UNLISTEN \"\(channelName)\";") return conn.send([.query(query)], onResponse: { _ in }) + }) + + notificationHandlers[channelName] = { message in + try handler(message) } let query = PostgreSQLQuery(query: "LISTEN \"\(channelName)\";") return queue.enqueue([.query(query)], onInput: { message in switch message { case let .notificationResponse(notification): - try handler(notification.message) + try self.notificationHandlers[notification.channel]?(notification.message) default: break } @@ -26,4 +33,10 @@ extension PostgreSQLConnection { let query = PostgreSQLQuery(query: "NOTIFY \"\(channelName)\", '\(message)';") return send([.query(query)]).map(to: Void.self, { _ in }) } + + public func unlisten(_ channelName: String, unlistenHandler: (() -> Void)? = nil) throws -> Future { + notificationHandlers.removeValue(forKey: channelName) + let query = PostgreSQLQuery(query: "UNLISTEN \"\(channelName)\";") + return send([.query(query)], onResponse: { _ in unlistenHandler?() }) + } } diff --git a/Sources/PostgreSQL/Connection/PostgreSQLConnection.swift b/Sources/PostgreSQL/Connection/PostgreSQLConnection.swift index f70b7f2d..6bb3a27e 100644 --- a/Sources/PostgreSQL/Connection/PostgreSQLConnection.swift +++ b/Sources/PostgreSQL/Connection/PostgreSQLConnection.swift @@ -44,6 +44,15 @@ public final class PostgreSQLConnection: DatabaseConnection, BasicWorker { /// The current query running, if one exists. private var pipeline: Future + /// Block type to be called on close of connection + internal typealias CloseHandler = ((PostgreSQLConnection) -> Future) + /// Called on close of the connection + internal var closeHandlers = [CloseHandler]() + /// Handler type for Notifications + internal typealias NotificationHandler = (String) throws -> Void + /// Handlers to be stored by channel name + internal var notificationHandlers: [String: NotificationHandler] = [:] + /// Creates a new Redis client on the provided data source and sink. init(queue: QueueHandler, channel: Channel) { self.queue = queue @@ -184,19 +193,24 @@ public final class PostgreSQLConnection: DatabaseConnection, BasicWorker { } } - internal var beforeClose: ((PostgreSQLConnection) -> Future)? /// Closes this client. public func close() { - if let beforeClose = beforeClose { - _ = beforeClose(self).then { _ in - self.channel.close(mode: CloseMode.all) + _ = executeCloseHandlersThenClose() + } + + + private func executeCloseHandlersThenClose() -> Future { + if let beforeClose = closeHandlers.popLast() { + return beforeClose(self).then { _ in + self.executeCloseHandlersThenClose() } } else { - channel.close(promise: nil) + return channel.close(mode: .all) } } + /// Called when this class deinitializes. deinit { close() diff --git a/Sources/PostgreSQL/Message/PostgreSQLNotificationResponse.swift b/Sources/PostgreSQL/Message/PostgreSQLNotificationResponse.swift index c34c6a6a..cc454c27 100644 --- a/Sources/PostgreSQL/Message/PostgreSQLNotificationResponse.swift +++ b/Sources/PostgreSQL/Message/PostgreSQLNotificationResponse.swift @@ -2,13 +2,12 @@ import Foundation struct PostgreSQLNotificationResponse: Decodable { /// The message coming from PSQL + let channel: String let message: String init(from decoder: Decoder) throws { let container = try decoder.singleValueContainer() - _ = try container.decode(Int32.self) // message length - _ = try container.decode(Int32.self) // process id of message - let channelId = try container.decode(String.self) - let message = try? container.decode(String.self) - self.message = message ?? channelId + _ = try container.decode(Int32.self) + channel = try container.decode(String.self) + message = try container.decode(String.self) } } diff --git a/Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift b/Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift index e8cac7ec..b51c5d51 100644 --- a/Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift +++ b/Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift @@ -348,7 +348,7 @@ class PostgreSQLConnectionTests: XCTestCase { let completionHandlerExpectation2 = expectation(description: "final completion handler called") let notifyConn = try PostgreSQLConnection.makeTest() let listenConn = try PostgreSQLConnection.makeTest() - let channelName = "Foo" + let channelName = "Fooze" let messageText = "Bar" let finalMessageText = "Baz" @@ -363,11 +363,65 @@ class PostgreSQLConnectionTests: XCTestCase { try notifyConn.notify(channelName, message: messageText).wait() try notifyConn.notify(channelName, message: finalMessageText).wait() + waitForExpectations(timeout: defaultTimeout) notifyConn.close() listenConn.close() + } + + func testNotifyAndListenOnMultipleChannels() throws { + let completionHandlerExpectation1 = expectation(description: "first completion handler called") + let completionHandlerExpectation2 = expectation(description: "final completion handler called") + let notifyConn = try PostgreSQLConnection.makeTest() + let listenConn = try PostgreSQLConnection.makeTest() + let channelName = "Fooze" + let channelName2 = "Foozalz" + let messageText = "Bar" + let finalMessageText = "Baz" + + try listenConn.listen(channelName) { text in + if text == messageText { + completionHandlerExpectation1.fulfill() + } + }.catch({ err in XCTFail("error \(err)") }) + + try listenConn.listen(channelName2) { text in + if text == finalMessageText { + completionHandlerExpectation2.fulfill() + } + }.catch({ err in XCTFail("error \(err)") }) + + try notifyConn.notify(channelName, message: messageText).wait() + try notifyConn.notify(channelName2, message: finalMessageText).wait() + waitForExpectations(timeout: defaultTimeout) + notifyConn.close() + listenConn.close() } + func testUnlisten() throws { + let unlistenHandlerExpectation = expectation(description: "unlisten completion handler called") + + let listenHandlerExpectation = expectation(description: "listen completion handler called") + + let notifyConn = try PostgreSQLConnection.makeTest() + let listenConn = try PostgreSQLConnection.makeTest() + let channelName = "Foozers" + let messageText = "Bar" + + try listenConn.listen(channelName) { text in + if text == messageText { + listenHandlerExpectation.fulfill() + } + }.catch({ err in XCTFail("error \(err)") }) + + try notifyConn.notify(channelName, message: messageText).wait() + try notifyConn.unlisten(channelName, unlistenHandler: { + unlistenHandlerExpectation.fulfill() + }).wait() + waitForExpectations(timeout: defaultTimeout) + notifyConn.close() + listenConn.close() + } func testURLParsing() throws { let databaseURL = "postgres://username:password@hostname.com:5432/database" @@ -389,6 +443,8 @@ class PostgreSQLConnectionTests: XCTestCase { ("testNull", testNull), ("testGH24", testGH24), ("testNotifyAndListen", testNotifyAndListen), + ("testNotifyAndListenOnMultipleChannels", testNotifyAndListenOnMultipleChannels), + ("testUnlisten", testUnlisten), ("testURLParsing", testURLParsing), ] } @@ -398,8 +454,7 @@ extension PostgreSQLConnection { static func makeTest() throws -> PostgreSQLConnection { let hostname: String #if Xcode - //hostname = (try? Process.execute("docker-machine", "ip")) ?? "192.168.99.100" - hostname = "localhost" + hostname = (try? Process.execute("docker-machine", "ip")) ?? "192.168.99.100" #else hostname = "localhost" #endif