Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ func (a *API) oauthClientAuth(w http.ResponseWriter, r *http.Request) (context.C
return nil, apierrors.NewInternalServerError("Error validating client credentials").WithInternalError(err)
}

// Validate client secret
if !oauthserver.ValidateClientSecret(clientSecret, client.ClientSecretHash) {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials")
// Validate authentication using centralized logic
if err := oauthserver.ValidateClientAuthentication(client, clientSecret); err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, err.Error())
}

// Add authenticated client to context
Expand Down
9 changes: 6 additions & 3 deletions internal/api/oauthserver/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ func ExtractClientCredentials(r *http.Request) (clientID, clientSecret string, e
return "", "", nil
}

// If only one is provided, it's an error
if clientID == "" || clientSecret == "" {
return "", "", errors.New("both client_id and client_secret must be provided")
// For public clients, only client_id is required (client_secret should be empty)
// For confidential clients, both client_id and client_secret are required
// We'll validate this based on the client type in the calling handler
// TODO(cemal) :: this will be validated in detail during the `/token` endpoint implementation
if clientID == "" {
return "", "", errors.New("client_id is required")
}

return clientID, clientSecret, nil
Expand Down
110 changes: 110 additions & 0 deletions internal/api/oauthserver/client_auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package oauthserver

import (
"fmt"

"github.com/supabase/auth/internal/models"
)

// InferClientTypeFromAuthMethod infers client type from token_endpoint_auth_method
func InferClientTypeFromAuthMethod(authMethod string) string {
switch authMethod {
case models.TokenEndpointAuthMethodNone:
return models.OAuthServerClientTypePublic
case models.TokenEndpointAuthMethodClientSecretBasic, models.TokenEndpointAuthMethodClientSecretPost:
return models.OAuthServerClientTypeConfidential
default:
return models.OAuthServerClientTypeConfidential // Default to confidential
}
}

// GetValidAuthMethodsForClientType returns the valid authentication methods for a client type
func GetValidAuthMethodsForClientType(clientType string) []string {
switch clientType {
case models.OAuthServerClientTypePublic:
return []string{models.TokenEndpointAuthMethodNone}
case models.OAuthServerClientTypeConfidential:
return []string{
models.TokenEndpointAuthMethodClientSecretBasic,
models.TokenEndpointAuthMethodClientSecretPost,
}
default:
return []string{} // Unknown client type
}
}

// ValidateClientTypeConsistency validates consistency between client_type and token_endpoint_auth_method
func ValidateClientTypeConsistency(clientType, authMethod string) error {
if clientType == "" || authMethod == "" {
return nil // Skip validation if either is not provided
}

expectedClientType := InferClientTypeFromAuthMethod(authMethod)
if clientType != expectedClientType {
return fmt.Errorf("client_type '%s' is inconsistent with token_endpoint_auth_method '%s' (expected client_type '%s')",
clientType, authMethod, expectedClientType)
}

return nil
}

// IsValidAuthMethodForClientType checks if the auth method is valid for the given client type
func IsValidAuthMethodForClientType(clientType, authMethod string) bool {
validMethods := GetValidAuthMethodsForClientType(clientType)
for _, method := range validMethods {
if method == authMethod {
return true
}
}
return false
}

// DetermineClientType determines the final client type using the priority:
// 1. Explicit client_type
// 2. Inferred from token_endpoint_auth_method
// 3. Default to confidential
func DetermineClientType(explicitClientType, authMethod string) string {
// Priority 1: Explicit client_type
if explicitClientType != "" {
return explicitClientType
}

// Priority 2: Infer from token_endpoint_auth_method
if authMethod != "" {
return InferClientTypeFromAuthMethod(authMethod)
}

// Priority 3: Default to confidential
return models.OAuthServerClientTypeConfidential
}

// ValidateClientAuthentication validates client authentication based on client type
func ValidateClientAuthentication(client *models.OAuthServerClient, providedSecret string) error {
if client.IsPublic() {
// Public clients should not provide client secrets
if providedSecret != "" {
return fmt.Errorf("public clients must not provide client_secret")
}
return nil
}

// Confidential clients must provide a valid client secret
if providedSecret == "" {
return fmt.Errorf("confidential clients must provide client_secret")
}

if !ValidateClientSecret(providedSecret, client.ClientSecretHash) {
return fmt.Errorf("invalid client credentials")
}

return nil
}

// GetAllValidAuthMethods returns all supported authentication methods
func GetAllValidAuthMethods() []string {
return []string{
models.TokenEndpointAuthMethodNone,
models.TokenEndpointAuthMethodClientSecretBasic,
models.TokenEndpointAuthMethodClientSecretPost,
}
}
Loading
Loading