diff --git a/Package.swift b/Package.swift index 575d3cdc..1c8d3c07 100644 --- a/Package.swift +++ b/Package.swift @@ -34,13 +34,17 @@ let package = Package( ), .testTarget( name: "IntegrationTests", - dependencies: ["Valkey"] + dependencies: [ + "Valkey", + ] ), .testTarget( name: "ValkeyTests", dependencies: [ "Valkey", .product(name: "NIOTestUtils", package: "swift-nio"), + .product(name: "Logging", package: "swift-log"), + .product(name: "NIOEmbedded", package: "swift-nio"), ] ), ] diff --git a/Sources/Valkey/Connection/ValkeyConnection.swift b/Sources/Valkey/Connection/ValkeyConnection.swift index fa0cd06b..658cbc37 100644 --- a/Sources/Valkey/Connection/ValkeyConnection.swift +++ b/Sources/Valkey/Connection/ValkeyConnection.swift @@ -191,15 +191,7 @@ public final class ValkeyConnection: Sendable { let connect = bootstrap.channelInitializer { channel in do { - let sync = channel.pipeline.syncOperations - if case .enable(let sslContext, let tlsServerName) = configuration.tls.base { - try sync.addHandler(NIOSSLClientHandler(context: sslContext, serverHostname: tlsServerName)) - } - let valkeyChannelHandler = ValkeyChannelHandler( - eventLoop: channel.eventLoop, - logger: logger - ) - try sync.addHandler(valkeyChannelHandler) + try self._setupChannel(channel, configuration: configuration, logger: logger) return eventLoop.makeSucceededVoidFuture() } catch { return eventLoop.makeFailedFuture(error) @@ -226,6 +218,36 @@ public final class ValkeyConnection: Sendable { } } + package static func setupChannel(_ channel: any Channel, configuration: ValkeyClientConfiguration, logger: Logger) async throws -> ValkeyConnection { + if !channel.eventLoop.inEventLoop { + return try await channel.eventLoop.submit { + let handler = try self._setupChannel(channel, configuration: configuration, logger: logger) + return ValkeyConnection(channel: channel, channelHandler: handler, configuration: configuration, logger: logger) + }.get() + } + + let handler = try self._setupChannel(channel, configuration: configuration, logger: logger) + return ValkeyConnection(channel: channel, channelHandler: handler, configuration: configuration, logger: logger) + } + + @discardableResult + private static func _setupChannel(_ channel: any Channel, configuration: ValkeyClientConfiguration, logger: Logger) throws -> ValkeyChannelHandler { + channel.eventLoop.assertInEventLoop() + let sync = channel.pipeline.syncOperations + switch configuration.tls.base { + case .enable(let sslContext, let tlsServerName): + try sync.addHandler(NIOSSLClientHandler(context: sslContext, serverHostname: tlsServerName)) + case .disable: + break + } + let valkeyChannelHandler = ValkeyChannelHandler( + eventLoop: channel.eventLoop, + logger: logger + ) + try sync.addHandler(valkeyChannelHandler) + return valkeyChannelHandler + } + /// create a BSD sockets based bootstrap private static func createSocketsBootstrap(eventLoopGroup: EventLoopGroup) -> ClientBootstrap { ClientBootstrap(group: eventLoopGroup) diff --git a/Tests/ValkeyTests/ValkeyConnectionTests.swift b/Tests/ValkeyTests/ValkeyConnectionTests.swift new file mode 100644 index 00000000..64b10e73 --- /dev/null +++ b/Tests/ValkeyTests/ValkeyConnectionTests.swift @@ -0,0 +1,38 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the swift-valkey open source project +// +// Copyright (c) 2025 Apple Inc. and the swift-valkey project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of swift-valkey project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIOCore +import NIOEmbedded +import Testing +import Valkey + +@Suite +struct ConnectionTests { + + @Test + func testConnectionCreationAndGET() async throws { + let channel = NIOAsyncTestingChannel() + let logger = Logger(label: "test") + let connection = try await ValkeyConnection.setupChannel(channel, configuration: .init(), logger: logger) + + async let fooResult = connection.get(key: "foo") + + let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self) + #expect(outbound == ByteBuffer(string: "*2\r\n$3\r\nGET\r\n$3\r\nfoo\r\n")) + + try await channel.writeInbound(ByteBuffer(string: "$3\r\nBar\r\n")) + #expect(try await fooResult == "Bar") + } +}