Skip to content

Commit 86b7de4

Browse files
authored
feat(oauth2): use id field as the public client_id (#2154)
## Summary Simplified OAuth client model by removing redundant `client_id` column and using the primary `id` UUID field as the public client identifier.
1 parent 0fd4bb4 commit 86b7de4

File tree

10 files changed

+80
-76
lines changed

10 files changed

+80
-76
lines changed

internal/api/middleware.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"time"
1414

1515
chimiddleware "github.com/go-chi/chi/v5/middleware"
16+
"github.com/gofrs/uuid"
1617
"github.com/sirupsen/logrus"
1718
"github.com/supabase/auth/internal/api/apierrors"
1819
"github.com/supabase/auth/internal/api/oauthserver"
@@ -98,9 +99,15 @@ func (a *API) oauthClientAuth(w http.ResponseWriter, r *http.Request) (context.C
9899
return ctx, nil
99100
}
100101

102+
// Parse client_id as UUID
103+
clientUUID, err := uuid.FromString(clientID)
104+
if err != nil {
105+
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client_id format")
106+
}
107+
101108
// Validate client credentials
102109
db := a.db.WithContext(ctx)
103-
client, err := models.FindOAuthServerClientByClientID(db, clientID)
110+
client, err := models.FindOAuthServerClientByID(db, clientUUID)
104111
if err != nil {
105112
if models.IsNotFoundError(err) {
106113
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeInvalidCredentials, "Invalid client credentials")

internal/api/oauthserver/authorize.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"strings"
1010

1111
"github.com/go-chi/chi/v5"
12+
"github.com/gofrs/uuid"
1213
"github.com/supabase/auth/internal/api/apierrors"
1314
"github.com/supabase/auth/internal/api/shared"
1415
"github.com/supabase/auth/internal/models"
@@ -103,7 +104,13 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er
103104
return err
104105
}
105106

106-
client, err := s.getOAuthServerClient(ctx, params.ClientID)
107+
// Parse client_id as UUID
108+
clientID, err := uuid.FromString(params.ClientID)
109+
if err != nil {
110+
return apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthClientNotFound, "invalid client_id format")
111+
}
112+
113+
client, err := s.getOAuthServerClient(ctx, clientID)
107114
if err != nil {
108115
if models.IsNotFoundError(err) {
109116
return apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthClientNotFound, "invalid client_id")
@@ -144,7 +151,7 @@ func (s *Server) OAuthServerAuthorize(w http.ResponseWriter, r *http.Request) er
144151
}
145152

146153
observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID)
147-
observability.LogEntrySetField(r, "client_id", client.ClientID)
154+
observability.LogEntrySetField(r, "client_id", client.ID.String())
148155

149156
// Redirect to authorization path with authorization_id
150157
if config.OAuthServer.AuthorizationPath == "" {
@@ -228,7 +235,7 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ
228235
response := AuthorizationDetailsResponse{
229236
AuthorizationID: authorization.AuthorizationID,
230237
Client: ClientDetailsResponse{
231-
ClientID: authorization.Client.ClientID,
238+
ClientID: authorization.Client.ID.String(),
232239
ClientName: utilities.StringValue(authorization.Client.ClientName),
233240
ClientURI: utilities.StringValue(authorization.Client.ClientURI),
234241
LogoURI: utilities.StringValue(authorization.Client.LogoURI),
@@ -241,7 +248,7 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ
241248
}
242249

243250
observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID)
244-
observability.LogEntrySetField(r, "client_id", authorization.Client.ClientID)
251+
observability.LogEntrySetField(r, "client_id", authorization.Client.ID.String())
245252

246253
return shared.SendJSON(w, http.StatusOK, response)
247254
}

internal/api/oauthserver/client_auth_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,6 @@ func TestValidateClientAuthentication(t *testing.T) {
298298
// Create test clients
299299
publicClient := &models.OAuthServerClient{
300300
ID: uuid.Must(uuid.NewV4()),
301-
ClientID: "public_client",
302301
ClientType: models.OAuthServerClientTypePublic,
303302
// No client secret hash for public clients
304303
}
@@ -307,7 +306,6 @@ func TestValidateClientAuthentication(t *testing.T) {
307306
secretHash, _ := hashClientSecret("test_secret")
308307
confidentialClient := &models.OAuthServerClient{
309308
ID: uuid.Must(uuid.NewV4()),
310-
ClientID: "confidential_client",
311309
ClientType: models.OAuthServerClientTypeConfidential,
312310
ClientSecretHash: secretHash,
313311
}

internal/api/oauthserver/handlers.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"time"
99

1010
"github.com/go-chi/chi/v5"
11+
"github.com/gofrs/uuid"
1112
"github.com/supabase/auth/internal/api/apierrors"
1213
"github.com/supabase/auth/internal/api/shared"
1314
"github.com/supabase/auth/internal/models"
@@ -53,7 +54,7 @@ func oauthServerClientToResponse(client *models.OAuthServerClient, includeSecret
5354
}
5455

5556
response := &OAuthServerClientResponse{
56-
ClientID: client.ClientID,
57+
ClientID: client.ID.String(),
5758
ClientType: client.ClientType,
5859

5960
// OAuth 2.1 DCR fields
@@ -83,13 +84,19 @@ func oauthServerClientToResponse(client *models.OAuthServerClient, includeSecret
8384
// LoadOAuthServerClient is middleware that loads an OAuth server client from the URL parameter
8485
func (s *Server) LoadOAuthServerClient(w http.ResponseWriter, r *http.Request) (context.Context, error) {
8586
ctx := r.Context()
86-
clientID := chi.URLParam(r, "client_id")
87+
clientIDStr := chi.URLParam(r, "client_id")
8788

88-
if clientID == "" {
89+
if clientIDStr == "" {
8990
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "client_id is required")
9091
}
9192

92-
observability.LogEntrySetField(r, "oauth_client_id", clientID)
93+
// Parse client_id as UUID
94+
clientID, err := uuid.FromString(clientIDStr)
95+
if err != nil {
96+
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "invalid client_id format")
97+
}
98+
99+
observability.LogEntrySetField(r, "oauth_client_id", clientIDStr)
93100

94101
client, err := s.getOAuthServerClient(ctx, clientID)
95102
if err != nil {
@@ -171,7 +178,7 @@ func (s *Server) OAuthServerClientDelete(w http.ResponseWriter, r *http.Request)
171178
ctx := r.Context()
172179
client := GetOAuthServerClient(ctx)
173180

174-
if err := s.deleteOAuthServerClient(ctx, client.ClientID); err != nil {
181+
if err := s.deleteOAuthServerClient(ctx, client.ID); err != nil {
175182
return apierrors.NewInternalServerError("Error deleting OAuth client").WithInternalError(err)
176183
}
177184

internal/api/oauthserver/handlers_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientDynamicRegisterDisabled() {
167167
func (ts *OAuthClientTestSuite) TestOAuthServerClientGetHandler() {
168168
client, _ := ts.createTestOAuthClient()
169169

170-
req := httptest.NewRequest(http.MethodGet, "/admin/oauth/clients/"+client.ClientID, nil)
170+
req := httptest.NewRequest(http.MethodGet, "/admin/oauth/clients/"+client.ID.String(), nil)
171171

172172
ctx := WithOAuthServerClient(req.Context(), client)
173173
req = req.WithContext(ctx)
@@ -183,7 +183,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientGetHandler() {
183183
err = json.Unmarshal(w.Body.Bytes(), &response)
184184
require.NoError(ts.T(), err)
185185

186-
assert.Equal(ts.T(), client.ClientID, response.ClientID)
186+
assert.Equal(ts.T(), client.ID.String(), response.ClientID)
187187
assert.Empty(ts.T(), response.ClientSecret) // Should NOT be included in get response
188188
assert.Equal(ts.T(), "Test Client", response.ClientName)
189189
}
@@ -193,7 +193,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientDeleteHandler() {
193193
client, _ := ts.createTestOAuthClient()
194194

195195
// Create HTTP request with client in context
196-
req := httptest.NewRequest(http.MethodDelete, "/admin/oauth/clients/"+client.ClientID, nil)
196+
req := httptest.NewRequest(http.MethodDelete, "/admin/oauth/clients/"+client.ID.String(), nil)
197197

198198
// Add client to context (normally done by LoadOAuthServerClient middleware)
199199
ctx := WithOAuthServerClient(req.Context(), client)
@@ -208,7 +208,7 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientDeleteHandler() {
208208
assert.Empty(ts.T(), w.Body.String())
209209

210210
// Verify client was soft-deleted
211-
deletedClient, err := ts.Server.getOAuthServerClient(context.Background(), client.ClientID)
211+
deletedClient, err := ts.Server.getOAuthServerClient(context.Background(), client.ID)
212212
assert.Error(ts.T(), err) // it was soft-deleted
213213
assert.Nil(ts.T(), deletedClient)
214214
}
@@ -235,8 +235,8 @@ func (ts *OAuthClientTestSuite) TestOAuthServerClientListHandler() {
235235

236236
// Check that both clients are in the response (order might vary)
237237
clientIDs := []string{response.Clients[0].ClientID, response.Clients[1].ClientID}
238-
assert.Contains(ts.T(), clientIDs, client1.ClientID)
239-
assert.Contains(ts.T(), clientIDs, client2.ClientID)
238+
assert.Contains(ts.T(), clientIDs, client1.ID.String())
239+
assert.Contains(ts.T(), clientIDs, client2.ID.String())
240240

241241
// Verify client secrets are not included in list response
242242
for _, client := range response.Clients {

internal/api/oauthserver/service.go

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ import (
1111
"slices"
1212
"time"
1313

14+
"github.com/gofrs/uuid"
1415
"github.com/pkg/errors"
1516
"github.com/supabase/auth/internal/api/apierrors"
16-
"github.com/supabase/auth/internal/crypto"
1717
"github.com/supabase/auth/internal/models"
1818
"github.com/supabase/auth/internal/utilities"
1919
)
@@ -141,11 +141,6 @@ func validateRedirectURI(uri string) error {
141141
return nil
142142
}
143143

144-
// generateClientID generates a URL-safe random client ID
145-
func generateClientID() string {
146-
// Generate a 32-character alphanumeric client ID
147-
return crypto.SecureAlphanumeric(32)
148-
}
149144

150145
// generateClientSecret generates a secure random client secret
151146
func generateClientSecret() string {
@@ -193,7 +188,7 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer
193188
db := s.db.WithContext(ctx)
194189

195190
client := &models.OAuthServerClient{
196-
ClientID: generateClientID(),
191+
ID: uuid.Must(uuid.NewV4()),
197192
RegistrationType: params.RegistrationType,
198193
ClientType: clientType,
199194
ClientName: utilities.StringPtr(params.ClientName),
@@ -222,11 +217,11 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer
222217
return client, plaintextSecret, nil
223218
}
224219

225-
// getOAuthServerClient retrieves an OAuth client by client_id
226-
func (s *Server) getOAuthServerClient(ctx context.Context, clientID string) (*models.OAuthServerClient, error) {
220+
// getOAuthServerClient retrieves an OAuth client by ID
221+
func (s *Server) getOAuthServerClient(ctx context.Context, clientID uuid.UUID) (*models.OAuthServerClient, error) {
227222
db := s.db.WithContext(ctx)
228223

229-
client, err := models.FindOAuthServerClientByClientID(db, clientID)
224+
client, err := models.FindOAuthServerClientByID(db, clientID)
230225
if err != nil {
231226
return nil, err
232227
}
@@ -235,10 +230,10 @@ func (s *Server) getOAuthServerClient(ctx context.Context, clientID string) (*mo
235230
}
236231

237232
// deleteOAuthServerClient soft-deletes an OAuth client
238-
func (s *Server) deleteOAuthServerClient(ctx context.Context, clientID string) error {
233+
func (s *Server) deleteOAuthServerClient(ctx context.Context, clientID uuid.UUID) error {
239234
db := s.db.WithContext(ctx)
240235

241-
client, err := models.FindOAuthServerClientByClientID(db, clientID)
236+
client, err := models.FindOAuthServerClientByID(db, clientID)
242237
if err != nil {
243238
return err
244239
}

internal/api/oauthserver/service_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ func (ts *OAuthServiceTestSuite) TestOAuthServerClientServiceMethods() {
9292
assert.Equal(ts.T(), "dynamic", client.RegistrationType)
9393

9494
// Test getOAuthServerClient
95-
retrievedClient, err := ts.Server.getOAuthServerClient(ctx, client.ClientID)
95+
retrievedClient, err := ts.Server.getOAuthServerClient(ctx, client.ID)
9696
require.NoError(ts.T(), err)
97-
assert.Equal(ts.T(), client.ClientID, retrievedClient.ClientID)
97+
assert.Equal(ts.T(), client.ID, retrievedClient.ID)
9898

9999
}
100100

@@ -133,11 +133,11 @@ func (ts *OAuthServiceTestSuite) TestDeleteOAuthServerClient() {
133133

134134
// Delete the client
135135
ctx := context.Background()
136-
err := ts.Server.deleteOAuthServerClient(ctx, client.ClientID)
136+
err := ts.Server.deleteOAuthServerClient(ctx, client.ID)
137137
require.NoError(ts.T(), err)
138138

139139
// Verify client was soft-deleted
140-
deletedClient, err := ts.Server.getOAuthServerClient(ctx, client.ClientID)
140+
deletedClient, err := ts.Server.getOAuthServerClient(ctx, client.ID)
141141
assert.Error(ts.T(), err) // it was soft-deleted
142142
assert.Nil(ts.T(), deletedClient)
143143
}

internal/models/oauth_client.go

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ const (
2828

2929
// OAuthServerClient represents an OAuth client application registered with this OAuth server
3030
type OAuthServerClient struct {
31-
ID uuid.UUID `json:"-" db:"id"`
32-
ClientID string `json:"client_id" db:"client_id"`
31+
ID uuid.UUID `json:"client_id" db:"id"`
3332
ClientSecretHash string `json:"-" db:"client_secret_hash"`
3433
RegistrationType string `json:"registration_type" db:"registration_type"`
3534
ClientType string `json:"client_type" db:"client_type"`
@@ -57,8 +56,8 @@ func (c *OAuthServerClient) BeforeSave(tx *pop.Connection) error {
5756

5857
// Validate performs basic validation on the OAuth client
5958
func (c *OAuthServerClient) Validate() error {
60-
if c.ClientID == "" {
61-
return fmt.Errorf("client_id is required")
59+
if c.ID == uuid.Nil {
60+
return fmt.Errorf("id is required")
6261
}
6362

6463
if c.RegistrationType != "dynamic" && c.RegistrationType != "manual" {
@@ -182,28 +181,13 @@ func FindOAuthServerClientByID(tx *storage.Connection, id uuid.UUID) (*OAuthServ
182181
return client, nil
183182
}
184183

185-
// FindOAuthServerClientByClientID finds an OAuth client by client_id
186-
func FindOAuthServerClientByClientID(tx *storage.Connection, clientID string) (*OAuthServerClient, error) {
187-
client := &OAuthServerClient{}
188-
if err := tx.Q().Where("client_id = ? AND deleted_at IS NULL", clientID).First(client); err != nil {
189-
if errors.Cause(err) == sql.ErrNoRows {
190-
return nil, OAuthServerClientNotFoundError{}
191-
}
192-
return nil, errors.Wrap(err, "error finding OAuth client")
193-
}
194-
return client, nil
195-
}
196184

197185
// CreateOAuthServerClient creates a new OAuth client in the database
198186
func CreateOAuthServerClient(tx *storage.Connection, client *OAuthServerClient) error {
199187
if err := client.Validate(); err != nil {
200188
return err
201189
}
202190

203-
if client.ID == uuid.Nil {
204-
client.ID = uuid.Must(uuid.NewV4())
205-
}
206-
207191
now := time.Now()
208192
client.CreatedAt = now
209193
client.UpdatedAt = now

0 commit comments

Comments
 (0)