Skip to content

Commit

Permalink
State machine (#135)
Browse files Browse the repository at this point in the history
* Adds PSQLFrontendMessage & PSQLBackendMessage

* State machine

* Removed Xcode headers.

* Apply suggestions from code review

Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org>

* Code review

* Apply suggestions from code review

Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org>

* Code review

* Apply suggestions from code review

Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org>

* Code review

* Add rudementary sasl support

* Update Sources/PostgresNIO/Connection/PostgresConnection+Notifications.swift

Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org>

* Code review

* Code review

* A little more error handling

* Error handling

* Better logging

* Fixes!

* Some better state handling when closing

* State machine tests

* Better cleanup in error states

* Cherry pick to be reverted.

* PreparedStatementStateMachine tests

* Code review

* Enable trace logging to better find the flaky tests

* PSQLChannelHandler logging + cleanup

* PR review

* Code review

* Update Sources/PostgresNIO/New/Connection State Machine/AuthenticationStateMachine.swift

Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org>

* Last code comment

* Last code comment fix

Co-authored-by: Gwynne Raskind <gwynne@darkrainfall.org>
  • Loading branch information
fabianfett and gwynne committed Feb 24, 2021
1 parent 5876fdf commit cdb18d1
Show file tree
Hide file tree
Showing 125 changed files with 10,157 additions and 812 deletions.
10 changes: 6 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ jobs:
run: swift test --enable-test-discovery --sanitize=thread
env:
POSTGRES_HOSTNAME: psql
POSTGRES_USERNAME: vapor_username
POSTGRES_USER: vapor_username
POSTGRES_DB: vapor_database
POSTGRES_PASSWORD: vapor_password
POSTGRES_DATABASE: vapor_database
POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }}

# Run package tests on macOS against supported PSQL versions
macos:
Expand Down Expand Up @@ -138,6 +139,7 @@ jobs:
run: swift test --enable-test-discovery --sanitize=thread
env:
POSTGRES_HOSTNAME: 127.0.0.1
POSTGRES_USERNAME: vapor_username
POSTGRES_USER: vapor_username
POSTGRES_DB: postgres
POSTGRES_PASSWORD: vapor_password
POSTGRES_DATABASE: postgres
POSTGRES_HOST_AUTH_METHOD: ${{ matrix.dbauth }}
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ let package = Package(
.product(name: "Logging", package: "swift-log"),
.product(name: "Metrics", package: "swift-metrics"),
.product(name: "NIO", package: "swift-nio"),
.product(name: "NIOTLS", package: "swift-nio"),
.product(name: "NIOFoundationCompat", package: "swift-nio"),
.product(name: "NIOSSL", package: "swift-nio-ssl"),
]),
.testTarget(name: "PostgresNIOTests", dependencies: [
Expand Down
158 changes: 9 additions & 149 deletions Sources/PostgresNIO/Connection/PostgresConnection+Authenticate.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import Crypto
import NIO
import Logging

extension PostgresConnection {
public func authenticate(
Expand All @@ -9,155 +7,17 @@ extension PostgresConnection {
password: String? = nil,
logger: Logger = .init(label: "codes.vapor.postgres")
) -> EventLoopFuture<Void> {
let auth = PostgresAuthenticationRequest(
let authContext = AuthContext(
username: username,
database: database,
password: password
)
return self.send(auth, logger: self.logger)
}
}

// MARK: Private
password: password,
database: database)
let outgoing = PSQLOutgoingEvent.authenticate(authContext)
self.underlying.channel.triggerUserOutboundEvent(outgoing, promise: nil)

private final class PostgresAuthenticationRequest: PostgresRequest {
enum State {
case ready
case saslInitialSent(SASLAuthenticationManager<SASLMechanism.SCRAM.SHA256>)
case saslChallengeResponse(SASLAuthenticationManager<SASLMechanism.SCRAM.SHA256>)
case saslWaitOkay
case done
}

let username: String
let database: String?
let password: String?
var state: State

init(username: String, database: String?, password: String?) {
self.state = .ready
self.username = username
self.database = database
self.password = password
}

func log(to logger: Logger) {
logger.debug("Logging into Postgres db \(self.database ?? "nil") as \(self.username)")
}

func respond(to message: PostgresMessage) throws -> [PostgresMessage]? {
if case .error = message.identifier {
// terminate immediately on error
return nil
}

switch self.state {
case .ready:
switch message.identifier {
case .authentication:
let auth = try PostgresMessage.Authentication(message: message)
switch auth {
case .md5(let salt):
let pwdhash = self.md5((self.password ?? "") + self.username).hexdigest()
let hash = "md5" + self.md5(self.bytes(pwdhash) + salt).hexdigest()
return try [PostgresMessage.Password(string: hash).message()]
case .plaintext:
return try [PostgresMessage.Password(string: self.password ?? "").message()]
case .saslMechanisms(let saslMechanisms):
if saslMechanisms.contains("SCRAM-SHA-256") && self.password != nil {
let saslManager = SASLAuthenticationManager(asClientSpeaking:
SASLMechanism.SCRAM.SHA256(username: self.username, password: { self.password! }))
var message: PostgresMessage?

if (try saslManager.handle(message: nil, sender: { bytes in
message = try PostgresMessage.SASLInitialResponse(mechanism: "SCRAM-SHA-256", initialData: bytes).message()
})) {
self.state = .saslWaitOkay
} else {
self.state = .saslInitialSent(saslManager)
}
return [message].compactMap { $0 }
} else {
throw PostgresError.protocol("Unable to authenticate with any available SASL mechanism: \(saslMechanisms)")
}
case .saslContinue, .saslFinal:
throw PostgresError.protocol("Unexpected SASL response to start message: \(message)")
case .ok:
self.state = .done
return []
}
default: throw PostgresError.protocol("Unexpected response to start message: \(message)")
}
case .saslInitialSent(let manager),
.saslChallengeResponse(let manager):
switch message.identifier {
case .authentication:
let auth = try PostgresMessage.Authentication(message: message)
switch auth {
case .saslContinue(let data), .saslFinal(let data):
var message: PostgresMessage?
if try manager.handle(message: data, sender: { bytes in
message = try PostgresMessage.SASLResponse(responseData: bytes).message()
}) {
self.state = .saslWaitOkay
} else {
self.state = .saslChallengeResponse(manager)
}
return [message].compactMap { $0 }
default: throw PostgresError.protocol("Unexpected response during SASL negotiation: \(message)")
}
default: throw PostgresError.protocol("Unexpected response during SASL negotiation: \(message)")
}
case .saslWaitOkay:
switch message.identifier {
case .authentication:
let auth = try PostgresMessage.Authentication(message: message)
switch auth {
case .ok:
self.state = .done
return []
default: throw PostgresError.protocol("Unexpected response while waiting for post-SASL ok: \(message)")
}
default: throw PostgresError.protocol("Unexpected response while waiting for post-SASL ok: \(message)")
}
case .done:
switch message.identifier {
case .parameterStatus:
// self.status[status.parameter] = status.value
return []
case .backendKeyData:
// self.processID = data.processID
// self.secretKey = data.secretKey
return []
case .readyForQuery:
return nil
default: throw PostgresError.protocol("Unexpected response to password authentication: \(message)")
}
return self.underlying.channel.pipeline.handler(type: PSQLEventsHandler.self).flatMap { handler in
handler.authenticateFuture
}.flatMapErrorThrowing { error in
throw error.asAppropriatePostgresError
}

}

func start() throws -> [PostgresMessage] {
return try [
PostgresMessage.Startup.versionThree(parameters: [
"user": self.username,
"database": self.database ?? username
]).message()
]
}

// MARK: Private

private func md5(_ string: String) -> [UInt8] {
return md5(self.bytes(string))
}

private func md5(_ message: [UInt8]) -> [UInt8] {
let digest = Insecure.MD5.hash(data: message)
return .init(digest)
}

func bytes(_ string: String) -> [UInt8] {
return Array(string.utf8)
}
}
62 changes: 20 additions & 42 deletions Sources/PostgresNIO/Connection/PostgresConnection+Connect.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import Logging
import NIO

extension PostgresConnection {
Expand All @@ -9,47 +8,26 @@ extension PostgresConnection {
logger: Logger = .init(label: "codes.vapor.postgres"),
on eventLoop: EventLoop
) -> EventLoopFuture<PostgresConnection> {
let bootstrap = ClientBootstrap(group: eventLoop)
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
return bootstrap.connect(to: socketAddress).flatMap { channel in
return channel.pipeline.addHandlers([
ByteToMessageHandler(PostgresMessageDecoder(logger: logger)),
MessageToByteHandler(PostgresMessageEncoder(logger: logger)),
PostgresRequestHandler(logger: logger),
PostgresErrorHandler(logger: logger)
]).map {
return PostgresConnection(channel: channel, logger: logger)
}
}.flatMap { (conn: PostgresConnection) in
if let tlsConfiguration = tlsConfiguration {
return conn.requestTLS(
using: tlsConfiguration,
serverHostname: serverHostname,
logger: logger
).flatMapError { error in
conn.close().flatMapThrowing {
throw error
}
}.map { conn }
} else {
return eventLoop.makeSucceededFuture(conn)
}

let coders = PSQLConnection.Configuration.Coders(
jsonEncoder: PostgresJSONEncoderWrapper(_defaultJSONEncoder),
jsonDecoder: PostgresJSONDecoderWrapper(_defaultJSONDecoder)
)

let configuration = PSQLConnection.Configuration(
connection: .resolved(address: socketAddress, serverName: serverHostname),
authentication: nil,
tlsConfiguration: tlsConfiguration,
coders: coders)

return PSQLConnection.connect(
configuration: configuration,
logger: logger,
on: eventLoop
).map { connection in
PostgresConnection(underlying: connection, logger: logger)
}.flatMapErrorThrowing { error in
throw error.asAppropriatePostgresError
}
}
}


private final class PostgresErrorHandler: ChannelInboundHandler {
typealias InboundIn = Never

let logger: Logger
init(logger: Logger) {
self.logger = logger
}

func errorCaught(context: ChannelHandlerContext, error: Error) {
self.logger.error("Uncaught error: \(error)")
context.close(promise: nil)
context.fireErrorCaught(error)
}
}
Loading

0 comments on commit cdb18d1

Please sign in to comment.