Skip to content

Commit

Permalink
Creating providers with config (#3334)
Browse files Browse the repository at this point in the history
* Make the providerClassManager reusable by adding an intermediate providerTracker

In order to be able to add the providerAuthManager in the next step and
avoid code duplication, let's make the providerClassManager a little
more reusable by extracting the code that tracks registered classes.

* Add providerAuthManager

Adds the providerAuthManager interface that exposes per-provider OAuth
settings. This allows the provider creation in an OAuth callback to let
the provider specific code sit in the provider codebase and let the
creation be provider agnostic.

ALso creates a sessionService that currently allows to create a provider
from a session record that is created before the OAuth conversation.

Fixes: #3263

* Enroll providers with OAuth flow and config

Adds the CLI support for creating an OAuth provider with a config.

* Reduce code duplication in handlers_oauth.go

Two or more, use a for.

* Pass providerConfig to GH app on creation

Allows to pass provider configuration when creating a GH app. This will
be useful for creating providers with repository auto-enrollment.

Fixes: #3263

Unlike the OAuth flow which is reusable for any provider, the GH App
flow is quite GH-app specific and thus uses the original GH app handler.

* Only rewrite fallback OAuth config values if set

We added a ProviderConfig-based structure but then in a subsequent patch
we overwrite its values unconditionally. Let's only overwrite them when
set.

* Move fallback OAuth configuration to be instantiated on startup

Based on review feedback, to isolate the fallback code.

* Adjust tests after recent config validation changes
  • Loading branch information
jhrozek committed Jun 4, 2024
1 parent a896873 commit e578bf8
Show file tree
Hide file tree
Showing 17 changed files with 1,432 additions and 320 deletions.
19 changes: 12 additions & 7 deletions cmd/cli/app/provider/provider_enroll.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ func EnrollProviderCommand(ctx context.Context, cmd *cobra.Command, _ []string,
// This will have a different timeout
enrollemntCtx := cmd.Context()

return enrollUsingOAuth2Flow(enrollemntCtx, cmd, oauthClient, providerClient, provider, project, owner, skipBrowser)
return enrollUsingOAuth2Flow(
enrollemntCtx, cmd, oauthClient, providerClient, providerName, provider, project, owner, skipBrowser, config)
}

func enrollUsingToken(
Expand Down Expand Up @@ -183,10 +184,12 @@ func enrollUsingOAuth2Flow(
cmd *cobra.Command,
oauthClient minderv1.OAuthServiceClient,
providerClient minderv1.ProvidersServiceClient,
provider string,
providerName string,
providerClass string,
project string,
owner string,
skipBrowser bool,
providerConfig *structpb.Struct,
) error {
oAuthCallbackCtx, oAuthCancel := context.WithTimeout(ctx, MAX_WAIT+5*time.Second)
defer oAuthCancel()
Expand All @@ -197,7 +200,7 @@ func enrollUsingOAuth2Flow(

// If the user is using the legacy GitHub provider, don't let them enroll a new provider.
// However, they may update the credentials for the existing provider.
if legacyProviderEnrolled && provider != legacyGitHubProvider.ToString() {
if legacyProviderEnrolled && providerName != legacyGitHubProvider.ToString() {
return fmt.Errorf("it seems you are using the legacy github provider. " +
"If you would like to enroll a new provider, please delete your existing provider by " +
"running \"minder provider delete --name github\"")
Expand All @@ -210,10 +213,12 @@ func enrollUsingOAuth2Flow(
}

resp, err := oauthClient.GetAuthorizationURL(ctx, &minderv1.GetAuthorizationURLRequest{
Context: &minderv1.Context{Provider: &provider, Project: &project},
Cli: true,
Port: int32(port),
Owner: &owner,
Context: &minderv1.Context{Provider: &providerName, Project: &project},
Cli: true,
Port: int32(port),
Owner: &owner,
Config: providerConfig,
ProviderClass: providerClass,
})
if err != nil {
return cli.MessageAndError("error getting authorization URL", err)
Expand Down
120 changes: 0 additions & 120 deletions internal/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,115 +19,16 @@ import (
"context"
"fmt"
"net/http"
"slices"

go_github "github.com/google/go-github/v61/github"
"github.com/spf13/viper"
"golang.org/x/oauth2"
"golang.org/x/oauth2/github"

"github.com/stacklok/minder/internal/config/server"
"github.com/stacklok/minder/internal/db"
)

const (
// Github OAuth2 provider
Github = "github"

// GitHubApp provider
GitHubApp = "github-app"
)

var knownProviders = []string{Github, GitHubApp}

func getOAuthClientConfig(c *server.ProviderConfig, provider string) (*server.OAuthClientConfig, error) {
var oc *server.OAuthClientConfig
var err error

if !slices.Contains(knownProviders, provider) {
return nil, fmt.Errorf("invalid provider: %s", provider)
}

// first read the new provider-nested key. If it's missing, fallback to using the older
// top-level keys.
switch provider {
case Github:
if c != nil && c.GitHub != nil {
oc = &c.GitHub.OAuthClientConfig
}
fallbackOAuthClientConfigValues("github", oc)
case GitHubApp:
if c != nil && c.GitHubApp != nil {
oc = &c.GitHubApp.OAuthClientConfig
}
fallbackOAuthClientConfigValues("github-app", oc)
default:
err = fmt.Errorf("unknown provider: %s", provider)
}

return oc, err
}

func fallbackOAuthClientConfigValues(provider string, cfg *server.OAuthClientConfig) {
// we read the values one-by-one instead of just getting the top-level key to allow
// for environment variables to be set per-variable
cfg.ClientID = viper.GetString(fmt.Sprintf("%s.client_id", provider))
cfg.ClientIDFile = viper.GetString(fmt.Sprintf("%s.client_id_file", provider))
cfg.ClientSecret = viper.GetString(fmt.Sprintf("%s.client_secret", provider))
cfg.ClientSecretFile = viper.GetString(fmt.Sprintf("%s.client_secret_file", provider))
cfg.RedirectURI = viper.GetString(fmt.Sprintf("%s.redirect_uri", provider))
}

// NewOAuthConfig creates a new OAuth2 config for the given provider
// and whether the client is a CLI or web client
func NewOAuthConfig(c *server.ProviderConfig, provider string, cli bool) (*oauth2.Config, error) {
oauthConfig, err := getOAuthClientConfig(c, provider)
if err != nil {
return nil, fmt.Errorf("failed to get OAuth client config: %w", err)
}

redirectURL := func(provider string, cli bool) string {
base := oauthConfig.RedirectURI
if provider == GitHubApp {
// GitHub App does not distinguish between CLI and web clients
return base
}
if cli {
return fmt.Sprintf("%s/cli", base)
}
return fmt.Sprintf("%s/web", base)
}

scopes := func(provider string) []string {
if provider == GitHubApp {
return []string{}
}
return []string{"user:email", "repo", "read:packages", "write:packages", "workflow", "read:org"}
}

endpoint := func() oauth2.Endpoint {
return github.Endpoint
}

clientID, err := oauthConfig.GetClientID()
if err != nil {
return nil, fmt.Errorf("failed to get client ID: %w", err)
}

clientSecret, err := oauthConfig.GetClientSecret()
if err != nil {
return nil, fmt.Errorf("failed to get client secret: %w", err)
}

return &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL(provider, cli),
Scopes: scopes(provider),
Endpoint: endpoint(),
}, nil
}

// NewProviderHttpClient creates a new http client for the given provider
func NewProviderHttpClient(provider string) *http.Client {
if provider == Github {
Expand Down Expand Up @@ -158,24 +59,3 @@ func DeleteAccessToken(ctx context.Context, provider string, token string) error
}
return nil
}

// ValidateProviderToken validates the given token for the given provider
func ValidateProviderToken(_ context.Context, provider db.ProviderClass, token string) error {
// Fixme: this should really be handled by the provider. Should this be in the credentials API or the manager?
if provider == db.ProviderClassGithub {
// Create an OAuth2 token source with the PAT
tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})

// Create an authenticated GitHub client
oauth2Client := oauth2.NewClient(context.Background(), tokenSource)
client := go_github.NewClient(oauth2Client)

// Make a sample API request to check token validity
_, _, err := client.Users.Get(context.Background(), "")
if err != nil {
return fmt.Errorf("invalid token: %s", err)
}
return nil
}
return fmt.Errorf("invalid provider: %s", provider)
}
47 changes: 46 additions & 1 deletion internal/config/server/oauth_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,18 @@

package server

import "fmt"
import (
"fmt"

"github.com/spf13/viper"
)

// OAuthEndpoint is the configuration for the OAuth endpoint
// Only used for testing
type OAuthEndpoint struct {
// TokenURL is the OAuth token URL
TokenURL string `mapstructure:"token_url"`
}

// OAuthClientConfig is the configuration for the OAuth client
type OAuthClientConfig struct {
Expand All @@ -29,6 +40,8 @@ type OAuthClientConfig struct {
ClientSecretFile string `mapstructure:"client_secret_file"`
// RedirectURI is the OAuth redirect URI
RedirectURI string `mapstructure:"redirect_uri"`
// Endpoint is the OAuth endpoint. Currently only used for testing
Endpoint *OAuthEndpoint `mapstructure:"endpoint"`
}

// GetClientID returns the OAuth client ID from either the file or the argument
Expand All @@ -46,3 +59,35 @@ func (cfg *OAuthClientConfig) GetClientSecret() (string, error) {
}
return fileOrArg(cfg.ClientSecretFile, cfg.ClientSecret, "client secret")
}

// FallbackOAuthClientConfigValues reads the OAuth client configuration values directly via viper
// this is a temporary hack until we migrate all the configuration to be read from the per-provider
// sections
func FallbackOAuthClientConfigValues(provider string, cfg *OAuthClientConfig) {
// we read the values one-by-one instead of just getting the top-level key to allow
// for environment variables to be set per-variable
fallbackClientID := viper.GetString(fmt.Sprintf("%s.client_id", provider))
if fallbackClientID != "" {
cfg.ClientID = fallbackClientID
}

fallbackClientIDFile := viper.GetString(fmt.Sprintf("%s.client_id_file", provider))
if fallbackClientIDFile != "" {
cfg.ClientIDFile = fallbackClientIDFile
}

fallbackClientSecret := viper.GetString(fmt.Sprintf("%s.client_secret", provider))
if fallbackClientSecret != "" {
cfg.ClientSecret = fallbackClientSecret
}

fallbackClientSecretFile := viper.GetString(fmt.Sprintf("%s.client_secret_file", provider))
if fallbackClientSecretFile != "" {
cfg.ClientSecretFile = fallbackClientSecretFile
}

fallbackRedirectURI := viper.GetString(fmt.Sprintf("%s.redirect_uri", provider))
if fallbackRedirectURI != "" {
cfg.RedirectURI = fallbackRedirectURI
}
}
71 changes: 43 additions & 28 deletions internal/controlplane/handlers_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ import (
"github.com/stacklok/minder/internal/engine"
"github.com/stacklok/minder/internal/logger"
"github.com/stacklok/minder/internal/providers"
"github.com/stacklok/minder/internal/providers/credentials"
"github.com/stacklok/minder/internal/providers/github/service"
"github.com/stacklok/minder/internal/providers/manager"
"github.com/stacklok/minder/internal/util"
pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1"
)
Expand All @@ -63,11 +65,9 @@ func (s *Server) GetAuthorizationURL(ctx context.Context,
providerName = req.GetContext().GetProvider()
}

var providerClass string
if req.GetProviderClass() == "" {
providerClass := req.GetProviderClass()
if providerClass == "" {
providerClass = providerName
} else {
providerClass = req.GetProviderClass()
}

// Telemetry logging
Expand Down Expand Up @@ -157,7 +157,7 @@ func (s *Server) GetAuthorizationURL(ctx context.Context,
}

// Create a new OAuth2 config for the given provider
oauthConfig, err := s.providerAuthFactory(&s.cfg.Provider, providerClass, req.Cli)
oauthConfig, err := s.providerAuthManager.NewOAuthConfig(db.ProviderClass(providerClass), req.Cli)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -234,36 +234,51 @@ func (s *Server) processOAuthCallback(ctx context.Context, w http.ResponseWriter

provider := pathParams["provider"]

// Check the nonce to make sure it's valid
valid, err := mcrypto.IsNonceValid(state, s.cfg.Auth.NoncePeriod)
stateData, err := s.getValidSessionState(ctx, state)
if err != nil {
return fmt.Errorf("error checking nonce: %w", err)
}
if !valid {
return errors.New("invalid nonce")
}

// get projectID from session along with state nonce from the database
stateData, err := s.store.GetProjectIDBySessionState(ctx, state)
if err != nil {
return fmt.Errorf("error getting project ID by session state: %w", err)
return fmt.Errorf("error validating session state: %w", err)
}

// Telemetry logging
logger.BusinessRecord(ctx).Project = stateData.ProjectID
logger.BusinessRecord(ctx).Provider = provider

token, err := s.exchangeCodeForToken(ctx, provider, code)
token, err := s.exchangeCodeForToken(ctx, db.ProviderClass(provider), code)
if err != nil {
return fmt.Errorf("error exchanging code for token: %w", err)
}

p, err := s.ghProviders.CreateGitHubOAuthProvider(ctx, provider, db.ProviderClassGithub, *token, stateData, state)
tokenCred := credentials.NewOAuth2TokenCredential(token.AccessToken)

// Older enrollments may not have a RemoteUser stored; these should age out fairly quickly.
s.mt.AddTokenOpCount(ctx, "check", stateData.RemoteUser.Valid)
err = s.providerAuthManager.ValidateCredentials(
ctx, db.ProviderClass(provider), tokenCred, manager.WithRemoteUser(stateData.RemoteUser.String))
if errors.Is(err, service.ErrInvalidTokenIdentity) {
return newHttpError(http.StatusForbidden, "User token mismatch").SetContents(
"The provided login token was associated with a different user.")
}
if err != nil {
if errors.Is(err, service.ErrInvalidTokenIdentity) {
return newHttpError(http.StatusForbidden, "User token mismatch").SetContents(
"The provided login token was associated with a different GitHub user.")
}
return fmt.Errorf("error validating provider credentials: %w", err)
}

ftoken := &oauth2.Token{
AccessToken: token.AccessToken,
TokenType: token.TokenType,
RefreshToken: "",
}

// encode token
encryptedToken, err := s.cryptoEngine.EncryptOAuthToken(ftoken)
if err != nil {
return fmt.Errorf("error encoding token: %w", err)
}

p, err := s.sessionService.CreateProviderFromSessionState(ctx, db.ProviderClass(provider), &encryptedToken, state)
if db.ErrIsUniqueViolation(err) {
// todo: update config?
zerolog.Ctx(ctx).Info().Str("provider", provider).Msg("Provider already exists")
} else if err != nil {
return fmt.Errorf("error creating provider: %w", err)
}

Expand Down Expand Up @@ -310,9 +325,9 @@ func (s *Server) processAppCallback(ctx context.Context, w http.ResponseWriter,
return err
}

token, err := s.exchangeCodeForToken(ctx, provider, code)
token, err := s.exchangeCodeForToken(ctx, db.ProviderClass(provider), code)
if err != nil {
return err
return fmt.Errorf("error exchanging code for token: %w", err)
}

installationID, err := strconv.ParseInt(installationIDParam, 10, 64)
Expand Down Expand Up @@ -403,9 +418,9 @@ func (s *Server) getValidSessionState(ctx context.Context, state string) (db.Get
return stateData, nil
}

func (s *Server) exchangeCodeForToken(ctx context.Context, providerClass string, code string) (*oauth2.Token, error) {
func (s *Server) exchangeCodeForToken(ctx context.Context, providerClass db.ProviderClass, code string) (*oauth2.Token, error) {
// generate a new OAuth2 config for the given provider
oauthConfig, err := s.providerAuthFactory(&s.cfg.Provider, providerClass, true)
oauthConfig, err := s.providerAuthManager.NewOAuthConfig(providerClass, true)
if err != nil {
return nil, fmt.Errorf("error creating OAuth config: %w", err)
}
Expand Down Expand Up @@ -461,7 +476,7 @@ func (s *Server) StoreProviderToken(ctx context.Context,
}

// validate token
err = auth.ValidateProviderToken(ctx, provider.Class, in.AccessToken)
err = s.providerAuthManager.ValidateCredentials(ctx, provider.Class, in.AccessToken)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid token provided: %v", err)
}
Expand Down
Loading

0 comments on commit e578bf8

Please sign in to comment.