Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions hack/test.env
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ GOTRUE_EXTERNAL_TWITTER_ENABLED=true
GOTRUE_EXTERNAL_TWITTER_CLIENT_ID=testclientid
GOTRUE_EXTERNAL_TWITTER_SECRET=testsecret
GOTRUE_EXTERNAL_TWITTER_REDIRECT_URI=https://identity.services.netlify.com/callback
GOTRUE_EXTERNAL_X_ENABLED=true
GOTRUE_EXTERNAL_X_CLIENT_ID=testclientid
GOTRUE_EXTERNAL_X_SECRET=testsecret
GOTRUE_EXTERNAL_X_REDIRECT_URI=https://identity.services.netlify.com/callback
GOTRUE_EXTERNAL_ZOOM_ENABLED=true
GOTRUE_EXTERNAL_ZOOM_CLIENT_ID=testclientid
GOTRUE_EXTERNAL_ZOOM_SECRET=testsecret
Expand Down
3 changes: 3 additions & 0 deletions internal/api/apierrors/errorcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ const (
ErrorCodeRefreshTokenAlreadyUsed ErrorCode = "refresh_token_already_used"
ErrorCodeFlowStateNotFound ErrorCode = "flow_state_not_found"
ErrorCodeFlowStateExpired ErrorCode = "flow_state_expired"
ErrorCodeOAuthClientStateNotFound ErrorCode = "oauth_client_state_not_found"
ErrorCodeOAuthClientStateExpired ErrorCode = "oauth_client_state_expired"
ErrorCodeOAuthInvalidState ErrorCode = "oauth_invalid_state"
ErrorCodeSignupDisabled ErrorCode = "signup_disabled"
ErrorCodeUserBanned ErrorCode = "user_banned"
ErrorCodeProviderEmailNeedsVerification ErrorCode = "provider_email_needs_verification"
Expand Down
15 changes: 15 additions & 0 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/url"

"github.com/gofrs/uuid"
jwt "github.com/golang-jwt/jwt/v5"
"github.com/supabase/auth/internal/api/shared"
"github.com/supabase/auth/internal/models"
Expand Down Expand Up @@ -33,6 +34,7 @@ const (
ssoProviderKey = contextKey("sso_provider")
externalHostKey = contextKey("external_host")
flowStateKey = contextKey("flow_state_id")
oauthClientStateKey = contextKey("oauth_client_state_id")
)

// withToken adds the JWT token to the context.
Expand Down Expand Up @@ -137,6 +139,19 @@ func getFlowStateID(ctx context.Context) string {
return obj.(string)
}

func withOAuthClientStateID(ctx context.Context, oauthClientStateID uuid.UUID) context.Context {
return context.WithValue(ctx, oauthClientStateKey, oauthClientStateID)
}

func getOAuthClientStateID(ctx context.Context) uuid.UUID {
obj := ctx.Value(oauthClientStateKey)
if obj == nil {
return uuid.Nil
}

return obj.(uuid.UUID)
}

func getInviteToken(ctx context.Context) string {
obj := ctx.Value(inviteTokenKey)
if obj == nil {
Expand Down
78 changes: 51 additions & 27 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ import (
// ExternalProviderClaims are the JWT claims sent as the state in the external oauth provider signup flow
type ExternalProviderClaims struct {
AuthMicroserviceClaims
Provider string `json:"provider"`
InviteToken string `json:"invite_token,omitempty"`
Referrer string `json:"referrer,omitempty"`
FlowStateID string `json:"flow_state_id"`
LinkingTargetID string `json:"linking_target_id,omitempty"`
EmailOptional bool `json:"email_optional,omitempty"`
Provider string `json:"provider"`
InviteToken string `json:"invite_token,omitempty"`
Referrer string `json:"referrer,omitempty"`
FlowStateID string `json:"flow_state_id"`
OAuthClientStateID string `json:"oauth_client_state_id,omitempty"`
LinkingTargetID string `json:"linking_target_id,omitempty"`
EmailOptional bool `json:"email_optional,omitempty"`
}

// ExternalProviderRedirect redirects the request to the oauth provider
Expand Down Expand Up @@ -90,6 +91,32 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
flowStateID = flowState.ID.String()
}

authUrlParams := make([]oauth2.AuthCodeOption, 0)
query.Del("scopes")
query.Del("provider")
query.Del("code_challenge")
query.Del("code_challenge_method")
for key := range query {
if key == "workos_provider" {
// See https://workos.com/docs/reference/sso/authorize/get
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam("provider", query.Get(key)))
} else {
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam(key, query.Get(key)))
}
}

oauthClientStateID := ""
if oauthProvider, ok := p.(provider.OAuthProvider); ok && oauthProvider.RequiresPKCE() {
codeVerifier := oauth2.GenerateVerifier()
oauthClientState := models.NewOAuthClientState(providerType, &codeVerifier)
err := db.Create(oauthClientState)
if err != nil {
return "", err
}
oauthClientStateID = oauthClientState.ID.String()
authUrlParams = append(authUrlParams, oauth2.S256ChallengeOption(codeVerifier))
}

claims := ExternalProviderClaims{
AuthMicroserviceClaims: AuthMicroserviceClaims{
RegisteredClaims: jwt.RegisteredClaims{
Expand All @@ -98,11 +125,12 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
SiteURL: config.SiteURL,
InstanceID: uuid.Nil.String(),
},
Provider: providerType,
InviteToken: inviteToken,
Referrer: redirectURL,
FlowStateID: flowStateID,
EmailOptional: pConfig.EmailOptional,
Provider: providerType,
InviteToken: inviteToken,
Referrer: redirectURL,
FlowStateID: flowStateID,
OAuthClientStateID: oauthClientStateID,
EmailOptional: pConfig.EmailOptional,
}

if linkingTargetUser != nil {
Expand All @@ -115,20 +143,6 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
return "", apierrors.NewInternalServerError("Error creating state").WithInternalError(err)
}

authUrlParams := make([]oauth2.AuthCodeOption, 0)
query.Del("scopes")
query.Del("provider")
query.Del("code_challenge")
query.Del("code_challenge_method")
for key := range query {
if key == "workos_provider" {
// See https://workos.com/docs/reference/sso/authorize/get
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam("provider", query.Get(key)))
} else {
authUrlParams = append(authUrlParams, oauth2.SetAuthURLParam(key, query.Get(key)))
}
}

authURL := p.AuthCodeURL(tokenString, authUrlParams...)

return authURL, nil
Expand Down Expand Up @@ -565,6 +579,13 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request, db *storag
if claims.FlowStateID != "" {
ctx = withFlowStateID(ctx, claims.FlowStateID)
}
if claims.OAuthClientStateID != "" {
oauthClientStateID, err := uuid.FromString(claims.OAuthClientStateID)
if err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthState, "OAuth callback with invalid state (oauth_client_state_id must be UUID)")
}
ctx = withOAuthClientStateID(ctx, oauthClientStateID)
}
if claims.LinkingTargetID != "" {
linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID)
if err != nil {
Expand Down Expand Up @@ -634,7 +655,7 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide
p, err = provider.NewLinkedinProvider(pConfig, scopes)
case "linkedin_oidc":
pConfig = config.External.LinkedinOIDC
p, err = provider.NewLinkedinOIDCProvider(pConfig, scopes)
p, err = provider.NewLinkedinOIDCProvider(ctx, pConfig, scopes)
case "notion":
pConfig = config.External.Notion
p, err = provider.NewNotionProvider(pConfig)
Expand All @@ -656,9 +677,12 @@ func (a *API) Provider(ctx context.Context, name string, scopes string) (provide
case "twitter":
pConfig = config.External.Twitter
p, err = provider.NewTwitterProvider(pConfig, scopes)
case "x":
pConfig = config.External.X
p, err = provider.NewXProvider(pConfig, scopes)
case "vercel_marketplace":
pConfig = config.External.VercelMarketplace
p, err = provider.NewVercelMarketplaceProvider(pConfig, scopes)
p, err = provider.NewVercelMarketplaceProvider(ctx, pConfig, scopes)
case "workos":
pConfig = config.External.WorkOS
p, err = provider.NewWorkOSProvider(pConfig)
Expand Down
43 changes: 38 additions & 5 deletions internal/api/external_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ import (
"net/http"
"net/url"

"github.com/gofrs/uuid"
"github.com/mrjones/oauth"
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/utilities"
"golang.org/x/oauth2"
)

// OAuthProviderData contains the userData and token returned by the oauth provider
Expand Down Expand Up @@ -55,6 +58,8 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con
}

func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) {
db := a.db.WithContext(ctx)

var rq url.Values
if err := r.ParseForm(); r.Method == http.MethodPost && err == nil {
rq = r.Form
Expand All @@ -72,28 +77,56 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing")
}

oAuthProvider, _, err := a.OAuthProvider(ctx, providerType)
oauthProvider, _, err := a.OAuthProvider(ctx, providerType)
if err != nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err)
}

log := observability.GetLogEntry(r).Entry

var oauthClientState *models.OAuthClientState
// if there's a non-empty OAuthClientStateID we perform PKCE Flow for the external provider
if oauthClientStateID := getOAuthClientStateID(ctx); oauthClientStateID != uuid.Nil {
oauthClientState, err = models.FindAndDeleteOAuthClientStateByID(db, oauthClientStateID)
if models.IsNotFoundError(err) {
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeOAuthClientStateNotFound, "OAuth state not found").WithInternalError(err)
} else if err != nil {
return nil, apierrors.NewInternalServerError("Failed to find OAuth state").WithInternalError(err)
}

if oauthClientState.ProviderType != providerType {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthInvalidState, "OAuth provider mismatch")
}

if oauthClientState.IsExpired() {
return nil, apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeOAuthClientStateExpired, "OAuth state expired")
}
}

if oauthProvider.RequiresPKCE() && oauthClientState == nil {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeOAuthInvalidState, "OAuth PKCE code verifier missing")
}

log.WithFields(logrus.Fields{
"provider": providerType,
"code": oauthCode,
}).Debug("Exchanging oauth code")
}).Debug("Exchanging OAuth code")

token, err := oAuthProvider.GetOAuthToken(oauthCode)
var tokenOpts []oauth2.AuthCodeOption
if oauthClientState != nil {
tokenOpts = append(tokenOpts, oauth2.VerifierOption(*oauthClientState.CodeVerifier))
}
token, err := oauthProvider.GetOAuthToken(ctx, oauthCode, tokenOpts...)
if err != nil {
return nil, apierrors.NewInternalServerError("Unable to exchange external code: %s", oauthCode).WithInternalError(err)
}

userData, err := oAuthProvider.GetUserData(ctx, token)
userData, err := oauthProvider.GetUserData(ctx, token)
if err != nil {
return nil, apierrors.NewInternalServerError("Error getting user profile from external provider").WithInternalError(err)
}

switch externalProvider := oAuthProvider.(type) {
switch externalProvider := oauthProvider.(type) {
case *provider.AppleProvider:
// apple only returns user info the first time
oauthUser := rq.Get("user")
Expand Down
Loading