From 2d047dc36930f219ce0a7d1e1cfcb1948df2e3f7 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Thu, 7 Aug 2025 16:04:44 +0200 Subject: [PATCH 01/15] feat: implement OAuth2 authorization endpoint --- go.mod | 2 +- internal/api/api.go | 8 + internal/api/apierrors/errorcode.go | 3 + internal/api/context.go | 13 +- internal/api/oauthserver/authorize.go | 484 ++++++++++++++++++ internal/api/oauthserver/handlers.go | 2 +- internal/api/shared/context.go | 36 ++ internal/conf/configuration.go | 6 +- internal/models/flow_state.go | 22 +- internal/models/oauth_authorization.go | 279 ++++++++++ internal/models/oauth_authorization_test.go | 331 ++++++++++++ internal/models/oauth_consent.go | 189 +++++++ internal/models/oauth_consent_test.go | 105 ++++ internal/models/oauth_scope.go | 47 ++ internal/models/oauth_scope_test.go | 199 +++++++ internal/models/pkce.go | 32 ++ internal/models/pkce_test.go | 123 +++++ ...0_add_oauth_authorizations_consents.up.sql | 86 ++++ 18 files changed, 1933 insertions(+), 34 deletions(-) create mode 100644 internal/api/oauthserver/authorize.go create mode 100644 internal/api/shared/context.go create mode 100644 internal/models/oauth_authorization.go create mode 100644 internal/models/oauth_authorization_test.go create mode 100644 internal/models/oauth_consent.go create mode 100644 internal/models/oauth_consent_test.go create mode 100644 internal/models/oauth_scope.go create mode 100644 internal/models/oauth_scope_test.go create mode 100644 internal/models/pkce.go create mode 100644 internal/models/pkce_test.go create mode 100644 migrations/20250804100000_add_oauth_authorizations_consents.up.sql diff --git a/go.mod b/go.mod index 6c06d4b84..e3db84713 100644 --- a/go.mod +++ b/go.mod @@ -173,7 +173,7 @@ require ( golang.org/x/net v0.38.0 // indirect golang.org/x/sync v0.12.0 golang.org/x/sys v0.31.0 - golang.org/x/text v0.23.0 // indirect + golang.org/x/text v0.23.0 golang.org/x/time v0.9.0 google.golang.org/appengine v1.6.8 // indirect google.golang.org/grpc v1.63.2 // indirect diff --git a/internal/api/api.go b/internal/api/api.go index bdf11265f..488714ead 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -185,6 +185,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Get("/settings", api.Settings) + // `/authorize` to initiate OAuth2 authorization flow with the external providers + // where Supabase Auth is an OAuth2 Client r.Get("/authorize", api.ExternalProviderRedirect) r.With(api.requireAdminCredentials).Post("/invite", api.Invite) @@ -345,6 +347,12 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Route("/oauth", func(r *router) { r.With(api.limitHandler(api.limiterOpts.OAuthClientRegister)). Post("/clients/register", api.oauthServer.OAuthServerClientDynamicRegister) + + // OAuth 2.1 Authorization endpoints + // `/authorize` to initiate OAuth2 authorization code flow where Supabase Auth is the OAuth2 provider + r.Get("/authorize", api.oauthServer.OAuthServerAuthorize) + r.With(api.requireAuthentication).Get("/authorizations/{authorization_id}", api.oauthServer.OAuthServerGetAuthorization) + r.With(api.requireAuthentication).Post("/authorizations/{authorization_id}/consent", api.oauthServer.OAuthServerConsent) }) }) diff --git a/internal/api/apierrors/errorcode.go b/internal/api/apierrors/errorcode.go index 710797548..7658c1225 100644 --- a/internal/api/apierrors/errorcode.go +++ b/internal/api/apierrors/errorcode.go @@ -97,4 +97,7 @@ const ( ErrorCodeWeb3UnsupportedChain ErrorCode = "web3_unsupported_chain" ErrorCodeOAuthDynamicClientRegistrationDisabled ErrorCode = "oauth_dynamic_client_registration_disabled" ErrorCodeEmailAddressNotProvided ErrorCode = "email_address_not_provided" + + ErrorCodeOAuthClientNotFound ErrorCode = "oauth_client_not_found" + ErrorCodeOAuthAuthorizationNotFound ErrorCode = "oauth_authorization_not_found" ) diff --git a/internal/api/context.go b/internal/api/context.go index 5f0f744c4..e1d285cb1 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -5,6 +5,7 @@ import ( "net/url" jwt "github.com/golang-jwt/jwt/v5" + "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/models" ) @@ -21,7 +22,6 @@ const ( tokenKey = contextKey("jwt") inviteTokenKey = contextKey("invite_token") signatureKey = contextKey("signature") - userKey = contextKey("user") targetUserKey = contextKey("target_user") factorKey = contextKey("factor") sessionKey = contextKey("session") @@ -60,7 +60,7 @@ func getClaims(ctx context.Context) *AccessTokenClaims { // withUser adds the user to the context. func withUser(ctx context.Context, u *models.User) context.Context { - return context.WithValue(ctx, userKey, u) + return shared.WithUser(ctx, u) } // withTargetUser adds the target user for linking to the context. @@ -75,14 +75,7 @@ func withFactor(ctx context.Context, f *models.Factor) context.Context { // getUser reads the user from the context. func getUser(ctx context.Context) *models.User { - if ctx == nil { - return nil - } - obj := ctx.Value(userKey) - if obj == nil { - return nil - } - return obj.(*models.User) + return shared.GetUser(ctx) } // getTargetUser reads the user from the context. diff --git a/internal/api/oauthserver/authorize.go b/internal/api/oauthserver/authorize.go new file mode 100644 index 000000000..15d849391 --- /dev/null +++ b/internal/api/oauthserver/authorize.go @@ -0,0 +1,484 @@ +package oauthserver + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/shared" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" +) + +// AuthorizeParams represents the parameters for an OAuth authorization request +type AuthorizeParams struct { + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + ResponseType string `json:"response_type"` + Scope string `json:"scope"` + State string `json:"state"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` +} + +// AuthorizationDetailsResponse represents the response for getting authorization details +type AuthorizationDetailsResponse struct { + AuthorizationID string `json:"authorization_id"` + Client ClientDetailsResponse `json:"client"` + User UserDetailsResponse `json:"user"` + Scope string `json:"scope"` +} + +// ClientDetailsResponse represents client details in authorization response +type ClientDetailsResponse struct { + ClientID string `json:"client_id"` + ClientName string `json:"client_name"` + ClientURI string `json:"client_uri"` + LogoURI string `json:"logo_uri"` +} + +// UserDetailsResponse represents user details in authorization response +type UserDetailsResponse struct { + ID string `json:"id"` + Email string `json:"email"` +} + +// ConsentRequest represents a consent decision request +type ConsentRequest struct { + Action OAuthServerConsentAction `json:"action"` +} + +// ConsentResponse represents the response after processing consent +type ConsentResponse struct { + RedirectURL string `json:"redirect_url"` +} + +type OAuthServerConsentAction string + +const ( + OAuthServerConsentActionApprove OAuthServerConsentAction = "approve" + OAuthServerConsentActionDeny OAuthServerConsentAction = "deny" +) + +// OAuthServerAuthorize handles GET /oauth/authorize +func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := s.db.WithContext(ctx) + config := s.config + + // Validate OAuth2 authorize parameters + params, err := s.validateAuthorizeParams(r) + if err != nil { + return err + } + + // Validate client exists and redirect_uri matches + client, err := s.getOAuthServerClient(ctx, params.ClientID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthClientNotFound, "invalid client_id") + } + return apierrors.NewInternalServerError("error validating client").WithInternalError(err) + } + + if !s.isValidRedirectURI(client, params.RedirectURI) { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid redirect_uri") + } + + // Store authorization request in database (without user initially) + authorization := models.NewOAuthServerAuthorization( + params.ClientID, + params.RedirectURI, + params.Scope, + params.State, + params.CodeChallenge, + params.CodeChallengeMethod, + ) + + if err := models.CreateOAuthServerAuthorization(db, authorization); err != nil { + return apierrors.NewInternalServerError("error creating authorization").WithInternalError(err) + } + + observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) + observability.LogEntrySetField(r, "client_id", authorization.ClientID) + + // Redirect to configured authorization URL with authorization_id + // TODO(cemal): should this be really a different config or something set on the config.SiteURL? + if config.OAuthServer.AuthorizationURL == "" { + return apierrors.NewInternalServerError("oauth authorization URL not configured") + } + + redirectURL := fmt.Sprintf("%s?authorization_id=%s", + config.OAuthServer.AuthorizationURL, + authorization.AuthorizationID) + + http.Redirect(w, r, redirectURL, http.StatusFound) + return nil +} + +// OAuthServerGetAuthorization handles GET /oauth/authorizations/{authorization_id} +func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := s.db.WithContext(ctx) + + // Get authenticated user + user := shared.GetUser(ctx) + if user == nil { + return apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "authentication required") + } + + authorizationID := chi.URLParam(r, "authorization_id") + authorization, err := s.validateAndFindAuthorization(r, db, authorizationID) + if err != nil { + return err + } + + // Set user_id if not already set + if authorization.UserID == nil { + // Use transaction to atomically set user and check for auto-approve + var shouldAutoApprove bool + var existingConsent *models.OAuthServerConsent + + err := db.Transaction(func(tx *storage.Connection) error { + if err := authorization.SetUser(tx, user.ID); err != nil { + return err + } + + // Check for existing consent and auto-approve if available + var err error + existingConsent, err = models.FindActiveOAuthServerConsentByUserAndClient(tx, user.ID, authorization.ClientID) + if err != nil { + return err + } + + // Check if consent covers requested scopes + if existingConsent != nil && s.consentCoversScopes(existingConsent, authorization.Scope) { + shouldAutoApprove = true + } + + return nil + }) + + if err != nil { + return apierrors.NewInternalServerError("error setting user and checking consent").WithInternalError(err) + } + + // If we should auto-approve, do it now + if shouldAutoApprove { + return s.autoApproveAndRedirect(w, r, authorization) + } + } else { + // Authorization already has user_id set, validate ownership + if err := s.validateAuthorizationOwnership(r, authorization, user); err != nil { + return err + } + } + + // Build response with client and user details + response := AuthorizationDetailsResponse{ + AuthorizationID: authorization.AuthorizationID, + Client: ClientDetailsResponse{ + ClientID: authorization.Client.ClientID, + ClientName: authorization.Client.ClientName.String(), + ClientURI: authorization.Client.ClientURI.String(), + LogoURI: authorization.Client.LogoURI.String(), + }, + User: UserDetailsResponse{ + ID: user.ID.String(), + Email: user.Email.String(), + }, + Scope: authorization.Scope, + } + + observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) + observability.LogEntrySetField(r, "client_id", authorization.ClientID) + + return shared.SendJSON(w, http.StatusOK, response) +} + +// OAuthServerConsent handles POST /oauth/authorizations/{authorization_id}/consent +func (s *Server) OAuthServerConsent(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := s.db.WithContext(ctx) + + // Get authenticated user + user := shared.GetUser(ctx) + if user == nil { + return apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "authentication required") + } + + var body ConsentRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeBadJSON, "invalid JSON body") + } + + if body.Action != OAuthServerConsentActionApprove && body.Action != OAuthServerConsentActionDeny { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "action must be 'approve' or 'deny'") + } + + // Validate and find authorization outside transaction first + authorizationID := chi.URLParam(r, "authorization_id") + observability.LogEntrySetField(r, "authorization_id", authorizationID) + authorization, err := s.validateAndFindAuthorization(r, db, authorizationID) + if err != nil { + return err + } + + // Ensure authorization belongs to authenticated user + if err := s.validateAuthorizationOwnership(r, authorization, user); err != nil { + return err + } + + // Process consent in transaction + var redirectURL string + err = db.Transaction(func(tx *storage.Connection) error { + // Re-fetch in transaction to ensure consistency + authorization, err := models.FindOAuthServerAuthorizationByID(tx, authorizationID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") + } + return apierrors.NewInternalServerError("error finding authorization").WithInternalError(err) + } + + // Re-check expiration and status in transaction (state could have changed) + if authorization.IsExpired() { + if err := authorization.MarkExpired(tx); err != nil { + observability.GetLogEntry(r).Entry.WithError(err).Warn("failed to mark authorization as expired") + } + return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") + } + + if authorization.Status != models.OAuthServerAuthorizationPending { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request is no longer pending") + } + + if body.Action == OAuthServerConsentActionApprove { + // Approve authorization + if err := authorization.Approve(tx); err != nil { + return apierrors.NewInternalServerError("error approving authorization").WithInternalError(err) + } + + // Store consent for future use + scopes := authorization.GetScopeList() + consent := models.NewOAuthServerConsent(user.ID, authorization.ClientID, scopes) + if err := models.UpsertOAuthServerConsent(tx, consent); err != nil { + return apierrors.NewInternalServerError("error storing consent").WithInternalError(err) + } + + // Build success redirect URL + redirectURL = s.buildSuccessRedirectURL(authorization) + + observability.LogEntrySetField(r, "oauth_consent_action", string(OAuthServerConsentActionApprove)) + + } else { + // Deny authorization + if err := authorization.Deny(tx); err != nil { + return apierrors.NewInternalServerError("error denying authorization").WithInternalError(err) + } + + // Build error redirect URL + // Errors are being returned to the client in the redirect url per OAuth2 spec + redirectURL = s.buildErrorRedirectURL(authorization, "access_denied", "User denied the request") + + observability.LogEntrySetField(r, "oauth_consent_action", string(OAuthServerConsentActionDeny)) + } + + return nil + }) + + if err != nil { + return err + } + + // Return redirect URL to frontend + response := ConsentResponse{ + RedirectURL: redirectURL, + } + + return shared.SendJSON(w, http.StatusOK, response) +} + +// Helper functions + +// validateAndFindAuthorization validates the authorization_id parameter and finds the authorization, +// performing all necessary checks (existence, expiration, status) +func (s *Server) validateAndFindAuthorization(r *http.Request, db *storage.Connection, authorizationID string) (*models.OAuthServerAuthorization, error) { + if authorizationID == "" { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization_id is required") + } + + authorization, err := models.FindOAuthServerAuthorizationByID(db, authorizationID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") + } + return nil, apierrors.NewInternalServerError("error finding authorization").WithInternalError(err) + } + + // Check if expired first - no point processing expired authorizations + if authorization.IsExpired() { + // Mark as expired in database + if err := authorization.MarkExpired(db); err != nil { + observability.GetLogEntry(r).Entry.WithError(err).Warn("failed to mark authorization as expired") + } + // returning not found to avoid leaking information about the existence of the authorization + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") + } + + // Check if still pending + if authorization.Status != models.OAuthServerAuthorizationPending { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request cannot be processed") + } + + return authorization, nil +} + +// validateAuthorizationOwnership checks if the authorization belongs to the authenticated user +func (s *Server) validateAuthorizationOwnership(r *http.Request, authorization *models.OAuthServerAuthorization, user *models.User) error { + if authorization.UserID == nil || *authorization.UserID != user.ID { + observability.GetLogEntry(r).Entry. + WithField("request_user_id", user.ID). + WithField("authorization_id", authorization.AuthorizationID). + Warn("authorization belongs to different user") + return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") + } + return nil +} + +func (s *Server) validateAuthorizeParams(r *http.Request) (*AuthorizeParams, error) { + query := r.URL.Query() + + params := &AuthorizeParams{ + ClientID: query.Get("client_id"), + RedirectURI: query.Get("redirect_uri"), + ResponseType: query.Get("response_type"), + Scope: query.Get("scope"), + State: query.Get("state"), + CodeChallenge: query.Get("code_challenge"), + CodeChallengeMethod: query.Get("code_challenge_method"), + } + + // Required parameters + if params.ClientID == "" { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_id is required") + } + if params.RedirectURI == "" { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "redirect_uri is required") + } + + // Default values + if params.ResponseType == "" { + params.ResponseType = models.OAuthServerResponseTypeCode.String() + } + if params.Scope == "" { + params.Scope = s.config.OAuthServer.DefaultScope + } + + // OAuth 2.1 only supports "code" response type + if params.ResponseType != models.OAuthServerResponseTypeCode.String() { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only response_type=code is supported") + } + + // PKCE validation + if err := s.validatePKCEParams(params.CodeChallengeMethod, params.CodeChallenge); err != nil { + return nil, err + } + + return params, nil +} + +func (s *Server) validatePKCEParams(codeChallengeMethod, codeChallenge string) error { + // PKCE is mandatory for the authorization code flow OAuth2.1 + // Both code_challenge and code_challenge_method must be provided together + if codeChallenge == "" || codeChallengeMethod == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "PKCE flow requires both code_challenge and code_challenge_method") + } + + // Validate code challenge method (case-insensitive) + if strings.ToLower(codeChallengeMethod) != "s256" && strings.ToLower(codeChallengeMethod) != "plain" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, + "code_challenge_method must be 'S256' or 'plain'") + } + + // Validate code challenge format and length (per OAuth2 spec) + if len(codeChallenge) < 43 || len(codeChallenge) > 128 { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, + "code_challenge must be between 43 and 128 characters") + } + + return nil +} + +func (s *Server) isValidRedirectURI(client *models.OAuthServerClient, redirectURI string) bool { + registeredURIs := client.GetRedirectURIs() + for _, registeredURI := range registeredURIs { + // exact string matching per OAuth2 spec + if registeredURI == redirectURI { + return true + } + } + return false +} + +func (s *Server) consentCoversScopes(consent *models.OAuthServerConsent, requestedScope string) bool { + if consent.IsRevoked() { + return false + } + + requestedScopes := models.ParseScopeString(requestedScope) + return consent.HasAllScopes(requestedScopes) +} + +func (s *Server) autoApproveAndRedirect(w http.ResponseWriter, r *http.Request, authorization *models.OAuthServerAuthorization) error { + ctx := r.Context() + db := s.db.WithContext(ctx) + + // Approve the authorization in a transaction + err := db.Transaction(func(tx *storage.Connection) error { + return authorization.Approve(tx) + }) + + if err != nil { + return apierrors.NewInternalServerError("Error auto-approving authorization").WithInternalError(err) + } + + observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) + observability.LogEntrySetField(r, "auto_approved", true) + + // Return JSON with redirect URL (same format as consent endpoint) + redirectURL := s.buildSuccessRedirectURL(authorization) + response := ConsentResponse{ + RedirectURL: redirectURL, + } + + return shared.SendJSON(w, http.StatusOK, response) +} + +func (s *Server) buildSuccessRedirectURL(authorization *models.OAuthServerAuthorization) string { + u, _ := url.Parse(authorization.RedirectURI) + q := u.Query() + q.Set("code", authorization.AuthorizationCode.String()) + if authorization.State.String() != "" { + q.Set("state", authorization.State.String()) + } + u.RawQuery = q.Encode() + return u.String() +} + +func (s *Server) buildErrorRedirectURL(authorization *models.OAuthServerAuthorization, errorCode, errorDescription string) string { + u, _ := url.Parse(authorization.RedirectURI) + q := u.Query() + q.Set("error", errorCode) + q.Set("error_description", errorDescription) + if authorization.State.String() != "" { + q.Set("state", authorization.State.String()) + } + u.RawQuery = q.Encode() + return u.String() +} diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index 27093e80e..fa827240f 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -80,7 +80,7 @@ func (s *Server) LoadOAuthServerClient(w http.ResponseWriter, r *http.Request) ( client, err := s.getOAuthServerClient(ctx, clientID) if err != nil { if models.IsNotFoundError(err) { - return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeUserNotFound, "OAuth client not found") + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthClientNotFound, "OAuth client not found") } return nil, apierrors.NewInternalServerError("Error loading OAuth client").WithInternalError(err) } diff --git a/internal/api/shared/context.go b/internal/api/shared/context.go new file mode 100644 index 000000000..10fcfdb23 --- /dev/null +++ b/internal/api/shared/context.go @@ -0,0 +1,36 @@ +package shared + +import ( + "context" + + "github.com/supabase/auth/internal/models" +) + +// ContextKey is the type for context keys to avoid collisions +type ContextKey string + +func (c ContextKey) String() string { + return "gotrue api context key " + string(c) +} + +// Context keys used across packages +const ( + UserKey ContextKey = "user" +) + +// GetUser reads the user from the context - shared implementation +func GetUser(ctx context.Context) *models.User { + if ctx == nil { + return nil + } + obj := ctx.Value(UserKey) + if obj == nil { + return nil + } + return obj.(*models.User) +} + +// WithUser adds the user to the context - shared implementation +func WithUser(ctx context.Context, u *models.User) context.Context { + return context.WithValue(ctx, UserKey, u) +} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 02f5aeaec..5a22cb109 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -73,7 +73,11 @@ type OAuthProviderConfiguration struct { // OAuthServerConfiguration holds OAuth server configuration type OAuthServerConfiguration struct { - AllowDynamicRegistration bool `json:"allow_dynamic_registration" split_words:"true"` + AllowDynamicRegistration bool `json:"allow_dynamic_registration" split_words:"true"` + AuthorizationURL string `json:"authorization_url" split_words:"true"` + AuthorizationTimeout time.Duration `json:"authorization_timeout" split_words:"true" default:"5m"` + // Placeholder for now, for (near) future extensibility + DefaultScope string `json:"default_scope" split_words:"true" default:"email"` } type AnonymousProviderConfiguration struct { diff --git a/internal/models/flow_state.go b/internal/models/flow_state.go index 9a770d81d..a484f4a74 100644 --- a/internal/models/flow_state.go +++ b/internal/models/flow_state.go @@ -1,10 +1,7 @@ package models import ( - "crypto/sha256" - "crypto/subtle" "database/sql" - "encoding/base64" "fmt" "strings" "time" @@ -15,8 +12,6 @@ import ( "github.com/gofrs/uuid" ) -const InvalidCodeChallengeError = "code challenge does not match previously saved code verifier" -const InvalidCodeMethodError = "code challenge method not supported" type FlowState struct { ID uuid.UUID `json:"id" db:"id"` @@ -134,22 +129,7 @@ func FindFlowStateByUserID(tx *storage.Connection, id string, authenticationMeth } func (f *FlowState) VerifyPKCE(codeVerifier string) error { - switch f.CodeChallengeMethod { - case SHA256.String(): - hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier)) - encodedCodeVerifier := base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:]) - if subtle.ConstantTimeCompare([]byte(f.CodeChallenge), []byte(encodedCodeVerifier)) != 1 { - return errors.New(InvalidCodeChallengeError) - } - case Plain.String(): - if subtle.ConstantTimeCompare([]byte(f.CodeChallenge), []byte(codeVerifier)) != 1 { - return errors.New(InvalidCodeChallengeError) - } - default: - return errors.New(InvalidCodeMethodError) - - } - return nil + return VerifyPKCEChallenge(f.CodeChallenge, f.CodeChallengeMethod, codeVerifier) } func (f *FlowState) IsExpired(expiryDuration time.Duration) bool { diff --git a/internal/models/oauth_authorization.go b/internal/models/oauth_authorization.go new file mode 100644 index 000000000..9cedae35d --- /dev/null +++ b/internal/models/oauth_authorization.go @@ -0,0 +1,279 @@ +package models + +import ( + "database/sql" + "fmt" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/storage" +) + +// OAuthServerAuthorizationStatus represents the status of an OAuth server authorization request +type OAuthServerAuthorizationStatus string + +const ( + OAuthServerAuthorizationPending OAuthServerAuthorizationStatus = "pending" + OAuthServerAuthorizationApproved OAuthServerAuthorizationStatus = "approved" + OAuthServerAuthorizationDenied OAuthServerAuthorizationStatus = "denied" + OAuthServerAuthorizationExpired OAuthServerAuthorizationStatus = "expired" +) + +func (s OAuthServerAuthorizationStatus) String() string { + return string(s) +} + +// OAuthServerResponseType represents the OAuth server response type +type OAuthServerResponseType string + +const ( + OAuthServerResponseTypeCode OAuthServerResponseType = "code" +) + +func (rt OAuthServerResponseType) String() string { + return string(rt) +} + +// OAuthServerAuthorization represents an OAuth 2.1 server authorization request +type OAuthServerAuthorization struct { + ID uuid.UUID `json:"-" db:"id"` + AuthorizationID string `json:"authorization_id" db:"authorization_id"` + ClientID string `json:"client_id" db:"client_id"` + UserID *uuid.UUID `json:"user_id" db:"user_id"` + RedirectURI string `json:"redirect_uri" db:"redirect_uri"` + Scope string `json:"scope" db:"scope"` + State storage.NullString `json:"state" db:"state"` + CodeChallenge storage.NullString `json:"code_challenge" db:"code_challenge"` + CodeChallengeMethod storage.NullString `json:"code_challenge_method" db:"code_challenge_method"` + ResponseType OAuthServerResponseType `json:"response_type" db:"response_type"` + Status OAuthServerAuthorizationStatus `json:"status" db:"status"` + AuthorizationCode storage.NullString `json:"-" db:"authorization_code"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + ExpiresAt time.Time `json:"expires_at" db:"expires_at"` + ApprovedAt *time.Time `json:"approved_at" db:"approved_at"` + + // Relations with OAuth clients + Client *OAuthServerClient `json:"client,omitempty" db:"-"` +} + +// TableName returns the table name for the OAuthServerAuthorization model +func (OAuthServerAuthorization) TableName() string { + return "oauth_authorizations" +} + +// NewOAuthServerAuthorization creates a new OAuth server authorization request without user (for initial flow) +func NewOAuthServerAuthorization(clientID, redirectURI, scope, state, codeChallenge, codeChallengeMethod string) *OAuthServerAuthorization { + id := uuid.Must(uuid.NewV4()) + authorizationID := crypto.SecureAlphanumeric(32) // Generate random ID for frontend + + now := time.Now() + expiresAt := now.Add(10 * time.Minute) // 10 minute expiration + + auth := &OAuthServerAuthorization{ + ID: id, + AuthorizationID: authorizationID, + ClientID: clientID, + UserID: nil, // No user yet + RedirectURI: redirectURI, + Scope: scope, + ResponseType: OAuthServerResponseTypeCode, + Status: OAuthServerAuthorizationPending, + CreatedAt: now, + ExpiresAt: expiresAt, + } + + if state != "" { + auth.State = storage.NullString(state) + } + if codeChallenge != "" { + auth.CodeChallenge = storage.NullString(codeChallenge) + } + if codeChallengeMethod != "" { + auth.CodeChallengeMethod = storage.NullString(codeChallengeMethod) + } + + return auth +} + +// IsExpired checks if the authorization request has expired +func (auth *OAuthServerAuthorization) IsExpired() bool { + return time.Now().After(auth.ExpiresAt) +} + +// SetUser sets the user ID for the authorization request (after login) +func (auth *OAuthServerAuthorization) SetUser(tx *storage.Connection, userID uuid.UUID) error { + auth.UserID = &userID + return tx.UpdateOnly(auth, "user_id") +} + +// GetScopeList returns the scopes as a slice +func (auth *OAuthServerAuthorization) GetScopeList() []string { + return ParseScopeString(auth.Scope) +} + +// GenerateAuthorizationCode generates a new authorization code if not already set +func (auth *OAuthServerAuthorization) GenerateAuthorizationCode() string { + if auth.AuthorizationCode.String() != "" { + return auth.AuthorizationCode.String() + } + + code := uuid.Must(uuid.NewV4()).String() + auth.AuthorizationCode = storage.NullString(code) + return code +} + +// Approve approves the authorization request and generates an authorization code +func (auth *OAuthServerAuthorization) Approve(tx *storage.Connection) error { + if auth.IsExpired() { + return fmt.Errorf("authorization request has expired") + } + + if auth.Status != OAuthServerAuthorizationPending { + return fmt.Errorf("authorization request is not pending (current status: %s)", auth.Status) + } + + now := time.Now() + auth.Status = OAuthServerAuthorizationApproved + auth.ApprovedAt = &now + auth.GenerateAuthorizationCode() + + return tx.UpdateOnly(auth, "status", "approved_at", "authorization_code") +} + +// Deny denies the authorization request +func (auth *OAuthServerAuthorization) Deny(tx *storage.Connection) error { + if auth.Status != OAuthServerAuthorizationPending { + return fmt.Errorf("authorization request is not pending (current status: %s)", auth.Status) + } + + auth.Status = OAuthServerAuthorizationDenied + return tx.UpdateOnly(auth, "status") +} + +// MarkExpired marks the authorization request as expired +func (auth *OAuthServerAuthorization) MarkExpired(tx *storage.Connection) error { + if auth.Status != OAuthServerAuthorizationPending { + return fmt.Errorf("authorization request is not pending (current status: %s)", auth.Status) + } + + auth.Status = OAuthServerAuthorizationExpired + return tx.UpdateOnly(auth, "status") +} + +// Validate performs basic validation on the OAuth authorization +func (auth *OAuthServerAuthorization) Validate() error { + if auth.ClientID == "" { + return fmt.Errorf("client_id is required") + } + // UserID can be nil initially for unauthenticated authorization requests + // It will be set when user authenticates + if auth.RedirectURI == "" { + return fmt.Errorf("redirect_uri is required") + } + if auth.Scope == "" { + return fmt.Errorf("scope is required") + } + if auth.ResponseType != OAuthServerResponseTypeCode { + return fmt.Errorf("only response_type=code is supported") + } + if auth.ExpiresAt.Before(auth.CreatedAt) { + return fmt.Errorf("expires_at must be after created_at") + } + + return nil +} + +// VerifyPKCE verifies the PKCE code verifier against the stored challenge +func (auth *OAuthServerAuthorization) VerifyPKCE(codeVerifier string) error { + if auth.CodeChallenge.String() == "" { + // No PKCE challenge stored, verification passes + return nil + } + + if codeVerifier == "" { + return fmt.Errorf("code_verifier is required when PKCE challenge is present") + } + + // Use the shared PKCE verification function + return VerifyPKCEChallenge(auth.CodeChallenge.String(), auth.CodeChallengeMethod.String(), codeVerifier) +} + +// Query functions for OAuth authorizations + +// FindOAuthServerAuthorizationByID finds an OAuth authorization by authorization_id +func FindOAuthServerAuthorizationByID(tx *storage.Connection, authorizationID string) (*OAuthServerAuthorization, error) { + auth := &OAuthServerAuthorization{} + if err := tx.Q().Where("authorization_id = ?", authorizationID).First(auth); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, OAuthServerAuthorizationNotFoundError{} + } + return nil, errors.Wrap(err, "error finding OAuth authorization") + } + + if auth.ClientID != "" { + client := &OAuthServerClient{} + if err := tx.Q().Where("client_id = ?", auth.ClientID).First(client); err == nil { + auth.Client = client + } + } + + return auth, nil +} + +// FindOAuthServerAuthorizationByCode finds an OAuth authorization by authorization code +func FindOAuthServerAuthorizationByCode(tx *storage.Connection, code string) (*OAuthServerAuthorization, error) { + auth := &OAuthServerAuthorization{} + if err := tx.Q().Where("authorization_code = ? AND status = ?", code, OAuthServerAuthorizationApproved).First(auth); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, OAuthServerAuthorizationNotFoundError{} + } + return nil, errors.Wrap(err, "error finding OAuth authorization by code") + } + + // Load client relationship (always present) + if auth.ClientID != "" { + client := &OAuthServerClient{} + if err := tx.Q().Where("client_id = ?", auth.ClientID).First(client); err == nil { + auth.Client = client + } + } + + return auth, nil +} + +// CreateOAuthServerAuthorization creates a new OAuth authorization in the database +func CreateOAuthServerAuthorization(tx *storage.Connection, auth *OAuthServerAuthorization) error { + if err := auth.Validate(); err != nil { + return err + } + + if auth.ID == uuid.Nil { + auth.ID = uuid.Must(uuid.NewV4()) + } + + if auth.AuthorizationID == "" { + auth.AuthorizationID = crypto.SecureAlphanumeric(32) + } + + return tx.Create(auth) +} + +// CleanupExpiredOAuthServerAuthorizations marks expired authorizations as expired +func CleanupExpiredOAuthServerAuthorizations(tx *storage.Connection) error { + query := ` + UPDATE ` + (&OAuthServerAuthorization{}).TableName() + ` + SET status = ? + WHERE status = ? AND expires_at < now() + ` + return tx.RawQuery(query, OAuthServerAuthorizationExpired, OAuthServerAuthorizationPending).Exec() +} + +// Error types for OAuth authorization operations + +type OAuthServerAuthorizationNotFoundError struct{} + +func (e OAuthServerAuthorizationNotFoundError) Error() string { + return "OAuth authorization not found" +} diff --git a/internal/models/oauth_authorization_test.go b/internal/models/oauth_authorization_test.go new file mode 100644 index 000000000..8021ce6bd --- /dev/null +++ b/internal/models/oauth_authorization_test.go @@ -0,0 +1,331 @@ +package models + +import ( + "fmt" + "testing" + "time" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" +) + +func TestNewOAuthServerAuthorization(t *testing.T) { + clientID := "test-client-id" + redirectURI := "https://example.com/callback" + scope := "openid profile" + state := "random-state" + codeChallenge := "test-challenge" + codeChallengeMethod := "S256" + + auth := NewOAuthServerAuthorization(clientID, redirectURI, scope, state, codeChallenge, codeChallengeMethod) + + assert.NotEmpty(t, auth.ID) + assert.NotEmpty(t, auth.AuthorizationID) + assert.Equal(t, clientID, auth.ClientID) + assert.Nil(t, auth.UserID) + assert.Equal(t, redirectURI, auth.RedirectURI) + assert.Equal(t, scope, auth.Scope) + assert.Equal(t, state, auth.State.String()) + assert.Equal(t, codeChallenge, auth.CodeChallenge.String()) + assert.Equal(t, codeChallengeMethod, auth.CodeChallengeMethod.String()) + assert.Equal(t, OAuthServerResponseTypeCode, auth.ResponseType) + assert.Equal(t, OAuthServerAuthorizationPending, auth.Status) + assert.True(t, auth.ExpiresAt.After(auth.CreatedAt)) + assert.Nil(t, auth.ApprovedAt) +} + +func TestOAuthServerAuthorization_IsExpired(t *testing.T) { + auth := &OAuthServerAuthorization{ + CreatedAt: time.Now().Add(-1 * time.Hour), + ExpiresAt: time.Now().Add(-30 * time.Minute), // Expired 30 minutes ago + } + + assert.True(t, auth.IsExpired()) + + auth.ExpiresAt = time.Now().Add(30 * time.Minute) // Expires in 30 minutes + assert.False(t, auth.IsExpired()) +} + +func TestOAuthServerAuthorization_GenerateAuthorizationCode(t *testing.T) { + auth := &OAuthServerAuthorization{} + + // First call should generate a code + code1 := auth.GenerateAuthorizationCode() + assert.NotEmpty(t, code1) + assert.Equal(t, code1, auth.AuthorizationCode.String()) + + // Second call should return the same code + code2 := auth.GenerateAuthorizationCode() + assert.Equal(t, code1, code2) +} + +func TestOAuthServerAuthorization_ApproveErrCases(t *testing.T) { + tests := []struct { + name string + auth *OAuthServerAuthorization + wantErr bool + errMsg string + }{ + { + name: "approval of expired authorization", + auth: &OAuthServerAuthorization{ + Status: OAuthServerAuthorizationPending, + CreatedAt: time.Now().Add(-1 * time.Hour), + ExpiresAt: time.Now().Add(-30 * time.Minute), // Expired + }, + wantErr: true, + errMsg: "authorization request has expired", + }, + { + name: "approval of already approved authorization", + auth: &OAuthServerAuthorization{ + Status: OAuthServerAuthorizationApproved, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + wantErr: true, + errMsg: "authorization request is not pending", + }, + { + name: "approval of denied authorization", + auth: &OAuthServerAuthorization{ + Status: OAuthServerAuthorizationDenied, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + wantErr: true, + errMsg: "authorization request is not pending", + }, + { + name: "approval of expired status authorization", + auth: &OAuthServerAuthorization{ + Status: OAuthServerAuthorizationExpired, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + wantErr: true, + errMsg: "authorization request is not pending", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test validation logic before database operations + if tt.auth.IsExpired() { + err := fmt.Errorf("authorization request has expired") + assert.Error(t, err) + assert.Contains(t, err.Error(), "authorization request has expired") + return + } + + if tt.auth.Status != OAuthServerAuthorizationPending { + err := fmt.Errorf("authorization request is not pending (current status: %s)", tt.auth.Status) + assert.Error(t, err) + assert.Contains(t, err.Error(), "authorization request is not pending") + return + } + + // If we get here, it should be valid for approval + assert.False(t, tt.wantErr, "Expected error but validation passed") + }) + } +} + +func TestOAuthServerAuthorization_ApproveSuccess(t *testing.T) { + auth := &OAuthServerAuthorization{ + Status: OAuthServerAuthorizationPending, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(10 * time.Minute), + } + + // Test the in-memory state changes that happen during approval + beforeTime := time.Now() + auth.Status = OAuthServerAuthorizationApproved + now := time.Now() + auth.ApprovedAt = &now + auth.GenerateAuthorizationCode() + + assert.Equal(t, OAuthServerAuthorizationApproved, auth.Status) + assert.NotNil(t, auth.ApprovedAt) + assert.True(t, auth.ApprovedAt.After(beforeTime)) + assert.NotEmpty(t, auth.AuthorizationCode.String()) +} + +func TestOAuthServerAuthorization_Deny(t *testing.T) { + tests := []struct { + name string + status OAuthServerAuthorizationStatus + wantErr bool + errMsg string + }{ + { + name: "deny pending authorization", + status: OAuthServerAuthorizationPending, + wantErr: false, + }, + { + name: "deny already approved authorization", + status: OAuthServerAuthorizationApproved, + wantErr: true, + errMsg: "authorization request is not pending", + }, + { + name: "deny already denied authorization", + status: OAuthServerAuthorizationDenied, + wantErr: true, + errMsg: "authorization request is not pending", + }, + { + name: "deny expired authorization", + status: OAuthServerAuthorizationExpired, + wantErr: true, + errMsg: "authorization request is not pending", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := &OAuthServerAuthorization{Status: tt.status} + + if auth.Status != OAuthServerAuthorizationPending { + err := fmt.Errorf("authorization request is not pending (current status: %s)", auth.Status) + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } + return + } + + // Test successful denial state change + auth.Status = OAuthServerAuthorizationDenied + assert.Equal(t, OAuthServerAuthorizationDenied, auth.Status) + }) + } +} + +func TestOAuthServerAuthorization_MarkExpiredLogic(t *testing.T) { + tests := []struct { + name string + status OAuthServerAuthorizationStatus + wantErr bool + errMsg string + }{ + { + name: "mark pending authorization as expired", + status: OAuthServerAuthorizationPending, + wantErr: false, + }, + { + name: "mark approved authorization as expired", + status: OAuthServerAuthorizationApproved, + wantErr: true, + errMsg: "authorization request is not pending", + }, + { + name: "mark denied authorization as expired", + status: OAuthServerAuthorizationDenied, + wantErr: true, + errMsg: "authorization request is not pending", + }, + { + name: "mark already expired authorization as expired", + status: OAuthServerAuthorizationExpired, + wantErr: true, + errMsg: "authorization request is not pending", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := &OAuthServerAuthorization{Status: tt.status} + + if auth.Status != OAuthServerAuthorizationPending { + err := fmt.Errorf("authorization request is not pending (current status: %s)", auth.Status) + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } + return + } + + // Test successful expiration state change + auth.Status = OAuthServerAuthorizationExpired + assert.Equal(t, OAuthServerAuthorizationExpired, auth.Status) + }) + } +} + +func TestOAuthServerAuthorization_Validate(t *testing.T) { + userID := uuid.Must(uuid.NewV4()) + validAuth := &OAuthServerAuthorization{ + ClientID: "test-client", + UserID: &userID, + RedirectURI: "https://example.com/callback", + Scope: "openid", + ResponseType: OAuthServerResponseTypeCode, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(10 * time.Minute), + } + + // Valid authorization should pass + assert.NoError(t, validAuth.Validate()) + + // Test UserID can be nil (for unauthenticated requests) + validAuthNoUser := *validAuth + validAuthNoUser.UserID = nil + assert.NoError(t, validAuthNoUser.Validate()) + + // Test invalid cases + tests := []struct { + name string + modify func(*OAuthServerAuthorization) + wantErr bool + errMsg string + }{ + { + name: "missing client_id", + modify: func(a *OAuthServerAuthorization) { a.ClientID = "" }, + wantErr: true, + errMsg: "client_id is required", + }, + { + name: "missing redirect_uri", + modify: func(a *OAuthServerAuthorization) { a.RedirectURI = "" }, + wantErr: true, + errMsg: "redirect_uri is required", + }, + { + name: "missing scope", + modify: func(a *OAuthServerAuthorization) { a.Scope = "" }, + wantErr: true, + errMsg: "scope is required", + }, + { + name: "invalid response_type", + modify: func(a *OAuthServerAuthorization) { a.ResponseType = "token" }, + wantErr: true, + errMsg: "only response_type=code is supported", + }, + { + name: "expires_at before created_at", + modify: func(a *OAuthServerAuthorization) { a.ExpiresAt = a.CreatedAt.Add(-1 * time.Minute) }, + wantErr: true, + errMsg: "expires_at must be after created_at", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := *validAuth // Copy + tt.modify(&auth) + + err := auth.Validate() + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/models/oauth_consent.go b/internal/models/oauth_consent.go new file mode 100644 index 000000000..5b076dcca --- /dev/null +++ b/internal/models/oauth_consent.go @@ -0,0 +1,189 @@ +package models + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +// OAuthServerConsent represents user consent for an OAuth server client's access to specific scopes +type OAuthServerConsent struct { + ID uuid.UUID `json:"id" db:"id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + ClientID string `json:"client_id" db:"client_id"` + Scopes string `json:"scopes" db:"scopes"` + GrantedAt time.Time `json:"granted_at" db:"granted_at"` + RevokedAt *time.Time `json:"revoked_at" db:"revoked_at"` +} + +// TableName returns the table name for the OAuthConsent model +func (OAuthServerConsent) TableName() string { + return "oauth_consents" +} + +// NewOAuthConsent creates a new OAuth consent record +func NewOAuthServerConsent(userID uuid.UUID, clientID string, scopes []string) *OAuthServerConsent { + return &OAuthServerConsent{ + ID: uuid.Must(uuid.NewV4()), + UserID: userID, + ClientID: clientID, + Scopes: strings.Join(scopes, " "), + GrantedAt: time.Now(), + } +} + +// GetScopeList returns the granted scopes as a slice +func (consent *OAuthServerConsent) GetScopeList() []string { + return ParseScopeString(consent.Scopes) +} + +// HasScope checks if the consent includes a specific scope +func (consent *OAuthServerConsent) HasScope(scope string) bool { + return HasScope(consent.GetScopeList(), scope) +} + +// HasAllScopes checks if the consent includes all of the requested scopes +func (consent *OAuthServerConsent) HasAllScopes(requestedScopes []string) bool { + return HasAllScopes(consent.GetScopeList(), requestedScopes) +} + +// IsRevoked checks if the consent has been revoked +func (consent *OAuthServerConsent) IsRevoked() bool { + return consent.RevokedAt != nil +} + +// Revoke revokes the consent +func (consent *OAuthServerConsent) Revoke(tx *storage.Connection) error { + if consent.IsRevoked() { + return fmt.Errorf("consent is already revoked") + } + + now := time.Now() + consent.RevokedAt = &now + return tx.UpdateOnly(consent, "revoked_at") +} + +// UpdateScopes updates the granted scopes for this consent +func (consent *OAuthServerConsent) UpdateScopes(tx *storage.Connection, scopes []string) error { + if consent.IsRevoked() { + return fmt.Errorf("cannot update scopes for revoked consent") + } + + consent.Scopes = strings.Join(scopes, " ") + consent.GrantedAt = time.Now() // Update granted time to reflect the change + return tx.UpdateOnly(consent, "scopes", "granted_at") +} + +// Validate performs basic validation on the OAuth consent +func (consent *OAuthServerConsent) Validate() error { + if consent.UserID == uuid.Nil { + return fmt.Errorf("user_id is required") + } + if consent.ClientID == "" { + return fmt.Errorf("client_id is required") + } + if strings.TrimSpace(consent.Scopes) == "" { + return fmt.Errorf("scopes cannot be empty") + } + if consent.RevokedAt != nil && consent.RevokedAt.Before(consent.GrantedAt) { + return fmt.Errorf("revoked_at cannot be before granted_at") + } + + return nil +} + +// Query functions for OAuth consents + +// FindOAuthServerConsentByUserAndClient finds an OAuth consent by user and client +func FindOAuthServerConsentByUserAndClient(tx *storage.Connection, userID uuid.UUID, clientID string) (*OAuthServerConsent, error) { + consent := &OAuthServerConsent{} + if err := tx.Eager().Q().Where("user_id = ? AND client_id = ?", userID, clientID).First(consent); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil // No consent found (not an error) + } + return nil, errors.Wrap(err, "error finding OAuth consent") + } + return consent, nil +} + +// FindActiveOAuthServerConsentByUserAndClient finds an active (non-revoked) OAuth consent +func FindActiveOAuthServerConsentByUserAndClient(tx *storage.Connection, userID uuid.UUID, clientID string) (*OAuthServerConsent, error) { + consent := &OAuthServerConsent{} + if err := tx.Q().Where("user_id = ? AND client_id = ? AND revoked_at IS NULL", userID, clientID).First(consent); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil // No active consent found (not an error) + } + return nil, errors.Wrap(err, "error finding active OAuth consent") + } + + return consent, nil +} + +// FindOAuthServerConsentsByUser finds all OAuth consents for a user +func FindOAuthServerConsentsByUser(tx *storage.Connection, userID uuid.UUID, includeRevoked bool) ([]*OAuthServerConsent, error) { + var consents []*OAuthServerConsent + query := tx.Q().Where("user_id = ?", userID) + + if !includeRevoked { + query = query.Where("revoked_at IS NULL") + } + + if err := query.Order("granted_at desc").All(&consents); err != nil { + return nil, errors.Wrap(err, "error finding OAuth consents by user") + } + + return consents, nil +} + +// UpsertOAuthServerConsent creates or updates an OAuth consent +func UpsertOAuthServerConsent(tx *storage.Connection, consent *OAuthServerConsent) error { + if err := consent.Validate(); err != nil { + return err + } + + existing, err := FindOAuthServerConsentByUserAndClient(tx, consent.UserID, consent.ClientID) + if err != nil { + return err + } + + if existing != nil { + // Update existing consent + existing.Scopes = consent.Scopes + existing.GrantedAt = time.Now() + existing.RevokedAt = nil // Un-revoke if previously revoked + return tx.Update(existing) + } + + // Create new consent + if consent.ID == uuid.Nil { + consent.ID = uuid.Must(uuid.NewV4()) + } + return tx.Create(consent) +} + +// RevokeOAuthServerConsentsByClient revokes all consents for a specific client +func RevokeOAuthServerConsentsByClient(tx *storage.Connection, clientID string) error { + now := time.Now() + query := ` + UPDATE ` + (&OAuthServerConsent{}).TableName() + ` + SET revoked_at = ? + WHERE client_id = ? AND revoked_at IS NULL + ` + return tx.RawQuery(query, now, clientID).Exec() +} + +// RevokeOAuthServerConsentsByUser revokes all consents for a specific user +func RevokeOAuthServerConsentsByUser(tx *storage.Connection, userID uuid.UUID) error { + now := time.Now() + query := ` + UPDATE ` + (&OAuthServerConsent{}).TableName() + ` + SET revoked_at = ? + WHERE user_id = ? AND revoked_at IS NULL + ` + return tx.RawQuery(query, now, userID).Exec() +} diff --git a/internal/models/oauth_consent_test.go b/internal/models/oauth_consent_test.go new file mode 100644 index 000000000..51c5c2e43 --- /dev/null +++ b/internal/models/oauth_consent_test.go @@ -0,0 +1,105 @@ +package models + +import ( + "testing" + "time" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" +) + +func TestNewOAuthServerConsent(t *testing.T) { + userID := uuid.Must(uuid.NewV4()) + clientID := "test-client-id" + scopes := []string{"openid", "profile", "email"} + + consent := NewOAuthServerConsent(userID, clientID, scopes) + + assert.NotEmpty(t, consent.ID) + assert.Equal(t, userID, consent.UserID) + assert.Equal(t, clientID, consent.ClientID) + assert.Equal(t, "openid profile email", consent.Scopes) + assert.False(t, consent.GrantedAt.IsZero()) + assert.Nil(t, consent.RevokedAt) +} + +func TestOAuthServerConsent_IsRevoked(t *testing.T) { + consent := &OAuthServerConsent{} + + // Initially not revoked + assert.False(t, consent.IsRevoked()) + + // After revocation + now := time.Now() + consent.RevokedAt = &now + assert.True(t, consent.IsRevoked()) +} + +func TestOAuthServerConsent_Validate(t *testing.T) { + validConsent := &OAuthServerConsent{ + UserID: uuid.Must(uuid.NewV4()), + ClientID: "test-client", + Scopes: "openid profile", + GrantedAt: time.Now(), + } + + // Valid consent should pass + assert.NoError(t, validConsent.Validate()) + + // Test invalid cases + tests := []struct { + name string + modify func(*OAuthServerConsent) + wantErr bool + errMsg string + }{ + { + name: "missing user_id", + modify: func(c *OAuthServerConsent) { c.UserID = uuid.Nil }, + wantErr: true, + errMsg: "user_id is required", + }, + { + name: "missing client_id", + modify: func(c *OAuthServerConsent) { c.ClientID = "" }, + wantErr: true, + errMsg: "client_id is required", + }, + { + name: "empty scopes", + modify: func(c *OAuthServerConsent) { c.Scopes = "" }, + wantErr: true, + errMsg: "scopes cannot be empty", + }, + { + name: "whitespace only scopes", + modify: func(c *OAuthServerConsent) { c.Scopes = " " }, + wantErr: true, + errMsg: "scopes cannot be empty", + }, + { + name: "revoked_at before granted_at", + modify: func(c *OAuthServerConsent) { + revokedAt := c.GrantedAt.Add(-1 * time.Hour) + c.RevokedAt = &revokedAt + }, + wantErr: true, + errMsg: "revoked_at cannot be before granted_at", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + consent := *validConsent // Copy + tt.modify(&consent) + + err := consent.Validate() + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/models/oauth_scope.go b/internal/models/oauth_scope.go new file mode 100644 index 000000000..6b3c437d2 --- /dev/null +++ b/internal/models/oauth_scope.go @@ -0,0 +1,47 @@ +package models + +import "strings" + +// ParseScopeString parses a space-separated scope string into a slice +func ParseScopeString(scopeString string) []string { + if scopeString == "" { + return []string{} + } + scopes := strings.Split(strings.TrimSpace(scopeString), " ") + var result []string + for _, scope := range scopes { + if strings.TrimSpace(scope) != "" { + result = append(result, strings.TrimSpace(scope)) + } + } + // Always return empty slice instead of nil for consistency + if result == nil { + return []string{} + } + return result +} + +// HasScope checks if the given scope list includes a specific scope +func HasScope(scopes []string, scope string) bool { + for _, s := range scopes { + if s == scope { + return true + } + } + return false +} + +// HasAllScopes checks if the granted scopes include all of the requested scopes +func HasAllScopes(grantedScopes, requestedScopes []string) bool { + grantedSet := make(map[string]bool) + for _, scope := range grantedScopes { + grantedSet[scope] = true + } + + for _, requestedScope := range requestedScopes { + if !grantedSet[requestedScope] { + return false + } + } + return true +} diff --git a/internal/models/oauth_scope_test.go b/internal/models/oauth_scope_test.go new file mode 100644 index 000000000..c87271500 --- /dev/null +++ b/internal/models/oauth_scope_test.go @@ -0,0 +1,199 @@ +package models + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseScopeString(t *testing.T) { + tests := []struct { + name string + scope string + expected []string + }{ + { + name: "single scope", + scope: "openid", + expected: []string{"openid"}, + }, + { + name: "multiple scopes", + scope: "openid profile email", + expected: []string{"openid", "profile", "email"}, + }, + { + name: "empty scope", + scope: "", + expected: []string{}, + }, + { + name: "scope with extra spaces", + scope: " openid profile ", + expected: []string{"openid", "profile"}, + }, + { + name: "scope with empty segments", + scope: "openid profile email", + expected: []string{"openid", "profile", "email"}, + }, + { + name: "only spaces", + scope: " ", + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParseScopeString(tt.scope) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestHasScope(t *testing.T) { + scopes := []string{"openid", "profile", "email"} + + tests := []struct { + name string + scopes []string + scope string + expected bool + }{ + { + name: "scope exists", + scopes: scopes, + scope: "openid", + expected: true, + }, + { + name: "scope exists middle", + scopes: scopes, + scope: "profile", + expected: true, + }, + { + name: "scope exists last", + scopes: scopes, + scope: "email", + expected: true, + }, + { + name: "scope does not exist", + scopes: scopes, + scope: "phone", + expected: false, + }, + { + name: "empty scope", + scopes: scopes, + scope: "", + expected: false, + }, + { + name: "empty scopes list", + scopes: []string{}, + scope: "openid", + expected: false, + }, + { + name: "nil scopes list", + scopes: nil, + scope: "openid", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := HasScope(tt.scopes, tt.scope) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestHasAllScopes(t *testing.T) { + grantedScopes := []string{"openid", "profile", "email"} + + tests := []struct { + name string + grantedScopes []string + requestedScopes []string + expected bool + }{ + { + name: "requested scopes are a subset of granted scopes", + grantedScopes: grantedScopes, + requestedScopes: []string{"openid", "profile"}, + expected: true, + }, + { + name: "requested scopes are an exact match of granted scopes", + grantedScopes: grantedScopes, + requestedScopes: []string{"openid", "profile", "email"}, + expected: true, + }, + { + name: "requested scopes are a single scope granted", + grantedScopes: grantedScopes, + requestedScopes: []string{"openid"}, + expected: true, + }, + { + name: "granted scopes are missing a scope", + grantedScopes: grantedScopes, + requestedScopes: []string{"openid", "phone"}, + expected: false, + }, + { + name: "granted scopes are missing multiple scopes", + grantedScopes: grantedScopes, + requestedScopes: []string{"phone", "address"}, + expected: false, + }, + { + name: "requested scopes are empty", + grantedScopes: grantedScopes, + requestedScopes: []string{}, + expected: true, + }, + { + name: "requested scopes are nil", + grantedScopes: grantedScopes, + requestedScopes: nil, + expected: true, + }, + { + name: "granted scopes are empty with requested scopes", + grantedScopes: []string{}, + requestedScopes: []string{"openid"}, + expected: false, + }, + { + name: "granted scopes are empty with requested scopes", + grantedScopes: []string{}, + requestedScopes: []string{}, + expected: true, + }, + { + name: "granted scopes are nil with requested scopes", + grantedScopes: nil, + requestedScopes: []string{"openid"}, + expected: false, + }, + { + name: "granted scopes are nil with requested scopes", + grantedScopes: nil, + requestedScopes: []string{}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := HasAllScopes(tt.grantedScopes, tt.requestedScopes) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/internal/models/pkce.go b/internal/models/pkce.go new file mode 100644 index 000000000..1fd788dd5 --- /dev/null +++ b/internal/models/pkce.go @@ -0,0 +1,32 @@ +package models + +import ( + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "errors" + "strings" +) + +const PKCEInvalidCodeChallengeError = "code challenge does not match previously saved code verifier" +const PKCEInvalidCodeMethodError = "code challenge method not supported" + +// VerifyPKCEChallenge performs PKCE verification using the provided challenge, method, and verifier +// This is a shared utility function used by both FlowState and OAuthServerAuthorization +func VerifyPKCEChallenge(codeChallenge, codeChallengeMethod, codeVerifier string) error { + switch strings.ToLower(codeChallengeMethod) { + case "s256": + hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier)) + encodedCodeVerifier := base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:]) + if subtle.ConstantTimeCompare([]byte(codeChallenge), []byte(encodedCodeVerifier)) != 1 { + return errors.New(PKCEInvalidCodeChallengeError) + } + case "plain": + if subtle.ConstantTimeCompare([]byte(codeChallenge), []byte(codeVerifier)) != 1 { + return errors.New(PKCEInvalidCodeChallengeError) + } + default: + return errors.New(PKCEInvalidCodeMethodError) + } + return nil +} diff --git a/internal/models/pkce_test.go b/internal/models/pkce_test.go new file mode 100644 index 000000000..9f984be4a --- /dev/null +++ b/internal/models/pkce_test.go @@ -0,0 +1,123 @@ +package models + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestVerifyPKCEChallenge(t *testing.T) { + tests := []struct { + name string + codeChallenge string + codeChallengeMethod string + codeVerifier string + wantErr bool + errMsg string + }{ + { + name: "valid S256 PKCE", + codeChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", // S256 of "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + codeChallengeMethod: "S256", + codeVerifier: "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk", + wantErr: false, + }, + { + name: "valid plain PKCE", + codeChallenge: "test-challenge", + codeChallengeMethod: "plain", + codeVerifier: "test-challenge", + wantErr: false, + }, + { + name: "invalid S256 verifier", + codeChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", + codeChallengeMethod: "S256", + codeVerifier: "wrong-verifier", + wantErr: true, + errMsg: "code challenge does not match", + }, + { + name: "invalid plain verifier", + codeChallenge: "test-challenge", + codeChallengeMethod: "plain", + codeVerifier: "wrong-challenge", + wantErr: true, + errMsg: "code challenge does not match", + }, + { + name: "invalid challenge method", + codeChallenge: "test-challenge", + codeChallengeMethod: "invalid", + codeVerifier: "test-challenge", + wantErr: true, + errMsg: "code challenge method not supported", + }, + { + name: "case insensitive S256 method", + codeChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", + codeChallengeMethod: "s256", + codeVerifier: "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk", + wantErr: false, + }, + { + name: "case insensitive plain method", + codeChallenge: "test-challenge", + codeChallengeMethod: "PLAIN", + codeVerifier: "test-challenge", + wantErr: false, + }, + { + name: "empty verifier with S256", + codeChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", + codeChallengeMethod: "S256", + codeVerifier: "", + wantErr: true, + errMsg: "code challenge does not match", + }, + { + name: "empty verifier with plain", + codeChallenge: "test-challenge", + codeChallengeMethod: "plain", + codeVerifier: "", + wantErr: true, + errMsg: "code challenge does not match", + }, + { + name: "empty challenge with S256", + codeChallenge: "", + codeChallengeMethod: "S256", + codeVerifier: "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk", + wantErr: true, + errMsg: "code challenge does not match", + }, + { + name: "empty challenge with plain", + codeChallenge: "", + codeChallengeMethod: "plain", + codeVerifier: "test-challenge", + wantErr: true, + errMsg: "code challenge does not match", + }, + { + name: "both empty with plain", + codeChallenge: "", + codeChallengeMethod: "plain", + codeVerifier: "", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := VerifyPKCEChallenge(tt.codeChallenge, tt.codeChallengeMethod, tt.codeVerifier) + + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/migrations/20250804100000_add_oauth_authorizations_consents.up.sql b/migrations/20250804100000_add_oauth_authorizations_consents.up.sql new file mode 100644 index 000000000..2287bbd42 --- /dev/null +++ b/migrations/20250804100000_add_oauth_authorizations_consents.up.sql @@ -0,0 +1,86 @@ +-- Create OAuth 2.1 support with enums, authorization, and consent tables + +-- Create enums for OAuth authorization management +do $$ begin + create type {{ index .Options "Namespace" }}.oauth_authorization_status as enum('pending', 'approved', 'denied', 'expired'); +exception + when duplicate_object then null; +end $$; + +do $$ begin + create type {{ index .Options "Namespace" }}.oauth_response_type as enum('code'); +exception + when duplicate_object then null; +end $$; + +-- Create oauth_authorizations table for OAuth 2.1 authorization requests +create table if not exists {{ index .Options "Namespace" }}.oauth_authorizations ( + id uuid not null, + authorization_id text not null, + client_id text not null references {{ index .Options "Namespace" }}.oauth_clients(client_id) on delete cascade, + user_id uuid null references {{ index .Options "Namespace" }}.users(id) on delete cascade, + redirect_uri text not null, + scope text not null, + state text null, + code_challenge text null, + code_challenge_method {{ index .Options "Namespace" }}.code_challenge_method null, + response_type {{ index .Options "Namespace" }}.oauth_response_type not null default 'code', + + -- Flow control + status {{ index .Options "Namespace" }}.oauth_authorization_status not null default 'pending', + authorization_code text null, + + -- Timestamps + created_at timestamptz not null default now(), + expires_at timestamptz not null default (now() + interval '3 minutes'), + approved_at timestamptz null, + + constraint oauth_authorizations_pkey primary key (id), + constraint oauth_authorizations_authorization_id_key unique (authorization_id), + constraint oauth_authorizations_authorization_code_key unique (authorization_code), + constraint oauth_authorizations_redirect_uri_length check (char_length(redirect_uri) <= 2048), + constraint oauth_authorizations_scope_length check (char_length(scope) <= 4096), + constraint oauth_authorizations_state_length check (char_length(state) <= 4096), + constraint oauth_authorizations_code_challenge_length check (char_length(code_challenge) <= 128), + constraint oauth_authorizations_authorization_code_length check (char_length(authorization_code) <= 255), + constraint oauth_authorizations_expires_at_future check (expires_at > created_at) +); + +-- Create indexes for oauth_authorizations +-- for CleanupExpiredOAuthServerAuthorizations +create index if not exists oauth_auth_pending_exp_idx + on {{ index .Options "Namespace" }}.oauth_authorizations (expires_at) + where status = 'pending'; + + + +-- Create oauth_consents table for user consent management +create table if not exists {{ index .Options "Namespace" }}.oauth_consents ( + id uuid not null, + user_id uuid not null references {{ index .Options "Namespace" }}.users(id) on delete cascade, + client_id text not null references {{ index .Options "Namespace" }}.oauth_clients(client_id) on delete cascade, + scopes text not null, + granted_at timestamptz not null default now(), + revoked_at timestamptz null, + + constraint oauth_consents_pkey primary key (id), + constraint oauth_consents_user_client_unique unique (user_id, client_id), + constraint oauth_consents_scopes_length check (char_length(scopes) <= 2048), + constraint oauth_consents_scopes_not_empty check (char_length(trim(scopes)) > 0), + constraint oauth_consents_revoked_after_granted check (revoked_at is null or revoked_at >= granted_at) +); + +-- Create indexes for oauth_consents +-- Active consent look-up (user + client, only non-revoked rows) +create index if not exists oauth_consents_active_user_client_idx + on {{ index .Options "Namespace" }}.oauth_consents (user_id, client_id) + where revoked_at is null; + +-- "Show me all consents for this user, newest first" +create index if not exists oauth_consents_user_order_idx + on {{ index .Options "Namespace" }}.oauth_consents (user_id, granted_at desc); + +-- Bulk revoke for an entire client (only non-revoked rows) +create index if not exists oauth_consents_active_client_idx + on {{ index .Options "Namespace" }}.oauth_consents (client_id) + where revoked_at is null; From 01115f456de903f4ca5138e109a81e4af2457cd2 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Thu, 7 Aug 2025 17:38:56 +0200 Subject: [PATCH 02/15] fix: gofmt --- internal/models/flow_state.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/models/flow_state.go b/internal/models/flow_state.go index a484f4a74..1d9178127 100644 --- a/internal/models/flow_state.go +++ b/internal/models/flow_state.go @@ -12,7 +12,6 @@ import ( "github.com/gofrs/uuid" ) - type FlowState struct { ID uuid.UUID `json:"id" db:"id"` UserID *uuid.UUID `json:"user_id,omitempty" db:"user_id"` From d9e957f6288a40fdf840eb4bd7a9a070b21b1d73 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Wed, 13 Aug 2025 13:27:46 +0200 Subject: [PATCH 03/15] chore: add `omitempty` for the response jsontags --- internal/api/oauthserver/authorize.go | 18 +++++++++--------- internal/api/oauthserver/handlers.go | 16 ++++++++-------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/internal/api/oauthserver/authorize.go b/internal/api/oauthserver/authorize.go index 15d849391..6e05aef87 100644 --- a/internal/api/oauthserver/authorize.go +++ b/internal/api/oauthserver/authorize.go @@ -29,23 +29,23 @@ type AuthorizeParams struct { // AuthorizationDetailsResponse represents the response for getting authorization details type AuthorizationDetailsResponse struct { AuthorizationID string `json:"authorization_id"` - Client ClientDetailsResponse `json:"client"` - User UserDetailsResponse `json:"user"` - Scope string `json:"scope"` + Client ClientDetailsResponse `json:"client,omitempty"` + User UserDetailsResponse `json:"user,omitempty"` + Scope string `json:"scope,omitempty"` } // ClientDetailsResponse represents client details in authorization response type ClientDetailsResponse struct { ClientID string `json:"client_id"` - ClientName string `json:"client_name"` - ClientURI string `json:"client_uri"` - LogoURI string `json:"logo_uri"` + ClientName string `json:"client_name,omitempty"` + ClientURI string `json:"client_uri,omitempty"` + LogoURI string `json:"logo_uri,omitempty"` } // UserDetailsResponse represents user details in authorization response type UserDetailsResponse struct { - ID string `json:"id"` - Email string `json:"email"` + ID string `json:"id,omitempty"` + Email string `json:"email,omitempty"` } // ConsentRequest represents a consent decision request @@ -55,7 +55,7 @@ type ConsentRequest struct { // ConsentResponse represents the response after processing consent type ConsentResponse struct { - RedirectURL string `json:"redirect_url"` + RedirectURL string `json:"redirect_url,omitempty"` } type OAuthServerConsentAction string diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index fa827240f..b8b6719e5 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -18,23 +18,23 @@ type OAuthServerClientResponse struct { ClientID string `json:"client_id"` ClientSecret string `json:"client_secret,omitempty"` // only returned on registration - RedirectURIs []string `json:"redirect_uris"` - TokenEndpointAuthMethod []string `json:"token_endpoint_auth_method"` - GrantTypes []string `json:"grant_types"` - ResponseTypes []string `json:"response_types"` + RedirectURIs []string `json:"redirect_uris,omitempty"` + TokenEndpointAuthMethod []string `json:"token_endpoint_auth_method,omitempty"` + GrantTypes []string `json:"grant_types,omitempty"` + ResponseTypes []string `json:"response_types,omitempty"` ClientName string `json:"client_name,omitempty"` ClientURI string `json:"client_uri,omitempty"` LogoURI string `json:"logo_uri,omitempty"` // Metadata fields - RegistrationType string `json:"registration_type"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + RegistrationType string `json:"registration_type,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` } // OAuthServerClientListResponse represents the response for listing OAuth clients type OAuthServerClientListResponse struct { - Clients []OAuthServerClientResponse `json:"clients"` + Clients []OAuthServerClientResponse `json:"clients,omitempty"` } // oauthServerClientToResponse converts a model to response format From 1167bea0d71c154c4c09eb92d43cc5d76cba777d Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Wed, 13 Aug 2025 13:28:32 +0200 Subject: [PATCH 04/15] chore: single line SQLs --- internal/models/oauth_authorization.go | 6 +----- internal/models/oauth_consent.go | 12 ++---------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/internal/models/oauth_authorization.go b/internal/models/oauth_authorization.go index 9cedae35d..6a6109322 100644 --- a/internal/models/oauth_authorization.go +++ b/internal/models/oauth_authorization.go @@ -262,11 +262,7 @@ func CreateOAuthServerAuthorization(tx *storage.Connection, auth *OAuthServerAut // CleanupExpiredOAuthServerAuthorizations marks expired authorizations as expired func CleanupExpiredOAuthServerAuthorizations(tx *storage.Connection) error { - query := ` - UPDATE ` + (&OAuthServerAuthorization{}).TableName() + ` - SET status = ? - WHERE status = ? AND expires_at < now() - ` + query := "UPDATE " + (&OAuthServerAuthorization{}).TableName() + " SET status = ? WHERE status = ? AND expires_at < now()" return tx.RawQuery(query, OAuthServerAuthorizationExpired, OAuthServerAuthorizationPending).Exec() } diff --git a/internal/models/oauth_consent.go b/internal/models/oauth_consent.go index 5b076dcca..cef2dd294 100644 --- a/internal/models/oauth_consent.go +++ b/internal/models/oauth_consent.go @@ -169,21 +169,13 @@ func UpsertOAuthServerConsent(tx *storage.Connection, consent *OAuthServerConsen // RevokeOAuthServerConsentsByClient revokes all consents for a specific client func RevokeOAuthServerConsentsByClient(tx *storage.Connection, clientID string) error { now := time.Now() - query := ` - UPDATE ` + (&OAuthServerConsent{}).TableName() + ` - SET revoked_at = ? - WHERE client_id = ? AND revoked_at IS NULL - ` + query := "UPDATE " + (&OAuthServerConsent{}).TableName() + " SET revoked_at = ? WHERE client_id = ? AND revoked_at IS NULL" return tx.RawQuery(query, now, clientID).Exec() } // RevokeOAuthServerConsentsByUser revokes all consents for a specific user func RevokeOAuthServerConsentsByUser(tx *storage.Connection, userID uuid.UUID) error { now := time.Now() - query := ` - UPDATE ` + (&OAuthServerConsent{}).TableName() + ` - SET revoked_at = ? - WHERE user_id = ? AND revoked_at IS NULL - ` + query := "UPDATE " + (&OAuthServerConsent{}).TableName() + " SET revoked_at = ? WHERE user_id = ? AND revoked_at IS NULL" return tx.RawQuery(query, now, userID).Exec() } From aa17148fe9399c3718c097c3122e5996deeb1690 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Wed, 13 Aug 2025 13:40:59 +0200 Subject: [PATCH 05/15] feat: append `AuthorizationPath` to site URL --- internal/api/oauthserver/authorize.go | 25 ++++++++++++++++++------- internal/conf/configuration.go | 2 +- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/internal/api/oauthserver/authorize.go b/internal/api/oauthserver/authorize.go index 6e05aef87..a5fdb5112 100644 --- a/internal/api/oauthserver/authorize.go +++ b/internal/api/oauthserver/authorize.go @@ -107,15 +107,13 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) observability.LogEntrySetField(r, "client_id", authorization.ClientID) - // Redirect to configured authorization URL with authorization_id - // TODO(cemal): should this be really a different config or something set on the config.SiteURL? - if config.OAuthServer.AuthorizationURL == "" { - return apierrors.NewInternalServerError("oauth authorization URL not configured") + // Redirect to authorization path with authorization_id + if config.OAuthServer.AuthorizationPath == "" { + return apierrors.NewInternalServerError("oauth authorization path not configured") } - redirectURL := fmt.Sprintf("%s?authorization_id=%s", - config.OAuthServer.AuthorizationURL, - authorization.AuthorizationID) + baseURL := s.buildAuthorizationURL(config.SiteURL, config.OAuthServer.AuthorizationPath) + redirectURL := fmt.Sprintf("%s?authorization_id=%s", baseURL, authorization.AuthorizationID) http.Redirect(w, r, redirectURL, http.StatusFound) return nil @@ -482,3 +480,16 @@ func (s *Server) buildErrorRedirectURL(authorization *models.OAuthServerAuthoriz u.RawQuery = q.Encode() return u.String() } + +// buildAuthorizationURL safely joins a base URL with a path, handling slashes correctly +func (s *Server) buildAuthorizationURL(baseURL, pathToJoin string) string { + // Trim trailing slash from baseURL + baseURL = strings.TrimRight(baseURL, "/") + + // Ensure pathToJoin starts with a slash + if !strings.HasPrefix(pathToJoin, "/") { + pathToJoin = "/" + pathToJoin + } + + return baseURL + pathToJoin +} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 5a22cb109..2c57d8ece 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -74,7 +74,7 @@ type OAuthProviderConfiguration struct { // OAuthServerConfiguration holds OAuth server configuration type OAuthServerConfiguration struct { AllowDynamicRegistration bool `json:"allow_dynamic_registration" split_words:"true"` - AuthorizationURL string `json:"authorization_url" split_words:"true"` + AuthorizationPath string `json:"authorization_path" split_words:"true"` AuthorizationTimeout time.Duration `json:"authorization_timeout" split_words:"true" default:"5m"` // Placeholder for now, for (near) future extensibility DefaultScope string `json:"default_scope" split_words:"true" default:"email"` From 76d13bd1de7d802973a714552bdde1537e13d489 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Wed, 13 Aug 2025 13:43:44 +0200 Subject: [PATCH 06/15] chore: mv pkce logic to `security` package --- internal/models/flow_state.go | 3 ++- internal/models/oauth_authorization.go | 3 ++- internal/{models => security}/pkce.go | 2 +- internal/{models => security}/pkce_test.go | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) rename internal/{models => security}/pkce.go (98%) rename internal/{models => security}/pkce_test.go (99%) diff --git a/internal/models/flow_state.go b/internal/models/flow_state.go index 1d9178127..7ce4c940a 100644 --- a/internal/models/flow_state.go +++ b/internal/models/flow_state.go @@ -7,6 +7,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/supabase/auth/internal/security" "github.com/supabase/auth/internal/storage" "github.com/gofrs/uuid" @@ -128,7 +129,7 @@ func FindFlowStateByUserID(tx *storage.Connection, id string, authenticationMeth } func (f *FlowState) VerifyPKCE(codeVerifier string) error { - return VerifyPKCEChallenge(f.CodeChallenge, f.CodeChallengeMethod, codeVerifier) + return security.VerifyPKCEChallenge(f.CodeChallenge, f.CodeChallengeMethod, codeVerifier) } func (f *FlowState) IsExpired(expiryDuration time.Duration) bool { diff --git a/internal/models/oauth_authorization.go b/internal/models/oauth_authorization.go index 6a6109322..2903f130b 100644 --- a/internal/models/oauth_authorization.go +++ b/internal/models/oauth_authorization.go @@ -8,6 +8,7 @@ import ( "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/security" "github.com/supabase/auth/internal/storage" ) @@ -197,7 +198,7 @@ func (auth *OAuthServerAuthorization) VerifyPKCE(codeVerifier string) error { } // Use the shared PKCE verification function - return VerifyPKCEChallenge(auth.CodeChallenge.String(), auth.CodeChallengeMethod.String(), codeVerifier) + return security.VerifyPKCEChallenge(auth.CodeChallenge.String(), auth.CodeChallengeMethod.String(), codeVerifier) } // Query functions for OAuth authorizations diff --git a/internal/models/pkce.go b/internal/security/pkce.go similarity index 98% rename from internal/models/pkce.go rename to internal/security/pkce.go index 1fd788dd5..7e4ccf41d 100644 --- a/internal/models/pkce.go +++ b/internal/security/pkce.go @@ -1,4 +1,4 @@ -package models +package security import ( "crypto/sha256" diff --git a/internal/models/pkce_test.go b/internal/security/pkce_test.go similarity index 99% rename from internal/models/pkce_test.go rename to internal/security/pkce_test.go index 9f984be4a..795b06778 100644 --- a/internal/models/pkce_test.go +++ b/internal/security/pkce_test.go @@ -1,4 +1,4 @@ -package models +package security import ( "testing" From d67a7753376e439ac671e9132e0536cd26f488e7 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Wed, 13 Aug 2025 14:09:41 +0200 Subject: [PATCH 07/15] chore: `NullString` -> string pointer --- internal/api/oauthserver/authorize.go | 19 ++++++++------ internal/api/oauthserver/handlers.go | 7 ++--- internal/api/oauthserver/service.go | 8 +++--- internal/api/oauthserver/service_test.go | 2 +- internal/models/oauth_authorization.go | 28 +++++++++++--------- internal/models/oauth_authorization_test.go | 10 +++---- internal/models/oauth_client.go | 6 ++--- internal/models/oauth_client_test.go | 29 +++++++++++++-------- internal/utilities/strings.go | 17 ++++++++++++ 9 files changed, 79 insertions(+), 47 deletions(-) create mode 100644 internal/utilities/strings.go diff --git a/internal/api/oauthserver/authorize.go b/internal/api/oauthserver/authorize.go index a5fdb5112..56a633ed3 100644 --- a/internal/api/oauthserver/authorize.go +++ b/internal/api/oauthserver/authorize.go @@ -13,6 +13,7 @@ import ( "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" ) // AuthorizeParams represents the parameters for an OAuth authorization request @@ -182,9 +183,9 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ AuthorizationID: authorization.AuthorizationID, Client: ClientDetailsResponse{ ClientID: authorization.Client.ClientID, - ClientName: authorization.Client.ClientName.String(), - ClientURI: authorization.Client.ClientURI.String(), - LogoURI: authorization.Client.LogoURI.String(), + ClientName: utilities.StringValue(authorization.Client.ClientName), + ClientURI: utilities.StringValue(authorization.Client.ClientURI), + LogoURI: utilities.StringValue(authorization.Client.LogoURI), }, User: UserDetailsResponse{ ID: user.ID.String(), @@ -461,9 +462,11 @@ func (s *Server) autoApproveAndRedirect(w http.ResponseWriter, r *http.Request, func (s *Server) buildSuccessRedirectURL(authorization *models.OAuthServerAuthorization) string { u, _ := url.Parse(authorization.RedirectURI) q := u.Query() - q.Set("code", authorization.AuthorizationCode.String()) - if authorization.State.String() != "" { - q.Set("state", authorization.State.String()) + if authorization.AuthorizationCode != nil { + q.Set("code", *authorization.AuthorizationCode) + } + if authorization.State != nil && *authorization.State != "" { + q.Set("state", *authorization.State) } u.RawQuery = q.Encode() return u.String() @@ -474,8 +477,8 @@ func (s *Server) buildErrorRedirectURL(authorization *models.OAuthServerAuthoriz q := u.Query() q.Set("error", errorCode) q.Set("error_description", errorDescription) - if authorization.State.String() != "" { - q.Set("state", authorization.State.String()) + if authorization.State != nil && *authorization.State != "" { + q.Set("state", *authorization.State) } u.RawQuery = q.Encode() return u.String() diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index b8b6719e5..2135929d1 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -11,6 +11,7 @@ import ( "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/utilities" ) // OAuthServerClientResponse represents the response format for OAuth client operations @@ -47,9 +48,9 @@ func oauthServerClientToResponse(client *models.OAuthServerClient, includeSecret TokenEndpointAuthMethod: []string{"client_secret_basic", "client_secret_post"}, // Both methods are supported GrantTypes: client.GetGrantTypes(), ResponseTypes: []string{"code"}, // Always "code" in OAuth 2.1 - ClientName: client.ClientName.String(), - ClientURI: client.ClientURI.String(), - LogoURI: client.LogoURI.String(), + ClientName: utilities.StringValue(client.ClientName), + ClientURI: utilities.StringValue(client.ClientURI), + LogoURI: utilities.StringValue(client.LogoURI), // Metadata fields RegistrationType: client.RegistrationType, diff --git a/internal/api/oauthserver/service.go b/internal/api/oauthserver/service.go index 2537f5e1d..02c3911bd 100644 --- a/internal/api/oauthserver/service.go +++ b/internal/api/oauthserver/service.go @@ -10,7 +10,7 @@ import ( "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/models" - "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" "golang.org/x/crypto/bcrypt" ) @@ -160,9 +160,9 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer client := &models.OAuthServerClient{ ClientID: generateClientID(), RegistrationType: params.RegistrationType, - ClientName: storage.NullString(params.ClientName), - ClientURI: storage.NullString(params.ClientURI), - LogoURI: storage.NullString(params.LogoURI), + ClientName: utilities.StringPtr(params.ClientName), + ClientURI: utilities.StringPtr(params.ClientURI), + LogoURI: utilities.StringPtr(params.LogoURI), } client.SetRedirectURIs(params.RedirectURIs) diff --git a/internal/api/oauthserver/service_test.go b/internal/api/oauthserver/service_test.go index 3c7f1f365..a1220f893 100644 --- a/internal/api/oauthserver/service_test.go +++ b/internal/api/oauthserver/service_test.go @@ -86,7 +86,7 @@ func (ts *OAuthServiceTestSuite) TestOAuthServerClientServiceMethods() { require.NoError(ts.T(), err) require.NotNil(ts.T(), client) require.NotEmpty(ts.T(), secret) - assert.Equal(ts.T(), "Test Client", client.ClientName.String()) + assert.Equal(ts.T(), "Test Client", *client.ClientName) assert.Equal(ts.T(), "dynamic", client.RegistrationType) // Test getOAuthServerClient diff --git a/internal/models/oauth_authorization.go b/internal/models/oauth_authorization.go index 2903f130b..61fc7e83d 100644 --- a/internal/models/oauth_authorization.go +++ b/internal/models/oauth_authorization.go @@ -45,12 +45,12 @@ type OAuthServerAuthorization struct { UserID *uuid.UUID `json:"user_id" db:"user_id"` RedirectURI string `json:"redirect_uri" db:"redirect_uri"` Scope string `json:"scope" db:"scope"` - State storage.NullString `json:"state" db:"state"` - CodeChallenge storage.NullString `json:"code_challenge" db:"code_challenge"` - CodeChallengeMethod storage.NullString `json:"code_challenge_method" db:"code_challenge_method"` + State *string `json:"state,omitempty" db:"state"` + CodeChallenge *string `json:"code_challenge,omitempty" db:"code_challenge"` + CodeChallengeMethod *string `json:"code_challenge_method,omitempty" db:"code_challenge_method"` ResponseType OAuthServerResponseType `json:"response_type" db:"response_type"` Status OAuthServerAuthorizationStatus `json:"status" db:"status"` - AuthorizationCode storage.NullString `json:"-" db:"authorization_code"` + AuthorizationCode *string `json:"-" db:"authorization_code"` CreatedAt time.Time `json:"created_at" db:"created_at"` ExpiresAt time.Time `json:"expires_at" db:"expires_at"` ApprovedAt *time.Time `json:"approved_at" db:"approved_at"` @@ -86,13 +86,13 @@ func NewOAuthServerAuthorization(clientID, redirectURI, scope, state, codeChalle } if state != "" { - auth.State = storage.NullString(state) + auth.State = &state } if codeChallenge != "" { - auth.CodeChallenge = storage.NullString(codeChallenge) + auth.CodeChallenge = &codeChallenge } if codeChallengeMethod != "" { - auth.CodeChallengeMethod = storage.NullString(codeChallengeMethod) + auth.CodeChallengeMethod = &codeChallengeMethod } return auth @@ -116,12 +116,12 @@ func (auth *OAuthServerAuthorization) GetScopeList() []string { // GenerateAuthorizationCode generates a new authorization code if not already set func (auth *OAuthServerAuthorization) GenerateAuthorizationCode() string { - if auth.AuthorizationCode.String() != "" { - return auth.AuthorizationCode.String() + if auth.AuthorizationCode != nil && *auth.AuthorizationCode != "" { + return *auth.AuthorizationCode } code := uuid.Must(uuid.NewV4()).String() - auth.AuthorizationCode = storage.NullString(code) + auth.AuthorizationCode = &code return code } @@ -188,7 +188,7 @@ func (auth *OAuthServerAuthorization) Validate() error { // VerifyPKCE verifies the PKCE code verifier against the stored challenge func (auth *OAuthServerAuthorization) VerifyPKCE(codeVerifier string) error { - if auth.CodeChallenge.String() == "" { + if auth.CodeChallenge == nil || *auth.CodeChallenge == "" { // No PKCE challenge stored, verification passes return nil } @@ -198,7 +198,11 @@ func (auth *OAuthServerAuthorization) VerifyPKCE(codeVerifier string) error { } // Use the shared PKCE verification function - return security.VerifyPKCEChallenge(auth.CodeChallenge.String(), auth.CodeChallengeMethod.String(), codeVerifier) + var codeChallengeMethod string + if auth.CodeChallengeMethod != nil { + codeChallengeMethod = *auth.CodeChallengeMethod + } + return security.VerifyPKCEChallenge(*auth.CodeChallenge, codeChallengeMethod, codeVerifier) } // Query functions for OAuth authorizations diff --git a/internal/models/oauth_authorization_test.go b/internal/models/oauth_authorization_test.go index 8021ce6bd..85412e2c0 100644 --- a/internal/models/oauth_authorization_test.go +++ b/internal/models/oauth_authorization_test.go @@ -25,9 +25,9 @@ func TestNewOAuthServerAuthorization(t *testing.T) { assert.Nil(t, auth.UserID) assert.Equal(t, redirectURI, auth.RedirectURI) assert.Equal(t, scope, auth.Scope) - assert.Equal(t, state, auth.State.String()) - assert.Equal(t, codeChallenge, auth.CodeChallenge.String()) - assert.Equal(t, codeChallengeMethod, auth.CodeChallengeMethod.String()) + assert.Equal(t, state, *auth.State) + assert.Equal(t, codeChallenge, *auth.CodeChallenge) + assert.Equal(t, codeChallengeMethod, *auth.CodeChallengeMethod) assert.Equal(t, OAuthServerResponseTypeCode, auth.ResponseType) assert.Equal(t, OAuthServerAuthorizationPending, auth.Status) assert.True(t, auth.ExpiresAt.After(auth.CreatedAt)) @@ -52,7 +52,7 @@ func TestOAuthServerAuthorization_GenerateAuthorizationCode(t *testing.T) { // First call should generate a code code1 := auth.GenerateAuthorizationCode() assert.NotEmpty(t, code1) - assert.Equal(t, code1, auth.AuthorizationCode.String()) + assert.Equal(t, code1, *auth.AuthorizationCode) // Second call should return the same code code2 := auth.GenerateAuthorizationCode() @@ -148,7 +148,7 @@ func TestOAuthServerAuthorization_ApproveSuccess(t *testing.T) { assert.Equal(t, OAuthServerAuthorizationApproved, auth.Status) assert.NotNil(t, auth.ApprovedAt) assert.True(t, auth.ApprovedAt.After(beforeTime)) - assert.NotEmpty(t, auth.AuthorizationCode.String()) + assert.NotEmpty(t, *auth.AuthorizationCode) } func TestOAuthServerAuthorization_Deny(t *testing.T) { diff --git a/internal/models/oauth_client.go b/internal/models/oauth_client.go index 7c8776377..93805a0f7 100644 --- a/internal/models/oauth_client.go +++ b/internal/models/oauth_client.go @@ -22,9 +22,9 @@ type OAuthServerClient struct { RedirectURIs string `json:"-" db:"redirect_uris"` GrantTypes string `json:"grant_types" db:"grant_types"` - ClientName storage.NullString `json:"client_name" db:"client_name"` - ClientURI storage.NullString `json:"client_uri" db:"client_uri"` - LogoURI storage.NullString `json:"logo_uri" db:"logo_uri"` + ClientName *string `json:"client_name,omitempty" db:"client_name"` + ClientURI *string `json:"client_uri,omitempty" db:"client_uri"` + LogoURI *string `json:"logo_uri,omitempty" db:"logo_uri"` CreatedAt time.Time `json:"created_at" db:"created_at"` UpdatedAt time.Time `json:"updated_at" db:"updated_at"` DeletedAt *time.Time `json:"deleted_at,omitempty" db:"deleted_at"` diff --git a/internal/models/oauth_client_test.go b/internal/models/oauth_client_test.go index 1a2607f6c..233da83fb 100644 --- a/internal/models/oauth_client_test.go +++ b/internal/models/oauth_client_test.go @@ -20,7 +20,7 @@ type OAuthServerClientTestSuite struct { } func (ts *OAuthServerClientTestSuite) SetupTest() { - TruncateAll(ts.db) + _ = TruncateAll(ts.db) } func TestOAuthServerClient(t *testing.T) { @@ -39,10 +39,11 @@ func TestOAuthServerClient(t *testing.T) { } func (ts *OAuthServerClientTestSuite) TestOAuthServerClientValidation() { + testClientName := "Test Client" validClient := &OAuthServerClient{ ID: uuid.Must(uuid.NewV4()), ClientID: "test_client_id", - ClientName: storage.NullString("Test Client"), + ClientName: &testClientName, RegistrationType: "dynamic", RedirectURIs: "https://example.com/callback", GrantTypes: "authorization_code,refresh_token", @@ -141,9 +142,10 @@ func (ts *OAuthServerClientTestSuite) TestRedirectURIHelpers() { } func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClient() { + testAppName := "Test Application" client := &OAuthServerClient{ ClientID: "test_client_create_" + uuid.Must(uuid.NewV4()).String()[:8], - ClientName: storage.NullString("Test Application"), + ClientName: &testAppName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", RedirectURIs: "https://example.com/callback", @@ -170,9 +172,10 @@ func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClientValidation() { func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { // Create a test client + testName := "Find By ID Test" client := &OAuthServerClient{ ClientID: "test_client_find_by_id_" + uuid.Must(uuid.NewV4()).String()[:8], - ClientName: storage.NullString("Find By ID Test"), + ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", RedirectURIs: "https://example.com/callback", @@ -185,7 +188,7 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { foundClient, err := FindOAuthServerClientByID(ts.db, client.ID) require.NoError(ts.T(), err) assert.Equal(ts.T(), client.ClientID, foundClient.ClientID) - assert.Equal(ts.T(), client.ClientName.String(), foundClient.ClientName.String()) + assert.Equal(ts.T(), *client.ClientName, *foundClient.ClientName) // Test not found _, err = FindOAuthServerClientByID(ts.db, uuid.Must(uuid.NewV4())) @@ -195,9 +198,10 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { // Create a test client + testName := "Find By Client ID Test" client := &OAuthServerClient{ ClientID: "test_client_find_by_client_id_" + uuid.Must(uuid.NewV4()).String()[:8], - ClientName: storage.NullString("Find By Client ID Test"), + ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "manual", RedirectURIs: "https://example.com/callback", @@ -210,7 +214,7 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { foundClient, err := FindOAuthServerClientByClientID(ts.db, client.ClientID) require.NoError(ts.T(), err) assert.Equal(ts.T(), client.ID, foundClient.ID) - assert.Equal(ts.T(), client.ClientName.String(), foundClient.ClientName.String()) + assert.Equal(ts.T(), *client.ClientName, *foundClient.ClientName) // Test not found _, err = FindOAuthServerClientByClientID(ts.db, "nonexistent_client_id") @@ -220,9 +224,10 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { func (ts *OAuthServerClientTestSuite) TestUpdateOAuthServerClient() { // Create a test client + originalName := "Original Name" client := &OAuthServerClient{ ClientID: "test_client_update_" + uuid.Must(uuid.NewV4()).String()[:8], - ClientName: storage.NullString("Original Name"), + ClientName: &originalName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", RedirectURIs: "https://example.com/callback", @@ -233,7 +238,8 @@ func (ts *OAuthServerClientTestSuite) TestUpdateOAuthServerClient() { originalUpdatedAt := client.UpdatedAt // Update the client - client.ClientName = storage.NullString("Updated Name") + updatedName := "Updated Name" + client.ClientName = &updatedName // client.Description removed - no longer exists client.SetRedirectURIs([]string{"https://updated.example.com/callback"}) @@ -243,7 +249,7 @@ func (ts *OAuthServerClientTestSuite) TestUpdateOAuthServerClient() { // Verify updates updatedClient, err := FindOAuthServerClientByID(ts.db, client.ID) require.NoError(ts.T(), err) - assert.Equal(ts.T(), "Updated Name", updatedClient.ClientName.String()) + assert.Equal(ts.T(), "Updated Name", *updatedClient.ClientName) // assert.Equal(ts.T(), "Updated description", updatedClient.Description.String()) // Description field removed assert.Equal(ts.T(), []string{"https://updated.example.com/callback"}, updatedClient.GetRedirectURIs()) assert.True(ts.T(), updatedClient.UpdatedAt.After(originalUpdatedAt)) @@ -267,9 +273,10 @@ func (ts *OAuthServerClientTestSuite) TestClientSecretHashing() { func (ts *OAuthServerClientTestSuite) TestSoftDelete() { // Create a test client + testName := "Soft Delete Test" client := &OAuthServerClient{ ClientID: "test_client_soft_delete_" + uuid.Must(uuid.NewV4()).String()[:8], - ClientName: storage.NullString("Soft Delete Test"), + ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", RedirectURIs: "https://example.com/callback", diff --git a/internal/utilities/strings.go b/internal/utilities/strings.go new file mode 100644 index 000000000..471a4552d --- /dev/null +++ b/internal/utilities/strings.go @@ -0,0 +1,17 @@ +package utilities + +// StringValue safely extracts a string from a *string, returning empty string if nil +func StringValue(s *string) string { + if s == nil { + return "" + } + return *s +} + +// StringPtr returns a pointer to a string if non-empty, nil otherwise +func StringPtr(s string) *string { + if s == "" { + return nil + } + return &s +} From b9c26ea4f9e759a29fb423b447e46e8a8d3c676e Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Wed, 13 Aug 2025 14:27:18 +0200 Subject: [PATCH 08/15] feat: enable oauth server based on config --- internal/api/api.go | 60 +++++++++++++---------- internal/api/api_test.go | 27 ++++++++++ internal/api/oauthserver/handlers_test.go | 3 +- internal/api/oauthserver/service_test.go | 6 ++- internal/conf/configuration.go | 1 + 5 files changed, 68 insertions(+), 29 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index 488714ead..bda0fc0de 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -88,10 +88,14 @@ func (a *API) deprecationNotices() { // NewAPIWithVersion creates a new REST API using the specified version func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string, opt ...Option) *API { api := &API{ - config: globalConfig, - db: db, - version: version, - oauthServer: oauthserver.NewServer(globalConfig, db), + config: globalConfig, + db: db, + version: version, + } + + // Only initialize OAuth server if enabled + if globalConfig.OAuthServer.Enabled { + api.oauthServer = oauthserver.NewServer(globalConfig, db) } for _, o := range opt { @@ -327,33 +331,37 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne }) // Admin only oauth client management endpoints - r.Route("/oauth", func(r *router) { - r.Route("/clients", func(r *router) { - // Manual client registration - r.Post("/", api.oauthServer.AdminOAuthServerClientRegister) - - r.Get("/", api.oauthServer.OAuthServerClientList) - - r.Route("/{client_id}", func(r *router) { - r.Use(api.oauthServer.LoadOAuthServerClient) - r.Get("/", api.oauthServer.OAuthServerClientGet) - r.Delete("/", api.oauthServer.OAuthServerClientDelete) + if globalConfig.OAuthServer.Enabled { + r.Route("/oauth", func(r *router) { + r.Route("/clients", func(r *router) { + // Manual client registration + r.Post("/", api.oauthServer.AdminOAuthServerClientRegister) + + r.Get("/", api.oauthServer.OAuthServerClientList) + + r.Route("/{client_id}", func(r *router) { + r.Use(api.oauthServer.LoadOAuthServerClient) + r.Get("/", api.oauthServer.OAuthServerClientGet) + r.Delete("/", api.oauthServer.OAuthServerClientDelete) + }) }) }) - }) + } }) // OAuth Dynamic Client Registration endpoint (public, rate limited) - r.Route("/oauth", func(r *router) { - r.With(api.limitHandler(api.limiterOpts.OAuthClientRegister)). - Post("/clients/register", api.oauthServer.OAuthServerClientDynamicRegister) - - // OAuth 2.1 Authorization endpoints - // `/authorize` to initiate OAuth2 authorization code flow where Supabase Auth is the OAuth2 provider - r.Get("/authorize", api.oauthServer.OAuthServerAuthorize) - r.With(api.requireAuthentication).Get("/authorizations/{authorization_id}", api.oauthServer.OAuthServerGetAuthorization) - r.With(api.requireAuthentication).Post("/authorizations/{authorization_id}/consent", api.oauthServer.OAuthServerConsent) - }) + if globalConfig.OAuthServer.Enabled { + r.Route("/oauth", func(r *router) { + r.With(api.limitHandler(api.limiterOpts.OAuthClientRegister)). + Post("/clients/register", api.oauthServer.OAuthServerClientDynamicRegister) + + // OAuth 2.1 Authorization endpoints + // `/authorize` to initiate OAuth2 authorization code flow where Supabase Auth is the OAuth2 provider + r.Get("/authorize", api.oauthServer.OAuthServerAuthorize) + r.With(api.requireAuthentication).Get("/authorizations/{authorization_id}", api.oauthServer.OAuthServerGetAuthorization) + r.With(api.requireAuthentication).Post("/authorizations/{authorization_id}/consent", api.oauthServer.OAuthServerConsent) + }) + } }) corsHandler := cors.New(cors.Options{ diff --git a/internal/api/api_test.go b/internal/api/api_test.go index a472be737..c205503bf 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -55,3 +55,30 @@ func TestEmailEnabledByDefault(t *testing.T) { require.True(t, api.config.External.Email.Enabled) } + +func TestOAuthServerDisabledByDefault(t *testing.T) { + api, _, err := setupAPIForTest() + require.NoError(t, err) + + // OAuth server should be disabled by default + require.False(t, api.config.OAuthServer.Enabled) + + // OAuth server instance should not be initialized when disabled + require.Nil(t, api.oauthServer) +} + +func TestOAuthServerCanBeEnabled(t *testing.T) { + api, _, err := setupAPIForTestWithCallback(func(config *conf.GlobalConfiguration, conn *storage.Connection) { + if config != nil { + // Enable OAuth server + config.OAuthServer.Enabled = true + } + }) + require.NoError(t, err) + + // OAuth server should be enabled + require.True(t, api.config.OAuthServer.Enabled) + + // OAuth server instance should be initialized when enabled + require.NotNil(t, api.oauthServer) +} diff --git a/internal/api/oauthserver/handlers_test.go b/internal/api/oauthserver/handlers_test.go index c72cdb534..0d74eec83 100644 --- a/internal/api/oauthserver/handlers_test.go +++ b/internal/api/oauthserver/handlers_test.go @@ -33,7 +33,8 @@ func TestOAuthClientHandler(t *testing.T) { conn, err := test.SetupDBConnection(globalConfig) require.NoError(t, err) - // Enable OAuth dynamic client registration for tests + // Enable OAuth server and dynamic client registration for tests + globalConfig.OAuthServer.Enabled = true globalConfig.OAuthServer.AllowDynamicRegistration = true server := NewServer(globalConfig, conn) diff --git a/internal/api/oauthserver/service_test.go b/internal/api/oauthserver/service_test.go index a1220f893..60798a49e 100644 --- a/internal/api/oauthserver/service_test.go +++ b/internal/api/oauthserver/service_test.go @@ -30,7 +30,8 @@ func TestOAuthService(t *testing.T) { conn, err := test.SetupDBConnection(globalConfig) require.NoError(t, err) - // Enable OAuth dynamic client registration for tests + // Enable OAuth server and dynamic client registration for tests + globalConfig.OAuthServer.Enabled = true globalConfig.OAuthServer.AllowDynamicRegistration = true server := NewServer(globalConfig, conn) @@ -49,7 +50,8 @@ func (ts *OAuthServiceTestSuite) SetupTest() { if ts.DB != nil { models.TruncateAll(ts.DB) } - // Enable OAuth dynamic client registration for tests + // Enable OAuth server and dynamic client registration for tests + ts.Config.OAuthServer.Enabled = true ts.Config.OAuthServer.AllowDynamicRegistration = true } diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 2c57d8ece..e59460c0e 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -73,6 +73,7 @@ type OAuthProviderConfiguration struct { // OAuthServerConfiguration holds OAuth server configuration type OAuthServerConfiguration struct { + Enabled bool `json:"enabled" default:"false"` AllowDynamicRegistration bool `json:"allow_dynamic_registration" split_words:"true"` AuthorizationPath string `json:"authorization_path" split_words:"true"` AuthorizationTimeout time.Duration `json:"authorization_timeout" split_words:"true" default:"5m"` From 4ae8fe819ea6fdd9abd68961c974247e7b2fb2be Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Wed, 13 Aug 2025 14:32:32 +0200 Subject: [PATCH 09/15] fix: normalize code challenge method to lowercase MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit as we’re using the existing database enum, and it’s defined with `s256` and `plain` --- internal/models/oauth_authorization.go | 6 +++- internal/models/oauth_authorization_test.go | 33 ++++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/internal/models/oauth_authorization.go b/internal/models/oauth_authorization.go index 61fc7e83d..cfdf2675d 100644 --- a/internal/models/oauth_authorization.go +++ b/internal/models/oauth_authorization.go @@ -3,6 +3,7 @@ package models import ( "database/sql" "fmt" + "strings" "time" "github.com/gofrs/uuid" @@ -92,7 +93,10 @@ func NewOAuthServerAuthorization(clientID, redirectURI, scope, state, codeChalle auth.CodeChallenge = &codeChallenge } if codeChallengeMethod != "" { - auth.CodeChallengeMethod = &codeChallengeMethod + // Normalize code challenge method to lowercase for database storage + // Database enum expects 's256' and 'plain' (lowercase) + normalizedMethod := strings.ToLower(codeChallengeMethod) + auth.CodeChallengeMethod = &normalizedMethod } return auth diff --git a/internal/models/oauth_authorization_test.go b/internal/models/oauth_authorization_test.go index 85412e2c0..19f95ae5d 100644 --- a/internal/models/oauth_authorization_test.go +++ b/internal/models/oauth_authorization_test.go @@ -27,13 +27,44 @@ func TestNewOAuthServerAuthorization(t *testing.T) { assert.Equal(t, scope, auth.Scope) assert.Equal(t, state, *auth.State) assert.Equal(t, codeChallenge, *auth.CodeChallenge) - assert.Equal(t, codeChallengeMethod, *auth.CodeChallengeMethod) + assert.Equal(t, "s256", *auth.CodeChallengeMethod) // Should be normalized to lowercase assert.Equal(t, OAuthServerResponseTypeCode, auth.ResponseType) assert.Equal(t, OAuthServerAuthorizationPending, auth.Status) assert.True(t, auth.ExpiresAt.After(auth.CreatedAt)) assert.Nil(t, auth.ApprovedAt) } +func TestNewOAuthServerAuthorization_CodeChallengeMethodNormalization(t *testing.T) { + testCases := []struct { + name string + input string + expected string + }{ + {"uppercase S256", "S256", "s256"}, + {"lowercase s256", "s256", "s256"}, + {"mixed case S256", "s256", "s256"}, + {"uppercase PLAIN", "PLAIN", "plain"}, + {"lowercase plain", "plain", "plain"}, + {"mixed case Plain", "Plain", "plain"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + auth := NewOAuthServerAuthorization( + "client-id", + "https://example.com/callback", + "openid", + "state", + "challenge", + tc.input, + ) + + assert.Equal(t, tc.expected, *auth.CodeChallengeMethod, + "Expected code_challenge_method to be normalized to %s, got %s", tc.expected, *auth.CodeChallengeMethod) + }) + } +} + func TestOAuthServerAuthorization_IsExpired(t *testing.T) { auth := &OAuthServerAuthorization{ CreatedAt: time.Now().Add(-1 * time.Hour), From c071b125d791f096f22c1a2dcec555fea45675c4 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Thu, 14 Aug 2025 15:37:18 +0200 Subject: [PATCH 10/15] feat: redirect on errors if possible --- internal/api/oauthserver/authorize.go | 101 +++++++++++++++++--------- 1 file changed, 66 insertions(+), 35 deletions(-) diff --git a/internal/api/oauthserver/authorize.go b/internal/api/oauthserver/authorize.go index 56a633ed3..6794b8421 100644 --- a/internal/api/oauthserver/authorize.go +++ b/internal/api/oauthserver/authorize.go @@ -2,6 +2,7 @@ package oauthserver import ( "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -66,19 +67,38 @@ const ( OAuthServerConsentActionDeny OAuthServerConsentAction = "deny" ) +// OAuth2 error codes per RFC 6749 +const ( + oAuth2ErrorInvalidRequest = "invalid_request" + oAuth2ErrorServerError = "server_error" + oAuth2ErrorAccessDenied = "access_denied" +) + // OAuthServerAuthorize handles GET /oauth/authorize func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() db := s.db.WithContext(ctx) config := s.config - // Validate OAuth2 authorize parameters - params, err := s.validateAuthorizeParams(r) + query := r.URL.Query() + + params := &AuthorizeParams{ + ClientID: query.Get("client_id"), + RedirectURI: query.Get("redirect_uri"), + ResponseType: query.Get("response_type"), + Scope: query.Get("scope"), + State: query.Get("state"), + CodeChallenge: query.Get("code_challenge"), + CodeChallengeMethod: query.Get("code_challenge_method"), + } + + // validate basic required parameters (client_id, redirect_uri) + // this errors wont be redirected, just returned in the json + params, err := s.validateBasicAuthorizeParams(params) if err != nil { return err } - // Validate client exists and redirect_uri matches client, err := s.getOAuthServerClient(ctx, params.ClientID) if err != nil { if models.IsNotFoundError(err) { @@ -87,10 +107,20 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er return apierrors.NewInternalServerError("error validating client").WithInternalError(err) } + // validate redirect_uri matches client's registered URIs if !s.isValidRedirectURI(client, params.RedirectURI) { + // Invalid redirect_uri should NOT redirect per OAuth2 spec since we can't trust it return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid redirect_uri") } + // From this point on, we have valid client + redirect_uri + all params, so we can redirect errors + // validate all other parameters - now we can redirect errors + if err := s.validateRemainingAuthorizeParams(params); err != nil { + errorRedirectURL := s.buildErrorRedirectURL(params.RedirectURI, oAuth2ErrorInvalidRequest, err.Error(), params.State) + http.Redirect(w, r, errorRedirectURL, http.StatusFound) + return nil + } + // Store authorization request in database (without user initially) authorization := models.NewOAuthServerAuthorization( params.ClientID, @@ -102,7 +132,10 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er ) if err := models.CreateOAuthServerAuthorization(db, authorization); err != nil { - return apierrors.NewInternalServerError("error creating authorization").WithInternalError(err) + // Error creating authorization - redirect with server_error + errorRedirectURL := s.buildErrorRedirectURL(params.RedirectURI, oAuth2ErrorServerError, "error creating authorization", params.State) + http.Redirect(w, r, errorRedirectURL, http.StatusFound) + return nil } observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) @@ -110,7 +143,10 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er // Redirect to authorization path with authorization_id if config.OAuthServer.AuthorizationPath == "" { - return apierrors.NewInternalServerError("oauth authorization path not configured") + // OAuth authorization path not configured - redirect with server_error + errorRedirectURL := s.buildErrorRedirectURL(params.RedirectURI, oAuth2ErrorServerError, "oauth authorization path not configured", params.State) + http.Redirect(w, r, errorRedirectURL, http.StatusFound) + return nil } baseURL := s.buildAuthorizationURL(config.SiteURL, config.OAuthServer.AuthorizationPath) @@ -283,7 +319,11 @@ func (s *Server) OAuthServerConsent(w http.ResponseWriter, r *http.Request) erro // Build error redirect URL // Errors are being returned to the client in the redirect url per OAuth2 spec - redirectURL = s.buildErrorRedirectURL(authorization, "access_denied", "User denied the request") + var state string + if authorization.State != nil { + state = *authorization.State + } + redirectURL = s.buildErrorRedirectURL(authorization.RedirectURI, oAuth2ErrorAccessDenied, "User denied the request", state) observability.LogEntrySetField(r, "oauth_consent_action", string(OAuthServerConsentActionDeny)) } @@ -350,20 +390,8 @@ func (s *Server) validateAuthorizationOwnership(r *http.Request, authorization * return nil } -func (s *Server) validateAuthorizeParams(r *http.Request) (*AuthorizeParams, error) { - query := r.URL.Query() - - params := &AuthorizeParams{ - ClientID: query.Get("client_id"), - RedirectURI: query.Get("redirect_uri"), - ResponseType: query.Get("response_type"), - Scope: query.Get("scope"), - State: query.Get("state"), - CodeChallenge: query.Get("code_challenge"), - CodeChallengeMethod: query.Get("code_challenge_method"), - } - - // Required parameters +// validateBasicAuthorizeParams validates only client_id and redirect_uri (needed before we can redirect errors) +func (s *Server) validateBasicAuthorizeParams(params *AuthorizeParams) (*AuthorizeParams, error) { if params.ClientID == "" { return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_id is required") } @@ -371,7 +399,11 @@ func (s *Server) validateAuthorizeParams(r *http.Request) (*AuthorizeParams, err return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "redirect_uri is required") } - // Default values + return params, nil +} + +// validateRemainingAuthorizeParams validates all other parameters (can redirect errors since we have valid client + redirect_uri) +func (s *Server) validateRemainingAuthorizeParams(params *AuthorizeParams) error { if params.ResponseType == "" { params.ResponseType = models.OAuthServerResponseTypeCode.String() } @@ -381,34 +413,32 @@ func (s *Server) validateAuthorizeParams(r *http.Request) (*AuthorizeParams, err // OAuth 2.1 only supports "code" response type if params.ResponseType != models.OAuthServerResponseTypeCode.String() { - return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Only response_type=code is supported") + return errors.New("Only response_type=code is supported") } // PKCE validation if err := s.validatePKCEParams(params.CodeChallengeMethod, params.CodeChallenge); err != nil { - return nil, err + return err } - return params, nil + return nil } func (s *Server) validatePKCEParams(codeChallengeMethod, codeChallenge string) error { // PKCE is mandatory for the authorization code flow OAuth2.1 // Both code_challenge and code_challenge_method must be provided together if codeChallenge == "" || codeChallengeMethod == "" { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "PKCE flow requires both code_challenge and code_challenge_method") + return errors.New("PKCE flow requires both code_challenge and code_challenge_method") } // Validate code challenge method (case-insensitive) if strings.ToLower(codeChallengeMethod) != "s256" && strings.ToLower(codeChallengeMethod) != "plain" { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, - "code_challenge_method must be 'S256' or 'plain'") + return errors.New("code_challenge_method must be 'S256' or 'plain'") } // Validate code challenge format and length (per OAuth2 spec) if len(codeChallenge) < 43 || len(codeChallenge) > 128 { - return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, - "code_challenge must be between 43 and 128 characters") + return errors.New("code_challenge must be between 43 and 128 characters") } return nil @@ -472,13 +502,14 @@ func (s *Server) buildSuccessRedirectURL(authorization *models.OAuthServerAuthor return u.String() } -func (s *Server) buildErrorRedirectURL(authorization *models.OAuthServerAuthorization, errorCode, errorDescription string) string { - u, _ := url.Parse(authorization.RedirectURI) +// buildErrorRedirectURL builds an error redirect URL with the given parameters +func (s *Server) buildErrorRedirectURL(redirectURI, errorCode, errorDescription, state string) string { + u, _ := url.Parse(redirectURI) q := u.Query() q.Set("error", errorCode) q.Set("error_description", errorDescription) - if authorization.State != nil && *authorization.State != "" { - q.Set("state", *authorization.State) + if state != "" { + q.Set("state", state) } u.RawQuery = q.Encode() return u.String() @@ -488,11 +519,11 @@ func (s *Server) buildErrorRedirectURL(authorization *models.OAuthServerAuthoriz func (s *Server) buildAuthorizationURL(baseURL, pathToJoin string) string { // Trim trailing slash from baseURL baseURL = strings.TrimRight(baseURL, "/") - + // Ensure pathToJoin starts with a slash if !strings.HasPrefix(pathToJoin, "/") { pathToJoin = "/" + pathToJoin } - + return baseURL + pathToJoin } From 8f520ed565c186aa7fa9932af2248639bc0e095d Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Thu, 14 Aug 2025 15:46:53 +0200 Subject: [PATCH 11/15] feat: add origin check for the GET/POST authorization --- internal/api/oauthserver/authorize.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/internal/api/oauthserver/authorize.go b/internal/api/oauthserver/authorize.go index 6794b8421..bc8c7a263 100644 --- a/internal/api/oauthserver/authorize.go +++ b/internal/api/oauthserver/authorize.go @@ -161,6 +161,11 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ ctx := r.Context() db := s.db.WithContext(ctx) + // Validate request origin - the request must come from the site URL as we redirected there at the first place + if err := s.validateRequestOrigin(r); err != nil { + return err + } + // Get authenticated user user := shared.GetUser(ctx) if user == nil { @@ -241,6 +246,11 @@ func (s *Server) OAuthServerConsent(w http.ResponseWriter, r *http.Request) erro ctx := r.Context() db := s.db.WithContext(ctx) + // Validate request origin - the request must come from the site URL as we redirected there at the first place + if err := s.validateRequestOrigin(r); err != nil { + return err + } + // Get authenticated user user := shared.GetUser(ctx) if user == nil { @@ -345,6 +355,21 @@ func (s *Server) OAuthServerConsent(w http.ResponseWriter, r *http.Request) erro // Helper functions +// validateRequestOrigin checks if the request is coming from an authorized origin +func (s *Server) validateRequestOrigin(r *http.Request) error { + // Check referer header + referer := r.Referer() + if referer == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "request must originate from authorized domain") + } + + if !utilities.IsRedirectURLValid(s.config, referer) { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "unauthorized request origin") + } + + return nil +} + // validateAndFindAuthorization validates the authorization_id parameter and finds the authorization, // performing all necessary checks (existence, expiration, status) func (s *Server) validateAndFindAuthorization(r *http.Request, db *storage.Connection, authorizationID string) (*models.OAuthServerAuthorization, error) { From 0c128e1837e6d37626953e6185333aa2217d1f17 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Thu, 14 Aug 2025 22:13:39 +0200 Subject: [PATCH 12/15] fix: error strings should not be capitalized --- internal/api/oauthserver/authorize.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/api/oauthserver/authorize.go b/internal/api/oauthserver/authorize.go index bc8c7a263..9fcffb481 100644 --- a/internal/api/oauthserver/authorize.go +++ b/internal/api/oauthserver/authorize.go @@ -438,7 +438,7 @@ func (s *Server) validateRemainingAuthorizeParams(params *AuthorizeParams) error // OAuth 2.1 only supports "code" response type if params.ResponseType != models.OAuthServerResponseTypeCode.String() { - return errors.New("Only response_type=code is supported") + return errors.New("only response_type=code is supported") } // PKCE validation From 7ca2846eccb758f32965f49a48a9ef220e1907eb Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Tue, 19 Aug 2025 13:49:56 +0200 Subject: [PATCH 13/15] feat: add resource param in authorize call `resource` param is being used in MCP auth to indicate which MCP server is being accessed we can provide the configuration of valid resources in the future. for now the plan is just to store during `/authorize` and validate if the same resource is being requested in the `/token` call --- internal/api/oauthserver/authorize.go | 51 +++++++++++++++++-- internal/models/oauth_authorization.go | 14 +++-- internal/models/oauth_authorization_test.go | 19 ++++--- ...0_add_oauth_authorizations_consents.up.sql | 2 + 4 files changed, 68 insertions(+), 18 deletions(-) diff --git a/internal/api/oauthserver/authorize.go b/internal/api/oauthserver/authorize.go index 9fcffb481..32133d2dc 100644 --- a/internal/api/oauthserver/authorize.go +++ b/internal/api/oauthserver/authorize.go @@ -19,11 +19,14 @@ import ( // AuthorizeParams represents the parameters for an OAuth authorization request type AuthorizeParams struct { - ClientID string `json:"client_id"` - RedirectURI string `json:"redirect_uri"` - ResponseType string `json:"response_type"` - Scope string `json:"scope"` - State string `json:"state"` + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + ResponseType string `json:"response_type"` + Scope string `json:"scope"` + State string `json:"state"` + + // Resource Resource Indicator per RFC8707 + Resource string `json:"resource"` CodeChallenge string `json:"code_challenge"` CodeChallengeMethod string `json:"code_challenge_method"` } @@ -88,6 +91,7 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er ResponseType: query.Get("response_type"), Scope: query.Get("scope"), State: query.Get("state"), + Resource: query.Get("resource"), CodeChallenge: query.Get("code_challenge"), CodeChallengeMethod: query.Get("code_challenge_method"), } @@ -127,6 +131,7 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er params.RedirectURI, params.Scope, params.State, + params.Resource, params.CodeChallenge, params.CodeChallengeMethod, ) @@ -441,6 +446,11 @@ func (s *Server) validateRemainingAuthorizeParams(params *AuthorizeParams) error return errors.New("only response_type=code is supported") } + // Resource parameter validation (per RFC 8707) + if err := s.validateResourceParam(params.Resource); err != nil { + return err + } + // PKCE validation if err := s.validatePKCEParams(params.CodeChallengeMethod, params.CodeChallenge); err != nil { return err @@ -469,6 +479,37 @@ func (s *Server) validatePKCEParams(codeChallengeMethod, codeChallenge string) e return nil } +// validateResourceParam validates the resource parameter per RFC 8707 +func (s *Server) validateResourceParam(resource string) error { + // Resource parameter is optional + if resource == "" { + return nil + } + + // Parse URL to validate it's an absolute URI + parsedURL, err := url.Parse(resource) + if err != nil { + return errors.New("resource must be a valid URI") + } + + // Must be an absolute URI (have scheme) + if !parsedURL.IsAbs() { + return errors.New("resource must be an absolute URI") + } + + // Must not include a fragment component + if parsedURL.Fragment != "" { + return errors.New("resource must not include a fragment component") + } + + // Should not include a query component + if parsedURL.RawQuery != "" { + return errors.New("resource must not include a query component") + } + + return nil +} + func (s *Server) isValidRedirectURI(client *models.OAuthServerClient, redirectURI string) bool { registeredURIs := client.GetRedirectURIs() for _, registeredURI := range registeredURIs { diff --git a/internal/models/oauth_authorization.go b/internal/models/oauth_authorization.go index cfdf2675d..1b974c91f 100644 --- a/internal/models/oauth_authorization.go +++ b/internal/models/oauth_authorization.go @@ -46,12 +46,13 @@ type OAuthServerAuthorization struct { UserID *uuid.UUID `json:"user_id" db:"user_id"` RedirectURI string `json:"redirect_uri" db:"redirect_uri"` Scope string `json:"scope" db:"scope"` - State *string `json:"state,omitempty" db:"state"` - CodeChallenge *string `json:"code_challenge,omitempty" db:"code_challenge"` - CodeChallengeMethod *string `json:"code_challenge_method,omitempty" db:"code_challenge_method"` + State *string `json:"state,omitempty" db:"state"` + Resource *string `json:"resource,omitempty" db:"resource"` + CodeChallenge *string `json:"code_challenge,omitempty" db:"code_challenge"` + CodeChallengeMethod *string `json:"code_challenge_method,omitempty" db:"code_challenge_method"` ResponseType OAuthServerResponseType `json:"response_type" db:"response_type"` Status OAuthServerAuthorizationStatus `json:"status" db:"status"` - AuthorizationCode *string `json:"-" db:"authorization_code"` + AuthorizationCode *string `json:"-" db:"authorization_code"` CreatedAt time.Time `json:"created_at" db:"created_at"` ExpiresAt time.Time `json:"expires_at" db:"expires_at"` ApprovedAt *time.Time `json:"approved_at" db:"approved_at"` @@ -66,7 +67,7 @@ func (OAuthServerAuthorization) TableName() string { } // NewOAuthServerAuthorization creates a new OAuth server authorization request without user (for initial flow) -func NewOAuthServerAuthorization(clientID, redirectURI, scope, state, codeChallenge, codeChallengeMethod string) *OAuthServerAuthorization { +func NewOAuthServerAuthorization(clientID, redirectURI, scope, state, resource, codeChallenge, codeChallengeMethod string) *OAuthServerAuthorization { id := uuid.Must(uuid.NewV4()) authorizationID := crypto.SecureAlphanumeric(32) // Generate random ID for frontend @@ -89,6 +90,9 @@ func NewOAuthServerAuthorization(clientID, redirectURI, scope, state, codeChalle if state != "" { auth.State = &state } + if resource != "" { + auth.Resource = &resource + } if codeChallenge != "" { auth.CodeChallenge = &codeChallenge } diff --git a/internal/models/oauth_authorization_test.go b/internal/models/oauth_authorization_test.go index 19f95ae5d..cae404f3b 100644 --- a/internal/models/oauth_authorization_test.go +++ b/internal/models/oauth_authorization_test.go @@ -16,8 +16,9 @@ func TestNewOAuthServerAuthorization(t *testing.T) { state := "random-state" codeChallenge := "test-challenge" codeChallengeMethod := "S256" + resource := "https://api.example.com/" - auth := NewOAuthServerAuthorization(clientID, redirectURI, scope, state, codeChallenge, codeChallengeMethod) + auth := NewOAuthServerAuthorization(clientID, redirectURI, scope, state, resource, codeChallenge, codeChallengeMethod) assert.NotEmpty(t, auth.ID) assert.NotEmpty(t, auth.AuthorizationID) @@ -26,6 +27,7 @@ func TestNewOAuthServerAuthorization(t *testing.T) { assert.Equal(t, redirectURI, auth.RedirectURI) assert.Equal(t, scope, auth.Scope) assert.Equal(t, state, *auth.State) + assert.Equal(t, resource, *auth.Resource) assert.Equal(t, codeChallenge, *auth.CodeChallenge) assert.Equal(t, "s256", *auth.CodeChallengeMethod) // Should be normalized to lowercase assert.Equal(t, OAuthServerResponseTypeCode, auth.ResponseType) @@ -51,15 +53,16 @@ func TestNewOAuthServerAuthorization_CodeChallengeMethodNormalization(t *testing for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { auth := NewOAuthServerAuthorization( - "client-id", - "https://example.com/callback", - "openid", - "state", - "challenge", + "client-id", + "https://example.com/callback", + "openid", + "state", + "", + "challenge", tc.input, ) - - assert.Equal(t, tc.expected, *auth.CodeChallengeMethod, + + assert.Equal(t, tc.expected, *auth.CodeChallengeMethod, "Expected code_challenge_method to be normalized to %s, got %s", tc.expected, *auth.CodeChallengeMethod) }) } diff --git a/migrations/20250804100000_add_oauth_authorizations_consents.up.sql b/migrations/20250804100000_add_oauth_authorizations_consents.up.sql index 2287bbd42..96ae97130 100644 --- a/migrations/20250804100000_add_oauth_authorizations_consents.up.sql +++ b/migrations/20250804100000_add_oauth_authorizations_consents.up.sql @@ -22,6 +22,7 @@ create table if not exists {{ index .Options "Namespace" }}.oauth_authorizations redirect_uri text not null, scope text not null, state text null, + resource text null, code_challenge text null, code_challenge_method {{ index .Options "Namespace" }}.code_challenge_method null, response_type {{ index .Options "Namespace" }}.oauth_response_type not null default 'code', @@ -41,6 +42,7 @@ create table if not exists {{ index .Options "Namespace" }}.oauth_authorizations constraint oauth_authorizations_redirect_uri_length check (char_length(redirect_uri) <= 2048), constraint oauth_authorizations_scope_length check (char_length(scope) <= 4096), constraint oauth_authorizations_state_length check (char_length(state) <= 4096), + constraint oauth_authorizations_resource_length check (char_length(resource) <= 2048), constraint oauth_authorizations_code_challenge_length check (char_length(code_challenge) <= 128), constraint oauth_authorizations_authorization_code_length check (char_length(authorization_code) <= 255), constraint oauth_authorizations_expires_at_future check (expires_at > created_at) From 1a313d8e49a1523d3b30cc086b1457f56dd7931d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cemal=20K=C4=B1l=C4=B1=C3=A7?= Date: Thu, 21 Aug 2025 15:06:43 +0200 Subject: [PATCH 14/15] feat: add `.well-known/oauth-authorization-server` endpoint (#2124) ## What kind of change does this PR introduce? Adding `.well-known/oauth-authorization-server` endpoint per [RFC 8414](https://datatracker.ietf.org/doc/html/rfc8414) --- internal/api/api.go | 6 +++- internal/api/oauthserver/handlers.go | 48 ++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/internal/api/api.go b/internal/api/api.go index bda0fc0de..24cecc40f 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -92,7 +92,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne db: db, version: version, } - + // Only initialize OAuth server if enabled if globalConfig.OAuthServer.Enabled { api.oauthServer = oauthserver.NewServer(globalConfig, db) @@ -175,6 +175,10 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Get("/health", api.HealthCheck) r.Get("/.well-known/jwks.json", api.Jwks) + if globalConfig.OAuthServer.Enabled { + r.Get("/.well-known/oauth-authorization-server", api.oauthServer.OAuthServerMetadata) + } + r.Route("/callback", func(r *router) { r.Use(api.isValidExternalHost) r.Use(api.loadFlowState) diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index 2135929d1..8a247a054 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -3,6 +3,7 @@ package oauthserver import ( "context" "encoding/json" + "fmt" "net/http" "time" @@ -183,3 +184,50 @@ func (s *Server) OAuthServerClientList(w http.ResponseWriter, r *http.Request) e return shared.SendJSON(w, http.StatusOK, response) } + +// OAuthServerMetadataResponse represents the OAuth 2.1 Authorization Server Metadata per RFC 8414 +type OAuthServerMetadataResponse struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + JWKSetURI string `json:"jwks_uri"` + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + ResponseTypesSupported []string `json:"response_types_supported"` + ResponseModesSupported []string `json:"response_modes_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` + + // TODO(cemal) :: Append the scopes supported when scope management is clarified! + // ScopesSupported []string `json:"scopes_supported"` +} + +// OAuthServerMetadata handles GET /.well-known/oauth-authorization-server +func (s *Server) OAuthServerMetadata(w http.ResponseWriter, r *http.Request) error { + issuer := s.config.JWT.Issuer + + // TODO(cemal) :: Remove this check when we have the config validation in place + if issuer == "" { + return apierrors.NewInternalServerError("Issuer is not set") + } + + response := OAuthServerMetadataResponse{ + Issuer: issuer, + AuthorizationEndpoint: fmt.Sprintf("%s/oauth/authorize", issuer), + TokenEndpoint: fmt.Sprintf("%s/oauth/token", issuer), + JWKSetURI: fmt.Sprintf("%s/.well-known/jwks.json", issuer), + ResponseTypesSupported: []string{"code"}, + ResponseModesSupported: []string{"query"}, + GrantTypesSupported: []string{"authorization_code", "refresh_token"}, + TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"}, + CodeChallengeMethodsSupported: []string{"S256", "plain"}, + } + + // Include registration endpoint if dynamic registration is enabled + if s.config.OAuthServer.AllowDynamicRegistration { + response.RegistrationEndpoint = fmt.Sprintf("%s/oauth/clients/register", issuer) + } + + // TODO: Cache response for 10 minutes, but consider dynamic registration toggle changes + return shared.SendJSON(w, http.StatusOK, response) +} From 563b0338862684bbd4abd3625f1ea0f39ac5b6fd Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Mon, 1 Sep 2025 17:54:00 +0200 Subject: [PATCH 15/15] fix: use the primary key of oauth_clients for FK --- internal/api/oauthserver/authorize.go | 6 +++--- internal/models/oauth_authorization.go | 14 +++++++------- internal/models/oauth_authorization_test.go | 9 +++++---- internal/models/oauth_consent.go | 12 ++++++------ internal/models/oauth_consent_test.go | 6 +++--- ...100000_add_oauth_authorizations_consents.up.sql | 4 ++-- 6 files changed, 26 insertions(+), 25 deletions(-) diff --git a/internal/api/oauthserver/authorize.go b/internal/api/oauthserver/authorize.go index 32133d2dc..599a710ac 100644 --- a/internal/api/oauthserver/authorize.go +++ b/internal/api/oauthserver/authorize.go @@ -127,7 +127,7 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er // Store authorization request in database (without user initially) authorization := models.NewOAuthServerAuthorization( - params.ClientID, + client.ID, // Use the client's UUID instead of the public client_id string params.RedirectURI, params.Scope, params.State, @@ -144,7 +144,7 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er } observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) - observability.LogEntrySetField(r, "client_id", authorization.ClientID) + observability.LogEntrySetField(r, "client_id", client.ClientID) // Redirect to authorization path with authorization_id if config.OAuthServer.AuthorizationPath == "" { @@ -241,7 +241,7 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ } observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) - observability.LogEntrySetField(r, "client_id", authorization.ClientID) + observability.LogEntrySetField(r, "client_id", authorization.Client.ClientID) return shared.SendJSON(w, http.StatusOK, response) } diff --git a/internal/models/oauth_authorization.go b/internal/models/oauth_authorization.go index 1b974c91f..409e03ccd 100644 --- a/internal/models/oauth_authorization.go +++ b/internal/models/oauth_authorization.go @@ -42,7 +42,7 @@ func (rt OAuthServerResponseType) String() string { type OAuthServerAuthorization struct { ID uuid.UUID `json:"-" db:"id"` AuthorizationID string `json:"authorization_id" db:"authorization_id"` - ClientID string `json:"client_id" db:"client_id"` + ClientID uuid.UUID `json:"-" db:"client_id"` UserID *uuid.UUID `json:"user_id" db:"user_id"` RedirectURI string `json:"redirect_uri" db:"redirect_uri"` Scope string `json:"scope" db:"scope"` @@ -67,7 +67,7 @@ func (OAuthServerAuthorization) TableName() string { } // NewOAuthServerAuthorization creates a new OAuth server authorization request without user (for initial flow) -func NewOAuthServerAuthorization(clientID, redirectURI, scope, state, resource, codeChallenge, codeChallengeMethod string) *OAuthServerAuthorization { +func NewOAuthServerAuthorization(clientID uuid.UUID, redirectURI, scope, state, resource, codeChallenge, codeChallengeMethod string) *OAuthServerAuthorization { id := uuid.Must(uuid.NewV4()) authorizationID := crypto.SecureAlphanumeric(32) // Generate random ID for frontend @@ -173,7 +173,7 @@ func (auth *OAuthServerAuthorization) MarkExpired(tx *storage.Connection) error // Validate performs basic validation on the OAuth authorization func (auth *OAuthServerAuthorization) Validate() error { - if auth.ClientID == "" { + if auth.ClientID == uuid.Nil { return fmt.Errorf("client_id is required") } // UserID can be nil initially for unauthenticated authorization requests @@ -225,9 +225,9 @@ func FindOAuthServerAuthorizationByID(tx *storage.Connection, authorizationID st return nil, errors.Wrap(err, "error finding OAuth authorization") } - if auth.ClientID != "" { + if auth.ClientID != uuid.Nil { client := &OAuthServerClient{} - if err := tx.Q().Where("client_id = ?", auth.ClientID).First(client); err == nil { + if err := tx.Q().Where("id = ?", auth.ClientID).First(client); err == nil { auth.Client = client } } @@ -246,9 +246,9 @@ func FindOAuthServerAuthorizationByCode(tx *storage.Connection, code string) (*O } // Load client relationship (always present) - if auth.ClientID != "" { + if auth.ClientID != uuid.Nil { client := &OAuthServerClient{} - if err := tx.Q().Where("client_id = ?", auth.ClientID).First(client); err == nil { + if err := tx.Q().Where("id = ?", auth.ClientID).First(client); err == nil { auth.Client = client } } diff --git a/internal/models/oauth_authorization_test.go b/internal/models/oauth_authorization_test.go index cae404f3b..d039efdcb 100644 --- a/internal/models/oauth_authorization_test.go +++ b/internal/models/oauth_authorization_test.go @@ -10,7 +10,7 @@ import ( ) func TestNewOAuthServerAuthorization(t *testing.T) { - clientID := "test-client-id" + clientID := uuid.Must(uuid.NewV4()) redirectURI := "https://example.com/callback" scope := "openid profile" state := "random-state" @@ -53,7 +53,7 @@ func TestNewOAuthServerAuthorization_CodeChallengeMethodNormalization(t *testing for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { auth := NewOAuthServerAuthorization( - "client-id", + uuid.Must(uuid.NewV4()), "https://example.com/callback", "openid", "state", @@ -291,8 +291,9 @@ func TestOAuthServerAuthorization_MarkExpiredLogic(t *testing.T) { func TestOAuthServerAuthorization_Validate(t *testing.T) { userID := uuid.Must(uuid.NewV4()) + clientID := uuid.Must(uuid.NewV4()) validAuth := &OAuthServerAuthorization{ - ClientID: "test-client", + ClientID: clientID, UserID: &userID, RedirectURI: "https://example.com/callback", Scope: "openid", @@ -318,7 +319,7 @@ func TestOAuthServerAuthorization_Validate(t *testing.T) { }{ { name: "missing client_id", - modify: func(a *OAuthServerAuthorization) { a.ClientID = "" }, + modify: func(a *OAuthServerAuthorization) { a.ClientID = uuid.Nil }, wantErr: true, errMsg: "client_id is required", }, diff --git a/internal/models/oauth_consent.go b/internal/models/oauth_consent.go index cef2dd294..4aeeaa4bf 100644 --- a/internal/models/oauth_consent.go +++ b/internal/models/oauth_consent.go @@ -15,7 +15,7 @@ import ( type OAuthServerConsent struct { ID uuid.UUID `json:"id" db:"id"` UserID uuid.UUID `json:"user_id" db:"user_id"` - ClientID string `json:"client_id" db:"client_id"` + ClientID uuid.UUID `json:"-" db:"client_id"` Scopes string `json:"scopes" db:"scopes"` GrantedAt time.Time `json:"granted_at" db:"granted_at"` RevokedAt *time.Time `json:"revoked_at" db:"revoked_at"` @@ -27,7 +27,7 @@ func (OAuthServerConsent) TableName() string { } // NewOAuthConsent creates a new OAuth consent record -func NewOAuthServerConsent(userID uuid.UUID, clientID string, scopes []string) *OAuthServerConsent { +func NewOAuthServerConsent(userID uuid.UUID, clientID uuid.UUID, scopes []string) *OAuthServerConsent { return &OAuthServerConsent{ ID: uuid.Must(uuid.NewV4()), UserID: userID, @@ -84,7 +84,7 @@ func (consent *OAuthServerConsent) Validate() error { if consent.UserID == uuid.Nil { return fmt.Errorf("user_id is required") } - if consent.ClientID == "" { + if consent.ClientID == uuid.Nil { return fmt.Errorf("client_id is required") } if strings.TrimSpace(consent.Scopes) == "" { @@ -100,7 +100,7 @@ func (consent *OAuthServerConsent) Validate() error { // Query functions for OAuth consents // FindOAuthServerConsentByUserAndClient finds an OAuth consent by user and client -func FindOAuthServerConsentByUserAndClient(tx *storage.Connection, userID uuid.UUID, clientID string) (*OAuthServerConsent, error) { +func FindOAuthServerConsentByUserAndClient(tx *storage.Connection, userID uuid.UUID, clientID uuid.UUID) (*OAuthServerConsent, error) { consent := &OAuthServerConsent{} if err := tx.Eager().Q().Where("user_id = ? AND client_id = ?", userID, clientID).First(consent); err != nil { if errors.Cause(err) == sql.ErrNoRows { @@ -112,7 +112,7 @@ func FindOAuthServerConsentByUserAndClient(tx *storage.Connection, userID uuid.U } // FindActiveOAuthServerConsentByUserAndClient finds an active (non-revoked) OAuth consent -func FindActiveOAuthServerConsentByUserAndClient(tx *storage.Connection, userID uuid.UUID, clientID string) (*OAuthServerConsent, error) { +func FindActiveOAuthServerConsentByUserAndClient(tx *storage.Connection, userID uuid.UUID, clientID uuid.UUID) (*OAuthServerConsent, error) { consent := &OAuthServerConsent{} if err := tx.Q().Where("user_id = ? AND client_id = ? AND revoked_at IS NULL", userID, clientID).First(consent); err != nil { if errors.Cause(err) == sql.ErrNoRows { @@ -167,7 +167,7 @@ func UpsertOAuthServerConsent(tx *storage.Connection, consent *OAuthServerConsen } // RevokeOAuthServerConsentsByClient revokes all consents for a specific client -func RevokeOAuthServerConsentsByClient(tx *storage.Connection, clientID string) error { +func RevokeOAuthServerConsentsByClient(tx *storage.Connection, clientID uuid.UUID) error { now := time.Now() query := "UPDATE " + (&OAuthServerConsent{}).TableName() + " SET revoked_at = ? WHERE client_id = ? AND revoked_at IS NULL" return tx.RawQuery(query, now, clientID).Exec() diff --git a/internal/models/oauth_consent_test.go b/internal/models/oauth_consent_test.go index 51c5c2e43..4548004f3 100644 --- a/internal/models/oauth_consent_test.go +++ b/internal/models/oauth_consent_test.go @@ -10,7 +10,7 @@ import ( func TestNewOAuthServerConsent(t *testing.T) { userID := uuid.Must(uuid.NewV4()) - clientID := "test-client-id" + clientID := uuid.Must(uuid.NewV4()) scopes := []string{"openid", "profile", "email"} consent := NewOAuthServerConsent(userID, clientID, scopes) @@ -38,7 +38,7 @@ func TestOAuthServerConsent_IsRevoked(t *testing.T) { func TestOAuthServerConsent_Validate(t *testing.T) { validConsent := &OAuthServerConsent{ UserID: uuid.Must(uuid.NewV4()), - ClientID: "test-client", + ClientID: uuid.Must(uuid.NewV4()), Scopes: "openid profile", GrantedAt: time.Now(), } @@ -61,7 +61,7 @@ func TestOAuthServerConsent_Validate(t *testing.T) { }, { name: "missing client_id", - modify: func(c *OAuthServerConsent) { c.ClientID = "" }, + modify: func(c *OAuthServerConsent) { c.ClientID = uuid.Nil }, wantErr: true, errMsg: "client_id is required", }, diff --git a/migrations/20250804100000_add_oauth_authorizations_consents.up.sql b/migrations/20250804100000_add_oauth_authorizations_consents.up.sql index 96ae97130..a74b4305c 100644 --- a/migrations/20250804100000_add_oauth_authorizations_consents.up.sql +++ b/migrations/20250804100000_add_oauth_authorizations_consents.up.sql @@ -17,7 +17,7 @@ end $$; create table if not exists {{ index .Options "Namespace" }}.oauth_authorizations ( id uuid not null, authorization_id text not null, - client_id text not null references {{ index .Options "Namespace" }}.oauth_clients(client_id) on delete cascade, + client_id uuid not null references {{ index .Options "Namespace" }}.oauth_clients(id) on delete cascade, user_id uuid null references {{ index .Options "Namespace" }}.users(id) on delete cascade, redirect_uri text not null, scope text not null, @@ -60,7 +60,7 @@ create index if not exists oauth_auth_pending_exp_idx create table if not exists {{ index .Options "Namespace" }}.oauth_consents ( id uuid not null, user_id uuid not null references {{ index .Options "Namespace" }}.users(id) on delete cascade, - client_id text not null references {{ index .Options "Namespace" }}.oauth_clients(client_id) on delete cascade, + client_id uuid not null references {{ index .Options "Namespace" }}.oauth_clients(id) on delete cascade, scopes text not null, granted_at timestamptz not null default now(), revoked_at timestamptz null,