From 31baa4f92dfcc2e8d44e2225d5f775df3493296f Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Tue, 17 Oct 2023 17:08:22 -0300 Subject: [PATCH 1/7] Move functions-swift to repo --- Examples/Package.swift | 13 ++ Package.resolved | 18 +-- Package.swift | 17 +-- Sources/Functions/FunctionsClient.swift | 114 ++++++++++++++++++ Sources/Functions/Types.swift | 74 ++++++++++++ Sources/Functions/Version.swift | 1 + .../FunctionInvokeOptionsTests.swift | 26 ++++ .../FunctionsTests/FunctionsClientTests.swift | 93 ++++++++++++++ .../xcshareddata/swiftpm/Package.resolved | 40 ++---- 9 files changed, 351 insertions(+), 45 deletions(-) create mode 100644 Examples/Package.swift create mode 100644 Sources/Functions/FunctionsClient.swift create mode 100644 Sources/Functions/Types.swift create mode 100644 Sources/Functions/Version.swift create mode 100644 Tests/FunctionsTests/FunctionInvokeOptionsTests.swift create mode 100644 Tests/FunctionsTests/FunctionsClientTests.swift diff --git a/Examples/Package.swift b/Examples/Package.swift new file mode 100644 index 00000000..e75600e0 --- /dev/null +++ b/Examples/Package.swift @@ -0,0 +1,13 @@ +// swift-tools-version:5.7 +// The swift-tools-version declares the minimum version of Swift required to build this package. + +import Foundation +import PackageDescription + +var package = Package( + name: "Examples", + platforms: [], + products: [], + dependencies: [], + targets: [] +) diff --git a/Package.resolved b/Package.resolved index b5381ff9..72f624f2 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,14 +1,5 @@ { "pins" : [ - { - "identity" : "functions-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/supabase-community/functions-swift", - "state" : { - "branch" : "dependency-free", - "revision" : "5e12b7c4a206a29910dd64f6f23faed437955089" - } - }, { "identity" : "gotrue-swift", "kind" : "remoteSourceControl", @@ -27,6 +18,15 @@ "version" : "4.2.2" } }, + { + "identity" : "mocker", + "kind" : "remoteSourceControl", + "location" : "https://github.com/WeTransfer/Mocker", + "state" : { + "revision" : "4384e015cae4916a6828252467a4437173c7ae17", + "version" : "3.0.1" + } + }, { "identity" : "postgrest-swift", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index e2859669..0a1cfd38 100644 --- a/Package.swift +++ b/Package.swift @@ -17,9 +17,15 @@ var package = Package( .library( name: "Supabase", targets: ["Supabase"] + ), + .library( + name: "Functions", + targets: ["Functions"] ) ], - dependencies: [], + dependencies: [ + .package(url: "https://github.com/WeTransfer/Mocker", from: "3.0.1"), + ], targets: [ .target( name: "Supabase", @@ -28,10 +34,12 @@ var package = Package( .product(name: "SupabaseStorage", package: "storage-swift"), .product(name: "Realtime", package: "realtime-swift"), .product(name: "PostgREST", package: "postgrest-swift"), - .product(name: "Functions", package: "functions-swift"), + "Functions", ] ), .testTarget(name: "SupabaseTests", dependencies: ["Supabase"]), + .target(name: "Functions"), + .testTarget(name: "FunctionsTests", dependencies: ["Functions", "Mocker"]), ] ) @@ -42,7 +50,6 @@ if ProcessInfo.processInfo.environment["USE_LOCAL_PACKAGES"] != nil { .package(path: "../storage-swift"), .package(path: "../realtime-swift"), .package(path: "../postgrest-swift"), - .package(path: "../functions-swift"), ] ) } else { @@ -61,10 +68,6 @@ if ProcessInfo.processInfo.environment["USE_LOCAL_PACKAGES"] != nil { url: "https://github.com/supabase-community/postgrest-swift", branch: "dependency-free" ), - .package( - url: "https://github.com/supabase-community/functions-swift", - branch: "dependency-free" - ), ] ) } diff --git a/Sources/Functions/FunctionsClient.swift b/Sources/Functions/FunctionsClient.swift new file mode 100644 index 00000000..962eb127 --- /dev/null +++ b/Sources/Functions/FunctionsClient.swift @@ -0,0 +1,114 @@ +import Foundation + +/// An actor representing a client for invoking functions. +public actor FunctionsClient { + /// Typealias for the fetch handler used to make requests. + public typealias FetchHandler = @Sendable (_ request: URLRequest) async throws -> ( + Data, URLResponse + ) + + /// The base URL for the functions. + let url: URL + /// Headers to be included in the requests. + var headers: [String: String] + /// The fetch handler used to make requests. + let fetch: FetchHandler + + /// Initializes a new instance of `FunctionsClient`. + /// + /// - Parameters: + /// - url: The base URL for the functions. + /// - headers: Headers to be included in the requests. (Default: empty dictionary) + /// - fetch: The fetch handler used to make requests. (Default: URLSession.shared.data(for:)) + public init( + url: URL, + headers: [String: String] = [:], + fetch: @escaping FetchHandler = { try await URLSession.shared.data(for: $0) } + ) { + self.url = url + self.headers = headers + self.headers["X-Client-Info"] = "functions-swift/\(version)" + self.fetch = fetch + } + + /// Updates the authorization header. + /// + /// - Parameter token: The new JWT token sent in the authorization header. + public func setAuth(token: String) { + headers["Authorization"] = "Bearer \(token)" + } + + /// Invokes a function and decodes the response. + /// + /// - Parameters: + /// - functionName: The name of the function to invoke. + /// - invokeOptions: Options for invoking the function. (Default: empty `FunctionInvokeOptions`) + /// - decode: A closure to decode the response data and HTTPURLResponse into a `Response` object. + /// - Returns: The decoded `Response` object. + public func invoke( + functionName: String, + invokeOptions: FunctionInvokeOptions = .init(), + decode: (Data, HTTPURLResponse) throws -> Response + ) async throws -> Response { + let (data, response) = try await rawInvoke( + functionName: functionName, invokeOptions: invokeOptions) + return try decode(data, response) + } + + /// Invokes a function and decodes the response as a specific type. + /// + /// - Parameters: + /// - functionName: The name of the function to invoke. + /// - invokeOptions: Options for invoking the function. (Default: empty `FunctionInvokeOptions`) + /// - decoder: The JSON decoder to use for decoding the response. (Default: `JSONDecoder()`) + /// - Returns: The decoded object of type `T`. + public func invoke( + functionName: String, + invokeOptions: FunctionInvokeOptions = .init(), + decoder: JSONDecoder = JSONDecoder() + ) async throws -> T { + try await invoke(functionName: functionName, invokeOptions: invokeOptions) { data, _ in + try decoder.decode(T.self, from: data) + } + } + + /// Invokes a function without expecting a response. + /// + /// - Parameters: + /// - functionName: The name of the function to invoke. + /// - invokeOptions: Options for invoking the function. (Default: empty `FunctionInvokeOptions`) + public func invoke( + functionName: String, + invokeOptions: FunctionInvokeOptions = .init() + ) async throws { + try await invoke(functionName: functionName, invokeOptions: invokeOptions) { _, _ in () } + } + + private func rawInvoke( + functionName: String, + invokeOptions: FunctionInvokeOptions + ) async throws -> (Data, HTTPURLResponse) { + let url = self.url.appendingPathComponent(functionName) + var urlRequest = URLRequest(url: url) + urlRequest.allHTTPHeaderFields = invokeOptions.headers.merging(headers) { first, _ in first } + urlRequest.httpMethod = (invokeOptions.method ?? .post).rawValue + urlRequest.httpBody = invokeOptions.body + + let (data, response) = try await fetch(urlRequest) + + guard let httpResponse = response as? HTTPURLResponse else { + throw URLError(.badServerResponse) + } + + guard 200..<300 ~= httpResponse.statusCode else { + throw FunctionsError.httpError(code: httpResponse.statusCode, data: data) + } + + let isRelayError = httpResponse.value(forHTTPHeaderField: "x-relay-error") == "true" + if isRelayError { + throw FunctionsError.relayError + } + + return (data, httpResponse) + } +} diff --git a/Sources/Functions/Types.swift b/Sources/Functions/Types.swift new file mode 100644 index 00000000..6ac72b15 --- /dev/null +++ b/Sources/Functions/Types.swift @@ -0,0 +1,74 @@ +import Foundation + +/// An error type representing various errors that can occur while invoking functions. +public enum FunctionsError: Error, LocalizedError { + /// Error indicating a relay error while invoking the Edge Function. + case relayError + /// Error indicating a non-2xx status code returned by the Edge Function. + case httpError(code: Int, data: Data) + + /// A localized description of the error. + public var errorDescription: String? { + switch self { + case .relayError: + return "Relay Error invoking the Edge Function" + case let .httpError(code, _): + return "Edge Function returned a non-2xx status code: \(code)" + } + } +} + +/// Options for invoking a function. +public struct FunctionInvokeOptions { + /// Method to use in the function invocation. + let method: Method? + /// Headers to be included in the function invocation. + let headers: [String: String] + /// Body data to be sent with the function invocation. + let body: Data? + + /// Initializes the `FunctionInvokeOptions` structure. + /// + /// - Parameters: + /// - method: Method to use in the function invocation. + /// - headers: Headers to be included in the function invocation. (Default: empty dictionary) + /// - body: The body data to be sent with the function invocation. (Default: nil) + public init(method: Method? = nil, headers: [String: String] = [:], body: some Encodable) { + var headers = headers + + switch body { + case let string as String: + headers["Content-Type"] = "text/plain" + self.body = string.data(using: .utf8) + case let data as Data: + headers["Content-Type"] = "application/octet-stream" + self.body = data + default: + // default, assume this is JSON + headers["Content-Type"] = "application/json" + self.body = try? JSONEncoder().encode(body) + } + + self.method = method + self.headers = headers + } + + /// Initializes the `FunctionInvokeOptions` structure. + /// + /// - Parameters: + /// - method: Method to use in the function invocation. + /// - headers: Headers to be included in the function invocation. (Default: empty dictionary) + public init(method: Method? = nil, headers: [String: String] = [:]) { + self.method = method + self.headers = headers + body = nil + } + + public enum Method: String { + case get = "GET" + case post = "POST" + case put = "PUT" + case patch = "PATCH" + case delete = "DELETE" + } +} diff --git a/Sources/Functions/Version.swift b/Sources/Functions/Version.swift new file mode 100644 index 00000000..5e629e8e --- /dev/null +++ b/Sources/Functions/Version.swift @@ -0,0 +1 @@ +let version = "1.0.0" diff --git a/Tests/FunctionsTests/FunctionInvokeOptionsTests.swift b/Tests/FunctionsTests/FunctionInvokeOptionsTests.swift new file mode 100644 index 00000000..ef5ae218 --- /dev/null +++ b/Tests/FunctionsTests/FunctionInvokeOptionsTests.swift @@ -0,0 +1,26 @@ +import XCTest + +@testable import Functions + +final class FunctionInvokeOptionsTests: XCTestCase { + func testStringBody() { + let options = FunctionInvokeOptions(body: "string value") + XCTAssertEqual(options.headers["Content-Type"], "text/plain") + XCTAssertNotNil(options.body) + } + + func testDataBody() { + let options = FunctionInvokeOptions(body: "binary value".data(using: .utf8)!) + XCTAssertEqual(options.headers["Content-Type"], "application/octet-stream") + XCTAssertNotNil(options.body) + } + + func testEncodableBody() { + struct Body: Encodable { + let value: String + } + let options = FunctionInvokeOptions(body: Body(value: "value")) + XCTAssertEqual(options.headers["Content-Type"], "application/json") + XCTAssertNotNil(options.body) + } +} diff --git a/Tests/FunctionsTests/FunctionsClientTests.swift b/Tests/FunctionsTests/FunctionsClientTests.swift new file mode 100644 index 00000000..fc45b70f --- /dev/null +++ b/Tests/FunctionsTests/FunctionsClientTests.swift @@ -0,0 +1,93 @@ +import Mocker +import XCTest + +@testable import Functions + +final class FunctionsClientTests: XCTestCase { + let url = URL(string: "http://localhost:5432/functions/v1")! + let apiKey = "supabase.anon.key" + + lazy var sut = FunctionsClient(url: url, headers: ["apikey": apiKey]) + + func testInvoke() async throws { + let url = URL(string: "http://localhost:5432/functions/v1/hello_world")! + var _request: URLRequest? + + var mock = Mock(url: url, dataType: .json, statusCode: 200, data: [.post: Data()]) + mock.onRequestHandler = .init { _request = $0 } + mock.register() + + let body = ["name": "Supabase"] + + try await sut.invoke( + functionName: "hello_world", + invokeOptions: .init(headers: ["X-Custom-Key": "value"], body: body) + ) + + let request = try XCTUnwrap(_request) + + XCTAssertEqual(request.url, url) + XCTAssertEqual(request.httpMethod, "POST") + XCTAssertEqual(request.value(forHTTPHeaderField: "apikey"), apiKey) + XCTAssertEqual(request.value(forHTTPHeaderField: "X-Custom-Key"), "value") + XCTAssertEqual( + request.value(forHTTPHeaderField: "X-Client-Info"), + "functions-swift/\(Functions.version)" + ) + } + + func testInvoke_shouldThrow_URLError_badServerResponse() async { + let url = URL(string: "http://localhost:5432/functions/v1/hello_world")! + let mock = Mock( + url: url, dataType: .json, statusCode: 200, data: [.post: Data()], + requestError: URLError(.badServerResponse)) + mock.register() + + do { + try await sut.invoke(functionName: "hello_world") + } catch let urlError as URLError { + XCTAssertEqual(urlError.code, .badServerResponse) + } catch { + XCTFail("Unexpected error thrown \(error)") + } + } + + func testInvoke_shouldThrow_FunctionsError_httpError() async { + let url = URL(string: "http://localhost:5432/functions/v1/hello_world")! + let mock = Mock( + url: url, dataType: .json, statusCode: 300, data: [.post: "error".data(using: .utf8)!]) + mock.register() + + do { + try await sut.invoke(functionName: "hello_world") + XCTFail("Invoke should fail.") + } catch let FunctionsError.httpError(code, data) { + XCTAssertEqual(code, 300) + XCTAssertEqual(data, "error".data(using: .utf8)) + } catch { + XCTFail("Unexpected error thrown \(error)") + } + } + + func testInvoke_shouldThrow_FunctionsError_relayError() async { + let url = URL(string: "http://localhost:5432/functions/v1/hello_world")! + let mock = Mock( + url: url, dataType: .json, statusCode: 200, data: [.post: Data()], + additionalHeaders: ["x-relay-error": "true"]) + mock.register() + + do { + try await sut.invoke(functionName: "hello_world") + XCTFail("Invoke should fail.") + } catch FunctionsError.relayError { + } catch { + XCTFail("Unexpected error thrown \(error)") + } + } + + func test_setAuth() async { + await sut.setAuth(token: "access.token") + let headers = await sut.headers + XCTAssertEqual(headers["Authorization"], "Bearer access.token") + } +} diff --git a/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved b/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved index 47415a64..293e536e 100644 --- a/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,30 +1,12 @@ { "pins" : [ - { - "identity" : "functions-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/supabase-community/functions-swift", - "state" : { - "branch" : "dependency-free", - "revision" : "479a3112f5ffdbeaec05f9b19124ab621c69d44a" - } - }, - { - "identity" : "get", - "kind" : "remoteSourceControl", - "location" : "https://github.com/kean/Get", - "state" : { - "revision" : "12830cc64f31789ae6f4352d2d51d03a25fc3741", - "version" : "2.1.6" - } - }, { "identity" : "gotrue-swift", "kind" : "remoteSourceControl", "location" : "https://github.com/supabase-community/gotrue-swift", "state" : { - "revision" : "6c7d119bf236fe0071ff05f3639fdde6f05e759a", - "version" : "1.2.0" + "branch" : "dependency-free", + "revision" : "6dc6d577ce88613cd1ae17b2367228e4d684b101" } }, { @@ -36,6 +18,15 @@ "version" : "4.2.2" } }, + { + "identity" : "mocker", + "kind" : "remoteSourceControl", + "location" : "https://github.com/WeTransfer/Mocker", + "state" : { + "revision" : "4384e015cae4916a6828252467a4437173c7ae17", + "version" : "3.0.1" + } + }, { "identity" : "postgrest-swift", "kind" : "remoteSourceControl", @@ -108,15 +99,6 @@ "version" : "0.8.0" } }, - { - "identity" : "urlqueryencoder", - "kind" : "remoteSourceControl", - "location" : "https://github.com/kean/URLQueryEncoder", - "state" : { - "revision" : "4ce950479707ea109f229d7230ec074a133b15d7", - "version" : "0.2.1" - } - }, { "identity" : "xctest-dynamic-overlay", "kind" : "remoteSourceControl", From af339e392697c866290bc9884974e8e01acca7ef Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Tue, 17 Oct 2023 17:13:41 -0300 Subject: [PATCH 2/7] Move postgrest to repo --- Package.resolved | 27 +- Package.swift | 33 ++- Sources/PostgREST/PostgrestBuilder.swift | 196 +++++++++++++++ Sources/PostgREST/PostgrestClient.swift | 188 ++++++++++++++ .../PostgREST/PostgrestFilterBuilder.swift | 233 ++++++++++++++++++ Sources/PostgREST/PostgrestQueryBuilder.swift | 164 ++++++++++++ Sources/PostgREST/PostgrestRpcBuilder.swift | 39 +++ .../PostgREST/PostgrestTransformBuilder.swift | 104 ++++++++ Sources/PostgREST/Types.swift | 96 ++++++++ Sources/PostgREST/URLQueryRepresentable.swift | 56 +++++ Sources/PostgREST/Version.swift | 1 + .../IntegrationTests.swift | 134 ++++++++++ .../PostgRESTTests/BuildURLRequestTests.swift | 118 +++++++++ .../PostgRESTTests/Helpers/LockIsolated.swift | 39 +++ .../PostgrestResponseTests.swift | 47 ++++ .../URLQueryRepresentableTests.swift | 16 ++ ...uildRequest.call-rpc-without-parameter.txt | 6 + .../testBuildRequest.call-rpc.txt | 7 + .../testBuildRequest.insert-new-user.txt | 7 + .../testBuildRequest.query-with-character.txt | 5 + ...testBuildRequest.query-with-timestampz.txt | 5 + ...sers-where-email-ends-with-supabase-co.txt | 5 + ...uildRequest.test-all-filters-and-count.txt | 5 + ...equest.test-contains-filter-with-array.txt | 5 + ...t.test-contains-filter-with-dictionary.txt | 5 + .../testBuildRequest.test-in-filter.txt | 5 + ...equest.test-upsert-ignoring-duplicates.txt | 8 + ...st.test-upsert-not-ignoring-duplicates.txt | 8 + .../xcshareddata/swiftpm/Package.resolved | 27 +- 29 files changed, 1564 insertions(+), 25 deletions(-) create mode 100644 Sources/PostgREST/PostgrestBuilder.swift create mode 100644 Sources/PostgREST/PostgrestClient.swift create mode 100644 Sources/PostgREST/PostgrestFilterBuilder.swift create mode 100644 Sources/PostgREST/PostgrestQueryBuilder.swift create mode 100644 Sources/PostgREST/PostgrestRpcBuilder.swift create mode 100644 Sources/PostgREST/PostgrestTransformBuilder.swift create mode 100644 Sources/PostgREST/Types.swift create mode 100644 Sources/PostgREST/URLQueryRepresentable.swift create mode 100644 Sources/PostgREST/Version.swift create mode 100644 Tests/PostgRESTIntegrationTests/IntegrationTests.swift create mode 100644 Tests/PostgRESTTests/BuildURLRequestTests.swift create mode 100644 Tests/PostgRESTTests/Helpers/LockIsolated.swift create mode 100644 Tests/PostgRESTTests/PostgrestResponseTests.swift create mode 100644 Tests/PostgRESTTests/URLQueryRepresentableTests.swift create mode 100644 Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.call-rpc-without-parameter.txt create mode 100644 Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.call-rpc.txt create mode 100644 Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.insert-new-user.txt create mode 100644 Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.query-with-character.txt create mode 100644 Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.query-with-timestampz.txt create mode 100644 Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.select-all-users-where-email-ends-with-supabase-co.txt create mode 100644 Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-all-filters-and-count.txt create mode 100644 Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-contains-filter-with-array.txt create mode 100644 Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-contains-filter-with-dictionary.txt create mode 100644 Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-in-filter.txt create mode 100644 Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-upsert-ignoring-duplicates.txt create mode 100644 Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-upsert-not-ignoring-duplicates.txt diff --git a/Package.resolved b/Package.resolved index 72f624f2..8f39d598 100644 --- a/Package.resolved +++ b/Package.resolved @@ -27,15 +27,6 @@ "version" : "3.0.1" } }, - { - "identity" : "postgrest-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/supabase-community/postgrest-swift", - "state" : { - "branch" : "dependency-free", - "revision" : "ef3faa36a6987d34c8d32bab1cf545d40bf1a182" - } - }, { "identity" : "realtime-swift", "kind" : "remoteSourceControl", @@ -53,6 +44,24 @@ "branch" : "dependency-free", "revision" : "62bf80cc46e22088ca390e506b1a712f4774a018" } + }, + { + "identity" : "swift-snapshot-testing", + "kind" : "remoteSourceControl", + "location" : "https://github.com/pointfreeco/swift-snapshot-testing", + "state" : { + "revision" : "bb0ea08db8e73324fe6c3727f755ca41a23ff2f4", + "version" : "1.14.2" + } + }, + { + "identity" : "swift-syntax", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-syntax.git", + "state" : { + "revision" : "74203046135342e4a4a627476dd6caf8b28fe11b", + "version" : "509.0.0" + } } ], "version" : 2 diff --git a/Package.swift b/Package.swift index 0a1cfd38..6286a034 100644 --- a/Package.swift +++ b/Package.swift @@ -16,15 +16,20 @@ var package = Package( products: [ .library( name: "Supabase", - targets: ["Supabase"] + targets: ["Supabase", "Functions", "PostgREST"] ), .library( name: "Functions", targets: ["Functions"] + ), + .library( + name: "PostgREST", + targets: ["PostgREST"] ) ], dependencies: [ .package(url: "https://github.com/WeTransfer/Mocker", from: "3.0.1"), + .package(url: "https://github.com/pointfreeco/swift-snapshot-testing", from: "1.8.1"), ], targets: [ .target( @@ -33,13 +38,32 @@ var package = Package( .product(name: "GoTrue", package: "gotrue-swift"), .product(name: "SupabaseStorage", package: "storage-swift"), .product(name: "Realtime", package: "realtime-swift"), - .product(name: "PostgREST", package: "postgrest-swift"), + "PostgREST", "Functions", ] ), .testTarget(name: "SupabaseTests", dependencies: ["Supabase"]), .target(name: "Functions"), .testTarget(name: "FunctionsTests", dependencies: ["Functions", "Mocker"]), + .target( + name: "PostgREST", + dependencies: [] + ), + .testTarget( + name: "PostgRESTTests", + dependencies: [ + "PostgREST", + .product( + name: "SnapshotTesting", + package: "swift-snapshot-testing", + condition: .when(platforms: [.iOS, .macOS, .tvOS]) + ), + ], + exclude: [ + "__Snapshots__" + ] + ), + .testTarget(name: "PostgRESTIntegrationTests", dependencies: ["PostgREST"]), ] ) @@ -49,7 +73,6 @@ if ProcessInfo.processInfo.environment["USE_LOCAL_PACKAGES"] != nil { .package(path: "../gotrue-swift"), .package(path: "../storage-swift"), .package(path: "../realtime-swift"), - .package(path: "../postgrest-swift"), ] ) } else { @@ -64,10 +87,6 @@ if ProcessInfo.processInfo.environment["USE_LOCAL_PACKAGES"] != nil { branch: "dependency-free" ), .package(url: "https://github.com/supabase-community/realtime-swift.git", from: "0.0.2"), - .package( - url: "https://github.com/supabase-community/postgrest-swift", - branch: "dependency-free" - ), ] ) } diff --git a/Sources/PostgREST/PostgrestBuilder.swift b/Sources/PostgREST/PostgrestBuilder.swift new file mode 100644 index 00000000..c512e5f6 --- /dev/null +++ b/Sources/PostgREST/PostgrestBuilder.swift @@ -0,0 +1,196 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +/// The builder class for creating and executing requests to a PostgREST server. +public class PostgrestBuilder { + /// The configuration for the PostgREST client. + let configuration: PostgrestClient.Configuration + /// The URL for the request. + let url: URL + /// The query parameters for the request. + var queryParams: [(name: String, value: String?)] + /// The headers for the request. + var headers: [String: String] + /// The HTTP method for the request. + var method: String + /// The body data for the request. + var body: Data? + + /// The options for fetching data from the PostgREST server. + var fetchOptions = FetchOptions() + + init( + configuration: PostgrestClient.Configuration, + url: URL, + queryParams: [(name: String, value: String?)], + headers: [String: String], + method: String, + body: Data? + ) { + self.configuration = configuration + self.url = url + self.queryParams = queryParams + self.headers = headers + self.method = method + self.body = body + } + + convenience init(_ other: PostgrestBuilder) { + self.init( + configuration: other.configuration, + url: other.url, + queryParams: other.queryParams, + headers: other.headers, + method: other.method, + body: other.body + ) + } + + /// Executes the request and returns a response of type Void. + /// - Parameters: + /// - options: Options for querying Supabase. + /// - Returns: A `PostgrestResponse` instance representing the response. + @discardableResult + public func execute( + options: FetchOptions = FetchOptions() + ) async throws -> PostgrestResponse { + self.fetchOptions = options + return try await execute { _ in () } + } + + /// Executes the request and returns a response of the specified type. + /// - Parameters: + /// - options: Options for querying Supabase. + /// - Returns: A `PostgrestResponse` instance representing the response. + @discardableResult + public func execute( + options: FetchOptions = FetchOptions() + ) async throws -> PostgrestResponse { + self.fetchOptions = options + return try await execute { [configuration] data in + try configuration.decoder.decode(T.self, from: data) + } + } + + func appendSearchParams(name: String, value: String) { + queryParams.append((name, value)) + } + + private func execute(decode: (Data) throws -> T) async throws -> PostgrestResponse { + if fetchOptions.head { + method = "HEAD" + } + + if let count = fetchOptions.count { + if let prefer = headers["Prefer"] { + headers["Prefer"] = "\(prefer),count=\(count.rawValue)" + } else { + headers["Prefer"] = "count=\(count.rawValue)" + } + } + + if headers["Accept"] == nil { + headers["Accept"] = "application/json" + } + headers["Content-Type"] = "application/json" + + if let schema = configuration.schema { + if method == "GET" || method == "HEAD" { + headers["Accept-Profile"] = schema + } else { + headers["Content-Profile"] = schema + } + } + + let urlRequest = try makeURLRequest() + + let (data, response) = try await configuration.fetch(urlRequest) + guard let httpResponse = response as? HTTPURLResponse else { + throw URLError(.badServerResponse) + } + + guard 200..<300 ~= httpResponse.statusCode else { + let error = try configuration.decoder.decode(PostgrestError.self, from: data) + throw error + } + + let value = try decode(data) + return PostgrestResponse(data: data, response: httpResponse, value: value) + } + + private func makeURLRequest() throws -> URLRequest { + guard var components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { + throw URLError(.badURL) + } + + if !queryParams.isEmpty { + let percentEncodedQuery = + (components.percentEncodedQuery.map { $0 + "&" } ?? "") + self.query(queryParams) + components.percentEncodedQuery = percentEncodedQuery + } + + guard let url = components.url else { + throw URLError(.badURL) + } + + var urlRequest = URLRequest(url: url) + + for (key, value) in headers { + urlRequest.setValue(value, forHTTPHeaderField: key) + } + + urlRequest.httpMethod = method + + if let body { + urlRequest.httpBody = body + } + + return urlRequest + } + + private func escape(_ string: String) -> String { + string.addingPercentEncoding(withAllowedCharacters: .postgrestURLQueryAllowed) ?? string + } + + private func query(_ parameters: [(String, String?)]) -> String { + parameters.compactMap { key, value in + if let value { + return (key, value) + } + return nil + } + .map { key, value in + let escapedKey = escape(key) + let escapedValue = escape(value) + return "\(escapedKey)=\(escapedValue)" + } + .joined(separator: "&") + } +} + +extension CharacterSet { + /// Creates a CharacterSet from RFC 3986 allowed characters. + /// + /// RFC 3986 states that the following characters are "reserved" characters. + /// + /// - General Delimiters: ":", "#", "[", "]", "@", "?", "/" + /// - Sub-Delimiters: "!", "$", "&", "'", "(", ")", "*", "+", ",", ";", "=" + /// + /// In RFC 3986 - Section 3.4, it states that the "?" and "/" characters should not be escaped to + /// allow + /// query strings to include a URL. Therefore, all "reserved" characters with the exception of "?" + /// and "/" + /// should be percent-escaped in the query string. + static let postgrestURLQueryAllowed: CharacterSet = { + let generalDelimitersToEncode = + ":#[]@" // does not include "?" or "/" due to RFC 3986 - Section 3.4 + let subDelimitersToEncode = "!$&'()*+,;=" + let encodableDelimiters = + CharacterSet(charactersIn: "\(generalDelimitersToEncode)\(subDelimitersToEncode)") + + return CharacterSet.urlQueryAllowed.subtracting(encodableDelimiters) + }() +} diff --git a/Sources/PostgREST/PostgrestClient.swift b/Sources/PostgREST/PostgrestClient.swift new file mode 100644 index 00000000..1da7e7cf --- /dev/null +++ b/Sources/PostgREST/PostgrestClient.swift @@ -0,0 +1,188 @@ +import Foundation + +/// PostgREST client. +public actor PostgrestClient { + public typealias FetchHandler = @Sendable (_ request: URLRequest) async throws -> ( + Data, URLResponse + ) + + /// The configuration struct for the PostgREST client. + public struct Configuration { + public var url: URL + public var schema: String? + public var headers: [String: String] + public var fetch: FetchHandler + public var encoder: JSONEncoder + public var decoder: JSONDecoder + + /// Initializes a new configuration for the PostgREST client. + /// - Parameters: + /// - url: The URL of the PostgREST server. + /// - schema: The schema to use. + /// - headers: The headers to include in requests. + /// - fetch: The fetch handler to use for requests. + /// - encoder: The JSONEncoder to use for encoding. + /// - decoder: The JSONDecoder to use for decoding. + public init( + url: URL, + schema: String? = nil, + headers: [String: String] = [:], + fetch: @escaping FetchHandler = { try await URLSession.shared.data(for: $0) }, + encoder: JSONEncoder = .postgrest, + decoder: JSONDecoder = .postgrest + ) { + self.url = url + self.schema = schema + self.headers = headers + self.fetch = fetch + self.encoder = encoder + self.decoder = decoder + } + } + + public private(set) var configuration: Configuration + + /// Creates a PostgREST client with the specified configuration. + /// - Parameter configuration: The configuration for the client. + public init(configuration: Configuration) { + var configuration = configuration + configuration.headers["X-Client-Info"] = "postgrest-swift/\(version)" + self.configuration = configuration + } + + /// Creates a PostgREST client with the specified parameters. + /// - Parameters: + /// - url: The URL of the PostgREST server. + /// - schema: The schema to use. + /// - headers: The headers to include in requests. + /// - session: The URLSession to use for requests. + /// - encoder: The JSONEncoder to use for encoding. + /// - decoder: The JSONDecoder to use for decoding. + public init( + url: URL, + schema: String? = nil, + headers: [String: String] = [:], + fetch: @escaping FetchHandler = { try await URLSession.shared.data(for: $0) }, + encoder: JSONEncoder = .postgrest, + decoder: JSONDecoder = .postgrest + ) { + self.init( + configuration: Configuration( + url: url, + schema: schema, + headers: headers, + fetch: fetch, + encoder: encoder, + decoder: decoder + ) + ) + } + + /// Sets the authorization token for the client. + /// - Parameter token: The authorization token. + /// - Returns: The PostgrestClient instance. + @discardableResult + public func setAuth(_ token: String?) -> PostgrestClient { + if let token { + configuration.headers["Authorization"] = "Bearer \(token)" + } else { + configuration.headers.removeValue(forKey: "Authorization") + } + return self + } + + /// Performs a query on a table or a view. + /// - Parameter table: The table or view name to query. + /// - Returns: A PostgrestQueryBuilder instance. + public func from(_ table: String) -> PostgrestQueryBuilder { + PostgrestQueryBuilder( + configuration: configuration, + url: configuration.url.appendingPathComponent(table), + queryParams: [], + headers: configuration.headers, + method: "GET", + body: nil + ) + } + + /// Performs a function call. + /// - Parameters: + /// - fn: The function name to call. + /// - params: The parameters to pass to the function call. + /// - count: Count algorithm to use to count rows returned by the function. + /// Only applicable for set-returning functions. + /// - Returns: A PostgrestTransformBuilder instance. + /// - Throws: An error if the function call fails. + public func rpc( + fn: String, + params: U, + count: CountOption? = nil + ) throws -> PostgrestTransformBuilder { + try PostgrestRpcBuilder( + configuration: configuration, + url: configuration.url.appendingPathComponent("rpc").appendingPathComponent(fn), + queryParams: [], + headers: configuration.headers, + method: "POST", + body: nil + ).rpc(params: params, count: count) + } + + /// Performs a function call. + /// - Parameters: + /// - fn: The function name to call. + /// - count: Count algorithm to use to count rows returned by the function. + /// Only applicable for set-returning functions. + /// - Returns: A PostgrestTransformBuilder instance. + /// - Throws: An error if the function call fails. + public func rpc( + fn: String, + count: CountOption? = nil + ) throws -> PostgrestTransformBuilder { + try rpc(fn: fn, params: NoParams(), count: count) + } +} + +private let supportedDateFormatters: [ISO8601DateFormatter] = [ + { () -> ISO8601DateFormatter in + let formatter = ISO8601DateFormatter() + formatter.formatOptions = [.withInternetDateTime, .withFractionalSeconds] + return formatter + }(), + { () -> ISO8601DateFormatter in + let formatter = ISO8601DateFormatter() + formatter.formatOptions = [.withInternetDateTime] + return formatter + }(), +] + +extension JSONDecoder { + /// The JSONDecoder instance for PostgREST responses. + public static let postgrest = { () -> JSONDecoder in + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = .custom { decoder in + let container = try decoder.singleValueContainer() + let string = try container.decode(String.self) + + for formatter in supportedDateFormatters { + if let date = formatter.date(from: string) { + return date + } + } + + throw DecodingError.dataCorruptedError( + in: container, debugDescription: "Invalid date format: \(string)" + ) + } + return decoder + }() +} + +extension JSONEncoder { + /// The JSONEncoder instance for PostgREST requests. + public static let postgrest = { () -> JSONEncoder in + let encoder = JSONEncoder() + encoder.dateEncodingStrategy = .iso8601 + return encoder + }() +} diff --git a/Sources/PostgREST/PostgrestFilterBuilder.swift b/Sources/PostgREST/PostgrestFilterBuilder.swift new file mode 100644 index 00000000..926d4a8e --- /dev/null +++ b/Sources/PostgREST/PostgrestFilterBuilder.swift @@ -0,0 +1,233 @@ +import Foundation + +public class PostgrestFilterBuilder: PostgrestTransformBuilder { + public enum Operator: String, CaseIterable { + case eq, neq, gt, gte, lt, lte, like, ilike, `is`, `in`, cs, cd, sl, sr, nxl, nxr, adj, ov, fts, + plfts, phfts, wfts + } + + // MARK: - Filters + + public func not(column: String, operator op: Operator, value: URLQueryRepresentable) + -> PostgrestFilterBuilder + { + appendSearchParams(name: column, value: "not.\(op.rawValue).\(value.queryValue)") + return self + } + + public func or(filters: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: "or", value: "(\(filters.queryValue.queryValue))") + return self + } + + public func eq(column: String, value: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "eq.\(value.queryValue)") + return self + } + + public func neq(column: String, value: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "neq.\(value.queryValue)") + return self + } + + public func gt(column: String, value: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "gt.\(value.queryValue)") + return self + } + + public func gte(column: String, value: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "gte.\(value.queryValue)") + return self + } + + public func lt(column: String, value: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "lt.\(value.queryValue)") + return self + } + + public func lte(column: String, value: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "lte.\(value.queryValue)") + return self + } + + public func like(column: String, value: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "like.\(value.queryValue)") + return self + } + + public func ilike(column: String, value: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "ilike.\(value.queryValue)") + return self + } + + public func `is`(column: String, value: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "is.\(value.queryValue)") + return self + } + + public func `in`(column: String, value: [URLQueryRepresentable]) -> PostgrestFilterBuilder { + appendSearchParams( + name: column, + value: "in.(\(value.map(\.queryValue).joined(separator: ",")))" + ) + return self + } + + public func contains(column: String, value: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "cs.\(value.queryValue)") + return self + } + + public func rangeLt(column: String, range: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "sl.\(range.queryValue)") + return self + } + + public func rangeGt(column: String, range: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "sr.\(range.queryValue)") + return self + } + + public func rangeGte(column: String, range: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "nxl.\(range.queryValue)") + return self + } + + public func rangeLte(column: String, range: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "nxr.\(range.queryValue)") + return self + } + + public func rangeAdjacent(column: String, range: URLQueryRepresentable) -> PostgrestFilterBuilder + { + appendSearchParams(name: column, value: "adj.\(range.queryValue)") + return self + } + + public func overlaps(column: String, value: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "ov.\(value.queryValue)") + return self + } + + public func textSearch(column: String, range: URLQueryRepresentable) -> PostgrestFilterBuilder { + appendSearchParams(name: column, value: "adj.\(range.queryValue)") + return self + } + + public func textSearch( + column: String, query: URLQueryRepresentable, config: String? = nil, type: TextSearchType? = nil + ) -> PostgrestFilterBuilder { + appendSearchParams( + name: column, value: "\(type?.rawValue ?? "")fts\(config ?? "").\(query.queryValue)" + ) + return self + } + + public func fts(column: String, query: URLQueryRepresentable, config: String? = nil) + -> PostgrestFilterBuilder + { + appendSearchParams(name: column, value: "fts\(config ?? "").\(query.queryValue)") + return self + } + + public func plfts(column: String, query: URLQueryRepresentable, config: String? = nil) + -> PostgrestFilterBuilder + { + appendSearchParams(name: column, value: "plfts\(config ?? "").\(query.queryValue)") + return self + } + + public func phfts(column: String, query: URLQueryRepresentable, config: String? = nil) + -> PostgrestFilterBuilder + { + appendSearchParams(name: column, value: "phfts\(config ?? "").\(query.queryValue)") + return self + } + + public func wfts(column: String, query: URLQueryRepresentable, config: String? = nil) + -> PostgrestFilterBuilder + { + appendSearchParams(name: column, value: "wfts\(config ?? "").\(query.queryValue)") + return self + } + + public func filter(column: String, operator: Operator, value: URLQueryRepresentable) + -> PostgrestFilterBuilder + { + appendSearchParams(name: column, value: "\(`operator`.rawValue).\(value.queryValue)") + return self + } + + public func match(query: [String: URLQueryRepresentable]) -> PostgrestFilterBuilder { + query.forEach { key, value in + appendSearchParams(name: key, value: "eq.\(value.queryValue)") + } + return self + } + + // MARK: - Filter Semantic Improvements + + public func equals(column: String, value: String) -> PostgrestFilterBuilder { + eq(column: column, value: value) + } + + public func notEquals(column: String, value: String) -> PostgrestFilterBuilder { + neq(column: column, value: value) + } + + public func greaterThan(column: String, value: String) -> PostgrestFilterBuilder { + gt(column: column, value: value) + } + + public func greaterThanOrEquals(column: String, value: String) -> PostgrestFilterBuilder { + gte(column: column, value: value) + } + + public func lowerThan(column: String, value: String) -> PostgrestFilterBuilder { + lt(column: column, value: value) + } + + public func lowerThanOrEquals(column: String, value: String) -> PostgrestFilterBuilder { + lte(column: column, value: value) + } + + public func rangeLowerThan(column: String, range: String) -> PostgrestFilterBuilder { + rangeLt(column: column, range: range) + } + + public func rangeGreaterThan(column: String, value: String) -> PostgrestFilterBuilder { + rangeGt(column: column, range: value) + } + + public func rangeGreaterThanOrEquals(column: String, value: String) -> PostgrestFilterBuilder { + rangeGte(column: column, range: value) + } + + public func rangeLowerThanOrEquals(column: String, value: String) -> PostgrestFilterBuilder { + rangeLte(column: column, range: value) + } + + public func fullTextSearch(column: String, query: String, config: String? = nil) + -> PostgrestFilterBuilder + { + fts(column: column, query: query, config: config) + } + + public func plainToFullTextSearch(column: String, query: String, config: String? = nil) + -> PostgrestFilterBuilder + { + plfts(column: column, query: query, config: config) + } + + public func phraseToFullTextSearch(column: String, query: String, config: String? = nil) + -> PostgrestFilterBuilder + { + phfts(column: column, query: query, config: config) + } + + public func webFullTextSearch(column: String, query: String, config: String? = nil) + -> PostgrestFilterBuilder + { + wfts(column: column, query: query, config: config) + } +} diff --git a/Sources/PostgREST/PostgrestQueryBuilder.swift b/Sources/PostgREST/PostgrestQueryBuilder.swift new file mode 100644 index 00000000..16b8ad66 --- /dev/null +++ b/Sources/PostgREST/PostgrestQueryBuilder.swift @@ -0,0 +1,164 @@ +import Foundation + +public final class PostgrestQueryBuilder: PostgrestBuilder { + /// Performs a vertical filtering with SELECT. + /// - Parameters: + /// - columns: The columns to retrieve, separated by commas. + /// - head: When set to true, select will void data. + /// - count: Count algorithm to use to count rows in a table. + /// - Returns: A `PostgrestFilterBuilder` instance for further filtering or operations. + public func select( + columns: String = "*", + head: Bool = false, + count: CountOption? = nil + ) -> PostgrestFilterBuilder { + method = "GET" + // remove whitespaces except when quoted. + var quoted = false + let cleanedColumns = columns.compactMap { char -> String? in + if char.isWhitespace, !quoted { + return nil + } + if char == "\"" { + quoted = !quoted + } + return String(char) + } + .joined(separator: "") + appendSearchParams(name: "select", value: cleanedColumns) + if let count = count { + headers["Prefer"] = "count=\(count.rawValue)" + } + if head { + method = "HEAD" + } + return PostgrestFilterBuilder(self) + } + + /// Performs an INSERT into the table. + /// - Parameters: + /// - values: The values to insert. + /// - returning: The returning options for the query. + /// - count: Count algorithm to use to count rows in a table. + /// - Returns: A `PostgrestFilterBuilder` instance for further filtering or operations. + /// - Throws: An error if the insert fails. + public func insert( + values: U, + returning: PostgrestReturningOptions? = nil, + count: CountOption? = nil + ) throws -> PostgrestFilterBuilder { + method = "POST" + var prefersHeaders: [String] = [] + if let returning = returning { + prefersHeaders.append("return=\(returning.rawValue)") + } + body = try configuration.encoder.encode(values) + if let count = count { + prefersHeaders.append("count=\(count.rawValue)") + } + if let prefer = headers["Prefer"] { + prefersHeaders.insert(prefer, at: 0) + } + if !prefersHeaders.isEmpty { + headers["Prefer"] = prefersHeaders.joined(separator: ",") + } + + // TODO: How to do this in Swift? + // if (Array.isArray(values)) { + // const columns = values.reduce((acc, x) => acc.concat(Object.keys(x)), [] as string[]) + // if (columns.length > 0) { + // const uniqueColumns = [...new Set(columns)].map((column) => `"${column}"`) + // this.url.searchParams.set('columns', uniqueColumns.join(',')) + // } + // } + + return PostgrestFilterBuilder(self) + } + + /// Performs an UPSERT into the table. + /// - Parameters: + /// - values: The values to insert. + /// - onConflict: The column(s) with a unique constraint to perform the UPSERT. + /// - returning: The returning options for the query. + /// - count: Count algorithm to use to count rows in a table. + /// - ignoreDuplicates: Specifies if duplicate rows should be ignored and not inserted. + /// - Returns: A `PostgrestFilterBuilder` instance for further filtering or operations. + /// - Throws: An error if the upsert fails. + public func upsert( + values: U, + onConflict: String? = nil, + returning: PostgrestReturningOptions = .representation, + count: CountOption? = nil, + ignoreDuplicates: Bool = false + ) throws -> PostgrestFilterBuilder { + method = "POST" + var prefersHeaders = [ + "resolution=\(ignoreDuplicates ? "ignore" : "merge")-duplicates", + "return=\(returning.rawValue)", + ] + if let onConflict = onConflict { + appendSearchParams(name: "on_conflict", value: onConflict) + } + body = try configuration.encoder.encode(values) + if let count = count { + prefersHeaders.append("count=\(count.rawValue)") + } + if let prefer = headers["Prefer"] { + prefersHeaders.insert(prefer, at: 0) + } + if !prefersHeaders.isEmpty { + headers["Prefer"] = prefersHeaders.joined(separator: ",") + } + return PostgrestFilterBuilder(self) + } + + /// Performs an UPDATE on the table. + /// - Parameters: + /// - values: The values to update. + /// - returning: The returning options for the query. + /// - count: Count algorithm to use to count rows in a table. + /// - Returns: A `PostgrestFilterBuilder` instance for further filtering or operations. + /// - Throws: An error if the update fails. + public func update( + values: U, + returning: PostgrestReturningOptions = .representation, + count: CountOption? = nil + ) throws -> PostgrestFilterBuilder { + method = "PATCH" + var preferHeaders = ["return=\(returning.rawValue)"] + body = try configuration.encoder.encode(values) + if let count = count { + preferHeaders.append("count=\(count.rawValue)") + } + if let prefer = headers["Prefer"] { + preferHeaders.insert(prefer, at: 0) + } + if !preferHeaders.isEmpty { + headers["Prefer"] = preferHeaders.joined(separator: ",") + } + return PostgrestFilterBuilder(self) + } + + /// Performs a DELETE on the table. + /// - Parameters: + /// - returning: The returning options for the query. + /// - count: Count algorithm to use to count rows in a table. + /// - Returns: A `PostgrestFilterBuilder` instance for further filtering or operations. + public func delete( + returning: PostgrestReturningOptions = .representation, + count: CountOption? = nil + ) -> PostgrestFilterBuilder { + method = "DELETE" + var preferHeaders = ["return=\(returning.rawValue)"] + if let count = count { + preferHeaders.append("count=\(count.rawValue)") + } + if let prefer = headers["Prefer"] { + preferHeaders.insert(prefer, at: 0) + } + if !preferHeaders.isEmpty { + headers["Prefer"] = preferHeaders.joined(separator: ",") + } + return PostgrestFilterBuilder(self) + } +} diff --git a/Sources/PostgREST/PostgrestRpcBuilder.swift b/Sources/PostgREST/PostgrestRpcBuilder.swift new file mode 100644 index 00000000..480d2f9c --- /dev/null +++ b/Sources/PostgREST/PostgrestRpcBuilder.swift @@ -0,0 +1,39 @@ +import Foundation + +struct NoParams: Encodable {} + +public final class PostgrestRpcBuilder: PostgrestBuilder { + /// Performs a function call with parameters. + /// - Parameters: + /// - params: The parameters to pass to the function. + /// - head: When set to `true`, the function call will use the `HEAD` method. Default is `false`. + /// - count: Count algorithm to use to count rows in a table. Default is `nil`. + /// - Returns: The `PostgrestTransformBuilder` instance for method chaining. + /// - Throws: An error if the function call fails. + func rpc( + params: U, + head: Bool = false, + count: CountOption? = nil + ) throws -> PostgrestTransformBuilder { + // TODO: Support `HEAD` method + // https://github.com/supabase/postgrest-js/blob/master/src/lib/PostgrestRpcBuilder.ts#L38 + assert(head == false, "HEAD is not currently supported yet.") + + method = "POST" + if params is NoParams { + // noop + } else { + body = try configuration.encoder.encode(params) + } + + if let count = count { + if let prefer = headers["Prefer"] { + headers["Prefer"] = "\(prefer),count=\(count.rawValue)" + } else { + headers["Prefer"] = "count=\(count.rawValue)" + } + } + + return PostgrestTransformBuilder(self) + } +} diff --git a/Sources/PostgREST/PostgrestTransformBuilder.swift b/Sources/PostgREST/PostgrestTransformBuilder.swift new file mode 100644 index 00000000..dfa6ee68 --- /dev/null +++ b/Sources/PostgREST/PostgrestTransformBuilder.swift @@ -0,0 +1,104 @@ +public class PostgrestTransformBuilder: PostgrestBuilder { + /// Performs a vertical filtering with SELECT. + /// - Parameters: + /// - columns: The columns to retrieve, separated by commas. + public func select(columns: String = "*") -> PostgrestTransformBuilder { + // remove whitespaces except when quoted. + var quoted = false + let cleanedColumns = columns.compactMap { char -> String? in + if char.isWhitespace, !quoted { + return nil + } + if char == "\"" { + quoted = !quoted + } + return String(char) + } + .joined(separator: "") + appendSearchParams(name: "select", value: cleanedColumns) + return self + } + + /// Orders the result with the specified `column`. + /// - Parameters: + /// - column: The column to order on. + /// - ascending: If `true`, the result will be in ascending order. + /// - nullsFirst: If `true`, `null`s appear first. + /// - foreignTable: The foreign table to use (if `column` is a foreign column). + public func order( + column: String, + ascending: Bool = true, + nullsFirst: Bool = false, + foreignTable: String? = nil + ) -> PostgrestTransformBuilder { + let key = foreignTable.map { "\($0).order" } ?? "order" + let existingOrderIndex = queryParams.firstIndex(where: { $0.name == key }) + let value = "\(column).\(ascending ? "asc" : "desc").\(nullsFirst ? "nullsfirst" : "nullslast")" + + if let existingOrderIndex = existingOrderIndex, + let currentValue = queryParams[existingOrderIndex].value + { + queryParams[existingOrderIndex] = (key, "\(currentValue),\(value)") + } else { + queryParams.append((key, value)) + } + + return self + } + + /// Limits the result with the specified `count`. + /// - Parameters: + /// - count: The maximum no. of rows to limit to. + /// - foreignTable: The foreign table to use (for foreign columns). + public func limit(count: Int, foreignTable: String? = nil) -> PostgrestTransformBuilder { + let key = foreignTable.map { "\($0).limit" } ?? "limit" + if let index = queryParams.firstIndex(where: { $0.name == key }) { + queryParams[index] = (key, "\(count)") + } else { + queryParams.append((key, "\(count)")) + } + return self + } + + /// Limits the result to rows within the specified range, inclusive. + /// - Parameters: + /// - lowerBounds: The starting index from which to limit the result, inclusive. + /// - upperBounds: The last index to which to limit the result, inclusve. + /// - foreignTable: The foreign table to use (for foreign columns). + public func range( + from lowerBounds: Int, + to upperBounds: Int, + foreignTable: String? = nil + ) -> PostgrestTransformBuilder { + let keyOffset = foreignTable.map { "\($0).offset" } ?? "offset" + let keyLimit = foreignTable.map { "\($0).limit" } ?? "limit" + + if let index = queryParams.firstIndex(where: { $0.name == keyOffset }) { + queryParams[index] = (keyOffset, "\(lowerBounds)") + } else { + queryParams.append((keyOffset, "\(lowerBounds)")) + } + + // Range is inclusive, so add 1 + if let index = queryParams.firstIndex(where: { $0.name == keyLimit }) { + queryParams[index] = (keyLimit, "\(upperBounds - lowerBounds + 1)") + } else { + queryParams.append((keyLimit, "\(upperBounds - lowerBounds + 1)")) + } + + return self + } + + /// Retrieves only one row from the result. Result must be one row (e.g. using `limit`), otherwise + /// this will result in an error. + public func single() -> PostgrestTransformBuilder { + headers["Accept"] = "application/vnd.pgrst.object+json" + return self + } + + /// Set the response type to CSV. + public func csv() -> PostgrestTransformBuilder { + headers["Accept"] = "text/csv" + return self + } +} diff --git a/Sources/PostgREST/Types.swift b/Sources/PostgREST/Types.swift new file mode 100644 index 00000000..80cfa4a5 --- /dev/null +++ b/Sources/PostgREST/Types.swift @@ -0,0 +1,96 @@ +import Foundation + +public struct PostgrestError: Error, Codable { + public let details: String? + public let hint: String? + public let code: String? + public let message: String + + public init(details: String? = nil, hint: String? = nil, code: String? = nil, message: String) { + self.hint = hint + self.details = details + self.code = code + self.message = message + } +} + +extension PostgrestError: LocalizedError { + public var errorDescription: String? { + message + } +} + +public struct PostgrestResponse { + public let data: Data + public let response: HTTPURLResponse + public let count: Int? + public let value: T + + public var status: Int { + response.statusCode + } + + public init( + data: Data, + response: HTTPURLResponse, + value: T + ) { + var count: Int? + + if let contentRange = response.value(forHTTPHeaderField: "Content-Range")?.split(separator: "/") + .last + { + count = contentRange == "*" ? nil : Int(contentRange) + } + + self.data = data + self.response = response + self.count = count + self.value = value + } +} + +/// Returns count as part of the response when specified. +public enum CountOption: String { + case exact + case planned + case estimated +} + +/// Enum of options representing the ways PostgREST can return values from the server. +/// +/// https://postgrest.org/en/v9.0/api.html?highlight=PREFER#insertions-updates +public enum PostgrestReturningOptions: String { + /// Returns nothing from the server + case minimal + /// Returns a copy of the updated data. + case representation +} + +/// The type of tsquery conversion to use on query. +public enum TextSearchType: String { + /// Uses PostgreSQL's plainto_tsquery function. + case plain = "pl" + /// Uses PostgreSQL's phraseto_tsquery function. + case phrase = "ph" + /// Uses PostgreSQL's websearch_to_tsquery function. + /// This function will never raise syntax errors, which makes it possible to use raw user-supplied + /// input for search, and can be used with advanced operators. + case websearch = "w" +} + +/// Options for querying Supabase. +public struct FetchOptions { + /// Set head to true if you only want the count value and not the underlying data. + public let head: Bool + + /// count options can be used to retrieve the total number of rows that satisfies the + /// query. The value for count respects any filters (e.g. eq, gt), but ignores + /// modifiers (e.g. limit, range). + public let count: CountOption? + + public init(head: Bool = false, count: CountOption? = nil) { + self.head = head + self.count = count + } +} diff --git a/Sources/PostgREST/URLQueryRepresentable.swift b/Sources/PostgREST/URLQueryRepresentable.swift new file mode 100644 index 00000000..fcf5fcdd --- /dev/null +++ b/Sources/PostgREST/URLQueryRepresentable.swift @@ -0,0 +1,56 @@ +import Foundation + +/// A type that can fit into the query part of a URL. +public protocol URLQueryRepresentable { + /// A String representation of this instance that can fit as a query parameter's value. + var queryValue: String { get } +} + +extension String: URLQueryRepresentable { + public var queryValue: String { self } +} + +extension Int: URLQueryRepresentable { + public var queryValue: String { "\(self)" } +} + +extension Double: URLQueryRepresentable { + public var queryValue: String { "\(self)" } +} + +extension Bool: URLQueryRepresentable { + public var queryValue: String { "\(self)" } +} + +extension UUID: URLQueryRepresentable { + public var queryValue: String { uuidString } +} + +extension Array: URLQueryRepresentable where Element: URLQueryRepresentable { + public var queryValue: String { + "{\(map(\.queryValue).joined(separator: ","))}" + } +} + +extension Dictionary: URLQueryRepresentable +where + Key: URLQueryRepresentable, + Value: URLQueryRepresentable +{ + public var queryValue: String { + JSONSerialization.stringfy(self) + } +} + +extension JSONSerialization { + static func stringfy(_ object: Any) -> String { + guard + let data = try? data( + withJSONObject: object, options: [.withoutEscapingSlashes, .sortedKeys]), + let string = String(data: data, encoding: .utf8) + else { + return "{}" + } + return string + } +} diff --git a/Sources/PostgREST/Version.swift b/Sources/PostgREST/Version.swift new file mode 100644 index 00000000..0951025c --- /dev/null +++ b/Sources/PostgREST/Version.swift @@ -0,0 +1 @@ +let version = "1.0.2" diff --git a/Tests/PostgRESTIntegrationTests/IntegrationTests.swift b/Tests/PostgRESTIntegrationTests/IntegrationTests.swift new file mode 100644 index 00000000..abb10857 --- /dev/null +++ b/Tests/PostgRESTIntegrationTests/IntegrationTests.swift @@ -0,0 +1,134 @@ +import PostgREST +import XCTest + +struct Todo: Codable, Hashable { + let id: UUID + var description: String + var isComplete: Bool + var tags: [String] + let createdAt: Date + + enum CodingKeys: String, CodingKey { + case id + case description + case isComplete = "is_complete" + case tags + case createdAt = "created_at" + } +} + +struct NewTodo: Codable, Hashable { + var description: String + var isComplete: Bool = false + var tags: [String] + + enum CodingKeys: String, CodingKey { + case description + case isComplete = "is_complete" + case tags + } +} + +struct User: Codable, Hashable { + let email: String +} + +@available(iOS 15.0.0, macOS 12.0.0, tvOS 13.0, *) +final class IntegrationTests: XCTestCase { + let client = PostgrestClient( + url: URL(string: "http://localhost:54321/rest/v1")!, + headers: [ + "apikey": + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZS1kZW1vIiwicm9sZSI6ImFub24iLCJleHAiOjE5ODM4MTI5OTZ9.CRXP1A7WOeoJeXxjNni43kdQwgnWNReilDMblYTn_I0" + ] + ) + + override func setUp() async throws { + try await super.setUp() + + try XCTSkipUnless( + ProcessInfo.processInfo.environment["INTEGRATION_TESTS"] != nil, + "INTEGRATION_TESTS not defined." + ) + + // Run fresh test by deleting all data. Delete without a where clause isn't supported, so have + // to do this `neq` trick to delete all data. + try await client.from("todos").delete().neq(column: "id", value: UUID().uuidString).execute() + try await client.from("users").delete().neq(column: "id", value: UUID().uuidString).execute() + } + + func testIntegration() async throws { + var todos: [Todo] = try await client.from("todos").select().execute().value + XCTAssertEqual(todos, []) + + let insertedTodo: Todo = try await client.from("todos") + .insert( + values: NewTodo( + description: "Implement integration tests for postgrest-swift", + tags: ["tag 01", "tag 02"] + ), + returning: .representation + ) + .single() + .execute() + .value + + todos = try await client.from("todos").select().execute().value + XCTAssertEqual(todos, [insertedTodo]) + + let insertedTodos: [Todo] = try await client.from("todos") + .insert( + values: [ + NewTodo(description: "Make supabase swift libraries production ready", tags: ["tag 01"]), + NewTodo(description: "Drink some coffee", tags: ["tag 02"]), + ], + returning: .representation + ) + .execute() + .value + + todos = try await client.from("todos").select().execute().value + XCTAssertEqual(todos, [insertedTodo] + insertedTodos) + + let drinkCoffeeTodo = insertedTodos[1] + let updatedTodo: Todo = try await client.from("todos") + .update(values: ["is_complete": true]) + .eq(column: "id", value: drinkCoffeeTodo.id.uuidString) + .single() + .execute() + .value + XCTAssertTrue(updatedTodo.isComplete) + + let completedTodos: [Todo] = try await client.from("todos") + .select() + .eq(column: "is_complete", value: true) + .execute() + .value + XCTAssertEqual(completedTodos, [updatedTodo]) + + try await client.from("todos").delete().eq(column: "is_complete", value: true).execute() + todos = try await client.from("todos").select().execute().value + XCTAssertTrue(completedTodos.allSatisfy { todo in !todos.contains(todo) }) + + let todosWithSpecificTag: [Todo] = try await client.from("todos").select() + .contains(column: "tags", value: ["tag 01"]).execute().value + XCTAssertEqual(todosWithSpecificTag, [insertedTodo, insertedTodos[0]]) + } + + func testQueryWithPlusSign() async throws { + let users = [ + User(email: "johndoe@mail.com"), + User(email: "johndoe+test1@mail.com"), + User(email: "johndoe+test2@mail.com"), + ] + + try await client.from("users").insert(values: users).execute() + + let fetchedUsers: [User] = try await client.from("users").select() + .ilike(column: "email", value: "johndoe+test%").execute().value + XCTAssertEqual( + fetchedUsers[...], + users[1...2] + ) + } +} diff --git a/Tests/PostgRESTTests/BuildURLRequestTests.swift b/Tests/PostgRESTTests/BuildURLRequestTests.swift new file mode 100644 index 00000000..0acfc22a --- /dev/null +++ b/Tests/PostgRESTTests/BuildURLRequestTests.swift @@ -0,0 +1,118 @@ +#if !os(watchOS) + import Foundation + import SnapshotTesting + import XCTest + + @testable import PostgREST + + #if canImport(FoundationNetworking) + import FoundationNetworking + #endif + + @MainActor + final class BuildURLRequestTests: XCTestCase { + let url = URL(string: "https://example.supabase.co")! + + struct TestCase: Sendable { + let name: String + var record = false + let build: @Sendable (PostgrestClient) async throws -> PostgrestBuilder + } + + func testBuildRequest() async throws { + let runningTestCase = LockIsolated(Optional.none) + + let client = PostgrestClient( + url: url, schema: nil, + fetch: { @MainActor request in + runningTestCase.withValue { runningTestCase in + guard let runningTestCase = runningTestCase else { + XCTFail("execute called without a runningTestCase set.") + return (Data(), URLResponse()) + } + + assertSnapshot( + matching: request, + as: .curl, + named: runningTestCase.name, + record: runningTestCase.record, + testName: "testBuildRequest()" + ) + + return (Data(), URLResponse()) + } + }) + + let testCases: [TestCase] = [ + TestCase(name: "select all users where email ends with '@supabase.co'") { client in + await client.from("users") + .select() + .like(column: "email", value: "%@supabase.co") + }, + TestCase(name: "insert new user") { client in + try await client.from("users") + .insert(values: ["email": "johndoe@supabase.io"]) + }, + TestCase(name: "call rpc") { client in + try await client.rpc(fn: "test_fcn", params: ["KEY": "VALUE"]) + }, + TestCase(name: "call rpc without parameter") { client in + try await client.rpc(fn: "test_fcn") + }, + TestCase(name: "test all filters and count") { client in + var query = await client.from("todos").select() + + for op in PostgrestFilterBuilder.Operator.allCases { + query = query.filter(column: "column", operator: op, value: "Some value") + } + + return query + }, + TestCase(name: "test in filter") { client in + await client.from("todos").select().in(column: "id", value: [1, 2, 3]) + }, + TestCase(name: "test contains filter with dictionary") { client in + await client.from("users").select(columns: "name") + .contains(column: "address", value: ["postcode": 90210]) + }, + TestCase(name: "test contains filter with array") { client in + await client.from("users") + .select() + .contains(column: "name", value: ["is:online", "faction:red"]) + }, + TestCase(name: "test upsert not ignoring duplicates") { client in + try await client.from("users") + .upsert(values: ["email": "johndoe@supabase.io"]) + }, + TestCase(name: "test upsert ignoring duplicates") { client in + try await client.from("users") + .upsert(values: ["email": "johndoe@supabase.io"], ignoreDuplicates: true) + }, + TestCase(name: "query with + character") { client in + await client.from("users") + .select() + .eq(column: "id", value: "Cigányka-ér (0+400 cskm) vízrajzi állomás") + }, + TestCase(name: "query with timestampz") { client in + await client.from("tasks") + .select() + .gt(column: "received_at", value: "2023-03-23T15:50:30.511743+00:00") + .order(column: "received_at") + }, + ] + + for testCase in testCases { + runningTestCase.withValue { $0 = testCase } + let builder = try await testCase.build(client) + _ = try? await builder.execute() + } + } + + func testSessionConfiguration() async { + let client = PostgrestClient(url: url, schema: nil) + let clientInfoHeader = await client.configuration.headers["X-Client-Info"] + XCTAssertNotNil(clientInfoHeader) + } + } + +#endif diff --git a/Tests/PostgRESTTests/Helpers/LockIsolated.swift b/Tests/PostgRESTTests/Helpers/LockIsolated.swift new file mode 100644 index 00000000..7afd7aa3 --- /dev/null +++ b/Tests/PostgRESTTests/Helpers/LockIsolated.swift @@ -0,0 +1,39 @@ +// +// File.swift +// +// +// Created by Guilherme Souza on 07/10/23. +// + +import Foundation + +final class LockIsolated: @unchecked Sendable { + private let lock = NSRecursiveLock() + private var _value: Value + + init(_ value: Value) { + self._value = value + } + + @discardableResult + func withValue(_ block: (inout Value) throws -> T) rethrows -> T { + try lock.sync { + var value = self._value + defer { self._value = value } + return try block(&value) + } + } + + var value: Value { + lock.sync { self._value } + } +} + +extension NSRecursiveLock { + @discardableResult + func sync(work: () throws -> R) rethrows -> R { + lock() + defer { unlock() } + return try work() + } +} diff --git a/Tests/PostgRESTTests/PostgrestResponseTests.swift b/Tests/PostgRESTTests/PostgrestResponseTests.swift new file mode 100644 index 00000000..38e20925 --- /dev/null +++ b/Tests/PostgRESTTests/PostgrestResponseTests.swift @@ -0,0 +1,47 @@ +import XCTest + +@testable import PostgREST + +class PostgrestResponseTests: XCTestCase { + func testInit() { + // Prepare data and response + let data = Data() + let response = HTTPURLResponse( + url: URL(string: "http://example.com")!, + statusCode: 200, + httpVersion: nil, + headerFields: ["Content-Range": "bytes 0-100/200"])! + let value = "Test Value" + + // Create the PostgrestResponse instance + let postgrestResponse = PostgrestResponse(data: data, response: response, value: value) + + // Assert the properties + XCTAssertEqual(postgrestResponse.data, data) + XCTAssertEqual(postgrestResponse.response, response) + XCTAssertEqual(postgrestResponse.value, value) + XCTAssertEqual(postgrestResponse.status, 200) + XCTAssertEqual(postgrestResponse.count, 200) + } + + func testInitWithNoCount() { + // Prepare data and response + let data = Data() + let response = HTTPURLResponse( + url: URL(string: "http://example.com")!, + statusCode: 200, + httpVersion: nil, + headerFields: ["Content-Range": "*"])! + let value = "Test Value" + + // Create the PostgrestResponse instance + let postgrestResponse = PostgrestResponse(data: data, response: response, value: value) + + // Assert the properties + XCTAssertEqual(postgrestResponse.data, data) + XCTAssertEqual(postgrestResponse.response, response) + XCTAssertEqual(postgrestResponse.value, value) + XCTAssertEqual(postgrestResponse.status, 200) + XCTAssertNil(postgrestResponse.count) + } +} diff --git a/Tests/PostgRESTTests/URLQueryRepresentableTests.swift b/Tests/PostgRESTTests/URLQueryRepresentableTests.swift new file mode 100644 index 00000000..0bdc4c4f --- /dev/null +++ b/Tests/PostgRESTTests/URLQueryRepresentableTests.swift @@ -0,0 +1,16 @@ +import PostgREST +import XCTest + +final class URLQueryRepresentableTests: XCTestCase { + func testArray() { + let array = ["is:online", "faction:red"] + let queryValue = array.queryValue + XCTAssertEqual(queryValue, "{is:online,faction:red}") + } + + func testDictionary() { + let dictionary = ["postalcode": 90210] + let queryValue = dictionary.queryValue + XCTAssertEqual(queryValue, "{\"postalcode\":90210}") + } +} diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.call-rpc-without-parameter.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.call-rpc-without-parameter.txt new file mode 100644 index 00000000..9a29a9fa --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.call-rpc-without-parameter.txt @@ -0,0 +1,6 @@ +curl \ + --request POST \ + --header "Accept: application/json" \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: postgrest-swift/1.0.2" \ + "https://example.supabase.co/rpc/test_fcn" \ No newline at end of file diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.call-rpc.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.call-rpc.txt new file mode 100644 index 00000000..32e8b909 --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.call-rpc.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Accept: application/json" \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: postgrest-swift/1.0.2" \ + --data "{\"KEY\":\"VALUE\"}" \ + "https://example.supabase.co/rpc/test_fcn" \ No newline at end of file diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.insert-new-user.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.insert-new-user.txt new file mode 100644 index 00000000..ec591838 --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.insert-new-user.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Accept: application/json" \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: postgrest-swift/1.0.2" \ + --data "{\"email\":\"johndoe@supabase.io\"}" \ + "https://example.supabase.co/users" \ No newline at end of file diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.query-with-character.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.query-with-character.txt new file mode 100644 index 00000000..2ccf8315 --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.query-with-character.txt @@ -0,0 +1,5 @@ +curl \ + --header "Accept: application/json" \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: postgrest-swift/1.0.2" \ + "https://example.supabase.co/users?id=eq.Cig%C3%A1nyka-%C3%A9r%20(0+400%20cskm)%20v%C3%ADzrajzi%20%C3%A1llom%C3%A1s&select=*" \ No newline at end of file diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.query-with-timestampz.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.query-with-timestampz.txt new file mode 100644 index 00000000..5dcd2bf5 --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.query-with-timestampz.txt @@ -0,0 +1,5 @@ +curl \ + --header "Accept: application/json" \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: postgrest-swift/1.0.2" \ + "https://example.supabase.co/tasks?order=received_at.asc.nullslast&received_at=gt.2023-03-23T15:50:30.511743+00:00&select=*" \ No newline at end of file diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.select-all-users-where-email-ends-with-supabase-co.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.select-all-users-where-email-ends-with-supabase-co.txt new file mode 100644 index 00000000..cb49b21b --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.select-all-users-where-email-ends-with-supabase-co.txt @@ -0,0 +1,5 @@ +curl \ + --header "Accept: application/json" \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: postgrest-swift/1.0.2" \ + "https://example.supabase.co/users?email=like.%25@supabase.co&select=*" \ No newline at end of file diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-all-filters-and-count.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-all-filters-and-count.txt new file mode 100644 index 00000000..5c02d0de --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-all-filters-and-count.txt @@ -0,0 +1,5 @@ +curl \ + --header "Accept: application/json" \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: postgrest-swift/1.0.2" \ + "https://example.supabase.co/todos?column=eq.Some%20value&column=neq.Some%20value&column=gt.Some%20value&column=gte.Some%20value&column=lt.Some%20value&column=lte.Some%20value&column=like.Some%20value&column=ilike.Some%20value&column=is.Some%20value&column=in.Some%20value&column=cs.Some%20value&column=cd.Some%20value&column=sl.Some%20value&column=sr.Some%20value&column=nxl.Some%20value&column=nxr.Some%20value&column=adj.Some%20value&column=ov.Some%20value&column=fts.Some%20value&column=plfts.Some%20value&column=phfts.Some%20value&column=wfts.Some%20value&select=*" \ No newline at end of file diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-contains-filter-with-array.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-contains-filter-with-array.txt new file mode 100644 index 00000000..862d5d5a --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-contains-filter-with-array.txt @@ -0,0 +1,5 @@ +curl \ + --header "Accept: application/json" \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: postgrest-swift/1.0.2" \ + "https://example.supabase.co/users?name=cs.%7Bis:online,faction:red%7D&select=*" \ No newline at end of file diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-contains-filter-with-dictionary.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-contains-filter-with-dictionary.txt new file mode 100644 index 00000000..b9c6e1b5 --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-contains-filter-with-dictionary.txt @@ -0,0 +1,5 @@ +curl \ + --header "Accept: application/json" \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: postgrest-swift/1.0.2" \ + "https://example.supabase.co/users?address=cs.%7B%22postcode%22:90210%7D&select=name" \ No newline at end of file diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-in-filter.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-in-filter.txt new file mode 100644 index 00000000..15a06494 --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-in-filter.txt @@ -0,0 +1,5 @@ +curl \ + --header "Accept: application/json" \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: postgrest-swift/1.0.2" \ + "https://example.supabase.co/todos?id=in.(1,2,3)&select=*" \ No newline at end of file diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-upsert-ignoring-duplicates.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-upsert-ignoring-duplicates.txt new file mode 100644 index 00000000..a018f58f --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-upsert-ignoring-duplicates.txt @@ -0,0 +1,8 @@ +curl \ + --request POST \ + --header "Accept: application/json" \ + --header "Content-Type: application/json" \ + --header "Prefer: resolution=ignore-duplicates,return=representation" \ + --header "X-Client-Info: postgrest-swift/1.0.2" \ + --data "{\"email\":\"johndoe@supabase.io\"}" \ + "https://example.supabase.co/users" \ No newline at end of file diff --git a/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-upsert-not-ignoring-duplicates.txt b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-upsert-not-ignoring-duplicates.txt new file mode 100644 index 00000000..4fde2bba --- /dev/null +++ b/Tests/PostgRESTTests/__Snapshots__/BuildURLRequestTests/testBuildRequest.test-upsert-not-ignoring-duplicates.txt @@ -0,0 +1,8 @@ +curl \ + --request POST \ + --header "Accept: application/json" \ + --header "Content-Type: application/json" \ + --header "Prefer: resolution=merge-duplicates,return=representation" \ + --header "X-Client-Info: postgrest-swift/1.0.2" \ + --data "{\"email\":\"johndoe@supabase.io\"}" \ + "https://example.supabase.co/users" \ No newline at end of file diff --git a/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved b/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved index 293e536e..f690ce9a 100644 --- a/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -27,15 +27,6 @@ "version" : "3.0.1" } }, - { - "identity" : "postgrest-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/supabase-community/postgrest-swift", - "state" : { - "branch" : "dependency-free", - "revision" : "ef3faa36a6987d34c8d32bab1cf545d40bf1a182" - } - }, { "identity" : "realtime-swift", "kind" : "remoteSourceControl", @@ -90,6 +81,24 @@ "version" : "0.8.0" } }, + { + "identity" : "swift-snapshot-testing", + "kind" : "remoteSourceControl", + "location" : "https://github.com/pointfreeco/swift-snapshot-testing", + "state" : { + "revision" : "bb0ea08db8e73324fe6c3727f755ca41a23ff2f4", + "version" : "1.14.2" + } + }, + { + "identity" : "swift-syntax", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-syntax.git", + "state" : { + "revision" : "74203046135342e4a4a627476dd6caf8b28fe11b", + "version" : "509.0.0" + } + }, { "identity" : "swiftui-navigation", "kind" : "remoteSourceControl", From 5649c9bc98115bb1f08ca55d94cc43d4780ab543 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Tue, 17 Oct 2023 17:20:06 -0300 Subject: [PATCH 3/7] Move gotrue to repo --- Package.resolved | 9 - Package.swift | 57 +- Sources/GoTrue/Deprecated.swift | 14 + Sources/GoTrue/Extensions.swift | 11 + Sources/GoTrue/GoTrueClient.swift | 622 ++++++++++++++++++ Sources/GoTrue/GoTrueError.swift | 33 + Sources/GoTrue/GoTrueLocalStorage.swift | 32 + Sources/GoTrue/Internal/Helpers.swift | 46 ++ Sources/GoTrue/Internal/Request.swift | 50 ++ Sources/GoTrue/Internal/SessionManager.swift | 77 +++ Sources/GoTrue/Internal/ShareReplay.swift | 96 +++ Sources/GoTrue/Types.swift | 587 +++++++++++++++++ Sources/GoTrue/Version.swift | 1 + Tests/GoTrueTests/DecoderTests.swift | 42 ++ Tests/GoTrueTests/JWTTests.swift | 13 + Tests/GoTrueTests/MockHelpers.swift | 15 + Tests/GoTrueTests/RequestsTests.swift | 305 +++++++++ Tests/GoTrueTests/Resources/session.json | 37 ++ .../Resources/signup-response.json | 30 + Tests/GoTrueTests/Resources/user.json | 32 + .../RequestsTests/testRefreshSession.1.txt | 7 + .../testResetPasswordForEmail.1.txt | 7 + .../RequestsTests/testSessionFromURL.1.txt | 5 + .../testSetSessionWithAExpiredToken.1.txt | 7 + ...tSetSessionWithAFutureExpirationDate.1.txt | 5 + .../testSignInWithEmailAndPassword.1.txt | 7 + .../RequestsTests/testSignInWithIdToken.1.txt | 7 + .../testSignInWithOTPUsingEmail.1.txt | 7 + .../testSignInWithOTPUsingPhone.1.txt | 7 + .../testSignInWithPhoneAndPassword.1.txt | 7 + .../testSignUpWithEmailAndPassword.1.txt | 7 + .../testSignUpWithPhoneAndPassword.1.txt | 7 + .../RequestsTests/testUpdateUser.1.txt | 8 + .../testVerifyOTPUsingEmail.1.txt | 7 + .../testVerifyOTPUsingPhone.1.txt | 7 + .../xcshareddata/swiftpm/Package.resolved | 9 - 36 files changed, 2176 insertions(+), 44 deletions(-) create mode 100644 Sources/GoTrue/Deprecated.swift create mode 100644 Sources/GoTrue/Extensions.swift create mode 100644 Sources/GoTrue/GoTrueClient.swift create mode 100644 Sources/GoTrue/GoTrueError.swift create mode 100644 Sources/GoTrue/GoTrueLocalStorage.swift create mode 100644 Sources/GoTrue/Internal/Helpers.swift create mode 100644 Sources/GoTrue/Internal/Request.swift create mode 100644 Sources/GoTrue/Internal/SessionManager.swift create mode 100644 Sources/GoTrue/Internal/ShareReplay.swift create mode 100644 Sources/GoTrue/Types.swift create mode 100644 Sources/GoTrue/Version.swift create mode 100644 Tests/GoTrueTests/DecoderTests.swift create mode 100644 Tests/GoTrueTests/JWTTests.swift create mode 100644 Tests/GoTrueTests/MockHelpers.swift create mode 100644 Tests/GoTrueTests/RequestsTests.swift create mode 100644 Tests/GoTrueTests/Resources/session.json create mode 100644 Tests/GoTrueTests/Resources/signup-response.json create mode 100644 Tests/GoTrueTests/Resources/user.json create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testRefreshSession.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testResetPasswordForEmail.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testSessionFromURL.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testSetSessionWithAExpiredToken.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testSetSessionWithAFutureExpirationDate.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithEmailAndPassword.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithIdToken.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithOTPUsingEmail.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithOTPUsingPhone.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithPhoneAndPassword.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignUpWithEmailAndPassword.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignUpWithPhoneAndPassword.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testUpdateUser.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testVerifyOTPUsingEmail.1.txt create mode 100644 Tests/GoTrueTests/__Snapshots__/RequestsTests/testVerifyOTPUsingPhone.1.txt diff --git a/Package.resolved b/Package.resolved index 8f39d598..4cea5e65 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,14 +1,5 @@ { "pins" : [ - { - "identity" : "gotrue-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/supabase-community/gotrue-swift", - "state" : { - "branch" : "dependency-free", - "revision" : "4eece4fe9d8a6596ec5dedd2ffc14a9594cd2134" - } - }, { "identity" : "keychainaccess", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index 6286a034..117bc42f 100644 --- a/Package.swift +++ b/Package.swift @@ -14,37 +14,36 @@ var package = Package( .tvOS(.v13), ], products: [ - .library( - name: "Supabase", - targets: ["Supabase", "Functions", "PostgREST"] - ), - .library( - name: "Functions", - targets: ["Functions"] - ), - .library( - name: "PostgREST", - targets: ["PostgREST"] - ) + .library(name: "Functions", targets: ["Functions"]), + .library(name: "GoTrue", targets: ["GoTrue"]), + .library(name: "PostgREST", targets: ["PostgREST"]), + .library(name: "Supabase", targets: ["Supabase", "Functions", "PostgREST", "GoTrue"]), ], dependencies: [ + .package(url: "https://github.com/kishikawakatsumi/KeychainAccess", from: "4.2.2"), .package(url: "https://github.com/WeTransfer/Mocker", from: "3.0.1"), .package(url: "https://github.com/pointfreeco/swift-snapshot-testing", from: "1.8.1"), ], targets: [ + .target(name: "Functions"), + .testTarget(name: "FunctionsTests", dependencies: ["Functions", "Mocker"]), .target( - name: "Supabase", + name: "GoTrue", dependencies: [ - .product(name: "GoTrue", package: "gotrue-swift"), - .product(name: "SupabaseStorage", package: "storage-swift"), - .product(name: "Realtime", package: "realtime-swift"), - "PostgREST", - "Functions", + .product(name: "KeychainAccess", package: "KeychainAccess") + ] + ), + .testTarget( + name: "GoTrueTests", + dependencies: [ + "GoTrue", + "Mocker", + .product(name: "SnapshotTesting", package: "swift-snapshot-testing"), + ], + resources: [ + .process("Resources") ] ), - .testTarget(name: "SupabaseTests", dependencies: ["Supabase"]), - .target(name: "Functions"), - .testTarget(name: "FunctionsTests", dependencies: ["Functions", "Mocker"]), .target( name: "PostgREST", dependencies: [] @@ -64,13 +63,23 @@ var package = Package( ] ), .testTarget(name: "PostgRESTIntegrationTests", dependencies: ["PostgREST"]), + .target( + name: "Supabase", + dependencies: [ + "GoTrue", + .product(name: "SupabaseStorage", package: "storage-swift"), + .product(name: "Realtime", package: "realtime-swift"), + "PostgREST", + "Functions", + ] + ), + .testTarget(name: "SupabaseTests", dependencies: ["Supabase"]), ] ) if ProcessInfo.processInfo.environment["USE_LOCAL_PACKAGES"] != nil { package.dependencies.append( contentsOf: [ - .package(path: "../gotrue-swift"), .package(path: "../storage-swift"), .package(path: "../realtime-swift"), ] @@ -78,10 +87,6 @@ if ProcessInfo.processInfo.environment["USE_LOCAL_PACKAGES"] != nil { } else { package.dependencies.append( contentsOf: [ - .package( - url: "https://github.com/supabase-community/gotrue-swift", - branch: "dependency-free" - ), .package( url: "https://github.com/supabase-community/storage-swift.git", branch: "dependency-free" diff --git a/Sources/GoTrue/Deprecated.swift b/Sources/GoTrue/Deprecated.swift new file mode 100644 index 00000000..3db80c61 --- /dev/null +++ b/Sources/GoTrue/Deprecated.swift @@ -0,0 +1,14 @@ +import Foundation + +extension GoTrueMetaSecurity { + @available(*, deprecated, renamed: "captchaToken") + public var hcaptchaToken: String { + get { captchaToken } + set { captchaToken = newValue } + } + + @available(*, deprecated, renamed: "init(captchaToken:)") + public init(hcaptchaToken: String) { + self.init(captchaToken: hcaptchaToken) + } +} diff --git a/Sources/GoTrue/Extensions.swift b/Sources/GoTrue/Extensions.swift new file mode 100644 index 00000000..f0a50fa4 --- /dev/null +++ b/Sources/GoTrue/Extensions.swift @@ -0,0 +1,11 @@ +extension AuthResponse { + public var user: User? { + if case let .user(user) = self { return user } + return nil + } + + public var session: Session? { + if case let .session(session) = self { return session } + return nil + } +} diff --git a/Sources/GoTrue/GoTrueClient.swift b/Sources/GoTrue/GoTrueClient.swift new file mode 100644 index 00000000..9e1de0e6 --- /dev/null +++ b/Sources/GoTrue/GoTrueClient.swift @@ -0,0 +1,622 @@ +import Combine +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +public final class GoTrueClient { + public typealias FetchHandler = @Sendable (_ request: URLRequest) async throws -> ( + Data, + URLResponse + ) + + public struct Configuration { + public let url: URL + public var headers: [String: String] + public let localStorage: GoTrueLocalStorage + public let encoder: JSONEncoder + public let decoder: JSONDecoder + public let fetch: FetchHandler + + public init( + url: URL, + headers: [String: String] = [:], + localStorage: GoTrueLocalStorage? = nil, + encoder: JSONEncoder = .goTrue, + decoder: JSONDecoder = .goTrue, + fetch: @escaping FetchHandler = { try await URLSession.shared.data(for: $0) } + ) { + self.url = url + self.headers = headers + self.localStorage = + localStorage + ?? KeychainLocalStorage( + service: "supabase.gotrue.swift", + accessGroup: nil + ) + self.encoder = encoder + self.decoder = decoder + self.fetch = fetch + } + } + + private let configuration: Configuration + private lazy var sessionManager = SessionManager( + localStorage: self.configuration.localStorage, + sessionRefresher: { try await self.refreshSession(refreshToken: $0) } + ) + + private let authEventChangeSubject = PassthroughSubject() + /// Asynchronous sequence of authentication change events emitted during life of `GoTrueClient`. + public var authEventChange: AnyPublisher { + authEventChangeSubject.shareReplay(1).eraseToAnyPublisher() + } + + // private let initializationTask: Task + + /// Returns the session, refreshing it if necessary. + public var session: Session { + get async throws { + try await sessionManager.session() + } + } + + public convenience init( + url: URL, + headers: [String: String] = [:], + localStorage: GoTrueLocalStorage? = nil, + encoder: JSONEncoder = .goTrue, + decoder: JSONDecoder = .goTrue, + fetch: @escaping FetchHandler = { try await URLSession.shared.data(for: $0) } + ) { + self.init( + configuration: Configuration( + url: url, + headers: headers, + localStorage: localStorage, + encoder: encoder, + decoder: decoder, + fetch: fetch + )) + } + + public init(configuration: Configuration) { + var configuration = configuration + configuration.headers["X-Client-Info"] = "gotrue-swift/\(version)" + self.configuration = configuration + + Task { + do { + _ = try await sessionManager.session() + authEventChangeSubject.send(.signedIn) + } catch { + authEventChangeSubject.send(.signedOut) + } + } + } + + /// Initialize the client session from storage. + /// + /// This method should be called on the app startup, for making sure that the client is fully + /// initialized + /// before proceeding. + // public func initialize() async { + // await initializationTask.value + // } + + /// Creates a new user. + /// - Parameters: + /// - email: User's email address. + /// - password: Password for the user. + /// - data: User's metadata. + @discardableResult + public func signUp( + email: String, + password: String, + data: [String: AnyJSON]? = nil, + redirectTo: URL? = nil, + captchaToken: String? = nil + ) async throws -> AuthResponse { + try await _signUp( + request: .init( + path: "/signup", + method: "POST", + query: [ + redirectTo.map { URLQueryItem(name: "redirect_to", value: $0.absoluteString) } + ].compactMap { $0 }, + body: configuration.encoder.encode( + SignUpRequest( + email: email, + password: password, + data: data, + gotrueMetaSecurity: captchaToken.map(GoTrueMetaSecurity.init(captchaToken:)) + ) + ) + ) + ) + } + + /// Creates a new user. + /// - Parameters: + /// - phone: User's phone number with international prefix. + /// - password: Password for the user. + /// - data: User's metadata. + @discardableResult + public func signUp( + phone: String, + password: String, + data: [String: AnyJSON]? = nil, + captchaToken: String? = nil + ) async throws -> AuthResponse { + try await _signUp( + request: .init( + path: "/signup", + method: "POST", + body: configuration.encoder.encode( + SignUpRequest( + password: password, + phone: phone, + data: data, + gotrueMetaSecurity: captchaToken.map(GoTrueMetaSecurity.init(captchaToken:)) + ) + ) + ) + ) + } + + private func _signUp(request: Request) async throws -> AuthResponse { + await sessionManager.remove() + let response = try await execute(request).decoded( + as: AuthResponse.self, + decoder: configuration.decoder + ) + + if let session = response.session { + try await sessionManager.update(session) + authEventChangeSubject.send(.signedIn) + } + + return response + } + + /// Log in an existing user with an email and password. + @discardableResult + public func signIn(email: String, password: String) async throws -> Session { + try await _signIn( + request: .init( + path: "/token", + method: "POST", + query: [URLQueryItem(name: "grant_type", value: "password")], + body: configuration.encoder.encode( + UserCredentials(email: email, password: password) + ) + ) + ) + } + + /// Log in an existing user with a phone and password. + @discardableResult + public func signIn(phone: String, password: String) async throws -> Session { + try await _signIn( + request: .init( + path: "/token", + method: "POST", + query: [URLQueryItem(name: "grant_type", value: "password")], + body: configuration.encoder.encode( + UserCredentials(password: password, phone: phone) + ) + ) + ) + } + + /// Allows signing in with an ID token issued by certain supported providers. + /// The ID token is verified for validity and a new session is established. + @discardableResult + public func signInWithIdToken(credentials: OpenIDConnectCredentials) async throws -> Session { + try await _signIn( + request: .init( + path: "/token", + method: "POST", + query: [URLQueryItem(name: "grant_type", value: "id_token")], + body: configuration.encoder.encode(credentials) + ) + ) + } + + private func _signIn(request: Request) async throws -> Session { + await sessionManager.remove() + + let session = try await execute(request).decoded( + as: Session.self, + decoder: configuration.decoder + ) + + if session.user.emailConfirmedAt != nil || session.user.confirmedAt != nil { + try await sessionManager.update(session) + authEventChangeSubject.send(.signedIn) + } + + return session + } + + /// Log in user using magic link. + /// + /// If the `{{ .ConfirmationURL }}` variable is specified in the email template, a magic link will + /// be sent. + /// If the `{{ .Token }}` variable is specified in the email template, an OTP will be sent. + /// - Parameters: + /// - email: User's email address. + /// - redirectTo: Redirect URL embedded in the email link. + /// - shouldCreateUser: Creates a new user, defaults to `true`. + /// - data: User's metadata. + /// - captchaToken: Captcha verification token. + public func signInWithOTP( + email: String, + redirectTo _: URL? = nil, + shouldCreateUser: Bool? = nil, + data: [String: AnyJSON]? = nil, + captchaToken: String? = nil + ) async throws { + await sessionManager.remove() + try await execute( + .init( + path: "/otp", + method: "POST", + body: configuration.encoder.encode( + OTPParams( + email: email, + createUser: shouldCreateUser, + data: data, + gotrueMetaSecurity: captchaToken.map(GoTrueMetaSecurity.init(captchaToken:)) + ) + ) + ) + ) + } + + /// Log in user using a one-time password (OTP).. + /// + /// - Parameters: + /// - phone: User's phone with international prefix. + /// - shouldCreateUser: Creates a new user, defaults to `true`. + /// - data: User's metadata. + /// - captchaToken: Captcha verification token. + public func signInWithOTP( + phone: String, + shouldCreateUser: Bool? = nil, + data: [String: AnyJSON]? = nil, + captchaToken: String? = nil + ) async throws { + await sessionManager.remove() + try await execute( + .init( + path: "/otp", + method: "POST", + body: configuration.encoder.encode( + OTPParams( + phone: phone, + createUser: shouldCreateUser, + data: data, + gotrueMetaSecurity: captchaToken.map(GoTrueMetaSecurity.init(captchaToken:)) + ) + ) + ) + ) + } + + /// Log in an existing user via a third-party provider. + public func getOAuthSignInURL( + provider: Provider, + scopes: String? = nil, + redirectTo: URL? = nil, + queryParams: [(name: String, value: String?)] = [] + ) throws -> URL { + guard + var components = URLComponents( + url: configuration.url.appendingPathComponent("authorize"), resolvingAgainstBaseURL: false + ) + else { + throw URLError(.badURL) + } + + var queryItems: [URLQueryItem] = [ + URLQueryItem(name: "provider", value: provider.rawValue) + ] + + if let scopes { + queryItems.append(URLQueryItem(name: "scopes", value: scopes)) + } + + if let redirectTo { + queryItems.append(URLQueryItem(name: "redirect_to", value: redirectTo.absoluteString)) + } + + queryItems.append(contentsOf: queryParams.map(URLQueryItem.init)) + + components.queryItems = queryItems + + guard let url = components.url else { + throw URLError(.badURL) + } + + return url + } + + @discardableResult + public func refreshSession(refreshToken: String) async throws -> Session { + do { + let session = try await execute( + .init( + path: "/token", + method: "POST", + query: [URLQueryItem(name: "grant_type", value: "refresh_token")], + body: configuration.encoder.encode(UserCredentials(refreshToken: refreshToken)) + ) + ).decoded(as: Session.self, decoder: configuration.decoder) + + if session.user.phoneConfirmedAt != nil || session.user.emailConfirmedAt != nil + || session + .user.confirmedAt != nil + { + try await sessionManager.update(session) + authEventChangeSubject.send(.signedIn) + } + + return session + } catch { + throw error + } + } + + /// Gets the session data from a OAuth2 callback URL. + @discardableResult + public func session(from url: URL, storeSession: Bool = true) async throws -> Session { + guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { + throw URLError(.badURL) + } + + let params = extractParams(from: components.fragment ?? "") + + if let errorDescription = params.first(where: { $0.name == "error_description" })?.value { + throw GoTrueError.api(.init(errorDescription: errorDescription)) + } + + guard + let accessToken = params.first(where: { $0.name == "access_token" })?.value, + let expiresIn = params.first(where: { $0.name == "expires_in" })?.value, + let refreshToken = params.first(where: { $0.name == "refresh_token" })?.value, + let tokenType = params.first(where: { $0.name == "token_type" })?.value + else { + throw URLError(.badURL) + } + + let providerToken = params.first(where: { $0.name == "provider_token" })?.value + let providerRefreshToken = params.first(where: { $0.name == "provider_refresh_token" })?.value + + let user = try await execute( + .init( + path: "/user", + method: "GET", + headers: ["Authorization": "\(tokenType) \(accessToken)"] + ) + ).decoded(as: User.self, decoder: configuration.decoder) + + let session = Session( + providerToken: providerToken, + providerRefreshToken: providerRefreshToken, + accessToken: accessToken, + tokenType: tokenType, + expiresIn: Double(expiresIn) ?? 0, + refreshToken: refreshToken, + user: user + ) + + if storeSession { + try await sessionManager.update(session) + authEventChangeSubject.send(.signedIn) + + if let type = params.first(where: { $0.name == "type" })?.value, type == "recovery" { + authEventChangeSubject.send(.passwordRecovery) + } + } + + return session + } + + /// Sets the session data from the current session. If the current session is expired, setSession + /// will take care of refreshing it to obtain a new session. + /// + /// If the refresh token is invalid and the current session has expired, an error will be thrown. + /// This method will use the exp claim defined in the access token. + /// - Parameters: + /// - accessToken: The current access token. + /// - refreshToken: The current refresh token. + /// - Returns: A new valid session. + @discardableResult + public func setSession(accessToken: String, refreshToken: String) async throws -> Session { + let now = Date() + var expiresAt = now + var hasExpired = true + var session: Session? + + let jwt = try decode(jwt: accessToken) + if let exp = jwt["exp"] as? TimeInterval { + expiresAt = Date(timeIntervalSince1970: exp) + hasExpired = expiresAt <= now + } else { + throw GoTrueError.missingExpClaim + } + + if hasExpired { + session = try await refreshSession(refreshToken: refreshToken) + } else { + let user = try await authorizedExecute(.init(path: "/user", method: "GET")) + .decoded(as: User.self, decoder: configuration.decoder) + session = Session( + accessToken: accessToken, + tokenType: "bearer", + expiresIn: expiresAt.timeIntervalSince(now), + refreshToken: refreshToken, + user: user + ) + } + + guard let session else { + throw GoTrueError.sessionNotFound + } + + try await sessionManager.update(session) + authEventChangeSubject.send(.tokenRefreshed) + return session + } + + /// Signs out the current user, if there is a logged in user. + public func signOut() async throws { + defer { authEventChangeSubject.send(.signedOut) } + + let session = try? await sessionManager.session() + if session != nil { + try await authorizedExecute( + .init( + path: "/logout", + method: "POST" + ) + ) + await sessionManager.remove() + } + } + + /// Log in an user given a User supplied OTP received via email. + @discardableResult + public func verifyOTP( + email: String, + token: String, + type: OTPType, + redirectTo: URL? = nil, + captchaToken: String? = nil + ) async throws -> AuthResponse { + try await _verifyOTP( + request: .init( + path: "/verify", + method: "POST", + query: [ + redirectTo.map { URLQueryItem(name: "redirect_to", value: $0.absoluteString) } + ].compactMap { $0 }, + body: configuration.encoder.encode( + VerifyOTPParams( + email: email, + token: token, + type: type, + gotrueMetaSecurity: captchaToken.map(GoTrueMetaSecurity.init(captchaToken:)) + ) + ) + ) + ) + } + + /// Log in an user given a User supplied OTP received via mobile. + @discardableResult + public func verifyOTP( + phone: String, + token: String, + type: OTPType, + captchaToken: String? = nil + ) async throws -> AuthResponse { + try await _verifyOTP( + request: .init( + path: "/verify", + method: "POST", + body: configuration.encoder.encode( + VerifyOTPParams( + phone: phone, + token: token, + type: type, + gotrueMetaSecurity: captchaToken.map(GoTrueMetaSecurity.init(captchaToken:)) + ) + ) + ) + ) + } + + private func _verifyOTP(request: Request) async throws -> AuthResponse { + await sessionManager.remove() + + let response = try await execute(request).decoded( + as: AuthResponse.self, + decoder: configuration.decoder + ) + + if let session = response.session { + try await sessionManager.update(session) + authEventChangeSubject.send(.signedIn) + } + + return response + } + + /// Updates user data, if there is a logged in user. + @discardableResult + public func update(user: UserAttributes) async throws -> User { + var session = try await sessionManager.session() + let user = try await authorizedExecute( + .init(path: "/user", method: "PUT", body: configuration.encoder.encode(user)) + ).decoded(as: User.self, decoder: configuration.decoder) + session.user = user + try await sessionManager.update(session) + authEventChangeSubject.send(.userUpdated) + return user + } + + /// Sends a reset request to an email address. + public func resetPasswordForEmail( + _ email: String, + redirectTo: URL? = nil, + captchaToken: String? = nil + ) async throws { + try await execute( + .init( + path: "/recover", + method: "POST", + query: [ + redirectTo.map { URLQueryItem(name: "redirect_to", value: $0.absoluteString) } + ].compactMap { $0 }, + body: configuration.encoder.encode( + RecoverParams( + email: email, + gotrueMetaSecurity: captchaToken.map(GoTrueMetaSecurity.init(captchaToken:)) + ) + ) + ) + ) + } + + @discardableResult + private func authorizedExecute(_ request: Request) async throws -> Response { + let session = try await sessionManager.session() + + var request = request + request.headers["Authorization"] = "Bearer \(session.accessToken)" + + return try await execute(request) + } + + @discardableResult + private func execute(_ request: Request) async throws -> Response { + var request = request + request.headers.merge(configuration.headers) { r, _ in r } + let urlRequest = try request.urlRequest(withBaseURL: configuration.url) + + let (data, response) = try await configuration.fetch(urlRequest) + guard let httpResponse = response as? HTTPURLResponse else { + throw URLError(.badServerResponse) + } + + guard (200..<300).contains(httpResponse.statusCode) else { + let apiError = try configuration.decoder.decode(GoTrueError.APIError.self, from: data) + throw GoTrueError.api(apiError) + } + + return Response(data: data, response: httpResponse) + } +} diff --git a/Sources/GoTrue/GoTrueError.swift b/Sources/GoTrue/GoTrueError.swift new file mode 100644 index 00000000..338c5be4 --- /dev/null +++ b/Sources/GoTrue/GoTrueError.swift @@ -0,0 +1,33 @@ +import Foundation + +public enum GoTrueError: LocalizedError, Sendable { + case missingExpClaim + case malformedJWT + case sessionNotFound + case api(APIError) + + public struct APIError: Error, Decodable, Sendable { + public var message: String? + public var msg: String? + public var code: Int? + public var error: String? + public var errorDescription: String? + + private enum CodingKeys: String, CodingKey { + case message + case msg + case code + case error + case errorDescription = "error_description" + } + } + + public var errorDescription: String? { + switch self { + case .missingExpClaim: return "Missing expiration claim on access token." + case .malformedJWT: return "A malformed JWT received." + case .sessionNotFound: return "Unable to get a valid session." + case let .api(error): return error.errorDescription ?? error.message ?? error.msg + } + } +} diff --git a/Sources/GoTrue/GoTrueLocalStorage.swift b/Sources/GoTrue/GoTrueLocalStorage.swift new file mode 100644 index 00000000..cebd49b6 --- /dev/null +++ b/Sources/GoTrue/GoTrueLocalStorage.swift @@ -0,0 +1,32 @@ +import Foundation +@preconcurrency import KeychainAccess + +public protocol GoTrueLocalStorage: Sendable { + func store(key: String, value: Data) throws + func retrieve(key: String) throws -> Data? + func remove(key: String) throws +} + +struct KeychainLocalStorage: GoTrueLocalStorage { + private let keychain: Keychain + + init(service: String, accessGroup: String?) { + if let accessGroup { + keychain = Keychain(service: service, accessGroup: accessGroup) + } else { + keychain = Keychain(service: service) + } + } + + func store(key: String, value: Data) throws { + try keychain.set(value, key: key) + } + + func retrieve(key: String) throws -> Data? { + try keychain.getData(key) + } + + func remove(key: String) throws { + try keychain.remove(key) + } +} diff --git a/Sources/GoTrue/Internal/Helpers.swift b/Sources/GoTrue/Internal/Helpers.swift new file mode 100644 index 00000000..77191618 --- /dev/null +++ b/Sources/GoTrue/Internal/Helpers.swift @@ -0,0 +1,46 @@ +import Foundation + +func extractParams(from fragment: String) -> [(name: String, value: String)] { + let components = + fragment + .split(separator: "&") + .map { $0.split(separator: "=") } + + return + components + .compactMap { + $0.count == 2 + ? (name: String($0[0]), value: String($0[1])) + : nil + } +} + +func decode(jwt: String) throws -> [String: Any] { + let parts = jwt.split(separator: ".") + guard parts.count == 3 else { + throw GoTrueError.malformedJWT + } + + let payload = String(parts[1]) + guard let data = base64URLDecode(payload) else { + throw GoTrueError.malformedJWT + } + let json = try JSONSerialization.jsonObject(with: data, options: []) + guard let decodedPayload = json as? [String: Any] else { + throw GoTrueError.malformedJWT + } + return decodedPayload +} + +private func base64URLDecode(_ value: String) -> Data? { + var base64 = value.replacingOccurrences(of: "-", with: "+") + .replacingOccurrences(of: "_", with: "/") + let length = Double(base64.lengthOfBytes(using: .utf8)) + let requiredLength = 4 * ceil(length / 4.0) + let paddingLength = requiredLength - length + if paddingLength > 0 { + let padding = "".padding(toLength: Int(paddingLength), withPad: "=", startingAt: 0) + base64 = base64 + padding + } + return Data(base64Encoded: base64, options: .ignoreUnknownCharacters) +} diff --git a/Sources/GoTrue/Internal/Request.swift b/Sources/GoTrue/Internal/Request.swift new file mode 100644 index 00000000..c1583e9e --- /dev/null +++ b/Sources/GoTrue/Internal/Request.swift @@ -0,0 +1,50 @@ +import Foundation + +struct Request { + var path: String + var method: String + var query: [URLQueryItem] = [] + var headers: [String: String] = [:] + var body: Data? + + func urlRequest(withBaseURL baseURL: URL) throws -> URLRequest { + var url = baseURL.appendingPathComponent(path) + if !query.isEmpty { + guard var components = URLComponents(url: url, resolvingAgainstBaseURL: false) else { + throw URLError(.badURL) + } + + components.queryItems = query + + if let newURL = components.url { + url = newURL + } else { + throw URLError(.badURL) + } + } + + var request = URLRequest(url: url) + request.httpMethod = method + + if body != nil, headers["Content-Type"] == nil { + request.setValue("application/json", forHTTPHeaderField: "Content-Type") + } + + for (name, value) in headers { + request.setValue(value, forHTTPHeaderField: name) + } + + request.httpBody = body + + return request + } +} + +struct Response { + let data: Data + let response: HTTPURLResponse + + func decoded(as _: T.Type, decoder: JSONDecoder) throws -> T { + try decoder.decode(T.self, from: data) + } +} diff --git a/Sources/GoTrue/Internal/SessionManager.swift b/Sources/GoTrue/Internal/SessionManager.swift new file mode 100644 index 00000000..66f46424 --- /dev/null +++ b/Sources/GoTrue/Internal/SessionManager.swift @@ -0,0 +1,77 @@ +import Foundation +import KeychainAccess + +struct StoredSession: Codable { + var session: Session + var expirationDate: Date + + var isValid: Bool { + expirationDate > Date().addingTimeInterval(60) + } + + init(session: Session, expirationDate: Date? = nil) { + self.session = session + self.expirationDate = expirationDate ?? Date().addingTimeInterval(session.expiresIn) + } +} + +actor SessionManager { + typealias SessionRefresher = @Sendable (_ refreshToken: String) async throws -> Session + + private var task: Task? + private let localStorage: GoTrueLocalStorage + private let sessionRefresher: SessionRefresher + + init(localStorage: GoTrueLocalStorage, sessionRefresher: @escaping SessionRefresher) { + self.localStorage = localStorage + self.sessionRefresher = sessionRefresher + } + + func session() async throws -> Session { + if let task { + return try await task.value + } + + guard let currentSession = try localStorage.getSession() else { + throw GoTrueError.sessionNotFound + } + + if currentSession.isValid { + return currentSession.session + } + + task = Task { + defer { self.task = nil } + + let session = try await sessionRefresher(currentSession.session.refreshToken) + try update(session) + return session + } + + return try await task!.value + } + + func update(_ session: Session) throws { + try localStorage.storeSession(StoredSession(session: session)) + } + + func remove() { + localStorage.deleteSession() + } +} + +extension GoTrueLocalStorage { + func getSession() throws -> StoredSession? { + try retrieve(key: "supabase.session").flatMap { + try JSONDecoder.goTrue.decode(StoredSession.self, from: $0) + } + } + + func storeSession(_ session: StoredSession) throws { + try store(key: "supabase.session", value: JSONEncoder.goTrue.encode(session)) + } + + func deleteSession() { + try? remove(key: "supabase.session") + } +} diff --git a/Sources/GoTrue/Internal/ShareReplay.swift b/Sources/GoTrue/Internal/ShareReplay.swift new file mode 100644 index 00000000..a935c21f --- /dev/null +++ b/Sources/GoTrue/Internal/ShareReplay.swift @@ -0,0 +1,96 @@ +import Combine +import Foundation + +extension Publisher { + /// Provides a subject that shares a single subscription to the upstream publisher and + /// replays at most `bufferSize` items emitted by that publisher + /// - Parameter bufferSize: limits the number of items that can be replayed + func shareReplay(_ bufferSize: Int) -> AnyPublisher { + multicast(subject: ReplaySubject(bufferSize)).autoconnect().eraseToAnyPublisher() + } +} + +final class ReplaySubject: Subject { + private var buffer = [Output]() + private let bufferSize: Int + private let lock = NSRecursiveLock() + + init(_ bufferSize: Int = 0) { + self.bufferSize = bufferSize + } + + private var subscriptions = [ReplaySubjectSubscription]() + private var completion: Subscribers.Completion? + + func receive(subscriber: Downstream) + where Downstream.Failure == Failure, Downstream.Input == Output { + lock.lock() + defer { lock.unlock() } + let subscription = ReplaySubjectSubscription( + downstream: AnySubscriber(subscriber)) + subscriber.receive(subscription: subscription) + subscriptions.append(subscription) + subscription.replay(buffer, completion: completion) + } + + /// Establishes demand for a new upstream subscriptions + func send(subscription: Subscription) { + lock.lock() + defer { lock.unlock() } + subscription.request(.unlimited) + } + + /// Sends a value to the subscriber. + func send(_ value: Output) { + lock.lock() + defer { lock.unlock() } + buffer.append(value) + buffer = buffer.suffix(bufferSize) + subscriptions.forEach { $0.receive(value) } + } + + /// Sends a completion event to the subscriber. + func send(completion: Subscribers.Completion) { + lock.lock() + defer { lock.unlock() } + self.completion = completion + subscriptions.forEach { subscription in subscription.receive(completion: completion) } + } +} + +final class ReplaySubjectSubscription: Subscription { + private let downstream: AnySubscriber + private var isCompleted = false + private var demand: Subscribers.Demand = .none + + init(downstream: AnySubscriber) { + self.downstream = downstream + } + + func request(_ newDemand: Subscribers.Demand) { + demand += newDemand + } + + func cancel() { + isCompleted = true + } + + func receive(_ value: Output) { + guard !isCompleted, demand > 0 else { return } + + demand += downstream.receive(value) + demand -= 1 + } + + func receive(completion: Subscribers.Completion) { + guard !isCompleted else { return } + isCompleted = true + downstream.receive(completion: completion) + } + + func replay(_ values: [Output], completion: Subscribers.Completion?) { + guard !isCompleted else { return } + values.forEach { value in receive(value) } + if let completion = completion { receive(completion: completion) } + } +} diff --git a/Sources/GoTrue/Types.swift b/Sources/GoTrue/Types.swift new file mode 100644 index 00000000..9c567c14 --- /dev/null +++ b/Sources/GoTrue/Types.swift @@ -0,0 +1,587 @@ +import Foundation + +public enum AuthChangeEvent: String, Sendable { + case passwordRecovery = "PASSWORD_RECOVERY" + case signedIn = "SIGNED_IN" + case signedOut = "SIGNED_OUT" + case tokenRefreshed = "TOKEN_REFRESHED" + case userUpdated = "USER_UPDATED" + case userDeleted = "USER_DELETED" +} + +public enum AnyJSON: Hashable, Codable, Sendable { + case string(String) + case number(Double) + case object([String: AnyJSON]) + case array([AnyJSON]) + case bool(Bool) + case null + + public var value: Any? { + switch self { + case let .string(string): return string + case let .number(double): return double + case let .object(dictionary): return dictionary + case let .array(array): return array + case let .bool(bool): return bool + case .null: return nil + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case let .array(array): try container.encode(array) + case let .object(object): try container.encode(object) + case let .string(string): try container.encode(string) + case let .number(number): try container.encode(number) + case let .bool(bool): try container.encode(bool) + case .null: try container.encodeNil() + } + } + + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if let object = try? container.decode([String: AnyJSON].self) { + self = .object(object) + } else if let array = try? container.decode([AnyJSON].self) { + self = .array(array) + } else if let string = try? container.decode(String.self) { + self = .string(string) + } else if let bool = try? container.decode(Bool.self) { + self = .bool(bool) + } else if let number = try? container.decode(Double.self) { + self = .number(number) + } else if container.decodeNil() { + self = .null + } else { + throw DecodingError.dataCorrupted( + .init(codingPath: decoder.codingPath, debugDescription: "Invalid JSON value.") + ) + } + } +} + +public struct UserCredentials: Codable, Hashable, Sendable { + public var email: String? + public var password: String? + public var phone: String? + public var refreshToken: String? + + public init( + email: String? = nil, + password: String? = nil, + phone: String? = nil, + refreshToken: String? = nil + ) { + self.email = email + self.password = password + self.phone = phone + self.refreshToken = refreshToken + } + + public enum CodingKeys: String, CodingKey { + case email + case password + case phone + case refreshToken = "refresh_token" + } +} + +public struct SignUpRequest: Codable, Hashable, Sendable { + public var email: String? + public var password: String? + public var phone: String? + public var data: [String: AnyJSON]? + public var gotrueMetaSecurity: GoTrueMetaSecurity? + + public init( + email: String? = nil, + password: String? = nil, + phone: String? = nil, + data: [String: AnyJSON]? = nil, + gotrueMetaSecurity: GoTrueMetaSecurity? = nil + ) { + self.email = email + self.password = password + self.phone = phone + self.data = data + self.gotrueMetaSecurity = gotrueMetaSecurity + } + + public enum CodingKeys: String, CodingKey { + case email + case password + case phone + case data + case gotrueMetaSecurity = "gotrue_meta_security" + } +} + +public struct Session: Codable, Hashable, Sendable { + /// The oauth provider token. If present, this can be used to make external API requests to the + /// oauth provider used. + public var providerToken: String? + /// The oauth provider refresh token. If present, this can be used to refresh the provider_token + /// via the oauth provider's API. Not all oauth providers return a provider refresh token. If the + /// provider_refresh_token is missing, please refer to the oauth provider's documentation for + /// information on how to obtain the provider refresh token. + public var providerRefreshToken: String? + /// The access token jwt. It is recommended to set the JWT_EXPIRY to a shorter expiry value. + public var accessToken: String + public var tokenType: String + /// The number of seconds until the token expires (since it was issued). Returned when a login is + /// confirmed. + public var expiresIn: Double + /// A one-time used refresh token that never expires. + public var refreshToken: String + public var user: User + + public init( + providerToken: String? = nil, + providerRefreshToken: String? = nil, + accessToken: String, + tokenType: String, + expiresIn: Double, + refreshToken: String, + user: User + ) { + self.providerToken = providerToken + self.providerRefreshToken = providerRefreshToken + self.accessToken = accessToken + self.tokenType = tokenType + self.expiresIn = expiresIn + self.refreshToken = refreshToken + self.user = user + } + + public enum CodingKeys: String, CodingKey { + case providerToken = "provider_token" + case providerRefreshToken = "provider_refresh_token" + case accessToken = "access_token" + case tokenType = "token_type" + case expiresIn = "expires_in" + case refreshToken = "refresh_token" + case user + } +} + +public struct User: Codable, Hashable, Identifiable, Sendable { + public var id: UUID + public var appMetadata: [String: AnyJSON] + public var userMetadata: [String: AnyJSON] + public var aud: String + public var confirmationSentAt: Date? + public var recoverySentAt: Date? + public var emailChangeSentAt: Date? + public var newEmail: String? + public var invitedAt: Date? + public var actionLink: String? + public var email: String? + public var phone: String? + public var createdAt: Date + public var confirmedAt: Date? + public var emailConfirmedAt: Date? + public var phoneConfirmedAt: Date? + public var lastSignInAt: Date? + public var role: String? + public var updatedAt: Date + public var identities: [UserIdentity]? + + public init( + id: UUID, + appMetadata: [String: AnyJSON], + userMetadata: [String: AnyJSON], + aud: String, + confirmationSentAt: Date? = nil, + recoverySentAt: Date? = nil, + emailChangeSentAt: Date? = nil, + newEmail: String? = nil, + invitedAt: Date? = nil, + actionLink: String? = nil, + email: String? = nil, + phone: String? = nil, + createdAt: Date, + confirmedAt: Date? = nil, + emailConfirmedAt: Date? = nil, + phoneConfirmedAt: Date? = nil, + lastSignInAt: Date? = nil, + role: String? = nil, + updatedAt: Date, + identities: [UserIdentity]? = nil + ) { + self.id = id + self.appMetadata = appMetadata + self.userMetadata = userMetadata + self.aud = aud + self.confirmationSentAt = confirmationSentAt + self.recoverySentAt = recoverySentAt + self.emailChangeSentAt = emailChangeSentAt + self.newEmail = newEmail + self.invitedAt = invitedAt + self.actionLink = actionLink + self.email = email + self.phone = phone + self.createdAt = createdAt + self.confirmedAt = confirmedAt + self.emailConfirmedAt = emailConfirmedAt + self.phoneConfirmedAt = phoneConfirmedAt + self.lastSignInAt = lastSignInAt + self.role = role + self.updatedAt = updatedAt + self.identities = identities + } + + public enum CodingKeys: String, CodingKey { + case id + case appMetadata = "app_metadata" + case userMetadata = "user_metadata" + case aud + case confirmationSentAt = "confirmation_sent_at" + case recoverySentAt = "recovery_sent_at" + case emailChangeSentAt = "email_change_sent_at" + case newEmail = "new_email" + case invitedAt = "invited_at" + case actionLink = "action_link" + case email + case phone + case createdAt = "created_at" + case confirmedAt = "confirmed_at" + case emailConfirmedAt = "email_confirmed_at" + case phoneConfirmedAt = "phone_confirmed_at" + case lastSignInAt = "last_sign_in_at" + case role + case updatedAt = "updated_at" + case identities + } +} + +public struct UserIdentity: Codable, Hashable, Identifiable, Sendable { + public var id: String + public var userID: UUID + public var identityData: [String: AnyJSON] + public var provider: String + public var createdAt: Date + public var lastSignInAt: Date + public var updatedAt: Date + + public init( + id: String, + userID: UUID, + identityData: [String: AnyJSON], + provider: String, + createdAt: Date, + lastSignInAt: Date, + updatedAt: Date + ) { + self.id = id + self.userID = userID + self.identityData = identityData + self.provider = provider + self.createdAt = createdAt + self.lastSignInAt = lastSignInAt + self.updatedAt = updatedAt + } + + public enum CodingKeys: String, CodingKey { + case id + case userID = "user_id" + case identityData = "identity_data" + case provider + case createdAt = "created_at" + case lastSignInAt = "last_sign_in_at" + case updatedAt = "updated_at" + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + id = try container.decode(String.self, forKey: .id) + userID = try container.decode(UUID.self, forKey: .userID) + identityData = + try container + .decodeIfPresent([String: AnyJSON].self, forKey: .identityData) ?? [:] + provider = try container.decode(String.self, forKey: .provider) + createdAt = try container.decode(Date.self, forKey: .createdAt) + lastSignInAt = try container.decode(Date.self, forKey: .lastSignInAt) + updatedAt = try container.decode(Date.self, forKey: .updatedAt) + } +} + +public enum Provider: String, Codable, CaseIterable, Sendable { + case apple + case azure + case bitbucket + case discord + case email + case facebook + case github + case gitlab + case google + case keycloak + case linkedin + case notion + case slack + case spotify + case twitch + case twitter + case workos +} + +public struct OpenIDConnectCredentials: Codable, Hashable, Sendable { + /// Provider name or OIDC `iss` value identifying which provider should be used to verify the + /// provided token. Supported names: `google`, `apple`, `azure`, `facebook`. + public var provider: Provider? + + /// OIDC ID token issued by the specified provider. The `iss` claim in the ID token must match the + /// supplied provider. Some ID tokens contain an `at_hash` which require that you provide an + /// `access_token` value to be accepted properly. If the token contains a `nonce` claim you must + /// supply the nonce used to obtain the ID token. + public var idToken: String + + /// If the ID token contains an `at_hash` claim, then the hash of this value is compared to the + /// value in the ID token. + public var accessToken: String? + + /// If the ID token contains a `nonce` claim, then the hash of this value is compared to the value + /// in the ID token. + public var nonce: String? + + /// Verification token received when the user completes the captcha on the site. + public var gotrueMetaSecurity: GoTrueMetaSecurity? + + public init( + provider: Provider? = nil, + idToken: String, + accessToken: String? = nil, + nonce: String? = nil, + gotrueMetaSecurity: GoTrueMetaSecurity? = nil + ) { + self.provider = provider + self.idToken = idToken + self.accessToken = accessToken + self.nonce = nonce + self.gotrueMetaSecurity = gotrueMetaSecurity + } + + public enum CodingKeys: String, CodingKey { + case provider + case idToken = "id_token" + case accessToken = "access_token" + case nonce + case gotrueMetaSecurity = "gotrue_meta_security" + } + + public enum Provider: String, Codable, Hashable, Sendable { + case google, apple, azure, facebook + } +} + +public struct GoTrueMetaSecurity: Codable, Hashable, Sendable { + public var captchaToken: String + + public init(captchaToken: String) { + self.captchaToken = captchaToken + } + + public enum CodingKeys: String, CodingKey { + case captchaToken = "captcha_token" + } +} + +public struct OTPParams: Codable, Hashable, Sendable { + public var email: String? + public var phone: String? + public var createUser: Bool + public var data: [String: AnyJSON]? + public var gotrueMetaSecurity: GoTrueMetaSecurity? + + public init( + email: String? = nil, + phone: String? = nil, + createUser: Bool? = nil, + data: [String: AnyJSON]? = nil, + gotrueMetaSecurity: GoTrueMetaSecurity? = nil + ) { + self.email = email + self.phone = phone + self.createUser = createUser ?? true + self.data = data + self.gotrueMetaSecurity = gotrueMetaSecurity + } + + public enum CodingKeys: String, CodingKey { + case email + case phone + case createUser = "create_user" + case data + case gotrueMetaSecurity = "gotrue_meta_security" + } +} + +public struct VerifyOTPParams: Codable, Hashable, Sendable { + public var email: String? + public var phone: String? + public var token: String + public var type: OTPType + public var gotrueMetaSecurity: GoTrueMetaSecurity? + + public init( + email: String? = nil, + phone: String? = nil, + token: String, + type: OTPType, + gotrueMetaSecurity: GoTrueMetaSecurity? = nil + ) { + self.email = email + self.phone = phone + self.token = token + self.type = type + self.gotrueMetaSecurity = gotrueMetaSecurity + } + + public enum CodingKeys: String, CodingKey { + case email + case phone + case token + case type + case gotrueMetaSecurity = "gotrue_meta_security" + } +} + +public enum OTPType: String, Codable, CaseIterable, Sendable { + case sms + case phoneChange = "phone_change" + case signup + case invite + case magiclink + case recovery + case emailChange = "email_change" +} + +public enum AuthResponse: Codable, Hashable, Sendable { + case session(Session) + case user(User) + + public init(from decoder: Decoder) throws { + let container = try decoder.singleValueContainer() + if let value = try? container.decode(Session.self) { + self = .session(value) + } else if let value = try? container.decode(User.self) { + self = .user(value) + } else { + throw DecodingError.dataCorruptedError( + in: container, + debugDescription: "Data could not be decoded as any of the expected types (Session, User)." + ) + } + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.singleValueContainer() + switch self { + case let .session(value): try container.encode(value) + case let .user(value): try container.encode(value) + } + } +} + +public struct UserAttributes: Codable, Hashable, Sendable { + /// The user's email. + public var email: String? + /// The user's phone. + public var phone: String? + /// The user's password. + public var password: String? + /// An email change token. + public var emailChangeToken: String? + /// A custom data object to store the user's metadata. This maps to the `auth.users.user_metadata` + /// column. The `data` should be a JSON object that includes user-specific info, such as their + /// first and last name. + public var data: [String: AnyJSON]? + + public init( + email: String? = nil, + phone: String? = nil, + password: String? = nil, + emailChangeToken: String? = nil, + data: [String: AnyJSON]? = nil + ) { + self.email = email + self.phone = phone + self.password = password + self.emailChangeToken = emailChangeToken + self.data = data + } + + public enum CodingKeys: String, CodingKey { + case email + case phone + case password + case emailChangeToken = "email_change_token" + case data + } +} + +public struct RecoverParams: Codable, Hashable, Sendable { + public var email: String + public var gotrueMetaSecurity: GoTrueMetaSecurity? + + public init(email: String, gotrueMetaSecurity: GoTrueMetaSecurity? = nil) { + self.email = email + self.gotrueMetaSecurity = gotrueMetaSecurity + } + + public enum CodingKeys: String, CodingKey { + case email + case gotrueMetaSecurity = "gotrue_meta_security" + } +} + +// MARK: - Encodable & Decodable + +private let dateFormatterWithFractionalSeconds = { () -> ISO8601DateFormatter in + let formatter = ISO8601DateFormatter() + formatter.formatOptions = [.withInternetDateTime, .withFractionalSeconds] + return formatter +}() + +private let dateFormatter = { () -> ISO8601DateFormatter in + let formatter = ISO8601DateFormatter() + formatter.formatOptions = [.withInternetDateTime] + return formatter +}() + +extension JSONDecoder { + public static let goTrue = { () -> JSONDecoder in + let decoder = JSONDecoder() + decoder.dateDecodingStrategy = .custom { decoder in + let container = try decoder.singleValueContainer() + let string = try container.decode(String.self) + + let supportedFormatters = [dateFormatterWithFractionalSeconds, dateFormatter] + + for formatter in supportedFormatters { + if let date = formatter.date(from: string) { + return date + } + } + + throw DecodingError.dataCorruptedError( + in: container, debugDescription: "Invalid date format: \(string)" + ) + } + return decoder + }() +} + +extension JSONEncoder { + public static let goTrue = { () -> JSONEncoder in + let encoder = JSONEncoder() + encoder.dateEncodingStrategy = .custom { date, encoder in + var container = encoder.singleValueContainer() + let string = dateFormatter.string(from: date) + try container.encode(string) + } + return encoder + }() +} diff --git a/Sources/GoTrue/Version.swift b/Sources/GoTrue/Version.swift new file mode 100644 index 00000000..21b131bc --- /dev/null +++ b/Sources/GoTrue/Version.swift @@ -0,0 +1 @@ +let version = "1.3.0" diff --git a/Tests/GoTrueTests/DecoderTests.swift b/Tests/GoTrueTests/DecoderTests.swift new file mode 100644 index 00000000..72941f83 --- /dev/null +++ b/Tests/GoTrueTests/DecoderTests.swift @@ -0,0 +1,42 @@ +import GoTrue +import SnapshotTesting +import XCTest + +final class InMemoryLocalStorage: GoTrueLocalStorage, @unchecked Sendable { + private let queue = DispatchQueue(label: "InMemoryLocalStorage") + private var storage: [String: Data] = [:] + + func store(key: String, value: Data) throws { + queue.sync { + storage[key] = value + } + } + + func retrieve(key: String) throws -> Data? { + queue.sync { + storage[key] + } + } + + func remove(key: String) throws { + queue.sync { + storage[key] = nil + } + } +} + +final class DecoderTests: XCTestCase { + func testDecodeUser() { + XCTAssertNoThrow( + try JSONDecoder.goTrue.decode(User.self, from: json(named: "user")) + ) + } + + func testDecodeSessionOrUser() { + XCTAssertNoThrow( + try JSONDecoder.goTrue.decode( + AuthResponse.self, from: json(named: "session") + ) + ) + } +} diff --git a/Tests/GoTrueTests/JWTTests.swift b/Tests/GoTrueTests/JWTTests.swift new file mode 100644 index 00000000..d6084725 --- /dev/null +++ b/Tests/GoTrueTests/JWTTests.swift @@ -0,0 +1,13 @@ +import XCTest + +@testable import GoTrue + +final class JWTTests: XCTestCase { + func testDecodeJWT() throws { + let token = + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJhdXRoZW50aWNhdGVkIiwiZXhwIjoxNjQ4NjQwMDIxLCJzdWIiOiJmMzNkM2VjOS1hMmVlLTQ3YzQtODBlMS01YmQ5MTlmM2Q4YjgiLCJlbWFpbCI6ImhpQGJpbmFyeXNjcmFwaW5nLmNvIiwicGhvbmUiOiIiLCJhcHBfbWV0YWRhdGEiOnsicHJvdmlkZXIiOiJlbWFpbCIsInByb3ZpZGVycyI6WyJlbWFpbCJdfSwidXNlcl9tZXRhZGF0YSI6e30sInJvbGUiOiJhdXRoZW50aWNhdGVkIn0.CGr5zNE5Yltlbn_3Ms2cjSLs_AW9RKM3lxh7cTQrg0w" + let jwt = try decode(jwt: token) + let exp = try XCTUnwrap(jwt["exp"] as? TimeInterval) + XCTAssertEqual(exp, 1_648_640_021) + } +} diff --git a/Tests/GoTrueTests/MockHelpers.swift b/Tests/GoTrueTests/MockHelpers.swift new file mode 100644 index 00000000..94c294c3 --- /dev/null +++ b/Tests/GoTrueTests/MockHelpers.swift @@ -0,0 +1,15 @@ +import Foundation +import Mocker + +@testable import GoTrue + +func json(named name: String) -> Data { + let url = Bundle.module.url(forResource: name, withExtension: "json") + return try! Data(contentsOf: url!) +} + +extension Decodable { + init(fromMockNamed name: String) { + self = try! JSONDecoder.goTrue.decode(Self.self, from: json(named: name)) + } +} diff --git a/Tests/GoTrueTests/RequestsTests.swift b/Tests/GoTrueTests/RequestsTests.swift new file mode 100644 index 00000000..0e883623 --- /dev/null +++ b/Tests/GoTrueTests/RequestsTests.swift @@ -0,0 +1,305 @@ +// +// File.swift +// +// +// Created by Guilherme Souza on 07/10/23. +// + +import SnapshotTesting +import XCTest + +@testable import GoTrue + +struct UnimplementedError: Error {} + +final class RequestsTests: XCTestCase { + + var localStorage: InMemoryLocalStorage! + + func testSignUpWithEmailAndPassword() async { + let sut = makeSUT() + + await assert { + try await sut.signUp( + email: "example@mail.com", + password: "the.pass", + data: ["custom_key": .string("custom_value")], + redirectTo: URL(string: "https://supabase.com"), + captchaToken: "dummy-captcha" + ) + } + } + + func testSignUpWithPhoneAndPassword() async { + let sut = makeSUT() + await assert { + try await sut.signUp( + phone: "+1 202-918-2132", + password: "the.pass", + data: ["custom_key": .string("custom_value")], + captchaToken: "dummy-captcha" + ) + } + } + + func testSignInWithEmailAndPassword() async { + let sut = makeSUT() + await assert { + try await sut.signIn( + email: "example@mail.com", + password: "the.pass" + ) + } + } + + func testSignInWithPhoneAndPassword() async { + let sut = makeSUT() + await assert { + try await sut.signIn( + phone: "+1 202-918-2132", + password: "the.pass" + ) + } + } + + func testSignInWithIdToken() async { + let sut = makeSUT() + await assert { + try await sut.signInWithIdToken( + credentials: OpenIDConnectCredentials( + provider: .apple, + idToken: "id-token", + accessToken: "access-token", + nonce: "nonce", + gotrueMetaSecurity: GoTrueMetaSecurity( + captchaToken: "captcha-token" + ) + ) + ) + } + } + + func testSignInWithOTPUsingEmail() async { + let sut = makeSUT() + await assert { + try await sut.signInWithOTP( + email: "example@mail.com", + redirectTo: URL(string: "https://supabase.com"), + shouldCreateUser: true, + data: ["custom_key": .string("custom_value")], + captchaToken: "dummy-captcha" + ) + } + } + + func testSignInWithOTPUsingPhone() async { + let sut = makeSUT() + await assert { + try await sut.signInWithOTP( + phone: "+1 202-918-2132", + shouldCreateUser: true, + data: ["custom_key": .string("custom_value")], + captchaToken: "dummy-captcha" + ) + } + } + + func testGetOAuthSignInURL() throws { + let sut = makeSUT() + let url = try sut.getOAuthSignInURL( + provider: .github, scopes: "read,write", + redirectTo: URL(string: "https://dummy-url.com/redirect")!, + queryParams: [("extra_key", "extra_value")] + ) + XCTAssertEqual( + url, + URL( + string: + "http://localhost:54321/auth/v1/authorize?provider=github&scopes=read,write&redirect_to=https://dummy-url.com/redirect&extra_key=extra_value" + )! + ) + } + + func testRefreshSession() async { + let sut = makeSUT() + await assert { + try await sut.refreshSession(refreshToken: "refresh-token") + } + } + + #if !os(watchOS) + // Not working on watchOS. + func testSessionFromURL() async throws { + let sut = makeSUT(fetch: { request in + let authorizationHeader = request.allHTTPHeaderFields?["Authorization"] + XCTAssertEqual(authorizationHeader, "bearer accesstoken") + return (json(named: "user"), HTTPURLResponse()) + }) + + let url = URL( + string: + "https://dummy-url.com/callback#access_token=accesstoken&expires_in=60&refresh_token=refreshtoken&token_type=bearer" + )! + + let session = try await sut.session(from: url) + let expectedSession = Session( + accessToken: "accesstoken", + tokenType: "bearer", + expiresIn: 60, + refreshToken: "refreshtoken", + user: User(fromMockNamed: "user") + ) + XCTAssertEqual(session, expectedSession) + } + #endif + + func testSessionFromURLWithMissingComponent() async { + let sut = makeSUT() + let url = URL( + string: + "https://dummy-url.com/callback#access_token=accesstoken&expires_in=60&refresh_token=refreshtoken" + )! + + do { + _ = try await sut.session(from: url) + } catch let error as URLError { + XCTAssertEqual(error.code, .badURL) + } catch { + XCTFail("Unexpected error thrown: \(error.localizedDescription)") + } + } + + func testSetSessionWithAFutureExpirationDate() async throws { + let sut = makeSUT() + try localStorage.storeSession(.init(session: .validSession)) + + let accessToken = + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJhdXRoZW50aWNhdGVkIiwiZXhwIjo0ODUyMTYzNTkzLCJzdWIiOiJmMzNkM2VjOS1hMmVlLTQ3YzQtODBlMS01YmQ5MTlmM2Q4YjgiLCJlbWFpbCI6ImhpQGJpbmFyeXNjcmFwaW5nLmNvIiwicGhvbmUiOiIiLCJhcHBfbWV0YWRhdGEiOnsicHJvdmlkZXIiOiJlbWFpbCIsInByb3ZpZGVycyI6WyJlbWFpbCJdfSwidXNlcl9tZXRhZGF0YSI6e30sInJvbGUiOiJhdXRoZW50aWNhdGVkIn0.UiEhoahP9GNrBKw_OHBWyqYudtoIlZGkrjs7Qa8hU7I" + + await assert { + try await sut.setSession(accessToken: accessToken, refreshToken: "dummy-refresh-token") + } + } + + func testSetSessionWithAExpiredToken() async throws { + let sut = makeSUT() + + let accessToken = + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJhdXRoZW50aWNhdGVkIiwiZXhwIjoxNjQ4NjQwMDIxLCJzdWIiOiJmMzNkM2VjOS1hMmVlLTQ3YzQtODBlMS01YmQ5MTlmM2Q4YjgiLCJlbWFpbCI6ImhpQGJpbmFyeXNjcmFwaW5nLmNvIiwicGhvbmUiOiIiLCJhcHBfbWV0YWRhdGEiOnsicHJvdmlkZXIiOiJlbWFpbCIsInByb3ZpZGVycyI6WyJlbWFpbCJdfSwidXNlcl9tZXRhZGF0YSI6e30sInJvbGUiOiJhdXRoZW50aWNhdGVkIn0.CGr5zNE5Yltlbn_3Ms2cjSLs_AW9RKM3lxh7cTQrg0w" + + await assert { + try await sut.setSession(accessToken: accessToken, refreshToken: "dummy-refresh-token") + } + } + + func testSignOut() async { + let sut = makeSUT() + await assert { + try await sut.signOut() + } + } + + func testVerifyOTPUsingEmail() async { + let sut = makeSUT() + await assert { + try await sut.verifyOTP( + email: "example@mail.com", + token: "123456", + type: .magiclink, + redirectTo: URL(string: "https://supabase.com"), + captchaToken: "captcha-token" + ) + } + } + + func testVerifyOTPUsingPhone() async { + let sut = makeSUT() + await assert { + try await sut.verifyOTP( + phone: "+1 202-918-2132", + token: "123456", + type: .sms, + captchaToken: "captcha-token" + ) + } + } + + func testUpdateUser() async throws { + let sut = makeSUT() + try localStorage.storeSession(StoredSession(session: .validSession)) + await assert { + try await sut.update( + user: UserAttributes( + email: "example@mail.com", + phone: "+1 202-918-2132", + password: "another.pass", + emailChangeToken: "123456", + data: ["custom_key": .string("custom_value")] + ) + ) + } + } + + func testResetPasswordForEmail() async { + let sut = makeSUT() + await assert { + try await sut.resetPasswordForEmail( + "example@mail.com", + redirectTo: URL(string: "https://supabase.com"), + captchaToken: "captcha-token" + ) + } + } + + private func assert(_ block: () async throws -> Void) async { + do { + try await block() + } catch is UnimplementedError { + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + private func makeSUT( + record: Bool = false, + fetch: GoTrueClient.FetchHandler? = nil, + file: StaticString = #file, + testName: String = #function, + line: UInt = #line + ) -> GoTrueClient { + localStorage = InMemoryLocalStorage() + let encoder = JSONEncoder.goTrue + encoder.outputFormatting = .sortedKeys + + return GoTrueClient( + url: clientURL, + headers: ["apikey": "dummy.api.key"], + localStorage: localStorage, + encoder: encoder, + fetch: { request in + DispatchQueue.main.sync { + assertSnapshot( + of: request, as: .curl, record: record, file: file, testName: testName, line: line) + } + + if let fetch { + return try await fetch(request) + } + + throw UnimplementedError() + } + ) + } +} + +let clientURL = URL(string: "http://localhost:54321/auth/v1")! + +extension Session { + static let validSession = Session( + accessToken: "accesstoken", + tokenType: "bearer", + expiresIn: 120, + refreshToken: "refreshtoken", + user: User(fromMockNamed: "user") + ) +} diff --git a/Tests/GoTrueTests/Resources/session.json b/Tests/GoTrueTests/Resources/session.json new file mode 100644 index 00000000..24eeff1b --- /dev/null +++ b/Tests/GoTrueTests/Resources/session.json @@ -0,0 +1,37 @@ +{ + "access_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJhdXRoZW50aWNhdGVkIiwiZXhwIjoxNjQ4NjQwMDIxLCJzdWIiOiJmMzNkM2VjOS1hMmVlLTQ3YzQtODBlMS01YmQ5MTlmM2Q4YjgiLCJlbWFpbCI6Imd1aWxoZXJtZTJAZ3Jkcy5kZXYiLCJwaG9uZSI6IiIsImFwcF9tZXRhZGF0YSI6eyJwcm92aWRlciI6ImVtYWlsIiwicHJvdmlkZXJzIjpbImVtYWlsIl19LCJ1c2VyX21ldGFkYXRhIjp7fSwicm9sZSI6ImF1dGhlbnRpY2F0ZWQifQ.4lMvmz2pJkWu1hMsBgXP98Fwz4rbvFYl4VA9joRv6kY", + "token_type": "bearer", + "expires_in": 3600, + "refresh_token": "GGduTeu95GraIXQ56jppkw", + "user": { + "id": "f33d3ec9-a2ee-47c4-80e1-5bd919f3d8b8", + "aud": "authenticated", + "role": "authenticated", + "email": "guilherme@binaryscraping.co", + "email_confirmed_at": "2022-03-30T10:33:41.018575157Z", + "phone": "", + "last_sign_in_at": "2022-03-30T10:33:41.021531328Z", + "app_metadata": { + "provider": "email", + "providers": [ + "email" + ] + }, + "user_metadata": {}, + "identities": [ + { + "id": "f33d3ec9-a2ee-47c4-80e1-5bd919f3d8b8", + "user_id": "f33d3ec9-a2ee-47c4-80e1-5bd919f3d8b8", + "identity_data": { + "sub": "f33d3ec9-a2ee-47c4-80e1-5bd919f3d8b8" + }, + "provider": "email", + "last_sign_in_at": "2022-03-30T10:33:41.015557063Z", + "created_at": "2022-03-30T10:33:41.015612Z", + "updated_at": "2022-03-30T10:33:41.015616Z" + } + ], + "created_at": "2022-03-30T10:33:41.005433Z", + "updated_at": "2022-03-30T10:33:41.022688Z" + } +} diff --git a/Tests/GoTrueTests/Resources/signup-response.json b/Tests/GoTrueTests/Resources/signup-response.json new file mode 100644 index 00000000..7cbd6883 --- /dev/null +++ b/Tests/GoTrueTests/Resources/signup-response.json @@ -0,0 +1,30 @@ +{ + "id": "859f402d-b3de-4105-a1b9-932836d9193b", + "aud": "authenticated", + "role": "authenticated", + "email": "guilherme@grds.dev", + "phone": "", + "confirmation_sent_at": "2022-04-09T11:57:01.710600634Z", + "app_metadata": { + "provider": "email", + "providers": [ + "email" + ] + }, + "user_metadata": {}, + "identities": [ + { + "id": "859f402d-b3de-4105-a1b9-932836d9193b", + "user_id": "859f402d-b3de-4105-a1b9-932836d9193b", + "identity_data": { + "sub": "859f402d-b3de-4105-a1b9-932836d9193b" + }, + "provider": "email", + "last_sign_in_at": "2022-04-09T11:23:45.899902Z", + "created_at": "2022-04-09T11:23:45.899924Z", + "updated_at": "2022-04-09T11:23:45.899926Z" + } + ], + "created_at": "2022-04-09T11:23:45.874827Z", + "updated_at": "2022-04-09T11:57:01.720803Z" +} diff --git a/Tests/GoTrueTests/Resources/user.json b/Tests/GoTrueTests/Resources/user.json new file mode 100644 index 00000000..a9e81cdc --- /dev/null +++ b/Tests/GoTrueTests/Resources/user.json @@ -0,0 +1,32 @@ +{ + "id": "859f402d-b3de-4105-a1b9-932836d9193b", + "aud": "authenticated", + "role": "authenticated", + "email": "guilherme@grds.dev", + "phone": "", + "confirmation_sent_at": "2022-04-09T11:57:01.710600634Z", + "app_metadata": { + "provider": "email", + "providers": [ + "email" + ] + }, + "user_metadata": { + "referrer_id": null + }, + "identities": [ + { + "id": "859f402d-b3de-4105-a1b9-932836d9193b", + "user_id": "859f402d-b3de-4105-a1b9-932836d9193b", + "identity_data": { + "sub": "859f402d-b3de-4105-a1b9-932836d9193b" + }, + "provider": "email", + "last_sign_in_at": "2022-04-09T11:23:45Z", + "created_at": "2022-04-09T11:23:45.899924Z", + "updated_at": "2022-04-09T11:23:45.899926Z" + } + ], + "created_at": "2022-04-09T11:23:45.874827Z", + "updated_at": "2022-04-09T11:57:01.720803Z" +} diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testRefreshSession.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testRefreshSession.1.txt new file mode 100644 index 00000000..42eccab2 --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testRefreshSession.1.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"refresh_token\":\"refresh-token\"}" \ + "http://localhost:54321/auth/v1/token?grant_type=refresh_token" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testResetPasswordForEmail.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testResetPasswordForEmail.1.txt new file mode 100644 index 00000000..924deff5 --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testResetPasswordForEmail.1.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"email\":\"example@mail.com\",\"gotrue_meta_security\":{\"captcha_token\":\"captcha-token\"}}" \ + "http://localhost:54321/auth/v1/recover?redirect_to=https://supabase.com" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSessionFromURL.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSessionFromURL.1.txt new file mode 100644 index 00000000..1b620315 --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSessionFromURL.1.txt @@ -0,0 +1,5 @@ +curl \ + --header "Authorization: bearer accesstoken" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + "http://localhost:54321/auth/v1/user" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSetSessionWithAExpiredToken.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSetSessionWithAExpiredToken.1.txt new file mode 100644 index 00000000..ba0cd1c1 --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSetSessionWithAExpiredToken.1.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"refresh_token\":\"dummy-refresh-token\"}" \ + "http://localhost:54321/auth/v1/token?grant_type=refresh_token" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSetSessionWithAFutureExpirationDate.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSetSessionWithAFutureExpirationDate.1.txt new file mode 100644 index 00000000..e903025f --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSetSessionWithAFutureExpirationDate.1.txt @@ -0,0 +1,5 @@ +curl \ + --header "Authorization: Bearer accesstoken" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + "http://localhost:54321/auth/v1/user" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithEmailAndPassword.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithEmailAndPassword.1.txt new file mode 100644 index 00000000..c588fa7e --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithEmailAndPassword.1.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"email\":\"example@mail.com\",\"password\":\"the.pass\"}" \ + "http://localhost:54321/auth/v1/token?grant_type=password" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithIdToken.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithIdToken.1.txt new file mode 100644 index 00000000..e97d4bfe --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithIdToken.1.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"access_token\":\"access-token\",\"gotrue_meta_security\":{\"captcha_token\":\"captcha-token\"},\"id_token\":\"id-token\",\"nonce\":\"nonce\",\"provider\":\"apple\"}" \ + "http://localhost:54321/auth/v1/token?grant_type=id_token" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithOTPUsingEmail.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithOTPUsingEmail.1.txt new file mode 100644 index 00000000..2c7a3957 --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithOTPUsingEmail.1.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"create_user\":true,\"data\":{\"custom_key\":\"custom_value\"},\"email\":\"example@mail.com\",\"gotrue_meta_security\":{\"captcha_token\":\"dummy-captcha\"}}" \ + "http://localhost:54321/auth/v1/otp" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithOTPUsingPhone.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithOTPUsingPhone.1.txt new file mode 100644 index 00000000..105daf56 --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithOTPUsingPhone.1.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"create_user\":true,\"data\":{\"custom_key\":\"custom_value\"},\"gotrue_meta_security\":{\"captcha_token\":\"dummy-captcha\"},\"phone\":\"+1 202-918-2132\"}" \ + "http://localhost:54321/auth/v1/otp" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithPhoneAndPassword.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithPhoneAndPassword.1.txt new file mode 100644 index 00000000..4cd170ac --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignInWithPhoneAndPassword.1.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"password\":\"the.pass\",\"phone\":\"+1 202-918-2132\"}" \ + "http://localhost:54321/auth/v1/token?grant_type=password" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignUpWithEmailAndPassword.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignUpWithEmailAndPassword.1.txt new file mode 100644 index 00000000..98ebb920 --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignUpWithEmailAndPassword.1.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"data\":{\"custom_key\":\"custom_value\"},\"email\":\"example@mail.com\",\"gotrue_meta_security\":{\"captcha_token\":\"dummy-captcha\"},\"password\":\"the.pass\"}" \ + "http://localhost:54321/auth/v1/signup?redirect_to=https://supabase.com" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignUpWithPhoneAndPassword.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignUpWithPhoneAndPassword.1.txt new file mode 100644 index 00000000..0353c874 --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testSignUpWithPhoneAndPassword.1.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"data\":{\"custom_key\":\"custom_value\"},\"gotrue_meta_security\":{\"captcha_token\":\"dummy-captcha\"},\"password\":\"the.pass\",\"phone\":\"+1 202-918-2132\"}" \ + "http://localhost:54321/auth/v1/signup" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testUpdateUser.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testUpdateUser.1.txt new file mode 100644 index 00000000..d6f31a72 --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testUpdateUser.1.txt @@ -0,0 +1,8 @@ +curl \ + --request PUT \ + --header "Authorization: Bearer accesstoken" \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"data\":{\"custom_key\":\"custom_value\"},\"email\":\"example@mail.com\",\"email_change_token\":\"123456\",\"password\":\"another.pass\",\"phone\":\"+1 202-918-2132\"}" \ + "http://localhost:54321/auth/v1/user" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testVerifyOTPUsingEmail.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testVerifyOTPUsingEmail.1.txt new file mode 100644 index 00000000..0a64bf73 --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testVerifyOTPUsingEmail.1.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"email\":\"example@mail.com\",\"gotrue_meta_security\":{\"captcha_token\":\"captcha-token\"},\"token\":\"123456\",\"type\":\"magiclink\"}" \ + "http://localhost:54321/auth/v1/verify?redirect_to=https://supabase.com" \ No newline at end of file diff --git a/Tests/GoTrueTests/__Snapshots__/RequestsTests/testVerifyOTPUsingPhone.1.txt b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testVerifyOTPUsingPhone.1.txt new file mode 100644 index 00000000..41fd2ad7 --- /dev/null +++ b/Tests/GoTrueTests/__Snapshots__/RequestsTests/testVerifyOTPUsingPhone.1.txt @@ -0,0 +1,7 @@ +curl \ + --request POST \ + --header "Content-Type: application/json" \ + --header "X-Client-Info: gotrue-swift/1.3.0" \ + --header "apikey: dummy.api.key" \ + --data "{\"gotrue_meta_security\":{\"captcha_token\":\"captcha-token\"},\"phone\":\"+1 202-918-2132\",\"token\":\"123456\",\"type\":\"sms\"}" \ + "http://localhost:54321/auth/v1/verify" \ No newline at end of file diff --git a/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved b/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved index f690ce9a..e2b4dc23 100644 --- a/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,14 +1,5 @@ { "pins" : [ - { - "identity" : "gotrue-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/supabase-community/gotrue-swift", - "state" : { - "branch" : "dependency-free", - "revision" : "6dc6d577ce88613cd1ae17b2367228e4d684b101" - } - }, { "identity" : "keychainaccess", "kind" : "remoteSourceControl", From 093c73384ceacbc25fe09d35d701584651c9c146 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Tue, 17 Oct 2023 17:25:32 -0300 Subject: [PATCH 4/7] Add realtime to repo --- Package.resolved | 9 - Package.swift | 22 +- Sources/Realtime/Channel.swift | 822 +++++++++++++ Sources/Realtime/Defaults.swift | 261 +++++ Sources/Realtime/Delegated.swift | 102 ++ Sources/Realtime/HeartbeatTimer.swift | 133 +++ Sources/Realtime/Message.swift | 89 ++ Sources/Realtime/Presence.swift | 443 +++++++ Sources/Realtime/Push.swift | 265 +++++ Sources/Realtime/RealtimeClient.swift | 1015 +++++++++++++++++ Sources/Realtime/SynchronizedArray.swift | 33 + Sources/Realtime/TimeoutTimer.swift | 108 ++ Sources/Realtime/Transport.swift | 300 +++++ Tests/RealtimeTests/ChannelTopicTests.swift | 19 + Tests/RealtimeTests/RealtimeTests.swift | 128 +++ .../xcshareddata/swiftpm/Package.resolved | 9 - 16 files changed, 3726 insertions(+), 32 deletions(-) create mode 100644 Sources/Realtime/Channel.swift create mode 100644 Sources/Realtime/Defaults.swift create mode 100644 Sources/Realtime/Delegated.swift create mode 100644 Sources/Realtime/HeartbeatTimer.swift create mode 100644 Sources/Realtime/Message.swift create mode 100644 Sources/Realtime/Presence.swift create mode 100644 Sources/Realtime/Push.swift create mode 100644 Sources/Realtime/RealtimeClient.swift create mode 100644 Sources/Realtime/SynchronizedArray.swift create mode 100644 Sources/Realtime/TimeoutTimer.swift create mode 100644 Sources/Realtime/Transport.swift create mode 100644 Tests/RealtimeTests/ChannelTopicTests.swift create mode 100644 Tests/RealtimeTests/RealtimeTests.swift diff --git a/Package.resolved b/Package.resolved index 4cea5e65..0e82514d 100644 --- a/Package.resolved +++ b/Package.resolved @@ -18,15 +18,6 @@ "version" : "3.0.1" } }, - { - "identity" : "realtime-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/supabase-community/realtime-swift.git", - "state" : { - "revision" : "0b985c687fe963f6bd818ff77a35c27247b98bb4", - "version" : "0.0.2" - } - }, { "identity" : "storage-swift", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index 117bc42f..afe27098 100644 --- a/Package.swift +++ b/Package.swift @@ -17,7 +17,8 @@ var package = Package( .library(name: "Functions", targets: ["Functions"]), .library(name: "GoTrue", targets: ["GoTrue"]), .library(name: "PostgREST", targets: ["PostgREST"]), - .library(name: "Supabase", targets: ["Supabase", "Functions", "PostgREST", "GoTrue"]), + .library(name: "Realtime", targets: ["Realtime"]), + .library(name: "Supabase", targets: ["Supabase", "Functions", "PostgREST", "GoTrue", "Realtime"]), ], dependencies: [ .package(url: "https://github.com/kishikawakatsumi/KeychainAccess", from: "4.2.2"), @@ -40,14 +41,9 @@ var package = Package( "Mocker", .product(name: "SnapshotTesting", package: "swift-snapshot-testing"), ], - resources: [ - .process("Resources") - ] - ), - .target( - name: "PostgREST", - dependencies: [] + resources: [.process("Resources")] ), + .target(name: "PostgREST"), .testTarget( name: "PostgRESTTests", dependencies: [ @@ -58,17 +54,17 @@ var package = Package( condition: .when(platforms: [.iOS, .macOS, .tvOS]) ), ], - exclude: [ - "__Snapshots__" - ] + exclude: ["__Snapshots__"] ), .testTarget(name: "PostgRESTIntegrationTests", dependencies: ["PostgREST"]), + .target(name: "Realtime"), + .testTarget(name: "RealtimeTests", dependencies: ["Realtime"]), .target( name: "Supabase", dependencies: [ "GoTrue", .product(name: "SupabaseStorage", package: "storage-swift"), - .product(name: "Realtime", package: "realtime-swift"), + "Realtime", "PostgREST", "Functions", ] @@ -81,7 +77,6 @@ if ProcessInfo.processInfo.environment["USE_LOCAL_PACKAGES"] != nil { package.dependencies.append( contentsOf: [ .package(path: "../storage-swift"), - .package(path: "../realtime-swift"), ] ) } else { @@ -91,7 +86,6 @@ if ProcessInfo.processInfo.environment["USE_LOCAL_PACKAGES"] != nil { url: "https://github.com/supabase-community/storage-swift.git", branch: "dependency-free" ), - .package(url: "https://github.com/supabase-community/realtime-swift.git", from: "0.0.2"), ] ) } diff --git a/Sources/Realtime/Channel.swift b/Sources/Realtime/Channel.swift new file mode 100644 index 00000000..c22a5215 --- /dev/null +++ b/Sources/Realtime/Channel.swift @@ -0,0 +1,822 @@ +// Copyright (c) 2021 David Stump +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +import Foundation +import Swift + +/// Container class of bindings to the channel +struct Binding { + // The event that the Binding is bound to + let event: ChannelEvent + + // The reference number of the Binding + let ref: Int + + // The callback to be triggered + let callback: Delegated +} + +/// +/// Represents a Channel which is bound to a topic +/// +/// A Channel can bind to multiple events on a given topic and +/// be informed when those events occur within a topic. +/// +/// ### Example: +/// +/// let channel = socket.channel("room:123", params: ["token": "Room Token"]) +/// channel.on("new_msg") { payload in print("Got message", payload") } +/// channel.push("new_msg, payload: ["body": "This is a message"]) +/// .receive("ok") { payload in print("Sent message", payload) } +/// .receive("error") { payload in print("Send failed", payload) } +/// .receive("timeout") { payload in print("Networking issue...", payload) } +/// +/// channel.join() +/// .receive("ok") { payload in print("Channel Joined", payload) } +/// .receive("error") { payload in print("Failed ot join", payload) } +/// .receive("timeout") { payload in print("Networking issue...", payload) } +/// + +public class Channel { + /// The topic of the Channel. e.g. "rooms:friends" + public let topic: ChannelTopic + + /// The params sent when joining the channel + public var params: Payload { + didSet { self.joinPush.payload = params } + } + + /// The Socket that the channel belongs to + weak var socket: RealtimeClient? + + /// Current state of the Channel + var state: ChannelState + + /// Collection of event bindings + var syncBindingsDel: SynchronizedArray + + /// Tracks event binding ref counters + var bindingRef: Int + + /// Timout when attempting to join a Channel + var timeout: TimeInterval + + /// Set to true once the channel calls .join() + var joinedOnce: Bool + + /// Push to send when the channel calls .join() + var joinPush: Push! + + /// Buffer of Pushes that will be sent once the Channel's socket connects + var pushBuffer: [Push] + + /// Timer to attempt to rejoin + var rejoinTimer: TimeoutTimer + + /// Refs of stateChange hooks + var stateChangeRefs: [String] + + /// Initialize a Channel + /// - parameter topic: Topic of the Channel + /// - parameter options: Optional. Options to configure channel broadcast and presence. Leave nil for postgres channel. + /// - parameter socket: Socket that the channel is a part of + convenience init(topic: ChannelTopic, options: ChannelOptions? = nil, socket: RealtimeClient) { + self.init(topic: topic, params: options?.params ?? [:], socket: socket) + } + + /// Initialize a Channel + /// + /// - parameter topic: Topic of the Channel + /// - parameter params: Optional. Parameters to send when joining. + /// - parameter socket: Socket that the channel is a part of + init(topic: ChannelTopic, params: [String: Any], socket: RealtimeClient) { + state = ChannelState.closed + self.topic = topic + self.params = params + self.socket = socket + syncBindingsDel = SynchronizedArray() + bindingRef = 0 + timeout = socket.timeout + joinedOnce = false + pushBuffer = [] + stateChangeRefs = [] + rejoinTimer = TimeoutTimer() + + // Setup Timer delgation + rejoinTimer.callback + .delegate(to: self) { (self) in + if self.socket?.isConnected == true { self.rejoin() } + } + + rejoinTimer.timerCalculation + .delegate(to: self) { (self, tries) -> TimeInterval in + self.socket?.rejoinAfter(tries) ?? 5.0 + } + + // Respond to socket events + let onErrorRef = self.socket?.delegateOnError( + to: self, + callback: { (self, _) in + self.rejoinTimer.reset() + } + ) + if let ref = onErrorRef { stateChangeRefs.append(ref) } + + let onOpenRef = self.socket?.delegateOnOpen( + to: self, + callback: { (self) in + self.rejoinTimer.reset() + if self.isErrored { self.rejoin() } + } + ) + if let ref = onOpenRef { stateChangeRefs.append(ref) } + + // Setup Push Event to be sent when joining + joinPush = Push( + channel: self, + event: ChannelEvent.join, + payload: self.params, + timeout: timeout + ) + + /// Handle when a response is received after join() + joinPush.delegateReceive(.ok, to: self) { (self, _) in + // Mark the Channel as joined + self.state = ChannelState.joined + + // Reset the timer, preventing it from attempting to join again + self.rejoinTimer.reset() + + // Send and buffered messages and clear the buffer + self.pushBuffer.forEach { $0.send() } + self.pushBuffer = [] + } + + // Perform if Channel errors while attempting to joi + joinPush.delegateReceive(.error, to: self) { (self, _) in + self.state = .errored + if self.socket?.isConnected == true { self.rejoinTimer.scheduleTimeout() } + } + + // Handle when the join push times out when sending after join() + joinPush.delegateReceive(.timeout, to: self) { (self, _) in + // log that the channel timed out + self.socket?.logItems( + "channel", "timeout \(self.topic) \(self.joinRef ?? "") after \(self.timeout)s" + ) + + // Send a Push to the server to leave the channel + let leavePush = Push( + channel: self, + event: ChannelEvent.leave, + timeout: self.timeout + ) + leavePush.send() + + // Mark the Channel as in an error and attempt to rejoin if socket is connected + self.state = ChannelState.errored + self.joinPush.reset() + + if self.socket?.isConnected == true { self.rejoinTimer.scheduleTimeout() } + } + + /// Perfom when the Channel has been closed + delegateOnClose(to: self) { (self, _) in + // Reset any timer that may be on-going + self.rejoinTimer.reset() + + // Log that the channel was left + self.socket?.logItems( + "channel", "close topic: \(self.topic) joinRef: \(self.joinRef ?? "nil")" + ) + + // Mark the channel as closed and remove it from the socket + self.state = ChannelState.closed + self.socket?.remove(self) + } + + /// Perfom when the Channel errors + delegateOnError(to: self) { (self, message) in + // Log that the channel received an error + self.socket?.logItems( + "channel", "error topic: \(self.topic) joinRef: \(self.joinRef ?? "nil") mesage: \(message)" + ) + + // If error was received while joining, then reset the Push + if self.isJoining { + // Make sure that the "phx_join" isn't buffered to send once the socket + // reconnects. The channel will send a new join event when the socket connects. + if let safeJoinRef = self.joinRef { + self.socket?.removeFromSendBuffer(ref: safeJoinRef) + } + + // Reset the push to be used again later + self.joinPush.reset() + } + + // Mark the channel as errored and attempt to rejoin if socket is currently connected + self.state = ChannelState.errored + if self.socket?.isConnected == true { self.rejoinTimer.scheduleTimeout() } + } + + // Perform when the join reply is received + delegateOn(ChannelEvent.reply, to: self) { (self, message) in + // Trigger bindings + self.trigger( + event: ChannelEvent.channelReply(message.ref), + payload: message.rawPayload, + ref: message.ref, + joinRef: message.joinRef + ) + } + } + + deinit { + rejoinTimer.reset() + } + + /// Overridable message hook. Receives all events for specialized message + /// handling before dispatching to the channel callbacks. + /// + /// - parameter msg: The Message received by the client from the server + /// - return: Must return the message, modified or unmodified + public var onMessage: (_ message: Message) -> Message = { message in + message + } + + /// Joins the channel + /// + /// - parameter timeout: Optional. Defaults to Channel's timeout + /// - return: Push event + @discardableResult + public func join(timeout: TimeInterval? = nil) -> Push { + guard !joinedOnce else { + fatalError( + "tried to join multiple times. 'join' " + + "can only be called a single time per channel instance") + } + + // Join the Channel + if let safeTimeout = timeout { self.timeout = safeTimeout } + + joinedOnce = true + rejoin() + return joinPush + } + + /// Hook into when the Channel is closed. Does not handle retain cycles. + /// Use `delegateOnClose(to:)` for automatic handling of retain cycles. + /// + /// Example: + /// + /// let channel = socket.channel("topic") + /// channel.onClose() { [weak self] message in + /// self?.print("Channel \(message.topic) has closed" + /// } + /// + /// - parameter callback: Called when the Channel closes + /// - return: Ref counter of the subscription. See `func off()` + @discardableResult + public func onClose(_ callback: @escaping ((Message) -> Void)) -> Int { + return on(ChannelEvent.close, callback: callback) + } + + /// Hook into when the Channel is closed. Automatically handles retain + /// cycles. Use `onClose()` to handle yourself. + /// + /// Example: + /// + /// let channel = socket.channel("topic") + /// channel.delegateOnClose(to: self) { (self, message) in + /// self.print("Channel \(message.topic) has closed" + /// } + /// + /// - parameter owner: Class registering the callback. Usually `self` + /// - parameter callback: Called when the Channel closes + /// - return: Ref counter of the subscription. See `func off()` + @discardableResult + public func delegateOnClose( + to owner: Target, + callback: @escaping ((Target, Message) -> Void) + ) -> Int { + return delegateOn(ChannelEvent.close, to: owner, callback: callback) + } + + /// Hook into when the Channel receives an Error. Does not handle retain + /// cycles. Use `delegateOnError(to:)` for automatic handling of retain + /// cycles. + /// + /// Example: + /// + /// let channel = socket.channel("topic") + /// channel.onError() { [weak self] (message) in + /// self?.print("Channel \(message.topic) has errored" + /// } + /// + /// - parameter callback: Called when the Channel closes + /// - return: Ref counter of the subscription. See `func off()` + @discardableResult + public func onError(_ callback: @escaping ((_ message: Message) -> Void)) -> Int { + return on(ChannelEvent.error, callback: callback) + } + + /// Hook into when the Channel receives an Error. Automatically handles + /// retain cycles. Use `onError()` to handle yourself. + /// + /// Example: + /// + /// let channel = socket.channel("topic") + /// channel.delegateOnError(to: self) { (self, message) in + /// self.print("Channel \(message.topic) has closed" + /// } + /// + /// - parameter owner: Class registering the callback. Usually `self` + /// - parameter callback: Called when the Channel closes + /// - return: Ref counter of the subscription. See `func off()` + @discardableResult + public func delegateOnError( + to owner: Target, + callback: @escaping ((Target, Message) -> Void) + ) -> Int { + return delegateOn(ChannelEvent.error, to: owner, callback: callback) + } + + /// Subscribes on channel events. Does not handle retain cycles. Use + /// `delegateOn(_:, to:)` for automatic handling of retain cycles. + /// + /// Subscription returns a ref counter, which can be used later to + /// unsubscribe the exact event listener + /// + /// Example: + /// + /// let channel = socket.channel("topic") + /// let ref1 = channel.on("event") { [weak self] (message) in + /// self?.print("do stuff") + /// } + /// let ref2 = channel.on("event") { [weak self] (message) in + /// self?.print("do other stuff") + /// } + /// channel.off("event", ref1) + /// + /// Since unsubscription of ref1, "do stuff" won't print, but "do other + /// stuff" will keep on printing on the "event" + /// + /// - parameter event: Event to receive + /// - parameter callback: Called with the event's message + /// - return: Ref counter of the subscription. See `func off()` + @discardableResult + public func on(_ event: ChannelEvent, callback: @escaping ((Message) -> Void)) -> Int { + var delegated = Delegated() + delegated.manuallyDelegate(with: callback) + + return on(event, delegated: delegated) + } + + /// Subscribes on channel events. Automatically handles retain cycles. Use + /// `on()` to handle yourself. + /// + /// Subscription returns a ref counter, which can be used later to + /// unsubscribe the exact event listener + /// + /// Example: + /// + /// let channel = socket.channel("topic") + /// let ref1 = channel.delegateOn("event", to: self) { (self, message) in + /// self?.print("do stuff") + /// } + /// let ref2 = channel.delegateOn("event", to: self) { (self, message) in + /// self?.print("do other stuff") + /// } + /// channel.off("event", ref1) + /// + /// Since unsubscription of ref1, "do stuff" won't print, but "do other + /// stuff" will keep on printing on the "event" + /// + /// - parameter event: Event to receive + /// - parameter owner: Class registering the callback. Usually `self` + /// - parameter callback: Called with the event's message + /// - return: Ref counter of the subscription. See `func off()` + @discardableResult + public func delegateOn( + _ event: ChannelEvent, + to owner: Target, + callback: @escaping ((Target, Message) -> Void) + ) -> Int { + var delegated = Delegated() + delegated.delegate(to: owner, with: callback) + + return on(event, delegated: delegated) + } + + /// Shared method between `on` and `manualOn` + @discardableResult + private func on(_ event: ChannelEvent, delegated: Delegated) -> Int { + let ref = bindingRef + bindingRef = ref + 1 + + syncBindingsDel.append(Binding(event: event, ref: ref, callback: delegated)) + return ref + } + + /// Unsubscribes from a channel event. If a `ref` is given, only the exact + /// listener will be removed. Else all listeners for the `event` will be + /// removed. + /// + /// Example: + /// + /// let channel = socket.channel("topic") + /// let ref1 = channel.on("event") { _ in print("ref1 event" } + /// let ref2 = channel.on("event") { _ in print("ref2 event" } + /// let ref3 = channel.on("other_event") { _ in print("ref3 other" } + /// let ref4 = channel.on("other_event") { _ in print("ref4 other" } + /// channel.off("event", ref1) + /// channel.off("other_event") + /// + /// After this, only "ref2 event" will be printed if the channel receives + /// "event" and nothing is printed if the channel receives "other_event". + /// + /// - parameter event: Event to unsubscribe from + /// - paramter ref: Ref counter returned when subscribing. Can be omitted + public func off(_ event: ChannelEvent, ref: Int? = nil) { + syncBindingsDel.removeAll { bind -> Bool in + bind.event == event && (ref == nil || ref == bind.ref) + } + } + + /// Push a payload to the Channel + /// + /// Example: + /// + /// channel + /// .push("event", payload: ["message": "hello") + /// .receive("ok") { _ in { print("message sent") } + /// + /// - parameter event: Event to push + /// - parameter payload: Payload to push + /// - parameter timeout: Optional timeout + @discardableResult + public func push( + _ event: ChannelEvent, + payload: Payload, + timeout: TimeInterval = Defaults.timeoutInterval + ) -> Push { + guard joinedOnce else { + fatalError( + "Tried to push \(event) to \(topic) before joining. Use channel.join() before pushing events" + ) + } + + let pushEvent = Push( + channel: self, + event: event, + payload: payload, + timeout: timeout + ) + if canPush { + pushEvent.send() + } else { + pushEvent.startTimeout() + pushBuffer.append(pushEvent) + } + + return pushEvent + } + + /// Leaves the channel + /// + /// Unsubscribes from server events, and instructs channel to terminate on + /// server + /// + /// Triggers onClose() hooks + /// + /// To receive leave acknowledgements, use the a `receive` + /// hook to bind to the server ack, ie: + /// + /// Example: + //// + /// channel.leave().receive("ok") { _ in { print("left") } + /// + /// - parameter timeout: Optional timeout + /// - return: Push that can add receive hooks + @discardableResult + public func leave(timeout: TimeInterval = Defaults.timeoutInterval) -> Push { + // If attempting a rejoin during a leave, then reset, cancelling the rejoin + rejoinTimer.reset() + + // Now set the state to leaving + state = .leaving + + /// Delegated callback for a successful or a failed channel leave + var onCloseDelegate = Delegated() + onCloseDelegate.delegate(to: self) { (self, _) in + self.socket?.logItems("channel", "leave \(self.topic)") + + // Triggers onClose() hooks + self.trigger(event: ChannelEvent.close, payload: ["reason": "leave"]) + } + + // Push event to send to the server + let leavePush = Push( + channel: self, + event: ChannelEvent.leave, + timeout: timeout + ) + + // Perform the same behavior if successfully left the channel + // or if sending the event timed out + leavePush + .receive(.ok, delegated: onCloseDelegate) + .receive(.timeout, delegated: onCloseDelegate) + leavePush.send() + + // If the Channel cannot send push events, trigger a success locally + if !canPush { leavePush.trigger(.ok, payload: [:]) } + + // Return the push so it can be bound to + return leavePush + } + + /// Overridable message hook. Receives all events for specialized message + /// handling before dispatching to the channel callbacks. + /// + /// - parameter event: The event the message was for + /// - parameter payload: The payload for the message + /// - parameter ref: The reference of the message + /// - return: Must return the payload, modified or unmodified + public func onMessage(callback: @escaping (Message) -> Message) { + onMessage = callback + } + + // ---------------------------------------------------------------------- + + // MARK: - Internal + + // ---------------------------------------------------------------------- + /// Checks if an event received by the Socket belongs to this Channel + func isMember(_ message: Message) -> Bool { + // Return false if the message's topic does not match the Channel's topic + guard message.topic == topic else { return false } + + guard + let safeJoinRef = message.joinRef, + safeJoinRef != joinRef, + message.event.isLifecyleEvent + else { return true } + + socket?.logItems( + "channel", "dropping outdated message", message.topic, message.event, message.rawPayload, + safeJoinRef + ) + return false + } + + /// Sends the payload to join the Channel + func sendJoin(_ timeout: TimeInterval) { + state = ChannelState.joining + joinPush.resend(timeout) + } + + /// Rejoins the channel + func rejoin(_ timeout: TimeInterval? = nil) { + // Do not attempt to rejoin if the channel is in the process of leaving + guard !isLeaving else { return } + + // Leave potentially duplicate channels + socket?.leaveOpenTopic(topic: topic) + + // Send the joinPush + sendJoin(timeout ?? self.timeout) + } + + /// Triggers an event to the correct event bindings created by + /// `channel.on("event")`. + /// + /// - parameter message: Message to pass to the event bindings + func trigger(_ message: Message) { + let handledMessage = onMessage(message) + + syncBindingsDel + .filter { $0.event == message.event } + .forEach { $0.callback.call(handledMessage) } + } + + /// Triggers an event to the correct event bindings created by + //// `channel.on("event")`. + /// + /// - parameter event: Event to trigger + /// - parameter payload: Payload of the event + /// - parameter ref: Ref of the event. Defaults to empty + /// - parameter joinRef: Ref of the join event. Defaults to nil + func trigger( + event: ChannelEvent, + payload: Payload = [:], + ref: String = "", + joinRef: String? = nil + ) { + let message = Message( + ref: ref, + topic: topic, + event: event, + payload: payload, + joinRef: joinRef ?? self.joinRef + ) + trigger(message) + } + + /// The Ref send during the join message. + var joinRef: String? { + return joinPush.ref + } + + /// - return: True if the Channel can push messages, meaning the socket + /// is connected and the channel is joined + var canPush: Bool { + return socket?.isConnected == true && isJoined + } +} + +// ---------------------------------------------------------------------- + +// MARK: - Public API + +// ---------------------------------------------------------------------- +extension Channel { + /// - return: True if the Channel has been closed + public var isClosed: Bool { + return state == .closed + } + + /// - return: True if the Channel experienced an error + public var isErrored: Bool { + return state == .errored + } + + /// - return: True if the channel has joined + public var isJoined: Bool { + return state == .joined + } + + /// - return: True if the channel has requested to join + public var isJoining: Bool { + return state == .joining + } + + /// - return: True if the channel has requested to leave + public var isLeaving: Bool { + return state == .leaving + } +} +// ---------------------------------------------------------------------- + +// MARK: - Codable Payload + +// ---------------------------------------------------------------------- + +extension Payload { + + /// Initializes a payload from a given value + /// - parameter value: The value to encode + /// - parameter encoder: The encoder to use to encode the payload + /// - throws: Throws an error if the payload cannot be encoded + init(_ value: T, encoder: JSONEncoder = Defaults.encoder) throws { + let data = try encoder.encode(value) + self = try JSONSerialization.jsonObject(with: data, options: .allowFragments) as! Payload + } + + /// Decodes the payload to a given type + /// - parameter type: The type to decode to + /// - parameter decoder: The decoder to use to decode the payload + /// - returns: The decoded payload + /// - throws: Throws an error if the payload cannot be decoded + public func decode( + to type: T.Type = T.self, decoder: JSONDecoder = Defaults.decoder + ) throws -> T { + let data = try JSONSerialization.data(withJSONObject: self) + return try decoder.decode(type, from: data) + } + +} + +// ---------------------------------------------------------------------- + +// MARK: - Broadcast API + +// ---------------------------------------------------------------------- + +/// Represents the payload of a broadcast message +public struct BroadcastPayload { + public let type: String + public let event: String + public let payload: Payload +} + +extension Channel { + /// Broadcasts the payload to all other members of the channel + /// - parameter event: The event to broadcast + /// - parameter payload: The payload to broadcast + @discardableResult + public func broadcast(event: String, payload: Payload) -> Push { + self.push( + .broadcast, + payload: [ + "type": "broadcast", + "event": event, + "payload": payload, + ]) + } + + /// Broadcasts the encodable payload to all other members of the channel + /// - parameter event: The event to broadcast + /// - parameter payload: The payload to broadcast + /// - parameter encoder: The encoder to use to encode the payload + /// - throws: Throws an error if the payload cannot be encoded + @discardableResult + public func broadcast(event: String, payload: Encodable, encoder: JSONEncoder = Defaults.encoder) + throws -> Push + { + self.broadcast(event: event, payload: try Payload(payload)) + } + + /// Subscribes to broadcast events. Does not handle retain cycles. + /// + /// Example: + /// + /// let ref = channel.onBroadcast { [weak self] (message,broadcast) in + /// print(broadcast.event, broadcast.payload) + /// } + /// channel.off(.broadcast, ref1) + /// + /// Subscription returns a ref counter, which can be used later to + /// unsubscribe the exact event listener + /// - parameter callback: Called with the broadcast payload + /// - returns: Ref counter of the subscription. See `func off()` + @discardableResult + public func onBroadcast(callback: @escaping (Message, BroadcastPayload) -> Void) -> Int { + self.on( + .broadcast, + callback: { message in + let payload = BroadcastPayload( + type: message.payload["type"] as! String, event: message.payload["event"] as! String, + payload: message.payload["payload"] as! Payload) + callback(message, payload) + }) + } + +} +// ---------------------------------------------------------------------- + +// MARK: - Presence API + +// ---------------------------------------------------------------------- + +extension Channel { + /// Share presence state, available to all channel members via sync + /// - parameter payload: The payload to broadcast + @discardableResult + public func track(payload: Payload) -> Push { + self.push( + .presence, + payload: [ + "type": "presence", + "event": "track", + "payload": payload, + ]) + } + + /// Share presence state, available to all channel members via sync + /// - parameter payload: The payload to broadcast + /// - parameter encoder: The encoder to use to encode the payload + /// - throws: Throws an error if the payload cannot be encoded + @discardableResult + public func track(payload: Encodable, encoder: JSONEncoder = Defaults.encoder) throws -> Push { + self.track(payload: try Payload(payload)) + } + + /// Remove presence state for given channel + @discardableResult + public func untrack() -> Push { + self.push( + .presence, + payload: [ + "type": "presence", + "event": "untrack", + ]) + } +} diff --git a/Sources/Realtime/Defaults.swift b/Sources/Realtime/Defaults.swift new file mode 100644 index 00000000..bfc4e55d --- /dev/null +++ b/Sources/Realtime/Defaults.swift @@ -0,0 +1,261 @@ +// Copyright (c) 2021 David Stump +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +import Foundation + +/// A collection of default values and behaviors used across the Client +public enum Defaults { + /// Default timeout when sending messages + public static let timeoutInterval: TimeInterval = 10.0 + + /// Default interval to send heartbeats on + public static let heartbeatInterval: TimeInterval = 30.0 + + /// Default maximum amount of time which the system may delay heartbeat events in order to minimize power usage + public static let heartbeatLeeway: DispatchTimeInterval = .milliseconds(10) + + /// Default reconnect algorithm for the socket + public static let reconnectSteppedBackOff: (Int) -> TimeInterval = { tries in + tries > 9 ? 5.0 : [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0, 2.0][tries - 1] + } + + /** Default rejoin algorithm for individual channels */ + public static let rejoinSteppedBackOff: (Int) -> TimeInterval = { tries in + tries > 3 ? 10 : [1, 2, 5][tries - 1] + } + + public static let vsn = "2.0.0" + + /// Default encoder + public static let encoder: JSONEncoder = JSONEncoder() + + /// Default encode function, utilizing JSONSerialization.data + public static let encode: (Any) -> Data = { json in + assert(JSONSerialization.isValidJSONObject(json), "Invalid JSON object") + return + try! JSONSerialization + .data( + withJSONObject: json, + options: JSONSerialization.WritingOptions() + ) + } + + /// Default decoder + public static let decoder: JSONDecoder = JSONDecoder() + + /// Default decode function, utilizing JSONSerialization.jsonObject + public static let decode: (Data) -> Any? = { data in + guard + let json = + try? JSONSerialization + .jsonObject( + with: data, + options: JSONSerialization.ReadingOptions() + ) + else { return nil } + return json + } + + public static let heartbeatQueue: DispatchQueue = .init( + label: "com.phoenix.socket.heartbeat") +} + +/// Represents the multiple states that a Channel can be in +/// throughout it's lifecycle. +public enum ChannelState: String { + case closed + case errored + case joined + case joining + case leaving +} + +/// Represents the different events that can be sent through +/// a channel regarding a Channel's lifecycle or +/// that can be registered to be notified of. +public enum ChannelEvent: RawRepresentable { + public enum Presence: String { + case state + case diff + } + + case heartbeat + case join + case leave + case reply + case error + case close + + case all + case insert + case update + case delete + + case channelReply(String) + + case broadcast + + case presence + case presenceState + case presenceDiff + + public var rawValue: String { + switch self { + case .heartbeat: return "heartbeat" + case .join: return "phx_join" + case .leave: return "phx_leave" + case .reply: return "phx_reply" + case .error: return "phx_error" + case .close: return "phx_close" + + case .all: return "*" + case .insert: return "insert" + case .update: return "update" + case .delete: return "delete" + + case let .channelReply(reference): return "chan_reply_\(reference)" + + case .broadcast: return "broadcast" + + case .presence: return "presence" + case .presenceState: return "presence_state" + case .presenceDiff: return "presence_diff" + } + } + + public init?(rawValue: String) { + switch rawValue.lowercased() { + case "heartbeat": self = .heartbeat + case "phx_join": self = .join + case "phx_leave": self = .leave + case "phx_reply": self = .reply + case "phx_error": self = .error + case "phx_close": self = .close + case "*": self = .all + case "insert": self = .insert + case "update": self = .update + case "delete": self = .delete + case "broadcast": self = .broadcast + case "presence": self = .presence + case "presence_state": self = .presenceState + case "presence_diff": self = .presenceDiff + default: return nil + } + } + + var isLifecyleEvent: Bool { + switch self { + case .join, .leave, .reply, .error, .close: return true + default: return false + } + } +} + +/// Represents the different topic a channel can subscribe to. +public enum ChannelTopic: RawRepresentable, Equatable { + case all + case schema(_ schema: String) + case table(_ table: String, schema: String) + case column(_ column: String, value: String, table: String, schema: String) + + case heartbeat + + public var rawValue: String { + switch self { + case .all: return "realtime:*" + case let .schema(name): return "realtime:\(name)" + case let .table(tableName, schema): return "realtime:\(schema):\(tableName)" + case let .column(columnName, value, table, schema): + return "realtime:\(schema):\(table):\(columnName)=eq.\(value)" + case .heartbeat: return "phoenix" + } + } + + public init?(rawValue: String) { + if rawValue == "realtime:*" || rawValue == "*" { + self = .all + } else if rawValue == "phoenix" { + self = .heartbeat + } else { + let parts = rawValue.replacingOccurrences(of: "realtime:", with: "").split(separator: ":") + switch parts.count { + case 1: + self = .schema(String(parts[0])) + case 2: + self = .table(String(parts[1]), schema: String(parts[0])) + case 3: + let condition = parts[2].split(separator: "=") + if condition.count == 2, + condition[1].hasPrefix("eq.") + { + self = .column( + String(condition[0]), value: String(condition[1].dropFirst(3)), table: String(parts[1]), + schema: String(parts[0]) + ) + } else { + return nil + } + default: + return nil + } + } + } +} + +/// Represents the broadcast and presence options for a channel. +public struct ChannelOptions { + /// Used to track presence payload across clients. Must be unique per client. If `nil`, the server will generate one. + var presenceKey: String? + /// Enables the client to receieve their own`broadcast` messages + var broadcastSelf: Bool + /// Instructs the server to acknoledge the client's `broadcast` messages + var broadcastAcknowledge: Bool + + public init( + presenceKey: String? = nil, broadcastSelf: Bool = false, broadcastAcknowledge: Bool = false + ) { + self.presenceKey = presenceKey + self.broadcastSelf = broadcastSelf + self.broadcastAcknowledge = broadcastAcknowledge + } + + /// Parameters used to configure the channel + var params: [String: [String: Any]] { + [ + "config": [ + "presence": [ + "key": presenceKey ?? "" + ], + "broadcast": [ + "ack": broadcastAcknowledge, + "self": broadcastSelf, + ], + ] + ] + } + +} + +/// Represents the different status of a push +public enum PushStatus: String { + case ok + case error + case timeout +} diff --git a/Sources/Realtime/Delegated.swift b/Sources/Realtime/Delegated.swift new file mode 100644 index 00000000..a3388d10 --- /dev/null +++ b/Sources/Realtime/Delegated.swift @@ -0,0 +1,102 @@ +// Copyright (c) 2021 David Stump +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +/// Provides a memory-safe way of passing callbacks around while not creating +/// retain cycles. This file was copied from https://github.com/dreymonde/Delegated +/// instead of added as a dependency to reduce the number of packages that +/// ship with SwiftPhoenixClient +public struct Delegated { + private(set) var callback: ((Input) -> Output?)? + + public init() {} + + public mutating func delegate( + to target: Target, + with callback: @escaping (Target, Input) -> Output + ) { + self.callback = { [weak target] input in + guard let target = target else { + return nil + } + return callback(target, input) + } + } + + public func call(_ input: Input) -> Output? { + return callback?(input) + } + + public var isDelegateSet: Bool { + return callback != nil + } +} + +extension Delegated { + public mutating func stronglyDelegate( + to target: Target, + with callback: @escaping (Target, Input) -> Output + ) { + self.callback = { input in + callback(target, input) + } + } + + public mutating func manuallyDelegate(with callback: @escaping (Input) -> Output) { + self.callback = callback + } + + public mutating func removeDelegate() { + callback = nil + } +} + +extension Delegated where Input == Void { + public mutating func delegate( + to target: Target, + with callback: @escaping (Target) -> Output + ) { + delegate(to: target, with: { target, _ in callback(target) }) + } + + public mutating func stronglyDelegate( + to target: Target, + with callback: @escaping (Target) -> Output + ) { + stronglyDelegate(to: target, with: { target, _ in callback(target) }) + } +} + +extension Delegated where Input == Void { + public func call() -> Output? { + return call(()) + } +} + +extension Delegated where Output == Void { + public func call(_ input: Input) { + callback?(input) + } +} + +extension Delegated where Input == Void, Output == Void { + public func call() { + call(()) + } +} diff --git a/Sources/Realtime/HeartbeatTimer.swift b/Sources/Realtime/HeartbeatTimer.swift new file mode 100644 index 00000000..d8de6c52 --- /dev/null +++ b/Sources/Realtime/HeartbeatTimer.swift @@ -0,0 +1,133 @@ +// Copyright (c) 2021 David Stump +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +import Foundation + +/// Heartbeat Timer class which manages the lifecycle of the underlying +/// timer which triggers when a heartbeat should be fired. This heartbeat +/// runs on it's own Queue so that it does not interfere with the main +/// queue but guarantees thread safety. +class HeartbeatTimer { + // ---------------------------------------------------------------------- + + // MARK: - Dependencies + + // ---------------------------------------------------------------------- + // The interval to wait before firing the Timer + let timeInterval: TimeInterval + + /// The maximum amount of time which the system may delay the delivery of the timer events + let leeway: DispatchTimeInterval + + // The DispatchQueue to schedule the timers on + let queue: DispatchQueue + + // UUID which specifies the Timer instance. Verifies that timers are different + let uuid: String = UUID().uuidString + + // ---------------------------------------------------------------------- + + // MARK: - Properties + + // ---------------------------------------------------------------------- + // The underlying, cancelable, resettable, timer. + private var temporaryTimer: DispatchSourceTimer? + // The event handler that is called by the timer when it fires. + private var temporaryEventHandler: (() -> Void)? + + /** + Create a new HeartbeatTimer + + - Parameters: + - timeInterval: Interval to fire the timer. Repeats + - queue: Queue to schedule the timer on + - leeway: The maximum amount of time which the system may delay the delivery of the timer events + */ + init( + timeInterval: TimeInterval, queue: DispatchQueue = Defaults.heartbeatQueue, + leeway: DispatchTimeInterval = Defaults.heartbeatLeeway + ) { + self.timeInterval = timeInterval + self.queue = queue + self.leeway = leeway + } + + /** + Create a new HeartbeatTimer + + - Parameter timeInterval: Interval to fire the timer. Repeats + */ + convenience init(timeInterval: TimeInterval) { + self.init(timeInterval: timeInterval, queue: Defaults.heartbeatQueue) + } + + func start(eventHandler: @escaping () -> Void) { + queue.sync { + // Create a new DispatchSourceTimer, passing the event handler + let timer = DispatchSource.makeTimerSource(flags: [], queue: queue) + timer.setEventHandler(handler: eventHandler) + + // Schedule the timer to first fire in `timeInterval` and then + // repeat every `timeInterval` + timer.schedule( + deadline: DispatchTime.now() + self.timeInterval, + repeating: self.timeInterval, + leeway: self.leeway + ) + + // Start the timer + timer.resume() + self.temporaryEventHandler = eventHandler + self.temporaryTimer = timer + } + } + + func stop() { + // Must be queued synchronously to prevent threading issues. + queue.sync { + // DispatchSourceTimer will automatically cancel when released + temporaryTimer = nil + temporaryEventHandler = nil + } + } + + /** + True if the Timer exists and has not been cancelled. False otherwise + */ + var isValid: Bool { + guard let timer = temporaryTimer else { return false } + return !timer.isCancelled + } + + /** + Calls the Timer's event handler immediately. This method + is primarily used in tests (not ideal) + */ + func fire() { + guard isValid else { return } + temporaryEventHandler?() + } +} + +extension HeartbeatTimer: Equatable { + static func == (lhs: HeartbeatTimer, rhs: HeartbeatTimer) -> Bool { + return lhs.uuid == rhs.uuid + } +} diff --git a/Sources/Realtime/Message.swift b/Sources/Realtime/Message.swift new file mode 100644 index 00000000..50472032 --- /dev/null +++ b/Sources/Realtime/Message.swift @@ -0,0 +1,89 @@ +// Copyright (c) 2021 David Stump +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +import Foundation + +/// Data that is received from the Server. +public class Message { + /// Reference number. Empty if missing + public let ref: String + + /// Join Reference number + internal let joinRef: String? + + /// Message topic + public let topic: ChannelTopic + + /// Message event + public let event: ChannelEvent + + /// The raw payload from the Message, including a nested response from + /// phx_reply events. It is recommended to use `payload` instead. + internal let rawPayload: Payload + + /// Message payload + public var payload: Payload { + guard let response = rawPayload["response"] as? Payload + else { return rawPayload } + return response + } + + /// Convenience accessor. Equivalent to getting the status as such: + /// ```swift + /// message.payload["status"] + /// ``` + public var status: PushStatus? { + guard let status = rawPayload["status"] as? String else { + return nil + } + return PushStatus(rawValue: status) + } + + init( + ref: String = "", + topic: ChannelTopic = .all, + event: ChannelEvent = .all, + payload: Payload = [:], + joinRef: String? = nil + ) { + self.ref = ref + self.topic = topic + self.event = event + rawPayload = payload + self.joinRef = joinRef + } + + init?(json: [Any?]) { + guard json.count > 4 else { return nil } + joinRef = json[0] as? String + ref = json[1] as? String ?? "" + + if let topic = (json[2] as? String).flatMap(ChannelTopic.init(rawValue:)), + let event = (json[3] as? String).flatMap(ChannelEvent.init(rawValue:)), + let payload = json[4] as? Payload + { + self.topic = topic + self.event = event + rawPayload = payload + } else { + return nil + } + } +} diff --git a/Sources/Realtime/Presence.swift b/Sources/Realtime/Presence.swift new file mode 100644 index 00000000..67bf4c84 --- /dev/null +++ b/Sources/Realtime/Presence.swift @@ -0,0 +1,443 @@ +// Copyright (c) 2021 David Stump +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +import Foundation + +/// The Presence object provides features for syncing presence information from +/// the server with the client and handling presences joining and leaving. +/// +/// ## Syncing state from the server +/// +/// To sync presence state from the server, first instantiate an object and pass +/// your channel in to track lifecycle events: +/// +/// let channel = socket.channel("some:topic") +/// let presence = Presence(channel) +/// +/// If you have custom syncing state events, you can configure the `Presence` +/// object to use those instead. +/// +/// let options = Options(events: [.state: "my_state", .diff: "my_diff"]) +/// let presence = Presence(channel, opts: options) +/// +/// Next, use the presence.onSync callback to react to state changes from the +/// server. For example, to render the list of users every time the list +/// changes, you could write: +/// +/// presence.onSync { renderUsers(presence.list()) } +/// +/// ## Listing Presences +/// +/// presence.list is used to return a list of presence information based on the +/// local state of metadata. By default, all presence metadata is returned, but +/// a listBy function can be supplied to allow the client to select which +/// metadata to use for a given presence. For example, you may have a user +/// online from different devices with a metadata status of "online", but they +/// have set themselves to "away" on another device. In this case, the app may +/// choose to use the "away" status for what appears on the UI. The example +/// below defines a listBy function which prioritizes the first metadata which +/// was registered for each user. This could be the first tab they opened, or +/// the first device they came online from: +/// +/// let listBy: (String, Presence.Map) -> Presence.Meta = { id, pres in +/// let first = pres["metas"]!.first! +/// first["count"] = pres["metas"]!.count +/// first["id"] = id +/// return first +/// } +/// let onlineUsers = presence.list(by: listBy) +/// +/// (NOTE: The underlying behavior is a `map` on the `presence.state`. You are +/// mapping the `state` dictionary into whatever datastructure suites your needs) +/// +/// ## Handling individual presence join and leave events +/// +/// The presence.onJoin and presence.onLeave callbacks can be used to react to +/// individual presences joining and leaving the app. For example: +/// +/// let presence = Presence(channel) +/// presence.onJoin { [weak self] (key, current, newPres) in +/// if let cur = current { +/// print("user additional presence", cur) +/// } else { +/// print("user entered for the first time", newPres) +/// } +/// } +/// +/// presence.onLeave { [weak self] (key, current, leftPres) in +/// if current["metas"]?.isEmpty == true { +/// print("user has left from all devices", leftPres) +/// } else { +/// print("user left from a device", current) +/// } +/// } +/// +/// presence.onSync { renderUsers(presence.list()) } +public final class Presence { + // ---------------------------------------------------------------------- + + // MARK: - Enums and Structs + + // ---------------------------------------------------------------------- + /// Custom options that can be provided when creating Presence + /// + /// ### Example: + /// + /// let options = Options(events: [.state: "my_state", .diff: "my_diff"]) + /// let presence = Presence(channel, opts: options) + public struct Options { + let events: [Events: ChannelEvent] + + /// Default set of Options used when creating Presence. Uses the + /// phoenix events "presence_state" and "presence_diff" + public static let defaults = Options(events: [ + .state: .presenceState, + .diff: .presenceDiff, + ]) + + public init(events: [Events: ChannelEvent]) { + self.events = events + } + } + + /// Presense Events + public enum Events: String { + case state + case diff + } + + // ---------------------------------------------------------------------- + + // MARK: - Typaliases + + // ---------------------------------------------------------------------- + /// Meta details of a Presence. Just a dictionary of properties + public typealias Meta = [String: Any] + + /// A mapping of a String to an array of Metas. e.g. {"metas": [{id: 1}]} + public typealias Map = [String: [Meta]] + + /// A mapping of a Presence state to a mapping of Metas + public typealias State = [String: Map] + + // Diff has keys "joins" and "leaves", pointing to a Presence.State each + // containing the users that joined and left. + public typealias Diff = [String: State] + + /// Closure signature of OnJoin callbacks + public typealias OnJoin = (_ key: String, _ current: Map?, _ new: Map) -> Void + + /// Closure signature for OnLeave callbacks + public typealias OnLeave = (_ key: String, _ current: Map, _ left: Map) -> Void + + //// Closure signature for OnSync callbacks + public typealias OnSync = () -> Void + + /// Collection of callbacks with default values + struct Caller { + var onJoin: OnJoin = { _, _, _ in } + var onLeave: OnLeave = { _, _, _ in } + var onSync: OnSync = {} + } + + // ---------------------------------------------------------------------- + + // MARK: - Properties + + // ---------------------------------------------------------------------- + /// The channel the Presence belongs to + weak var channel: Channel? + + /// Caller to callback hooks + var caller: Caller + + /// The state of the Presence + public private(set) var state: State + + /// Pending `join` and `leave` diffs that need to be synced + public private(set) var pendingDiffs: [Diff] + + /// The channel's joinRef, set when state events occur + public private(set) var joinRef: String? + + public var isPendingSyncState: Bool { + guard let safeJoinRef = joinRef else { return true } + return safeJoinRef != channel?.joinRef + } + + /// Callback to be informed of joins + public var onJoin: OnJoin { + get { return caller.onJoin } + set { caller.onJoin = newValue } + } + + /// Set the OnJoin callback + public func onJoin(_ callback: @escaping OnJoin) { + onJoin = callback + } + + /// Callback to be informed of leaves + public var onLeave: OnLeave { + get { return caller.onLeave } + set { caller.onLeave = newValue } + } + + /// Set the OnLeave callback + public func onLeave(_ callback: @escaping OnLeave) { + onLeave = callback + } + + /// Callback to be informed of synces + public var onSync: OnSync { + get { return caller.onSync } + set { caller.onSync = newValue } + } + + /// Set the OnSync callback + public func onSync(_ callback: @escaping OnSync) { + onSync = callback + } + + public init(channel: Channel, opts: Options = Options.defaults) { + state = [:] + pendingDiffs = [] + self.channel = channel + joinRef = nil + caller = Caller() + + guard // Do not subscribe to events if they were not provided + let stateEvent = opts.events[.state], + let diffEvent = opts.events[.diff] + else { return } + + self.channel?.delegateOn(stateEvent, to: self) { (self, message) in + guard let newState = message.rawPayload as? State else { return } + + self.joinRef = self.channel?.joinRef + self.state = Presence.syncState( + self.state, + newState: newState, + onJoin: self.caller.onJoin, + onLeave: self.caller.onLeave + ) + + self.pendingDiffs.forEach { diff in + self.state = Presence.syncDiff( + self.state, + diff: diff, + onJoin: self.caller.onJoin, + onLeave: self.caller.onLeave + ) + } + + self.pendingDiffs = [] + self.caller.onSync() + } + + self.channel?.delegateOn(diffEvent, to: self) { (self, message) in + guard let diff = message.rawPayload as? Diff else { return } + if self.isPendingSyncState { + self.pendingDiffs.append(diff) + } else { + self.state = Presence.syncDiff( + self.state, + diff: diff, + onJoin: self.caller.onJoin, + onLeave: self.caller.onLeave + ) + self.caller.onSync() + } + } + } + + /// Returns the array of presences, with deault selected metadata. + public func list() -> [Map] { + return list(by: { _, pres in pres }) + } + + /// Returns the array of presences, with selected metadata + public func list(by transformer: (String, Map) -> T) -> [T] { + return Presence.listBy(state, transformer: transformer) + } + + /// Filter the Presence state with a given function + public func filter(by filter: ((String, Map) -> Bool)?) -> State { + return Presence.filter(state, by: filter) + } + + // ---------------------------------------------------------------------- + + // MARK: - Static + + // ---------------------------------------------------------------------- + + // Used to sync the list of presences on the server + // with the client's state. An optional `onJoin` and `onLeave` callback can + // be provided to react to changes in the client's local presences across + // disconnects and reconnects with the server. + // + // - returns: Presence.State + @discardableResult + public static func syncState( + _ currentState: State, + newState: State, + onJoin: OnJoin = { _, _, _ in }, + onLeave: OnLeave = { _, _, _ in } + ) -> State { + let state = currentState + var leaves: Presence.State = [:] + var joins: Presence.State = [:] + + state.forEach { key, presence in + if newState[key] == nil { + leaves[key] = presence + } + } + + newState.forEach { key, newPresence in + if let currentPresence = state[key] { + let newRefs = newPresence["metas"]!.map { $0["phx_ref"] as! String } + let curRefs = currentPresence["metas"]!.map { $0["phx_ref"] as! String } + + let joinedMetas = newPresence["metas"]!.filter { (meta: Meta) -> Bool in + !curRefs.contains { $0 == meta["phx_ref"] as! String } + } + let leftMetas = currentPresence["metas"]!.filter { (meta: Meta) -> Bool in + !newRefs.contains { $0 == meta["phx_ref"] as! String } + } + + if joinedMetas.count > 0 { + joins[key] = newPresence + joins[key]!["metas"] = joinedMetas + } + + if leftMetas.count > 0 { + leaves[key] = currentPresence + leaves[key]!["metas"] = leftMetas + } + } else { + joins[key] = newPresence + } + } + + return Presence.syncDiff( + state, + diff: ["joins": joins, "leaves": leaves], + onJoin: onJoin, + onLeave: onLeave + ) + } + + // Used to sync a diff of presence join and leave + // events from the server, as they happen. Like `syncState`, `syncDiff` + // accepts optional `onJoin` and `onLeave` callbacks to react to a user + // joining or leaving from a device. + // + // - returns: Presence.State + @discardableResult + public static func syncDiff( + _ currentState: State, + diff: Diff, + onJoin: OnJoin = { _, _, _ in }, + onLeave: OnLeave = { _, _, _ in } + ) -> State { + var state = currentState + diff["joins"]?.forEach { key, newPresence in + let currentPresence = state[key] + state[key] = newPresence + + if let curPresence = currentPresence { + let joinedRefs = state[key]!["metas"]!.map { $0["phx_ref"] as! String } + let curMetas = curPresence["metas"]!.filter { (meta: Meta) -> Bool in + !joinedRefs.contains { $0 == meta["phx_ref"] as! String } + } + state[key]!["metas"]!.insert(contentsOf: curMetas, at: 0) + } + + onJoin(key, currentPresence, newPresence) + } + + diff["leaves"]?.forEach { key, leftPresence in + guard var curPresence = state[key] else { return } + let refsToRemove = leftPresence["metas"]!.map { $0["phx_ref"] as! String } + let keepMetas = curPresence["metas"]!.filter { (meta: Meta) -> Bool in + !refsToRemove.contains { $0 == meta["phx_ref"] as! String } + } + + curPresence["metas"] = keepMetas + onLeave(key, curPresence, leftPresence) + + if keepMetas.count > 0 { + state[key]!["metas"] = keepMetas + } else { + state.removeValue(forKey: key) + } + } + + return state + } + + public static func filter( + _ presences: State, + by filter: ((String, Map) -> Bool)? + ) -> State { + let safeFilter = filter ?? { _, _ in true } + return presences.filter(safeFilter) + } + + public static func listBy( + _ presences: State, + transformer: (String, Map) -> T + ) -> [T] { + return presences.map(transformer) + } +} + +extension Presence.Map { + + /// Decodes the presence metadata to an array of the specified type. + /// - parameter type: The type to decode to. + /// - parameter decoder: The decoder to use. + /// - returns: The decoded values. + /// - throws: Any error that occurs during decoding. + public func decode( + to type: T.Type = T.self, decoder: JSONDecoder = Defaults.decoder + ) throws -> [T] { + let metas: [Presence.Meta] = self["metas"]! + let data = try JSONSerialization.data(withJSONObject: metas) + return try decoder.decode([T].self, from: data) + } + +} + +extension Presence.State { + + /// Decodes the presence metadata to a dictionary of arrays of the specified type. + /// - parameter type: The type to decode to. + /// - parameter decoder: The decoder to use. + /// - returns: The dictionary of decoded values. + /// - throws: Any error that occurs during decoding. + public func decode( + to type: T.Type = T.self, decoder: JSONDecoder = Defaults.decoder + ) throws -> [String: [T]] { + return try mapValues { try $0.decode(decoder: decoder) } + } + +} diff --git a/Sources/Realtime/Push.swift b/Sources/Realtime/Push.swift new file mode 100644 index 00000000..0eb5e8bc --- /dev/null +++ b/Sources/Realtime/Push.swift @@ -0,0 +1,265 @@ +// Copyright (c) 2021 David Stump +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +import Foundation + +/// Represnts pushing data to a `Channel` through the `Socket` +public class Push { + /// The channel sending the Push + public weak var channel: Channel? + + /// The event, for example `ChannelEvent.join` + public let event: ChannelEvent + + /// The payload, for example ["user_id": "abc123"] + public var payload: Payload + + /// The push timeout. Default is 10.0 seconds + public var timeout: TimeInterval + + /// The server's response to the Push + var receivedMessage: Message? + + /// Timer which triggers a timeout event + var timeoutTimer: TimerQueue + + /// WorkItem to be performed when the timeout timer fires + var timeoutWorkItem: DispatchWorkItem? + + /// Hooks into a Push. Where .receive("ok", callback(Payload)) are stored + var receiveHooks: [PushStatus: [Delegated]] + + /// True if the Push has been sent + var sent: Bool + + /// The reference ID of the Push + var ref: String? + + /// The event that is associated with the reference ID of the Push + var refEvent: ChannelEvent? + + /// Initializes a Push + /// + /// - parameter channel: The Channel + /// - parameter event: The event, for example ChannelEvent.join + /// - parameter payload: Optional. The Payload to send, e.g. ["user_id": "abc123"] + /// - parameter timeout: Optional. The push timeout. Default is 10.0s + init( + channel: Channel, + event: ChannelEvent, + payload: Payload = [:], + timeout: TimeInterval = Defaults.timeoutInterval + ) { + self.channel = channel + self.event = event + self.payload = payload + self.timeout = timeout + receivedMessage = nil + timeoutTimer = TimerQueue.main + receiveHooks = [:] + sent = false + ref = nil + } + + /// Resets and sends the Push + /// - parameter timeout: Optional. The push timeout. Default is 10.0s + public func resend(_ timeout: TimeInterval = Defaults.timeoutInterval) { + self.timeout = timeout + reset() + send() + } + + /// Sends the Push. If it has already timed out, then the call will + /// be ignored and return early. Use `resend` in this case. + public func send() { + guard !hasReceived(status: .timeout) else { return } + + startTimeout() + sent = true + channel?.socket?.push( + topic: channel?.topic ?? .all, + event: event, + payload: payload, + ref: ref, + joinRef: channel?.joinRef + ) + } + + /// Receive a specific event when sending an Outbound message. Subscribing + /// to status events with this method does not guarantees no retain cycles. + /// You should pass `weak self` in the capture list of the callback. You + /// can call `.delegateReceive(status:, to:, callback:) and the library will + /// handle it for you. + /// + /// Example: + /// + /// channel + /// .send(event:"custom", payload: ["body": "example"]) + /// .receive("error") { [weak self] payload in + /// print("Error: ", payload) + /// } + /// + /// - parameter status: Status to receive + /// - parameter callback: Callback to fire when the status is recevied + @discardableResult + public func receive( + _ status: PushStatus, + callback: @escaping ((Message) -> Void) + ) -> Push { + var delegated = Delegated() + delegated.manuallyDelegate(with: callback) + + return receive(status, delegated: delegated) + } + + /// Receive a specific event when sending an Outbound message. Automatically + /// prevents retain cycles. See `manualReceive(status:, callback:)` if you + /// want to handle this yourself. + /// + /// Example: + /// + /// channel + /// .send(event:"custom", payload: ["body": "example"]) + /// .delegateReceive("error", to: self) { payload in + /// print("Error: ", payload) + /// } + /// + /// - parameter status: Status to receive + /// - parameter owner: The class that is calling .receive. Usually `self` + /// - parameter callback: Callback to fire when the status is recevied + @discardableResult + public func delegateReceive( + _ status: PushStatus, + to owner: Target, + callback: @escaping ((Target, Message) -> Void) + ) -> Push { + var delegated = Delegated() + delegated.delegate(to: owner, with: callback) + + return receive(status, delegated: delegated) + } + + /// Shared behavior between `receive` calls + @discardableResult + internal func receive(_ status: PushStatus, delegated: Delegated) -> Push { + // If the message has already been received, pass it to the callback immediately + if hasReceived(status: status), let receivedMessage = receivedMessage { + delegated.call(receivedMessage) + } + + if receiveHooks[status] == nil { + /// Create a new array of hooks if no previous hook is associated with status + receiveHooks[status] = [delegated] + } else { + /// A previous hook for this status already exists. Just append the new hook + receiveHooks[status]?.append(delegated) + } + + return self + } + + /// Resets the Push as it was after it was first tnitialized. + internal func reset() { + cancelRefEvent() + ref = nil + refEvent = nil + receivedMessage = nil + sent = false + } + + /// Finds the receiveHook which needs to be informed of a status response + /// + /// - parameter status: Status which was received, e.g. "ok", "error", "timeout" + /// - parameter response: Response that was received + private func matchReceive(_ status: PushStatus, message: Message) { + receiveHooks[status]?.forEach { $0.call(message) } + } + + /// Reverses the result on channel.on(ChannelEvent, callback) that spawned the Push + private func cancelRefEvent() { + guard let refEvent = refEvent else { return } + channel?.off(refEvent) + } + + /// Cancel any ongoing Timeout Timer + internal func cancelTimeout() { + timeoutWorkItem?.cancel() + timeoutWorkItem = nil + } + + /// Starts the Timer which will trigger a timeout after a specific _timeout_ + /// time, in milliseconds, is reached. + internal func startTimeout() { + // Cancel any existing timeout before starting a new one + if let safeWorkItem = timeoutWorkItem, !safeWorkItem.isCancelled { + cancelTimeout() + } + + guard + let channel = channel, + let socket = channel.socket + else { return } + + let ref = socket.makeRef() + let refEvent = ChannelEvent.channelReply(ref) + + self.ref = ref + self.refEvent = refEvent + + /// If a response is received before the Timer triggers, cancel timer + /// and match the recevied event to it's corresponding hook + channel.delegateOn(refEvent, to: self) { (self, message) in + self.cancelRefEvent() + self.cancelTimeout() + self.receivedMessage = message + + /// Check if there is event a status available + guard let status = message.status else { return } + self.matchReceive(status, message: message) + } + + /// Setup and start the Timeout timer. + let workItem = DispatchWorkItem { + self.trigger(.timeout, payload: [:]) + } + + timeoutWorkItem = workItem + timeoutTimer.queue(timeInterval: timeout, execute: workItem) + } + + /// Checks if a status has already been received by the Push. + /// + /// - parameter status: Status to check + /// - return: True if given status has been received by the Push. + internal func hasReceived(status: PushStatus) -> Bool { + return receivedMessage?.status == status + } + + /// Triggers an event to be sent though the Channel + internal func trigger(_ status: PushStatus, payload: Payload) { + /// If there is no ref event, then there is nothing to trigger on the channel + guard let refEvent = refEvent else { return } + + var mutPayload = payload + mutPayload["status"] = status + + channel?.trigger(event: refEvent, payload: mutPayload) + } +} diff --git a/Sources/Realtime/RealtimeClient.swift b/Sources/Realtime/RealtimeClient.swift new file mode 100644 index 00000000..3133a2d1 --- /dev/null +++ b/Sources/Realtime/RealtimeClient.swift @@ -0,0 +1,1015 @@ +// Copyright (c) 2021 David Stump +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +import Foundation + +public enum SocketError: Error { + case abnormalClosureError +} + +/// Alias for a JSON dictionary [String: Any] +public typealias Payload = [String: Any] + +/// Alias for a function returning an optional JSON dictionary (`Payload?`) +public typealias PayloadClosure = () -> Payload? + +/// Struct that gathers callbacks assigned to the Socket +struct StateChangeCallbacks { + var open: [(ref: String, callback: Delegated)] = [] + var close: [(ref: String, callback: Delegated<(Int, String?), Void>)] = [] + var error: [(ref: String, callback: Delegated<(Error, URLResponse?), Void>)] = [] + var message: [(ref: String, callback: Delegated)] = [] +} + +/// ## Socket Connection +/// A single connection is established to the server and +/// channels are multiplexed over the connection. +/// Connect to the server using the `Socket` class: +/// +/// ```swift +/// let socket = new Socket("/socket", paramsClosure: { ["userToken": "123" ] }) +/// socket.connect() +/// ``` +/// +/// The `Socket` constructor takes the mount point of the socket, +/// the authentication params, as well as options that can be found in +/// the Socket docs, such as configuring the heartbeat. +public class RealtimeClient: TransportDelegate { + // ---------------------------------------------------------------------- + + // MARK: - Public Attributes + + // ---------------------------------------------------------------------- + /// The string WebSocket endpoint (ie `"ws://example.com/socket"`, + /// `"wss://example.com"`, etc.) That was passed to the Socket during + /// initialization. The URL endpoint will be modified by the Socket to + /// include `"/websocket"` if missing. + public let endPoint: String + + /// The fully qualified socket URL + public private(set) var endPointUrl: URL + + /// Resolves to return the `paramsClosure` result at the time of calling. + /// If the `Socket` was created with static params, then those will be + /// returned every time. + public var params: Payload? { + return paramsClosure?() + } + + /// The optional params closure used to get params when connecting. Must + /// be set when initializing the Socket. + public let paramsClosure: PayloadClosure? + + /// The WebSocket transport. Default behavior is to provide a + /// URLSessionWebsocketTask. See README for alternatives. + private let transport: (URL) -> Transport + + /// Phoenix serializer version, defaults to "2.0.0" + public let vsn: String + + /// Override to provide custom encoding of data before writing to the socket + public var encode: (Any) -> Data = Defaults.encode + + /// Override to provide custom decoding of data read from the socket + public var decode: (Data) -> Any? = Defaults.decode + + /// Timeout to use when opening connections + public var timeout: TimeInterval = Defaults.timeoutInterval + + /// Interval between sending a heartbeat + public var heartbeatInterval: TimeInterval = Defaults.heartbeatInterval + + /// The maximum amount of time which the system may delay heartbeats in order to optimize power usage + public var heartbeatLeeway: DispatchTimeInterval = Defaults.heartbeatLeeway + + /// Interval between socket reconnect attempts, in seconds + public var reconnectAfter: (Int) -> TimeInterval = Defaults.reconnectSteppedBackOff + + /// Interval between channel rejoin attempts, in seconds + public var rejoinAfter: (Int) -> TimeInterval = Defaults.rejoinSteppedBackOff + + /// The optional function to receive logs + public var logger: ((String) -> Void)? + + /// Disables heartbeats from being sent. Default is false. + public var skipHeartbeat: Bool = false + + /// Enable/Disable SSL certificate validation. Default is false. This + /// must be set before calling `socket.connect()` in order to be applied + public var disableSSLCertValidation: Bool = false + + #if os(Linux) + #else + /// Configure custom SSL validation logic, eg. SSL pinning. This + /// must be set before calling `socket.connect()` in order to apply. + // public var security: SSLTrustValidator? + + /// Configure the encryption used by your client by setting the + /// allowed cipher suites supported by your server. This must be + /// set before calling `socket.connect()` in order to apply. + public var enabledSSLCipherSuites: [SSLCipherSuite]? + #endif + + // ---------------------------------------------------------------------- + + // MARK: - Private Attributes + + // ---------------------------------------------------------------------- + /// Callbacks for socket state changes + var stateChangeCallbacks: StateChangeCallbacks = .init() + + /// Collection on channels created for the Socket + public internal(set) var channels: [Channel] = [] + + /// Buffers messages that need to be sent once the socket has connected. It is an array + /// of tuples, with the ref of the message to send and the callback that will send the message. + var sendBuffer: [(ref: String?, callback: () throws -> Void)] = [] + + /// Ref counter for messages + var ref: UInt64 = .min // 0 (max: 18,446,744,073,709,551,615) + + /// Timer that triggers sending new Heartbeat messages + var heartbeatTimer: HeartbeatTimer? + + /// Ref counter for the last heartbeat that was sent + var pendingHeartbeatRef: String? + + /// Timer to use when attempting to reconnect + var reconnectTimer: TimeoutTimer + + /// Close status + var closeStatus: CloseStatus = .unknown + + /// The connection to the server + var connection: Transport? = nil + + // ---------------------------------------------------------------------- + + // MARK: - Initialization + + // ---------------------------------------------------------------------- + @available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *) + public convenience init( + _ endPoint: String, + params: Payload? = nil, + vsn: String = Defaults.vsn + ) { + self.init( + endPoint: endPoint, + transport: { url in URLSessionTransport(url: url) }, + paramsClosure: { params }, + vsn: vsn + ) + } + + @available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *) + public convenience init( + _ endPoint: String, + paramsClosure: PayloadClosure?, + vsn: String = Defaults.vsn + ) { + self.init( + endPoint: endPoint, + transport: { url in URLSessionTransport(url: url) }, + paramsClosure: paramsClosure, + vsn: vsn + ) + } + + @available(*, deprecated, renamed: "init(_:params:vsn:)") + public convenience init( + endPoint: String, + params: Payload? = nil, + vsn: String = Defaults.vsn + ) { + self.init( + endPoint: endPoint, + transport: { url in URLSessionTransport(url: url) }, + paramsClosure: { params }, + vsn: vsn + ) + } + + public init( + endPoint: String, + transport: @escaping ((URL) -> Transport), + paramsClosure: PayloadClosure? = nil, + vsn: String = Defaults.vsn + ) { + self.transport = transport + self.paramsClosure = paramsClosure + self.endPoint = endPoint + self.vsn = vsn + endPointUrl = RealtimeClient.buildEndpointUrl( + endpoint: endPoint, + paramsClosure: paramsClosure, + vsn: vsn + ) + + reconnectTimer = TimeoutTimer() + reconnectTimer.callback.delegate(to: self) { (self) in + self.logItems("Socket attempting to reconnect") + self.teardown(reason: "reconnection") { self.connect() } + } + reconnectTimer.timerCalculation + .delegate(to: self) { (self, tries) -> TimeInterval in + let interval = self.reconnectAfter(tries) + self.logItems("Socket reconnecting in \(interval)s") + return interval + } + } + + deinit { + reconnectTimer.reset() + } + + // ---------------------------------------------------------------------- + + // MARK: - Public + + // ---------------------------------------------------------------------- + /// - return: The socket protocol, wss or ws + public var websocketProtocol: String { + switch endPointUrl.scheme { + case "https": return "wss" + case "http": return "ws" + default: return endPointUrl.scheme ?? "" + } + } + + /// - return: True if the socket is connected + public var isConnected: Bool { + return connectionState == .open + } + + /// - return: The state of the connect. [.connecting, .open, .closing, .closed] + public var connectionState: TransportReadyState { + return connection?.readyState ?? .closed + } + + /// Connects the Socket. The params passed to the Socket on initialization + /// will be sent through the connection. If the Socket is already connected, + /// then this call will be ignored. + public func connect() { + // Do not attempt to reconnect if the socket is currently connected + guard !isConnected else { return } + + // Reset the close status when attempting to connect + closeStatus = .unknown + + // We need to build this right before attempting to connect as the + // parameters could be built upon demand and change over time + endPointUrl = RealtimeClient.buildEndpointUrl( + endpoint: endPoint, + paramsClosure: paramsClosure, + vsn: vsn + ) + + connection = transport(endPointUrl) + connection?.delegate = self + // self.connection?.disableSSLCertValidation = disableSSLCertValidation + // + // #if os(Linux) + // #else + // self.connection?.security = security + // self.connection?.enabledSSLCipherSuites = enabledSSLCipherSuites + // #endif + + connection?.connect() + } + + /// Disconnects the socket + /// + /// - parameter code: Optional. Closing status code + /// - parameter callback: Optional. Called when disconnected + public func disconnect( + code: CloseCode = CloseCode.normal, + reason: String? = nil, + callback: (() -> Void)? = nil + ) { + // The socket was closed cleanly by the User + closeStatus = CloseStatus(closeCode: code.rawValue) + + // Reset any reconnects and teardown the socket connection + reconnectTimer.reset() + teardown(code: code, reason: reason, callback: callback) + } + + internal func teardown( + code: CloseCode = CloseCode.normal, reason: String? = nil, callback: (() -> Void)? = nil + ) { + connection?.delegate = nil + connection?.disconnect(code: code.rawValue, reason: reason) + connection = nil + + // The socket connection has been torndown, heartbeats are not needed + heartbeatTimer?.stop() + + // Since the connection's delegate was nil'd out, inform all state + // callbacks that the connection has closed + stateChangeCallbacks.close.forEach { $0.callback.call((code.rawValue, reason)) } + callback?() + } + + // ---------------------------------------------------------------------- + + // MARK: - Register Socket State Callbacks + + // ---------------------------------------------------------------------- + + /// Registers callbacks for connection open events. Does not handle retain + /// cycles. Use `delegateOnOpen(to:)` for automatic handling of retain cycles. + /// + /// Example: + /// + /// socket.onOpen() { [weak self] in + /// self?.print("Socket Connection Open") + /// } + /// + /// - parameter callback: Called when the Socket is opened + @discardableResult + public func onOpen(callback: @escaping () -> Void) -> String { + return onOpen { _ in callback() } + } + + /// Registers callbacks for connection open events. Does not handle retain + /// cycles. Use `delegateOnOpen(to:)` for automatic handling of retain cycles. + /// + /// Example: + /// + /// socket.onOpen() { [weak self] response in + /// self?.print("Socket Connection Open") + /// } + /// + /// - parameter callback: Called when the Socket is opened + @discardableResult + public func onOpen(callback: @escaping (URLResponse?) -> Void) -> String { + var delegated = Delegated() + delegated.manuallyDelegate(with: callback) + + return append(callback: delegated, to: &stateChangeCallbacks.open) + } + + /// Registers callbacks for connection open events. Automatically handles + /// retain cycles. Use `onOpen()` to handle yourself. + /// + /// Example: + /// + /// socket.delegateOnOpen(to: self) { self in + /// self.print("Socket Connection Open") + /// } + /// + /// - parameter owner: Class registering the callback. Usually `self` + /// - parameter callback: Called when the Socket is opened + @discardableResult + public func delegateOnOpen( + to owner: T, + callback: @escaping ((T) -> Void) + ) -> String { + return delegateOnOpen(to: owner) { owner, _ in callback(owner) } + } + + /// Registers callbacks for connection open events. Automatically handles + /// retain cycles. Use `onOpen()` to handle yourself. + /// + /// Example: + /// + /// socket.delegateOnOpen(to: self) { self, response in + /// self.print("Socket Connection Open") + /// } + /// + /// - parameter owner: Class registering the callback. Usually `self` + /// - parameter callback: Called when the Socket is opened + @discardableResult + public func delegateOnOpen( + to owner: T, + callback: @escaping ((T, URLResponse?) -> Void) + ) -> String { + var delegated = Delegated() + delegated.delegate(to: owner, with: callback) + + return append(callback: delegated, to: &stateChangeCallbacks.open) + } + + /// Registers callbacks for connection close events. Does not handle retain + /// cycles. Use `delegateOnClose(_:)` for automatic handling of retain cycles. + /// + /// Example: + /// + /// socket.onClose() { [weak self] in + /// self?.print("Socket Connection Close") + /// } + /// + /// - parameter callback: Called when the Socket is closed + @discardableResult + public func onClose(callback: @escaping () -> Void) -> String { + return onClose { _, _ in callback() } + } + + /// Registers callbacks for connection close events. Does not handle retain + /// cycles. Use `delegateOnClose(_:)` for automatic handling of retain cycles. + /// + /// Example: + /// + /// socket.onClose() { [weak self] code, reason in + /// self?.print("Socket Connection Close") + /// } + /// + /// - parameter callback: Called when the Socket is closed + @discardableResult + public func onClose(callback: @escaping (Int, String?) -> Void) -> String { + var delegated = Delegated<(Int, String?), Void>() + delegated.manuallyDelegate(with: callback) + + return append(callback: delegated, to: &stateChangeCallbacks.close) + } + + /// Registers callbacks for connection close events. Automatically handles + /// retain cycles. Use `onClose()` to handle yourself. + /// + /// Example: + /// + /// socket.delegateOnClose(self) { self in + /// self.print("Socket Connection Close") + /// } + /// + /// - parameter owner: Class registering the callback. Usually `self` + /// - parameter callback: Called when the Socket is closed + @discardableResult + public func delegateOnClose( + to owner: T, + callback: @escaping ((T) -> Void) + ) -> String { + return delegateOnClose(to: owner) { owner, _ in callback(owner) } + } + + /// Registers callbacks for connection close events. Automatically handles + /// retain cycles. Use `onClose()` to handle yourself. + /// + /// Example: + /// + /// socket.delegateOnClose(self) { self, code, reason in + /// self.print("Socket Connection Close") + /// } + /// + /// - parameter owner: Class registering the callback. Usually `self` + /// - parameter callback: Called when the Socket is closed + @discardableResult + public func delegateOnClose( + to owner: T, + callback: @escaping ((T, (Int, String?)) -> Void) + ) -> String { + var delegated = Delegated<(Int, String?), Void>() + delegated.delegate(to: owner, with: callback) + + return append(callback: delegated, to: &stateChangeCallbacks.close) + } + + /// Registers callbacks for connection error events. Does not handle retain + /// cycles. Use `delegateOnError(to:)` for automatic handling of retain cycles. + /// + /// Example: + /// + /// socket.onError() { [weak self] (error) in + /// self?.print("Socket Connection Error", error) + /// } + /// + /// - parameter callback: Called when the Socket errors + @discardableResult + public func onError(callback: @escaping ((Error, URLResponse?)) -> Void) -> String { + var delegated = Delegated<(Error, URLResponse?), Void>() + delegated.manuallyDelegate(with: callback) + + return append(callback: delegated, to: &stateChangeCallbacks.error) + } + + /// Registers callbacks for connection error events. Automatically handles + /// retain cycles. Use `manualOnError()` to handle yourself. + /// + /// Example: + /// + /// socket.delegateOnError(to: self) { (self, error) in + /// self.print("Socket Connection Error", error) + /// } + /// + /// - parameter owner: Class registering the callback. Usually `self` + /// - parameter callback: Called when the Socket errors + @discardableResult + public func delegateOnError( + to owner: T, + callback: @escaping ((T, (Error, URLResponse?)) -> Void) + ) -> String { + var delegated = Delegated<(Error, URLResponse?), Void>() + delegated.delegate(to: owner, with: callback) + + return append(callback: delegated, to: &stateChangeCallbacks.error) + } + + /// Registers callbacks for connection message events. Does not handle + /// retain cycles. Use `delegateOnMessage(_to:)` for automatic handling of + /// retain cycles. + /// + /// Example: + /// + /// socket.onMessage() { [weak self] (message) in + /// self?.print("Socket Connection Message", message) + /// } + /// + /// - parameter callback: Called when the Socket receives a message event + @discardableResult + public func onMessage(callback: @escaping (Message) -> Void) -> String { + var delegated = Delegated() + delegated.manuallyDelegate(with: callback) + + return append(callback: delegated, to: &stateChangeCallbacks.message) + } + + /// Registers callbacks for connection message events. Automatically handles + /// retain cycles. Use `onMessage()` to handle yourself. + /// + /// Example: + /// + /// socket.delegateOnMessage(self) { (self, message) in + /// self.print("Socket Connection Message", message) + /// } + /// + /// - parameter owner: Class registering the callback. Usually `self` + /// - parameter callback: Called when the Socket receives a message event + @discardableResult + public func delegateOnMessage( + to owner: T, + callback: @escaping ((T, Message) -> Void) + ) -> String { + var delegated = Delegated() + delegated.delegate(to: owner, with: callback) + + return append(callback: delegated, to: &stateChangeCallbacks.message) + } + + private func append(callback: T, to array: inout [(ref: String, callback: T)]) -> String { + let ref = makeRef() + array.append((ref, callback)) + return ref + } + + /// Releases all stored callback hooks (onError, onOpen, onClose, etc.) You should + /// call this method when you are finished when the Socket in order to release + /// any references held by the socket. + public func releaseCallbacks() { + stateChangeCallbacks.open.removeAll() + stateChangeCallbacks.close.removeAll() + stateChangeCallbacks.error.removeAll() + stateChangeCallbacks.message.removeAll() + } + + // ---------------------------------------------------------------------- + + // MARK: - Channel Initialization + // ---------------------------------------------------------------------- + /// Initialize a new Channel + /// + /// Example: + /// + /// let channel = socket.channel("rooms", options: ChannelOptions(presenceKey: "user123")) + /// + /// - parameter topic: Topic of the channel + /// - parameter options: Optional. Options to configure channel broadcast and presence. Leave nil for postgres channel. + /// - return: A new channel + public func channel( + _ topic: ChannelTopic, + options: ChannelOptions? = nil + ) -> Channel { + let channel = Channel(topic: topic, options: options, socket: self) + channels.append(channel) + return channel + } + // ---------------------------------------------------------------------- + /// Initialize a new Channel + /// + /// Example: + /// + /// let channel = socket.channel("rooms", params: ["user_id": "abc123"]) + /// + /// - parameter topic: Topic of the channel + /// - parameter params: Optional. Parameters for the channel + /// - return: A new channel + @available(*, deprecated, renamed: "channel(_:options:)") + public func channel( + _ topic: ChannelTopic, + params: [String: Any] + ) -> Channel { + let channel = Channel(topic: topic, params: params, socket: self) + channels.append(channel) + + return channel + } + + /// Removes the Channel from the socket. This does not cause the channel to + /// inform the server that it is leaving. You should call channel.leave() + /// prior to removing the Channel. + /// + /// Example: + /// + /// channel.leave() + /// socket.remove(channel) + /// + /// - parameter channel: Channel to remove + public func remove(_ channel: Channel) { + off(channel.stateChangeRefs) + channels.removeAll(where: { $0.joinRef == channel.joinRef }) + } + + /// Removes `onOpen`, `onClose`, `onError,` and `onMessage` registrations. + /// + /// + /// - Parameter refs: List of refs returned by calls to `onOpen`, `onClose`, etc + public func off(_ refs: [String]) { + stateChangeCallbacks.open = stateChangeCallbacks.open.filter { + !refs.contains($0.ref) + } + stateChangeCallbacks.close = stateChangeCallbacks.close.filter { + !refs.contains($0.ref) + } + stateChangeCallbacks.error = stateChangeCallbacks.error.filter { + !refs.contains($0.ref) + } + stateChangeCallbacks.message = stateChangeCallbacks.message.filter { + !refs.contains($0.ref) + } + } + + // ---------------------------------------------------------------------- + + // MARK: - Sending Data + + // ---------------------------------------------------------------------- + /// Sends data through the Socket. This method is internal. Instead, you + /// should call `push(_:, payload:, timeout:)` on the Channel you are + /// sending an event to. + /// + /// - parameter topic: + /// - parameter event: + /// - parameter payload: + /// - parameter ref: Optional. Defaults to nil + /// - parameter joinRef: Optional. Defaults to nil + internal func push( + topic: ChannelTopic, + event: ChannelEvent, + payload: Payload, + ref: String? = nil, + joinRef: String? = nil + ) { + let callback: (() throws -> Void) = { + let body: [Any?] = [joinRef, ref, topic.rawValue, event.rawValue, payload] + let data = self.encode(body) + + self.logItems("push", "Sending \(String(data: data, encoding: String.Encoding.utf8) ?? "")") + self.connection?.send(data: data) + } + + /// If the socket is connected, then execute the callback immediately. + if isConnected { + try? callback() + } else { + /// If the socket is not connected, add the push to a buffer which will + /// be sent immediately upon connection. + sendBuffer.append((ref: ref, callback: callback)) + } + } + + /// - return: the next message ref, accounting for overflows + public func makeRef() -> String { + ref = (ref == UInt64.max) ? 0 : ref + 1 + return String(ref) + } + + /// Logs the message. Override Socket.logger for specialized logging. noops by default + /// + /// - parameter items: List of items to be logged. Behaves just like debugPrint() + func logItems(_ items: Any...) { + let msg = items.map { String(describing: $0) }.joined(separator: ", ") + logger?("SwiftPhoenixClient: \(msg)") + } + + // ---------------------------------------------------------------------- + + // MARK: - Connection Events + + // ---------------------------------------------------------------------- + /// Called when the underlying Websocket connects to it's host + internal func onConnectionOpen(response: URLResponse?) { + logItems("transport", "Connected to \(endPoint)") + + // Reset the close status now that the socket has been connected + closeStatus = .unknown + + // Send any messages that were waiting for a connection + flushSendBuffer() + + // Reset how the socket tried to reconnect + reconnectTimer.reset() + + // Restart the heartbeat timer + resetHeartbeat() + + // Inform all onOpen callbacks that the Socket has opened + stateChangeCallbacks.open.forEach { $0.callback.call(response) } + } + + internal func onConnectionClosed(code: Int, reason: String?) { + logItems("transport", "close") + + // Send an error to all channels + triggerChannelError() + + // Prevent the heartbeat from triggering if the + heartbeatTimer?.stop() + + // Only attempt to reconnect if the socket did not close normally, + // or if it was closed abnormally but on client side (e.g. due to heartbeat timeout) + if closeStatus.shouldReconnect { + reconnectTimer.scheduleTimeout() + } + + stateChangeCallbacks.close.forEach { $0.callback.call((code, reason)) } + } + + internal func onConnectionError(_ error: Error, response: URLResponse?) { + logItems("transport", error, response ?? "") + + // Send an error to all channels + triggerChannelError() + + // Inform any state callbacks of the error + stateChangeCallbacks.error.forEach { $0.callback.call((error, response)) } + } + + internal func onConnectionMessage(_ rawMessage: String) { + logItems("receive ", rawMessage) + + guard + let data = rawMessage.data(using: String.Encoding.utf8), + let json = decode(data) as? [Any?], + let message = Message(json: json) + else { + logItems("receive: Unable to parse JSON: \(rawMessage)") + return + } + + // Clear heartbeat ref, preventing a heartbeat timeout disconnect + if message.ref == pendingHeartbeatRef { pendingHeartbeatRef = nil } + + if message.event == .close { + print("Close Event Received") + } + + // Dispatch the message to all channels that belong to the topic + channels + .filter { $0.isMember(message) } + .forEach { $0.trigger(message) } + + // Inform all onMessage callbacks of the message + stateChangeCallbacks.message.forEach { $0.callback.call(message) } + } + + /// Triggers an error event to all of the connected Channels + internal func triggerChannelError() { + channels.forEach { channel in + // Only trigger a channel error if it is in an "opened" state + if !(channel.isErrored || channel.isLeaving || channel.isClosed) { + channel.trigger(event: ChannelEvent.error) + } + } + } + + /// Send all messages that were buffered before the socket opened + internal func flushSendBuffer() { + guard isConnected && sendBuffer.count > 0 else { return } + sendBuffer.forEach { try? $0.callback() } + sendBuffer = [] + } + + /// Removes an item from the sendBuffer with the matching ref + internal func removeFromSendBuffer(ref: String) { + sendBuffer = sendBuffer.filter { $0.ref != ref } + } + + /// Builds a fully qualified socket `URL` from `endPoint` and `params`. + internal static func buildEndpointUrl( + endpoint: String, paramsClosure params: PayloadClosure?, vsn: String + ) -> URL { + guard + let url = URL(string: endpoint), + var urlComponents = URLComponents(url: url, resolvingAgainstBaseURL: false) + else { fatalError("Malformed URL: \(endpoint)") } + + // Ensure that the URL ends with "/websocket + if !urlComponents.path.contains("/websocket") { + // Do not duplicate '/' in the path + if urlComponents.path.last != "/" { + urlComponents.path.append("/") + } + + // append 'websocket' to the path + urlComponents.path.append("websocket") + } + + urlComponents.queryItems = [URLQueryItem(name: "vsn", value: vsn)] + + // If there are parameters, append them to the URL + if let params = params?() { + urlComponents.queryItems?.append( + contentsOf: params.map { + URLQueryItem(name: $0.key, value: String(describing: $0.value)) + }) + } + + guard let qualifiedUrl = urlComponents.url + else { fatalError("Malformed URL while adding parameters") } + return qualifiedUrl + } + + // Leaves any channel that is open that has a duplicate topic + internal func leaveOpenTopic(topic: ChannelTopic) { + guard + let dupe = channels.first(where: { $0.topic == topic && ($0.isJoined || $0.isJoining) }) + else { return } + + logItems("transport", "leaving duplicate topic: [\(topic)]") + dupe.leave() + } + + // ---------------------------------------------------------------------- + + // MARK: - Heartbeat + + // ---------------------------------------------------------------------- + internal func resetHeartbeat() { + // Clear anything related to the heartbeat + pendingHeartbeatRef = nil + heartbeatTimer?.stop() + + // Do not start up the heartbeat timer if skipHeartbeat is true + guard !skipHeartbeat else { return } + + heartbeatTimer = HeartbeatTimer(timeInterval: heartbeatInterval, leeway: heartbeatLeeway) + heartbeatTimer?.start(eventHandler: { [weak self] in + self?.sendHeartbeat() + }) + } + + /// Sends a heartbeat payload to the phoenix servers + @objc func sendHeartbeat() { + // Do not send if the connection is closed + guard isConnected else { return } + + // If there is a pending heartbeat ref, then the last heartbeat was + // never acknowledged by the server. Close the connection and attempt + // to reconnect. + if let _ = pendingHeartbeatRef { + pendingHeartbeatRef = nil + logItems( + "transport", + "heartbeat timeout. Attempting to re-establish connection" + ) + + // Close the socket manually, flagging the closure as abnormal. Do not use + // `teardown` or `disconnect` as they will nil out the websocket delegate. + abnormalClose("heartbeat timeout") + + return + } + + // The last heartbeat was acknowledged by the server. Send another one + pendingHeartbeatRef = makeRef() + push( + topic: .heartbeat, + event: ChannelEvent.heartbeat, + payload: [:], + ref: pendingHeartbeatRef + ) + } + + internal func abnormalClose(_ reason: String) { + closeStatus = .abnormal + + /* + We use NORMAL here since the client is the one determining to close the + connection. However, we set to close status to abnormal so that + the client knows that it should attempt to reconnect. + + If the server subsequently acknowledges with code 1000 (normal close), + the socket will keep the `.abnormal` close status and trigger a reconnection. + */ + connection?.disconnect(code: CloseCode.normal.rawValue, reason: reason) + } + + // ---------------------------------------------------------------------- + + // MARK: - TransportDelegate + + // ---------------------------------------------------------------------- + public func onOpen(response: URLResponse?) { + onConnectionOpen(response: response) + } + + public func onError(error: Error, response: URLResponse?) { + onConnectionError(error, response: response) + } + + public func onMessage(message: String) { + onConnectionMessage(message) + } + + public func onClose(code: Int, reason: String? = nil) { + closeStatus.update(transportCloseCode: code) + onConnectionClosed(code: code, reason: reason) + } +} + +// ---------------------------------------------------------------------- + +// MARK: - Close Codes + +// ---------------------------------------------------------------------- +extension RealtimeClient { + public enum CloseCode: Int { + case abnormal = 999 + + case normal = 1000 + + case goingAway = 1001 + } +} + +// ---------------------------------------------------------------------- + +// MARK: - Close Status + +// ---------------------------------------------------------------------- +extension RealtimeClient { + /// Indicates the different closure states a socket can be in. + enum CloseStatus { + /// Undetermined closure state + case unknown + /// A clean closure requested either by the client or the server + case clean + /// An abnormal closure requested by the client + case abnormal + + /// Temporarily close the socket, pausing reconnect attempts. Useful on mobile + /// clients when disconnecting a because the app resigned active but should + /// reconnect when app enters active state. + case temporary + + init(closeCode: Int) { + switch closeCode { + case CloseCode.abnormal.rawValue: + self = .abnormal + case CloseCode.goingAway.rawValue: + self = .temporary + default: + self = .clean + } + } + + mutating func update(transportCloseCode: Int) { + switch self { + case .unknown, .clean, .temporary: + // Allow transport layer to override these statuses. + self = .init(closeCode: transportCloseCode) + case .abnormal: + // Do not allow transport layer to override the abnormal close status. + // The socket itself should reset it on the next connection attempt. + // See `Socket.abnormalClose(_:)` for more information. + break + } + } + + var shouldReconnect: Bool { + switch self { + case .unknown, .abnormal: + return true + case .clean, .temporary: + return false + } + } + } +} diff --git a/Sources/Realtime/SynchronizedArray.swift b/Sources/Realtime/SynchronizedArray.swift new file mode 100644 index 00000000..e7345ce3 --- /dev/null +++ b/Sources/Realtime/SynchronizedArray.swift @@ -0,0 +1,33 @@ +// +// SynchronizedArray.swift +// SwiftPhoenixClient +// +// Created by Daniel Rees on 4/12/23. +// Copyright © 2023 SwiftPhoenixClient. All rights reserved. +// + +import Foundation + +/// A thread-safe array. +public class SynchronizedArray { + fileprivate let queue = DispatchQueue(label: "spc_sync_array", attributes: .concurrent) + fileprivate var array = [Element]() + + func append(_ newElement: Element) { + queue.async(flags: .barrier) { + self.array.append(newElement) + } + } + + func removeAll(where shouldBeRemoved: @escaping (Element) -> Bool) { + queue.async(flags: .barrier) { + self.array.removeAll(where: shouldBeRemoved) + } + } + + func filter(_ isIncluded: (Element) -> Bool) -> [Element] { + var result = [Element]() + queue.sync { result = self.array.filter(isIncluded) } + return result + } +} diff --git a/Sources/Realtime/TimeoutTimer.swift b/Sources/Realtime/TimeoutTimer.swift new file mode 100644 index 00000000..b6b37c4c --- /dev/null +++ b/Sources/Realtime/TimeoutTimer.swift @@ -0,0 +1,108 @@ +// Copyright (c) 2021 David Stump +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +/// Creates a timer that can perform calculated reties by setting +/// `timerCalculation` , such as exponential backoff. +/// +/// ### Example +/// +/// let reconnectTimer = TimeoutTimer() +/// +/// // Receive a callbcak when the timer is fired +/// reconnectTimer.callback.delegate(to: self) { (_) in +/// print("timer was fired") +/// } +/// +/// // Provide timer interval calculation +/// reconnectTimer.timerCalculation.delegate(to: self) { (_, tries) -> TimeInterval in +/// return tries > 2 ? 1000 : [1000, 5000, 10000][tries - 1] +/// } +/// +/// reconnectTimer.scheduleTimeout() // fires after 1000ms +/// reconnectTimer.scheduleTimeout() // fires after 5000ms +/// reconnectTimer.reset() +/// reconnectTimer.scheduleTimeout() // fires after 1000ms + +import Foundation + +// sourcery: AutoMockable +class TimeoutTimer { + /// Callback to be informed when the underlying Timer fires + var callback = Delegated() + + /// Provides TimeInterval to use when scheduling the timer + var timerCalculation = Delegated() + + /// The work to be done when the queue fires + var workItem: DispatchWorkItem? + + /// The number of times the underlyingTimer hass been set off. + var tries: Int = 0 + + /// The Queue to execute on. In testing, this is overridden + var queue: TimerQueue = .main + + /// Resets the Timer, clearing the number of tries and stops + /// any scheduled timeout. + func reset() { + tries = 0 + clearTimer() + } + + /// Schedules a timeout callback to fire after a calculated timeout duration. + func scheduleTimeout() { + // Clear any ongoing timer, not resetting the number of tries + clearTimer() + + // Get the next calculated interval, in milliseconds. Do not + // start the timer if the interval is returned as nil. + guard let timeInterval = timerCalculation.call(tries + 1) else { return } + + let workItem = DispatchWorkItem { + self.tries += 1 + self.callback.call() + } + + self.workItem = workItem + queue.queue(timeInterval: timeInterval, execute: workItem) + } + + /// Invalidates any ongoing Timer. Will not clear how many tries have been made + private func clearTimer() { + workItem?.cancel() + workItem = nil + } +} + +/// Wrapper class around a DispatchQueue. Allows for providing a fake clock +/// during tests. +class TimerQueue { + // Can be overriden in tests + static var main = TimerQueue() + + func queue(timeInterval: TimeInterval, execute: DispatchWorkItem) { + // TimeInterval is always in seconds. Multiply it by 1000 to convert + // to milliseconds and round to the nearest millisecond. + let dispatchInterval = Int(round(timeInterval * 1000)) + + let dispatchTime = DispatchTime.now() + .milliseconds(dispatchInterval) + DispatchQueue.main.asyncAfter(deadline: dispatchTime, execute: execute) + } +} diff --git a/Sources/Realtime/Transport.swift b/Sources/Realtime/Transport.swift new file mode 100644 index 00000000..92f73641 --- /dev/null +++ b/Sources/Realtime/Transport.swift @@ -0,0 +1,300 @@ +// Copyright (c) 2021 David Stump +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +import Foundation + +// ---------------------------------------------------------------------- + +// MARK: - Transport Protocol + +// ---------------------------------------------------------------------- +/// Defines a `Socket`'s Transport layer. +// sourcery: AutoMockable +public protocol Transport { + /// The current `ReadyState` of the `Transport` layer + var readyState: TransportReadyState { get } + + /// Delegate for the `Transport` layer + var delegate: TransportDelegate? { get set } + + /** + Connect to the server + */ + func connect() + + /** + Disconnect from the server. + + - Parameters: + - code: Status code as defined by Section 7.4 of RFC 6455. + - reason: Reason why the connection is closing. Optional. + */ + func disconnect(code: Int, reason: String?) + + /** + Sends a message to the server. + + - Parameter data: Data to send. + */ + func send(data: Data) +} + +// ---------------------------------------------------------------------- + +// MARK: - Transport Delegate Protocol + +// ---------------------------------------------------------------------- +/// Delegate to receive notifications of events that occur in the `Transport` layer +public protocol TransportDelegate { + /** + Notified when the `Transport` opens. + + - Parameter response: Response from the server indicating that the WebSocket handshake was successful and the connection has been upgraded to webSockets + */ + func onOpen(response: URLResponse?) + + /** + Notified when the `Transport` receives an error. + + - Parameter error: Client-side error from the underlying `Transport` implementation + - Parameter response: Response from the server, if any, that occurred with the Error + + */ + func onError(error: Error, response: URLResponse?) + + /** + Notified when the `Transport` receives a message from the server. + + - Parameter message: Message received from the server + */ + func onMessage(message: String) + + /** + Notified when the `Transport` closes. + + - Parameter code: Code that was sent when the `Transport` closed + - Parameter reason: A concise human-readable prose explanation for the closure + */ + func onClose(code: Int, reason: String?) +} + +// ---------------------------------------------------------------------- + +// MARK: - Transport Ready State Enum + +// ---------------------------------------------------------------------- +/// Available `ReadyState`s of a `Transport` layer. +public enum TransportReadyState { + /// The `Transport` is opening a connection to the server. + case connecting + + /// The `Transport` is connected to the server. + case open + + /// The `Transport` is closing the connection to the server. + case closing + + /// The `Transport` has disconnected from the server. + case closed +} + +// ---------------------------------------------------------------------- + +// MARK: - Default Websocket Transport Implementation + +// ---------------------------------------------------------------------- +/// A `Transport` implementation that relies on URLSession's native WebSocket +/// implementation. +/// +/// This implementation ships default with SwiftPhoenixClient however +/// SwiftPhoenixClient supports earlier OS versions using one of the submodule +/// `Transport` implementations. Or you can create your own implementation using +/// your own WebSocket library or implementation. +@available(macOS 10.15, iOS 13, watchOS 6, tvOS 13, *) +open class URLSessionTransport: NSObject, Transport, URLSessionWebSocketDelegate { + /// The URL to connect to + internal let url: URL + + /// The URLSession configuration + internal let configuration: URLSessionConfiguration + + /// The underling URLSession. Assigned during `connect()` + private var session: URLSession? = nil + + /// The ongoing task. Assigned during `connect()` + private var task: URLSessionWebSocketTask? = nil + + /** + Initializes a `Transport` layer built using URLSession's WebSocket + + Example: + + ```swift + let url = URL("wss://example.com/socket") + let transport: Transport = URLSessionTransport(url: url) + ``` + + Using a custom `URLSessionConfiguration` + + ```swift + let url = URL("wss://example.com/socket") + let configuration = URLSessionConfiguration.default + let transport: Transport = URLSessionTransport(url: url, configuration: configuration) + ``` + + - parameter url: URL to connect to + - parameter configuration: Provide your own URLSessionConfiguration. Uses `.default` if none provided + */ + public init(url: URL, configuration: URLSessionConfiguration = .default) { + // URLSession requires that the endpoint be "wss" instead of "https". + let endpoint = url.absoluteString + let wsEndpoint = + endpoint + .replacingOccurrences(of: "http://", with: "ws://") + .replacingOccurrences(of: "https://", with: "wss://") + + // Force unwrapping should be safe here since a valid URL came in and we just + // replaced the protocol. + self.url = URL(string: wsEndpoint)! + self.configuration = configuration + + super.init() + } + + // MARK: - Transport + + public var readyState: TransportReadyState = .closed + public var delegate: TransportDelegate? = nil + + open func connect() { + // Set the transport state as connecting + readyState = .connecting + + // Create the session and websocket task + session = URLSession(configuration: configuration, delegate: self, delegateQueue: nil) + task = session?.webSocketTask(with: url) + + // Start the task + task?.resume() + } + + open func disconnect(code: Int, reason: String?) { + /* + TODO: + 1. Provide a "strict" mode that fails if an invalid close code is given + 2. If strict mode is disabled, default to CloseCode.invalid + 3. Provide default .normalClosure function + */ + guard let closeCode = URLSessionWebSocketTask.CloseCode(rawValue: code) else { + fatalError("Could not create a CloseCode with invalid code: [\(code)].") + } + + readyState = .closing + task?.cancel(with: closeCode, reason: reason?.data(using: .utf8)) + session?.finishTasksAndInvalidate() + } + + open func send(data: Data) { + task?.send(.string(String(data: data, encoding: .utf8)!)) { _ in + // TODO: What is the behavior when an error occurs? + } + } + + // MARK: - URLSessionWebSocketDelegate + + open func urlSession( + _: URLSession, + webSocketTask: URLSessionWebSocketTask, + didOpenWithProtocol _: String? + ) { + // The Websocket is connected. Set Transport state to open and inform delegate + readyState = .open + delegate?.onOpen(response: webSocketTask.response) + + // Start receiving messages + receive() + } + + open func urlSession( + _: URLSession, + webSocketTask _: URLSessionWebSocketTask, + didCloseWith closeCode: URLSessionWebSocketTask.CloseCode, + reason: Data? + ) { + // A close frame was received from the server. + readyState = .closed + delegate?.onClose( + code: closeCode.rawValue, reason: reason.flatMap { String(data: $0, encoding: .utf8) } + ) + } + + open func urlSession( + _: URLSession, + task: URLSessionTask, + didCompleteWithError error: Error? + ) { + // The task has terminated. Inform the delegate that the transport has closed abnormally + // if this was caused by an error. + guard let err = error else { return } + + abnormalErrorReceived(err, response: task.response) + } + + // MARK: - Private + + private func receive() { + task?.receive { [weak self] result in + switch result { + case let .success(message): + switch message { + case .data: + print("Data received. This method is unsupported by the Client") + case let .string(text): + self?.delegate?.onMessage(message: text) + default: + fatalError("Unknown result was received. [\(result)]") + } + + // Since `.receive()` is only good for a single message, it must + // be called again after a message is received in order to + // received the next message. + self?.receive() + case let .failure(error): + print("Error when receiving \(error)") + self?.abnormalErrorReceived(error, response: nil) + } + } + } + + private func abnormalErrorReceived(_ error: Error, response: URLResponse?) { + // Set the state of the Transport to closed + readyState = .closed + + // Inform the Transport's delegate that an error occurred. + delegate?.onError(error: error, response: response) + + // An abnormal error is results in an abnormal closure, such as internet getting dropped + // so inform the delegate that the Transport has closed abnormally. This will kick off + // the reconnect logic. + delegate?.onClose( + code: RealtimeClient.CloseCode.abnormal.rawValue, reason: error.localizedDescription + ) + } +} diff --git a/Tests/RealtimeTests/ChannelTopicTests.swift b/Tests/RealtimeTests/ChannelTopicTests.swift new file mode 100644 index 00000000..5a21bfd2 --- /dev/null +++ b/Tests/RealtimeTests/ChannelTopicTests.swift @@ -0,0 +1,19 @@ +import XCTest + +@testable import Realtime + +final class ChannelTopicTests: XCTestCase { + func testRawValue() { + XCTAssertEqual(ChannelTopic.all, ChannelTopic(rawValue: "realtime:*")) + XCTAssertEqual(ChannelTopic.all, ChannelTopic(rawValue: "*")) + XCTAssertEqual(ChannelTopic.schema("public"), ChannelTopic(rawValue: "realtime:public")) + XCTAssertEqual( + ChannelTopic.table("users", schema: "public"), ChannelTopic(rawValue: "realtime:public:users") + ) + XCTAssertEqual( + ChannelTopic.column("email", value: "mail@supabase.io", table: "users", schema: "public"), + ChannelTopic(rawValue: "realtime:public:users:email=eq.mail@supabase.io") + ) + XCTAssertEqual(ChannelTopic.heartbeat, ChannelTopic(rawValue: "phoenix")) + } +} diff --git a/Tests/RealtimeTests/RealtimeTests.swift b/Tests/RealtimeTests/RealtimeTests.swift new file mode 100644 index 00000000..b8a6868d --- /dev/null +++ b/Tests/RealtimeTests/RealtimeTests.swift @@ -0,0 +1,128 @@ +import XCTest + +@testable import Realtime + +final class RealtimeTests: XCTestCase { + var supabaseUrl: String { + guard let url = ProcessInfo.processInfo.environment["supabaseUrl"] else { + XCTFail("supabaseUrl not defined in environment.") + return "" + } + + return url + } + + var supabaseKey: String { + guard let key = ProcessInfo.processInfo.environment["supabaseKey"] else { + XCTFail("supabaseKey not defined in environment.") + return "" + } + return key + } + + func testConnection() throws { + try XCTSkipIf( + ProcessInfo.processInfo.environment["INTEGRATION_TESTS"] == nil, + "INTEGRATION_TESTS not defined" + ) + + let socket = RealtimeClient( + "\(supabaseUrl)/realtime/v1", params: ["apikey": supabaseKey] + ) + + let e = expectation(description: "testConnection") + socket.onOpen { + XCTAssertEqual(socket.isConnected, true) + DispatchQueue.main.asyncAfter(deadline: .now() + 1) { + socket.disconnect() + } + } + + socket.onError { error, _ in + XCTFail(error.localizedDescription) + } + + socket.onClose { + XCTAssertEqual(socket.isConnected, false) + e.fulfill() + } + + socket.connect() + + waitForExpectations(timeout: 3000) { error in + if let error = error { + XCTFail("\(self.name)) failed: \(error.localizedDescription)") + } + } + } + + func testChannelCreation() throws { + try XCTSkipIf( + ProcessInfo.processInfo.environment["INTEGRATION_TESTS"] == nil, + "INTEGRATION_TESTS not defined" + ) + + let client = RealtimeClient( + "\(supabaseUrl)/realtime/v1", params: ["apikey": supabaseKey] + ) + let allChanges = client.channel(.all) + allChanges.on(.all) { message in + print(message) + } + allChanges.join() + allChanges.leave() + allChanges.off(.all) + + let allPublicInsertChanges = client.channel(.schema("public")) + allPublicInsertChanges.on(.insert) { message in + print(message) + } + allPublicInsertChanges.join() + allPublicInsertChanges.leave() + allPublicInsertChanges.off(.insert) + + let allUsersUpdateChanges = client.channel(.table("users", schema: "public")) + allUsersUpdateChanges.on(.update) { message in + print(message) + } + allUsersUpdateChanges.join() + allUsersUpdateChanges.leave() + allUsersUpdateChanges.off(.update) + + let allUserId99Changes = client.channel( + .column("id", value: "99", table: "users", schema: "public")) + allUserId99Changes.on(.all) { message in + print(message) + } + allUserId99Changes.join() + allUserId99Changes.leave() + allUserId99Changes.off(.all) + + XCTAssertEqual(client.isConnected, false) + + let e = expectation(description: name) + client.onOpen { + XCTAssertEqual(client.isConnected, true) + DispatchQueue.main.asyncAfter(deadline: .now() + 1) { + client.disconnect() + } + } + + client.onError { error, _ in + XCTFail(error.localizedDescription) + } + + client.onClose { + XCTAssertEqual(client.isConnected, false) + e.fulfill() + } + + client.connect() + + waitForExpectations(timeout: 3000) { error in + if let error = error { + XCTFail("\(self.name)) failed: \(error.localizedDescription)") + } + } + } +} diff --git a/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved b/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved index e2b4dc23..c0af399b 100644 --- a/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -18,15 +18,6 @@ "version" : "3.0.1" } }, - { - "identity" : "realtime-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/supabase-community/realtime-swift.git", - "state" : { - "revision" : "0b985c687fe963f6bd818ff77a35c27247b98bb4", - "version" : "0.0.2" - } - }, { "identity" : "storage-swift", "kind" : "remoteSourceControl", From e5bd945d1cfde4381891c60497e54b3102746d50 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Tue, 17 Oct 2023 17:29:02 -0300 Subject: [PATCH 5/7] Add storage to repo --- Package.resolved | 9 - Package.swift | 24 +- Sources/Storage/Bucket.swift | 28 ++ Sources/Storage/BucketOptions.swift | 13 + Sources/Storage/FileObject.swift | 37 +++ Sources/Storage/FileOptions.swift | 7 + Sources/Storage/MultipartFile.swift | 64 +++++ Sources/Storage/SearchOptions.swift | 20 ++ Sources/Storage/SortBy.swift | 9 + Sources/Storage/StorageApi.swift | 123 +++++++++ Sources/Storage/StorageBucketApi.swift | 117 ++++++++ Sources/Storage/StorageError.swift | 17 ++ Sources/Storage/StorageFileApi.swift | 249 ++++++++++++++++++ Sources/Storage/StorageHTTPClient.swift | 28 ++ Sources/Storage/SupabaseStorage.swift | 18 ++ Sources/Storage/TransformOptions.swift | 49 ++++ Sources/Supabase/SupabaseClient.swift | 2 +- Tests/StorageTests/SupabaseStorageTests.swift | 96 +++++++ .../xcshareddata/swiftpm/Package.resolved | 9 - 19 files changed, 881 insertions(+), 38 deletions(-) create mode 100644 Sources/Storage/Bucket.swift create mode 100644 Sources/Storage/BucketOptions.swift create mode 100644 Sources/Storage/FileObject.swift create mode 100644 Sources/Storage/FileOptions.swift create mode 100644 Sources/Storage/MultipartFile.swift create mode 100644 Sources/Storage/SearchOptions.swift create mode 100644 Sources/Storage/SortBy.swift create mode 100644 Sources/Storage/StorageApi.swift create mode 100644 Sources/Storage/StorageBucketApi.swift create mode 100644 Sources/Storage/StorageError.swift create mode 100644 Sources/Storage/StorageFileApi.swift create mode 100644 Sources/Storage/StorageHTTPClient.swift create mode 100644 Sources/Storage/SupabaseStorage.swift create mode 100644 Sources/Storage/TransformOptions.swift create mode 100644 Tests/StorageTests/SupabaseStorageTests.swift diff --git a/Package.resolved b/Package.resolved index 0e82514d..02c5f37d 100644 --- a/Package.resolved +++ b/Package.resolved @@ -18,15 +18,6 @@ "version" : "3.0.1" } }, - { - "identity" : "storage-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/supabase-community/storage-swift.git", - "state" : { - "branch" : "dependency-free", - "revision" : "62bf80cc46e22088ca390e506b1a712f4774a018" - } - }, { "identity" : "swift-snapshot-testing", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index afe27098..eba6e054 100644 --- a/Package.swift +++ b/Package.swift @@ -18,7 +18,8 @@ var package = Package( .library(name: "GoTrue", targets: ["GoTrue"]), .library(name: "PostgREST", targets: ["PostgREST"]), .library(name: "Realtime", targets: ["Realtime"]), - .library(name: "Supabase", targets: ["Supabase", "Functions", "PostgREST", "GoTrue", "Realtime"]), + .library(name: "Storage", targets: ["Storage"]), + .library(name: "Supabase", targets: ["Supabase", "Functions", "PostgREST", "GoTrue", "Realtime", "Storage"]), ], dependencies: [ .package(url: "https://github.com/kishikawakatsumi/KeychainAccess", from: "4.2.2"), @@ -59,11 +60,13 @@ var package = Package( .testTarget(name: "PostgRESTIntegrationTests", dependencies: ["PostgREST"]), .target(name: "Realtime"), .testTarget(name: "RealtimeTests", dependencies: ["Realtime"]), + .target(name: "Storage"), + .testTarget(name: "StorageTests", dependencies: ["Storage"]), .target( name: "Supabase", dependencies: [ "GoTrue", - .product(name: "SupabaseStorage", package: "storage-swift"), + "Storage", "Realtime", "PostgREST", "Functions", @@ -72,20 +75,3 @@ var package = Package( .testTarget(name: "SupabaseTests", dependencies: ["Supabase"]), ] ) - -if ProcessInfo.processInfo.environment["USE_LOCAL_PACKAGES"] != nil { - package.dependencies.append( - contentsOf: [ - .package(path: "../storage-swift"), - ] - ) -} else { - package.dependencies.append( - contentsOf: [ - .package( - url: "https://github.com/supabase-community/storage-swift.git", - branch: "dependency-free" - ), - ] - ) -} diff --git a/Sources/Storage/Bucket.swift b/Sources/Storage/Bucket.swift new file mode 100644 index 00000000..dbce527d --- /dev/null +++ b/Sources/Storage/Bucket.swift @@ -0,0 +1,28 @@ +public struct Bucket: Hashable { + public var id: String + public var name: String + public var owner: String + public var isPublic: Bool + public var createdAt: String + public var updatedAt: String + + init?(from dictionary: [String: Any]) { + guard + let id = dictionary["id"] as? String, + let name = dictionary["name"] as? String, + let owner = dictionary["owner"] as? String, + let createdAt = dictionary["created_at"] as? String, + let updatedAt = dictionary["updated_at"] as? String, + let isPublic = dictionary["public"] as? Bool + else { + return nil + } + + self.id = id + self.name = name + self.owner = owner + self.isPublic = isPublic + self.createdAt = createdAt + self.updatedAt = updatedAt + } +} diff --git a/Sources/Storage/BucketOptions.swift b/Sources/Storage/BucketOptions.swift new file mode 100644 index 00000000..db834c3e --- /dev/null +++ b/Sources/Storage/BucketOptions.swift @@ -0,0 +1,13 @@ +import Foundation + +public struct BucketOptions { + public let `public`: Bool + public let fileSizeLimit: Int? + public let allowedMimeTypes: [String]? + + public init(public: Bool = false, fileSizeLimit: Int? = nil, allowedMimeTypes: [String]? = nil) { + self.public = `public` + self.fileSizeLimit = fileSizeLimit + self.allowedMimeTypes = allowedMimeTypes + } +} diff --git a/Sources/Storage/FileObject.swift b/Sources/Storage/FileObject.swift new file mode 100644 index 00000000..055d5776 --- /dev/null +++ b/Sources/Storage/FileObject.swift @@ -0,0 +1,37 @@ +public struct FileObject { + public var name: String + public var bucketId: String? + public var owner: String? + public var id: String + public var updatedAt: String + public var createdAt: String + public var lastAccessedAt: String + public var metadata: [String: Any] + public var buckets: Bucket? + + public init?(from dictionary: [String: Any]) { + guard + let name = dictionary["name"] as? String, + let id = dictionary["id"] as? String, + let updatedAt = dictionary["updated_at"] as? String, + let createdAt = dictionary["created_at"] as? String, + let lastAccessedAt = dictionary["last_accessed_at"] as? String, + let metadata = dictionary["metadata"] as? [String: Any] + else { + return nil + } + + self.name = name + self.bucketId = dictionary["bucket_id"] as? String + self.owner = dictionary["owner"] as? String + self.id = id + self.updatedAt = updatedAt + self.createdAt = createdAt + self.lastAccessedAt = lastAccessedAt + self.metadata = metadata + + if let buckets = dictionary["buckets"] as? [String: Any] { + self.buckets = Bucket(from: buckets) + } + } +} diff --git a/Sources/Storage/FileOptions.swift b/Sources/Storage/FileOptions.swift new file mode 100644 index 00000000..bbc51b1e --- /dev/null +++ b/Sources/Storage/FileOptions.swift @@ -0,0 +1,7 @@ +public struct FileOptions { + public var cacheControl: String + + public init(cacheControl: String) { + self.cacheControl = cacheControl + } +} diff --git a/Sources/Storage/MultipartFile.swift b/Sources/Storage/MultipartFile.swift new file mode 100644 index 00000000..23d62535 --- /dev/null +++ b/Sources/Storage/MultipartFile.swift @@ -0,0 +1,64 @@ +import Foundation + +public struct File: Hashable, Equatable { + public var name: String + public var data: Data + public var fileName: String? + public var contentType: String? + + public init(name: String, data: Data, fileName: String?, contentType: String?) { + self.name = name + self.data = data + self.fileName = fileName + self.contentType = contentType + } +} + +public class FormData { + var files: [File] = [] + var boundary: String + + public init(boundary: String = UUID().uuidString) { + self.boundary = boundary + } + + public func append(file: File) { + files.append(file) + } + + public var contentType: String { + return "multipart/form-data; boundary=\(boundary)" + } + + public var data: Data { + var data = Data() + + for file in files { + data.append("--\(boundary)\r\n") + data.append("Content-Disposition: form-data; name=\"\(file.name)\"") + if let filename = file.fileName?.replacingOccurrences(of: "\"", with: "_") { + data.append("; filename=\"\(filename)\"") + } + data.append("\r\n") + if let contentType = file.contentType { + data.append("Content-Type: \(contentType)\r\n") + } + data.append("\r\n") + data.append(file.data) + data.append("\r\n") + } + + data.append("--\(boundary)--\r\n") + return data + } +} + +extension Data { + mutating func append(_ string: String) { + let data = string.data( + using: String.Encoding.utf8, + allowLossyConversion: true + ) + append(data!) + } +} diff --git a/Sources/Storage/SearchOptions.swift b/Sources/Storage/SearchOptions.swift new file mode 100644 index 00000000..f7cf9639 --- /dev/null +++ b/Sources/Storage/SearchOptions.swift @@ -0,0 +1,20 @@ +public struct SearchOptions { + /// The number of files you want to be returned. + public var limit: Int? + + /// The starting position. + public var offset: Int? + + /// The column to sort by. Can be any column inside a ``FileObject``. + public var sortBy: SortBy? + + /// The search string to filter files by. + public var search: String? + + public init(limit: Int? = nil, offset: Int? = nil, sortBy: SortBy? = nil, search: String? = nil) { + self.limit = limit + self.offset = offset + self.sortBy = sortBy + self.search = search + } +} diff --git a/Sources/Storage/SortBy.swift b/Sources/Storage/SortBy.swift new file mode 100644 index 00000000..7495c85c --- /dev/null +++ b/Sources/Storage/SortBy.swift @@ -0,0 +1,9 @@ +public struct SortBy { + public var column: String? + public var order: String? + + public init(column: String? = nil, order: String? = nil) { + self.column = column + self.order = order + } +} diff --git a/Sources/Storage/StorageApi.swift b/Sources/Storage/StorageApi.swift new file mode 100644 index 00000000..775e41bf --- /dev/null +++ b/Sources/Storage/StorageApi.swift @@ -0,0 +1,123 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +public class StorageApi { + var url: String + var headers: [String: String] + var session: StorageHTTPSession + + init(url: String, headers: [String: String], session: StorageHTTPSession) { + self.url = url + self.headers = headers + self.session = session + } + + internal enum HTTPMethod: String { + case get = "GET" + case head = "HEAD" + case post = "POST" + case put = "PUT" + case delete = "DELETE" + case connect = "CONNECT" + case options = "OPTIONS" + case trace = "TRACE" + case patch = "PATCH" + } + + @discardableResult + internal func fetch( + url: URL, + method: HTTPMethod = .get, + parameters: [String: Any]?, + headers: [String: String]? = nil + ) async throws -> Any { + var request = URLRequest(url: url) + request.httpMethod = method.rawValue + + if var headers = headers { + headers.merge(self.headers) { $1 } + request.allHTTPHeaderFields = headers + } else { + request.allHTTPHeaderFields = self.headers + } + + if let parameters = parameters { + request.httpBody = try JSONSerialization.data(withJSONObject: parameters, options: []) + } + + let (data, response) = try await session.fetch(request) + guard let httpResonse = response as? HTTPURLResponse else { + throw URLError(.badServerResponse) + } + + if let mimeType = httpResonse.mimeType { + switch mimeType { + case "application/json": + let json = try JSONSerialization.jsonObject(with: data, options: []) + return try parse(response: json, statusCode: httpResonse.statusCode) + default: + return try parse(response: data, statusCode: httpResonse.statusCode) + } + } else { + throw StorageError(message: "failed to get response") + } + } + + internal func fetch( + url: URL, + method: HTTPMethod = .post, + formData: FormData, + headers: [String: String]? = nil, + fileOptions: FileOptions? = nil, + jsonSerialization: Bool = true + ) async throws -> Any { + var request = URLRequest(url: url) + request.httpMethod = method.rawValue + + if let fileOptions = fileOptions { + request.setValue(fileOptions.cacheControl, forHTTPHeaderField: "cacheControl") + } + + var allHTTPHeaderFields = self.headers + if let headers = headers { + allHTTPHeaderFields.merge(headers) { $1 } + } + + allHTTPHeaderFields.forEach { key, value in + request.setValue(value, forHTTPHeaderField: key) + } + + request.setValue(formData.contentType, forHTTPHeaderField: "Content-Type") + + let (data, response) = try await session.upload(request, formData.data) + guard let httpResonse = response as? HTTPURLResponse else { + throw URLError(.badServerResponse) + } + + if jsonSerialization { + let json = try JSONSerialization.jsonObject(with: data, options: []) + return try parse(response: json, statusCode: httpResonse.statusCode) + } + + if let dataString = String(data: data, encoding: .utf8) { + return dataString + } + + throw StorageError(message: "failed to get response") + } + + private func parse(response: Any, statusCode: Int) throws -> Any { + if statusCode == 200 || 200..<300 ~= statusCode { + return response + } else if let dict = response as? [String: Any], let message = dict["message"] as? String { + throw StorageError(statusCode: statusCode, message: message) + } else if let dict = response as? [String: Any], let error = dict["error"] as? String { + throw StorageError(statusCode: statusCode, message: error) + } else { + throw StorageError(statusCode: statusCode, message: "something went wrong") + } + } +} diff --git a/Sources/Storage/StorageBucketApi.swift b/Sources/Storage/StorageBucketApi.swift new file mode 100644 index 00000000..46e27c8b --- /dev/null +++ b/Sources/Storage/StorageBucketApi.swift @@ -0,0 +1,117 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +/// Storage Bucket API +public class StorageBucketApi: StorageApi { + /// StorageBucketApi initializer + /// - Parameters: + /// - url: Storage HTTP URL + /// - headers: HTTP headers. + override init(url: String, headers: [String: String], session: StorageHTTPSession) { + super.init(url: url, headers: headers, session: session) + self.headers.merge(["Content-Type": "application/json"]) { $1 } + } + + /// Retrieves the details of all Storage buckets within an existing product. + public func listBuckets() async throws -> [Bucket] { + guard let url = URL(string: "\(url)/bucket") else { + throw StorageError(message: "badURL") + } + + let response = try await fetch(url: url, method: .get, parameters: nil, headers: headers) + guard let dict = response as? [[String: Any]] else { + throw StorageError(message: "failed to parse response") + } + + return dict.compactMap { Bucket(from: $0) } + } + + /// Retrieves the details of an existing Storage bucket. + /// - Parameters: + /// - id: The unique identifier of the bucket you would like to retrieve. + public func getBucket(id: String) async throws -> Bucket { + guard let url = URL(string: "\(url)/bucket/\(id)") else { + throw StorageError(message: "badURL") + } + + let response = try await fetch(url: url, method: .get, parameters: nil, headers: headers) + guard + let dict = response as? [String: Any], + let bucket = Bucket(from: dict) + else { + throw StorageError(message: "failed to parse response") + } + + return bucket + } + + /// Creates a new Storage bucket + /// - Parameters: + /// - id: A unique identifier for the bucket you are creating. + /// - completion: newly created bucket id + public func createBucket( + id: String, + options: BucketOptions = .init() + ) async throws -> [String: Any] { + guard let url = URL(string: "\(url)/bucket") else { + throw StorageError(message: "badURL") + } + + var params: [String: Any] = [ + "id": id, + "name": id, + ] + + params["public"] = options.public + params["file_size_limit"] = options.fileSizeLimit + params["allowed_mime_types"] = options.allowedMimeTypes + + let response = try await fetch( + url: url, + method: .post, + parameters: params, + headers: headers + ) + + guard let dict = response as? [String: Any] else { + throw StorageError(message: "failed to parse response") + } + + return dict + } + + /// Removes all objects inside a single bucket. + /// - Parameters: + /// - id: The unique identifier of the bucket you would like to empty. + @discardableResult + public func emptyBucket(id: String) async throws -> [String: Any] { + guard let url = URL(string: "\(url)/bucket/\(id)/empty") else { + throw StorageError(message: "badURL") + } + + let response = try await fetch(url: url, method: .post, parameters: [:], headers: headers) + guard let dict = response as? [String: Any] else { + throw StorageError(message: "failed to parse response") + } + return dict + } + + /// Deletes an existing bucket. A bucket can't be deleted with existing objects inside it. + /// You must first `empty()` the bucket. + /// - Parameters: + /// - id: The unique identifier of the bucket you would like to delete. + public func deleteBucket(id: String) async throws -> [String: Any] { + guard let url = URL(string: "\(url)/bucket/\(id)") else { + throw StorageError(message: "badURL") + } + + let response = try await fetch(url: url, method: .delete, parameters: [:], headers: headers) + guard let dict = response as? [String: Any] else { + throw StorageError(message: "failed to parse response") + } + return dict + } +} diff --git a/Sources/Storage/StorageError.swift b/Sources/Storage/StorageError.swift new file mode 100644 index 00000000..3da9e188 --- /dev/null +++ b/Sources/Storage/StorageError.swift @@ -0,0 +1,17 @@ +import Foundation + +public struct StorageError: Error { + public var statusCode: Int? + public var message: String? + + public init(statusCode: Int? = nil, message: String? = nil) { + self.statusCode = statusCode + self.message = message + } +} + +extension StorageError: LocalizedError { + public var errorDescription: String? { + return message + } +} diff --git a/Sources/Storage/StorageFileApi.swift b/Sources/Storage/StorageFileApi.swift new file mode 100644 index 00000000..1e09f831 --- /dev/null +++ b/Sources/Storage/StorageFileApi.swift @@ -0,0 +1,249 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +let DEFAULT_SEARCH_OPTIONS = SearchOptions( + limit: 100, + offset: 0, + sortBy: SortBy( + column: "name", + order: "asc" + ) +) + +/// Supabase Storage File API +public class StorageFileApi: StorageApi { + /// The bucket id to operate on. + var bucketId: String + + /// StorageFileApi initializer + /// - Parameters: + /// - url: Storage HTTP URL + /// - headers: HTTP headers. + /// - bucketId: The bucket id to operate on. + init(url: String, headers: [String: String], bucketId: String, session: StorageHTTPSession) { + self.bucketId = bucketId + super.init(url: url, headers: headers, session: session) + } + + /// Uploads a file to an existing bucket. + /// - Parameters: + /// - path: The relative file path. Should be of the format `folder/subfolder/filename.png`. The + /// bucket must already exist before attempting to upload. + /// - file: The File object to be stored in the bucket. + /// - fileOptions: HTTP headers. For example `cacheControl` + public func upload(path: String, file: File, fileOptions: FileOptions?) async throws -> Any { + guard let url = URL(string: "\(url)/object/\(bucketId)/\(path)") else { + throw StorageError(message: "badURL") + } + + let formData = FormData() + formData.append(file: file) + + return try await fetch( + url: url, + method: .post, + formData: formData, + headers: headers, + fileOptions: fileOptions + ) + } + + /// Replaces an existing file at the specified path with a new one. + /// - Parameters: + /// - path: The relative file path. Should be of the format `folder/subfolder`. The bucket + /// already exist before attempting to upload. + /// - file: The file object to be stored in the bucket. + /// - fileOptions: HTTP headers. For example `cacheControl` + public func update(path: String, file: File, fileOptions: FileOptions?) async throws -> Any { + guard let url = URL(string: "\(url)/object/\(bucketId)/\(path)") else { + throw StorageError(message: "badURL") + } + + let formData = FormData() + formData.append(file: file) + + return try await fetch( + url: url, + method: .put, + formData: formData, + headers: headers, + fileOptions: fileOptions + ) + } + + /// Moves an existing file, optionally renaming it at the same time. + /// - Parameters: + /// - fromPath: The original file path, including the current file name. For example + /// `folder/image.png`. + /// - toPath: The new file path, including the new file name. For example + /// `folder/image-copy.png`. + public func move(fromPath: String, toPath: String) async throws -> [String: Any] { + guard let url = URL(string: "\(url)/object/move") else { + throw StorageError(message: "badURL") + } + + let response = try await fetch( + url: url, method: .post, + parameters: ["bucketId": bucketId, "sourceKey": fromPath, "destinationKey": toPath], + headers: headers + ) + + guard let dict = response as? [String: Any] else { + throw StorageError(message: "failed to parse response") + } + + return dict + } + + /// Create signed url to download file without requiring permissions. This URL can be valid for a + /// set number of seconds. + /// - Parameters: + /// - path: The file path to be downloaded, including the current file name. For example + /// `folder/image.png`. + /// - expiresIn: The number of seconds until the signed URL expires. For example, `60` for a URL + /// which is valid for one minute. + public func createSignedURL(path: String, expiresIn: Int) async throws -> URL { + guard let url = URL(string: "\(url)/object/sign/\(bucketId)/\(path)") else { + throw StorageError(message: "badURL") + } + + let response = try await fetch( + url: url, + method: .post, + parameters: ["expiresIn": expiresIn], + headers: headers + ) + guard + let dict = response as? [String: Any], + let signedURLString = dict["signedURL"] as? String, + let signedURL = URL(string: self.url.appending(signedURLString)) + else { + throw StorageError(message: "failed to parse response") + } + return signedURL + } + + /// Deletes files within the same bucket + /// - Parameters: + /// - paths: An array of files to be deletes, including the path and file name. For example + /// [`folder/image.png`]. + public func remove(paths: [String]) async throws -> [FileObject] { + guard let url = URL(string: "\(url)/object/\(bucketId)") else { + throw StorageError(message: "badURL") + } + + let response = try await fetch( + url: url, + method: .delete, + parameters: ["prefixes": paths], + headers: headers + ) + guard let array = response as? [[String: Any]] else { + throw StorageError(message: "failed to parse response") + } + + return array.compactMap { FileObject(from: $0) } + } + + /// Lists all the files within a bucket. + /// - Parameters: + /// - path: The folder path. + /// - options: Search options, including `limit`, `offset`, and `sortBy`. + public func list( + path: String? = nil, + options: SearchOptions? = nil + ) async throws -> [FileObject] { + guard let url = URL(string: "\(url)/object/list/\(bucketId)") else { + throw StorageError(message: "badURL") + } + + var parameters: [String: Any] = ["prefix": path ?? ""] + parameters["limit"] = options?.limit ?? DEFAULT_SEARCH_OPTIONS.limit + parameters["offset"] = options?.offset ?? DEFAULT_SEARCH_OPTIONS.offset + parameters["search"] = options?.search ?? DEFAULT_SEARCH_OPTIONS.search + + if let sortBy = options?.sortBy ?? DEFAULT_SEARCH_OPTIONS.sortBy { + parameters["sortBy"] = [ + "column": sortBy.column, + "order": sortBy.order, + ] + } + + let response = try await fetch( + url: url, method: .post, parameters: parameters, headers: headers) + + guard let array = response as? [[String: Any]] else { + throw StorageError(message: "failed to parse response") + } + + return array.compactMap { FileObject(from: $0) } + } + + /// Downloads a file. + /// - Parameters: + /// - path: The file path to be downloaded, including the path and file name. For example + /// `folder/image.png`. + @discardableResult + public func download(path: String) async throws -> Data { + guard let url = URL(string: "\(url)/object/\(bucketId)/\(path)") else { + throw StorageError(message: "badURL") + } + + let response = try await fetch(url: url, parameters: nil) + guard let data = response as? Data else { + throw StorageError(message: "failed to parse response") + } + return data + } + + /// Returns a public url for an asset. + /// - Parameters: + /// - path: The file path to the asset. For example `folder/image.png`. + /// - download: Whether the asset should be downloaded. + /// - fileName: If specified, the file name for the asset that is downloaded. + /// - options: Transform the asset before retrieving it on the client. + public func getPublicURL( + path: String, + download: Bool = false, + fileName: String = "", + options: TransformOptions? = nil + ) throws -> URL { + var queryItems: [URLQueryItem] = [] + + guard var components = URLComponents(string: url) else { + throw StorageError(message: "badURL") + } + + if download { + queryItems.append(URLQueryItem(name: "download", value: fileName)) + } + + if let optionsQueryItems = options?.queryItems { + queryItems.append(contentsOf: optionsQueryItems) + } + + let renderPath = options != nil ? "render/image" : "object" + + components.path += "/\(renderPath)/public/\(bucketId)/\(path)" + components.queryItems = !queryItems.isEmpty ? queryItems : nil + + guard let generatedUrl = components.url else { + throw StorageError(message: "badUrl") + } + + return generatedUrl + } + + @available(*, deprecated, renamed: "getPublicURL") + public func getPublicUrl( + path: String, + download: Bool = false, + fileName: String = "", + options: TransformOptions? = nil + ) throws -> URL { + try getPublicURL(path: path, download: download, fileName: fileName, options: options) + } +} diff --git a/Sources/Storage/StorageHTTPClient.swift b/Sources/Storage/StorageHTTPClient.swift new file mode 100644 index 00000000..286e313b --- /dev/null +++ b/Sources/Storage/StorageHTTPClient.swift @@ -0,0 +1,28 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +public struct StorageHTTPSession: Sendable { + public let fetch: @Sendable (_ request: URLRequest) async throws -> (Data, URLResponse) + public let upload: + @Sendable (_ request: URLRequest, _ data: Data) async throws -> (Data, URLResponse) + + public init( + fetch: @escaping @Sendable (_ request: URLRequest) async throws -> (Data, URLResponse), + upload: @escaping @Sendable (_ request: URLRequest, _ data: Data) async throws -> ( + Data, URLResponse + ) + ) { + self.fetch = fetch + self.upload = upload + } + + public init() { + self.init( + fetch: { try await URLSession.shared.data(for: $0) }, + upload: { try await URLSession.shared.upload(for: $0, from: $1) } + ) + } +} diff --git a/Sources/Storage/SupabaseStorage.swift b/Sources/Storage/SupabaseStorage.swift new file mode 100644 index 00000000..c9947e22 --- /dev/null +++ b/Sources/Storage/SupabaseStorage.swift @@ -0,0 +1,18 @@ +public class SupabaseStorageClient: StorageBucketApi { + /// Storage Client initializer + /// - Parameters: + /// - url: Storage HTTP URL + /// - headers: HTTP headers. + override public init( + url: String, headers: [String: String], session: StorageHTTPSession = .init() + ) { + super.init(url: url, headers: headers, session: session) + } + + /// Perform file operation in a bucket. + /// - Parameter id: The bucket id to operate on. + /// - Returns: StorageFileApi object + public func from(id: String) -> StorageFileApi { + StorageFileApi(url: url, headers: headers, bucketId: id, session: session) + } +} diff --git a/Sources/Storage/TransformOptions.swift b/Sources/Storage/TransformOptions.swift new file mode 100644 index 00000000..d1dcee1e --- /dev/null +++ b/Sources/Storage/TransformOptions.swift @@ -0,0 +1,49 @@ +import Foundation + +public struct TransformOptions { + public var width: Int? + public var height: Int? + public var resize: String? + public var quality: Int? + public var format: String? + + public init( + width: Int? = nil, + height: Int? = nil, + resize: String? = "cover", + quality: Int? = 80, + format: String? = "origin" + ) { + self.width = width + self.height = height + self.resize = resize + self.quality = quality + self.format = format + } + + var queryItems: [URLQueryItem] { + var items = [URLQueryItem]() + + if let width = width { + items.append(URLQueryItem(name: "width", value: String(width))) + } + + if let height = height { + items.append(URLQueryItem(name: "height", value: String(height))) + } + + if let resize = resize { + items.append(URLQueryItem(name: "resize", value: resize)) + } + + if let quality = quality { + items.append(URLQueryItem(name: "quality", value: String(quality))) + } + + if let format = format { + items.append(URLQueryItem(name: "format", value: format)) + } + + return items + } +} diff --git a/Sources/Supabase/SupabaseClient.swift b/Sources/Supabase/SupabaseClient.swift index bbf80e24..472df21e 100644 --- a/Sources/Supabase/SupabaseClient.swift +++ b/Sources/Supabase/SupabaseClient.swift @@ -3,7 +3,7 @@ import Foundation @_exported import GoTrue @_exported import PostgREST @_exported import Realtime -@_exported import SupabaseStorage +@_exported import Storage /// Supabase Client. public class SupabaseClient { diff --git a/Tests/StorageTests/SupabaseStorageTests.swift b/Tests/StorageTests/SupabaseStorageTests.swift new file mode 100644 index 00000000..f74a888a --- /dev/null +++ b/Tests/StorageTests/SupabaseStorageTests.swift @@ -0,0 +1,96 @@ +import Foundation +import XCTest + +@testable import SupabaseStorage + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +final class SupabaseStorageTests: XCTestCase { + static var apiKey: String { + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZS1kZW1vIiwicm9sZSI6ImFub24iLCJleHAiOjE5ODM4MTI5OTZ9.CRXP1A7WOeoJeXxjNni43kdQwgnWNReilDMblYTn_I0" + } + + static var supabaseURL: String { + "http://localhost:54321/storage/v1" + } + + let bucket = "public" + + let storage = SupabaseStorageClient( + url: supabaseURL, + headers: [ + "Authorization": "Bearer \(apiKey)", + "apikey": apiKey, + ] + ) + + let uploadData = try! Data( + contentsOf: URL( + string: "https://raw.githubusercontent.com/supabase-community/storage-swift/main/README.md" + )! + ) + + override func setUp() async throws { + try await super.setUp() + _ = try? await storage.emptyBucket(id: bucket) + _ = try? await storage.deleteBucket(id: bucket) + + _ = try await storage.createBucket(id: bucket, options: BucketOptions(public: true)) + } + + func testListBuckets() async throws { + let buckets = try await storage.listBuckets() + XCTAssertEqual(buckets.map(\.name), [bucket]) + } + + func testFileIntegration() async throws { + try await uploadTestData() + + let files = try await storage.from(id: bucket).list() + XCTAssertEqual(files.map(\.name), ["README.md"]) + + let downloadedData = try await storage.from(id: bucket).download(path: "README.md") + XCTAssertEqual(downloadedData, uploadData) + + let removedFiles = try await storage.from(id: bucket).remove(paths: ["README.md"]) + XCTAssertEqual(removedFiles.map(\.name), ["README.md"]) + } + + func testGetPublicURL() async throws { + try await uploadTestData() + + let path = "README.md" + + let baseUrl = try storage.from(id: bucket).getPublicURL(path: path) + XCTAssertEqual(baseUrl.absoluteString, "\(Self.supabaseURL)/object/public/\(bucket)/\(path)") + + let baseUrlWithDownload = try storage.from(id: bucket).getPublicURL(path: path, download: true) + XCTAssertEqual( + baseUrlWithDownload.absoluteString, + "\(Self.supabaseURL)/object/public/\(bucket)/\(path)?download=") + + let baseUrlWithDownloadAndFileName = try storage.from(id: bucket).getPublicURL( + path: path, download: true, fileName: "test") + XCTAssertEqual( + baseUrlWithDownloadAndFileName.absoluteString, + "\(Self.supabaseURL)/object/public/\(bucket)/\(path)?download=test") + + let baseUrlWithAllOptions = try storage.from(id: bucket).getPublicURL( + path: path, download: true, fileName: "test", + options: TransformOptions(width: 300, height: 300)) + XCTAssertEqual( + baseUrlWithAllOptions.absoluteString, + "\(Self.supabaseURL)/render/image/public/\(bucket)/\(path)?download=test&width=300&height=300&resize=cover&quality=80&format=origin" + ) + } + + private func uploadTestData() async throws { + let file = File( + name: "README.md", data: uploadData, fileName: "README.md", contentType: "text/html") + _ = try await storage.from(id: bucket).upload( + path: "README.md", file: file, fileOptions: FileOptions(cacheControl: "3600") + ) + } +} diff --git a/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved b/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved index c0af399b..74f5b279 100644 --- a/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -18,15 +18,6 @@ "version" : "3.0.1" } }, - { - "identity" : "storage-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/supabase-community/storage-swift.git", - "state" : { - "branch" : "dependency-free", - "revision" : "62bf80cc46e22088ca390e506b1a712f4774a018" - } - }, { "identity" : "swift-case-paths", "kind" : "remoteSourceControl", From 7235d7ec635d92c6fbf9f9dfee1922b08573d598 Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Tue, 17 Oct 2023 17:35:12 -0300 Subject: [PATCH 6/7] Update dependencies on release --- .../xcshareddata/xcschemes/Supabase.xcscheme | 2 +- Examples/Examples.xcodeproj/project.pbxproj | 26 ++++++++++++------- Examples/Examples/ExamplesApp.swift | 14 ---------- Examples/Examples/RootView.swift | 2 +- Package.swift | 4 ++- Sources/Functions/Types.swift | 2 +- Sources/GoTrue/GoTrueClient.swift | 11 -------- Sources/Supabase/SupabaseClient.swift | 10 ++----- .../xcshareddata/swiftpm/Package.resolved | 24 ++++++++--------- 9 files changed, 36 insertions(+), 59 deletions(-) diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/Supabase.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/Supabase.xcscheme index 2b4e3972..ab7cba6b 100644 --- a/.swiftpm/xcode/xcshareddata/xcschemes/Supabase.xcscheme +++ b/.swiftpm/xcode/xcshareddata/xcschemes/Supabase.xcscheme @@ -1,6 +1,6 @@ - /// Returns the session, refreshing it if necessary. public var session: Session { get async throws { @@ -96,15 +94,6 @@ public final class GoTrueClient { } } - /// Initialize the client session from storage. - /// - /// This method should be called on the app startup, for making sure that the client is fully - /// initialized - /// before proceeding. - // public func initialize() async { - // await initializationTask.value - // } - /// Creates a new user. /// - Parameters: /// - email: User's email address. diff --git a/Sources/Supabase/SupabaseClient.swift b/Sources/Supabase/SupabaseClient.swift index 472df21e..2bfd3aca 100644 --- a/Sources/Supabase/SupabaseClient.swift +++ b/Sources/Supabase/SupabaseClient.swift @@ -32,10 +32,7 @@ public class SupabaseClient { SupabaseStorageClient( url: storageURL.absoluteString, headers: defaultHeaders, - session: StorageHTTPSession( - fetch: fetch, - upload: upload - ) + session: StorageHTTPSession(fetch: fetch, upload: upload) ) } @@ -51,10 +48,7 @@ public class SupabaseClient { /// Realtime client for Supabase public var realtime: RealtimeClient { - RealtimeClient( - endPoint: realtimeURL.absoluteString, - params: defaultHeaders - ) + RealtimeClient(realtimeURL.absoluteString, params: defaultHeaders) } /// Supabase Functions allows you to deploy and invoke edge functions. diff --git a/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved b/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved index 74f5b279..47655550 100644 --- a/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/supabase-swift.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -23,8 +23,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/pointfreeco/swift-case-paths", "state" : { - "revision" : "fc45e7b2cfece9dd80b5a45e6469ffe67fe67984", - "version" : "0.14.1" + "revision" : "5da6989aae464f324eef5c5b52bdb7974725ab81", + "version" : "1.0.0" } }, { @@ -32,8 +32,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-collections", "state" : { - "revision" : "937e904258d22af6e447a0b72c0bc67583ef64a2", - "version" : "1.0.4" + "revision" : "a902f1823a7ff3c9ab2fba0f992396b948eda307", + "version" : "1.0.5" } }, { @@ -41,8 +41,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/pointfreeco/swift-custom-dump", "state" : { - "revision" : "3a35f7892e7cf6ba28a78cd46a703c0be4e0c6dc", - "version" : "0.11.0" + "revision" : "3efbfba0e4e56c7187cc19137ee16b7c95346b79", + "version" : "1.1.0" } }, { @@ -50,8 +50,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/pointfreeco/swift-identified-collections.git", "state" : { - "revision" : "d01446a78fb768adc9a78cbb6df07767c8ccfc29", - "version" : "0.8.0" + "revision" : "d1e45f3e1eee2c9193f5369fa9d70a6ddad635e8", + "version" : "1.0.0" } }, { @@ -77,8 +77,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/pointfreeco/swiftui-navigation.git", "state" : { - "revision" : "2aa885e719087ee19df251c08a5980ad3e787f12", - "version" : "0.8.0" + "revision" : "6eb293c49505d86e9e24232cb6af6be7fff93bd5", + "version" : "1.0.2" } }, { @@ -86,8 +86,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/pointfreeco/xctest-dynamic-overlay", "state" : { - "revision" : "4af50b38daf0037cfbab15514a241224c3f62f98", - "version" : "0.8.5" + "revision" : "23cbf2294e350076ea4dbd7d5d047c1e76b03631", + "version" : "1.0.2" } } ], From 0718138105a9b9bc543efbd9b9d24a63e4c5e74f Mon Sep 17 00:00:00 2001 From: Guilherme Souza Date: Tue, 17 Oct 2023 17:48:30 -0300 Subject: [PATCH 7/7] Add Test Plan to run all tests --- .../xcschemes/Supabase-Package.xcscheme | 71 +++++++++++++++++++ Makefile | 4 +- Sources/Functions/FunctionsClient.swift | 4 +- Supabase.xctestplan | 66 +++++++++++++++++ Tests/StorageTests/SupabaseStorageTests.swift | 8 ++- Tests/SupabaseTests/SupabaseClientTests.swift | 3 +- .../contents.xcworkspacedata | 3 + 7 files changed, 152 insertions(+), 7 deletions(-) create mode 100644 .swiftpm/xcode/xcshareddata/xcschemes/Supabase-Package.xcscheme create mode 100644 Supabase.xctestplan diff --git a/.swiftpm/xcode/xcshareddata/xcschemes/Supabase-Package.xcscheme b/.swiftpm/xcode/xcshareddata/xcschemes/Supabase-Package.xcscheme new file mode 100644 index 00000000..19580b48 --- /dev/null +++ b/.swiftpm/xcode/xcshareddata/xcschemes/Supabase-Package.xcscheme @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/Makefile b/Makefile index 722107f2..94da342c 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,9 @@ -PLATFORM_IOS = iOS Simulator,name=iPhone 14 Pro Max +PLATFORM_IOS = iOS Simulator,name=iPhone 15 Pro Max test-library: xcodebuild test \ -workspace supabase-swift.xcworkspace \ - -scheme Supabase \ + -scheme Supabase-Package \ -destination platform="$(PLATFORM_IOS)" || exit 1; build-example: diff --git a/Sources/Functions/FunctionsClient.swift b/Sources/Functions/FunctionsClient.swift index 962eb127..df1708e6 100644 --- a/Sources/Functions/FunctionsClient.swift +++ b/Sources/Functions/FunctionsClient.swift @@ -2,7 +2,7 @@ import Foundation /// An actor representing a client for invoking functions. public actor FunctionsClient { - /// Typealias for the fetch handler used to make requests. + /// Fetch handler used to make requests. public typealias FetchHandler = @Sendable (_ request: URLRequest) async throws -> ( Data, URLResponse ) @@ -90,7 +90,7 @@ public actor FunctionsClient { ) async throws -> (Data, HTTPURLResponse) { let url = self.url.appendingPathComponent(functionName) var urlRequest = URLRequest(url: url) - urlRequest.allHTTPHeaderFields = invokeOptions.headers.merging(headers) { first, _ in first } + urlRequest.allHTTPHeaderFields = invokeOptions.headers.merging(headers) { invoke, _ in invoke } urlRequest.httpMethod = (invokeOptions.method ?? .post).rawValue urlRequest.httpBody = invokeOptions.body diff --git a/Supabase.xctestplan b/Supabase.xctestplan new file mode 100644 index 00000000..5077254d --- /dev/null +++ b/Supabase.xctestplan @@ -0,0 +1,66 @@ +{ + "configurations" : [ + { + "id" : "8C117D21-39DB-4469-B9E3-213E030294F9", + "name" : "Test Scheme Action", + "options" : { + + } + } + ], + "defaultOptions" : { + "codeCoverage" : true + }, + "testTargets" : [ + { + "target" : { + "containerPath" : "container:", + "identifier" : "FunctionsTests", + "name" : "FunctionsTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "GoTrueTests", + "name" : "GoTrueTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "PostgRESTIntegrationTests", + "name" : "PostgRESTIntegrationTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "PostgRESTTests", + "name" : "PostgRESTTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "RealtimeTests", + "name" : "RealtimeTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "StorageTests", + "name" : "StorageTests" + } + }, + { + "target" : { + "containerPath" : "container:", + "identifier" : "SupabaseTests", + "name" : "SupabaseTests" + } + } + ], + "version" : 1 +} diff --git a/Tests/StorageTests/SupabaseStorageTests.swift b/Tests/StorageTests/SupabaseStorageTests.swift index f74a888a..987e2c58 100644 --- a/Tests/StorageTests/SupabaseStorageTests.swift +++ b/Tests/StorageTests/SupabaseStorageTests.swift @@ -1,7 +1,7 @@ import Foundation import XCTest -@testable import SupabaseStorage +@testable import Storage #if canImport(FoundationNetworking) import FoundationNetworking @@ -34,6 +34,12 @@ final class SupabaseStorageTests: XCTestCase { override func setUp() async throws { try await super.setUp() + + try XCTSkipUnless( + ProcessInfo.processInfo.environment["INTEGRATION_TESTS"] != nil, + "INTEGRATION_TESTS not defined." + ) + _ = try? await storage.emptyBucket(id: bucket) _ = try? await storage.deleteBucket(id: bucket) diff --git a/Tests/SupabaseTests/SupabaseClientTests.swift b/Tests/SupabaseTests/SupabaseClientTests.swift index 0fc2c334..7469a2a2 100644 --- a/Tests/SupabaseTests/SupabaseClientTests.swift +++ b/Tests/SupabaseTests/SupabaseClientTests.swift @@ -18,7 +18,6 @@ final class SupabaseClientTests: XCTestCase { let customSchema = "custom_schema" let localStorage = GoTrueLocalStorageMock() let customHeaders = ["header_field": "header_value"] - let httpClient = SupabaseClient.HTTPClient(storage: nil) let client = SupabaseClient( supabaseURL: URL(string: "https://project-ref.supabase.co")!, @@ -28,7 +27,7 @@ final class SupabaseClientTests: XCTestCase { auth: SupabaseClientOptions.AuthOptions(storage: localStorage), global: SupabaseClientOptions.GlobalOptions( headers: customHeaders, - httpClient: httpClient + session: .shared ) ) ) diff --git a/supabase-swift.xcworkspace/contents.xcworkspacedata b/supabase-swift.xcworkspace/contents.xcworkspacedata index d6fe7ad5..89219a04 100644 --- a/supabase-swift.xcworkspace/contents.xcworkspacedata +++ b/supabase-swift.xcworkspace/contents.xcworkspacedata @@ -7,4 +7,7 @@ + +