From f17d2ab4948a442e14c2a0bc375075ffee2735b3 Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Mon, 3 Jun 2024 14:37:02 +0800 Subject: [PATCH 1/2] fix: improve logging for unacceptable aud error --- internal/api/token_oidc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index 1c728bf86..73dc71aca 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -179,7 +179,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R } if !correctAudience { - return oauthError("invalid request", "Unacceptable audience in id_token") + return oauthError("invalid request", fmt.Sprintf("Unacceptable audience in id_token: %v", idToken.Audience)) } if !oauthConfig.SkipNonceCheck { From f2f26d4f655bb1dd2face65825dd24394c004968 Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Mon, 3 Jun 2024 14:37:21 +0800 Subject: [PATCH 2/2] chore: refactor & add test --- internal/api/token_oidc.go | 23 ++++++----- internal/api/token_oidc_test.go | 69 +++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 9 deletions(-) create mode 100644 internal/api/token_oidc_test.go diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index 73dc71aca..7b0a8155b 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -24,7 +24,7 @@ type IdTokenGrantParams struct { Issuer string `json:"issuer"` } -func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.GlobalConfiguration, r *http.Request) (*oidc.Provider, *conf.OAuthProviderConfiguration, string, []string, error) { +func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.GlobalConfiguration, r *http.Request) (*oidc.Provider, bool, string, []string, error) { log := observability.GetLogEntry(r).Entry var cfg *conf.OAuthProviderConfiguration @@ -54,7 +54,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa if issuer == "" || !provider.IsAzureIssuer(issuer) { detectedIssuer, err := provider.DetectAzureIDTokenIssuer(ctx, p.IdToken) if err != nil { - return nil, nil, "", nil, badRequestError(ErrorCodeValidationFailed, "Unable to detect issuer in ID token for Azure provider").WithInternalError(err) + return nil, false, "", nil, badRequestError(ErrorCodeValidationFailed, "Unable to detect issuer in ID token for Azure provider").WithInternalError(err) } issuer = detectedIssuer } @@ -95,20 +95,25 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa } if !allowed { - return nil, nil, "", nil, badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) + return nil, false, "", nil, badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) + } + + cfg = &conf.OAuthProviderConfiguration{ + Enabled: true, + SkipNonceCheck: false, } } - if cfg != nil && !cfg.Enabled { - return nil, nil, "", nil, badRequestError(ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) + if !cfg.Enabled { + return nil, false, "", nil, badRequestError(ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) } oidcProvider, err := oidc.NewProvider(ctx, issuer) if err != nil { - return nil, nil, "", nil, err + return nil, false, "", nil, err } - return oidcProvider, cfg, providerType, acceptableClientIDs, nil + return oidcProvider, cfg.SkipNonceCheck, providerType, acceptableClientIDs, nil } // IdTokenGrant implements the id_token grant type flow @@ -131,7 +136,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R return oauthError("invalid request", "provider or client_id and issuer required") } - oidcProvider, oauthConfig, providerType, acceptableClientIDs, err := params.getProvider(ctx, config, r) + oidcProvider, skipNonceCheck, providerType, acceptableClientIDs, err := params.getProvider(ctx, config, r) if err != nil { return err } @@ -182,7 +187,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R return oauthError("invalid request", fmt.Sprintf("Unacceptable audience in id_token: %v", idToken.Audience)) } - if !oauthConfig.SkipNonceCheck { + if !skipNonceCheck { tokenHasNonce := idToken.Nonce != "" paramsHasNonce := params.Nonce != "" diff --git a/internal/api/token_oidc_test.go b/internal/api/token_oidc_test.go new file mode 100644 index 000000000..1eab99ebd --- /dev/null +++ b/internal/api/token_oidc_test.go @@ -0,0 +1,69 @@ +package api + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" +) + +type TokenOIDCTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +func TestTokenOIDC(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &TokenOIDCTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func SetupTestOIDCProvider(ts *TokenOIDCTestSuite) *httptest.Server { + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/openid-configuration": + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"issuer":"` + server.URL + `","authorization_endpoint":"` + server.URL + `/authorize","token_endpoint":"` + server.URL + `/token","jwks_uri":"` + server.URL + `/jwks"}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + return server +} + +func (ts *TokenOIDCTestSuite) TestGetProvider() { + server := SetupTestOIDCProvider(ts) + defer server.Close() + + params := &IdTokenGrantParams{ + IdToken: "test-id-token", + AccessToken: "test-access-token", + Nonce: "test-nonce", + Provider: server.URL, + ClientID: "test-client-id", + Issuer: server.URL, + } + + ts.Config.External.AllowedIdTokenIssuers = []string{server.URL} + + req := httptest.NewRequest(http.MethodPost, "http://localhost", nil) + oidcProvider, skipNonceCheck, providerType, acceptableClientIds, err := params.getProvider(context.Background(), ts.Config, req) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), oidcProvider) + require.False(ts.T(), skipNonceCheck) + require.Equal(ts.T(), params.Provider, providerType) + require.NotEmpty(ts.T(), acceptableClientIds) +}