diff --git a/internal/api/middleware.go b/internal/api/middleware.go index b14d1202f..875b3cada 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -13,6 +13,7 @@ import ( "time" chimiddleware "github.com/go-chi/chi/v5/middleware" + "github.com/gofrs/uuid" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/api/oauthserver" @@ -98,9 +99,15 @@ func (a *API) oauthClientAuth(w http.ResponseWriter, r *http.Request) (context.C return ctx, nil } + // Parse client_id as UUID + clientUUID, err := uuid.FromString(clientID) + if err != nil { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client_id format") + } + // Validate client credentials db := a.db.WithContext(ctx) - client, err := models.FindOAuthServerClientByClientID(db, clientID) + client, err := models.FindOAuthServerClientByID(db, clientUUID) if err != nil { if models.IsNotFoundError(err) { return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials") diff --git a/internal/api/oauthserver/authorize.go b/internal/api/oauthserver/authorize.go index 599a710ac..565e94bad 100644 --- a/internal/api/oauthserver/authorize.go +++ b/internal/api/oauthserver/authorize.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/go-chi/chi/v5" + "github.com/gofrs/uuid" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/models" @@ -103,7 +104,13 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er return err } - client, err := s.getOAuthServerClient(ctx, params.ClientID) + // Parse client_id as UUID + clientID, err := uuid.FromString(params.ClientID) + if err != nil { + return apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthClientNotFound, "invalid client_id format") + } + + client, err := s.getOAuthServerClient(ctx, clientID) if err != nil { if models.IsNotFoundError(err) { return apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthClientNotFound, "invalid client_id") @@ -144,7 +151,7 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er } observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) - observability.LogEntrySetField(r, "client_id", client.ClientID) + observability.LogEntrySetField(r, "client_id", client.ID.String()) // Redirect to authorization path with authorization_id if config.OAuthServer.AuthorizationPath == "" { @@ -228,7 +235,7 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ response := AuthorizationDetailsResponse{ AuthorizationID: authorization.AuthorizationID, Client: ClientDetailsResponse{ - ClientID: authorization.Client.ClientID, + ClientID: authorization.Client.ID.String(), ClientName: utilities.StringValue(authorization.Client.ClientName), ClientURI: utilities.StringValue(authorization.Client.ClientURI), LogoURI: utilities.StringValue(authorization.Client.LogoURI), @@ -241,7 +248,7 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ } observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) - observability.LogEntrySetField(r, "client_id", authorization.Client.ClientID) + observability.LogEntrySetField(r, "client_id", authorization.Client.ID.String()) return shared.SendJSON(w, http.StatusOK, response) } diff --git a/internal/api/oauthserver/client_auth_test.go b/internal/api/oauthserver/client_auth_test.go index 4e21c6650..40fec9c40 100644 --- a/internal/api/oauthserver/client_auth_test.go +++ b/internal/api/oauthserver/client_auth_test.go @@ -298,7 +298,6 @@ func TestValidateClientAuthentication(t *testing.T) { // Create test clients publicClient := &models.OAuthServerClient{ ID: uuid.Must(uuid.NewV4()), - ClientID: "public_client", ClientType: models.OAuthServerClientTypePublic, // No client secret hash for public clients } @@ -307,7 +306,6 @@ func TestValidateClientAuthentication(t *testing.T) { secretHash, _ := hashClientSecret("test_secret") confidentialClient := &models.OAuthServerClient{ ID: uuid.Must(uuid.NewV4()), - ClientID: "confidential_client", ClientType: models.OAuthServerClientTypeConfidential, ClientSecretHash: secretHash, } diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index f6edc1e31..efacfd423 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -8,6 +8,7 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/gofrs/uuid" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/models" @@ -53,7 +54,7 @@ func oauthServerClientToResponse(client *models.OAuthServerClient, includeSecret } response := &OAuthServerClientResponse{ - ClientID: client.ClientID, + ClientID: client.ID.String(), ClientType: client.ClientType, // OAuth 2.1 DCR fields @@ -83,13 +84,19 @@ func oauthServerClientToResponse(client *models.OAuthServerClient, includeSecret // LoadOAuthServerClient is middleware that loads an OAuth server client from the URL parameter func (s *Server) LoadOAuthServerClient(w http.ResponseWriter, r *http.Request) (context.Context, error) { ctx := r.Context() - clientID := chi.URLParam(r, "client_id") + clientIDStr := chi.URLParam(r, "client_id") - if clientID == "" { + if clientIDStr == "" { return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_id is required") } - observability.LogEntrySetField(r, "oauth_client_id", clientID) + // Parse client_id as UUID + clientID, err := uuid.FromString(clientIDStr) + if err != nil { + return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid client_id format") + } + + observability.LogEntrySetField(r, "oauth_client_id", clientIDStr) client, err := s.getOAuthServerClient(ctx, clientID) if err != nil { @@ -171,7 +178,7 @@ func (s *Server) OAuthServerClientDelete(w http.ResponseWriter, r *http.Request) ctx := r.Context() client := GetOAuthServerClient(ctx) - if err := s.deleteOAuthServerClient(ctx, client.ClientID); err != nil { + if err := s.deleteOAuthServerClient(ctx, client.ID); err != nil { return apierrors.NewInternalServerError("Error deleting OAuth client").WithInternalError(err) } diff --git a/internal/api/oauthserver/handlers_test.go b/internal/api/oauthserver/handlers_test.go index 0d74eec83..e75a68a97 100644 --- a/internal/api/oauthserver/handlers_test.go +++ b/internal/api/oauthserver/handlers_test.go @@ -167,7 +167,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientDynamicRegisterDisabled() { func (ts *OAuthClientTestSuite) TestOAuthServerClientGetHandler() { client, _ := ts.createTestOAuthClient() - req := httptest.NewRequest(http.MethodGet, "/admin/oauth/clients/"+client.ClientID, nil) + req := httptest.NewRequest(http.MethodGet, "/admin/oauth/clients/"+client.ID.String(), nil) ctx := WithOAuthServerClient(req.Context(), client) req = req.WithContext(ctx) @@ -183,7 +183,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientGetHandler() { err = json.Unmarshal(w.Body.Bytes(), &response) require.NoError(ts.T(), err) - assert.Equal(ts.T(), client.ClientID, response.ClientID) + assert.Equal(ts.T(), client.ID.String(), response.ClientID) assert.Empty(ts.T(), response.ClientSecret) // Should NOT be included in get response assert.Equal(ts.T(), "Test Client", response.ClientName) } @@ -193,7 +193,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientDeleteHandler() { client, _ := ts.createTestOAuthClient() // Create HTTP request with client in context - req := httptest.NewRequest(http.MethodDelete, "/admin/oauth/clients/"+client.ClientID, nil) + req := httptest.NewRequest(http.MethodDelete, "/admin/oauth/clients/"+client.ID.String(), nil) // Add client to context (normally done by LoadOAuthServerClient middleware) ctx := WithOAuthServerClient(req.Context(), client) @@ -208,7 +208,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientDeleteHandler() { assert.Empty(ts.T(), w.Body.String()) // Verify client was soft-deleted - deletedClient, err := ts.Server.getOAuthServerClient(context.Background(), client.ClientID) + deletedClient, err := ts.Server.getOAuthServerClient(context.Background(), client.ID) assert.Error(ts.T(), err) // it was soft-deleted assert.Nil(ts.T(), deletedClient) } @@ -235,8 +235,8 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientListHandler() { // Check that both clients are in the response (order might vary) clientIDs := []string{response.Clients[0].ClientID, response.Clients[1].ClientID} - assert.Contains(ts.T(), clientIDs, client1.ClientID) - assert.Contains(ts.T(), clientIDs, client2.ClientID) + assert.Contains(ts.T(), clientIDs, client1.ID.String()) + assert.Contains(ts.T(), clientIDs, client2.ID.String()) // Verify client secrets are not included in list response for _, client := range response.Clients { diff --git a/internal/api/oauthserver/service.go b/internal/api/oauthserver/service.go index 95a8e8b46..f940d5f01 100644 --- a/internal/api/oauthserver/service.go +++ b/internal/api/oauthserver/service.go @@ -11,9 +11,9 @@ import ( "slices" "time" + "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/supabase/auth/internal/api/apierrors" - "github.com/supabase/auth/internal/crypto" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/utilities" ) @@ -141,11 +141,6 @@ func validateRedirectURI(uri string) error { return nil } -// generateClientID generates a URL-safe random client ID -func generateClientID() string { - // Generate a 32-character alphanumeric client ID - return crypto.SecureAlphanumeric(32) -} // generateClientSecret generates a secure random client secret func generateClientSecret() string { @@ -193,7 +188,7 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer db := s.db.WithContext(ctx) client := &models.OAuthServerClient{ - ClientID: generateClientID(), + ID: uuid.Must(uuid.NewV4()), RegistrationType: params.RegistrationType, ClientType: clientType, ClientName: utilities.StringPtr(params.ClientName), @@ -222,11 +217,11 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer return client, plaintextSecret, nil } -// getOAuthServerClient retrieves an OAuth client by client_id -func (s *Server) getOAuthServerClient(ctx context.Context, clientID string) (*models.OAuthServerClient, error) { +// getOAuthServerClient retrieves an OAuth client by ID +func (s *Server) getOAuthServerClient(ctx context.Context, clientID uuid.UUID) (*models.OAuthServerClient, error) { db := s.db.WithContext(ctx) - client, err := models.FindOAuthServerClientByClientID(db, clientID) + client, err := models.FindOAuthServerClientByID(db, clientID) if err != nil { return nil, err } @@ -235,10 +230,10 @@ func (s *Server) getOAuthServerClient(ctx context.Context, clientID string) (*mo } // deleteOAuthServerClient soft-deletes an OAuth client -func (s *Server) deleteOAuthServerClient(ctx context.Context, clientID string) error { +func (s *Server) deleteOAuthServerClient(ctx context.Context, clientID uuid.UUID) error { db := s.db.WithContext(ctx) - client, err := models.FindOAuthServerClientByClientID(db, clientID) + client, err := models.FindOAuthServerClientByID(db, clientID) if err != nil { return err } diff --git a/internal/api/oauthserver/service_test.go b/internal/api/oauthserver/service_test.go index 60798a49e..cbad9d692 100644 --- a/internal/api/oauthserver/service_test.go +++ b/internal/api/oauthserver/service_test.go @@ -92,9 +92,9 @@ func (ts *OAuthServiceTestSuite) TestOAuthServerClientServiceMethods() { assert.Equal(ts.T(), "dynamic", client.RegistrationType) // Test getOAuthServerClient - retrievedClient, err := ts.Server.getOAuthServerClient(ctx, client.ClientID) + retrievedClient, err := ts.Server.getOAuthServerClient(ctx, client.ID) require.NoError(ts.T(), err) - assert.Equal(ts.T(), client.ClientID, retrievedClient.ClientID) + assert.Equal(ts.T(), client.ID, retrievedClient.ID) } @@ -133,11 +133,11 @@ func (ts *OAuthServiceTestSuite) TestDeleteOAuthServerClient() { // Delete the client ctx := context.Background() - err := ts.Server.deleteOAuthServerClient(ctx, client.ClientID) + err := ts.Server.deleteOAuthServerClient(ctx, client.ID) require.NoError(ts.T(), err) // Verify client was soft-deleted - deletedClient, err := ts.Server.getOAuthServerClient(ctx, client.ClientID) + deletedClient, err := ts.Server.getOAuthServerClient(ctx, client.ID) assert.Error(ts.T(), err) // it was soft-deleted assert.Nil(ts.T(), deletedClient) } diff --git a/internal/models/oauth_client.go b/internal/models/oauth_client.go index bdb3bdbd3..71ffa833a 100644 --- a/internal/models/oauth_client.go +++ b/internal/models/oauth_client.go @@ -28,8 +28,7 @@ const ( // OAuthServerClient represents an OAuth client application registered with this OAuth server type OAuthServerClient struct { - ID uuid.UUID `json:"-" db:"id"` - ClientID string `json:"client_id" db:"client_id"` + ID uuid.UUID `json:"client_id" db:"id"` ClientSecretHash string `json:"-" db:"client_secret_hash"` RegistrationType string `json:"registration_type" db:"registration_type"` ClientType string `json:"client_type" db:"client_type"` @@ -57,8 +56,8 @@ func (c *OAuthServerClient) BeforeSave(tx *pop.Connection) error { // Validate performs basic validation on the OAuth client func (c *OAuthServerClient) Validate() error { - if c.ClientID == "" { - return fmt.Errorf("client_id is required") + if c.ID == uuid.Nil { + return fmt.Errorf("id is required") } if c.RegistrationType != "dynamic" && c.RegistrationType != "manual" { @@ -182,17 +181,6 @@ func FindOAuthServerClientByID(tx *storage.Connection, id uuid.UUID) (*OAuthServ return client, nil } -// FindOAuthServerClientByClientID finds an OAuth client by client_id -func FindOAuthServerClientByClientID(tx *storage.Connection, clientID string) (*OAuthServerClient, error) { - client := &OAuthServerClient{} - if err := tx.Q().Where("client_id = ? AND deleted_at IS NULL", clientID).First(client); err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return nil, OAuthServerClientNotFoundError{} - } - return nil, errors.Wrap(err, "error finding OAuth client") - } - return client, nil -} // CreateOAuthServerClient creates a new OAuth client in the database func CreateOAuthServerClient(tx *storage.Connection, client *OAuthServerClient) error { @@ -200,10 +188,6 @@ func CreateOAuthServerClient(tx *storage.Connection, client *OAuthServerClient) return err } - if client.ID == uuid.Nil { - client.ID = uuid.Must(uuid.NewV4()) - } - now := time.Now() client.CreatedAt = now client.UpdatedAt = now diff --git a/internal/models/oauth_client_test.go b/internal/models/oauth_client_test.go index 7f21f4630..6cee43611 100644 --- a/internal/models/oauth_client_test.go +++ b/internal/models/oauth_client_test.go @@ -50,7 +50,6 @@ func (ts *OAuthServerClientTestSuite) TestOAuthServerClientValidation() { testSecretHash, _ := testHashClientSecret("test_secret") validClient := &OAuthServerClient{ ID: uuid.Must(uuid.NewV4()), - ClientID: "test_client_id", ClientName: &testClientName, RegistrationType: "dynamic", ClientType: OAuthServerClientTypeConfidential, @@ -63,19 +62,12 @@ func (ts *OAuthServerClientTestSuite) TestOAuthServerClientValidation() { err := validClient.Validate() assert.NoError(ts.T(), err) - // Test missing client_id + // Test missing id invalidClient := *validClient - invalidClient.ClientID = "" + invalidClient.ID = uuid.Nil err = invalidClient.Validate() assert.Error(ts.T(), err) - assert.Contains(ts.T(), err.Error(), "client_id is required") - - // Test missing client_id - invalidClient = *validClient - invalidClient.ClientID = "" - err = invalidClient.Validate() - assert.Error(ts.T(), err) - assert.Contains(ts.T(), err.Error(), "client_id is required") + assert.Contains(ts.T(), err.Error(), "id is required") // Test invalid registration type invalidClient = *validClient @@ -155,7 +147,7 @@ func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClient() { testAppName := "Test Application" testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ - ClientID: "test_client_create_" + uuid.Must(uuid.NewV4()).String()[:8], + ID: uuid.Must(uuid.NewV4()), ClientName: &testAppName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", @@ -175,12 +167,13 @@ func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClient() { func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClientValidation() { invalidClient := &OAuthServerClient{ - ClientID: "", // Missing required field + ID: uuid.Must(uuid.NewV4()), // Provide ID so validation gets to other fields + // Missing required fields like RegistrationType } err := CreateOAuthServerClient(ts.db, invalidClient) assert.Error(ts.T(), err) - assert.Contains(ts.T(), err.Error(), "client_id is required") + assert.Contains(ts.T(), err.Error(), "registration_type must be 'dynamic' or 'manual'") } func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { @@ -188,7 +181,7 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { testName := "Find By ID Test" testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ - ClientID: "test_client_find_by_id_" + uuid.Must(uuid.NewV4()).String()[:8], + ID: uuid.Must(uuid.NewV4()), ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", @@ -203,7 +196,7 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { // Find by ID foundClient, err := FindOAuthServerClientByID(ts.db, client.ID) require.NoError(ts.T(), err) - assert.Equal(ts.T(), client.ClientID, foundClient.ClientID) + assert.Equal(ts.T(), client.ID, foundClient.ID) assert.Equal(ts.T(), *client.ClientName, *foundClient.ClientName) // Test not found @@ -217,7 +210,7 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { testName := "Find By Client ID Test" testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ - ClientID: "test_client_find_by_client_id_" + uuid.Must(uuid.NewV4()).String()[:8], + ID: uuid.Must(uuid.NewV4()), ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "manual", @@ -229,14 +222,14 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { err := CreateOAuthServerClient(ts.db, client) require.NoError(ts.T(), err) - // Find by client_id - foundClient, err := FindOAuthServerClientByClientID(ts.db, client.ClientID) + // Find by ID (which is now the client_id) + foundClient, err := FindOAuthServerClientByID(ts.db, client.ID) require.NoError(ts.T(), err) assert.Equal(ts.T(), client.ID, foundClient.ID) assert.Equal(ts.T(), *client.ClientName, *foundClient.ClientName) // Test not found - _, err = FindOAuthServerClientByClientID(ts.db, "nonexistent_client_id") + _, err = FindOAuthServerClientByID(ts.db, uuid.Must(uuid.NewV4())) assert.Error(ts.T(), err) assert.True(ts.T(), IsNotFoundError(err)) } @@ -246,7 +239,7 @@ func (ts *OAuthServerClientTestSuite) TestUpdateOAuthServerClient() { originalName := "Original Name" testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ - ClientID: "test_client_update_" + uuid.Must(uuid.NewV4()).String()[:8], + ID: uuid.Must(uuid.NewV4()), ClientName: &originalName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", @@ -300,7 +293,7 @@ func (ts *OAuthServerClientTestSuite) TestSoftDelete() { testName := "Soft Delete Test" testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ - ClientID: "test_client_soft_delete_" + uuid.Must(uuid.NewV4()).String()[:8], + ID: uuid.Must(uuid.NewV4()), ClientName: &testName, GrantTypes: "authorization_code,refresh_token", RegistrationType: "dynamic", @@ -320,7 +313,7 @@ func (ts *OAuthServerClientTestSuite) TestSoftDelete() { require.NoError(ts.T(), err) // Verify client is not found in normal queries (which filter out deleted) - _, err = FindOAuthServerClientByClientID(ts.db, client.ClientID) + _, err = FindOAuthServerClientByID(ts.db, client.ID) assert.Error(ts.T(), err) assert.True(ts.T(), IsNotFoundError(err)) } diff --git a/migrations/20250903112500_remove_oauth_client_id_column.up.sql b/migrations/20250903112500_remove_oauth_client_id_column.up.sql new file mode 100644 index 000000000..273ff9e25 --- /dev/null +++ b/migrations/20250903112500_remove_oauth_client_id_column.up.sql @@ -0,0 +1,13 @@ +-- Drop the client_id column and related constraints/indexes from oauth_clients table +-- The id (uuid) field will serve as the public client_id + +-- Drop the unique constraint on client_id +alter table {{ index .Options "Namespace" }}.oauth_clients + drop constraint if exists oauth_clients_client_id_key; + +-- Drop the index on client_id +drop index if exists {{ index .Options "Namespace" }}.oauth_clients_client_id_idx; + +-- Drop the client_id column +alter table {{ index .Options "Namespace" }}.oauth_clients + drop column if exists client_id;