diff --git a/cmd/auth-rest/startcmd/parameters.go b/cmd/auth-rest/startcmd/parameters.go index 92d2357..22c80c6 100644 --- a/cmd/auth-rest/startcmd/parameters.go +++ b/cmd/auth-rest/startcmd/parameters.go @@ -44,7 +44,7 @@ type deviceCertParams struct { type oidcParams struct { hydraURL *url.URL callbackURL string - providers map[string]operation.OIDCProviderConfig + providers map[string]*operation.OIDCProviderConfig } type oidcProvidersConfig struct { @@ -52,14 +52,15 @@ type oidcProvidersConfig struct { } type oidcProviderConfig struct { - URL string `yaml:"url"` - ClientID string `yaml:"clientID"` - ClientSecret string `yaml:"clientSecret"` - Name string `yaml:"name"` - SignUpLogoURL string `yaml:"signUpLogoURL"` - SignInLogoURL string `yaml:"signInLogoURL"` - Order int `yaml:"order"` - SkipIssuerCheck bool `yaml:"skipIssuerCheck"` + URL string `yaml:"url"` + ClientID string `yaml:"clientID"` + ClientSecret string `yaml:"clientSecret"` + Name string `yaml:"name"` + SignUpLogoURL string `yaml:"signUpLogoURL"` + SignInLogoURL string `yaml:"signInLogoURL"` + Order int `yaml:"order"` + SkipIssuerCheck bool `yaml:"skipIssuerCheck"` + Scopes []string `yaml:"scopes"` } type bootstrapParams struct { diff --git a/cmd/auth-rest/startcmd/start.go b/cmd/auth-rest/startcmd/start.go index 51dad74..51b3646 100644 --- a/cmd/auth-rest/startcmd/start.go +++ b/cmd/auth-rest/startcmd/start.go @@ -543,10 +543,10 @@ func getOIDCParams(cmd *cobra.Command) (*oidcParams, error) { return nil, fmt.Errorf("failed to parse contents of %s: %w", oidcProvFile, err) } - params.providers = make(map[string]operation.OIDCProviderConfig, len(data.Providers)) + params.providers = make(map[string]*operation.OIDCProviderConfig, len(data.Providers)) for k, v := range data.Providers { - params.providers[k] = operation.OIDCProviderConfig{ + params.providers[k] = &operation.OIDCProviderConfig{ URL: v.URL, ClientID: v.ClientID, ClientSecret: v.ClientSecret, @@ -555,6 +555,7 @@ func getOIDCParams(cmd *cobra.Command) (*oidcParams, error) { SignInLogoURL: v.SignInLogoURL, Order: v.Order, SkipIssuerCheck: v.SkipIssuerCheck, + Scopes: v.Scopes, } } diff --git a/pkg/restapi/operation/operations.go b/pkg/restapi/operation/operations.go index 02e2a37..192897f 100644 --- a/pkg/restapi/operation/operations.go +++ b/pkg/restapi/operation/operations.go @@ -67,7 +67,7 @@ type Operation struct { client httpClient requestTokens map[string]string transientStore storage.Store - oidcProvidersConfig map[string]OIDCProviderConfig + oidcProvidersConfig map[string]*OIDCProviderConfig cachedOIDCProviders map[string]oidcProvider uiEndpoint string bootstrapStore storage.Store @@ -104,7 +104,7 @@ type Config struct { // OIDCConfig holds the OIDC configuration. type OIDCConfig struct { CallbackURL string - Providers map[string]OIDCProviderConfig + Providers map[string]*OIDCProviderConfig } // OIDCProviderConfig holds the configuration for a single OIDC provider. @@ -117,6 +117,7 @@ type OIDCProviderConfig struct { SignInLogoURL string Order int SkipIssuerCheck bool + Scopes []string } // CookieConfig holds cookie configuration. @@ -327,11 +328,23 @@ func (o *Operation) oidcLoginHandler(w http.ResponseWriter, r *http.Request) { return } + provConfig, ok := o.oidcProvidersConfig[providerID] + if !ok { + o.writeErrorResponse(w, http.StatusInternalServerError, "provider not supported: %s", providerID) + + return + } + + scopes := []string{oidc.ScopeOpenID} + if len(provConfig.Scopes) != 0 { + scopes = append(scopes, provConfig.Scopes...) + } else { + scopes = append(scopes, "profile", "email") + } + authOption := oauth2.SetAuthURLParam(providerQueryParam, providerID) redirectURL := provider.OAuth2Config( - oidc.ScopeOpenID, - "profile", - "email", + scopes..., ).AuthCodeURL(state, oauth2.AccessTypeOnline, authOption) http.Redirect(w, r, redirectURL, http.StatusFound) @@ -927,7 +940,7 @@ func (o *Operation) getProvider(providerID string) (oidcProvider, error) { return nil, fmt.Errorf("provider not supported: %s", providerID) } - prov, err := o.initOIDCProvider(providerID, &provider) + prov, err := o.initOIDCProvider(providerID, provider) if err != nil { return nil, fmt.Errorf("failed to init oidc provider: %w", err) } diff --git a/pkg/restapi/operation/operations_test.go b/pkg/restapi/operation/operations_test.go index 42b454e..fa20411 100644 --- a/pkg/restapi/operation/operations_test.go +++ b/pkg/restapi/operation/operations_test.go @@ -104,12 +104,27 @@ func TestOIDCLoginHandler(t *testing.T) { svc.cachedOIDCProviders = map[string]oidcProvider{ provider: &mockOIDCProvider{}, } + svc.oidcProvidersConfig = map[string]*OIDCProviderConfig{provider: {}} w := httptest.NewRecorder() svc.oidcLoginHandler(w, newOIDCLoginRequest(provider)) require.Equal(t, http.StatusFound, w.Code) require.NotEmpty(t, w.Header().Get("location")) }) + t.Run("provider not supported", func(t *testing.T) { + provider := uuid.New().String() + config := config(t) + svc, err := New(config) + require.NoError(t, err) + svc.cookies = mockCookies() + svc.cachedOIDCProviders = map[string]oidcProvider{ + provider: &mockOIDCProvider{}, + } + w := httptest.NewRecorder() + svc.oidcLoginHandler(w, newOIDCLoginRequest(provider)) + require.Equal(t, http.StatusInternalServerError, w.Code) + }) + t.Run("internal server error if cannot open cookie store", func(t *testing.T) { svc, err := New(config(t)) require.NoError(t, err) @@ -160,7 +175,7 @@ func TestOIDCLoginHandler(t *testing.T) { t.Run("error if oidc provider is invalid", func(t *testing.T) { config := config(t) - config.OIDC.Providers = map[string]OIDCProviderConfig{ + config.OIDC.Providers = map[string]*OIDCProviderConfig{ "test": { URL: "INVALID", }, @@ -1818,7 +1833,7 @@ func config(t *testing.T) *Config { return &Config{ OIDC: &OIDCConfig{ CallbackURL: "http://test.com", - Providers: map[string]OIDCProviderConfig{ + Providers: map[string]*OIDCProviderConfig{ "mock1": { URL: mockoidc.StartProvider(t), ClientID: uuid.New().String(),