Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/gotrue/lib/gotrue.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ export 'src/constants.dart'
hide Constants, GenerateLinkTypeExtended, AuthChangeEventExtended;
export 'src/gotrue_admin_api.dart';
export 'src/gotrue_client.dart';
export 'src/helper.dart' show decodeJwt, validateExp;
export 'src/types/auth_exception.dart';
export 'src/types/auth_response.dart' hide ToSnakeCase;
export 'src/types/auth_state.dart';
export 'src/types/gotrue_async_storage.dart';
export 'src/types/jwt.dart';
export 'src/types/mfa.dart';
export 'src/types/types.dart';
export 'src/types/session.dart';
Expand Down
14 changes: 14 additions & 0 deletions packages/gotrue/lib/src/base64url.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import 'dart:convert';

class Base64Url {
/// Decodes a base64url string to a UTF-8 string
static String decodeToString(String input) {
final normalized = base64Url.normalize(input);
return utf8.decode(base64Url.decode(normalized));
}

static List<int> decodeToBytes(String input) {
final normalized = base64Url.normalize(input);
return base64Url.decode(normalized);
}
}
3 changes: 3 additions & 0 deletions packages/gotrue/lib/src/constants.dart
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class Constants {

/// The name of the header that contains API version.
static const apiVersionHeaderName = 'x-supabase-api-version';

/// The TTL for the JWKS cache.
static const jwksTtl = Duration(minutes: 10);
}

class ApiVersions {
Expand Down
117 changes: 117 additions & 0 deletions packages/gotrue/lib/src/gotrue_client.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import 'dart:async';
import 'dart:convert';
import 'dart:math';
import 'dart:typed_data';

import 'package:collection/collection.dart';
import 'package:gotrue/gotrue.dart';
Expand All @@ -13,6 +14,7 @@ import 'package:http/http.dart';
import 'package:jwt_decode/jwt_decode.dart';
import 'package:logging/logging.dart';
import 'package:meta/meta.dart';
import 'package:pointycastle/export.dart';
import 'package:retry/retry.dart';
import 'package:rxdart/subjects.dart';

Expand Down Expand Up @@ -58,6 +60,9 @@ class GoTrueClient {
/// Completer to combine multiple simultaneous token refresh requests.
Completer<AuthResponse>? _refreshTokenCompleter;

JWKSet? _jwks;
DateTime? _jwksCachedAt;

final _onAuthStateChangeController = BehaviorSubject<AuthState>();
final _onAuthStateChangeControllerSync =
BehaviorSubject<AuthState>(sync: true);
Expand Down Expand Up @@ -1336,4 +1341,116 @@ class GoTrueClient {
);
return exception;
}

Future<JWK?> _fetchJwk(String kid, JWKSet suppliedJwks) async {
// try fetching from the supplied jwks
final jwk = suppliedJwks.keys.firstWhereOrNull((jwk) => jwk.kid == kid);
if (jwk != null) {
return jwk;
}

final now = DateTime.now();

// try fetching from cache
final cachedJwk = _jwks?.keys.firstWhereOrNull((jwk) => jwk.kid == kid);

// jwks exists and it isn't stale
if (cachedJwk != null &&
_jwksCachedAt != null &&
_jwksCachedAt!.add(Constants.jwksTtl).isAfter(now)) {
return cachedJwk;
}

// jwk isn't cached in memory so we need to fetch it from the well-known endpoint
final jwksResponse = await _fetch.request(
'$_url/.well-known/jwks.json',
RequestMethodType.get,
options: GotrueRequestOptions(headers: _headers),
);

final jwks = JWKSet.fromJson(jwksResponse as Map<String, dynamic>);

if (jwks.keys.isEmpty) {
return null;
}

_jwks = jwks;
_jwksCachedAt = now;

// find the signing key
return jwks.keys.firstWhereOrNull((jwk) => jwk.kid == kid);
}

/// Extracts the JWT claims present in the access token by first verifying the
/// JWT against the server's JSON Web Key Set endpoint
/// `/.well-known/jwks.json` which is often cached, resulting in significantly
/// faster responses. Prefer this method over [getUser] which always
/// sends a request to the Auth server for each JWT.
///
/// If the project is not using an asymmetric JWT signing key (like ECC or
/// RSA) it always sends a request to the Auth server (similar to [getUser]) to verify the JWT.
/// [jwt] An optional specific JWT you wish to verify, not the one you
/// can obtain from [currentSession].
/// [options] Various additional options that allow you to customize the
/// behavior of this method.
///
/// Returns a [GetClaimsResponse] containing the JWT claims, or throws an [AuthException] on error.
Future<GetClaimsResponse> getClaims([
String? jwt,
GetClaimsOptions? options,
]) async {
String token = jwt ?? '';

if (token.isEmpty) {
final session = currentSession;
if (session == null) {
throw AuthSessionMissingException('No session found');
}
token = session.accessToken;
}

// Decode the JWT to get the payload
final decoded = decodeJwt(token);

// Validate expiration unless allowExpired is true
if (!(options?.allowExpired ?? false)) {
validateExp(decoded.payload.exp);
}

final signingKey =
(decoded.header.alg.startsWith('HS') || decoded.header.kid == null)
? null
: await _fetchJwk(decoded.header.kid!, _jwks!);

// If symmetric algorithm, fallback to getUser()
if (signingKey == null) {
await getUser(token);
return GetClaimsResponse(
claims: decoded.payload,
header: decoded.header,
signature: decoded.signature);
}

final publicKey = RSAPublicKey(signingKey['n'], signingKey['e']);
final signer = RSASigner(SHA256Digest(), '0609608648016503040201'); // PKCS1

// initialize with false, which means verify
signer.init(false, PublicKeyParameter<RSAPublicKey>(publicKey));

final signature = RSASignature(Uint8List.fromList(decoded.signature));
final isValidSignature = signer.verifySignature(
Uint8List.fromList(
utf8.encode('${decoded.raw.header}.${decoded.raw.payload}')),
signature,
);

if (!isValidSignature) {
throw AuthInvalidJwtException('Invalid JWT signature');
}

return GetClaimsResponse(
claims: decoded.payload,
header: decoded.header,
signature: decoded.signature);
}
}
60 changes: 60 additions & 0 deletions packages/gotrue/lib/src/helper.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ import 'dart:convert';
import 'dart:math';

import 'package:crypto/crypto.dart';
import 'package:gotrue/src/base64url.dart';
import 'package:gotrue/src/types/auth_exception.dart';
import 'package:gotrue/src/types/jwt.dart';

/// Converts base 10 int into String representation of base 16 int and takes the last two digets.
String dec2hex(int dec) {
Expand Down Expand Up @@ -30,3 +33,60 @@ void validateUuid(String id) {
throw ArgumentError('Invalid id: $id, must be a valid UUID');
}
}

/// Decodes a JWT token without performing validation
///
/// Returns a [DecodedJwt] containing the header, payload, signature, and raw parts.
/// Throws [AuthInvalidJwtException] if the JWT structure is invalid.
DecodedJwt decodeJwt(String token) {
final parts = token.split('.');
if (parts.length != 3) {
throw AuthInvalidJwtException('Invalid JWT structure');
}

final rawHeader = parts[0];
final rawPayload = parts[1];
final rawSignature = parts[2];

try {
// Decode header
final headerJson = Base64Url.decodeToString(rawHeader);
final header = JwtHeader.fromJson(json.decode(headerJson));

// Decode payload
final payloadJson = Base64Url.decodeToString(rawPayload);
final payload = JwtPayload.fromJson(json.decode(payloadJson));

// Decode signature
final signature = Base64Url.decodeToBytes(rawSignature);

return DecodedJwt(
header: header,
payload: payload,
signature: signature,
raw: JwtRawParts(
header: rawHeader,
payload: rawPayload,
signature: rawSignature,
),
);
} catch (e) {
if (e is AuthInvalidJwtException) {
rethrow;
}
throw AuthInvalidJwtException('Failed to decode JWT: $e');
}
}

/// Validates the expiration time of a JWT
///
/// Throws [AuthException] if the exp claim is missing or the JWT has expired.
void validateExp(int? exp) {
if (exp == null) {
throw AuthException('Missing exp claim');
}
final timeNow = DateTime.now().millisecondsSinceEpoch / 1000;
if (exp <= timeNow) {
throw AuthException('JWT has expired');
}
}
12 changes: 12 additions & 0 deletions packages/gotrue/lib/src/types/auth_exception.dart
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,15 @@ class AuthWeakPasswordException extends AuthException {
String toString() =>
'AuthWeakPasswordException(message: $message, statusCode: $statusCode, reasons: $reasons)';
}

class AuthInvalidJwtException extends AuthException {
AuthInvalidJwtException(super.message)
: super(
statusCode: '400',
code: 'invalid_jwt',
);

@override
String toString() =>
'AuthInvalidJwtException(message: $message, statusCode: $statusCode, code: $code)';
}
Loading
Loading