From bba7ab73915760b6d3f21b4a736ebe8ead7016f4 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Mon, 1 Sep 2025 20:32:16 +0200 Subject: [PATCH 1/4] feat: add OAuth client type --- internal/api/middleware.go | 6 +- internal/api/oauthserver/auth.go | 9 +- internal/api/oauthserver/client_auth.go | 110 +++++ internal/api/oauthserver/client_auth_test.go | 413 ++++++++++++++++++ internal/api/oauthserver/handlers.go | 30 +- internal/api/oauthserver/service.go | 42 +- internal/models/oauth_client.go | 38 ++ ...0250901200500_add_oauth_client_type.up.sql | 14 + 8 files changed, 643 insertions(+), 19 deletions(-) create mode 100644 internal/api/oauthserver/client_auth.go create mode 100644 internal/api/oauthserver/client_auth_test.go create mode 100644 migrations/20250901200500_add_oauth_client_type.up.sql diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 4c18e1666..b14d1202f 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -108,9 +108,9 @@ func (a *API) oauthClientAuth(w http.ResponseWriter, r *http.Request) (context.C return nil, apierrors.NewInternalServerError("Error validating client credentials").WithInternalError(err) } - // Validate client secret - if !oauthserver.ValidateClientSecret(clientSecret, client.ClientSecretHash) { - return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials") + // Validate authentication using centralized logic + if err := oauthserver.ValidateClientAuthentication(client, clientSecret); err != nil { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, err.Error()) } // Add authenticated client to context diff --git a/internal/api/oauthserver/auth.go b/internal/api/oauthserver/auth.go index 46f05d5bb..4a7756255 100644 --- a/internal/api/oauthserver/auth.go +++ b/internal/api/oauthserver/auth.go @@ -41,9 +41,12 @@ func ExtractClientCredentials(r *http.Request) (clientID, clientSecret string, e return "", "", nil } - // If only one is provided, it's an error - if clientID == "" || clientSecret == "" { - return "", "", errors.New("both client_id and client_secret must be provided") + // For public clients, only client_id is required (client_secret should be empty) + // For confidential clients, both client_id and client_secret are required + // We'll validate this based on the client type in the calling handler + // TODO(cemal) :: this will be validated in detail during the `/token` endpoint implementation + if clientID == "" { + return "", "", errors.New("client_id is required") } return clientID, clientSecret, nil diff --git a/internal/api/oauthserver/client_auth.go b/internal/api/oauthserver/client_auth.go new file mode 100644 index 000000000..6ed505e84 --- /dev/null +++ b/internal/api/oauthserver/client_auth.go @@ -0,0 +1,110 @@ +package oauthserver + +import ( + "fmt" + + "github.com/supabase/auth/internal/models" +) + +// InferClientTypeFromAuthMethod infers client type from token_endpoint_auth_method +func InferClientTypeFromAuthMethod(authMethod string) string { + switch authMethod { + case models.TokenEndpointAuthMethodNone: + return models.OAuthServerClientTypePublic + case models.TokenEndpointAuthMethodClientSecretBasic, models.TokenEndpointAuthMethodClientSecretPost: + return models.OAuthServerClientTypeConfidential + default: + return models.OAuthServerClientTypeConfidential // Default to confidential + } +} + +// GetValidAuthMethodsForClientType returns the valid authentication methods for a client type +func GetValidAuthMethodsForClientType(clientType string) []string { + switch clientType { + case models.OAuthServerClientTypePublic: + return []string{models.TokenEndpointAuthMethodNone} + case models.OAuthServerClientTypeConfidential: + return []string{ + models.TokenEndpointAuthMethodClientSecretBasic, + models.TokenEndpointAuthMethodClientSecretPost, + } + default: + return []string{} // Unknown client type + } +} + +// ValidateClientTypeConsistency validates consistency between client_type and token_endpoint_auth_method +func ValidateClientTypeConsistency(clientType, authMethod string) error { + if clientType == "" || authMethod == "" { + return nil // Skip validation if either is not provided + } + + expectedClientType := InferClientTypeFromAuthMethod(authMethod) + if clientType != expectedClientType { + return fmt.Errorf("client_type '%s' is inconsistent with token_endpoint_auth_method '%s' (expected client_type '%s')", + clientType, authMethod, expectedClientType) + } + + return nil +} + +// IsValidAuthMethodForClientType checks if the auth method is valid for the given client type +func IsValidAuthMethodForClientType(clientType, authMethod string) bool { + validMethods := GetValidAuthMethodsForClientType(clientType) + for _, method := range validMethods { + if method == authMethod { + return true + } + } + return false +} + +// DetermineClientType determines the final client type using the priority: +// 1. Explicit client_type +// 2. Inferred from token_endpoint_auth_method +// 3. Default to confidential +func DetermineClientType(explicitClientType, authMethod string) string { + // Priority 1: Explicit client_type + if explicitClientType != "" { + return explicitClientType + } + + // Priority 2: Infer from token_endpoint_auth_method + if authMethod != "" { + return InferClientTypeFromAuthMethod(authMethod) + } + + // Priority 3: Default to confidential + return models.OAuthServerClientTypeConfidential +} + +// ValidateClientAuthentication validates client authentication based on client type +func ValidateClientAuthentication(client *models.OAuthServerClient, providedSecret string) error { + if client.IsPublic() { + // Public clients should not provide client secrets + if providedSecret != "" { + return fmt.Errorf("public clients must not provide client_secret") + } + return nil + } + + // Confidential clients must provide a valid client secret + if providedSecret == "" { + return fmt.Errorf("confidential clients must provide client_secret") + } + + if !ValidateClientSecret(providedSecret, client.ClientSecretHash) { + return fmt.Errorf("invalid client credentials") + } + + return nil +} + +// GetAllValidAuthMethods returns all supported authentication methods +func GetAllValidAuthMethods() []string { + return []string{ + models.TokenEndpointAuthMethodNone, + models.TokenEndpointAuthMethodClientSecretBasic, + models.TokenEndpointAuthMethodClientSecretPost, + } +} diff --git a/internal/api/oauthserver/client_auth_test.go b/internal/api/oauthserver/client_auth_test.go new file mode 100644 index 000000000..20f0c0514 --- /dev/null +++ b/internal/api/oauthserver/client_auth_test.go @@ -0,0 +1,413 @@ +package oauthserver + +import ( + "testing" + + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/models" + "golang.org/x/crypto/bcrypt" +) + +func TestInferClientTypeFromAuthMethod(t *testing.T) { + tests := []struct { + name string + authMethod string + expected string + }{ + { + name: "none method should return public", + authMethod: models.TokenEndpointAuthMethodNone, + expected: models.OAuthServerClientTypePublic, + }, + { + name: "client_secret_basic should return confidential", + authMethod: models.TokenEndpointAuthMethodClientSecretBasic, + expected: models.OAuthServerClientTypeConfidential, + }, + { + name: "client_secret_post should return confidential", + authMethod: models.TokenEndpointAuthMethodClientSecretPost, + expected: models.OAuthServerClientTypeConfidential, + }, + { + name: "unknown method should default to confidential", + authMethod: "unknown_method", + expected: models.OAuthServerClientTypeConfidential, + }, + { + name: "empty method should default to confidential", + authMethod: "", + expected: models.OAuthServerClientTypeConfidential, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := InferClientTypeFromAuthMethod(tt.authMethod) + if result != tt.expected { + t.Errorf("InferClientTypeFromAuthMethod() = %v, expected %v", result, tt.expected) + } + }) + } +} + +func TestGetValidAuthMethodsForClientType(t *testing.T) { + + tests := []struct { + name string + clientType string + expected []string + }{ + { + name: "public client should only support none", + clientType: models.OAuthServerClientTypePublic, + expected: []string{models.TokenEndpointAuthMethodNone}, + }, + { + name: "confidential client should support secret methods", + clientType: models.OAuthServerClientTypeConfidential, + expected: []string{ + models.TokenEndpointAuthMethodClientSecretBasic, + models.TokenEndpointAuthMethodClientSecretPost, + }, + }, + { + name: "unknown client type should return empty", + clientType: "unknown_type", + expected: []string{}, + }, + { + name: "empty client type should return empty", + clientType: "", + expected: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetValidAuthMethodsForClientType(tt.clientType) + if len(result) != len(tt.expected) { + t.Errorf("GetValidAuthMethodsForClientType() returned %d methods, expected %d", len(result), len(tt.expected)) + return + } + for i, method := range result { + if method != tt.expected[i] { + t.Errorf("GetValidAuthMethodsForClientType()[%d] = %v, expected %v", i, method, tt.expected[i]) + } + } + }) + } +} + +func TestValidateClientTypeConsistency(t *testing.T) { + + tests := []struct { + name string + clientType string + authMethod string + expectError bool + }{ + { + name: "consistent public client", + clientType: models.OAuthServerClientTypePublic, + authMethod: models.TokenEndpointAuthMethodNone, + expectError: false, + }, + { + name: "consistent confidential client with basic auth", + clientType: models.OAuthServerClientTypeConfidential, + authMethod: models.TokenEndpointAuthMethodClientSecretBasic, + expectError: false, + }, + { + name: "consistent confidential client with post auth", + clientType: models.OAuthServerClientTypeConfidential, + authMethod: models.TokenEndpointAuthMethodClientSecretPost, + expectError: false, + }, + { + name: "inconsistent public client with secret auth", + clientType: models.OAuthServerClientTypePublic, + authMethod: models.TokenEndpointAuthMethodClientSecretBasic, + expectError: true, + }, + { + name: "inconsistent confidential client with none auth", + clientType: models.OAuthServerClientTypeConfidential, + authMethod: models.TokenEndpointAuthMethodNone, + expectError: true, + }, + { + name: "empty client type should not error", + clientType: "", + authMethod: models.TokenEndpointAuthMethodNone, + expectError: false, + }, + { + name: "empty auth method should not error", + clientType: models.OAuthServerClientTypePublic, + authMethod: "", + expectError: false, + }, + { + name: "both empty should not error", + clientType: "", + authMethod: "", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateClientTypeConsistency(tt.clientType, tt.authMethod) + if tt.expectError && err == nil { + t.Errorf("ValidateClientTypeConsistency() expected error but got nil") + } + if !tt.expectError && err != nil { + t.Errorf("ValidateClientTypeConsistency() expected no error but got: %v", err) + } + }) + } +} + +func TestDetermineClientType(t *testing.T) { + + tests := []struct { + name string + explicitClientType string + authMethod string + expected string + }{ + { + name: "explicit public overrides auth method", + explicitClientType: models.OAuthServerClientTypePublic, + authMethod: models.TokenEndpointAuthMethodClientSecretBasic, + expected: models.OAuthServerClientTypePublic, + }, + { + name: "explicit confidential overrides auth method", + explicitClientType: models.OAuthServerClientTypeConfidential, + authMethod: models.TokenEndpointAuthMethodNone, + expected: models.OAuthServerClientTypeConfidential, + }, + { + name: "infer public from none auth method", + explicitClientType: "", + authMethod: models.TokenEndpointAuthMethodNone, + expected: models.OAuthServerClientTypePublic, + }, + { + name: "infer confidential from basic auth method", + explicitClientType: "", + authMethod: models.TokenEndpointAuthMethodClientSecretBasic, + expected: models.OAuthServerClientTypeConfidential, + }, + { + name: "infer confidential from post auth method", + explicitClientType: "", + authMethod: models.TokenEndpointAuthMethodClientSecretPost, + expected: models.OAuthServerClientTypeConfidential, + }, + { + name: "default to confidential when both empty", + explicitClientType: "", + authMethod: "", + expected: models.OAuthServerClientTypeConfidential, + }, + { + name: "default to confidential with unknown auth method", + explicitClientType: "", + authMethod: "unknown_method", + expected: models.OAuthServerClientTypeConfidential, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := DetermineClientType(tt.explicitClientType, tt.authMethod) + if result != tt.expected { + t.Errorf("DetermineClientType() = %v, expected %v", result, tt.expected) + } + }) + } +} + +func TestIsValidAuthMethodForClientType(t *testing.T) { + + tests := []struct { + name string + clientType string + authMethod string + expected bool + }{ + { + name: "none is valid for public", + clientType: models.OAuthServerClientTypePublic, + authMethod: models.TokenEndpointAuthMethodNone, + expected: true, + }, + { + name: "basic is invalid for public", + clientType: models.OAuthServerClientTypePublic, + authMethod: models.TokenEndpointAuthMethodClientSecretBasic, + expected: false, + }, + { + name: "post is invalid for public", + clientType: models.OAuthServerClientTypePublic, + authMethod: models.TokenEndpointAuthMethodClientSecretPost, + expected: false, + }, + { + name: "none is invalid for confidential", + clientType: models.OAuthServerClientTypeConfidential, + authMethod: models.TokenEndpointAuthMethodNone, + expected: false, + }, + { + name: "basic is valid for confidential", + clientType: models.OAuthServerClientTypeConfidential, + authMethod: models.TokenEndpointAuthMethodClientSecretBasic, + expected: true, + }, + { + name: "post is valid for confidential", + clientType: models.OAuthServerClientTypeConfidential, + authMethod: models.TokenEndpointAuthMethodClientSecretPost, + expected: true, + }, + { + name: "unknown method is invalid for any type", + clientType: models.OAuthServerClientTypePublic, + authMethod: "unknown_method", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsValidAuthMethodForClientType(tt.clientType, tt.authMethod) + if result != tt.expected { + t.Errorf("IsValidAuthMethodForClientType() = %v, expected %v", result, tt.expected) + } + }) + } +} + +func TestValidateClientAuthentication(t *testing.T) { + + // Create test clients + publicClient := &models.OAuthServerClient{ + ID: uuid.Must(uuid.NewV4()), + ClientID: "public_client", + ClientType: models.OAuthServerClientTypePublic, + // No client secret hash for public clients + } + + // Create a hashed secret for confidential client + secretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) + confidentialClient := &models.OAuthServerClient{ + ID: uuid.Must(uuid.NewV4()), + ClientID: "confidential_client", + ClientType: models.OAuthServerClientTypeConfidential, + ClientSecretHash: string(secretHash), + } + + tests := []struct { + name string + client *models.OAuthServerClient + providedSecret string + expectError bool + errorContains string + }{ + { + name: "public client with no secret should pass", + client: publicClient, + providedSecret: "", + expectError: false, + }, + { + name: "public client with secret should fail", + client: publicClient, + providedSecret: "some_secret", + expectError: true, + errorContains: "public clients must not provide client_secret", + }, + { + name: "confidential client with correct secret should pass", + client: confidentialClient, + providedSecret: "test_secret", + expectError: false, + }, + { + name: "confidential client with no secret should fail", + client: confidentialClient, + providedSecret: "", + expectError: true, + errorContains: "confidential clients must provide client_secret", + }, + { + name: "confidential client with wrong secret should fail", + client: confidentialClient, + providedSecret: "wrong_secret", + expectError: true, + errorContains: "invalid client credentials", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateClientAuthentication(tt.client, tt.providedSecret) + + if tt.expectError { + if err == nil { + t.Errorf("ValidateClientAuthentication() expected error but got nil") + return + } + if tt.errorContains != "" && !containsString(err.Error(), tt.errorContains) { + t.Errorf("ValidateClientAuthentication() error = %v, expected to contain %v", err, tt.errorContains) + } + } else { + if err != nil { + t.Errorf("ValidateClientAuthentication() expected no error but got: %v", err) + } + } + }) + } +} + +func TestGetAllValidAuthMethods(t *testing.T) { + + expected := []string{ + models.TokenEndpointAuthMethodNone, + models.TokenEndpointAuthMethodClientSecretBasic, + models.TokenEndpointAuthMethodClientSecretPost, + } + + result := GetAllValidAuthMethods() + + if len(result) != len(expected) { + t.Errorf("GetAllValidAuthMethods() returned %d methods, expected %d", len(result), len(expected)) + return + } + + for i, method := range result { + if method != expected[i] { + t.Errorf("GetAllValidAuthMethods()[%d] = %v, expected %v", i, method, expected[i]) + } + } +} + +// Helper function to check if a string contains a substring +func containsString(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || (len(s) > len(substr) && + (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || + func() bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false + }()))) +} diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index 8a247a054..f6edc1e31 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -19,6 +19,7 @@ import ( type OAuthServerClientResponse struct { ClientID string `json:"client_id"` ClientSecret string `json:"client_secret,omitempty"` // only returned on registration + ClientType string `json:"client_type"` RedirectURIs []string `json:"redirect_uris,omitempty"` TokenEndpointAuthMethod []string `json:"token_endpoint_auth_method,omitempty"` @@ -41,12 +42,23 @@ type OAuthServerClientListResponse struct { // oauthServerClientToResponse converts a model to response format func oauthServerClientToResponse(client *models.OAuthServerClient, includeSecret bool) *OAuthServerClientResponse { + // Set token endpoint auth methods based on client type + var tokenEndpointAuthMethods []string + if client.IsPublic() { + // Public clients don't use client authentication + tokenEndpointAuthMethods = []string{models.TokenEndpointAuthMethodNone} + } else { + // Confidential clients use client secret authentication + tokenEndpointAuthMethods = []string{models.TokenEndpointAuthMethodClientSecretBasic, models.TokenEndpointAuthMethodClientSecretPost} + } + response := &OAuthServerClientResponse{ - ClientID: client.ClientID, + ClientID: client.ClientID, + ClientType: client.ClientType, // OAuth 2.1 DCR fields RedirectURIs: client.GetRedirectURIs(), - TokenEndpointAuthMethod: []string{"client_secret_basic", "client_secret_post"}, // Both methods are supported + TokenEndpointAuthMethod: tokenEndpointAuthMethods, GrantTypes: client.GetGrantTypes(), ResponseTypes: []string{"code"}, // Always "code" in OAuth 2.1 ClientName: utilities.StringValue(client.ClientName), @@ -59,8 +71,8 @@ func oauthServerClientToResponse(client *models.OAuthServerClient, includeSecret UpdatedAt: client.UpdatedAt, } - // Only include client_secret during registration - if includeSecret { + // Only include client_secret during registration and only for confidential clients + if includeSecret && client.IsConfidential() { // Note: This will be filled in by the handler with the plaintext secret response.ClientSecret = "" } @@ -109,7 +121,9 @@ func (s *Server) AdminOAuthServerClientRegister(w http.ResponseWriter, r *http.R } response := oauthServerClientToResponse(client, true) - response.ClientSecret = plaintextSecret + if client.IsConfidential() { + response.ClientSecret = plaintextSecret + } return shared.SendJSON(w, http.StatusCreated, response) } @@ -136,7 +150,9 @@ func (s *Server) OAuthServerClientDynamicRegister(w http.ResponseWriter, r *http } response := oauthServerClientToResponse(client, true) - response.ClientSecret = plaintextSecret + if client.IsConfidential() { + response.ClientSecret = plaintextSecret + } return shared.SendJSON(w, http.StatusCreated, response) } @@ -219,7 +235,7 @@ func (s *Server) OAuthServerMetadata(w http.ResponseWriter, r *http.Request) err ResponseTypesSupported: []string{"code"}, ResponseModesSupported: []string{"query"}, GrantTypesSupported: []string{"authorization_code", "refresh_token"}, - TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"}, + TokenEndpointAuthMethodsSupported: []string{models.TokenEndpointAuthMethodClientSecretBasic, models.TokenEndpointAuthMethodClientSecretPost, models.TokenEndpointAuthMethodNone}, CodeChallengeMethodsSupported: []string{"S256", "plain"}, } diff --git a/internal/api/oauthserver/service.go b/internal/api/oauthserver/service.go index 02c3911bd..efb1d4bb6 100644 --- a/internal/api/oauthserver/service.go +++ b/internal/api/oauthserver/service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/url" + "slices" "time" "github.com/pkg/errors" @@ -19,6 +20,10 @@ type OAuthServerClientRegisterParams struct { // Required fields RedirectURIs []string `json:"redirect_uris"` + // Client type can be explicitly provided or inferred from token_endpoint_auth_method + ClientType string `json:"client_type,omitempty"` // models.OAuthServerClientTypePublic or models.OAuthServerClientTypeConfidential + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` // "none", "client_secret_basic", or "client_secret_post" + GrantTypes []string `json:"grant_types,omitempty"` ClientName string `json:"client_name,omitempty"` ClientURI string `json:"client_uri,omitempty"` @@ -77,6 +82,24 @@ func (p *OAuthServerClientRegisterParams) validate() error { return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "registration_type must be 'dynamic' or 'manual'") } + // Validate client_type if provided (defaults to confidential if not specified) + if p.ClientType != "" && p.ClientType != models.OAuthServerClientTypePublic && p.ClientType != models.OAuthServerClientTypeConfidential { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_type must be '%s' or '%s'", models.OAuthServerClientTypePublic, models.OAuthServerClientTypeConfidential) + } + + // Validate token_endpoint_auth_method if provided + if p.TokenEndpointAuthMethod != "" { + validMethods := GetAllValidAuthMethods() + if !slices.Contains(validMethods, p.TokenEndpointAuthMethod) { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "token_endpoint_auth_method must be one of: %v", validMethods) + } + } + + // Validate consistency between client_type and token_endpoint_auth_method + if err := ValidateClientTypeConsistency(p.ClientType, p.TokenEndpointAuthMethod); err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, err.Error()) + } + return nil } @@ -155,11 +178,15 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer grantTypes = []string{"authorization_code", "refresh_token"} } + // Determine client type using centralized logic + clientType := DetermineClientType(params.ClientType, params.TokenEndpointAuthMethod) + db := s.db.WithContext(ctx) client := &models.OAuthServerClient{ ClientID: generateClientID(), RegistrationType: params.RegistrationType, + ClientType: clientType, ClientName: utilities.StringPtr(params.ClientName), ClientURI: utilities.StringPtr(params.ClientURI), LogoURI: utilities.StringPtr(params.LogoURI), @@ -168,13 +195,16 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer client.SetRedirectURIs(params.RedirectURIs) client.SetGrantTypes(grantTypes) - // Generate client secret for all clients - plaintextSecret := generateClientSecret() - hash, err := hashClientSecret(plaintextSecret) - if err != nil { - return nil, "", errors.Wrap(err, "failed to hash client secret") + var plaintextSecret string + // Only generate client secret for confidential clients + if client.IsConfidential() { + plaintextSecret = generateClientSecret() + hash, err := hashClientSecret(plaintextSecret) + if err != nil { + return nil, "", errors.Wrap(err, "failed to hash client secret") + } + client.ClientSecretHash = hash } - client.ClientSecretHash = hash if err := models.CreateOAuthServerClient(db, client); err != nil { return nil, "", errors.Wrap(err, "failed to create OAuth client") diff --git a/internal/models/oauth_client.go b/internal/models/oauth_client.go index 93805a0f7..bdb3bdbd3 100644 --- a/internal/models/oauth_client.go +++ b/internal/models/oauth_client.go @@ -13,12 +13,26 @@ import ( "github.com/supabase/auth/internal/storage" ) +// OAuth client type constants +const ( + OAuthServerClientTypePublic = "public" + OAuthServerClientTypeConfidential = "confidential" +) + +// OAuth token endpoint authentication method constants +const ( + TokenEndpointAuthMethodNone = "none" + TokenEndpointAuthMethodClientSecretBasic = "client_secret_basic" + TokenEndpointAuthMethodClientSecretPost = "client_secret_post" +) + // OAuthServerClient represents an OAuth client application registered with this OAuth server type OAuthServerClient struct { ID uuid.UUID `json:"-" db:"id"` ClientID string `json:"client_id" db:"client_id"` ClientSecretHash string `json:"-" db:"client_secret_hash"` RegistrationType string `json:"registration_type" db:"registration_type"` + ClientType string `json:"client_type" db:"client_type"` RedirectURIs string `json:"-" db:"redirect_uris"` GrantTypes string `json:"grant_types" db:"grant_types"` @@ -51,10 +65,24 @@ func (c *OAuthServerClient) Validate() error { return fmt.Errorf("registration_type must be 'dynamic' or 'manual'") } + if c.ClientType != OAuthServerClientTypePublic && c.ClientType != OAuthServerClientTypeConfidential { + return fmt.Errorf("client_type must be '%s' or '%s'", OAuthServerClientTypePublic, OAuthServerClientTypeConfidential) + } + if c.RedirectURIs == "" { return fmt.Errorf("at least one redirect_uri is required") } + // Confidential clients must have a client secret + if c.ClientType == OAuthServerClientTypeConfidential && c.ClientSecretHash == "" { + return fmt.Errorf("client_secret is required for confidential clients") + } + + // Public clients should not have a client secret (enforce PKCE instead) + if c.ClientType == OAuthServerClientTypePublic && c.ClientSecretHash != "" { + return fmt.Errorf("client_secret is not allowed for public clients, use PKCE instead") + } + return nil } @@ -84,6 +112,16 @@ func (c *OAuthServerClient) SetGrantTypes(types []string) { c.GrantTypes = strings.Join(types, ",") } +// IsPublic returns true if the client is a public client +func (c *OAuthServerClient) IsPublic() bool { + return c.ClientType == OAuthServerClientTypePublic +} + +// IsConfidential returns true if the client is a confidential client +func (c *OAuthServerClient) IsConfidential() bool { + return c.ClientType == OAuthServerClientTypeConfidential +} + // validateRedirectURI validates a single redirect URI according to OAuth 2.1 spec func validateRedirectURI(uri string) error { if uri == "" { diff --git a/migrations/20250901200500_add_oauth_client_type.up.sql b/migrations/20250901200500_add_oauth_client_type.up.sql new file mode 100644 index 000000000..a9712a704 --- /dev/null +++ b/migrations/20250901200500_add_oauth_client_type.up.sql @@ -0,0 +1,14 @@ +-- Make client_secret_hash nullable to support public clients +-- Public clients don't have client secrets, only confidential clients do + +alter table {{ index .Options "Namespace" }}.oauth_clients alter column client_secret_hash drop not null; + +-- Add client_type enum and column to oauth_clients table +do $$ begin + create type {{ index .Options "Namespace" }}.oauth_client_type as enum('public', 'confidential'); +exception + when duplicate_object then null; +end $$; + +-- Add client_type column to oauth_clients table +alter table {{ index .Options "Namespace" }}.oauth_clients add column if not exists client_type {{ index .Options "Namespace" }}.oauth_client_type not null default 'confidential'; From dc241103f858f6ffab8ac8d3b112f8ac5eded518 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Mon, 1 Sep 2025 21:49:49 +0200 Subject: [PATCH 2/4] fix: provide client_type param in oauth_client_test --- internal/models/oauth_client_test.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/internal/models/oauth_client_test.go b/internal/models/oauth_client_test.go index 233da83fb..937686d4b 100644 --- a/internal/models/oauth_client_test.go +++ b/internal/models/oauth_client_test.go @@ -40,11 +40,14 @@ func TestOAuthServerClient(t *testing.T) { func (ts *OAuthServerClientTestSuite) TestOAuthServerClientValidation() { testClientName := "Test Client" + testSecretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) validClient := &OAuthServerClient{ ID: uuid.Must(uuid.NewV4()), ClientID: "test_client_id", ClientName: &testClientName, RegistrationType: "dynamic", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: string(testSecretHash), RedirectURIs: "https://example.com/callback", GrantTypes: "authorization_code,refresh_token", } @@ -143,11 +146,14 @@ func (ts *OAuthServerClientTestSuite) TestRedirectURIHelpers() { func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClient() { testAppName := "Test Application" + testSecretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) client := &OAuthServerClient{ ClientID: "test_client_create_" + uuid.Must(uuid.NewV4()).String()[:8], ClientName: &testAppName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: string(testSecretHash), RedirectURIs: "https://example.com/callback", } @@ -173,11 +179,14 @@ func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClientValidation() { func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { // Create a test client testName := "Find By ID Test" + testSecretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) client := &OAuthServerClient{ ClientID: "test_client_find_by_id_" + uuid.Must(uuid.NewV4()).String()[:8], ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: string(testSecretHash), RedirectURIs: "https://example.com/callback", } @@ -199,11 +208,14 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { // Create a test client testName := "Find By Client ID Test" + testSecretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) client := &OAuthServerClient{ ClientID: "test_client_find_by_client_id_" + uuid.Must(uuid.NewV4()).String()[:8], ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "manual", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: string(testSecretHash), RedirectURIs: "https://example.com/callback", } @@ -225,11 +237,14 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { func (ts *OAuthServerClientTestSuite) TestUpdateOAuthServerClient() { // Create a test client originalName := "Original Name" + testSecretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) client := &OAuthServerClient{ ClientID: "test_client_update_" + uuid.Must(uuid.NewV4()).String()[:8], ClientName: &originalName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: string(testSecretHash), RedirectURIs: "https://example.com/callback", } @@ -274,11 +289,14 @@ func (ts *OAuthServerClientTestSuite) TestClientSecretHashing() { func (ts *OAuthServerClientTestSuite) TestSoftDelete() { // Create a test client testName := "Soft Delete Test" + testSecretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) client := &OAuthServerClient{ ClientID: "test_client_soft_delete_" + uuid.Must(uuid.NewV4()).String()[:8], ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: string(testSecretHash), RedirectURIs: "https://example.com/callback", } From a97ffa052e329baa71ef3d62d6a0459a8b789b21 Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Wed, 3 Sep 2025 11:19:58 +0200 Subject: [PATCH 3/4] feat: use sha256 for client secrets --- internal/api/oauthserver/client_auth_test.go | 5 ++- internal/api/oauthserver/service.go | 33 +++++++++++++------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/internal/api/oauthserver/client_auth_test.go b/internal/api/oauthserver/client_auth_test.go index 20f0c0514..4e21c6650 100644 --- a/internal/api/oauthserver/client_auth_test.go +++ b/internal/api/oauthserver/client_auth_test.go @@ -5,7 +5,6 @@ import ( "github.com/gofrs/uuid" "github.com/supabase/auth/internal/models" - "golang.org/x/crypto/bcrypt" ) func TestInferClientTypeFromAuthMethod(t *testing.T) { @@ -305,12 +304,12 @@ func TestValidateClientAuthentication(t *testing.T) { } // Create a hashed secret for confidential client - secretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) + secretHash, _ := hashClientSecret("test_secret") confidentialClient := &models.OAuthServerClient{ ID: uuid.Must(uuid.NewV4()), ClientID: "confidential_client", ClientType: models.OAuthServerClientTypeConfidential, - ClientSecretHash: string(secretHash), + ClientSecretHash: secretHash, } tests := []struct { diff --git a/internal/api/oauthserver/service.go b/internal/api/oauthserver/service.go index efb1d4bb6..95a8e8b46 100644 --- a/internal/api/oauthserver/service.go +++ b/internal/api/oauthserver/service.go @@ -2,6 +2,10 @@ package oauthserver import ( "context" + "crypto/rand" + "crypto/sha256" + "crypto/subtle" + "encoding/base64" "fmt" "net/url" "slices" @@ -12,7 +16,6 @@ import ( "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/utilities" - "golang.org/x/crypto/bcrypt" ) // OAuthServerClientRegisterParams contains parameters for registering a new OAuth client @@ -146,23 +149,29 @@ func generateClientID() string { // generateClientSecret generates a secure random client secret func generateClientSecret() string { - // Generate a 64-character secure random secret - return crypto.SecureAlphanumeric(64) + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + // This should never happen, but fallback to panic for security + panic(fmt.Sprintf("failed to generate random bytes for client secret: %v", err)) + } + return base64.RawURLEncoding.EncodeToString(b) } -// hashClientSecret hashes a client secret using bcrypt +// hashClientSecret hashes a client secret using SHA-256 func hashClientSecret(secret string) (string, error) { - hash, err := bcrypt.GenerateFromPassword([]byte(secret), bcrypt.DefaultCost) + sum := sha256.Sum256([]byte(secret)) + return base64.RawURLEncoding.EncodeToString(sum[:]), nil +} + +// ValidateClientSecret validates a client secret against its hash using constant-time comparison +func ValidateClientSecret(providedSecret, storedHash string) bool { + calc := sha256.Sum256([]byte(providedSecret)) + stored, err := base64.RawURLEncoding.DecodeString(storedHash) if err != nil { - return "", errors.Wrap(err, "failed to hash client secret") + return false } - return string(hash), nil -} -// ValidateClientSecret validates a client secret against its hash -func ValidateClientSecret(secret, hash string) bool { - err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(secret)) - return err == nil + return subtle.ConstantTimeCompare(calc[:], stored) == 1 } // registerOAuthServerClient creates a new OAuth server client with generated credentials From 5213f87c504d855343806e72352ffc5312d90f2a Mon Sep 17 00:00:00 2001 From: Cemal Kilic Date: Wed, 3 Sep 2025 12:00:20 +0200 Subject: [PATCH 4/4] fix: use sha256 in the tests --- internal/models/oauth_client_test.go | 47 +++++++++++++++++----------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/internal/models/oauth_client_test.go b/internal/models/oauth_client_test.go index 937686d4b..7f21f4630 100644 --- a/internal/models/oauth_client_test.go +++ b/internal/models/oauth_client_test.go @@ -1,6 +1,8 @@ package models import ( + "crypto/sha256" + "encoding/base64" "testing" "time" @@ -11,7 +13,6 @@ import ( "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/storage" "github.com/supabase/auth/internal/storage/test" - "golang.org/x/crypto/bcrypt" ) type OAuthServerClientTestSuite struct { @@ -19,6 +20,12 @@ type OAuthServerClientTestSuite struct { db *storage.Connection } +// testHashClientSecret is a test helper that hashes a client secret using the same method as the service +func testHashClientSecret(secret string) (string, error) { + sum := sha256.Sum256([]byte(secret)) + return base64.RawURLEncoding.EncodeToString(sum[:]), nil +} + func (ts *OAuthServerClientTestSuite) SetupTest() { _ = TruncateAll(ts.db) } @@ -40,14 +47,14 @@ func TestOAuthServerClient(t *testing.T) { func (ts *OAuthServerClientTestSuite) TestOAuthServerClientValidation() { testClientName := "Test Client" - testSecretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) + testSecretHash, _ := testHashClientSecret("test_secret") validClient := &OAuthServerClient{ ID: uuid.Must(uuid.NewV4()), ClientID: "test_client_id", ClientName: &testClientName, RegistrationType: "dynamic", ClientType: OAuthServerClientTypeConfidential, - ClientSecretHash: string(testSecretHash), + ClientSecretHash: testSecretHash, RedirectURIs: "https://example.com/callback", GrantTypes: "authorization_code,refresh_token", } @@ -146,14 +153,14 @@ func (ts *OAuthServerClientTestSuite) TestRedirectURIHelpers() { func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClient() { testAppName := "Test Application" - testSecretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) + testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ ClientID: "test_client_create_" + uuid.Must(uuid.NewV4()).String()[:8], ClientName: &testAppName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", ClientType: OAuthServerClientTypeConfidential, - ClientSecretHash: string(testSecretHash), + ClientSecretHash: testSecretHash, RedirectURIs: "https://example.com/callback", } @@ -179,14 +186,14 @@ func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClientValidation() { func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { // Create a test client testName := "Find By ID Test" - testSecretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) + testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ ClientID: "test_client_find_by_id_" + uuid.Must(uuid.NewV4()).String()[:8], ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", ClientType: OAuthServerClientTypeConfidential, - ClientSecretHash: string(testSecretHash), + ClientSecretHash: testSecretHash, RedirectURIs: "https://example.com/callback", } @@ -208,14 +215,14 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { // Create a test client testName := "Find By Client ID Test" - testSecretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) + testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ ClientID: "test_client_find_by_client_id_" + uuid.Must(uuid.NewV4()).String()[:8], ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "manual", ClientType: OAuthServerClientTypeConfidential, - ClientSecretHash: string(testSecretHash), + ClientSecretHash: testSecretHash, RedirectURIs: "https://example.com/callback", } @@ -237,14 +244,14 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { func (ts *OAuthServerClientTestSuite) TestUpdateOAuthServerClient() { // Create a test client originalName := "Original Name" - testSecretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) + testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ ClientID: "test_client_update_" + uuid.Must(uuid.NewV4()).String()[:8], ClientName: &originalName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", ClientType: OAuthServerClientTypeConfidential, - ClientSecretHash: string(testSecretHash), + ClientSecretHash: testSecretHash, RedirectURIs: "https://example.com/callback", } @@ -274,29 +281,31 @@ func (ts *OAuthServerClientTestSuite) TestClientSecretHashing() { // Test that secrets can be properly hashed and validated secret := "super_secret_client_secret" - hash, err := bcrypt.GenerateFromPassword([]byte(secret), bcrypt.DefaultCost) + hash, err := testHashClientSecret(secret) require.NoError(ts.T(), err) - // Test correct secret validates - err = bcrypt.CompareHashAndPassword(hash, []byte(secret)) - assert.NoError(ts.T(), err) + // Test correct secret validates - hash the provided secret and compare + calc := sha256.Sum256([]byte(secret)) + stored, err := base64.RawURLEncoding.DecodeString(hash) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), calc[:], stored) // Test incorrect secret fails - err = bcrypt.CompareHashAndPassword(hash, []byte("wrong_secret")) - assert.Error(ts.T(), err) + wrongCalc := sha256.Sum256([]byte("wrong_secret")) + assert.NotEqual(ts.T(), wrongCalc[:], stored) } func (ts *OAuthServerClientTestSuite) TestSoftDelete() { // Create a test client testName := "Soft Delete Test" - testSecretHash, _ := bcrypt.GenerateFromPassword([]byte("test_secret"), bcrypt.DefaultCost) + testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ ClientID: "test_client_soft_delete_" + uuid.Must(uuid.NewV4()).String()[:8], ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", ClientType: OAuthServerClientTypeConfidential, - ClientSecretHash: string(testSecretHash), + ClientSecretHash: testSecretHash, RedirectURIs: "https://example.com/callback", }