Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,17 @@ let package = Package(
),
.testTarget(
name: "IntegrationTests",
dependencies: ["Valkey"]
dependencies: [
"Valkey",
]
),
.testTarget(
name: "ValkeyTests",
dependencies: [
"Valkey",
.product(name: "NIOTestUtils", package: "swift-nio"),
.product(name: "Logging", package: "swift-log"),
.product(name: "NIOEmbedded", package: "swift-nio"),
]
),
]
Expand Down
40 changes: 31 additions & 9 deletions Sources/Valkey/Connection/ValkeyConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,7 @@ public final class ValkeyConnection: Sendable {

let connect = bootstrap.channelInitializer { channel in
do {
let sync = channel.pipeline.syncOperations
if case .enable(let sslContext, let tlsServerName) = configuration.tls.base {
try sync.addHandler(NIOSSLClientHandler(context: sslContext, serverHostname: tlsServerName))
}
let valkeyChannelHandler = ValkeyChannelHandler(
eventLoop: channel.eventLoop,
logger: logger
)
try sync.addHandler(valkeyChannelHandler)
try self._setupChannel(channel, configuration: configuration, logger: logger)
return eventLoop.makeSucceededVoidFuture()
} catch {
return eventLoop.makeFailedFuture(error)
Expand All @@ -226,6 +218,36 @@ public final class ValkeyConnection: Sendable {
}
}

package static func setupChannel(_ channel: any Channel, configuration: ValkeyClientConfiguration, logger: Logger) async throws -> ValkeyConnection {
if !channel.eventLoop.inEventLoop {
return try await channel.eventLoop.submit {
let handler = try self._setupChannel(channel, configuration: configuration, logger: logger)
return ValkeyConnection(channel: channel, channelHandler: handler, configuration: configuration, logger: logger)
}.get()
}

let handler = try self._setupChannel(channel, configuration: configuration, logger: logger)
return ValkeyConnection(channel: channel, channelHandler: handler, configuration: configuration, logger: logger)
}

@discardableResult
private static func _setupChannel(_ channel: any Channel, configuration: ValkeyClientConfiguration, logger: Logger) throws -> ValkeyChannelHandler {
channel.eventLoop.assertInEventLoop()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is the same as in _makeClient() which sets up the bootstrap. Maybe use the same function in _makeClient.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

let sync = channel.pipeline.syncOperations
switch configuration.tls.base {
case .enable(let sslContext, let tlsServerName):
try sync.addHandler(NIOSSLClientHandler(context: sslContext, serverHostname: tlsServerName))
case .disable:
break
}
let valkeyChannelHandler = ValkeyChannelHandler(
eventLoop: channel.eventLoop,
logger: logger
)
try sync.addHandler(valkeyChannelHandler)
return valkeyChannelHandler
}

/// create a BSD sockets based bootstrap
private static func createSocketsBootstrap(eventLoopGroup: EventLoopGroup) -> ClientBootstrap {
ClientBootstrap(group: eventLoopGroup)
Expand Down
38 changes: 38 additions & 0 deletions Tests/ValkeyTests/ValkeyConnectionTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the swift-valkey open source project
//
// Copyright (c) 2025 Apple Inc. and the swift-valkey project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of swift-valkey project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import Logging
import NIOCore
import NIOEmbedded
import Testing
import Valkey

@Suite
struct ConnectionTests {

@Test
func testConnectionCreationAndGET() async throws {
let channel = NIOAsyncTestingChannel()
let logger = Logger(label: "test")
let connection = try await ValkeyConnection.setupChannel(channel, configuration: .init(), logger: logger)

async let fooResult = connection.get(key: "foo")

let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
#expect(outbound == ByteBuffer(string: "*2\r\n$3\r\nGET\r\n$3\r\nfoo\r\n"))

try await channel.writeInbound(ByteBuffer(string: "$3\r\nBar\r\n"))
#expect(try await fooResult == "Bar")
}
}