Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for custom JWT de/serialisation #130

Merged
merged 6 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Sources/JWTKit/JWTError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public struct JWTError: Error, @unchecked Sendable {
case invalidBool
case noKeyProvided
case invalidX5CChain
case invalidHeaderField
case generic
}

Expand All @@ -34,6 +35,7 @@ public struct JWTError: Error, @unchecked Sendable {
public static let invalidBool = Self(.invalidBool)
public static let noKeyProvided = Self(.noKeyProvided)
public static let invalidX5CChain = Self(.invalidX5CChain)
public static let invalidHeaderField = Self(.invalidHeaderField)
public static let generic = Self(.generic)

public var description: String {
Expand All @@ -58,6 +60,8 @@ public struct JWTError: Error, @unchecked Sendable {
"noKeyProvided"
case .invalidX5CChain:
"invalidX5CChain"
case .invalidHeaderField:
"invalidHeaderField"
case .generic:
"generic"
}
Expand Down Expand Up @@ -159,6 +163,12 @@ public struct JWTError: Error, @unchecked Sendable {
new.reason = reason
return new
}

public static func invalidHeaderField(reason: String) -> Self {
var new = Self(errorType: .invalidHeaderField)
new.reason = reason
return new
}

public static func generic(identifier: String, reason: String) -> Self {
var new = Self(errorType: .generic)
Expand Down
79 changes: 13 additions & 66 deletions Sources/JWTKit/JWTHeader.swift
Original file line number Diff line number Diff line change
@@ -1,102 +1,49 @@
/// The header (details) used for signing and processing the JWT.
@dynamicMemberLookup
public struct JWTHeader: Sendable {
/// The algorithm used with the signing.
public var alg: String?
public var fields: [String: JWTHeaderField]

/// The Signature's Content Type.
public var typ: String?

/// The Payload's Content Type.
public var cty: String?

/// The JWT key identifier.
public var kid: JWKIdentifier?

/// The x5c certificate chain.
public var x5c: [String]?

/// Custom fields.
public var customFields: [String: JWTHeaderField]
public init(fields: [String: JWTHeaderField] = [:]) {
self.fields = fields
}

init(
alg: String? = nil,
typ: String? = nil,
cty: String? = nil,
kid: JWKIdentifier? = nil,
x5c: [String]? = nil,
customFields: [String: JWTHeaderField] = [:]
) {
self.alg = alg
self.typ = typ
self.cty = cty
self.kid = kid
self.x5c = x5c
self.customFields = customFields
public subscript(dynamicMember member: String) -> JWTHeaderField? {
get { fields[member] }
set { fields[member] = newValue }
}
}

extension JWTHeader: Codable {
public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encodeIfPresent(self.alg, forKey: .alg)
try container.encodeIfPresent(self.typ, forKey: .typ)
try container.encodeIfPresent(self.cty, forKey: .cty)
try container.encodeIfPresent(self.kid, forKey: .kid)
try container.encodeIfPresent(self.x5c, forKey: .x5c)
try customFields.forEach { key, value in
try fields.forEach { key, value in
try container.encode(value, forKey: .custom(name: key))
}
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.alg = try container.decodeIfPresent(String.self, forKey: .alg)
self.typ = try container.decodeIfPresent(String.self, forKey: .typ)
self.cty = try container.decodeIfPresent(String.self, forKey: .cty)
self.kid = try container.decodeIfPresent(JWKIdentifier.self, forKey: .kid)
self.x5c = try container.decodeIfPresent([String].self, forKey: .x5c)

self.customFields = try Set(container.allKeys)
.subtracting(CodingKeys.allKeys)
self.fields = try Set(container.allKeys)
.reduce(into: [String: JWTHeaderField]()) { result, key in
result[key.stringValue] = try container.decode(JWTHeaderField.self, forKey: key)
}
}

private enum CodingKeys: CodingKey, Equatable, Hashable {
case alg
case typ
case cty
case kid
case x5c
case custom(name: String)

static var allKeys: Set<CodingKeys> {
[.alg, .typ, .cty, .kid, .x5c]
}

var stringValue: String {
switch self {
case .alg: "alg"
case .typ: "typ"
case .cty: "cty"
case .kid: "kid"
case .x5c: "x5c"
case let .custom(name: name): name
case let .custom(name):
return name
}
}

var intValue: Int? { nil }

init?(stringValue: String) {
switch stringValue {
case "alg": self = .alg
case "typ": self = .typ
case "cty": self = .cty
case "kid": self = .kid
case "x5c": self = .x5c
default: self = .custom(name: stringValue)
}
self = .custom(name: stringValue)
}

init?(intValue _: Int) { nil }
Expand Down
143 changes: 107 additions & 36 deletions Sources/JWTKit/JWTHeaderField.swift
Original file line number Diff line number Diff line change
@@ -1,56 +1,39 @@
import Foundation

public indirect enum JWTHeaderField: Hashable, Sendable, Codable {
case null
case bool(Bool)
case int(Int)
case float(Double)
case decimal(Double)
case string(String)
case array([JWTHeaderField])
case object([String: JWTHeaderField])

public init(from decoder: any Decoder) throws {
let container: any SingleValueDecodingContainer

do {
container = try decoder.singleValueContainer()
} catch DecodingError.typeMismatch {
self = .null
return
}
do { container = try decoder.singleValueContainer() }
catch DecodingError.typeMismatch { self = .null; return }

if container.decodeNil() {
self = .null
return
}
if container.decodeNil() { self = .null; return }

do {
self = try .bool(container.decode(Bool.self))
return
} catch DecodingError.typeMismatch {}
do { self = try .bool(container.decode(Bool.self)); return }
catch DecodingError.typeMismatch {}

do {
self = try .int(container.decode(Int.self))
return
} catch DecodingError.typeMismatch {}
do { self = try .int(container.decode(Int.self)); return }
catch DecodingError.typeMismatch {}

do {
self = try .float(container.decode(Double.self))
return
} catch DecodingError.typeMismatch {}
do { self = try .decimal(container.decode(Double.self)); return }
catch DecodingError.typeMismatch {}

do {
self = try .string(container.decode(String.self))
return
} catch DecodingError.typeMismatch {}
do { self = try .string(container.decode(String.self)); return }
catch DecodingError.typeMismatch {}

do {
self = try .array(container.decode([Self].self))
return
} catch DecodingError.typeMismatch {}
do { self = try .array(container.decode([Self].self)); return }
catch DecodingError.typeMismatch {}

do {
self = try .object(container.decode([String: Self].self))
return
} catch DecodingError.typeMismatch {}
do { self = try .object(container.decode([String: Self].self)); return }
catch DecodingError.typeMismatch {}

throw DecodingError.dataCorruptedError(in: container, debugDescription: "No valid JSON type found.")
}
Expand All @@ -61,10 +44,98 @@ public indirect enum JWTHeaderField: Hashable, Sendable, Codable {
case .null: break
case let .bool(value): try container.encode(value)
case let .int(value): try container.encode(value)
case let .float(value): try container.encode(value)
case let .decimal(value): try container.encode(value)
case let .string(value): try container.encode(value)
case let .array(value): try container.encode(value)
case let .object(value): try container.encode(value)
}
}
}

public extension JWTHeaderField {
internal var isNull: Bool { if case .null = self { true } else { false } }
var asBool: Bool? { if case let .bool(b) = self { b } else { nil } }
var asInt: Int? { if case let .int(i) = self { i } else { nil } }
var asDecimal: Double? { if case let .decimal(d) = self { d } else { nil } }
var asString: String? { if case let .string(s) = self { s } else { nil } }
ptoffy marked this conversation as resolved.
Show resolved Hide resolved
internal var asArray: [Self]? { if case let .array(a) = self { a } else { nil } }
internal var asObject: [String: Self]? { if case let .object(o) = self { o } else { nil } }
}

public extension JWTHeaderField {
func asObject<T>(of _: T.Type) throws -> [String: T] {
guard let object = self.asObject else {
throw JWTError.invalidHeaderField(reason: "Element is not an object")
}
let values: [String: T]? = switch T.self {
case is Bool.Type: object.compactMapValues { $0.asBool } as? [String: T]
case is Int.Type: object.compactMapValues { $0.asInt } as? [String: T]
case is Double.Type: object.compactMapValues { $0.asDecimal } as? [String: T]
case is String.Type: object.compactMapValues { $0.asString } as? [String: T]
default: nil
}
guard let values, object.count == values.count else {
throw JWTError.invalidHeaderField(reason: "Object is not homogeneous")
}
return values
}

func asArray<T>(of _: T.Type) throws -> [T] {
guard let array = self.asArray else {
throw JWTError.invalidHeaderField(reason: "Element is not an array")
}
let values: [T]? = switch T.self {
case is Bool.Type: array.compactMap { $0.asBool } as? [T]
case is Int.Type: array.compactMap { $0.asInt } as? [T]
case is Double.Type: array.compactMap { $0.asDecimal } as? [T]
case is String.Type: array.compactMap { $0.asString } as? [T]
default: nil
}
guard let values, array.count == values.count else {
throw JWTError.invalidHeaderField(reason: "Array is not homogeneous")
}
return values
}
}

extension JWTHeaderField: ExpressibleByNilLiteral {
public init(nilLiteral _: ()) {
self = .null
}
}

extension JWTHeaderField: ExpressibleByStringLiteral {
public init(stringLiteral value: StringLiteralType) {
self = .string(value)
}
}

extension JWTHeaderField: ExpressibleByIntegerLiteral {
public init(integerLiteral value: IntegerLiteralType) {
self = .int(value)
}
}

extension JWTHeaderField: ExpressibleByBooleanLiteral {
public init(booleanLiteral value: BooleanLiteralType) {
self = .bool(value)
}
}

extension JWTHeaderField: ExpressibleByFloatLiteral {
public init(floatLiteral value: FloatLiteralType) {
self = .decimal(value)
}
}

extension JWTHeaderField: ExpressibleByArrayLiteral {
public init(arrayLiteral elements: JWTHeaderField...) {
self = .array(elements)
}
}

extension JWTHeaderField: ExpressibleByDictionaryLiteral {
public init(dictionaryLiteral elements: (String, JWTHeaderField)...) {
self = .object(Dictionary(uniqueKeysWithValues: elements))
}
}
Loading