Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for connecting to a Postgres database via Unix sockets. #70

Merged
merged 5 commits into from
May 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -3,14 +3,24 @@ import NIO
import NIOOpenSSL

extension PostgreSQLConnection {
/// Connects to a Redis server using a TCP socket.
@available(*, deprecated, message: "Use `.connect(to:...)` instead.")
public static func connect(
hostname: String = "localhost",
port: Int = 5432,
transport: PostgreSQLTransportConfig = .cleartext,
on worker: Worker,
onError: @escaping (Error) -> ()
) throws -> Future<PostgreSQLConnection> {
) throws -> Future<PostgreSQLConnection> {
return try connect(to: .tcp(hostname: hostname, port: port), transport: transport, on: worker, onError: onError)
}

/// Connects to a PostgreSQL server using a TCP socket.
public static func connect(
to serverAddress: PostgreSQLDatabaseConfig.ServerAddress = .default,
transport: PostgreSQLTransportConfig = .cleartext,
on worker: Worker,
onError: @escaping (Error) -> ()
) throws -> Future<PostgreSQLConnection> {
let handler = QueueHandler<PostgreSQLMessage, PostgreSQLMessage>(on: worker, onError: onError)
let bootstrap = ClientBootstrap(group: worker.eventLoop)
// Enable SO_REUSEADDR.
Expand All @@ -20,8 +30,16 @@ extension PostgreSQLConnection {
channel.pipeline.add(handler: handler)
}
}

return bootstrap.connect(host: hostname, port: port).flatMap { channel in

let connectedBootstrap: Future<Channel>
switch serverAddress {
case let .tcp(hostname, port):
connectedBootstrap = bootstrap.connect(host: hostname, port: port)
case let .unixSocket(socketPath):
connectedBootstrap = bootstrap.connect(unixDomainSocketPath: socketPath)
}

return connectedBootstrap.flatMap { channel in
let connection = PostgreSQLConnection(queue: handler, channel: channel)
if case .tls(let tlsConfiguration) = transport.method {
return connection.addSSLClientHandler(using: tlsConfiguration).transform(to: connection)
Expand Down
4 changes: 2 additions & 2 deletions Sources/PostgreSQL/Connection/PostgreSQLConnection.swift
Expand Up @@ -9,7 +9,7 @@ public final class PostgreSQLConnection: DatabaseConnection, BasicWorker {
return channel.eventLoop
}

/// Handles enqueued redis commands and responses.
/// Handles enqueued PostgreSQL commands and responses.
internal let queue: QueueHandler<PostgreSQLMessage, PostgreSQLMessage>

/// The channel
Expand Down Expand Up @@ -39,7 +39,7 @@ public final class PostgreSQLConnection: DatabaseConnection, BasicWorker {
/// Handlers to be stored by channel name
internal var notificationHandlers: [String: NotificationHandler] = [:]

/// Creates a new Redis client on the provided data source and sink.
/// Creates a new PostgreSQL client on the provided data source and sink.
init(queue: QueueHandler<PostgreSQLMessage, PostgreSQLMessage>, channel: Channel) {
self.queue = queue
self.channel = channel
Expand Down
2 changes: 1 addition & 1 deletion Sources/PostgreSQL/Database/PostgreSQLDatabase.swift
Expand Up @@ -19,7 +19,7 @@ public final class PostgreSQLDatabase: Database, LogSupporting {
public func newConnection(on worker: Worker) -> Future<PostgreSQLConnection> {
let config = self.config
return Future.flatMap(on: worker) {
return try PostgreSQLConnection.connect(hostname: config.hostname, port: config.port, transport: config.transportConfig, on: worker) { error in
return try PostgreSQLConnection.connect(to: config.serverAddress, transport: config.transportConfig, on: worker) { error in
print("[PostgreSQL] \(error)")
}.flatMap(to: PostgreSQLConnection.self) { client in
return client.authenticate(
Expand Down
42 changes: 26 additions & 16 deletions Sources/PostgreSQL/Database/PostgreSQLDatabaseConfig.swift
Expand Up @@ -7,12 +7,20 @@ public struct PostgreSQLDatabaseConfig {
public static func `default`() -> PostgreSQLDatabaseConfig {
return .init(hostname: "localhost", port: 5432, username: "postgres")
}

/// Destination hostname.
public let hostname: String

/// Destination port.
public let port: Int

/// Specifies how to connect to a PostgreSQL server.
public enum ServerAddress {
/// Connect via TCP using the given hostname and port.
case tcp(hostname: String, port: Int)
/// Connect via a Unix domain socket file.
case unixSocket(path: String)

public static let `default` = ServerAddress.tcp(hostname: "localhost", port: 5432)
public static let socketDefault = ServerAddress.unixSocket(path: "/tmp/.s.PGSQL.5432")
}

/// Which server to connect to.
public let serverAddress: ServerAddress

/// Username to authenticate.
public let username: String
Expand All @@ -29,14 +37,17 @@ public struct PostgreSQLDatabaseConfig {
public let transportConfig: PostgreSQLTransportConfig

/// Creates a new `PostgreSQLDatabaseConfig`.
public init(hostname: String, port: Int = 5432, username: String, database: String? = nil, password: String? = nil, transport: PostgreSQLTransportConfig = .cleartext) {
self.hostname = hostname
self.port = port
self.username = username
self.database = database
self.password = password
self.transportConfig = transport
}
public init(serverAddress: ServerAddress, username: String, database: String? = nil, password: String? = nil, transport: PostgreSQLTransportConfig = .cleartext) {
self.serverAddress = serverAddress
self.username = username
self.database = database
self.password = password
self.transportConfig = transport
}

public init(hostname: String, port: Int = 5432, username: String, database: String? = nil, password: String? = nil, transport: PostgreSQLTransportConfig = .cleartext) {
self.init(serverAddress: .tcp(hostname: hostname, port: port), username: username, database: database, password: password, transport: transport)
}

/// Creates a `PostgreSQLDatabaseConfig` frome a connection string.
public init(url urlString: String, transport: PostgreSQLTransportConfig = .cleartext) throws {
Expand All @@ -54,8 +65,7 @@ public struct PostgreSQLDatabaseConfig {
source: .capture()
)
}
self.hostname = hostname
self.port = port
self.serverAddress = .tcp(hostname: hostname, port: port)
self.username = username
let database = url.path
if database.hasPrefix("/") {
Expand Down
18 changes: 11 additions & 7 deletions Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift
Expand Up @@ -433,8 +433,12 @@ class PostgreSQLConnectionTests: XCTestCase {
func testURLParsing() throws {
let databaseURL = "postgres://username:password@hostname.com:5432/database"
let config = try PostgreSQLDatabaseConfig(url: databaseURL)
XCTAssertEqual(config.hostname, "hostname.com")
XCTAssertEqual(config.port, 5432)
if case let .tcp(hostname, port) = config.serverAddress {
XCTAssertEqual(hostname, "hostname.com")
XCTAssertEqual(port, 5432)
} else {
XCTFail("unexpected server address \(config.serverAddress)")
}
XCTAssertEqual(config.username, "username")
XCTAssertEqual(config.password, "password")
XCTAssertEqual(config.database, "database")
Expand Down Expand Up @@ -489,18 +493,18 @@ extension PostgreSQLConnection {
/// Creates a test event loop and psql client over ssl.
static func makeTest(transport: PostgreSQLTransportConfig) throws -> PostgreSQLConnection {
#if os(macOS)
return try _makeTest(hostname: "192.168.99.100", password: "vapor_password", port: transport.isTLS ? 5433 : 5432, transport: transport)
return try _makeTest(serverAddress: .tcp(hostname: "192.168.99.100", port: transport.isTLS ? 5433 : 5432), password: "vapor_password", transport: transport)
#else
return try _makeTest(hostname: transport.isTLS ? "tls" : "cleartext", password: "vapor_password", transport: transport)
return try _makeTest(serverAddress: .tcp(hostname: transport.isTLS ? "tls" : "cleartext", port: 5432), password: "vapor_password", transport: transport)
#endif
}

/// Creates a test connection.
private static func _makeTest(hostname: String, password: String? = nil, port: Int = 5432, transport: PostgreSQLTransportConfig = .cleartext) throws -> PostgreSQLConnection {
private static func _makeTest(serverAddress: PostgreSQLDatabaseConfig.ServerAddress, password: String? = nil, transport: PostgreSQLTransportConfig = .cleartext) throws -> PostgreSQLConnection {
let group = MultiThreadedEventLoopGroup(numThreads: 1)
let client = try PostgreSQLConnection.connect(hostname: hostname, port: port, transport: transport, on: group) { error in
let client = try PostgreSQLConnection.connect(to: serverAddress, transport: transport, on: group) { error in
XCTFail("\(error)")
}.wait()
}.wait()
_ = try client.authenticate(username: "vapor_username", database: "vapor_database", password: password).wait()
return client
}
Expand Down