Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ require (
golang.org/x/net v0.38.0 // indirect
golang.org/x/sync v0.12.0
golang.org/x/sys v0.31.0
golang.org/x/text v0.23.0 // indirect
golang.org/x/text v0.23.0
golang.org/x/time v0.9.0
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/grpc v1.63.2 // indirect
Expand Down
60 changes: 40 additions & 20 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,14 @@ func (a *API) deprecationNotices() {
// NewAPIWithVersion creates a new REST API using the specified version
func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string, opt ...Option) *API {
api := &API{
config: globalConfig,
db: db,
version: version,
oauthServer: oauthserver.NewServer(globalConfig, db),
config: globalConfig,
db: db,
version: version,
}

// Only initialize OAuth server if enabled
if globalConfig.OAuthServer.Enabled {
api.oauthServer = oauthserver.NewServer(globalConfig, db)
}

for _, o := range opt {
Expand Down Expand Up @@ -171,6 +175,10 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
r.Get("/health", api.HealthCheck)
r.Get("/.well-known/jwks.json", api.Jwks)

if globalConfig.OAuthServer.Enabled {
r.Get("/.well-known/oauth-authorization-server", api.oauthServer.OAuthServerMetadata)
}

r.Route("/callback", func(r *router) {
r.Use(api.isValidExternalHost)
r.Use(api.loadFlowState)
Expand All @@ -185,6 +193,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.Get("/settings", api.Settings)

// `/authorize` to initiate OAuth2 authorization flow with the external providers
// where Supabase Auth is an OAuth2 Client
r.Get("/authorize", api.ExternalProviderRedirect)

r.With(api.requireAdminCredentials).Post("/invite", api.Invite)
Expand Down Expand Up @@ -325,27 +335,37 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
})

// Admin only oauth client management endpoints
r.Route("/oauth", func(r *router) {
r.Route("/clients", func(r *router) {
// Manual client registration
r.Post("/", api.oauthServer.AdminOAuthServerClientRegister)

r.Get("/", api.oauthServer.OAuthServerClientList)

r.Route("/{client_id}", func(r *router) {
r.Use(api.oauthServer.LoadOAuthServerClient)
r.Get("/", api.oauthServer.OAuthServerClientGet)
r.Delete("/", api.oauthServer.OAuthServerClientDelete)
if globalConfig.OAuthServer.Enabled {
r.Route("/oauth", func(r *router) {
r.Route("/clients", func(r *router) {
// Manual client registration
r.Post("/", api.oauthServer.AdminOAuthServerClientRegister)

r.Get("/", api.oauthServer.OAuthServerClientList)

r.Route("/{client_id}", func(r *router) {
r.Use(api.oauthServer.LoadOAuthServerClient)
r.Get("/", api.oauthServer.OAuthServerClientGet)
r.Delete("/", api.oauthServer.OAuthServerClientDelete)
})
})
})
})
}
})

// OAuth Dynamic Client Registration endpoint (public, rate limited)
r.Route("/oauth", func(r *router) {
r.With(api.limitHandler(api.limiterOpts.OAuthClientRegister)).
Post("/clients/register", api.oauthServer.OAuthServerClientDynamicRegister)
})
if globalConfig.OAuthServer.Enabled {
r.Route("/oauth", func(r *router) {
r.With(api.limitHandler(api.limiterOpts.OAuthClientRegister)).
Post("/clients/register", api.oauthServer.OAuthServerClientDynamicRegister)

// OAuth 2.1 Authorization endpoints
// `/authorize` to initiate OAuth2 authorization code flow where Supabase Auth is the OAuth2 provider
r.Get("/authorize", api.oauthServer.OAuthServerAuthorize)
r.With(api.requireAuthentication).Get("/authorizations/{authorization_id}", api.oauthServer.OAuthServerGetAuthorization)
r.With(api.requireAuthentication).Post("/authorizations/{authorization_id}/consent", api.oauthServer.OAuthServerConsent)
})
}
})

corsHandler := cors.New(cors.Options{
Expand Down
27 changes: 27 additions & 0 deletions internal/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,30 @@ func TestEmailEnabledByDefault(t *testing.T) {

require.True(t, api.config.External.Email.Enabled)
}

func TestOAuthServerDisabledByDefault(t *testing.T) {
api, _, err := setupAPIForTest()
require.NoError(t, err)

// OAuth server should be disabled by default
require.False(t, api.config.OAuthServer.Enabled)

// OAuth server instance should not be initialized when disabled
require.Nil(t, api.oauthServer)
}

func TestOAuthServerCanBeEnabled(t *testing.T) {
api, _, err := setupAPIForTestWithCallback(func(config *conf.GlobalConfiguration, conn *storage.Connection) {
if config != nil {
// Enable OAuth server
config.OAuthServer.Enabled = true
}
})
require.NoError(t, err)

// OAuth server should be enabled
require.True(t, api.config.OAuthServer.Enabled)

// OAuth server instance should be initialized when enabled
require.NotNil(t, api.oauthServer)
}
3 changes: 3 additions & 0 deletions internal/api/apierrors/errorcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,7 @@ const (
ErrorCodeWeb3UnsupportedChain ErrorCode = "web3_unsupported_chain"
ErrorCodeOAuthDynamicClientRegistrationDisabled ErrorCode = "oauth_dynamic_client_registration_disabled"
ErrorCodeEmailAddressNotProvided ErrorCode = "email_address_not_provided"

ErrorCodeOAuthClientNotFound ErrorCode = "oauth_client_not_found"
ErrorCodeOAuthAuthorizationNotFound ErrorCode = "oauth_authorization_not_found"
)
13 changes: 3 additions & 10 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/url"

jwt "github.com/golang-jwt/jwt/v5"
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/models"
)

Expand All @@ -21,7 +22,6 @@ const (
tokenKey = contextKey("jwt")
inviteTokenKey = contextKey("invite_token")
signatureKey = contextKey("signature")
userKey = contextKey("user")
targetUserKey = contextKey("target_user")
factorKey = contextKey("factor")
sessionKey = contextKey("session")
Expand Down Expand Up @@ -60,7 +60,7 @@ func getClaims(ctx context.Context) *AccessTokenClaims {

// withUser adds the user to the context.
func withUser(ctx context.Context, u *models.User) context.Context {
return context.WithValue(ctx, userKey, u)
return shared.WithUser(ctx, u)
}

// withTargetUser adds the target user for linking to the context.
Expand All @@ -75,14 +75,7 @@ func withFactor(ctx context.Context, f *models.Factor) context.Context {

// getUser reads the user from the context.
func getUser(ctx context.Context) *models.User {
if ctx == nil {
return nil
}
obj := ctx.Value(userKey)
if obj == nil {
return nil
}
return obj.(*models.User)
return shared.GetUser(ctx)
}

// getTargetUser reads the user from the context.
Expand Down
Loading
Loading