Skip to content

Commit

Permalink
fix(auth): extract both query and fragment from URL (#365)
Browse files Browse the repository at this point in the history
* fix(auth): extract both query and fragment from URL

* test: query takes precedence
  • Loading branch information
grdsdev committed May 7, 2024
1 parent f1e17ee commit e9c7c8c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 32 deletions.
28 changes: 12 additions & 16 deletions Sources/Auth/AuthClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -595,32 +595,30 @@ public final class AuthClient: Sendable {
let params = extractParams(from: url)

if isPKCEFlow(url: url) {
guard let code = params.first(where: { $0.name == "code" })?.value else {
guard let code = params["code"] else {
throw AuthError.pkce(.codeVerifierNotFound)
}

let session = try await exchangeCodeForSession(authCode: code)
return session
}

if let errorDescription = params.first(where: { $0.name == "error_description" })?.value {
if let errorDescription = params["error_description"] {
throw AuthError.api(.init(errorDescription: errorDescription))
}

guard
let accessToken = params.first(where: { $0.name == "access_token" })?.value,
let expiresIn = params.first(where: { $0.name == "expires_in" }).map(\.value)
.flatMap(TimeInterval.init),
let refreshToken = params.first(where: { $0.name == "refresh_token" })?.value,
let tokenType = params.first(where: { $0.name == "token_type" })?.value
let accessToken = params["access_token"],
let expiresIn = params["expires_in"].flatMap(TimeInterval.init),
let refreshToken = params["refresh_token"],
let tokenType = params["token_type"]
else {
throw URLError(.badURL)
}

let expiresAt = params.first(where: { $0.name == "expires_at" }).map(\.value)
.flatMap(TimeInterval.init)
let providerToken = params.first(where: { $0.name == "provider_token" })?.value
let providerRefreshToken = params.first(where: { $0.name == "provider_refresh_token" })?.value
let expiresAt = params["expires_at"].flatMap(TimeInterval.init)
let providerToken = params["provider_token"]
let providerRefreshToken = params["provider_refresh_token"]

let user = try await api.execute(
.init(
Expand All @@ -644,7 +642,7 @@ public final class AuthClient: Sendable {
try await sessionManager.update(session)
eventEmitter.emit(.signedIn, session: session)

if let type = params.first(where: { $0.name == "type" })?.value, type == "recovery" {
if let type = params["type"], type == "recovery" {
eventEmitter.emit(.passwordRecovery, session: session)
}

Expand Down Expand Up @@ -1060,15 +1058,13 @@ public final class AuthClient: Sendable {

private func isImplicitGrantFlow(url: URL) -> Bool {
let fragments = extractParams(from: url)
return fragments.contains {
$0.name == "access_token" || $0.name == "error_description"
}
return fragments["access_token"] != nil || fragments["error_description"] != nil
}

private func isPKCEFlow(url: URL) -> Bool {
let fragments = extractParams(from: url)
let currentCodeVerifier = codeVerifierStorage.get()
return fragments.contains(where: { $0.name == "code" }) && currentCodeVerifier != nil
return fragments["code"] != nil && currentCodeVerifier != nil
}

private func getURLForProvider(
Expand Down
29 changes: 15 additions & 14 deletions Sources/Auth/Internal/Helpers.swift
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import Foundation

struct Params: Hashable {
var name: String
var value: String
}

func extractParams(from url: URL) -> [Params] {
/// Extracts parameters encoded in the URL both in the query and fragment.
func extractParams(from url: URL) -> [String: String] {
guard let components = URLComponents(url: url, resolvingAgainstBaseURL: false) else {
return []
return [:]
}

var result: [String: String] = [:]

if let fragment = components.fragment {
return extractParams(from: fragment)
let items = extractParams(from: fragment)
for item in items {
result[item.name] = item.value
}
}

if let queryItems = components.queryItems {
return queryItems.map {
Params(name: $0.name, value: $0.value ?? "")
if let items = components.queryItems {
for item in items {
result[item.name] = item.value
}
}

return []
return result
}

func extractParams(from fragment: String) -> [Params] {
private func extractParams(from fragment: String) -> [URLQueryItem] {
let components =
fragment
.split(separator: "&")
Expand All @@ -33,7 +34,7 @@ func extractParams(from fragment: String) -> [Params] {
components
.compactMap {
$0.count == 2
? Params(name: String($0[0]), value: String($0[1]))
? URLQueryItem(name: String($0[0]), value: String($0[1]))
: nil
}
}
Expand Down
17 changes: 15 additions & 2 deletions Tests/AuthTests/ExtractParamsTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,26 @@ final class ExtractParamsTests: XCTestCase {
let code = UUID().uuidString
let url = URL(string: "io.supabase.flutterquickstart://login-callback/?code=\(code)")!
let params = extractParams(from: url)
XCTAssertEqual(params, [Params(name: "code", value: code)])
XCTAssertEqual(params, ["code": code])
}

func testExtractParamsInFragment() {
let code = UUID().uuidString
let url = URL(string: "io.supabase.flutterquickstart://login-callback/#code=\(code)")!
let params = extractParams(from: url)
XCTAssertEqual(params, [Params(name: "code", value: code)])
XCTAssertEqual(params, ["code": code])
}

func testExtractParamsInBothFragmentAndQuery() {
let code = UUID().uuidString
let url = URL(string: "io.supabase.flutterquickstart://login-callback/?code=\(code)#message=abc")!
let params = extractParams(from: url)
XCTAssertEqual(params, ["code": code, "message": "abc"])
}

func testExtractParamsQueryTakesPrecedence() {
let url = URL(string: "io.supabase.flutterquickstart://login-callback/?code=123#code=abc")!
let params = extractParams(from: url)
XCTAssertEqual(params, ["code": "123"])
}
}

0 comments on commit e9c7c8c

Please sign in to comment.