Skip to content

Commit

Permalink
Make provider an optional filter (#2871)
Browse files Browse the repository at this point in the history
Ref #2522
  • Loading branch information
eleftherias committed Apr 1, 2024
1 parent 0bc208b commit 5c381cf
Show file tree
Hide file tree
Showing 17 changed files with 409 additions and 245 deletions.
3 changes: 1 addition & 2 deletions cmd/cli/app/artifact/artifact.go
Expand Up @@ -20,7 +20,6 @@ import (
"github.com/spf13/cobra"

"github.com/stacklok/minder/cmd/cli/app"
ghclient "github.com/stacklok/minder/internal/providers/github/oauth"
)

// ArtifactCmd is the artifact subcommand
Expand All @@ -36,6 +35,6 @@ var ArtifactCmd = &cobra.Command{
func init() {
app.RootCmd.AddCommand(ArtifactCmd)
// Flags for all subcommands
ArtifactCmd.PersistentFlags().StringP("provider", "p", ghclient.Github, "Name of the provider, i.e. github")
ArtifactCmd.PersistentFlags().StringP("provider", "p", "", "Name of the provider, i.e. github")
ArtifactCmd.PersistentFlags().StringP("project", "j", "", "ID of the project")
}
3 changes: 1 addition & 2 deletions cmd/cli/app/repo/repo.go
Expand Up @@ -20,7 +20,6 @@ import (
"github.com/spf13/cobra"

"github.com/stacklok/minder/cmd/cli/app"
ghclient "github.com/stacklok/minder/internal/providers/github/oauth"
)

// RepoCmd is the root command for the repo subcommands
Expand All @@ -36,6 +35,6 @@ var RepoCmd = &cobra.Command{
func init() {
app.RootCmd.AddCommand(RepoCmd)
// Flags for all subcommands
RepoCmd.PersistentFlags().StringP("provider", "p", ghclient.Github, "Name of the provider, i.e. github")
RepoCmd.PersistentFlags().StringP("provider", "p", "", "Name of the provider, i.e. github")
RepoCmd.PersistentFlags().StringP("project", "j", "", "ID of the project")
}
5 changes: 4 additions & 1 deletion cmd/server/app/webhook_update.go
Expand Up @@ -141,7 +141,10 @@ func updateGithubWebhooks(
) error {
repos, err := store.ListRegisteredRepositoriesByProjectIDAndProvider(ctx,
db.ListRegisteredRepositoriesByProjectIDAndProviderParams{
Provider: provider.Name,
Provider: sql.NullString{
String: provider.Name,
Valid: true,
},
ProjectID: provider.ProjectID,
})
if err != nil {
Expand Down
10 changes: 7 additions & 3 deletions database/query/repositories.sql
Expand Up @@ -20,7 +20,9 @@ INSERT INTO repositories (
SELECT * FROM repositories WHERE repo_id = $1;

-- name: GetRepositoryByRepoName :one
SELECT * FROM repositories WHERE provider = $1 AND repo_owner = $2 AND repo_name = $3 AND project_id = $4;
SELECT * FROM repositories
WHERE repo_owner = $1 AND repo_name = $2 AND project_id = $3
AND lower(provider) = lower(sqlc.narg('provider')::text) OR sqlc.narg('provider')::text IS NULL;

-- avoid using this, where possible use GetRepositoryByIDAndProject instead
-- name: GetRepositoryByID :one
Expand All @@ -31,14 +33,16 @@ SELECT * FROM repositories WHERE id = $1 AND project_id = $2;

-- name: ListRepositoriesByProjectID :many
SELECT * FROM repositories
WHERE provider = $1 AND project_id = $2
WHERE project_id = $1
AND (repo_id >= sqlc.narg('repo_id') OR sqlc.narg('repo_id') IS NULL)
AND lower(provider) = lower(COALESCE(sqlc.narg('provider'), provider)::text)
ORDER BY project_id, provider, repo_id
LIMIT sqlc.narg('limit')::bigint;

-- name: ListRegisteredRepositoriesByProjectIDAndProvider :many
SELECT * FROM repositories
WHERE provider = $1 AND project_id = $2 AND webhook_id IS NOT NULL
WHERE project_id = $1 AND webhook_id IS NOT NULL
AND lower(provider) = lower(sqlc.narg('provider')::text) OR sqlc.narg('provider')::text IS NULL
ORDER BY repo_name;

-- name: DeleteRepository :exec
Expand Down
15 changes: 9 additions & 6 deletions internal/controlplane/common.go
Expand Up @@ -27,10 +27,13 @@ import (
"google.golang.org/grpc/status"

"github.com/stacklok/minder/internal/db"
"github.com/stacklok/minder/internal/providers/github/oauth"
"github.com/stacklok/minder/internal/util"
pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1"
)

const defaultProvider = oauth.Github

// HasProtoContext is an interface that can be implemented by a request
type HasProtoContext interface {
GetContext() *pb.Context
Expand Down Expand Up @@ -60,10 +63,10 @@ func filteredResultNotFoundError(name sql.NullString, trait db.NullProviderType)
func getProviderFromRequestOrDefault(
ctx context.Context,
store db.Store,
in HasProtoContext,
providerName string,
projectId uuid.UUID,
) (db.Provider, error) {
name := getNameFilterParam(in.GetContext())
name := getNameFilterParam(providerName)
providers, err := findProvider(ctx, name, db.NullProviderType{}, projectId, store)
if err != nil {
return db.Provider{}, err
Expand All @@ -79,7 +82,7 @@ func getProvidersByTrait(
projectId uuid.UUID,
trait db.ProviderType,
) ([]db.Provider, error) {
name := getNameFilterParam(in.GetContext())
name := getNameFilterParam(in.GetContext().GetProvider())
t := db.NullProviderType{ProviderType: trait, Valid: true}
providers, err := findProvider(ctx, name, t, projectId, store)
if err != nil {
Expand Down Expand Up @@ -120,10 +123,10 @@ func findProvider(
}

// getNameFilterParam allows us to build a name filter for our provider queries
func getNameFilterParam(in *pb.Context) sql.NullString {
func getNameFilterParam(providerName string) sql.NullString {
return sql.NullString{
String: in.GetProvider(),
Valid: in.GetProvider() != "",
String: providerName,
Valid: providerName != "",
}
}

Expand Down
17 changes: 8 additions & 9 deletions internal/controlplane/handlers_artifacts.go
Expand Up @@ -39,24 +39,20 @@ import (
func (s *Server) ListArtifacts(ctx context.Context, in *pb.ListArtifactsRequest) (*pb.ListArtifactsResponse, error) {
entityCtx := engine.EntityFromContext(ctx)
projectID := entityCtx.Project.ID

provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, projectID)
if err != nil {
return nil, providerError(err)
}
providerName := entityCtx.Provider.Name

artifactFilter, err := parseArtifactListFrom(s.store, in.From)
if err != nil {
return nil, fmt.Errorf("failed to parse artifact list from: %w", err)
}

results, err := artifactFilter.listArtifacts(ctx, provider.Name, projectID)
results, err := artifactFilter.listArtifacts(ctx, providerName, projectID)
if err != nil {
return nil, fmt.Errorf("failed to list artifacts: %w", err)
}

// Telemetry logging
logger.BusinessRecord(ctx).Provider = provider.Name
logger.BusinessRecord(ctx).Provider = providerName
logger.BusinessRecord(ctx).Project = projectID

return &pb.ListArtifactsResponse{Results: results}, nil
Expand All @@ -73,9 +69,10 @@ func (s *Server) GetArtifactByName(ctx context.Context, in *pb.GetArtifactByName

entityCtx := engine.EntityFromContext(ctx)
projectID := entityCtx.Project.ID
providerFilter := getNameFilterParam(entityCtx.Provider.Name)

repo, err := s.store.GetRepositoryByRepoName(ctx, db.GetRepositoryByRepoNameParams{
Provider: in.GetContext().GetProvider(),
Provider: providerFilter,
RepoOwner: nameParts[0],
RepoName: nameParts[1],
ProjectID: projectID,
Expand Down Expand Up @@ -241,8 +238,10 @@ func (filter *artifactListFilter) listArtifacts(ctx context.Context, provider st
func artifactListRepoFilter(
ctx context.Context, store db.Store, provider string, projectID uuid.UUID, repoSlubList []string,
) ([]*db.Repository, error) {
providerFilter := getNameFilterParam(provider)

repositories, err := store.ListRegisteredRepositoriesByProjectIDAndProvider(ctx,
db.ListRegisteredRepositoriesByProjectIDAndProviderParams{Provider: provider, ProjectID: projectID})
db.ListRegisteredRepositoriesByProjectIDAndProviderParams{Provider: providerFilter, ProjectID: projectID})
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, status.Errorf(codes.NotFound, "repositories not found")
Expand Down
25 changes: 20 additions & 5 deletions internal/controlplane/handlers_oauth.go
Expand Up @@ -41,13 +41,10 @@ import (
"github.com/stacklok/minder/internal/engine"
"github.com/stacklok/minder/internal/logger"
"github.com/stacklok/minder/internal/providers"
"github.com/stacklok/minder/internal/providers/github/oauth"
"github.com/stacklok/minder/internal/util"
pb "github.com/stacklok/minder/pkg/api/protobuf/go/minder/v1"
)

const defaultProvider = oauth.Github

// GetAuthorizationURL returns the URL to redirect the user to for authorization
// and the state to be used for the callback. It accepts a provider string
// and a boolean indicating whether the client is a CLI or web client
Expand Down Expand Up @@ -384,8 +381,17 @@ func (s *Server) StoreProviderToken(ctx context.Context,
in *pb.StoreProviderTokenRequest) (*pb.StoreProviderTokenResponse, error) {
entityCtx := engine.EntityFromContext(ctx)
projectID := entityCtx.Project.ID
providerName := entityCtx.Provider.Name

if providerName == "" {
return nil, status.Errorf(codes.InvalidArgument, "provider name is required")
}

provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, projectID)
provider, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{
// We don't check parent projects here because subprojects should not update the credentials of their parent
Projects: []uuid.UUID{projectID},
Name: providerName,
})
if err != nil {
return nil, providerError(err)
}
Expand Down Expand Up @@ -450,8 +456,17 @@ func (s *Server) VerifyProviderTokenFrom(ctx context.Context,
in *pb.VerifyProviderTokenFromRequest) (*pb.VerifyProviderTokenFromResponse, error) {
entityCtx := engine.EntityFromContext(ctx)
projectID := entityCtx.Project.ID
providerName := entityCtx.Provider.Name

provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, projectID)
if providerName == "" {
return nil, status.Errorf(codes.InvalidArgument, "provider name is required")
}

provider, err := s.store.GetProviderByName(ctx, db.GetProviderByNameParams{
// We don't check parent projects here because subprojects should not check the credentials of their parent
Projects: []uuid.UUID{projectID},
Name: providerName,
})
if err != nil {
return nil, providerError(err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/controlplane/handlers_profile.go
Expand Up @@ -58,7 +58,7 @@ func (s *Server) CreateProfile(ctx context.Context,
}

// TODO: This will be removed once we decouple providers from profiles
provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, entityCtx.Project.ID)
provider, err := getProviderFromRequestOrDefault(ctx, s.store, entityCtx.Provider.Name, entityCtx.Project.ID)
if err != nil {
return nil, providerError(err)
}
Expand Down Expand Up @@ -634,7 +634,7 @@ func (s *Server) UpdateProfile(ctx context.Context,
}

// TODO: This will be removed once we decouple providers from profiles
provider, err := getProviderFromRequestOrDefault(ctx, s.store, in, entityCtx.Project.ID)
provider, err := getProviderFromRequestOrDefault(ctx, s.store, entityCtx.Provider.Name, entityCtx.Project.ID)
if err != nil {
return nil, providerError(err)
}
Expand Down

0 comments on commit 5c381cf

Please sign in to comment.