Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

chimiddleware "github.com/go-chi/chi/v5/middleware"
"github.com/gofrs/uuid"
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/oauthserver"
Expand Down Expand Up @@ -98,9 +99,15 @@ func (a *API) oauthClientAuth(w http.ResponseWriter, r *http.Request) (context.C
return ctx, nil
}

// Parse client_id as UUID
clientUUID, err := uuid.FromString(clientID)
if err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client_id format")
}

// Validate client credentials
db := a.db.WithContext(ctx)
client, err := models.FindOAuthServerClientByClientID(db, clientID)
client, err := models.FindOAuthServerClientByID(db, clientUUID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials")
Expand Down
15 changes: 11 additions & 4 deletions internal/api/oauthserver/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"

"github.com/go-chi/chi/v5"
"github.com/gofrs/uuid"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/models"
Expand Down Expand Up @@ -103,7 +104,13 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er
return err
}

client, err := s.getOAuthServerClient(ctx, params.ClientID)
// Parse client_id as UUID
clientID, err := uuid.FromString(params.ClientID)
if err != nil {
return apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthClientNotFound, "invalid client_id format")
}

client, err := s.getOAuthServerClient(ctx, clientID)
if err != nil {
if models.IsNotFoundError(err) {
return apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthClientNotFound, "invalid client_id")
Expand Down Expand Up @@ -144,7 +151,7 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er
}

observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID)
observability.LogEntrySetField(r, "client_id", client.ClientID)
observability.LogEntrySetField(r, "client_id", client.ID.String())

// Redirect to authorization path with authorization_id
if config.OAuthServer.AuthorizationPath == "" {
Expand Down Expand Up @@ -228,7 +235,7 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ
response := AuthorizationDetailsResponse{
AuthorizationID: authorization.AuthorizationID,
Client: ClientDetailsResponse{
ClientID: authorization.Client.ClientID,
ClientID: authorization.Client.ID.String(),
ClientName: utilities.StringValue(authorization.Client.ClientName),
ClientURI: utilities.StringValue(authorization.Client.ClientURI),
LogoURI: utilities.StringValue(authorization.Client.LogoURI),
Expand All @@ -241,7 +248,7 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ
}

observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID)
observability.LogEntrySetField(r, "client_id", authorization.Client.ClientID)
observability.LogEntrySetField(r, "client_id", authorization.Client.ID.String())

return shared.SendJSON(w, http.StatusOK, response)
}
Expand Down
2 changes: 0 additions & 2 deletions internal/api/oauthserver/client_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ func TestValidateClientAuthentication(t *testing.T) {
// Create test clients
publicClient := &models.OAuthServerClient{
ID: uuid.Must(uuid.NewV4()),
ClientID: "public_client",
ClientType: models.OAuthServerClientTypePublic,
// No client secret hash for public clients
}
Expand All @@ -307,7 +306,6 @@ func TestValidateClientAuthentication(t *testing.T) {
secretHash, _ := hashClientSecret("test_secret")
confidentialClient := &models.OAuthServerClient{
ID: uuid.Must(uuid.NewV4()),
ClientID: "confidential_client",
ClientType: models.OAuthServerClientTypeConfidential,
ClientSecretHash: secretHash,
}
Expand Down
17 changes: 12 additions & 5 deletions internal/api/oauthserver/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/go-chi/chi/v5"
"github.com/gofrs/uuid"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/models"
Expand Down Expand Up @@ -53,7 +54,7 @@ func oauthServerClientToResponse(client *models.OAuthServerClient, includeSecret
}

response := &OAuthServerClientResponse{
ClientID: client.ClientID,
ClientID: client.ID.String(),
ClientType: client.ClientType,

// OAuth 2.1 DCR fields
Expand Down Expand Up @@ -83,13 +84,19 @@ func oauthServerClientToResponse(client *models.OAuthServerClient, includeSecret
// LoadOAuthServerClient is middleware that loads an OAuth server client from the URL parameter
func (s *Server) LoadOAuthServerClient(w http.ResponseWriter, r *http.Request) (context.Context, error) {
ctx := r.Context()
clientID := chi.URLParam(r, "client_id")
clientIDStr := chi.URLParam(r, "client_id")

if clientID == "" {
if clientIDStr == "" {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_id is required")
}

observability.LogEntrySetField(r, "oauth_client_id", clientID)
// Parse client_id as UUID
clientID, err := uuid.FromString(clientIDStr)
if err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid client_id format")
}

observability.LogEntrySetField(r, "oauth_client_id", clientIDStr)

client, err := s.getOAuthServerClient(ctx, clientID)
if err != nil {
Expand Down Expand Up @@ -171,7 +178,7 @@ func (s *Server) OAuthServerClientDelete(w http.ResponseWriter, r *http.Request)
ctx := r.Context()
client := GetOAuthServerClient(ctx)

if err := s.deleteOAuthServerClient(ctx, client.ClientID); err != nil {
if err := s.deleteOAuthServerClient(ctx, client.ID); err != nil {
return apierrors.NewInternalServerError("Error deleting OAuth client").WithInternalError(err)
}

Expand Down
12 changes: 6 additions & 6 deletions internal/api/oauthserver/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientDynamicRegisterDisabled() {
func (ts *OAuthClientTestSuite) TestOAuthServerClientGetHandler() {
client, _ := ts.createTestOAuthClient()

req := httptest.NewRequest(http.MethodGet, "/admin/oauth/clients/"+client.ClientID, nil)
req := httptest.NewRequest(http.MethodGet, "/admin/oauth/clients/"+client.ID.String(), nil)

ctx := WithOAuthServerClient(req.Context(), client)
req = req.WithContext(ctx)
Expand All @@ -183,7 +183,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientGetHandler() {
err = json.Unmarshal(w.Body.Bytes(), &response)
require.NoError(ts.T(), err)

assert.Equal(ts.T(), client.ClientID, response.ClientID)
assert.Equal(ts.T(), client.ID.String(), response.ClientID)
assert.Empty(ts.T(), response.ClientSecret) // Should NOT be included in get response
assert.Equal(ts.T(), "Test Client", response.ClientName)
}
Expand All @@ -193,7 +193,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientDeleteHandler() {
client, _ := ts.createTestOAuthClient()

// Create HTTP request with client in context
req := httptest.NewRequest(http.MethodDelete, "/admin/oauth/clients/"+client.ClientID, nil)
req := httptest.NewRequest(http.MethodDelete, "/admin/oauth/clients/"+client.ID.String(), nil)

// Add client to context (normally done by LoadOAuthServerClient middleware)
ctx := WithOAuthServerClient(req.Context(), client)
Expand All @@ -208,7 +208,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientDeleteHandler() {
assert.Empty(ts.T(), w.Body.String())

// Verify client was soft-deleted
deletedClient, err := ts.Server.getOAuthServerClient(context.Background(), client.ClientID)
deletedClient, err := ts.Server.getOAuthServerClient(context.Background(), client.ID)
assert.Error(ts.T(), err) // it was soft-deleted
assert.Nil(ts.T(), deletedClient)
}
Expand All @@ -235,8 +235,8 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientListHandler() {

// Check that both clients are in the response (order might vary)
clientIDs := []string{response.Clients[0].ClientID, response.Clients[1].ClientID}
assert.Contains(ts.T(), clientIDs, client1.ClientID)
assert.Contains(ts.T(), clientIDs, client2.ClientID)
assert.Contains(ts.T(), clientIDs, client1.ID.String())
assert.Contains(ts.T(), clientIDs, client2.ID.String())

// Verify client secrets are not included in list response
for _, client := range response.Clients {
Expand Down
19 changes: 7 additions & 12 deletions internal/api/oauthserver/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
"slices"
"time"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/crypto"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/utilities"
)
Expand Down Expand Up @@ -141,11 +141,6 @@ func validateRedirectURI(uri string) error {
return nil
}

// generateClientID generates a URL-safe random client ID
func generateClientID() string {
// Generate a 32-character alphanumeric client ID
return crypto.SecureAlphanumeric(32)
}

// generateClientSecret generates a secure random client secret
func generateClientSecret() string {
Expand Down Expand Up @@ -193,7 +188,7 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer
db := s.db.WithContext(ctx)

client := &models.OAuthServerClient{
ClientID: generateClientID(),
ID: uuid.Must(uuid.NewV4()),
RegistrationType: params.RegistrationType,
ClientType: clientType,
ClientName: utilities.StringPtr(params.ClientName),
Expand Down Expand Up @@ -222,11 +217,11 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer
return client, plaintextSecret, nil
}

// getOAuthServerClient retrieves an OAuth client by client_id
func (s *Server) getOAuthServerClient(ctx context.Context, clientID string) (*models.OAuthServerClient, error) {
// getOAuthServerClient retrieves an OAuth client by ID
func (s *Server) getOAuthServerClient(ctx context.Context, clientID uuid.UUID) (*models.OAuthServerClient, error) {
db := s.db.WithContext(ctx)

client, err := models.FindOAuthServerClientByClientID(db, clientID)
client, err := models.FindOAuthServerClientByID(db, clientID)
if err != nil {
return nil, err
}
Expand All @@ -235,10 +230,10 @@ func (s *Server) getOAuthServerClient(ctx context.Context, clientID string) (*mo
}

// deleteOAuthServerClient soft-deletes an OAuth client
func (s *Server) deleteOAuthServerClient(ctx context.Context, clientID string) error {
func (s *Server) deleteOAuthServerClient(ctx context.Context, clientID uuid.UUID) error {
db := s.db.WithContext(ctx)

client, err := models.FindOAuthServerClientByClientID(db, clientID)
client, err := models.FindOAuthServerClientByID(db, clientID)
if err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions internal/api/oauthserver/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ func (ts *OAuthServiceTestSuite) TestOAuthServerClientServiceMethods() {
assert.Equal(ts.T(), "dynamic", client.RegistrationType)

// Test getOAuthServerClient
retrievedClient, err := ts.Server.getOAuthServerClient(ctx, client.ClientID)
retrievedClient, err := ts.Server.getOAuthServerClient(ctx, client.ID)
require.NoError(ts.T(), err)
assert.Equal(ts.T(), client.ClientID, retrievedClient.ClientID)
assert.Equal(ts.T(), client.ID, retrievedClient.ID)

}

Expand Down Expand Up @@ -133,11 +133,11 @@ func (ts *OAuthServiceTestSuite) TestDeleteOAuthServerClient() {

// Delete the client
ctx := context.Background()
err := ts.Server.deleteOAuthServerClient(ctx, client.ClientID)
err := ts.Server.deleteOAuthServerClient(ctx, client.ID)
require.NoError(ts.T(), err)

// Verify client was soft-deleted
deletedClient, err := ts.Server.getOAuthServerClient(ctx, client.ClientID)
deletedClient, err := ts.Server.getOAuthServerClient(ctx, client.ID)
assert.Error(ts.T(), err) // it was soft-deleted
assert.Nil(ts.T(), deletedClient)
}
Expand Down
22 changes: 3 additions & 19 deletions internal/models/oauth_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ const (

// OAuthServerClient represents an OAuth client application registered with this OAuth server
type OAuthServerClient struct {
ID uuid.UUID `json:"-" db:"id"`
ClientID string `json:"client_id" db:"client_id"`
ID uuid.UUID `json:"client_id" db:"id"`
ClientSecretHash string `json:"-" db:"client_secret_hash"`
RegistrationType string `json:"registration_type" db:"registration_type"`
ClientType string `json:"client_type" db:"client_type"`
Expand Down Expand Up @@ -57,8 +56,8 @@ func (c *OAuthServerClient) BeforeSave(tx *pop.Connection) error {

// Validate performs basic validation on the OAuth client
func (c *OAuthServerClient) Validate() error {
if c.ClientID == "" {
return fmt.Errorf("client_id is required")
if c.ID == uuid.Nil {
return fmt.Errorf("id is required")
}

if c.RegistrationType != "dynamic" && c.RegistrationType != "manual" {
Expand Down Expand Up @@ -182,28 +181,13 @@ func FindOAuthServerClientByID(tx *storage.Connection, id uuid.UUID) (*OAuthServ
return client, nil
}

// FindOAuthServerClientByClientID finds an OAuth client by client_id
func FindOAuthServerClientByClientID(tx *storage.Connection, clientID string) (*OAuthServerClient, error) {
client := &OAuthServerClient{}
if err := tx.Q().Where("client_id = ? AND deleted_at IS NULL", clientID).First(client); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return nil, OAuthServerClientNotFoundError{}
}
return nil, errors.Wrap(err, "error finding OAuth client")
}
return client, nil
}

// CreateOAuthServerClient creates a new OAuth client in the database
func CreateOAuthServerClient(tx *storage.Connection, client *OAuthServerClient) error {
if err := client.Validate(); err != nil {
return err
}

if client.ID == uuid.Nil {
client.ID = uuid.Must(uuid.NewV4())
}

now := time.Now()
client.CreatedAt = now
client.UpdatedAt = now
Expand Down
Loading