diff --git a/go.mod b/go.mod index 6c06d4b84..e3db84713 100644 --- a/go.mod +++ b/go.mod @@ -173,7 +173,7 @@ require ( golang.org/x/net v0.38.0 // indirect golang.org/x/sync v0.12.0 golang.org/x/sys v0.31.0 - golang.org/x/text v0.23.0 // indirect + golang.org/x/text v0.23.0 golang.org/x/time v0.9.0 google.golang.org/appengine v1.6.8 // indirect google.golang.org/grpc v1.63.2 // indirect diff --git a/internal/api/api.go b/internal/api/api.go index bdf11265f..24cecc40f 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -88,10 +88,14 @@ func (a *API) deprecationNotices() { // NewAPIWithVersion creates a new REST API using the specified version func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string, opt ...Option) *API { api := &API{ - config: globalConfig, - db: db, - version: version, - oauthServer: oauthserver.NewServer(globalConfig, db), + config: globalConfig, + db: db, + version: version, + } + + // Only initialize OAuth server if enabled + if globalConfig.OAuthServer.Enabled { + api.oauthServer = oauthserver.NewServer(globalConfig, db) } for _, o := range opt { @@ -171,6 +175,10 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Get("/health", api.HealthCheck) r.Get("/.well-known/jwks.json", api.Jwks) + if globalConfig.OAuthServer.Enabled { + r.Get("/.well-known/oauth-authorization-server", api.oauthServer.OAuthServerMetadata) + } + r.Route("/callback", func(r *router) { r.Use(api.isValidExternalHost) r.Use(api.loadFlowState) @@ -185,6 +193,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Get("/settings", api.Settings) + // `/authorize` to initiate OAuth2 authorization flow with the external providers + // where Supabase Auth is an OAuth2 Client r.Get("/authorize", api.ExternalProviderRedirect) r.With(api.requireAdminCredentials).Post("/invite", api.Invite) @@ -325,27 +335,37 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne }) // Admin only oauth client management endpoints - r.Route("/oauth", func(r *router) { - r.Route("/clients", func(r *router) { - // Manual client registration - r.Post("/", api.oauthServer.AdminOAuthServerClientRegister) - - r.Get("/", api.oauthServer.OAuthServerClientList) - - r.Route("/{client_id}", func(r *router) { - r.Use(api.oauthServer.LoadOAuthServerClient) - r.Get("/", api.oauthServer.OAuthServerClientGet) - r.Delete("/", api.oauthServer.OAuthServerClientDelete) + if globalConfig.OAuthServer.Enabled { + r.Route("/oauth", func(r *router) { + r.Route("/clients", func(r *router) { + // Manual client registration + r.Post("/", api.oauthServer.AdminOAuthServerClientRegister) + + r.Get("/", api.oauthServer.OAuthServerClientList) + + r.Route("/{client_id}", func(r *router) { + r.Use(api.oauthServer.LoadOAuthServerClient) + r.Get("/", api.oauthServer.OAuthServerClientGet) + r.Delete("/", api.oauthServer.OAuthServerClientDelete) + }) }) }) - }) + } }) // OAuth Dynamic Client Registration endpoint (public, rate limited) - r.Route("/oauth", func(r *router) { - r.With(api.limitHandler(api.limiterOpts.OAuthClientRegister)). - Post("/clients/register", api.oauthServer.OAuthServerClientDynamicRegister) - }) + if globalConfig.OAuthServer.Enabled { + r.Route("/oauth", func(r *router) { + r.With(api.limitHandler(api.limiterOpts.OAuthClientRegister)). + Post("/clients/register", api.oauthServer.OAuthServerClientDynamicRegister) + + // OAuth 2.1 Authorization endpoints + // `/authorize` to initiate OAuth2 authorization code flow where Supabase Auth is the OAuth2 provider + r.Get("/authorize", api.oauthServer.OAuthServerAuthorize) + r.With(api.requireAuthentication).Get("/authorizations/{authorization_id}", api.oauthServer.OAuthServerGetAuthorization) + r.With(api.requireAuthentication).Post("/authorizations/{authorization_id}/consent", api.oauthServer.OAuthServerConsent) + }) + } }) corsHandler := cors.New(cors.Options{ diff --git a/internal/api/api_test.go b/internal/api/api_test.go index a472be737..c205503bf 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -55,3 +55,30 @@ func TestEmailEnabledByDefault(t *testing.T) { require.True(t, api.config.External.Email.Enabled) } + +func TestOAuthServerDisabledByDefault(t *testing.T) { + api, _, err := setupAPIForTest() + require.NoError(t, err) + + // OAuth server should be disabled by default + require.False(t, api.config.OAuthServer.Enabled) + + // OAuth server instance should not be initialized when disabled + require.Nil(t, api.oauthServer) +} + +func TestOAuthServerCanBeEnabled(t *testing.T) { + api, _, err := setupAPIForTestWithCallback(func(config *conf.GlobalConfiguration, conn *storage.Connection) { + if config != nil { + // Enable OAuth server + config.OAuthServer.Enabled = true + } + }) + require.NoError(t, err) + + // OAuth server should be enabled + require.True(t, api.config.OAuthServer.Enabled) + + // OAuth server instance should be initialized when enabled + require.NotNil(t, api.oauthServer) +} diff --git a/internal/api/apierrors/errorcode.go b/internal/api/apierrors/errorcode.go index 710797548..7658c1225 100644 --- a/internal/api/apierrors/errorcode.go +++ b/internal/api/apierrors/errorcode.go @@ -97,4 +97,7 @@ const ( ErrorCodeWeb3UnsupportedChain ErrorCode = "web3_unsupported_chain" ErrorCodeOAuthDynamicClientRegistrationDisabled ErrorCode = "oauth_dynamic_client_registration_disabled" ErrorCodeEmailAddressNotProvided ErrorCode = "email_address_not_provided" + + ErrorCodeOAuthClientNotFound ErrorCode = "oauth_client_not_found" + ErrorCodeOAuthAuthorizationNotFound ErrorCode = "oauth_authorization_not_found" ) diff --git a/internal/api/context.go b/internal/api/context.go index 5f0f744c4..e1d285cb1 100644 --- a/internal/api/context.go +++ b/internal/api/context.go @@ -5,6 +5,7 @@ import ( "net/url" jwt "github.com/golang-jwt/jwt/v5" + "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/models" ) @@ -21,7 +22,6 @@ const ( tokenKey = contextKey("jwt") inviteTokenKey = contextKey("invite_token") signatureKey = contextKey("signature") - userKey = contextKey("user") targetUserKey = contextKey("target_user") factorKey = contextKey("factor") sessionKey = contextKey("session") @@ -60,7 +60,7 @@ func getClaims(ctx context.Context) *AccessTokenClaims { // withUser adds the user to the context. func withUser(ctx context.Context, u *models.User) context.Context { - return context.WithValue(ctx, userKey, u) + return shared.WithUser(ctx, u) } // withTargetUser adds the target user for linking to the context. @@ -75,14 +75,7 @@ func withFactor(ctx context.Context, f *models.Factor) context.Context { // getUser reads the user from the context. func getUser(ctx context.Context) *models.User { - if ctx == nil { - return nil - } - obj := ctx.Value(userKey) - if obj == nil { - return nil - } - return obj.(*models.User) + return shared.GetUser(ctx) } // getTargetUser reads the user from the context. diff --git a/internal/api/oauthserver/authorize.go b/internal/api/oauthserver/authorize.go new file mode 100644 index 000000000..599a710ac --- /dev/null +++ b/internal/api/oauthserver/authorize.go @@ -0,0 +1,595 @@ +package oauthserver + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + "github.com/go-chi/chi/v5" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/shared" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" +) + +// AuthorizeParams represents the parameters for an OAuth authorization request +type AuthorizeParams struct { + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + ResponseType string `json:"response_type"` + Scope string `json:"scope"` + State string `json:"state"` + + // Resource Resource Indicator per RFC8707 + Resource string `json:"resource"` + CodeChallenge string `json:"code_challenge"` + CodeChallengeMethod string `json:"code_challenge_method"` +} + +// AuthorizationDetailsResponse represents the response for getting authorization details +type AuthorizationDetailsResponse struct { + AuthorizationID string `json:"authorization_id"` + Client ClientDetailsResponse `json:"client,omitempty"` + User UserDetailsResponse `json:"user,omitempty"` + Scope string `json:"scope,omitempty"` +} + +// ClientDetailsResponse represents client details in authorization response +type ClientDetailsResponse struct { + ClientID string `json:"client_id"` + ClientName string `json:"client_name,omitempty"` + ClientURI string `json:"client_uri,omitempty"` + LogoURI string `json:"logo_uri,omitempty"` +} + +// UserDetailsResponse represents user details in authorization response +type UserDetailsResponse struct { + ID string `json:"id,omitempty"` + Email string `json:"email,omitempty"` +} + +// ConsentRequest represents a consent decision request +type ConsentRequest struct { + Action OAuthServerConsentAction `json:"action"` +} + +// ConsentResponse represents the response after processing consent +type ConsentResponse struct { + RedirectURL string `json:"redirect_url,omitempty"` +} + +type OAuthServerConsentAction string + +const ( + OAuthServerConsentActionApprove OAuthServerConsentAction = "approve" + OAuthServerConsentActionDeny OAuthServerConsentAction = "deny" +) + +// OAuth2 error codes per RFC 6749 +const ( + oAuth2ErrorInvalidRequest = "invalid_request" + oAuth2ErrorServerError = "server_error" + oAuth2ErrorAccessDenied = "access_denied" +) + +// OAuthServerAuthorize handles GET /oauth/authorize +func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := s.db.WithContext(ctx) + config := s.config + + query := r.URL.Query() + + params := &AuthorizeParams{ + ClientID: query.Get("client_id"), + RedirectURI: query.Get("redirect_uri"), + ResponseType: query.Get("response_type"), + Scope: query.Get("scope"), + State: query.Get("state"), + Resource: query.Get("resource"), + CodeChallenge: query.Get("code_challenge"), + CodeChallengeMethod: query.Get("code_challenge_method"), + } + + // validate basic required parameters (client_id, redirect_uri) + // this errors wont be redirected, just returned in the json + params, err := s.validateBasicAuthorizeParams(params) + if err != nil { + return err + } + + client, err := s.getOAuthServerClient(ctx, params.ClientID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthClientNotFound, "invalid client_id") + } + return apierrors.NewInternalServerError("error validating client").WithInternalError(err) + } + + // validate redirect_uri matches client's registered URIs + if !s.isValidRedirectURI(client, params.RedirectURI) { + // Invalid redirect_uri should NOT redirect per OAuth2 spec since we can't trust it + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid redirect_uri") + } + + // From this point on, we have valid client + redirect_uri + all params, so we can redirect errors + // validate all other parameters - now we can redirect errors + if err := s.validateRemainingAuthorizeParams(params); err != nil { + errorRedirectURL := s.buildErrorRedirectURL(params.RedirectURI, oAuth2ErrorInvalidRequest, err.Error(), params.State) + http.Redirect(w, r, errorRedirectURL, http.StatusFound) + return nil + } + + // Store authorization request in database (without user initially) + authorization := models.NewOAuthServerAuthorization( + client.ID, // Use the client's UUID instead of the public client_id string + params.RedirectURI, + params.Scope, + params.State, + params.Resource, + params.CodeChallenge, + params.CodeChallengeMethod, + ) + + if err := models.CreateOAuthServerAuthorization(db, authorization); err != nil { + // Error creating authorization - redirect with server_error + errorRedirectURL := s.buildErrorRedirectURL(params.RedirectURI, oAuth2ErrorServerError, "error creating authorization", params.State) + http.Redirect(w, r, errorRedirectURL, http.StatusFound) + return nil + } + + observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) + observability.LogEntrySetField(r, "client_id", client.ClientID) + + // Redirect to authorization path with authorization_id + if config.OAuthServer.AuthorizationPath == "" { + // OAuth authorization path not configured - redirect with server_error + errorRedirectURL := s.buildErrorRedirectURL(params.RedirectURI, oAuth2ErrorServerError, "oauth authorization path not configured", params.State) + http.Redirect(w, r, errorRedirectURL, http.StatusFound) + return nil + } + + baseURL := s.buildAuthorizationURL(config.SiteURL, config.OAuthServer.AuthorizationPath) + redirectURL := fmt.Sprintf("%s?authorization_id=%s", baseURL, authorization.AuthorizationID) + + http.Redirect(w, r, redirectURL, http.StatusFound) + return nil +} + +// OAuthServerGetAuthorization handles GET /oauth/authorizations/{authorization_id} +func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := s.db.WithContext(ctx) + + // Validate request origin - the request must come from the site URL as we redirected there at the first place + if err := s.validateRequestOrigin(r); err != nil { + return err + } + + // Get authenticated user + user := shared.GetUser(ctx) + if user == nil { + return apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "authentication required") + } + + authorizationID := chi.URLParam(r, "authorization_id") + authorization, err := s.validateAndFindAuthorization(r, db, authorizationID) + if err != nil { + return err + } + + // Set user_id if not already set + if authorization.UserID == nil { + // Use transaction to atomically set user and check for auto-approve + var shouldAutoApprove bool + var existingConsent *models.OAuthServerConsent + + err := db.Transaction(func(tx *storage.Connection) error { + if err := authorization.SetUser(tx, user.ID); err != nil { + return err + } + + // Check for existing consent and auto-approve if available + var err error + existingConsent, err = models.FindActiveOAuthServerConsentByUserAndClient(tx, user.ID, authorization.ClientID) + if err != nil { + return err + } + + // Check if consent covers requested scopes + if existingConsent != nil && s.consentCoversScopes(existingConsent, authorization.Scope) { + shouldAutoApprove = true + } + + return nil + }) + + if err != nil { + return apierrors.NewInternalServerError("error setting user and checking consent").WithInternalError(err) + } + + // If we should auto-approve, do it now + if shouldAutoApprove { + return s.autoApproveAndRedirect(w, r, authorization) + } + } else { + // Authorization already has user_id set, validate ownership + if err := s.validateAuthorizationOwnership(r, authorization, user); err != nil { + return err + } + } + + // Build response with client and user details + response := AuthorizationDetailsResponse{ + AuthorizationID: authorization.AuthorizationID, + Client: ClientDetailsResponse{ + ClientID: authorization.Client.ClientID, + ClientName: utilities.StringValue(authorization.Client.ClientName), + ClientURI: utilities.StringValue(authorization.Client.ClientURI), + LogoURI: utilities.StringValue(authorization.Client.LogoURI), + }, + User: UserDetailsResponse{ + ID: user.ID.String(), + Email: user.Email.String(), + }, + Scope: authorization.Scope, + } + + observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) + observability.LogEntrySetField(r, "client_id", authorization.Client.ClientID) + + return shared.SendJSON(w, http.StatusOK, response) +} + +// OAuthServerConsent handles POST /oauth/authorizations/{authorization_id}/consent +func (s *Server) OAuthServerConsent(w http.ResponseWriter, r *http.Request) error { + ctx := r.Context() + db := s.db.WithContext(ctx) + + // Validate request origin - the request must come from the site URL as we redirected there at the first place + if err := s.validateRequestOrigin(r); err != nil { + return err + } + + // Get authenticated user + user := shared.GetUser(ctx) + if user == nil { + return apierrors.NewForbiddenError(apierrors.ErrorCodeBadJWT, "authentication required") + } + + var body ConsentRequest + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeBadJSON, "invalid JSON body") + } + + if body.Action != OAuthServerConsentActionApprove && body.Action != OAuthServerConsentActionDeny { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "action must be 'approve' or 'deny'") + } + + // Validate and find authorization outside transaction first + authorizationID := chi.URLParam(r, "authorization_id") + observability.LogEntrySetField(r, "authorization_id", authorizationID) + authorization, err := s.validateAndFindAuthorization(r, db, authorizationID) + if err != nil { + return err + } + + // Ensure authorization belongs to authenticated user + if err := s.validateAuthorizationOwnership(r, authorization, user); err != nil { + return err + } + + // Process consent in transaction + var redirectURL string + err = db.Transaction(func(tx *storage.Connection) error { + // Re-fetch in transaction to ensure consistency + authorization, err := models.FindOAuthServerAuthorizationByID(tx, authorizationID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") + } + return apierrors.NewInternalServerError("error finding authorization").WithInternalError(err) + } + + // Re-check expiration and status in transaction (state could have changed) + if authorization.IsExpired() { + if err := authorization.MarkExpired(tx); err != nil { + observability.GetLogEntry(r).Entry.WithError(err).Warn("failed to mark authorization as expired") + } + return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") + } + + if authorization.Status != models.OAuthServerAuthorizationPending { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request is no longer pending") + } + + if body.Action == OAuthServerConsentActionApprove { + // Approve authorization + if err := authorization.Approve(tx); err != nil { + return apierrors.NewInternalServerError("error approving authorization").WithInternalError(err) + } + + // Store consent for future use + scopes := authorization.GetScopeList() + consent := models.NewOAuthServerConsent(user.ID, authorization.ClientID, scopes) + if err := models.UpsertOAuthServerConsent(tx, consent); err != nil { + return apierrors.NewInternalServerError("error storing consent").WithInternalError(err) + } + + // Build success redirect URL + redirectURL = s.buildSuccessRedirectURL(authorization) + + observability.LogEntrySetField(r, "oauth_consent_action", string(OAuthServerConsentActionApprove)) + + } else { + // Deny authorization + if err := authorization.Deny(tx); err != nil { + return apierrors.NewInternalServerError("error denying authorization").WithInternalError(err) + } + + // Build error redirect URL + // Errors are being returned to the client in the redirect url per OAuth2 spec + var state string + if authorization.State != nil { + state = *authorization.State + } + redirectURL = s.buildErrorRedirectURL(authorization.RedirectURI, oAuth2ErrorAccessDenied, "User denied the request", state) + + observability.LogEntrySetField(r, "oauth_consent_action", string(OAuthServerConsentActionDeny)) + } + + return nil + }) + + if err != nil { + return err + } + + // Return redirect URL to frontend + response := ConsentResponse{ + RedirectURL: redirectURL, + } + + return shared.SendJSON(w, http.StatusOK, response) +} + +// Helper functions + +// validateRequestOrigin checks if the request is coming from an authorized origin +func (s *Server) validateRequestOrigin(r *http.Request) error { + // Check referer header + referer := r.Referer() + if referer == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "request must originate from authorized domain") + } + + if !utilities.IsRedirectURLValid(s.config, referer) { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "unauthorized request origin") + } + + return nil +} + +// validateAndFindAuthorization validates the authorization_id parameter and finds the authorization, +// performing all necessary checks (existence, expiration, status) +func (s *Server) validateAndFindAuthorization(r *http.Request, db *storage.Connection, authorizationID string) (*models.OAuthServerAuthorization, error) { + if authorizationID == "" { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization_id is required") + } + + authorization, err := models.FindOAuthServerAuthorizationByID(db, authorizationID) + if err != nil { + if models.IsNotFoundError(err) { + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") + } + return nil, apierrors.NewInternalServerError("error finding authorization").WithInternalError(err) + } + + // Check if expired first - no point processing expired authorizations + if authorization.IsExpired() { + // Mark as expired in database + if err := authorization.MarkExpired(db); err != nil { + observability.GetLogEntry(r).Entry.WithError(err).Warn("failed to mark authorization as expired") + } + // returning not found to avoid leaking information about the existence of the authorization + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") + } + + // Check if still pending + if authorization.Status != models.OAuthServerAuthorizationPending { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request cannot be processed") + } + + return authorization, nil +} + +// validateAuthorizationOwnership checks if the authorization belongs to the authenticated user +func (s *Server) validateAuthorizationOwnership(r *http.Request, authorization *models.OAuthServerAuthorization, user *models.User) error { + if authorization.UserID == nil || *authorization.UserID != user.ID { + observability.GetLogEntry(r).Entry. + WithField("request_user_id", user.ID). + WithField("authorization_id", authorization.AuthorizationID). + Warn("authorization belongs to different user") + return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") + } + return nil +} + +// validateBasicAuthorizeParams validates only client_id and redirect_uri (needed before we can redirect errors) +func (s *Server) validateBasicAuthorizeParams(params *AuthorizeParams) (*AuthorizeParams, error) { + if params.ClientID == "" { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_id is required") + } + if params.RedirectURI == "" { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "redirect_uri is required") + } + + return params, nil +} + +// validateRemainingAuthorizeParams validates all other parameters (can redirect errors since we have valid client + redirect_uri) +func (s *Server) validateRemainingAuthorizeParams(params *AuthorizeParams) error { + if params.ResponseType == "" { + params.ResponseType = models.OAuthServerResponseTypeCode.String() + } + if params.Scope == "" { + params.Scope = s.config.OAuthServer.DefaultScope + } + + // OAuth 2.1 only supports "code" response type + if params.ResponseType != models.OAuthServerResponseTypeCode.String() { + return errors.New("only response_type=code is supported") + } + + // Resource parameter validation (per RFC 8707) + if err := s.validateResourceParam(params.Resource); err != nil { + return err + } + + // PKCE validation + if err := s.validatePKCEParams(params.CodeChallengeMethod, params.CodeChallenge); err != nil { + return err + } + + return nil +} + +func (s *Server) validatePKCEParams(codeChallengeMethod, codeChallenge string) error { + // PKCE is mandatory for the authorization code flow OAuth2.1 + // Both code_challenge and code_challenge_method must be provided together + if codeChallenge == "" || codeChallengeMethod == "" { + return errors.New("PKCE flow requires both code_challenge and code_challenge_method") + } + + // Validate code challenge method (case-insensitive) + if strings.ToLower(codeChallengeMethod) != "s256" && strings.ToLower(codeChallengeMethod) != "plain" { + return errors.New("code_challenge_method must be 'S256' or 'plain'") + } + + // Validate code challenge format and length (per OAuth2 spec) + if len(codeChallenge) < 43 || len(codeChallenge) > 128 { + return errors.New("code_challenge must be between 43 and 128 characters") + } + + return nil +} + +// validateResourceParam validates the resource parameter per RFC 8707 +func (s *Server) validateResourceParam(resource string) error { + // Resource parameter is optional + if resource == "" { + return nil + } + + // Parse URL to validate it's an absolute URI + parsedURL, err := url.Parse(resource) + if err != nil { + return errors.New("resource must be a valid URI") + } + + // Must be an absolute URI (have scheme) + if !parsedURL.IsAbs() { + return errors.New("resource must be an absolute URI") + } + + // Must not include a fragment component + if parsedURL.Fragment != "" { + return errors.New("resource must not include a fragment component") + } + + // Should not include a query component + if parsedURL.RawQuery != "" { + return errors.New("resource must not include a query component") + } + + return nil +} + +func (s *Server) isValidRedirectURI(client *models.OAuthServerClient, redirectURI string) bool { + registeredURIs := client.GetRedirectURIs() + for _, registeredURI := range registeredURIs { + // exact string matching per OAuth2 spec + if registeredURI == redirectURI { + return true + } + } + return false +} + +func (s *Server) consentCoversScopes(consent *models.OAuthServerConsent, requestedScope string) bool { + if consent.IsRevoked() { + return false + } + + requestedScopes := models.ParseScopeString(requestedScope) + return consent.HasAllScopes(requestedScopes) +} + +func (s *Server) autoApproveAndRedirect(w http.ResponseWriter, r *http.Request, authorization *models.OAuthServerAuthorization) error { + ctx := r.Context() + db := s.db.WithContext(ctx) + + // Approve the authorization in a transaction + err := db.Transaction(func(tx *storage.Connection) error { + return authorization.Approve(tx) + }) + + if err != nil { + return apierrors.NewInternalServerError("Error auto-approving authorization").WithInternalError(err) + } + + observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) + observability.LogEntrySetField(r, "auto_approved", true) + + // Return JSON with redirect URL (same format as consent endpoint) + redirectURL := s.buildSuccessRedirectURL(authorization) + response := ConsentResponse{ + RedirectURL: redirectURL, + } + + return shared.SendJSON(w, http.StatusOK, response) +} + +func (s *Server) buildSuccessRedirectURL(authorization *models.OAuthServerAuthorization) string { + u, _ := url.Parse(authorization.RedirectURI) + q := u.Query() + if authorization.AuthorizationCode != nil { + q.Set("code", *authorization.AuthorizationCode) + } + if authorization.State != nil && *authorization.State != "" { + q.Set("state", *authorization.State) + } + u.RawQuery = q.Encode() + return u.String() +} + +// buildErrorRedirectURL builds an error redirect URL with the given parameters +func (s *Server) buildErrorRedirectURL(redirectURI, errorCode, errorDescription, state string) string { + u, _ := url.Parse(redirectURI) + q := u.Query() + q.Set("error", errorCode) + q.Set("error_description", errorDescription) + if state != "" { + q.Set("state", state) + } + u.RawQuery = q.Encode() + return u.String() +} + +// buildAuthorizationURL safely joins a base URL with a path, handling slashes correctly +func (s *Server) buildAuthorizationURL(baseURL, pathToJoin string) string { + // Trim trailing slash from baseURL + baseURL = strings.TrimRight(baseURL, "/") + + // Ensure pathToJoin starts with a slash + if !strings.HasPrefix(pathToJoin, "/") { + pathToJoin = "/" + pathToJoin + } + + return baseURL + pathToJoin +} diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index 27093e80e..8a247a054 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -3,6 +3,7 @@ package oauthserver import ( "context" "encoding/json" + "fmt" "net/http" "time" @@ -11,6 +12,7 @@ import ( "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" + "github.com/supabase/auth/internal/utilities" ) // OAuthServerClientResponse represents the response format for OAuth client operations @@ -18,23 +20,23 @@ type OAuthServerClientResponse struct { ClientID string `json:"client_id"` ClientSecret string `json:"client_secret,omitempty"` // only returned on registration - RedirectURIs []string `json:"redirect_uris"` - TokenEndpointAuthMethod []string `json:"token_endpoint_auth_method"` - GrantTypes []string `json:"grant_types"` - ResponseTypes []string `json:"response_types"` + RedirectURIs []string `json:"redirect_uris,omitempty"` + TokenEndpointAuthMethod []string `json:"token_endpoint_auth_method,omitempty"` + GrantTypes []string `json:"grant_types,omitempty"` + ResponseTypes []string `json:"response_types,omitempty"` ClientName string `json:"client_name,omitempty"` ClientURI string `json:"client_uri,omitempty"` LogoURI string `json:"logo_uri,omitempty"` // Metadata fields - RegistrationType string `json:"registration_type"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + RegistrationType string `json:"registration_type,omitempty"` + CreatedAt time.Time `json:"created_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` } // OAuthServerClientListResponse represents the response for listing OAuth clients type OAuthServerClientListResponse struct { - Clients []OAuthServerClientResponse `json:"clients"` + Clients []OAuthServerClientResponse `json:"clients,omitempty"` } // oauthServerClientToResponse converts a model to response format @@ -47,9 +49,9 @@ func oauthServerClientToResponse(client *models.OAuthServerClient, includeSecret TokenEndpointAuthMethod: []string{"client_secret_basic", "client_secret_post"}, // Both methods are supported GrantTypes: client.GetGrantTypes(), ResponseTypes: []string{"code"}, // Always "code" in OAuth 2.1 - ClientName: client.ClientName.String(), - ClientURI: client.ClientURI.String(), - LogoURI: client.LogoURI.String(), + ClientName: utilities.StringValue(client.ClientName), + ClientURI: utilities.StringValue(client.ClientURI), + LogoURI: utilities.StringValue(client.LogoURI), // Metadata fields RegistrationType: client.RegistrationType, @@ -80,7 +82,7 @@ func (s *Server) LoadOAuthServerClient(w http.ResponseWriter, r *http.Request) ( client, err := s.getOAuthServerClient(ctx, clientID) if err != nil { if models.IsNotFoundError(err) { - return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeUserNotFound, "OAuth client not found") + return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthClientNotFound, "OAuth client not found") } return nil, apierrors.NewInternalServerError("Error loading OAuth client").WithInternalError(err) } @@ -182,3 +184,50 @@ func (s *Server) OAuthServerClientList(w http.ResponseWriter, r *http.Request) e return shared.SendJSON(w, http.StatusOK, response) } + +// OAuthServerMetadataResponse represents the OAuth 2.1 Authorization Server Metadata per RFC 8414 +type OAuthServerMetadataResponse struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + JWKSetURI string `json:"jwks_uri"` + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + ResponseTypesSupported []string `json:"response_types_supported"` + ResponseModesSupported []string `json:"response_modes_supported"` + GrantTypesSupported []string `json:"grant_types_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` + + // TODO(cemal) :: Append the scopes supported when scope management is clarified! + // ScopesSupported []string `json:"scopes_supported"` +} + +// OAuthServerMetadata handles GET /.well-known/oauth-authorization-server +func (s *Server) OAuthServerMetadata(w http.ResponseWriter, r *http.Request) error { + issuer := s.config.JWT.Issuer + + // TODO(cemal) :: Remove this check when we have the config validation in place + if issuer == "" { + return apierrors.NewInternalServerError("Issuer is not set") + } + + response := OAuthServerMetadataResponse{ + Issuer: issuer, + AuthorizationEndpoint: fmt.Sprintf("%s/oauth/authorize", issuer), + TokenEndpoint: fmt.Sprintf("%s/oauth/token", issuer), + JWKSetURI: fmt.Sprintf("%s/.well-known/jwks.json", issuer), + ResponseTypesSupported: []string{"code"}, + ResponseModesSupported: []string{"query"}, + GrantTypesSupported: []string{"authorization_code", "refresh_token"}, + TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"}, + CodeChallengeMethodsSupported: []string{"S256", "plain"}, + } + + // Include registration endpoint if dynamic registration is enabled + if s.config.OAuthServer.AllowDynamicRegistration { + response.RegistrationEndpoint = fmt.Sprintf("%s/oauth/clients/register", issuer) + } + + // TODO: Cache response for 10 minutes, but consider dynamic registration toggle changes + return shared.SendJSON(w, http.StatusOK, response) +} diff --git a/internal/api/oauthserver/handlers_test.go b/internal/api/oauthserver/handlers_test.go index c72cdb534..0d74eec83 100644 --- a/internal/api/oauthserver/handlers_test.go +++ b/internal/api/oauthserver/handlers_test.go @@ -33,7 +33,8 @@ func TestOAuthClientHandler(t *testing.T) { conn, err := test.SetupDBConnection(globalConfig) require.NoError(t, err) - // Enable OAuth dynamic client registration for tests + // Enable OAuth server and dynamic client registration for tests + globalConfig.OAuthServer.Enabled = true globalConfig.OAuthServer.AllowDynamicRegistration = true server := NewServer(globalConfig, conn) diff --git a/internal/api/oauthserver/service.go b/internal/api/oauthserver/service.go index 2537f5e1d..02c3911bd 100644 --- a/internal/api/oauthserver/service.go +++ b/internal/api/oauthserver/service.go @@ -10,7 +10,7 @@ import ( "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/models" - "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/utilities" "golang.org/x/crypto/bcrypt" ) @@ -160,9 +160,9 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer client := &models.OAuthServerClient{ ClientID: generateClientID(), RegistrationType: params.RegistrationType, - ClientName: storage.NullString(params.ClientName), - ClientURI: storage.NullString(params.ClientURI), - LogoURI: storage.NullString(params.LogoURI), + ClientName: utilities.StringPtr(params.ClientName), + ClientURI: utilities.StringPtr(params.ClientURI), + LogoURI: utilities.StringPtr(params.LogoURI), } client.SetRedirectURIs(params.RedirectURIs) diff --git a/internal/api/oauthserver/service_test.go b/internal/api/oauthserver/service_test.go index 3c7f1f365..60798a49e 100644 --- a/internal/api/oauthserver/service_test.go +++ b/internal/api/oauthserver/service_test.go @@ -30,7 +30,8 @@ func TestOAuthService(t *testing.T) { conn, err := test.SetupDBConnection(globalConfig) require.NoError(t, err) - // Enable OAuth dynamic client registration for tests + // Enable OAuth server and dynamic client registration for tests + globalConfig.OAuthServer.Enabled = true globalConfig.OAuthServer.AllowDynamicRegistration = true server := NewServer(globalConfig, conn) @@ -49,7 +50,8 @@ func (ts *OAuthServiceTestSuite) SetupTest() { if ts.DB != nil { models.TruncateAll(ts.DB) } - // Enable OAuth dynamic client registration for tests + // Enable OAuth server and dynamic client registration for tests + ts.Config.OAuthServer.Enabled = true ts.Config.OAuthServer.AllowDynamicRegistration = true } @@ -86,7 +88,7 @@ func (ts *OAuthServiceTestSuite) TestOAuthServerClientServiceMethods() { require.NoError(ts.T(), err) require.NotNil(ts.T(), client) require.NotEmpty(ts.T(), secret) - assert.Equal(ts.T(), "Test Client", client.ClientName.String()) + assert.Equal(ts.T(), "Test Client", *client.ClientName) assert.Equal(ts.T(), "dynamic", client.RegistrationType) // Test getOAuthServerClient diff --git a/internal/api/shared/context.go b/internal/api/shared/context.go new file mode 100644 index 000000000..10fcfdb23 --- /dev/null +++ b/internal/api/shared/context.go @@ -0,0 +1,36 @@ +package shared + +import ( + "context" + + "github.com/supabase/auth/internal/models" +) + +// ContextKey is the type for context keys to avoid collisions +type ContextKey string + +func (c ContextKey) String() string { + return "gotrue api context key " + string(c) +} + +// Context keys used across packages +const ( + UserKey ContextKey = "user" +) + +// GetUser reads the user from the context - shared implementation +func GetUser(ctx context.Context) *models.User { + if ctx == nil { + return nil + } + obj := ctx.Value(UserKey) + if obj == nil { + return nil + } + return obj.(*models.User) +} + +// WithUser adds the user to the context - shared implementation +func WithUser(ctx context.Context, u *models.User) context.Context { + return context.WithValue(ctx, UserKey, u) +} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 02f5aeaec..e59460c0e 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -73,7 +73,12 @@ type OAuthProviderConfiguration struct { // OAuthServerConfiguration holds OAuth server configuration type OAuthServerConfiguration struct { - AllowDynamicRegistration bool `json:"allow_dynamic_registration" split_words:"true"` + Enabled bool `json:"enabled" default:"false"` + AllowDynamicRegistration bool `json:"allow_dynamic_registration" split_words:"true"` + AuthorizationPath string `json:"authorization_path" split_words:"true"` + AuthorizationTimeout time.Duration `json:"authorization_timeout" split_words:"true" default:"5m"` + // Placeholder for now, for (near) future extensibility + DefaultScope string `json:"default_scope" split_words:"true" default:"email"` } type AnonymousProviderConfiguration struct { diff --git a/internal/models/flow_state.go b/internal/models/flow_state.go index 9a770d81d..7ce4c940a 100644 --- a/internal/models/flow_state.go +++ b/internal/models/flow_state.go @@ -1,23 +1,18 @@ package models import ( - "crypto/sha256" - "crypto/subtle" "database/sql" - "encoding/base64" "fmt" "strings" "time" "github.com/pkg/errors" + "github.com/supabase/auth/internal/security" "github.com/supabase/auth/internal/storage" "github.com/gofrs/uuid" ) -const InvalidCodeChallengeError = "code challenge does not match previously saved code verifier" -const InvalidCodeMethodError = "code challenge method not supported" - type FlowState struct { ID uuid.UUID `json:"id" db:"id"` UserID *uuid.UUID `json:"user_id,omitempty" db:"user_id"` @@ -134,22 +129,7 @@ func FindFlowStateByUserID(tx *storage.Connection, id string, authenticationMeth } func (f *FlowState) VerifyPKCE(codeVerifier string) error { - switch f.CodeChallengeMethod { - case SHA256.String(): - hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier)) - encodedCodeVerifier := base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:]) - if subtle.ConstantTimeCompare([]byte(f.CodeChallenge), []byte(encodedCodeVerifier)) != 1 { - return errors.New(InvalidCodeChallengeError) - } - case Plain.String(): - if subtle.ConstantTimeCompare([]byte(f.CodeChallenge), []byte(codeVerifier)) != 1 { - return errors.New(InvalidCodeChallengeError) - } - default: - return errors.New(InvalidCodeMethodError) - - } - return nil + return security.VerifyPKCEChallenge(f.CodeChallenge, f.CodeChallengeMethod, codeVerifier) } func (f *FlowState) IsExpired(expiryDuration time.Duration) bool { diff --git a/internal/models/oauth_authorization.go b/internal/models/oauth_authorization.go new file mode 100644 index 000000000..409e03ccd --- /dev/null +++ b/internal/models/oauth_authorization.go @@ -0,0 +1,288 @@ +package models + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/crypto" + "github.com/supabase/auth/internal/security" + "github.com/supabase/auth/internal/storage" +) + +// OAuthServerAuthorizationStatus represents the status of an OAuth server authorization request +type OAuthServerAuthorizationStatus string + +const ( + OAuthServerAuthorizationPending OAuthServerAuthorizationStatus = "pending" + OAuthServerAuthorizationApproved OAuthServerAuthorizationStatus = "approved" + OAuthServerAuthorizationDenied OAuthServerAuthorizationStatus = "denied" + OAuthServerAuthorizationExpired OAuthServerAuthorizationStatus = "expired" +) + +func (s OAuthServerAuthorizationStatus) String() string { + return string(s) +} + +// OAuthServerResponseType represents the OAuth server response type +type OAuthServerResponseType string + +const ( + OAuthServerResponseTypeCode OAuthServerResponseType = "code" +) + +func (rt OAuthServerResponseType) String() string { + return string(rt) +} + +// OAuthServerAuthorization represents an OAuth 2.1 server authorization request +type OAuthServerAuthorization struct { + ID uuid.UUID `json:"-" db:"id"` + AuthorizationID string `json:"authorization_id" db:"authorization_id"` + ClientID uuid.UUID `json:"-" db:"client_id"` + UserID *uuid.UUID `json:"user_id" db:"user_id"` + RedirectURI string `json:"redirect_uri" db:"redirect_uri"` + Scope string `json:"scope" db:"scope"` + State *string `json:"state,omitempty" db:"state"` + Resource *string `json:"resource,omitempty" db:"resource"` + CodeChallenge *string `json:"code_challenge,omitempty" db:"code_challenge"` + CodeChallengeMethod *string `json:"code_challenge_method,omitempty" db:"code_challenge_method"` + ResponseType OAuthServerResponseType `json:"response_type" db:"response_type"` + Status OAuthServerAuthorizationStatus `json:"status" db:"status"` + AuthorizationCode *string `json:"-" db:"authorization_code"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + ExpiresAt time.Time `json:"expires_at" db:"expires_at"` + ApprovedAt *time.Time `json:"approved_at" db:"approved_at"` + + // Relations with OAuth clients + Client *OAuthServerClient `json:"client,omitempty" db:"-"` +} + +// TableName returns the table name for the OAuthServerAuthorization model +func (OAuthServerAuthorization) TableName() string { + return "oauth_authorizations" +} + +// NewOAuthServerAuthorization creates a new OAuth server authorization request without user (for initial flow) +func NewOAuthServerAuthorization(clientID uuid.UUID, redirectURI, scope, state, resource, codeChallenge, codeChallengeMethod string) *OAuthServerAuthorization { + id := uuid.Must(uuid.NewV4()) + authorizationID := crypto.SecureAlphanumeric(32) // Generate random ID for frontend + + now := time.Now() + expiresAt := now.Add(10 * time.Minute) // 10 minute expiration + + auth := &OAuthServerAuthorization{ + ID: id, + AuthorizationID: authorizationID, + ClientID: clientID, + UserID: nil, // No user yet + RedirectURI: redirectURI, + Scope: scope, + ResponseType: OAuthServerResponseTypeCode, + Status: OAuthServerAuthorizationPending, + CreatedAt: now, + ExpiresAt: expiresAt, + } + + if state != "" { + auth.State = &state + } + if resource != "" { + auth.Resource = &resource + } + if codeChallenge != "" { + auth.CodeChallenge = &codeChallenge + } + if codeChallengeMethod != "" { + // Normalize code challenge method to lowercase for database storage + // Database enum expects 's256' and 'plain' (lowercase) + normalizedMethod := strings.ToLower(codeChallengeMethod) + auth.CodeChallengeMethod = &normalizedMethod + } + + return auth +} + +// IsExpired checks if the authorization request has expired +func (auth *OAuthServerAuthorization) IsExpired() bool { + return time.Now().After(auth.ExpiresAt) +} + +// SetUser sets the user ID for the authorization request (after login) +func (auth *OAuthServerAuthorization) SetUser(tx *storage.Connection, userID uuid.UUID) error { + auth.UserID = &userID + return tx.UpdateOnly(auth, "user_id") +} + +// GetScopeList returns the scopes as a slice +func (auth *OAuthServerAuthorization) GetScopeList() []string { + return ParseScopeString(auth.Scope) +} + +// GenerateAuthorizationCode generates a new authorization code if not already set +func (auth *OAuthServerAuthorization) GenerateAuthorizationCode() string { + if auth.AuthorizationCode != nil && *auth.AuthorizationCode != "" { + return *auth.AuthorizationCode + } + + code := uuid.Must(uuid.NewV4()).String() + auth.AuthorizationCode = &code + return code +} + +// Approve approves the authorization request and generates an authorization code +func (auth *OAuthServerAuthorization) Approve(tx *storage.Connection) error { + if auth.IsExpired() { + return fmt.Errorf("authorization request has expired") + } + + if auth.Status != OAuthServerAuthorizationPending { + return fmt.Errorf("authorization request is not pending (current status: %s)", auth.Status) + } + + now := time.Now() + auth.Status = OAuthServerAuthorizationApproved + auth.ApprovedAt = &now + auth.GenerateAuthorizationCode() + + return tx.UpdateOnly(auth, "status", "approved_at", "authorization_code") +} + +// Deny denies the authorization request +func (auth *OAuthServerAuthorization) Deny(tx *storage.Connection) error { + if auth.Status != OAuthServerAuthorizationPending { + return fmt.Errorf("authorization request is not pending (current status: %s)", auth.Status) + } + + auth.Status = OAuthServerAuthorizationDenied + return tx.UpdateOnly(auth, "status") +} + +// MarkExpired marks the authorization request as expired +func (auth *OAuthServerAuthorization) MarkExpired(tx *storage.Connection) error { + if auth.Status != OAuthServerAuthorizationPending { + return fmt.Errorf("authorization request is not pending (current status: %s)", auth.Status) + } + + auth.Status = OAuthServerAuthorizationExpired + return tx.UpdateOnly(auth, "status") +} + +// Validate performs basic validation on the OAuth authorization +func (auth *OAuthServerAuthorization) Validate() error { + if auth.ClientID == uuid.Nil { + return fmt.Errorf("client_id is required") + } + // UserID can be nil initially for unauthenticated authorization requests + // It will be set when user authenticates + if auth.RedirectURI == "" { + return fmt.Errorf("redirect_uri is required") + } + if auth.Scope == "" { + return fmt.Errorf("scope is required") + } + if auth.ResponseType != OAuthServerResponseTypeCode { + return fmt.Errorf("only response_type=code is supported") + } + if auth.ExpiresAt.Before(auth.CreatedAt) { + return fmt.Errorf("expires_at must be after created_at") + } + + return nil +} + +// VerifyPKCE verifies the PKCE code verifier against the stored challenge +func (auth *OAuthServerAuthorization) VerifyPKCE(codeVerifier string) error { + if auth.CodeChallenge == nil || *auth.CodeChallenge == "" { + // No PKCE challenge stored, verification passes + return nil + } + + if codeVerifier == "" { + return fmt.Errorf("code_verifier is required when PKCE challenge is present") + } + + // Use the shared PKCE verification function + var codeChallengeMethod string + if auth.CodeChallengeMethod != nil { + codeChallengeMethod = *auth.CodeChallengeMethod + } + return security.VerifyPKCEChallenge(*auth.CodeChallenge, codeChallengeMethod, codeVerifier) +} + +// Query functions for OAuth authorizations + +// FindOAuthServerAuthorizationByID finds an OAuth authorization by authorization_id +func FindOAuthServerAuthorizationByID(tx *storage.Connection, authorizationID string) (*OAuthServerAuthorization, error) { + auth := &OAuthServerAuthorization{} + if err := tx.Q().Where("authorization_id = ?", authorizationID).First(auth); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, OAuthServerAuthorizationNotFoundError{} + } + return nil, errors.Wrap(err, "error finding OAuth authorization") + } + + if auth.ClientID != uuid.Nil { + client := &OAuthServerClient{} + if err := tx.Q().Where("id = ?", auth.ClientID).First(client); err == nil { + auth.Client = client + } + } + + return auth, nil +} + +// FindOAuthServerAuthorizationByCode finds an OAuth authorization by authorization code +func FindOAuthServerAuthorizationByCode(tx *storage.Connection, code string) (*OAuthServerAuthorization, error) { + auth := &OAuthServerAuthorization{} + if err := tx.Q().Where("authorization_code = ? AND status = ?", code, OAuthServerAuthorizationApproved).First(auth); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, OAuthServerAuthorizationNotFoundError{} + } + return nil, errors.Wrap(err, "error finding OAuth authorization by code") + } + + // Load client relationship (always present) + if auth.ClientID != uuid.Nil { + client := &OAuthServerClient{} + if err := tx.Q().Where("id = ?", auth.ClientID).First(client); err == nil { + auth.Client = client + } + } + + return auth, nil +} + +// CreateOAuthServerAuthorization creates a new OAuth authorization in the database +func CreateOAuthServerAuthorization(tx *storage.Connection, auth *OAuthServerAuthorization) error { + if err := auth.Validate(); err != nil { + return err + } + + if auth.ID == uuid.Nil { + auth.ID = uuid.Must(uuid.NewV4()) + } + + if auth.AuthorizationID == "" { + auth.AuthorizationID = crypto.SecureAlphanumeric(32) + } + + return tx.Create(auth) +} + +// CleanupExpiredOAuthServerAuthorizations marks expired authorizations as expired +func CleanupExpiredOAuthServerAuthorizations(tx *storage.Connection) error { + query := "UPDATE " + (&OAuthServerAuthorization{}).TableName() + " SET status = ? WHERE status = ? AND expires_at < now()" + return tx.RawQuery(query, OAuthServerAuthorizationExpired, OAuthServerAuthorizationPending).Exec() +} + +// Error types for OAuth authorization operations + +type OAuthServerAuthorizationNotFoundError struct{} + +func (e OAuthServerAuthorizationNotFoundError) Error() string { + return "OAuth authorization not found" +} diff --git a/internal/models/oauth_authorization_test.go b/internal/models/oauth_authorization_test.go new file mode 100644 index 000000000..d039efdcb --- /dev/null +++ b/internal/models/oauth_authorization_test.go @@ -0,0 +1,366 @@ +package models + +import ( + "fmt" + "testing" + "time" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" +) + +func TestNewOAuthServerAuthorization(t *testing.T) { + clientID := uuid.Must(uuid.NewV4()) + redirectURI := "https://example.com/callback" + scope := "openid profile" + state := "random-state" + codeChallenge := "test-challenge" + codeChallengeMethod := "S256" + resource := "https://api.example.com/" + + auth := NewOAuthServerAuthorization(clientID, redirectURI, scope, state, resource, codeChallenge, codeChallengeMethod) + + assert.NotEmpty(t, auth.ID) + assert.NotEmpty(t, auth.AuthorizationID) + assert.Equal(t, clientID, auth.ClientID) + assert.Nil(t, auth.UserID) + assert.Equal(t, redirectURI, auth.RedirectURI) + assert.Equal(t, scope, auth.Scope) + assert.Equal(t, state, *auth.State) + assert.Equal(t, resource, *auth.Resource) + assert.Equal(t, codeChallenge, *auth.CodeChallenge) + assert.Equal(t, "s256", *auth.CodeChallengeMethod) // Should be normalized to lowercase + assert.Equal(t, OAuthServerResponseTypeCode, auth.ResponseType) + assert.Equal(t, OAuthServerAuthorizationPending, auth.Status) + assert.True(t, auth.ExpiresAt.After(auth.CreatedAt)) + assert.Nil(t, auth.ApprovedAt) +} + +func TestNewOAuthServerAuthorization_CodeChallengeMethodNormalization(t *testing.T) { + testCases := []struct { + name string + input string + expected string + }{ + {"uppercase S256", "S256", "s256"}, + {"lowercase s256", "s256", "s256"}, + {"mixed case S256", "s256", "s256"}, + {"uppercase PLAIN", "PLAIN", "plain"}, + {"lowercase plain", "plain", "plain"}, + {"mixed case Plain", "Plain", "plain"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + auth := NewOAuthServerAuthorization( + uuid.Must(uuid.NewV4()), + "https://example.com/callback", + "openid", + "state", + "", + "challenge", + tc.input, + ) + + assert.Equal(t, tc.expected, *auth.CodeChallengeMethod, + "Expected code_challenge_method to be normalized to %s, got %s", tc.expected, *auth.CodeChallengeMethod) + }) + } +} + +func TestOAuthServerAuthorization_IsExpired(t *testing.T) { + auth := &OAuthServerAuthorization{ + CreatedAt: time.Now().Add(-1 * time.Hour), + ExpiresAt: time.Now().Add(-30 * time.Minute), // Expired 30 minutes ago + } + + assert.True(t, auth.IsExpired()) + + auth.ExpiresAt = time.Now().Add(30 * time.Minute) // Expires in 30 minutes + assert.False(t, auth.IsExpired()) +} + +func TestOAuthServerAuthorization_GenerateAuthorizationCode(t *testing.T) { + auth := &OAuthServerAuthorization{} + + // First call should generate a code + code1 := auth.GenerateAuthorizationCode() + assert.NotEmpty(t, code1) + assert.Equal(t, code1, *auth.AuthorizationCode) + + // Second call should return the same code + code2 := auth.GenerateAuthorizationCode() + assert.Equal(t, code1, code2) +} + +func TestOAuthServerAuthorization_ApproveErrCases(t *testing.T) { + tests := []struct { + name string + auth *OAuthServerAuthorization + wantErr bool + errMsg string + }{ + { + name: "approval of expired authorization", + auth: &OAuthServerAuthorization{ + Status: OAuthServerAuthorizationPending, + CreatedAt: time.Now().Add(-1 * time.Hour), + ExpiresAt: time.Now().Add(-30 * time.Minute), // Expired + }, + wantErr: true, + errMsg: "authorization request has expired", + }, + { + name: "approval of already approved authorization", + auth: &OAuthServerAuthorization{ + Status: OAuthServerAuthorizationApproved, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + wantErr: true, + errMsg: "authorization request is not pending", + }, + { + name: "approval of denied authorization", + auth: &OAuthServerAuthorization{ + Status: OAuthServerAuthorizationDenied, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + wantErr: true, + errMsg: "authorization request is not pending", + }, + { + name: "approval of expired status authorization", + auth: &OAuthServerAuthorization{ + Status: OAuthServerAuthorizationExpired, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(10 * time.Minute), + }, + wantErr: true, + errMsg: "authorization request is not pending", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test validation logic before database operations + if tt.auth.IsExpired() { + err := fmt.Errorf("authorization request has expired") + assert.Error(t, err) + assert.Contains(t, err.Error(), "authorization request has expired") + return + } + + if tt.auth.Status != OAuthServerAuthorizationPending { + err := fmt.Errorf("authorization request is not pending (current status: %s)", tt.auth.Status) + assert.Error(t, err) + assert.Contains(t, err.Error(), "authorization request is not pending") + return + } + + // If we get here, it should be valid for approval + assert.False(t, tt.wantErr, "Expected error but validation passed") + }) + } +} + +func TestOAuthServerAuthorization_ApproveSuccess(t *testing.T) { + auth := &OAuthServerAuthorization{ + Status: OAuthServerAuthorizationPending, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(10 * time.Minute), + } + + // Test the in-memory state changes that happen during approval + beforeTime := time.Now() + auth.Status = OAuthServerAuthorizationApproved + now := time.Now() + auth.ApprovedAt = &now + auth.GenerateAuthorizationCode() + + assert.Equal(t, OAuthServerAuthorizationApproved, auth.Status) + assert.NotNil(t, auth.ApprovedAt) + assert.True(t, auth.ApprovedAt.After(beforeTime)) + assert.NotEmpty(t, *auth.AuthorizationCode) +} + +func TestOAuthServerAuthorization_Deny(t *testing.T) { + tests := []struct { + name string + status OAuthServerAuthorizationStatus + wantErr bool + errMsg string + }{ + { + name: "deny pending authorization", + status: OAuthServerAuthorizationPending, + wantErr: false, + }, + { + name: "deny already approved authorization", + status: OAuthServerAuthorizationApproved, + wantErr: true, + errMsg: "authorization request is not pending", + }, + { + name: "deny already denied authorization", + status: OAuthServerAuthorizationDenied, + wantErr: true, + errMsg: "authorization request is not pending", + }, + { + name: "deny expired authorization", + status: OAuthServerAuthorizationExpired, + wantErr: true, + errMsg: "authorization request is not pending", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := &OAuthServerAuthorization{Status: tt.status} + + if auth.Status != OAuthServerAuthorizationPending { + err := fmt.Errorf("authorization request is not pending (current status: %s)", auth.Status) + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } + return + } + + // Test successful denial state change + auth.Status = OAuthServerAuthorizationDenied + assert.Equal(t, OAuthServerAuthorizationDenied, auth.Status) + }) + } +} + +func TestOAuthServerAuthorization_MarkExpiredLogic(t *testing.T) { + tests := []struct { + name string + status OAuthServerAuthorizationStatus + wantErr bool + errMsg string + }{ + { + name: "mark pending authorization as expired", + status: OAuthServerAuthorizationPending, + wantErr: false, + }, + { + name: "mark approved authorization as expired", + status: OAuthServerAuthorizationApproved, + wantErr: true, + errMsg: "authorization request is not pending", + }, + { + name: "mark denied authorization as expired", + status: OAuthServerAuthorizationDenied, + wantErr: true, + errMsg: "authorization request is not pending", + }, + { + name: "mark already expired authorization as expired", + status: OAuthServerAuthorizationExpired, + wantErr: true, + errMsg: "authorization request is not pending", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := &OAuthServerAuthorization{Status: tt.status} + + if auth.Status != OAuthServerAuthorizationPending { + err := fmt.Errorf("authorization request is not pending (current status: %s)", auth.Status) + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } + return + } + + // Test successful expiration state change + auth.Status = OAuthServerAuthorizationExpired + assert.Equal(t, OAuthServerAuthorizationExpired, auth.Status) + }) + } +} + +func TestOAuthServerAuthorization_Validate(t *testing.T) { + userID := uuid.Must(uuid.NewV4()) + clientID := uuid.Must(uuid.NewV4()) + validAuth := &OAuthServerAuthorization{ + ClientID: clientID, + UserID: &userID, + RedirectURI: "https://example.com/callback", + Scope: "openid", + ResponseType: OAuthServerResponseTypeCode, + CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(10 * time.Minute), + } + + // Valid authorization should pass + assert.NoError(t, validAuth.Validate()) + + // Test UserID can be nil (for unauthenticated requests) + validAuthNoUser := *validAuth + validAuthNoUser.UserID = nil + assert.NoError(t, validAuthNoUser.Validate()) + + // Test invalid cases + tests := []struct { + name string + modify func(*OAuthServerAuthorization) + wantErr bool + errMsg string + }{ + { + name: "missing client_id", + modify: func(a *OAuthServerAuthorization) { a.ClientID = uuid.Nil }, + wantErr: true, + errMsg: "client_id is required", + }, + { + name: "missing redirect_uri", + modify: func(a *OAuthServerAuthorization) { a.RedirectURI = "" }, + wantErr: true, + errMsg: "redirect_uri is required", + }, + { + name: "missing scope", + modify: func(a *OAuthServerAuthorization) { a.Scope = "" }, + wantErr: true, + errMsg: "scope is required", + }, + { + name: "invalid response_type", + modify: func(a *OAuthServerAuthorization) { a.ResponseType = "token" }, + wantErr: true, + errMsg: "only response_type=code is supported", + }, + { + name: "expires_at before created_at", + modify: func(a *OAuthServerAuthorization) { a.ExpiresAt = a.CreatedAt.Add(-1 * time.Minute) }, + wantErr: true, + errMsg: "expires_at must be after created_at", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := *validAuth // Copy + tt.modify(&auth) + + err := auth.Validate() + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/models/oauth_client.go b/internal/models/oauth_client.go index 7c8776377..93805a0f7 100644 --- a/internal/models/oauth_client.go +++ b/internal/models/oauth_client.go @@ -22,9 +22,9 @@ type OAuthServerClient struct { RedirectURIs string `json:"-" db:"redirect_uris"` GrantTypes string `json:"grant_types" db:"grant_types"` - ClientName storage.NullString `json:"client_name" db:"client_name"` - ClientURI storage.NullString `json:"client_uri" db:"client_uri"` - LogoURI storage.NullString `json:"logo_uri" db:"logo_uri"` + ClientName *string `json:"client_name,omitempty" db:"client_name"` + ClientURI *string `json:"client_uri,omitempty" db:"client_uri"` + LogoURI *string `json:"logo_uri,omitempty" db:"logo_uri"` CreatedAt time.Time `json:"created_at" db:"created_at"` UpdatedAt time.Time `json:"updated_at" db:"updated_at"` DeletedAt *time.Time `json:"deleted_at,omitempty" db:"deleted_at"` diff --git a/internal/models/oauth_client_test.go b/internal/models/oauth_client_test.go index 1a2607f6c..233da83fb 100644 --- a/internal/models/oauth_client_test.go +++ b/internal/models/oauth_client_test.go @@ -20,7 +20,7 @@ type OAuthServerClientTestSuite struct { } func (ts *OAuthServerClientTestSuite) SetupTest() { - TruncateAll(ts.db) + _ = TruncateAll(ts.db) } func TestOAuthServerClient(t *testing.T) { @@ -39,10 +39,11 @@ func TestOAuthServerClient(t *testing.T) { } func (ts *OAuthServerClientTestSuite) TestOAuthServerClientValidation() { + testClientName := "Test Client" validClient := &OAuthServerClient{ ID: uuid.Must(uuid.NewV4()), ClientID: "test_client_id", - ClientName: storage.NullString("Test Client"), + ClientName: &testClientName, RegistrationType: "dynamic", RedirectURIs: "https://example.com/callback", GrantTypes: "authorization_code,refresh_token", @@ -141,9 +142,10 @@ func (ts *OAuthServerClientTestSuite) TestRedirectURIHelpers() { } func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClient() { + testAppName := "Test Application" client := &OAuthServerClient{ ClientID: "test_client_create_" + uuid.Must(uuid.NewV4()).String()[:8], - ClientName: storage.NullString("Test Application"), + ClientName: &testAppName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", RedirectURIs: "https://example.com/callback", @@ -170,9 +172,10 @@ func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClientValidation() { func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { // Create a test client + testName := "Find By ID Test" client := &OAuthServerClient{ ClientID: "test_client_find_by_id_" + uuid.Must(uuid.NewV4()).String()[:8], - ClientName: storage.NullString("Find By ID Test"), + ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", RedirectURIs: "https://example.com/callback", @@ -185,7 +188,7 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { foundClient, err := FindOAuthServerClientByID(ts.db, client.ID) require.NoError(ts.T(), err) assert.Equal(ts.T(), client.ClientID, foundClient.ClientID) - assert.Equal(ts.T(), client.ClientName.String(), foundClient.ClientName.String()) + assert.Equal(ts.T(), *client.ClientName, *foundClient.ClientName) // Test not found _, err = FindOAuthServerClientByID(ts.db, uuid.Must(uuid.NewV4())) @@ -195,9 +198,10 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { // Create a test client + testName := "Find By Client ID Test" client := &OAuthServerClient{ ClientID: "test_client_find_by_client_id_" + uuid.Must(uuid.NewV4()).String()[:8], - ClientName: storage.NullString("Find By Client ID Test"), + ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "manual", RedirectURIs: "https://example.com/callback", @@ -210,7 +214,7 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { foundClient, err := FindOAuthServerClientByClientID(ts.db, client.ClientID) require.NoError(ts.T(), err) assert.Equal(ts.T(), client.ID, foundClient.ID) - assert.Equal(ts.T(), client.ClientName.String(), foundClient.ClientName.String()) + assert.Equal(ts.T(), *client.ClientName, *foundClient.ClientName) // Test not found _, err = FindOAuthServerClientByClientID(ts.db, "nonexistent_client_id") @@ -220,9 +224,10 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { func (ts *OAuthServerClientTestSuite) TestUpdateOAuthServerClient() { // Create a test client + originalName := "Original Name" client := &OAuthServerClient{ ClientID: "test_client_update_" + uuid.Must(uuid.NewV4()).String()[:8], - ClientName: storage.NullString("Original Name"), + ClientName: &originalName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", RedirectURIs: "https://example.com/callback", @@ -233,7 +238,8 @@ func (ts *OAuthServerClientTestSuite) TestUpdateOAuthServerClient() { originalUpdatedAt := client.UpdatedAt // Update the client - client.ClientName = storage.NullString("Updated Name") + updatedName := "Updated Name" + client.ClientName = &updatedName // client.Description removed - no longer exists client.SetRedirectURIs([]string{"https://updated.example.com/callback"}) @@ -243,7 +249,7 @@ func (ts *OAuthServerClientTestSuite) TestUpdateOAuthServerClient() { // Verify updates updatedClient, err := FindOAuthServerClientByID(ts.db, client.ID) require.NoError(ts.T(), err) - assert.Equal(ts.T(), "Updated Name", updatedClient.ClientName.String()) + assert.Equal(ts.T(), "Updated Name", *updatedClient.ClientName) // assert.Equal(ts.T(), "Updated description", updatedClient.Description.String()) // Description field removed assert.Equal(ts.T(), []string{"https://updated.example.com/callback"}, updatedClient.GetRedirectURIs()) assert.True(ts.T(), updatedClient.UpdatedAt.After(originalUpdatedAt)) @@ -267,9 +273,10 @@ func (ts *OAuthServerClientTestSuite) TestClientSecretHashing() { func (ts *OAuthServerClientTestSuite) TestSoftDelete() { // Create a test client + testName := "Soft Delete Test" client := &OAuthServerClient{ ClientID: "test_client_soft_delete_" + uuid.Must(uuid.NewV4()).String()[:8], - ClientName: storage.NullString("Soft Delete Test"), + ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", RedirectURIs: "https://example.com/callback", diff --git a/internal/models/oauth_consent.go b/internal/models/oauth_consent.go new file mode 100644 index 000000000..4aeeaa4bf --- /dev/null +++ b/internal/models/oauth_consent.go @@ -0,0 +1,181 @@ +package models + +import ( + "database/sql" + "fmt" + "strings" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +// OAuthServerConsent represents user consent for an OAuth server client's access to specific scopes +type OAuthServerConsent struct { + ID uuid.UUID `json:"id" db:"id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + ClientID uuid.UUID `json:"-" db:"client_id"` + Scopes string `json:"scopes" db:"scopes"` + GrantedAt time.Time `json:"granted_at" db:"granted_at"` + RevokedAt *time.Time `json:"revoked_at" db:"revoked_at"` +} + +// TableName returns the table name for the OAuthConsent model +func (OAuthServerConsent) TableName() string { + return "oauth_consents" +} + +// NewOAuthConsent creates a new OAuth consent record +func NewOAuthServerConsent(userID uuid.UUID, clientID uuid.UUID, scopes []string) *OAuthServerConsent { + return &OAuthServerConsent{ + ID: uuid.Must(uuid.NewV4()), + UserID: userID, + ClientID: clientID, + Scopes: strings.Join(scopes, " "), + GrantedAt: time.Now(), + } +} + +// GetScopeList returns the granted scopes as a slice +func (consent *OAuthServerConsent) GetScopeList() []string { + return ParseScopeString(consent.Scopes) +} + +// HasScope checks if the consent includes a specific scope +func (consent *OAuthServerConsent) HasScope(scope string) bool { + return HasScope(consent.GetScopeList(), scope) +} + +// HasAllScopes checks if the consent includes all of the requested scopes +func (consent *OAuthServerConsent) HasAllScopes(requestedScopes []string) bool { + return HasAllScopes(consent.GetScopeList(), requestedScopes) +} + +// IsRevoked checks if the consent has been revoked +func (consent *OAuthServerConsent) IsRevoked() bool { + return consent.RevokedAt != nil +} + +// Revoke revokes the consent +func (consent *OAuthServerConsent) Revoke(tx *storage.Connection) error { + if consent.IsRevoked() { + return fmt.Errorf("consent is already revoked") + } + + now := time.Now() + consent.RevokedAt = &now + return tx.UpdateOnly(consent, "revoked_at") +} + +// UpdateScopes updates the granted scopes for this consent +func (consent *OAuthServerConsent) UpdateScopes(tx *storage.Connection, scopes []string) error { + if consent.IsRevoked() { + return fmt.Errorf("cannot update scopes for revoked consent") + } + + consent.Scopes = strings.Join(scopes, " ") + consent.GrantedAt = time.Now() // Update granted time to reflect the change + return tx.UpdateOnly(consent, "scopes", "granted_at") +} + +// Validate performs basic validation on the OAuth consent +func (consent *OAuthServerConsent) Validate() error { + if consent.UserID == uuid.Nil { + return fmt.Errorf("user_id is required") + } + if consent.ClientID == uuid.Nil { + return fmt.Errorf("client_id is required") + } + if strings.TrimSpace(consent.Scopes) == "" { + return fmt.Errorf("scopes cannot be empty") + } + if consent.RevokedAt != nil && consent.RevokedAt.Before(consent.GrantedAt) { + return fmt.Errorf("revoked_at cannot be before granted_at") + } + + return nil +} + +// Query functions for OAuth consents + +// FindOAuthServerConsentByUserAndClient finds an OAuth consent by user and client +func FindOAuthServerConsentByUserAndClient(tx *storage.Connection, userID uuid.UUID, clientID uuid.UUID) (*OAuthServerConsent, error) { + consent := &OAuthServerConsent{} + if err := tx.Eager().Q().Where("user_id = ? AND client_id = ?", userID, clientID).First(consent); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil // No consent found (not an error) + } + return nil, errors.Wrap(err, "error finding OAuth consent") + } + return consent, nil +} + +// FindActiveOAuthServerConsentByUserAndClient finds an active (non-revoked) OAuth consent +func FindActiveOAuthServerConsentByUserAndClient(tx *storage.Connection, userID uuid.UUID, clientID uuid.UUID) (*OAuthServerConsent, error) { + consent := &OAuthServerConsent{} + if err := tx.Q().Where("user_id = ? AND client_id = ? AND revoked_at IS NULL", userID, clientID).First(consent); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, nil // No active consent found (not an error) + } + return nil, errors.Wrap(err, "error finding active OAuth consent") + } + + return consent, nil +} + +// FindOAuthServerConsentsByUser finds all OAuth consents for a user +func FindOAuthServerConsentsByUser(tx *storage.Connection, userID uuid.UUID, includeRevoked bool) ([]*OAuthServerConsent, error) { + var consents []*OAuthServerConsent + query := tx.Q().Where("user_id = ?", userID) + + if !includeRevoked { + query = query.Where("revoked_at IS NULL") + } + + if err := query.Order("granted_at desc").All(&consents); err != nil { + return nil, errors.Wrap(err, "error finding OAuth consents by user") + } + + return consents, nil +} + +// UpsertOAuthServerConsent creates or updates an OAuth consent +func UpsertOAuthServerConsent(tx *storage.Connection, consent *OAuthServerConsent) error { + if err := consent.Validate(); err != nil { + return err + } + + existing, err := FindOAuthServerConsentByUserAndClient(tx, consent.UserID, consent.ClientID) + if err != nil { + return err + } + + if existing != nil { + // Update existing consent + existing.Scopes = consent.Scopes + existing.GrantedAt = time.Now() + existing.RevokedAt = nil // Un-revoke if previously revoked + return tx.Update(existing) + } + + // Create new consent + if consent.ID == uuid.Nil { + consent.ID = uuid.Must(uuid.NewV4()) + } + return tx.Create(consent) +} + +// RevokeOAuthServerConsentsByClient revokes all consents for a specific client +func RevokeOAuthServerConsentsByClient(tx *storage.Connection, clientID uuid.UUID) error { + now := time.Now() + query := "UPDATE " + (&OAuthServerConsent{}).TableName() + " SET revoked_at = ? WHERE client_id = ? AND revoked_at IS NULL" + return tx.RawQuery(query, now, clientID).Exec() +} + +// RevokeOAuthServerConsentsByUser revokes all consents for a specific user +func RevokeOAuthServerConsentsByUser(tx *storage.Connection, userID uuid.UUID) error { + now := time.Now() + query := "UPDATE " + (&OAuthServerConsent{}).TableName() + " SET revoked_at = ? WHERE user_id = ? AND revoked_at IS NULL" + return tx.RawQuery(query, now, userID).Exec() +} diff --git a/internal/models/oauth_consent_test.go b/internal/models/oauth_consent_test.go new file mode 100644 index 000000000..4548004f3 --- /dev/null +++ b/internal/models/oauth_consent_test.go @@ -0,0 +1,105 @@ +package models + +import ( + "testing" + "time" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" +) + +func TestNewOAuthServerConsent(t *testing.T) { + userID := uuid.Must(uuid.NewV4()) + clientID := uuid.Must(uuid.NewV4()) + scopes := []string{"openid", "profile", "email"} + + consent := NewOAuthServerConsent(userID, clientID, scopes) + + assert.NotEmpty(t, consent.ID) + assert.Equal(t, userID, consent.UserID) + assert.Equal(t, clientID, consent.ClientID) + assert.Equal(t, "openid profile email", consent.Scopes) + assert.False(t, consent.GrantedAt.IsZero()) + assert.Nil(t, consent.RevokedAt) +} + +func TestOAuthServerConsent_IsRevoked(t *testing.T) { + consent := &OAuthServerConsent{} + + // Initially not revoked + assert.False(t, consent.IsRevoked()) + + // After revocation + now := time.Now() + consent.RevokedAt = &now + assert.True(t, consent.IsRevoked()) +} + +func TestOAuthServerConsent_Validate(t *testing.T) { + validConsent := &OAuthServerConsent{ + UserID: uuid.Must(uuid.NewV4()), + ClientID: uuid.Must(uuid.NewV4()), + Scopes: "openid profile", + GrantedAt: time.Now(), + } + + // Valid consent should pass + assert.NoError(t, validConsent.Validate()) + + // Test invalid cases + tests := []struct { + name string + modify func(*OAuthServerConsent) + wantErr bool + errMsg string + }{ + { + name: "missing user_id", + modify: func(c *OAuthServerConsent) { c.UserID = uuid.Nil }, + wantErr: true, + errMsg: "user_id is required", + }, + { + name: "missing client_id", + modify: func(c *OAuthServerConsent) { c.ClientID = uuid.Nil }, + wantErr: true, + errMsg: "client_id is required", + }, + { + name: "empty scopes", + modify: func(c *OAuthServerConsent) { c.Scopes = "" }, + wantErr: true, + errMsg: "scopes cannot be empty", + }, + { + name: "whitespace only scopes", + modify: func(c *OAuthServerConsent) { c.Scopes = " " }, + wantErr: true, + errMsg: "scopes cannot be empty", + }, + { + name: "revoked_at before granted_at", + modify: func(c *OAuthServerConsent) { + revokedAt := c.GrantedAt.Add(-1 * time.Hour) + c.RevokedAt = &revokedAt + }, + wantErr: true, + errMsg: "revoked_at cannot be before granted_at", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + consent := *validConsent // Copy + tt.modify(&consent) + + err := consent.Validate() + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/models/oauth_scope.go b/internal/models/oauth_scope.go new file mode 100644 index 000000000..6b3c437d2 --- /dev/null +++ b/internal/models/oauth_scope.go @@ -0,0 +1,47 @@ +package models + +import "strings" + +// ParseScopeString parses a space-separated scope string into a slice +func ParseScopeString(scopeString string) []string { + if scopeString == "" { + return []string{} + } + scopes := strings.Split(strings.TrimSpace(scopeString), " ") + var result []string + for _, scope := range scopes { + if strings.TrimSpace(scope) != "" { + result = append(result, strings.TrimSpace(scope)) + } + } + // Always return empty slice instead of nil for consistency + if result == nil { + return []string{} + } + return result +} + +// HasScope checks if the given scope list includes a specific scope +func HasScope(scopes []string, scope string) bool { + for _, s := range scopes { + if s == scope { + return true + } + } + return false +} + +// HasAllScopes checks if the granted scopes include all of the requested scopes +func HasAllScopes(grantedScopes, requestedScopes []string) bool { + grantedSet := make(map[string]bool) + for _, scope := range grantedScopes { + grantedSet[scope] = true + } + + for _, requestedScope := range requestedScopes { + if !grantedSet[requestedScope] { + return false + } + } + return true +} diff --git a/internal/models/oauth_scope_test.go b/internal/models/oauth_scope_test.go new file mode 100644 index 000000000..c87271500 --- /dev/null +++ b/internal/models/oauth_scope_test.go @@ -0,0 +1,199 @@ +package models + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseScopeString(t *testing.T) { + tests := []struct { + name string + scope string + expected []string + }{ + { + name: "single scope", + scope: "openid", + expected: []string{"openid"}, + }, + { + name: "multiple scopes", + scope: "openid profile email", + expected: []string{"openid", "profile", "email"}, + }, + { + name: "empty scope", + scope: "", + expected: []string{}, + }, + { + name: "scope with extra spaces", + scope: " openid profile ", + expected: []string{"openid", "profile"}, + }, + { + name: "scope with empty segments", + scope: "openid profile email", + expected: []string{"openid", "profile", "email"}, + }, + { + name: "only spaces", + scope: " ", + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ParseScopeString(tt.scope) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestHasScope(t *testing.T) { + scopes := []string{"openid", "profile", "email"} + + tests := []struct { + name string + scopes []string + scope string + expected bool + }{ + { + name: "scope exists", + scopes: scopes, + scope: "openid", + expected: true, + }, + { + name: "scope exists middle", + scopes: scopes, + scope: "profile", + expected: true, + }, + { + name: "scope exists last", + scopes: scopes, + scope: "email", + expected: true, + }, + { + name: "scope does not exist", + scopes: scopes, + scope: "phone", + expected: false, + }, + { + name: "empty scope", + scopes: scopes, + scope: "", + expected: false, + }, + { + name: "empty scopes list", + scopes: []string{}, + scope: "openid", + expected: false, + }, + { + name: "nil scopes list", + scopes: nil, + scope: "openid", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := HasScope(tt.scopes, tt.scope) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestHasAllScopes(t *testing.T) { + grantedScopes := []string{"openid", "profile", "email"} + + tests := []struct { + name string + grantedScopes []string + requestedScopes []string + expected bool + }{ + { + name: "requested scopes are a subset of granted scopes", + grantedScopes: grantedScopes, + requestedScopes: []string{"openid", "profile"}, + expected: true, + }, + { + name: "requested scopes are an exact match of granted scopes", + grantedScopes: grantedScopes, + requestedScopes: []string{"openid", "profile", "email"}, + expected: true, + }, + { + name: "requested scopes are a single scope granted", + grantedScopes: grantedScopes, + requestedScopes: []string{"openid"}, + expected: true, + }, + { + name: "granted scopes are missing a scope", + grantedScopes: grantedScopes, + requestedScopes: []string{"openid", "phone"}, + expected: false, + }, + { + name: "granted scopes are missing multiple scopes", + grantedScopes: grantedScopes, + requestedScopes: []string{"phone", "address"}, + expected: false, + }, + { + name: "requested scopes are empty", + grantedScopes: grantedScopes, + requestedScopes: []string{}, + expected: true, + }, + { + name: "requested scopes are nil", + grantedScopes: grantedScopes, + requestedScopes: nil, + expected: true, + }, + { + name: "granted scopes are empty with requested scopes", + grantedScopes: []string{}, + requestedScopes: []string{"openid"}, + expected: false, + }, + { + name: "granted scopes are empty with requested scopes", + grantedScopes: []string{}, + requestedScopes: []string{}, + expected: true, + }, + { + name: "granted scopes are nil with requested scopes", + grantedScopes: nil, + requestedScopes: []string{"openid"}, + expected: false, + }, + { + name: "granted scopes are nil with requested scopes", + grantedScopes: nil, + requestedScopes: []string{}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := HasAllScopes(tt.grantedScopes, tt.requestedScopes) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/internal/security/pkce.go b/internal/security/pkce.go new file mode 100644 index 000000000..7e4ccf41d --- /dev/null +++ b/internal/security/pkce.go @@ -0,0 +1,32 @@ +package security + +import ( + "crypto/sha256" + "crypto/subtle" + "encoding/base64" + "errors" + "strings" +) + +const PKCEInvalidCodeChallengeError = "code challenge does not match previously saved code verifier" +const PKCEInvalidCodeMethodError = "code challenge method not supported" + +// VerifyPKCEChallenge performs PKCE verification using the provided challenge, method, and verifier +// This is a shared utility function used by both FlowState and OAuthServerAuthorization +func VerifyPKCEChallenge(codeChallenge, codeChallengeMethod, codeVerifier string) error { + switch strings.ToLower(codeChallengeMethod) { + case "s256": + hashedCodeVerifier := sha256.Sum256([]byte(codeVerifier)) + encodedCodeVerifier := base64.RawURLEncoding.EncodeToString(hashedCodeVerifier[:]) + if subtle.ConstantTimeCompare([]byte(codeChallenge), []byte(encodedCodeVerifier)) != 1 { + return errors.New(PKCEInvalidCodeChallengeError) + } + case "plain": + if subtle.ConstantTimeCompare([]byte(codeChallenge), []byte(codeVerifier)) != 1 { + return errors.New(PKCEInvalidCodeChallengeError) + } + default: + return errors.New(PKCEInvalidCodeMethodError) + } + return nil +} diff --git a/internal/security/pkce_test.go b/internal/security/pkce_test.go new file mode 100644 index 000000000..795b06778 --- /dev/null +++ b/internal/security/pkce_test.go @@ -0,0 +1,123 @@ +package security + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestVerifyPKCEChallenge(t *testing.T) { + tests := []struct { + name string + codeChallenge string + codeChallengeMethod string + codeVerifier string + wantErr bool + errMsg string + }{ + { + name: "valid S256 PKCE", + codeChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", // S256 of "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + codeChallengeMethod: "S256", + codeVerifier: "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk", + wantErr: false, + }, + { + name: "valid plain PKCE", + codeChallenge: "test-challenge", + codeChallengeMethod: "plain", + codeVerifier: "test-challenge", + wantErr: false, + }, + { + name: "invalid S256 verifier", + codeChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", + codeChallengeMethod: "S256", + codeVerifier: "wrong-verifier", + wantErr: true, + errMsg: "code challenge does not match", + }, + { + name: "invalid plain verifier", + codeChallenge: "test-challenge", + codeChallengeMethod: "plain", + codeVerifier: "wrong-challenge", + wantErr: true, + errMsg: "code challenge does not match", + }, + { + name: "invalid challenge method", + codeChallenge: "test-challenge", + codeChallengeMethod: "invalid", + codeVerifier: "test-challenge", + wantErr: true, + errMsg: "code challenge method not supported", + }, + { + name: "case insensitive S256 method", + codeChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", + codeChallengeMethod: "s256", + codeVerifier: "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk", + wantErr: false, + }, + { + name: "case insensitive plain method", + codeChallenge: "test-challenge", + codeChallengeMethod: "PLAIN", + codeVerifier: "test-challenge", + wantErr: false, + }, + { + name: "empty verifier with S256", + codeChallenge: "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", + codeChallengeMethod: "S256", + codeVerifier: "", + wantErr: true, + errMsg: "code challenge does not match", + }, + { + name: "empty verifier with plain", + codeChallenge: "test-challenge", + codeChallengeMethod: "plain", + codeVerifier: "", + wantErr: true, + errMsg: "code challenge does not match", + }, + { + name: "empty challenge with S256", + codeChallenge: "", + codeChallengeMethod: "S256", + codeVerifier: "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk", + wantErr: true, + errMsg: "code challenge does not match", + }, + { + name: "empty challenge with plain", + codeChallenge: "", + codeChallengeMethod: "plain", + codeVerifier: "test-challenge", + wantErr: true, + errMsg: "code challenge does not match", + }, + { + name: "both empty with plain", + codeChallenge: "", + codeChallengeMethod: "plain", + codeVerifier: "", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := VerifyPKCEChallenge(tt.codeChallenge, tt.codeChallengeMethod, tt.codeVerifier) + + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errMsg) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/utilities/strings.go b/internal/utilities/strings.go new file mode 100644 index 000000000..471a4552d --- /dev/null +++ b/internal/utilities/strings.go @@ -0,0 +1,17 @@ +package utilities + +// StringValue safely extracts a string from a *string, returning empty string if nil +func StringValue(s *string) string { + if s == nil { + return "" + } + return *s +} + +// StringPtr returns a pointer to a string if non-empty, nil otherwise +func StringPtr(s string) *string { + if s == "" { + return nil + } + return &s +} diff --git a/migrations/20250804100000_add_oauth_authorizations_consents.up.sql b/migrations/20250804100000_add_oauth_authorizations_consents.up.sql new file mode 100644 index 000000000..a74b4305c --- /dev/null +++ b/migrations/20250804100000_add_oauth_authorizations_consents.up.sql @@ -0,0 +1,88 @@ +-- Create OAuth 2.1 support with enums, authorization, and consent tables + +-- Create enums for OAuth authorization management +do $$ begin + create type {{ index .Options "Namespace" }}.oauth_authorization_status as enum('pending', 'approved', 'denied', 'expired'); +exception + when duplicate_object then null; +end $$; + +do $$ begin + create type {{ index .Options "Namespace" }}.oauth_response_type as enum('code'); +exception + when duplicate_object then null; +end $$; + +-- Create oauth_authorizations table for OAuth 2.1 authorization requests +create table if not exists {{ index .Options "Namespace" }}.oauth_authorizations ( + id uuid not null, + authorization_id text not null, + client_id uuid not null references {{ index .Options "Namespace" }}.oauth_clients(id) on delete cascade, + user_id uuid null references {{ index .Options "Namespace" }}.users(id) on delete cascade, + redirect_uri text not null, + scope text not null, + state text null, + resource text null, + code_challenge text null, + code_challenge_method {{ index .Options "Namespace" }}.code_challenge_method null, + response_type {{ index .Options "Namespace" }}.oauth_response_type not null default 'code', + + -- Flow control + status {{ index .Options "Namespace" }}.oauth_authorization_status not null default 'pending', + authorization_code text null, + + -- Timestamps + created_at timestamptz not null default now(), + expires_at timestamptz not null default (now() + interval '3 minutes'), + approved_at timestamptz null, + + constraint oauth_authorizations_pkey primary key (id), + constraint oauth_authorizations_authorization_id_key unique (authorization_id), + constraint oauth_authorizations_authorization_code_key unique (authorization_code), + constraint oauth_authorizations_redirect_uri_length check (char_length(redirect_uri) <= 2048), + constraint oauth_authorizations_scope_length check (char_length(scope) <= 4096), + constraint oauth_authorizations_state_length check (char_length(state) <= 4096), + constraint oauth_authorizations_resource_length check (char_length(resource) <= 2048), + constraint oauth_authorizations_code_challenge_length check (char_length(code_challenge) <= 128), + constraint oauth_authorizations_authorization_code_length check (char_length(authorization_code) <= 255), + constraint oauth_authorizations_expires_at_future check (expires_at > created_at) +); + +-- Create indexes for oauth_authorizations +-- for CleanupExpiredOAuthServerAuthorizations +create index if not exists oauth_auth_pending_exp_idx + on {{ index .Options "Namespace" }}.oauth_authorizations (expires_at) + where status = 'pending'; + + + +-- Create oauth_consents table for user consent management +create table if not exists {{ index .Options "Namespace" }}.oauth_consents ( + id uuid not null, + user_id uuid not null references {{ index .Options "Namespace" }}.users(id) on delete cascade, + client_id uuid not null references {{ index .Options "Namespace" }}.oauth_clients(id) on delete cascade, + scopes text not null, + granted_at timestamptz not null default now(), + revoked_at timestamptz null, + + constraint oauth_consents_pkey primary key (id), + constraint oauth_consents_user_client_unique unique (user_id, client_id), + constraint oauth_consents_scopes_length check (char_length(scopes) <= 2048), + constraint oauth_consents_scopes_not_empty check (char_length(trim(scopes)) > 0), + constraint oauth_consents_revoked_after_granted check (revoked_at is null or revoked_at >= granted_at) +); + +-- Create indexes for oauth_consents +-- Active consent look-up (user + client, only non-revoked rows) +create index if not exists oauth_consents_active_user_client_idx + on {{ index .Options "Namespace" }}.oauth_consents (user_id, client_id) + where revoked_at is null; + +-- "Show me all consents for this user, newest first" +create index if not exists oauth_consents_user_order_idx + on {{ index .Options "Namespace" }}.oauth_consents (user_id, granted_at desc); + +-- Bulk revoke for an entire client (only non-revoked rows) +create index if not exists oauth_consents_active_client_idx + on {{ index .Options "Namespace" }}.oauth_consents (client_id) + where revoked_at is null;