Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1258,18 +1258,7 @@ struct AuthContext: Equatable, CustomDebugStringConvertible {

enum PasswordAuthencationMode: Equatable {
case cleartext
case md5(salt: (UInt8, UInt8, UInt8, UInt8))

static func ==(lhs: Self, rhs: Self) -> Bool {
switch (lhs, rhs) {
case (.cleartext, .cleartext):
return true
case (.md5(let lhs), .md5(let rhs)):
return lhs == rhs
default:
return false
}
}
case md5(salt: UInt32)
}

extension ConnectionStateMachine.State: CustomDebugStringConvertible {
Expand Down
37 changes: 3 additions & 34 deletions Sources/PostgresNIO/New/Messages/Authentication.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ import NIOCore

extension PostgresBackendMessage {

enum Authentication: PayloadDecodable {
enum Authentication: PayloadDecodable, Hashable {
case ok
case kerberosV5
case md5(salt: (UInt8, UInt8, UInt8, UInt8))
case md5(salt: UInt32)
case plaintext
case scmCredential
case gss
Expand All @@ -26,7 +26,7 @@ extension PostgresBackendMessage {
case 3:
return .plaintext
case 5:
guard let salt = buffer.readMultipleIntegers(endianness: .big, as: (UInt8, UInt8, UInt8, UInt8).self) else {
guard let salt = buffer.readInteger(as: UInt32.self) else {
throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(4, actual: buffer.readableBytes)
}
return .md5(salt: salt)
Expand Down Expand Up @@ -61,37 +61,6 @@ extension PostgresBackendMessage {
}
}

extension PostgresBackendMessage.Authentication: Equatable {
static func ==(lhs: Self, rhs: Self) -> Bool {
switch (lhs, rhs) {
case (.ok, .ok):
return true
case (.kerberosV5, .kerberosV5):
return true
case (.md5(let lhs), .md5(let rhs)):
return lhs == rhs
case (.plaintext, .plaintext):
return true
case (.scmCredential, .scmCredential):
return true
case (.gss, .gss):
return true
case (.sspi, .sspi):
return true
case (.gssContinue(let lhs), .gssContinue(let rhs)):
return lhs == rhs
case (.sasl(let lhs), .sasl(let rhs)):
return lhs == rhs
case (.saslContinue(let lhs), .saslContinue(let rhs)):
return lhs == rhs
case (.saslFinal(let lhs), .saslFinal(let rhs)):
return lhs == rhs
default:
return false
}
}
}

extension PostgresBackendMessage.Authentication: CustomDebugStringConvertible {
var debugDescription: String {
switch self {
Expand Down
2 changes: 1 addition & 1 deletion Sources/PostgresNIO/New/Messages/BackendKeyData.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import NIOCore

extension PostgresBackendMessage {

struct BackendKeyData: PayloadDecodable, Equatable {
struct BackendKeyData: PayloadDecodable, Hashable {
let processID: Int32
let secretKey: Int32

Expand Down
2 changes: 1 addition & 1 deletion Sources/PostgresNIO/New/Messages/DataRow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import NIOCore
/// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick
/// the Swift compiler
@usableFromInline
struct DataRow: Sendable, PostgresBackendMessage.PayloadDecodable, Equatable {
struct DataRow: Sendable, PostgresBackendMessage.PayloadDecodable, Hashable {
@usableFromInline
var columnCount: Int16
@usableFromInline
Expand Down
4 changes: 2 additions & 2 deletions Sources/PostgresNIO/New/Messages/ErrorResponse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ extension PostgresBackendMessage {
case routine = 0x52 /// R
}

struct ErrorResponse: PSQLMessageNotice, PayloadDecodable, Equatable {
struct ErrorResponse: PSQLMessageNotice, PayloadDecodable, Hashable {
let fields: [PostgresBackendMessage.Field: String]

init(fields: [PostgresBackendMessage.Field: String]) {
self.fields = fields
}
}

struct NoticeResponse: PSQLMessageNotice, PayloadDecodable, Equatable {
struct NoticeResponse: PSQLMessageNotice, PayloadDecodable, Hashable {
let fields: [PostgresBackendMessage.Field: String]

init(fields: [PostgresBackendMessage.Field: String]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import NIOCore

extension PostgresBackendMessage {

struct NotificationResponse: PayloadDecodable, Equatable {
struct NotificationResponse: PayloadDecodable, Hashable {
let backendPID: Int32
let channel: String
let payload: String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import NIOCore

extension PostgresBackendMessage {

struct ParameterDescription: PayloadDecodable, Equatable {
struct ParameterDescription: PayloadDecodable, Hashable {
/// Specifies the object ID of the parameter data type.
var dataTypes: [PostgresDataType]

Expand Down
2 changes: 1 addition & 1 deletion Sources/PostgresNIO/New/Messages/ParameterStatus.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import NIOCore

extension PostgresBackendMessage {

struct ParameterStatus: PayloadDecodable, Equatable {
struct ParameterStatus: PayloadDecodable, Hashable {
/// The name of the run-time parameter being reported.
var parameter: String

Expand Down
34 changes: 4 additions & 30 deletions Sources/PostgresNIO/New/Messages/ReadyForQuery.swift
Original file line number Diff line number Diff line change
@@ -1,37 +1,11 @@
import NIOCore

extension PostgresBackendMessage {
enum TransactionState: PayloadDecodable, RawRepresentable {
typealias RawValue = UInt8

case idle
case inTransaction
case inFailedTransaction

init?(rawValue: UInt8) {
switch rawValue {
case UInt8(ascii: "I"):
self = .idle
case UInt8(ascii: "T"):
self = .inTransaction
case UInt8(ascii: "E"):
self = .inFailedTransaction
default:
return nil
}
}
enum TransactionState: UInt8, PayloadDecodable, Hashable {
case idle = 73 // ascii: I
case inTransaction = 84 // ascii: T
case inFailedTransaction = 69 // ascii: E

var rawValue: Self.RawValue {
switch self {
case .idle:
return UInt8(ascii: "I")
case .inTransaction:
return UInt8(ascii: "T")
case .inFailedTransaction:
return UInt8(ascii: "E")
}
}

static func decode(from buffer: inout ByteBuffer) throws -> Self {
let value = try buffer.throwingReadInteger(as: UInt8.self)
guard let state = Self.init(rawValue: value) else {
Expand Down
4 changes: 2 additions & 2 deletions Sources/PostgresNIO/New/Messages/RowDescription.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ import NIOCore
/// Not putting `DataRow` in ``PSQLBackendMessage`` is our way to trick
/// the Swift compiler.
@usableFromInline
struct RowDescription: PostgresBackendMessage.PayloadDecodable, Sendable, Equatable {
struct RowDescription: PostgresBackendMessage.PayloadDecodable, Sendable, Hashable {
/// Specifies the object ID of the parameter data type.
@usableFromInline
var columns: [Column]

@usableFromInline
struct Column: Equatable, Sendable {
struct Column: Hashable, Sendable {
/// The field name.
@usableFromInline
var name: String
Expand Down
2 changes: 1 addition & 1 deletion Sources/PostgresNIO/New/PostgresBackendMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ protocol PSQLMessagePayloadDecodable {
///
/// All messages are defined in the official Postgres Documentation in the section
/// [Frontend/Backend Protocol – Message Formats](https://www.postgresql.org/docs/13/protocol-message-formats.html)
enum PostgresBackendMessage {
enum PostgresBackendMessage: Hashable {

typealias PayloadDecodable = PSQLMessagePayloadDecodable

Expand Down
8 changes: 4 additions & 4 deletions Sources/PostgresNIO/New/PostgresChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,10 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
var hash2 = [UInt8]()
hash2.reserveCapacity(pwdhash.count + 4)
hash2.append(contentsOf: pwdhash)
hash2.append(salt.0)
hash2.append(salt.1)
hash2.append(salt.2)
hash2.append(salt.3)
var saltNetworkOrder = salt.bigEndian
withUnsafeBytes(of: &saltNetworkOrder) { ptr in
hash2.append(contentsOf: ptr)
}
let hash = Insecure.MD5.hash(data: hash2).md5PrefixHexdigest()

self.encoder.password(hash.utf8)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class AuthenticationStateMachineTests: XCTestCase {
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
var state = ConnectionStateMachine(requireBackendKeyData: true)
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3)
let salt: UInt32 = 0x00_01_02_03

XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext))
XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait)
Expand All @@ -30,8 +30,8 @@ class AuthenticationStateMachineTests: XCTestCase {
let authContext = AuthContext(username: "test", password: nil, database: "test")
var state = ConnectionStateMachine(requireBackendKeyData: true)
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3)
let salt: UInt32 = 0x00_01_02_03

XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)),
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .authMechanismRequiresPassword, closePromise: nil)))
Expand All @@ -49,8 +49,8 @@ class AuthenticationStateMachineTests: XCTestCase {
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
var state = ConnectionStateMachine(requireBackendKeyData: true)
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3)
let salt: UInt32 = 0x00_01_02_03

XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext))
let fields: [PostgresBackendMessage.Field: String] = [
Expand Down Expand Up @@ -107,12 +107,12 @@ class AuthenticationStateMachineTests: XCTestCase {
}

func testUnexpectedMessagesAfterPasswordSent() {
let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3)
let salt: UInt32 = 0x00_01_02_03
var buffer = ByteBuffer()
buffer.writeBytes([0, 1, 2, 3, 4, 5, 6, 7, 8])
let unexpected: [PostgresBackendMessage.Authentication] = [
.kerberosV5,
.md5(salt: (0, 1, 2, 3)),
.md5(salt: salt),
.plaintext,
.scmCredential,
.gss,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ConnectionStateMachineTests: XCTestCase {
XCTAssertEqual(state.sslHandlerAdded(), .wait)
XCTAssertEqual(state.sslEstablished(), .provideAuthenticationContext)
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
let salt: (UInt8, UInt8, UInt8, UInt8) = (0,1,2,3)
let salt: UInt32 = 0x00_01_02_03
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext))
}

Expand Down Expand Up @@ -154,7 +154,7 @@ class ConnectionStateMachineTests: XCTestCase {
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }

let authContext = AuthContext(username: "test", password: "abc123", database: "test")
let salt: (UInt8, UInt8, UInt8, UInt8) = (0, 1, 2, 3)
let salt: UInt32 = 0x00_01_02_03

let queryPromise = eventLoopGroup.next().makePromise(of: PSQLRowStream.self)

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder {
/// - parameters:
/// - data: The data to encode into a `ByteBuffer`.
/// - out: The `ByteBuffer` into which we want to encode.
func encode(data message: PostgresBackendMessage, out buffer: inout ByteBuffer) throws {
func encode(data message: PostgresBackendMessage, out buffer: inout ByteBuffer) {
switch message {
case .authentication(let authentication):
self.encode(messageID: message.id, payload: authentication, into: &buffer)
Expand Down Expand Up @@ -144,11 +144,7 @@ extension PostgresBackendMessage.Authentication: PSQLMessagePayloadEncodable {
buffer.writeInteger(Int32(3))

case .md5(salt: let salt):
buffer.writeInteger(Int32(5))
buffer.writeInteger(salt.0)
buffer.writeInteger(salt.1)
buffer.writeInteger(salt.2)
buffer.writeInteger(salt.3)
buffer.writeMultipleIntegers(Int32(5), salt)

case .scmCredential:
buffer.writeInteger(Int32(6))
Expand Down
Loading