From 596857e2e49e8ebf95a8a9bdd03295f96d05ebb3 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 5 Sep 2025 12:23:18 +0100 Subject: [PATCH] Add ValkeyConnection.triggerGracefulShutdown Signed-off-by: Adam Fowler --- .../ValkeyChannelHandler+stateMachine.swift | 9 ++++---- .../Connection/ValkeyChannelHandler.swift | 9 ++++++++ .../Valkey/Connection/ValkeyConnection.swift | 8 +++++++ ...alkeyChannelHandlerStateMachineTests.swift | 20 ++++++++--------- Tests/ValkeyTests/ValkeyConnectionTests.swift | 22 +++++++++++++++++++ 5 files changed, 53 insertions(+), 15 deletions(-) diff --git a/Sources/Valkey/Connection/ValkeyChannelHandler+stateMachine.swift b/Sources/Valkey/Connection/ValkeyChannelHandler+stateMachine.swift index 7be94dba..2d46c02c 100644 --- a/Sources/Valkey/Connection/ValkeyChannelHandler+stateMachine.swift +++ b/Sources/Valkey/Connection/ValkeyChannelHandler+stateMachine.swift @@ -330,14 +330,13 @@ extension ValkeyChannelHandler { } @usableFromInline - enum GracefulShutdownAction { - case waitForPendingCommands(Context) + enum TriggerGracefulShutdownAction { case closeConnection(Context) case doNothing } /// Want to gracefully shutdown the handler @usableFromInline - mutating func gracefulShutdown() -> GracefulShutdownAction { + mutating func triggerGracefulShutdown() -> TriggerGracefulShutdownAction { switch consume self.state { case .initialized: self = .closed(nil) @@ -346,11 +345,11 @@ extension ValkeyChannelHandler { var pendingCommands = state.pendingCommands pendingCommands.prepend(state.pendingHelloCommand) self = .closing(.init(context: state.context, pendingCommands: pendingCommands)) - return .waitForPendingCommands(state.context) + return .doNothing case .active(let state): if state.pendingCommands.count > 0 { self = .closing(.init(context: state.context, pendingCommands: state.pendingCommands)) - return .waitForPendingCommands(state.context) + return .doNothing } else { self = .closed(nil) return .closeConnection(state.context) diff --git a/Sources/Valkey/Connection/ValkeyChannelHandler.swift b/Sources/Valkey/Connection/ValkeyChannelHandler.swift index 79b02d38..a6a70aea 100644 --- a/Sources/Valkey/Connection/ValkeyChannelHandler.swift +++ b/Sources/Valkey/Connection/ValkeyChannelHandler.swift @@ -517,6 +517,15 @@ final class ValkeyChannelHandler: ChannelInboundHandler { break } } + + func triggerGracefulShutdown() { + switch self.stateMachine.triggerGracefulShutdown() { + case .closeConnection(let context): + context.close(mode: .all, promise: nil) + case .doNothing: + break + } + } } @available(valkeySwift 1.0, *) diff --git a/Sources/Valkey/Connection/ValkeyConnection.swift b/Sources/Valkey/Connection/ValkeyConnection.swift index 1821d2d9..14c11fcf 100644 --- a/Sources/Valkey/Connection/ValkeyConnection.swift +++ b/Sources/Valkey/Connection/ValkeyConnection.swift @@ -164,6 +164,14 @@ public final actor ValkeyConnection: ValkeyClientProtocol, Sendable { try await self.channelHandler.waitOnActive().get() } + /// Trigger graceful shutdown of connection + /// + /// The connection will wait until all pending commands have been processed before + /// closing the connection. + func triggerGracefulShutdown() { + self.channelHandler.triggerGracefulShutdown() + } + /// Send RESP command to Valkey connection /// - Parameter command: ValkeyCommand structure /// - Returns: The command response as defined in the ValkeyCommand diff --git a/Tests/ValkeyTests/ValkeyChannelHandlerStateMachineTests.swift b/Tests/ValkeyTests/ValkeyChannelHandlerStateMachineTests.swift index c0433e78..09a7be39 100644 --- a/Tests/ValkeyTests/ValkeyChannelHandlerStateMachineTests.swift +++ b/Tests/ValkeyTests/ValkeyChannelHandlerStateMachineTests.swift @@ -147,7 +147,7 @@ struct ValkeyChannelHandlerStateMachineTests { var stateMachine = ValkeyChannelHandler.StateMachine() stateMachine.setConnected(context: "testGracefulShutdown") stateMachine.receiveHelloResponse() - switch stateMachine.gracefulShutdown() { + switch stateMachine.triggerGracefulShutdown() { case .closeConnection(let context): #expect(context == "testGracefulShutdown") default: @@ -169,10 +169,10 @@ struct ValkeyChannelHandlerStateMachineTests { case .throwError: Issue.record("Invalid sendCommand action") } - switch stateMachine.gracefulShutdown() { - case .waitForPendingCommands(let context): - #expect(context == "testGracefulShutdown") - default: + switch stateMachine.triggerGracefulShutdown() { + case .doNothing: + break + case .closeConnection: Issue.record("Invalid waitForPendingCommands action") } expect( @@ -208,10 +208,10 @@ struct ValkeyChannelHandlerStateMachineTests { case .throwError: Issue.record("Invalid sendCommand action") } - switch stateMachine.gracefulShutdown() { - case .waitForPendingCommands(let context): - #expect(context == "testClosedClosingState") - default: + switch stateMachine.triggerGracefulShutdown() { + case .doNothing: + break + case .closeConnection: Issue.record("Invalid waitForPendingCommands action") } expect( @@ -334,7 +334,7 @@ struct ValkeyChannelHandlerStateMachineTests { case .throwError: Issue.record("Invalid sendCommand action") } - _ = stateMachine.gracefulShutdown() + _ = stateMachine.triggerGracefulShutdown() switch stateMachine.cancel(requestID: 23) { case .failPendingCommandsAndClose(let context, let cancel, let closeConnectionDueToCancel): #expect(context == "testCancelGracefulShutdown") diff --git a/Tests/ValkeyTests/ValkeyConnectionTests.swift b/Tests/ValkeyTests/ValkeyConnectionTests.swift index 0beb9fcb..711b45a1 100644 --- a/Tests/ValkeyTests/ValkeyConnectionTests.swift +++ b/Tests/ValkeyTests/ValkeyConnectionTests.swift @@ -587,6 +587,28 @@ struct ConnectionTests { try await channel.close() } + @Test + @available(valkeySwift 1.0, *) + func testTriggerGracefulShutdown() async throws { + let channel = NIOAsyncTestingChannel() + let logger = Logger(label: "test") + let connection = try await ValkeyConnection.setupChannelAndConnect(channel, configuration: .init(), logger: logger) + try await channel.processHello() + + async let fooResult = connection.get("foo").map { String(buffer: $0) } + + let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self) + #expect(outbound == RESPToken(.command(["GET", "foo"])).base) + + await connection.triggerGracefulShutdown() + #expect(channel.isActive) + + try await channel.writeInbound(RESPToken(.bulkString("Bar")).base) + #expect(try await fooResult == "Bar") + + try await channel.closeFuture.get() + } + #if DistributedTracingSupport @Suite struct DistributedTracingTests {