diff --git a/cmd/cli/app/provider/provider_enroll.go b/cmd/cli/app/provider/provider_enroll.go index 72e34b9e96..bd9a42cac5 100644 --- a/cmd/cli/app/provider/provider_enroll.go +++ b/cmd/cli/app/provider/provider_enroll.go @@ -131,7 +131,8 @@ func EnrollProviderCommand(ctx context.Context, cmd *cobra.Command, _ []string, // This will have a different timeout enrollemntCtx := cmd.Context() - return enrollUsingOAuth2Flow(enrollemntCtx, cmd, oauthClient, providerClient, provider, project, owner, skipBrowser) + return enrollUsingOAuth2Flow( + enrollemntCtx, cmd, oauthClient, providerClient, providerName, provider, project, owner, skipBrowser, config) } func enrollUsingToken( @@ -183,10 +184,12 @@ func enrollUsingOAuth2Flow( cmd *cobra.Command, oauthClient minderv1.OAuthServiceClient, providerClient minderv1.ProvidersServiceClient, - provider string, + providerName string, + providerClass string, project string, owner string, skipBrowser bool, + providerConfig *structpb.Struct, ) error { oAuthCallbackCtx, oAuthCancel := context.WithTimeout(ctx, MAX_WAIT+5*time.Second) defer oAuthCancel() @@ -197,7 +200,7 @@ func enrollUsingOAuth2Flow( // If the user is using the legacy GitHub provider, don't let them enroll a new provider. // However, they may update the credentials for the existing provider. - if legacyProviderEnrolled && provider != legacyGitHubProvider.ToString() { + if legacyProviderEnrolled && providerName != legacyGitHubProvider.ToString() { return fmt.Errorf("it seems you are using the legacy github provider. " + "If you would like to enroll a new provider, please delete your existing provider by " + "running \"minder provider delete --name github\"") @@ -210,10 +213,12 @@ func enrollUsingOAuth2Flow( } resp, err := oauthClient.GetAuthorizationURL(ctx, &minderv1.GetAuthorizationURLRequest{ - Context: &minderv1.Context{Provider: &provider, Project: &project}, - Cli: true, - Port: int32(port), - Owner: &owner, + Context: &minderv1.Context{Provider: &providerName, Project: &project}, + Cli: true, + Port: int32(port), + Owner: &owner, + Config: providerConfig, + ProviderClass: providerClass, }) if err != nil { return cli.MessageAndError("error getting authorization URL", err) diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go index f7d2157b2b..88a2926849 100644 --- a/internal/auth/oauth.go +++ b/internal/auth/oauth.go @@ -19,115 +19,16 @@ import ( "context" "fmt" "net/http" - "slices" go_github "github.com/google/go-github/v61/github" "github.com/spf13/viper" - "golang.org/x/oauth2" - "golang.org/x/oauth2/github" - - "github.com/stacklok/minder/internal/config/server" - "github.com/stacklok/minder/internal/db" ) const ( // Github OAuth2 provider Github = "github" - - // GitHubApp provider - GitHubApp = "github-app" ) -var knownProviders = []string{Github, GitHubApp} - -func getOAuthClientConfig(c *server.ProviderConfig, provider string) (*server.OAuthClientConfig, error) { - var oc *server.OAuthClientConfig - var err error - - if !slices.Contains(knownProviders, provider) { - return nil, fmt.Errorf("invalid provider: %s", provider) - } - - // first read the new provider-nested key. If it's missing, fallback to using the older - // top-level keys. - switch provider { - case Github: - if c != nil && c.GitHub != nil { - oc = &c.GitHub.OAuthClientConfig - } - fallbackOAuthClientConfigValues("github", oc) - case GitHubApp: - if c != nil && c.GitHubApp != nil { - oc = &c.GitHubApp.OAuthClientConfig - } - fallbackOAuthClientConfigValues("github-app", oc) - default: - err = fmt.Errorf("unknown provider: %s", provider) - } - - return oc, err -} - -func fallbackOAuthClientConfigValues(provider string, cfg *server.OAuthClientConfig) { - // we read the values one-by-one instead of just getting the top-level key to allow - // for environment variables to be set per-variable - cfg.ClientID = viper.GetString(fmt.Sprintf("%s.client_id", provider)) - cfg.ClientIDFile = viper.GetString(fmt.Sprintf("%s.client_id_file", provider)) - cfg.ClientSecret = viper.GetString(fmt.Sprintf("%s.client_secret", provider)) - cfg.ClientSecretFile = viper.GetString(fmt.Sprintf("%s.client_secret_file", provider)) - cfg.RedirectURI = viper.GetString(fmt.Sprintf("%s.redirect_uri", provider)) -} - -// NewOAuthConfig creates a new OAuth2 config for the given provider -// and whether the client is a CLI or web client -func NewOAuthConfig(c *server.ProviderConfig, provider string, cli bool) (*oauth2.Config, error) { - oauthConfig, err := getOAuthClientConfig(c, provider) - if err != nil { - return nil, fmt.Errorf("failed to get OAuth client config: %w", err) - } - - redirectURL := func(provider string, cli bool) string { - base := oauthConfig.RedirectURI - if provider == GitHubApp { - // GitHub App does not distinguish between CLI and web clients - return base - } - if cli { - return fmt.Sprintf("%s/cli", base) - } - return fmt.Sprintf("%s/web", base) - } - - scopes := func(provider string) []string { - if provider == GitHubApp { - return []string{} - } - return []string{"user:email", "repo", "read:packages", "write:packages", "workflow", "read:org"} - } - - endpoint := func() oauth2.Endpoint { - return github.Endpoint - } - - clientID, err := oauthConfig.GetClientID() - if err != nil { - return nil, fmt.Errorf("failed to get client ID: %w", err) - } - - clientSecret, err := oauthConfig.GetClientSecret() - if err != nil { - return nil, fmt.Errorf("failed to get client secret: %w", err) - } - - return &oauth2.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - RedirectURL: redirectURL(provider, cli), - Scopes: scopes(provider), - Endpoint: endpoint(), - }, nil -} - // NewProviderHttpClient creates a new http client for the given provider func NewProviderHttpClient(provider string) *http.Client { if provider == Github { @@ -158,24 +59,3 @@ func DeleteAccessToken(ctx context.Context, provider string, token string) error } return nil } - -// ValidateProviderToken validates the given token for the given provider -func ValidateProviderToken(_ context.Context, provider db.ProviderClass, token string) error { - // Fixme: this should really be handled by the provider. Should this be in the credentials API or the manager? - if provider == db.ProviderClassGithub { - // Create an OAuth2 token source with the PAT - tokenSource := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token}) - - // Create an authenticated GitHub client - oauth2Client := oauth2.NewClient(context.Background(), tokenSource) - client := go_github.NewClient(oauth2Client) - - // Make a sample API request to check token validity - _, _, err := client.Users.Get(context.Background(), "") - if err != nil { - return fmt.Errorf("invalid token: %s", err) - } - return nil - } - return fmt.Errorf("invalid provider: %s", provider) -} diff --git a/internal/config/server/oauth_client.go b/internal/config/server/oauth_client.go index 4ccc031df9..95140751f7 100644 --- a/internal/config/server/oauth_client.go +++ b/internal/config/server/oauth_client.go @@ -15,7 +15,18 @@ package server -import "fmt" +import ( + "fmt" + + "github.com/spf13/viper" +) + +// OAuthEndpoint is the configuration for the OAuth endpoint +// Only used for testing +type OAuthEndpoint struct { + // TokenURL is the OAuth token URL + TokenURL string `mapstructure:"token_url"` +} // OAuthClientConfig is the configuration for the OAuth client type OAuthClientConfig struct { @@ -29,6 +40,8 @@ type OAuthClientConfig struct { ClientSecretFile string `mapstructure:"client_secret_file"` // RedirectURI is the OAuth redirect URI RedirectURI string `mapstructure:"redirect_uri"` + // Endpoint is the OAuth endpoint. Currently only used for testing + Endpoint *OAuthEndpoint `mapstructure:"endpoint"` } // GetClientID returns the OAuth client ID from either the file or the argument @@ -46,3 +59,35 @@ func (cfg *OAuthClientConfig) GetClientSecret() (string, error) { } return fileOrArg(cfg.ClientSecretFile, cfg.ClientSecret, "client secret") } + +// FallbackOAuthClientConfigValues reads the OAuth client configuration values directly via viper +// this is a temporary hack until we migrate all the configuration to be read from the per-provider +// sections +func FallbackOAuthClientConfigValues(provider string, cfg *OAuthClientConfig) { + // we read the values one-by-one instead of just getting the top-level key to allow + // for environment variables to be set per-variable + fallbackClientID := viper.GetString(fmt.Sprintf("%s.client_id", provider)) + if fallbackClientID != "" { + cfg.ClientID = fallbackClientID + } + + fallbackClientIDFile := viper.GetString(fmt.Sprintf("%s.client_id_file", provider)) + if fallbackClientIDFile != "" { + cfg.ClientIDFile = fallbackClientIDFile + } + + fallbackClientSecret := viper.GetString(fmt.Sprintf("%s.client_secret", provider)) + if fallbackClientSecret != "" { + cfg.ClientSecret = fallbackClientSecret + } + + fallbackClientSecretFile := viper.GetString(fmt.Sprintf("%s.client_secret_file", provider)) + if fallbackClientSecretFile != "" { + cfg.ClientSecretFile = fallbackClientSecretFile + } + + fallbackRedirectURI := viper.GetString(fmt.Sprintf("%s.redirect_uri", provider)) + if fallbackRedirectURI != "" { + cfg.RedirectURI = fallbackRedirectURI + } +} diff --git a/internal/controlplane/handlers_oauth.go b/internal/controlplane/handlers_oauth.go index e1e3f235d7..e4d964d27b 100644 --- a/internal/controlplane/handlers_oauth.go +++ b/internal/controlplane/handlers_oauth.go @@ -42,7 +42,9 @@ import ( "github.com/stacklok/minder/internal/engine" "github.com/stacklok/minder/internal/logger" "github.com/stacklok/minder/internal/providers" + "github.com/stacklok/minder/internal/providers/credentials" "github.com/stacklok/minder/internal/providers/github/service" + "github.com/stacklok/minder/internal/providers/manager" "github.com/stacklok/minder/internal/util" pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" ) @@ -63,11 +65,9 @@ func (s *Server) GetAuthorizationURL(ctx context.Context, providerName = req.GetContext().GetProvider() } - var providerClass string - if req.GetProviderClass() == "" { + providerClass := req.GetProviderClass() + if providerClass == "" { providerClass = providerName - } else { - providerClass = req.GetProviderClass() } // Telemetry logging @@ -157,7 +157,7 @@ func (s *Server) GetAuthorizationURL(ctx context.Context, } // Create a new OAuth2 config for the given provider - oauthConfig, err := s.providerAuthFactory(&s.cfg.Provider, providerClass, req.Cli) + oauthConfig, err := s.providerAuthManager.NewOAuthConfig(db.ProviderClass(providerClass), req.Cli) if err != nil { return nil, err } @@ -234,36 +234,51 @@ func (s *Server) processOAuthCallback(ctx context.Context, w http.ResponseWriter provider := pathParams["provider"] - // Check the nonce to make sure it's valid - valid, err := mcrypto.IsNonceValid(state, s.cfg.Auth.NoncePeriod) + stateData, err := s.getValidSessionState(ctx, state) if err != nil { - return fmt.Errorf("error checking nonce: %w", err) - } - if !valid { - return errors.New("invalid nonce") - } - - // get projectID from session along with state nonce from the database - stateData, err := s.store.GetProjectIDBySessionState(ctx, state) - if err != nil { - return fmt.Errorf("error getting project ID by session state: %w", err) + return fmt.Errorf("error validating session state: %w", err) } // Telemetry logging logger.BusinessRecord(ctx).Project = stateData.ProjectID logger.BusinessRecord(ctx).Provider = provider - token, err := s.exchangeCodeForToken(ctx, provider, code) + token, err := s.exchangeCodeForToken(ctx, db.ProviderClass(provider), code) if err != nil { return fmt.Errorf("error exchanging code for token: %w", err) } - p, err := s.ghProviders.CreateGitHubOAuthProvider(ctx, provider, db.ProviderClassGithub, *token, stateData, state) + tokenCred := credentials.NewOAuth2TokenCredential(token.AccessToken) + + // Older enrollments may not have a RemoteUser stored; these should age out fairly quickly. + s.mt.AddTokenOpCount(ctx, "check", stateData.RemoteUser.Valid) + err = s.providerAuthManager.ValidateCredentials( + ctx, db.ProviderClass(provider), tokenCred, manager.WithRemoteUser(stateData.RemoteUser.String)) + if errors.Is(err, service.ErrInvalidTokenIdentity) { + return newHttpError(http.StatusForbidden, "User token mismatch").SetContents( + "The provided login token was associated with a different user.") + } if err != nil { - if errors.Is(err, service.ErrInvalidTokenIdentity) { - return newHttpError(http.StatusForbidden, "User token mismatch").SetContents( - "The provided login token was associated with a different GitHub user.") - } + return fmt.Errorf("error validating provider credentials: %w", err) + } + + ftoken := &oauth2.Token{ + AccessToken: token.AccessToken, + TokenType: token.TokenType, + RefreshToken: "", + } + + // encode token + encryptedToken, err := s.cryptoEngine.EncryptOAuthToken(ftoken) + if err != nil { + return fmt.Errorf("error encoding token: %w", err) + } + + p, err := s.sessionService.CreateProviderFromSessionState(ctx, db.ProviderClass(provider), &encryptedToken, state) + if db.ErrIsUniqueViolation(err) { + // todo: update config? + zerolog.Ctx(ctx).Info().Str("provider", provider).Msg("Provider already exists") + } else if err != nil { return fmt.Errorf("error creating provider: %w", err) } @@ -310,9 +325,9 @@ func (s *Server) processAppCallback(ctx context.Context, w http.ResponseWriter, return err } - token, err := s.exchangeCodeForToken(ctx, provider, code) + token, err := s.exchangeCodeForToken(ctx, db.ProviderClass(provider), code) if err != nil { - return err + return fmt.Errorf("error exchanging code for token: %w", err) } installationID, err := strconv.ParseInt(installationIDParam, 10, 64) @@ -403,9 +418,9 @@ func (s *Server) getValidSessionState(ctx context.Context, state string) (db.Get return stateData, nil } -func (s *Server) exchangeCodeForToken(ctx context.Context, providerClass string, code string) (*oauth2.Token, error) { +func (s *Server) exchangeCodeForToken(ctx context.Context, providerClass db.ProviderClass, code string) (*oauth2.Token, error) { // generate a new OAuth2 config for the given provider - oauthConfig, err := s.providerAuthFactory(&s.cfg.Provider, providerClass, true) + oauthConfig, err := s.providerAuthManager.NewOAuthConfig(providerClass, true) if err != nil { return nil, fmt.Errorf("error creating OAuth config: %w", err) } @@ -461,7 +476,7 @@ func (s *Server) StoreProviderToken(ctx context.Context, } // validate token - err = auth.ValidateProviderToken(ctx, provider.Class, in.AccessToken) + err = s.providerAuthManager.ValidateCredentials(ctx, provider.Class, in.AccessToken) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid token provided: %v", err) } diff --git a/internal/controlplane/handlers_oauth_test.go b/internal/controlplane/handlers_oauth_test.go index 9103693472..87c7fa7b73 100644 --- a/internal/controlplane/handlers_oauth_test.go +++ b/internal/controlplane/handlers_oauth_test.go @@ -46,54 +46,150 @@ import ( mockjwt "github.com/stacklok/minder/internal/auth/mock" serverconfig "github.com/stacklok/minder/internal/config/server" "github.com/stacklok/minder/internal/controlplane/metrics" + "github.com/stacklok/minder/internal/crypto" "github.com/stacklok/minder/internal/db" "github.com/stacklok/minder/internal/engine" "github.com/stacklok/minder/internal/events" "github.com/stacklok/minder/internal/providers" + "github.com/stacklok/minder/internal/providers/dockerhub" mockclients "github.com/stacklok/minder/internal/providers/github/clients/mock" + ghmanager "github.com/stacklok/minder/internal/providers/github/manager" mockgh "github.com/stacklok/minder/internal/providers/github/mock" ghService "github.com/stacklok/minder/internal/providers/github/service" mockprovsvc "github.com/stacklok/minder/internal/providers/github/service/mock" + "github.com/stacklok/minder/internal/providers/manager" + mockmanager "github.com/stacklok/minder/internal/providers/manager/mock" + "github.com/stacklok/minder/internal/providers/session" pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1" provinfv1 "github.com/stacklok/minder/pkg/providers/v1" ) -func TestNewOAuthConfig(t *testing.T) { +func Test_NewOAuthConfig(t *testing.T) { t.Parallel() - pc := &serverconfig.ProviderConfig{ - GitHub: &serverconfig.GitHubConfig{ - OAuthClientConfig: serverconfig.OAuthClientConfig{ - ClientID: "clientID", - ClientSecret: "clientSecret", - RedirectURI: "redirectURI", + const ( + githubRedirectURI = "http://github" + githubClientID = "ghClientID" + githubClientSecret = "ghClientSecret" + githubAppRedirectURI = "http://github-app" + githubAppClientID = "ghAppClientID" + githubAppClientSecret = "ghAppClientSecret" //nolint: gosec // This is a test file + ) + + scenarios := []struct { + name string + providerClass db.ProviderClass + cli bool + expected *oauth2.Config + err string + }{ + { + name: "github cli", + providerClass: db.ProviderClassGithub, + cli: true, + expected: &oauth2.Config{ + ClientID: githubClientID, + ClientSecret: githubClientSecret, + RedirectURL: githubRedirectURI + "/cli", + Endpoint: github.Endpoint, + Scopes: []string{"repo", "read:org", "workflow"}, }, }, + { + name: "github web", + providerClass: db.ProviderClassGithub, + cli: false, + expected: &oauth2.Config{ + ClientID: githubClientID, + ClientSecret: githubClientSecret, + RedirectURL: githubRedirectURI + "/web", + Endpoint: github.Endpoint, + Scopes: []string{"repo", "read:org", "workflow"}, + }, + }, + { + name: "github app cli", + providerClass: db.ProviderClassGithubApp, + cli: true, + expected: &oauth2.Config{ + ClientID: githubAppClientID, + ClientSecret: githubAppClientSecret, + RedirectURL: githubAppRedirectURI, + Endpoint: github.Endpoint, + Scopes: []string{}, + }, + }, + { + name: "github app web", + providerClass: db.ProviderClassGithubApp, + cli: true, + expected: &oauth2.Config{ + ClientID: githubAppClientID, + ClientSecret: githubAppClientSecret, + RedirectURL: githubAppRedirectURI, + Endpoint: github.Endpoint, + Scopes: []string{}, + }, + }, + { + name: "dockerhub fails as expected", + providerClass: db.ProviderClassDockerhub, + cli: true, + err: "class manager does not implement OAuthManager", + }, } - // Test with CLI set - cfg, err := auth.NewOAuthConfig(pc, "github", true) - if err != nil { - t.Errorf("Error in newOAuthConfig: %v", err) - } - - if cfg.Endpoint != github.Endpoint { - t.Errorf("Unexpected endpoint: %v", cfg.Endpoint) - } + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + t.Parallel() - // Test with CLI set - cfg, err = auth.NewOAuthConfig(pc, "github", false) - if err != nil { - t.Errorf("Error in newOAuthConfig: %v", err) - } + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) - if cfg.Endpoint != github.Endpoint { - t.Errorf("Unexpected endpoint: %v", cfg.Endpoint) - } + githubProviderManager := ghmanager.NewGitHubProviderClassManager( + nil, + nil, + &serverconfig.ProviderConfig{ + GitHub: &serverconfig.GitHubConfig{ + OAuthClientConfig: serverconfig.OAuthClientConfig{ + ClientID: githubClientID, + ClientSecret: githubClientSecret, + RedirectURI: githubRedirectURI, + }, + }, + GitHubApp: &serverconfig.GitHubAppConfig{ + OAuthClientConfig: serverconfig.OAuthClientConfig{ + ClientID: githubAppClientID, + ClientSecret: githubAppClientSecret, + RedirectURI: githubAppRedirectURI, + }, + }, + }, + nil, + nil, + nil, + nil, + nil, + ) + dockerhubProviderManager := dockerhub.NewDockerHubProviderClassManager(nil, nil) + + providerAuthManager, err := manager.NewAuthManager(githubProviderManager, dockerhubProviderManager) + require.NoError(t, err) - _, err = auth.NewOAuthConfig(pc, "invalid", true) - if err == nil { - t.Errorf("Expected error in newOAuthConfig, but got nil") + config, err := providerAuthManager.NewOAuthConfig(scenario.providerClass, scenario.cli) + if scenario.err == "" { + require.NoError(t, err) + require.NotNil(t, config) + require.Equal(t, scenario.expected.ClientID, config.ClientID) + require.Equal(t, scenario.expected.ClientSecret, config.ClientSecret) + require.Equal(t, scenario.expected.RedirectURL, config.RedirectURL) + require.Equal(t, scenario.expected.Endpoint, config.Endpoint) + require.Subsetf(t, config.Scopes, scenario.expected.Scopes, "expected: %v, got: %v", scenario.expected.Scopes, config.Scopes) + } else { + require.Error(t, err) + require.ErrorContains(t, err, scenario.err) + } + }) } } @@ -317,6 +413,8 @@ func TestGetAuthorizationURL(t *testing.T) { }, } mockJwt := mockjwt.NewMockJwtValidator(ctrl) + mockAuthManager := mockmanager.NewMockAuthManager(ctrl) + mockAuthManager.EXPECT().NewOAuthConfig(gomock.Any(), gomock.Any()).Return(&oauth2.Config{}, nil).AnyTimes() server := &Server{ store: store, @@ -324,7 +422,7 @@ func TestGetAuthorizationURL(t *testing.T) { evt: evt, cfg: c, mt: metrics.NewNoopMetrics(), - providerAuthFactory: auth.NewOAuthConfig, + providerAuthManager: mockAuthManager, } res, err := server.GetAuthorizationURL(ctx, tc.req) @@ -339,35 +437,67 @@ func TestProviderCallback(t *testing.T) { projectID := uuid.New() code := "0xefbeadde" + withProviderSearch := func(store *mockdb.MockStore) { + store.EXPECT().GetParentProjects(gomock.Any(), projectID).Return([]uuid.UUID{projectID}, nil) + store.EXPECT().FindProviders(gomock.Any(), gomock.Any()). + Return([]db.Provider{ + { + Name: "github", + Implements: []db.ProviderType{db.ProviderTypeGithub}, + Version: provinfv1.V1, + }, + }, nil) + } + + withProviderCreate := func(store *mockdb.MockStore) { + store.EXPECT().GetParentProjects(gomock.Any(), projectID).Return([]uuid.UUID{projectID}, nil) + store.EXPECT().FindProviders(gomock.Any(), gomock.Any()). + Return([]db.Provider{}, nil) + store.EXPECT().CreateProvider(gomock.Any(), gomock.Any()).Return(db.Provider{}, nil) + } + testCases := []struct { - name string - redirectUrl string - remoteUser sql.NullString - code int - existingProvider bool - err string + name string + redirectUrl string + remoteUser sql.NullString + code int + storeMockSetup func(store *mockdb.MockStore) + projectIDBySessionNumCalls int + err string }{{ - name: "Success", - redirectUrl: "http://localhost:8080", - existingProvider: true, - code: 307, + name: "Success", + redirectUrl: "http://localhost:8080", + projectIDBySessionNumCalls: 2, + storeMockSetup: func(store *mockdb.MockStore) { + withProviderSearch(store) + }, + code: 307, }, { - name: "Success with remote user", - redirectUrl: "http://localhost:8080", - remoteUser: sql.NullString{Valid: true, String: "31337"}, - existingProvider: true, - code: 307, + name: "Success with remote user", + redirectUrl: "http://localhost:8080", + remoteUser: sql.NullString{Valid: true, String: "31337"}, + projectIDBySessionNumCalls: 2, + storeMockSetup: func(store *mockdb.MockStore) { + withProviderSearch(store) + }, + code: 307, }, { - name: "Wrong remote userid", - remoteUser: sql.NullString{Valid: true, String: "1234"}, - existingProvider: true, - code: 403, - err: "The provided login token was associated with a different GitHub user.\n", + name: "Wrong remote userid", + remoteUser: sql.NullString{Valid: true, String: "1234"}, + projectIDBySessionNumCalls: 1, + storeMockSetup: func(_ *mockdb.MockStore) { + // this codepath fails before the store is called + }, + code: 403, + err: "The provided login token was associated with a different user.\n", }, { - name: "No existing provider", - redirectUrl: "http://localhost:8080", - existingProvider: false, - code: 307, + name: "No existing provider", + redirectUrl: "http://localhost:8080", + code: 307, + projectIDBySessionNumCalls: 2, + storeMockSetup: func(store *mockdb.MockStore) { + withProviderCreate(store) + }, }} for _, tt := range testCases { @@ -420,10 +550,73 @@ func TestProviderCallback(t *testing.T) { BuildOAuthClient(gomock.Any(), gomock.Any(), gomock.Any()). Return(nil, delegate, nil) } + tc.storeMockSetup(store) s, _ := newDefaultServer(t, store, clientFactory) + s.cfg.Provider = serverconfig.ProviderConfig{ + GitHub: &serverconfig.GitHubConfig{ + OAuthClientConfig: serverconfig.OAuthClientConfig{ + Endpoint: &serverconfig.OAuthEndpoint{ + TokenURL: oauthServer.URL, + }, + }, + }, + } + + tokenKeyPath := generateTokenKey(t) + eng, err := crypto.NewEngineFromConfig(&serverconfig.Config{ + Auth: serverconfig.AuthConfig{ + TokenKey: tokenKeyPath, + }, + }) + require.NoError(t, err) + + ghClientService := ghService.NewGithubProviderService( + store, + eng, + metrics.NewNoopMetrics(), + // These nil dependencies do not matter for the current tests + &serverconfig.ProviderConfig{ + GitHubApp: &serverconfig.GitHubAppConfig{ + WebhookSecret: "test", + }, + }, + nil, + clientFactory, + ) + + githubProviderManager := ghmanager.NewGitHubProviderClassManager( + nil, + nil, + &serverconfig.ProviderConfig{ + GitHub: &serverconfig.GitHubConfig{ + OAuthClientConfig: serverconfig.OAuthClientConfig{ + Endpoint: &serverconfig.OAuthEndpoint{ + TokenURL: oauthServer.URL, + }, + }, + }, + }, + nil, + nil, + nil, + nil, + ghClientService, + ) + dockerhubProviderManager := dockerhub.NewDockerHubProviderClassManager(nil, nil) + + authManager, err := manager.NewAuthManager(githubProviderManager, dockerhubProviderManager) + require.NoError(t, err) + s.providerAuthManager = authManager + + providerStore := providers.NewProviderStore(store) + providerManager, err := manager.NewProviderManager(providerStore, githubProviderManager, dockerhubProviderManager) + require.NoError(t, err) + s.providerManager = providerManager + + sessionService := session.NewProviderSessionService(providerManager, providerStore, store) + s.sessionService = sessionService - var err error encryptedUrlString, err := s.cryptoEngine.EncryptString(tc.redirectUrl) if err != nil { t.Fatalf("Failed to encrypt redirect URL: %v", err) @@ -435,12 +628,6 @@ func TestProviderCallback(t *testing.T) { serialized, err := encryptedUrlString.Serialize() require.NoError(t, err) - tx := sql.Tx{} - store.EXPECT().BeginTransaction().Return(&tx, nil) - store.EXPECT().GetQuerierWithTransaction(gomock.Any()).Return(store) - - gh.EXPECT().GetUserId(gomock.Any()).Return(int64(31337), nil).AnyTimes() - store.EXPECT().GetProjectIDBySessionState(gomock.Any(), state).Return( db.GetProjectIDBySessionStateRow{ ProjectID: projectID, @@ -450,36 +637,15 @@ func TestProviderCallback(t *testing.T) { Valid: true, }, RemoteUser: tc.remoteUser, - }, nil) - - if tc.existingProvider { - store.EXPECT().GetProviderByName(gomock.Any(), gomock.Any()).Return( - db.Provider{ - Name: "github", - Implements: []db.ProviderType{db.ProviderTypeGithub}, - Version: provinfv1.V1, - }, nil) - } else { - store.EXPECT().GetProviderByName(gomock.Any(), gomock.Any()).Return( - db.Provider{}, sql.ErrNoRows) - store.EXPECT().CreateProvider(gomock.Any(), gomock.Any()).Return(db.Provider{}, nil) - } + }, nil).Times(tc.projectIDBySessionNumCalls) if tc.code < http.StatusBadRequest { store.EXPECT().UpsertAccessToken(gomock.Any(), gomock.Any()).Return( db.ProviderAccessToken{}, nil) - store.EXPECT().Commit(gomock.Any()) } - store.EXPECT().Rollback(gomock.Any()) t.Logf("Request: %+v", req.URL) - s.providerAuthFactory = func(_ *serverconfig.ProviderConfig, _ string, _ bool) (*oauth2.Config, error) { - return &oauth2.Config{ - Endpoint: oauth2.Endpoint{ - TokenURL: oauthServer.URL, - }, - }, nil - } + s.HandleOAuthCallback()(&resp, &req, params) t.Logf("Response: %v", resp.Code) @@ -653,28 +819,38 @@ func TestHandleGitHubAppCallback(t *testing.T) { } })) defer oauthServer.Close() - providerAuthFactory := func(_ *serverconfig.ProviderConfig, _ string, _ bool) (*oauth2.Config, error) { - return &oauth2.Config{ - Endpoint: oauth2.Endpoint{ - TokenURL: oauthServer.URL, - }, - }, nil - } providerService := mockprovsvc.NewMockGitHubProviderService(ctrl) store := mockdb.NewMockStore(ctrl) gh := mockgh.NewMockClientService(ctrl) + mockAuthManager := mockmanager.NewMockAuthManager(ctrl) + mockAuthManager.EXPECT().NewOAuthConfig(gomock.Any(), gomock.Any()).Return(&oauth2.Config{ + Endpoint: oauth2.Endpoint{ + TokenURL: oauthServer.URL, + }, + }, nil).AnyTimes() + tc.buildStubs(store, providerService, gh) s := &Server{ store: store, ghProviders: providerService, evt: evt, - providerAuthFactory: providerAuthFactory, + providerAuthManager: mockAuthManager, ghClient: gh, cfg: &serverconfig.Config{ Auth: serverconfig.AuthConfig{}, + Provider: serverconfig.ProviderConfig{ + GitHub: &serverconfig.GitHubConfig{ + OAuthClientConfig: serverconfig.OAuthClientConfig{ + ClientID: "clientID", + Endpoint: &serverconfig.OAuthEndpoint{ + TokenURL: oauthServer.URL, + }, + }, + }, + }, }, } @@ -799,25 +975,25 @@ func TestVerifyProviderCredential(t *testing.T) { } })) defer oauthServer.Close() - providerAuthFactory := func(_ *serverconfig.ProviderConfig, _ string, _ bool) (*oauth2.Config, error) { - return &oauth2.Config{ - Endpoint: oauth2.Endpoint{ - TokenURL: oauthServer.URL, - }, - }, nil - } - store := mockdb.NewMockStore(ctrl) tc.buildStubs(store) s := &Server{ - store: store, - evt: evt, - providerAuthFactory: providerAuthFactory, - providerStore: providers.NewProviderStore(store), + store: store, + evt: evt, + providerStore: providers.NewProviderStore(store), cfg: &serverconfig.Config{ Auth: serverconfig.AuthConfig{}, + Provider: serverconfig.ProviderConfig{ + GitHub: &serverconfig.GitHubConfig{ + OAuthClientConfig: serverconfig.OAuthClientConfig{ + Endpoint: &serverconfig.OAuthEndpoint{ + TokenURL: oauthServer.URL, + }, + }, + }, + }, }, } diff --git a/internal/controlplane/handlers_providers_test.go b/internal/controlplane/handlers_providers_test.go index 98f862b029..96c9bd2f14 100644 --- a/internal/controlplane/handlers_providers_test.go +++ b/internal/controlplane/handlers_providers_test.go @@ -65,8 +65,9 @@ func newPbStruct(t *testing.T, data map[string]interface{}) *structpb.Struct { } type mockServer struct { - server *Server - mockStore *mockdb.MockStore + server *Server + mockStore *mockdb.MockStore + mockGhService *mockprovsvc.MockGitHubProviderService } func testServer(t *testing.T, ctrl *gomock.Controller) *mockServer { @@ -106,12 +107,14 @@ func testServer(t *testing.T, ctrl *gomock.Controller) *mockServer { cryptoEngine: mockCryptoEngine, store: mockStore, providerManager: providerManager, + ghProviders: mockProvidersSvc, cfg: &serverconfig.Config{}, } return &mockServer{ - server: &server, - mockStore: mockStore, + server: &server, + mockStore: mockStore, + mockGhService: mockProvidersSvc, } } @@ -232,6 +235,17 @@ func TestCreateProvider(t *testing.T) { jsonConfig, err := scenario.expected.Config.MarshalJSON() require.NoError(t, err) + var jsonUserConfig []byte + if scenario.userConfig != nil { + jsonUserConfig, err = scenario.userConfig.MarshalJSON() + require.NoError(t, err) + } + + if scenario.providerClass == db.ProviderClassGithubApp || scenario.providerClass == db.ProviderClassGithub { + fakeServer.mockGhService.EXPECT().GetConfig(gomock.Any(), scenario.providerClass, jsonUserConfig). + Return(jsonConfig, nil) + } + fakeServer.mockStore.EXPECT().CreateProvider(gomock.Any(), partialCreateParamsMatcher{ value: db.CreateProviderParams{ Name: scenario.name, @@ -341,6 +355,9 @@ func TestCreateProviderFailures(t *testing.T) { Provider: engine.Provider{Name: providerName}, }) + fakeServer.mockGhService.EXPECT().GetConfig(gomock.Any(), db.ProviderClassGithub, gomock.Any()). + Return(json.RawMessage(`{ "github": {} }`), nil) + fakeServer.mockStore.EXPECT().CreateProvider(gomock.Any(), gomock.Any()). Return(db.Provider{}, &pq.Error{Code: "23505"}) // unique_violation @@ -405,6 +422,9 @@ func TestCreateProviderFailures(t *testing.T) { fakeServer := testServer(t, ctrl) providerName := "bad-github-app" + fakeServer.mockGhService.EXPECT().GetConfig(gomock.Any(), db.ProviderClassGithubApp, gomock.Any()). + Return(json.RawMessage(`{ "auto_registration": { "entities": { "blah": {"enabled": true }}}, "github-app": {}}`), nil) + _, err := fakeServer.server.CreateProvider(context.Background(), &minder.CreateProviderRequest{ Context: &minder.Context{ Project: &projectIDStr, diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index 511c23bbc8..39e4c69af1 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -40,7 +40,6 @@ import ( "go.opentelemetry.io/otel/sdk/resource" sdktrace "go.opentelemetry.io/otel/sdk/trace" semconv "go.opentelemetry.io/otel/semconv/v1.17.0" - "golang.org/x/oauth2" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/reflection" @@ -60,6 +59,7 @@ import ( ghprov "github.com/stacklok/minder/internal/providers/github" "github.com/stacklok/minder/internal/providers/github/service" "github.com/stacklok/minder/internal/providers/manager" + "github.com/stacklok/minder/internal/providers/session" "github.com/stacklok/minder/internal/repositories/github" "github.com/stacklok/minder/internal/ruletypes" "github.com/stacklok/minder/internal/util" @@ -78,29 +78,30 @@ var ( // Server represents the controlplane server type Server struct { - store db.Store - cfg *serverconfig.Config - evt events.Publisher - mt metrics.Metrics - grpcServer *grpc.Server - jwt auth.JwtValidator - providerAuthFactory func(*serverconfig.ProviderConfig, string, bool) (*oauth2.Config, error) - authzClient authz.Client - idClient auth.Resolver - cryptoEngine crypto.Engine - featureFlags openfeature.IClient + store db.Store + cfg *serverconfig.Config + evt events.Publisher + mt metrics.Metrics + grpcServer *grpc.Server + jwt auth.JwtValidator + authzClient authz.Client + idClient auth.Resolver + cryptoEngine crypto.Engine + featureFlags openfeature.IClient // We may want to start breaking up the server struct if we use it to // inject more entity-specific interfaces. For example, we may want to // consider having a struct per grpc service - ruleTypes ruletypes.RuleTypeService - repos github.RepositoryService - profiles profiles.ProfileService - ghProviders service.GitHubProviderService - providerStore providers.ProviderStore - ghClient ghprov.ClientService - providerManager manager.ProviderManager - projectCreator projects.ProjectCreator - projectDeleter projects.ProjectDeleter + ruleTypes ruletypes.RuleTypeService + repos github.RepositoryService + profiles profiles.ProfileService + ghProviders service.GitHubProviderService + providerStore providers.ProviderStore + ghClient ghprov.ClientService + providerManager manager.ProviderManager + sessionService session.ProviderSessionService + providerAuthManager manager.AuthManager + projectCreator projects.ProjectCreator + projectDeleter projects.ProjectDeleter // Implementations for service registration pb.UnimplementedHealthServiceServer @@ -130,7 +131,9 @@ func NewServer( ruleService ruletypes.RuleTypeService, ghProviders service.GitHubProviderService, providerManager manager.ProviderManager, + providerAuthManager manager.AuthManager, providerStore providers.ProviderStore, + sessionService session.ProviderSessionService, projectDeleter projects.ProjectDeleter, projectCreator projects.ProjectCreator, ) *Server { @@ -140,7 +143,6 @@ func NewServer( evt: evt, cryptoEngine: cryptoEngine, jwt: jwt, - providerAuthFactory: auth.NewOAuthConfig, mt: serverMetrics, profiles: profileService, ruleTypes: ruleService, @@ -148,6 +150,8 @@ func NewServer( featureFlags: openfeature.NewClient(cfg.Flags.AppName), ghClient: &ghprov.ClientServiceImplementation{}, providerManager: providerManager, + providerAuthManager: providerAuthManager, + sessionService: sessionService, repos: repoService, ghProviders: ghProviders, authzClient: authzClient, diff --git a/internal/providers/github/manager/manager.go b/internal/providers/github/manager/manager.go index 12f19a152e..d5c1e078e2 100644 --- a/internal/providers/github/manager/manager.go +++ b/internal/providers/github/manager/manager.go @@ -26,6 +26,8 @@ import ( gogithub "github.com/google/go-github/v61/github" "github.com/google/uuid" "github.com/rs/zerolog" + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" "github.com/stacklok/minder/internal/config/server" "github.com/stacklok/minder/internal/crypto" @@ -297,28 +299,13 @@ type credentialDetails struct { } func (g *githubProviderManager) GetConfig( - _ context.Context, class db.ProviderClass, userConfig json.RawMessage, + ctx context.Context, class db.ProviderClass, userConfig json.RawMessage, ) (json.RawMessage, error) { if !slices.Contains(g.GetSupportedClasses(), class) { return nil, fmt.Errorf("provider does not implement %s", string(class)) } - var defaultConfig string - // nolint:exhaustive // we really want handle only the two - switch class { - case db.ProviderClassGithub: - defaultConfig = `{"github": {}}` - case db.ProviderClassGithubApp: - defaultConfig = `{"github-app": {}}` - default: - return nil, fmt.Errorf("unsupported provider class %s", class) - } - - if len(userConfig) == 0 { - return json.RawMessage(defaultConfig), nil - } - - return userConfig, nil + return g.ghService.GetConfig(ctx, class, userConfig) } func (g *githubProviderManager) ValidateConfig( @@ -342,3 +329,94 @@ func (g *githubProviderManager) ValidateConfig( return err } + +func (g *githubProviderManager) NewOAuthConfig(providerClass db.ProviderClass, cli bool) (*oauth2.Config, error) { + var oauthConfig *oauth2.Config + var oauthClientConfig *server.OAuthClientConfig + var err error + + switch providerClass { // nolint:exhaustive // we really want handle only the two + case db.ProviderClassGithub: + oauthClientConfig = &g.config.GitHub.OAuthClientConfig + oauthConfig = githubOauthConfig(oauthClientConfig.RedirectURI, cli) + case db.ProviderClassGithubApp: + oauthClientConfig = &g.config.GitHubApp.OAuthClientConfig + oauthConfig = githubAppOauthConfig(oauthClientConfig.RedirectURI) + default: + err = fmt.Errorf("invalid provider class: %s", providerClass) + } + + if err != nil { + return nil, err + } + + clientId, err := oauthClientConfig.GetClientID() + if err != nil { + return nil, fmt.Errorf("failed to get client ID: %w", err) + } + + clientSecret, err := oauthClientConfig.GetClientSecret() + if err != nil { + return nil, fmt.Errorf("failed to get client secret: %w", err) + } + + // this is currently only used for testing as github uses well-known endpoints + if oauthClientConfig.Endpoint != nil { + oauthConfig.Endpoint = oauth2.Endpoint{ + TokenURL: oauthClientConfig.Endpoint.TokenURL, + } + } + + oauthConfig.ClientID = clientId + oauthConfig.ClientSecret = clientSecret + return oauthConfig, nil +} + +func githubOauthConfig(redirectUrlBase string, cli bool) *oauth2.Config { + var redirectUrl string + + if cli { + redirectUrl = fmt.Sprintf("%s/cli", redirectUrlBase) + } else { + redirectUrl = fmt.Sprintf("%s/web", redirectUrlBase) + } + + return &oauth2.Config{ + RedirectURL: redirectUrl, + Scopes: []string{"user:email", "repo", "read:packages", "write:packages", "workflow", "read:org"}, + Endpoint: github.Endpoint, + } +} + +func githubAppOauthConfig(redirectUrlBase string) *oauth2.Config { + return &oauth2.Config{ + RedirectURL: redirectUrlBase, + Scopes: []string{}, + Endpoint: github.Endpoint, + } +} + +func (g *githubProviderManager) ValidateCredentials( + ctx context.Context, cred v1.Credential, params *m.CredentialVerifyParams, +) error { + tokenCred, ok := cred.(v1.OAuth2TokenCredential) + if !ok { + return fmt.Errorf("invalid credential type: %T", cred) + } + + token, err := tokenCred.GetAsOAuth2TokenSource().Token() + if err != nil { + return fmt.Errorf("cannot get token from credential: %w", err) + } + + if params.RemoteUser != "" { + err := g.ghService.VerifyProviderTokenIdentity(ctx, params.RemoteUser, token.AccessToken) + if err != nil { + return fmt.Errorf("error verifying token identity: %w", err) + } + } else { + zerolog.Ctx(ctx).Warn().Msg("RemoteUser not found in session state") + } + + return nil +} diff --git a/internal/providers/github/service/mock/service.go b/internal/providers/github/service/mock/service.go index 5726971a9f..1623223ef8 100644 --- a/internal/providers/github/service/mock/service.go +++ b/internal/providers/github/service/mock/service.go @@ -11,6 +11,7 @@ package mock_service import ( context "context" + json "encoding/json" http "net/http" reflect "reflect" @@ -116,6 +117,21 @@ func (mr *MockGitHubProviderServiceMockRecorder) DeleteInstallation(ctx, provide return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteInstallation", reflect.TypeOf((*MockGitHubProviderService)(nil).DeleteInstallation), ctx, providerID) } +// GetConfig mocks base method. +func (m *MockGitHubProviderService) GetConfig(ctx context.Context, class db.ProviderClass, userConfig json.RawMessage) (json.RawMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetConfig", ctx, class, userConfig) + ret0, _ := ret[0].(json.RawMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetConfig indicates an expected call of GetConfig. +func (mr *MockGitHubProviderServiceMockRecorder) GetConfig(ctx, class, userConfig any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfig", reflect.TypeOf((*MockGitHubProviderService)(nil).GetConfig), ctx, class, userConfig) +} + // ValidateGitHubAppWebhookPayload mocks base method. func (m *MockGitHubProviderService) ValidateGitHubAppWebhookPayload(r *http.Request) ([]byte, error) { m.ctrl.T.Helper() diff --git a/internal/providers/github/service/service.go b/internal/providers/github/service/service.go index 1e90c2c95d..9d08e2d707 100644 --- a/internal/providers/github/service/service.go +++ b/internal/providers/github/service/service.go @@ -69,6 +69,8 @@ type GitHubProviderService interface { // DeleteInstallation deletes the installation from GitHub, if the provider has an associated installation DeleteInstallation(ctx context.Context, providerID uuid.UUID) error VerifyProviderTokenIdentity(ctx context.Context, remoteUser string, accessToken string) error + // GetConfig returns the provider configuration + GetConfig(ctx context.Context, class db.ProviderClass, userConfig json.RawMessage) (json.RawMessage, error) } // TypeGitHubOrganization is the type returned from the GitHub API when the owner is an organization @@ -255,12 +257,18 @@ func (p *ghProviderService) CreateGitHubAppProvider( return nil } + finalConfig, err := p.GetConfig(ctx, db.ProviderClassGithubApp, stateData.ProviderConfig) + if err != nil { + return nil, fmt.Errorf("error getting provider config: %w", err) + } + provider, err := createGitHubApp( ctx, qtx, stateData.ProjectID, installationOwner, installationID, + finalConfig, validateOwnership, sql.NullString{ String: state, @@ -317,7 +325,8 @@ func (p *ghProviderService) CreateGitHubAppWithoutInvitation( zerolog.Ctx(ctx).Info().Str("project", project.ID.String()).Int64("owner", installationOwner.GetID()). Msg("Creating GitHub App Provider") - _, err = createGitHubApp(ctx, qtx, project.ID, installationOwner, installationID, nil, sql.NullString{}) + _, err = createGitHubApp( + ctx, qtx, project.ID, installationOwner, installationID, json.RawMessage(`{"github-app": {}}`), nil, sql.NullString{}) if err != nil { return nil, fmt.Errorf("error creating GitHub App Provider: %w", err) @@ -334,6 +343,7 @@ func createGitHubApp( projectId uuid.UUID, installationOwner *github.User, installationID int64, + providerConfig json.RawMessage, validateOwnership func(ctx context.Context) error, nonce sql.NullString, ) (db.Provider, error) { @@ -355,7 +365,7 @@ func createGitHubApp( ProjectID: projectId, Class: class, Implements: providerDef.Traits, - Definition: json.RawMessage(`{"github-app": {}}`), + Definition: providerConfig, AuthFlows: providerDef.AuthorizationFlows, }) if err != nil { @@ -527,3 +537,23 @@ func (p *ghProviderService) getInstallationOwner(ctx context.Context, installati } return installation.GetAccount(), nil } + +func (_ *ghProviderService) GetConfig( + _ context.Context, class db.ProviderClass, userConfig json.RawMessage, +) (json.RawMessage, error) { + var defaultConfig string + // nolint:exhaustive // we really want handle only the two + switch class { + case db.ProviderClassGithub: + defaultConfig = `{"github": {}}` + case db.ProviderClassGithubApp: + defaultConfig = `{"github-app": {}}` + default: + return nil, fmt.Errorf("unsupported provider class %s", class) + } + if len(userConfig) == 0 { + return json.RawMessage(defaultConfig), nil + } + + return userConfig, nil +} diff --git a/internal/providers/manager/auth_manager.go b/internal/providers/manager/auth_manager.go new file mode 100644 index 0000000000..743523a34f --- /dev/null +++ b/internal/providers/manager/auth_manager.go @@ -0,0 +1,113 @@ +// +// Copyright 2023 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:generate go run go.uber.org/mock/mockgen -package mock_$GOPACKAGE -destination=./mock/$GOFILE -source=./$GOFILE + +package manager + +import ( + "context" + "fmt" + + "golang.org/x/oauth2" + + "github.com/stacklok/minder/internal/db" + v1 "github.com/stacklok/minder/pkg/providers/v1" +) + +// CredentialVerifyParams are the currently supported parameters for credential verification +type CredentialVerifyParams struct { + RemoteUser string +} + +// CredentialVerifyOptFn is a function that sets options for credential verification +type CredentialVerifyOptFn func(*CredentialVerifyParams) + +// WithRemoteUser sets the remote user for the credential verification +func WithRemoteUser(remoteUser string) CredentialVerifyOptFn { + return func(params *CredentialVerifyParams) { + params.RemoteUser = remoteUser + } +} + +// AuthManager is the interface for managing authentication with provider classes +type AuthManager interface { + NewOAuthConfig(providerClass db.ProviderClass, cli bool) (*oauth2.Config, error) + ValidateCredentials(ctx context.Context, providerClass db.ProviderClass, cred v1.Credential, opts ...CredentialVerifyOptFn) error +} + +type providerClassAuthManager interface { +} + +type providerClassOAuthManager interface { + ProviderClassManager + + NewOAuthConfig(providerClass db.ProviderClass, cli bool) (*oauth2.Config, error) + ValidateCredentials(ctx context.Context, cred v1.Credential, params *CredentialVerifyParams) error +} + +type authManager struct { + classTracker +} + +// NewAuthManager creates a new AuthManager for managing authentication with providers classes +func NewAuthManager( + classManagers ...ProviderClassManager, +) (AuthManager, error) { + classes, err := newClassTracker(classManagers...) + if err != nil { + return nil, fmt.Errorf("error creating class tracker: %w", err) + } + + return &authManager{ + classTracker: *classes, + }, nil +} + +func (a *authManager) NewOAuthConfig(providerClass db.ProviderClass, cli bool) (*oauth2.Config, error) { + manager, err := a.getClassManager(providerClass) + if err != nil { + return nil, fmt.Errorf("error getting class manager: %w", err) + } + + oauthManager, ok := manager.(providerClassOAuthManager) + if !ok { + return nil, fmt.Errorf("class manager does not implement OAuthManager") + } + + return oauthManager.NewOAuthConfig(providerClass, cli) +} + +func (a *authManager) ValidateCredentials( + ctx context.Context, providerClass db.ProviderClass, cred v1.Credential, opts ...CredentialVerifyOptFn, +) error { + manager, err := a.getClassManager(providerClass) + if err != nil { + return fmt.Errorf("error getting class manager: %w", err) + } + + oauthManager, ok := manager.(providerClassOAuthManager) + if !ok { + return fmt.Errorf("class manager does not implement OAuthManager") + } + + var verifyParams CredentialVerifyParams + + for _, opt := range opts { + opt(&verifyParams) + } + + return oauthManager.ValidateCredentials(ctx, cred, &verifyParams) +} diff --git a/internal/providers/manager/auth_manager_test.go b/internal/providers/manager/auth_manager_test.go new file mode 100644 index 0000000000..195654e2dc --- /dev/null +++ b/internal/providers/manager/auth_manager_test.go @@ -0,0 +1,207 @@ +// Copyright 2024 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package manager_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" + + "github.com/stacklok/minder/internal/db" + "github.com/stacklok/minder/internal/providers/credentials" + "github.com/stacklok/minder/internal/providers/manager" + mockmanager "github.com/stacklok/minder/internal/providers/manager/mock" +) + +func TestAuthManager_NewAuthManager(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + scenarios := []struct { + name string + oauthMan *mockmanager.MockproviderClassOAuthManager + providerMan *mockmanager.MockProviderClassManager + setupMocks setupMockCalls + expectedErr string + }{ + { + name: "happy path", + setupMocks: func(ghClassManager *mockmanager.MockproviderClassOAuthManager, dhClassManager *mockmanager.MockProviderClassManager) { + ghClassManager.EXPECT().GetSupportedClasses().Return([]db.ProviderClass{db.ProviderClassGithub}).MaxTimes(1) + dhClassManager.EXPECT().GetSupportedClasses().Return([]db.ProviderClass{db.ProviderClassDockerhub}).MaxTimes(1) + }, + }, + { + name: "implementing the same classes", + setupMocks: func(ghClassManager *mockmanager.MockproviderClassOAuthManager, dhClassManager *mockmanager.MockProviderClassManager) { + ghClassManager.EXPECT().GetSupportedClasses().Return([]db.ProviderClass{db.ProviderClassGithub}).MaxTimes(1) + dhClassManager.EXPECT().GetSupportedClasses().Return([]db.ProviderClass{db.ProviderClassGithub}).MaxTimes(1) + }, + expectedErr: "more than once", + }, + { + name: "no registered classes", + setupMocks: func(ghClassManager *mockmanager.MockproviderClassOAuthManager, _ *mockmanager.MockProviderClassManager) { + ghClassManager.EXPECT().GetSupportedClasses().Return([]db.ProviderClass{}).MaxTimes(1) + }, + expectedErr: "no registered classes", + }, + } + + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + t.Parallel() + + scenario.oauthMan = mockmanager.NewMockproviderClassOAuthManager(ctrl) + scenario.providerMan = mockmanager.NewMockProviderClassManager(ctrl) + scenario.setupMocks(scenario.oauthMan, scenario.providerMan) + + authManager, err := manager.NewAuthManager(scenario.oauthMan, scenario.providerMan) + if scenario.expectedErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, scenario.expectedErr) + require.Nil(t, authManager) + } else { + require.NoError(t, err) + require.NotNil(t, authManager) + } + }) + } + + dhClassManager := mockmanager.NewMockProviderClassManager(ctrl) + require.NotNil(t, dhClassManager) +} + +func newMockAuthManager(t *testing.T, ctrl *gomock.Controller) (manager.AuthManager, *mockmanager.MockproviderClassOAuthManager, *mockmanager.MockProviderClassManager) { + t.Helper() + + ghClassManager := mockmanager.NewMockproviderClassOAuthManager(ctrl) + require.NotNil(t, ghClassManager) + ghClassManager.EXPECT().GetSupportedClasses().Return([]db.ProviderClass{db.ProviderClassGithub}).MaxTimes(1) + + dhClassManager := mockmanager.NewMockProviderClassManager(ctrl) + require.NotNil(t, dhClassManager) + dhClassManager.EXPECT().GetSupportedClasses().Return([]db.ProviderClass{db.ProviderClassDockerhub}).MaxTimes(1) + + authManager, err := manager.NewAuthManager(ghClassManager, dhClassManager) + require.NoError(t, err) + require.NotNil(t, authManager) + + return authManager, ghClassManager, dhClassManager +} + +type setupMockCalls func(*mockmanager.MockproviderClassOAuthManager, *mockmanager.MockProviderClassManager) + +func TestAuthManager_NewOAuthConfig_Validate_ClassManagerProperties(t *testing.T) { + t.Parallel() + + scenarios := []struct { + name string + providerClass db.ProviderClass + setupMocks setupMockCalls + expectedErr string + }{ + { + name: "github implements OAuthManager", + providerClass: db.ProviderClassGithub, + setupMocks: func(ghClassManager *mockmanager.MockproviderClassOAuthManager, _ *mockmanager.MockProviderClassManager) { + ghClassManager.EXPECT().NewOAuthConfig(db.ProviderClassGithub, false). + Return(&oauth2.Config{ + Endpoint: github.Endpoint, + }, nil). + MaxTimes(1) + ghClassManager.EXPECT().ValidateCredentials(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil). + MaxTimes(1) + }, + }, + { + name: "dockerhub does not implement OAuthManager", + providerClass: db.ProviderClassDockerhub, + setupMocks: func(_ *mockmanager.MockproviderClassOAuthManager, _ *mockmanager.MockProviderClassManager) { + }, + expectedErr: "class manager does not implement OAuthManager", + }, + { + name: "ghcr is not registered", + providerClass: db.ProviderClassGhcr, + setupMocks: func(_ *mockmanager.MockproviderClassOAuthManager, _ *mockmanager.MockProviderClassManager) { + }, + expectedErr: "error getting class manager", + }, + } + + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + authManager, ghClassManager, dhClassManager := newMockAuthManager(t, ctrl) + scenario.setupMocks(ghClassManager, dhClassManager) + + config, err := authManager.NewOAuthConfig(scenario.providerClass, false) + if scenario.expectedErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, scenario.expectedErr) + require.Nil(t, config) + } else { + require.NoError(t, err) + require.NotNil(t, config) + } + + err = authManager.ValidateCredentials(context.Background(), scenario.providerClass, credentials.NewOAuth2TokenCredential("token")) + if scenario.expectedErr != "" { + require.Error(t, err) + require.ErrorContains(t, err, scenario.expectedErr) + require.Nil(t, config) + } else { + require.NoError(t, err) + require.NotNil(t, config) + } + }) + + } +} + +func TestAuthManager_ValidateCredentials(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + authManager, ghClassManager, _ := newMockAuthManager(t, ctrl) + + ghClassManager.EXPECT().ValidateCredentials( + gomock.Any(), + credentials.NewGitHubTokenCredential("ghtoken"), + &manager.CredentialVerifyParams{ + RemoteUser: "remoteuser", + }) + + err := authManager.ValidateCredentials(context.Background(), + db.ProviderClassGithub, + credentials.NewGitHubTokenCredential("ghtoken"), + manager.WithRemoteUser("remoteuser"), + ) + require.NoError(t, err) +} diff --git a/internal/providers/manager/manager.go b/internal/providers/manager/manager.go index d6a0701125..272818f56e 100644 --- a/internal/providers/manager/manager.go +++ b/internal/providers/manager/manager.go @@ -68,8 +68,10 @@ type ProviderManager interface { // specific Provider class. The idea is that ProviderManager determines the // class of the Provider, and delegates to the appropraite ProviderClassManager type ProviderClassManager interface { - ValidateConfig(ctx context.Context, class db.ProviderClass, config json.RawMessage) error + providerClassAuthManager + GetConfig(ctx context.Context, class db.ProviderClass, userConfig json.RawMessage) (json.RawMessage, error) + ValidateConfig(ctx context.Context, class db.ProviderClass, config json.RawMessage) error // Build creates an instance of Provider based on the config in the DB Build(ctx context.Context, config *db.Provider) (v1.Provider, error) // Delete deletes an instance of this provider @@ -79,16 +81,21 @@ type ProviderClassManager interface { GetSupportedClasses() []db.ProviderClass } -type providerManager struct { +type classTracker struct { classManagers map[db.ProviderClass]ProviderClassManager - store providers.ProviderStore } -// NewProviderManager creates a new instance of ProviderManager -func NewProviderManager( - store providers.ProviderStore, +func (p *classTracker) getClassManager(class db.ProviderClass) (ProviderClassManager, error) { + manager, ok := p.classManagers[class] + if !ok { + return nil, fmt.Errorf("unexpected provider class: %s", class) + } + return manager, nil +} + +func newClassTracker( classManagers ...ProviderClassManager, -) (ProviderManager, error) { +) (*classTracker, error) { classes := make(map[db.ProviderClass]ProviderClassManager) for _, factory := range classManagers { @@ -108,9 +115,29 @@ func NewProviderManager( } } - return &providerManager{ + return &classTracker{ classManagers: classes, - store: store, + }, nil +} + +type providerManager struct { + classTracker + store providers.ProviderStore +} + +// NewProviderManager creates a new instance of ProviderManager +func NewProviderManager( + store providers.ProviderStore, + classManagers ...ProviderClassManager, +) (ProviderManager, error) { + classes, err := newClassTracker(classManagers...) + if err != nil { + return nil, fmt.Errorf("error creating class tracker: %w", err) + } + + return &providerManager{ + classTracker: *classes, + store: store, }, nil } @@ -222,11 +249,3 @@ func (p *providerManager) buildFromDBRecord(ctx context.Context, config *db.Prov } return manager.Build(ctx, config) } - -func (p *providerManager) getClassManager(class db.ProviderClass) (ProviderClassManager, error) { - manager, ok := p.classManagers[class] - if !ok { - return nil, fmt.Errorf("unexpected provider class: %s", class) - } - return manager, nil -} diff --git a/internal/providers/manager/mock/auth_manager.go b/internal/providers/manager/mock/auth_manager.go new file mode 100644 index 0000000000..cc0a9e1b54 --- /dev/null +++ b/internal/providers/manager/mock/auth_manager.go @@ -0,0 +1,226 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./auth_manager.go +// +// Generated by this command: +// +// mockgen -package mock_manager -destination=./mock/auth_manager.go -source=./auth_manager.go +// + +// Package mock_manager is a generated GoMock package. +package mock_manager + +import ( + context "context" + json "encoding/json" + reflect "reflect" + + db "github.com/stacklok/minder/internal/db" + manager "github.com/stacklok/minder/internal/providers/manager" + v1 "github.com/stacklok/minder/pkg/providers/v1" + gomock "go.uber.org/mock/gomock" + oauth2 "golang.org/x/oauth2" +) + +// MockAuthManager is a mock of AuthManager interface. +type MockAuthManager struct { + ctrl *gomock.Controller + recorder *MockAuthManagerMockRecorder +} + +// MockAuthManagerMockRecorder is the mock recorder for MockAuthManager. +type MockAuthManagerMockRecorder struct { + mock *MockAuthManager +} + +// NewMockAuthManager creates a new mock instance. +func NewMockAuthManager(ctrl *gomock.Controller) *MockAuthManager { + mock := &MockAuthManager{ctrl: ctrl} + mock.recorder = &MockAuthManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAuthManager) EXPECT() *MockAuthManagerMockRecorder { + return m.recorder +} + +// NewOAuthConfig mocks base method. +func (m *MockAuthManager) NewOAuthConfig(providerClass db.ProviderClass, cli bool) (*oauth2.Config, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewOAuthConfig", providerClass, cli) + ret0, _ := ret[0].(*oauth2.Config) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewOAuthConfig indicates an expected call of NewOAuthConfig. +func (mr *MockAuthManagerMockRecorder) NewOAuthConfig(providerClass, cli any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewOAuthConfig", reflect.TypeOf((*MockAuthManager)(nil).NewOAuthConfig), providerClass, cli) +} + +// ValidateCredentials mocks base method. +func (m *MockAuthManager) ValidateCredentials(ctx context.Context, providerClass db.ProviderClass, cred v1.Credential, opts ...manager.CredentialVerifyOptFn) error { + m.ctrl.T.Helper() + varargs := []any{ctx, providerClass, cred} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ValidateCredentials", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateCredentials indicates an expected call of ValidateCredentials. +func (mr *MockAuthManagerMockRecorder) ValidateCredentials(ctx, providerClass, cred any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, providerClass, cred}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateCredentials", reflect.TypeOf((*MockAuthManager)(nil).ValidateCredentials), varargs...) +} + +// MockproviderClassAuthManager is a mock of providerClassAuthManager interface. +type MockproviderClassAuthManager struct { + ctrl *gomock.Controller + recorder *MockproviderClassAuthManagerMockRecorder +} + +// MockproviderClassAuthManagerMockRecorder is the mock recorder for MockproviderClassAuthManager. +type MockproviderClassAuthManagerMockRecorder struct { + mock *MockproviderClassAuthManager +} + +// NewMockproviderClassAuthManager creates a new mock instance. +func NewMockproviderClassAuthManager(ctrl *gomock.Controller) *MockproviderClassAuthManager { + mock := &MockproviderClassAuthManager{ctrl: ctrl} + mock.recorder = &MockproviderClassAuthManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockproviderClassAuthManager) EXPECT() *MockproviderClassAuthManagerMockRecorder { + return m.recorder +} + +// MockproviderClassOAuthManager is a mock of providerClassOAuthManager interface. +type MockproviderClassOAuthManager struct { + ctrl *gomock.Controller + recorder *MockproviderClassOAuthManagerMockRecorder +} + +// MockproviderClassOAuthManagerMockRecorder is the mock recorder for MockproviderClassOAuthManager. +type MockproviderClassOAuthManagerMockRecorder struct { + mock *MockproviderClassOAuthManager +} + +// NewMockproviderClassOAuthManager creates a new mock instance. +func NewMockproviderClassOAuthManager(ctrl *gomock.Controller) *MockproviderClassOAuthManager { + mock := &MockproviderClassOAuthManager{ctrl: ctrl} + mock.recorder = &MockproviderClassOAuthManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockproviderClassOAuthManager) EXPECT() *MockproviderClassOAuthManagerMockRecorder { + return m.recorder +} + +// Build mocks base method. +func (m *MockproviderClassOAuthManager) Build(ctx context.Context, config *db.Provider) (v1.Provider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Build", ctx, config) + ret0, _ := ret[0].(v1.Provider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Build indicates an expected call of Build. +func (mr *MockproviderClassOAuthManagerMockRecorder) Build(ctx, config any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Build", reflect.TypeOf((*MockproviderClassOAuthManager)(nil).Build), ctx, config) +} + +// Delete mocks base method. +func (m *MockproviderClassOAuthManager) Delete(ctx context.Context, config *db.Provider) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", ctx, config) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockproviderClassOAuthManagerMockRecorder) Delete(ctx, config any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockproviderClassOAuthManager)(nil).Delete), ctx, config) +} + +// GetConfig mocks base method. +func (m *MockproviderClassOAuthManager) GetConfig(ctx context.Context, class db.ProviderClass, userConfig json.RawMessage) (json.RawMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetConfig", ctx, class, userConfig) + ret0, _ := ret[0].(json.RawMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetConfig indicates an expected call of GetConfig. +func (mr *MockproviderClassOAuthManagerMockRecorder) GetConfig(ctx, class, userConfig any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetConfig", reflect.TypeOf((*MockproviderClassOAuthManager)(nil).GetConfig), ctx, class, userConfig) +} + +// GetSupportedClasses mocks base method. +func (m *MockproviderClassOAuthManager) GetSupportedClasses() []db.ProviderClass { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSupportedClasses") + ret0, _ := ret[0].([]db.ProviderClass) + return ret0 +} + +// GetSupportedClasses indicates an expected call of GetSupportedClasses. +func (mr *MockproviderClassOAuthManagerMockRecorder) GetSupportedClasses() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSupportedClasses", reflect.TypeOf((*MockproviderClassOAuthManager)(nil).GetSupportedClasses)) +} + +// NewOAuthConfig mocks base method. +func (m *MockproviderClassOAuthManager) NewOAuthConfig(providerClass db.ProviderClass, cli bool) (*oauth2.Config, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewOAuthConfig", providerClass, cli) + ret0, _ := ret[0].(*oauth2.Config) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// NewOAuthConfig indicates an expected call of NewOAuthConfig. +func (mr *MockproviderClassOAuthManagerMockRecorder) NewOAuthConfig(providerClass, cli any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewOAuthConfig", reflect.TypeOf((*MockproviderClassOAuthManager)(nil).NewOAuthConfig), providerClass, cli) +} + +// ValidateConfig mocks base method. +func (m *MockproviderClassOAuthManager) ValidateConfig(ctx context.Context, class db.ProviderClass, config json.RawMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateConfig", ctx, class, config) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateConfig indicates an expected call of ValidateConfig. +func (mr *MockproviderClassOAuthManagerMockRecorder) ValidateConfig(ctx, class, config any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateConfig", reflect.TypeOf((*MockproviderClassOAuthManager)(nil).ValidateConfig), ctx, class, config) +} + +// ValidateCredentials mocks base method. +func (m *MockproviderClassOAuthManager) ValidateCredentials(ctx context.Context, cred v1.Credential, params *manager.CredentialVerifyParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateCredentials", ctx, cred, params) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateCredentials indicates an expected call of ValidateCredentials. +func (mr *MockproviderClassOAuthManagerMockRecorder) ValidateCredentials(ctx, cred, params any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateCredentials", reflect.TypeOf((*MockproviderClassOAuthManager)(nil).ValidateCredentials), ctx, cred, params) +} diff --git a/internal/providers/session/mock/service.go b/internal/providers/session/mock/service.go new file mode 100644 index 0000000000..a0060a8676 --- /dev/null +++ b/internal/providers/session/mock/service.go @@ -0,0 +1,149 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./service.go +// +// Generated by this command: +// +// mockgen -package mock_session -destination=./mock/service.go -source=./service.go +// + +// Package mock_session is a generated GoMock package. +package mock_session + +import ( + context "context" + reflect "reflect" + + uuid "github.com/google/uuid" + crypto "github.com/stacklok/minder/internal/crypto" + db "github.com/stacklok/minder/internal/db" + gomock "go.uber.org/mock/gomock" +) + +// MockProviderSessionService is a mock of ProviderSessionService interface. +type MockProviderSessionService struct { + ctrl *gomock.Controller + recorder *MockProviderSessionServiceMockRecorder +} + +// MockProviderSessionServiceMockRecorder is the mock recorder for MockProviderSessionService. +type MockProviderSessionServiceMockRecorder struct { + mock *MockProviderSessionService +} + +// NewMockProviderSessionService creates a new mock instance. +func NewMockProviderSessionService(ctrl *gomock.Controller) *MockProviderSessionService { + mock := &MockProviderSessionService{ctrl: ctrl} + mock.recorder = &MockProviderSessionServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockProviderSessionService) EXPECT() *MockProviderSessionServiceMockRecorder { + return m.recorder +} + +// CreateProviderFromSessionState mocks base method. +func (m *MockProviderSessionService) CreateProviderFromSessionState(ctx context.Context, providerClass db.ProviderClass, encryptedCreds *crypto.EncryptedData, state string) (*db.Provider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateProviderFromSessionState", ctx, providerClass, encryptedCreds, state) + ret0, _ := ret[0].(*db.Provider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateProviderFromSessionState indicates an expected call of CreateProviderFromSessionState. +func (mr *MockProviderSessionServiceMockRecorder) CreateProviderFromSessionState(ctx, providerClass, encryptedCreds, state any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateProviderFromSessionState", reflect.TypeOf((*MockProviderSessionService)(nil).CreateProviderFromSessionState), ctx, providerClass, encryptedCreds, state) +} + +// MockproviderByNameGetter is a mock of providerByNameGetter interface. +type MockproviderByNameGetter struct { + ctrl *gomock.Controller + recorder *MockproviderByNameGetterMockRecorder +} + +// MockproviderByNameGetterMockRecorder is the mock recorder for MockproviderByNameGetter. +type MockproviderByNameGetterMockRecorder struct { + mock *MockproviderByNameGetter +} + +// NewMockproviderByNameGetter creates a new mock instance. +func NewMockproviderByNameGetter(ctrl *gomock.Controller) *MockproviderByNameGetter { + mock := &MockproviderByNameGetter{ctrl: ctrl} + mock.recorder = &MockproviderByNameGetterMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockproviderByNameGetter) EXPECT() *MockproviderByNameGetterMockRecorder { + return m.recorder +} + +// GetByName mocks base method. +func (m *MockproviderByNameGetter) GetByName(ctx context.Context, projectID uuid.UUID, name string) (*db.Provider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetByName", ctx, projectID, name) + ret0, _ := ret[0].(*db.Provider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetByName indicates an expected call of GetByName. +func (mr *MockproviderByNameGetterMockRecorder) GetByName(ctx, projectID, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByName", reflect.TypeOf((*MockproviderByNameGetter)(nil).GetByName), ctx, projectID, name) +} + +// MockdbSessionStore is a mock of dbSessionStore interface. +type MockdbSessionStore struct { + ctrl *gomock.Controller + recorder *MockdbSessionStoreMockRecorder +} + +// MockdbSessionStoreMockRecorder is the mock recorder for MockdbSessionStore. +type MockdbSessionStoreMockRecorder struct { + mock *MockdbSessionStore +} + +// NewMockdbSessionStore creates a new mock instance. +func NewMockdbSessionStore(ctrl *gomock.Controller) *MockdbSessionStore { + mock := &MockdbSessionStore{ctrl: ctrl} + mock.recorder = &MockdbSessionStoreMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockdbSessionStore) EXPECT() *MockdbSessionStoreMockRecorder { + return m.recorder +} + +// GetProjectIDBySessionState mocks base method. +func (m *MockdbSessionStore) GetProjectIDBySessionState(ctx context.Context, sessionState string) (db.GetProjectIDBySessionStateRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProjectIDBySessionState", ctx, sessionState) + ret0, _ := ret[0].(db.GetProjectIDBySessionStateRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProjectIDBySessionState indicates an expected call of GetProjectIDBySessionState. +func (mr *MockdbSessionStoreMockRecorder) GetProjectIDBySessionState(ctx, sessionState any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProjectIDBySessionState", reflect.TypeOf((*MockdbSessionStore)(nil).GetProjectIDBySessionState), ctx, sessionState) +} + +// UpsertAccessToken mocks base method. +func (m *MockdbSessionStore) UpsertAccessToken(ctx context.Context, arg db.UpsertAccessTokenParams) (db.ProviderAccessToken, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertAccessToken", ctx, arg) + ret0, _ := ret[0].(db.ProviderAccessToken) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertAccessToken indicates an expected call of UpsertAccessToken. +func (mr *MockdbSessionStoreMockRecorder) UpsertAccessToken(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertAccessToken", reflect.TypeOf((*MockdbSessionStore)(nil).UpsertAccessToken), ctx, arg) +} diff --git a/internal/providers/session/service.go b/internal/providers/session/service.go new file mode 100644 index 0000000000..aae09e8cdc --- /dev/null +++ b/internal/providers/session/service.go @@ -0,0 +1,118 @@ +// Copyright 2024 Stacklok, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package session contains the business logic for creating providers from session state. +package session + +import ( + "context" + "database/sql" + "errors" + "fmt" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/stacklok/minder/internal/crypto" + "github.com/stacklok/minder/internal/db" + "github.com/stacklok/minder/internal/providers" + "github.com/stacklok/minder/internal/providers/manager" +) + +//go:generate go run go.uber.org/mock/mockgen -package mock_$GOPACKAGE -destination=./mock/$GOFILE -source=./$GOFILE + +// ProviderSessionService is the interface for creating providers from session state +type ProviderSessionService interface { + CreateProviderFromSessionState( + ctx context.Context, providerClass db.ProviderClass, + encryptedCreds *crypto.EncryptedData, state string, + ) (*db.Provider, error) +} + +type providerByNameGetter interface { + GetByName(ctx context.Context, projectID uuid.UUID, name string) (*db.Provider, error) +} + +type dbSessionStore interface { + GetProjectIDBySessionState(ctx context.Context, sessionState string) (db.GetProjectIDBySessionStateRow, error) + UpsertAccessToken(ctx context.Context, arg db.UpsertAccessTokenParams) (db.ProviderAccessToken, error) +} + +type providerSessionService struct { + providerManager manager.ProviderManager + provGetter providerByNameGetter + dbStore dbSessionStore +} + +// NewProviderSessionService creates a new provider session service +func NewProviderSessionService( + providerManager manager.ProviderManager, + provGetter providerByNameGetter, + dbStore dbSessionStore, +) ProviderSessionService { + return &providerSessionService{ + providerManager: providerManager, + provGetter: provGetter, + dbStore: dbStore, + } +} + +func (pss *providerSessionService) CreateProviderFromSessionState( + ctx context.Context, providerClass db.ProviderClass, + encryptedCreds *crypto.EncryptedData, state string, +) (*db.Provider, error) { + stateData, err := pss.dbStore.GetProjectIDBySessionState(ctx, state) + if err != nil { + return nil, fmt.Errorf("error getting state data by session state: %w", err) + } + + serialized, err := encryptedCreds.Serialize() + if err != nil { + return nil, status.Errorf(codes.Internal, "error serializing secret: %s", err) + } + + accessTokenParams := db.UpsertAccessTokenParams{ + ProjectID: stateData.ProjectID, + Provider: stateData.Provider, + OwnerFilter: stateData.OwnerFilter, + EnrollmentNonce: sql.NullString{String: state, Valid: true}, + EncryptedAccessToken: pqtype.NullRawMessage{ + RawMessage: serialized, + Valid: true, + }, + } + + // Check if the provider exists + pErr := providers.ErrProviderNotFoundBy{} + provider, err := pss.provGetter.GetByName(ctx, stateData.ProjectID, stateData.Provider) + if errors.As(err, &pErr) { + createdProvider, err := pss.providerManager.CreateFromConfig( + ctx, providerClass, stateData.ProjectID, stateData.Provider, stateData.ProviderConfig) + if err != nil { + return nil, fmt.Errorf("error creating provider: %w", err) + } + provider = createdProvider + } else if err != nil { + return nil, fmt.Errorf("error getting provider from DB: %w", err) + } + + _, err = pss.dbStore.UpsertAccessToken(ctx, accessTokenParams) + if err != nil { + return nil, fmt.Errorf("error inserting access token: %w", err) + } + + return provider, nil +} diff --git a/internal/service/service.go b/internal/service/service.go index 0213c29d14..7ec41998af 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -46,6 +46,7 @@ import ( "github.com/stacklok/minder/internal/providers/github/service" "github.com/stacklok/minder/internal/providers/manager" "github.com/stacklok/minder/internal/providers/ratecache" + "github.com/stacklok/minder/internal/providers/session" provtelemetry "github.com/stacklok/minder/internal/providers/telemetry" "github.com/stacklok/minder/internal/reconcilers" "github.com/stacklok/minder/internal/repositories/github" @@ -80,6 +81,9 @@ func AllInOneServerService( return fmt.Errorf("failed to create crypto engine: %w", err) } + serverconfig.FallbackOAuthClientConfigValues("github", &cfg.Provider.GitHub.OAuthClientConfig) + serverconfig.FallbackOAuthClientConfigValues("github-app", &cfg.Provider.GitHubApp.OAuthClientConfig) + profileSvc := profiles.NewProfileService(evt) ruleSvc := ruletypes.NewRuleTypeService() marketplace, err := marketplaces.NewMarketplaceFromServiceConfig(cfg.Marketplace, profileSvc, ruleSvc) @@ -121,8 +125,13 @@ func AllInOneServerService( if err != nil { return fmt.Errorf("failed to create provider manager: %w", err) } + providerAuthManager, err := manager.NewAuthManager(githubProviderManager, dockerhubProviderManager) + if err != nil { + return fmt.Errorf("failed to create provider auth manager: %w", err) + } repos := github.NewRepositoryService(whManager, store, evt, providerManager) projectDeleter := projects.NewProjectDeleter(authzClient, providerManager) + sessionsService := session.NewProviderSessionService(providerManager, providerStore, store) s := controlplane.NewServer( store, @@ -138,7 +147,9 @@ func AllInOneServerService( ruleSvc, ghProviders, providerManager, + providerAuthManager, providerStore, + sessionsService, projectDeleter, projectCreator, )