From 323c093818842d3e05d92d46a3695b651cbf4629 Mon Sep 17 00:00:00 2001 From: ahrav Date: Wed, 3 May 2023 08:35:53 -0700 Subject: [PATCH] Normalize GitHub repos during enumeration (#1269) * Normalize repos during enumeration. * fix test. * Add benchmark. * Add benchmark. * Add more realistic benchmark values. * add gist mocks. * Remove old normalize fxn. * abstract away the repo cache. * update test. * increase repo count. * increase page limnit to 100. * move callee fxns below caller for Chunks. * Add context to normalize. * remove extra logic in normalize repo. * Delete new.txt * Delete old.txt * Handle errors in a thread safe manner. * fix test.' * fix test. * handle repos that are included by users. * Abstract include ignore logic within repoCache. * Add better comment around repoCache. * Rename params. * remove commented out code. * use repos instead of items. * remove commented out code. * Use ++ instead of atomic increment. * update to use logger var. * use cache pkg. * Address comments. * fix test. * make less sucky test. * Update test. --- pkg/sources/github/github.go | 408 +++++++++--------- pkg/sources/github/github_integration_test.go | 15 +- pkg/sources/github/github_test.go | 334 ++++++++++---- 3 files changed, 468 insertions(+), 289 deletions(-) diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index 882f31f03770..770f90ef9842 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -25,6 +25,8 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "github.com/trufflesecurity/trufflehog/v3/pkg/cache" + "github.com/trufflesecurity/trufflehog/v3/pkg/cache/memory" "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/giturl" @@ -43,28 +45,29 @@ const ( ) type Source struct { - name string - githubUser string - githubToken string - sourceID int64 - jobID int64 - verify bool - repos, - orgs, - members, - includeRepos, - ignoreRepos []string - git *git.Git - scanOptions *git.ScanOptions - httpClient *http.Client - log logr.Logger - conn *sourcespb.GitHub - jobPool *errgroup.Group - resumeInfoMutex sync.Mutex - resumeInfoSlice []string - apiClient *github.Client - mu sync.Mutex - publicMap map[string]source_metadatapb.Visibility + name string + githubUser string + githubToken string + sourceID int64 + jobID int64 + verify bool + repos []string + orgs []string + members []string + orgsCache cache.Cache + filteredRepoCache *filteredRepoCache + memberCache map[string]struct{} + git *git.Git + scanOptions *git.ScanOptions + httpClient *http.Client + log logr.Logger + conn *sourcespb.GitHub + jobPool *errgroup.Group + resumeInfoMutex sync.Mutex + resumeInfoSlice []string + apiClient *github.Client + mu sync.Mutex + publicMap map[string]source_metadatapb.Visibility sources.Progress } @@ -128,6 +131,67 @@ func (s *Source) UserAndToken(ctx context.Context, installationClient *github.Cl return "", "", errors.New("unhandled credential type for token fetch") } +// filteredRepoCache is a wrapper around cache.Cache that filters out repos +// based on include and exclude globs. +type filteredRepoCache struct { + cache.Cache + include, exclude []glob.Glob +} + +func (s *Source) newFilteredRepoCache(c cache.Cache, include, exclude []string) *filteredRepoCache { + includeGlobs := make([]glob.Glob, 0, len(include)) + excludeGlobs := make([]glob.Glob, 0, len(exclude)) + for _, ig := range include { + g, err := glob.Compile(ig) + if err != nil { + s.log.V(1).Info("invalid include glob", "glob", g, "err", err) + } + includeGlobs = append(includeGlobs, g) + } + for _, eg := range exclude { + g, err := glob.Compile(eg) + if err != nil { + s.log.V(1).Info("invalid exclude glob", "glob", g, "err", err) + } + excludeGlobs = append(excludeGlobs, g) + } + return &filteredRepoCache{Cache: c, include: includeGlobs, exclude: excludeGlobs} +} + +// Set overrides the cache.Cache Set method to filter out repos based on +// include and exclude globs. +func (c *filteredRepoCache) Set(key, val string) { + if c.ignoreRepo(key) { + return + } + if !c.includeRepo(key) { + return + } + c.Cache.Set(key, val) +} + +func (c *filteredRepoCache) ignoreRepo(s string) bool { + for _, g := range c.exclude { + if g.Match(s) { + return true + } + } + return false +} + +func (c *filteredRepoCache) includeRepo(s string) bool { + if len(c.include) == 0 { + return true + } + + for _, g := range c.include { + if g.Match(s) { + return true + } + } + return false +} + // Init returns an initialized GitHub source. func (s *Source) Init(aCtx context.Context, name string, jobID, sourceID int64, verify bool, connection *anypb.Any, concurrency int) error { s.log = aCtx.Logger() @@ -149,10 +213,24 @@ func (s *Source) Init(aCtx context.Context, name string, jobID, sourceID int64, } s.conn = &conn + s.filteredRepoCache = s.newFilteredRepoCache(memory.New(), s.conn.IncludeRepos, s.conn.IgnoreRepos) + s.memberCache = make(map[string]struct{}) + s.repos = s.conn.Repositories + for _, repo := range s.repos { + r, err := s.normalizeRepo(repo) + if err != nil { + aCtx.Logger().Error(err, "invalid repository", "repo", repo) + continue + } + s.filteredRepoCache.Set(r, r) + } + s.orgs = s.conn.Organizations - s.includeRepos = s.conn.IncludeRepos - s.ignoreRepos = s.conn.IgnoreRepos + s.orgsCache = memory.New() + for _, org := range s.orgs { + s.orgsCache.Set(org, org) + } // Head or base should only be used with incoming webhooks if (len(s.conn.Head) > 0 || len(s.conn.Base) > 0) && len(s.repos) != 1 { @@ -259,6 +337,48 @@ func (s *Source) visibilityOf(repoURL string) (visibility source_metadatapb.Visi return } +// Chunks emits chunks of bytes over a channel. +func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error { + apiEndpoint := s.conn.Endpoint + if len(apiEndpoint) == 0 || endsWithGithub.MatchString(apiEndpoint) { + apiEndpoint = "https://api.github.com" + } + + installationClient, err := s.enumerate(ctx, apiEndpoint) + if err != nil { + return err + } + + return s.scan(ctx, installationClient, chunksChan) +} + +func (s *Source) enumerate(ctx context.Context, apiEndpoint string) (*github.Client, error) { + var installationClient *github.Client + var err error + + switch cred := s.conn.GetCredential().(type) { + case *sourcespb.GitHub_Unauthenticated: + s.enumerateUnauthenticated(ctx) + case *sourcespb.GitHub_Token: + if err = s.enumerateWithToken(ctx, apiEndpoint, cred.Token); err != nil { + return nil, err + } + case *sourcespb.GitHub_GithubApp: + if installationClient, err = s.enumerateWithApp(ctx, apiEndpoint, cred.GithubApp); err != nil { + return nil, err + } + default: + // TODO: move this error to Init + return nil, errors.Errorf("Invalid configuration given for source. Name: %s, Type: %s", s.name, s.Type()) + } + + s.repos = s.filteredRepoCache.Values() + + // We must sort the repos so we can resume later if necessary. + sort.Strings(s.repos) + return installationClient, nil +} + func (s *Source) enumerateUnauthenticated(ctx context.Context) { s.apiClient = github.NewClient(s.httpClient) if len(s.orgs) > unauthGithubOrgRateLimt { @@ -266,17 +386,17 @@ func (s *Source) enumerateUnauthenticated(ctx context.Context) { } for _, org := range s.orgs { - logger := s.log.WithValues("org", org) - if err := s.addRepos(ctx, org, s.getReposByOrg); err != nil { - logger.Error(err, "error fetching repos for org or user") + if err := s.getReposByOrg(ctx, org); err != nil { + s.log.Error(err, "error fetching repos for org or user") } + // We probably don't need to do this, since getting repos by org makes more sense? - if err := s.addRepos(ctx, org, s.getReposByUser); err != nil { - logger.Error(err, "error fetching repos for org or user") + if err := s.getReposByUser(ctx, org); err != nil { + s.log.Error(err, "error fetching repos for org or user") } if s.conn.ScanUsers { - logger.Info("Enumerating unauthenticated does not support scanning organization members") + s.log.Info("Enumerating unauthenticated does not support scanning organization members") } } } @@ -337,7 +457,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri specificScope = true for _, org := range s.orgs { logger := s.log.WithValues("org", org) - if err := s.addRepos(ctx, org, s.getReposByOrg); err != nil { + if err := s.getReposByOrg(ctx, org); err != nil { logger.Error(err, "error fetching repos for org") } @@ -353,7 +473,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri // If no scope was provided, enumerate them. if !specificScope { - if err := s.addRepos(ctx, ghUser.GetLogin(), s.getReposByUser); err != nil { + if err := s.getReposByUser(ctx, ghUser.GetLogin()); err != nil { s.log.Error(err, "error fetching repos by user") } @@ -365,14 +485,14 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri s.addOrgsByUser(ctx, ghUser.GetLogin()) } - for _, org := range s.orgs { + for _, org := range s.orgsCache.Keys() { logger := s.log.WithValues("org", org) - if err := s.addRepos(ctx, org, s.getReposByOrg); err != nil { + if err := s.getReposByOrg(ctx, org); err != nil { logger.Error(err, "error fetching repos by org") } - if err := s.addRepos(ctx, ghUser.GetLogin(), s.getReposByUser); err != nil { - logger.Error(err, "error fetching repos for user", "user", ghUser.GetLogin()) + if err := s.getReposByUser(ctx, ghUser.GetLogin()); err != nil { + logger.Error(err, "error fetching repos by user") } if s.conn.ScanUsers { @@ -385,7 +505,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri // If we enabled ScanUsers above, we've already added the gists for the current user and users from the orgs. // So if we don't have ScanUsers enabled, add the user gists as normal. - if err := s.addGistsByUser(ctx, ghUser.GetLogin()); err != nil { + if err := s.addUserGistsToCache(ctx, ghUser.GetLogin()); err != nil { s.log.Error(err, "error fetching gists", "user", ghUser.GetLogin()) } @@ -458,10 +578,10 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app * s.log.Info("Scanning repos", "org_members", len(s.members)) for _, member := range s.members { logger := s.log.WithValues("member", member) - if err = s.addGistsByUser(ctx, member); err != nil { + if err := s.getReposByUser(ctx, member); err != nil { logger.Error(err, "error fetching gists by user") } - if err := s.addRepos(ctx, member, s.getReposByUser); err != nil { + if err := s.getReposByUser(ctx, member); err != nil { logger.Error(err, "error fetching repos by user") } } @@ -471,40 +591,6 @@ func (s *Source) enumerateWithApp(ctx context.Context, apiEndpoint string, app * return installationClient, nil } -// Chunks emits chunks of bytes over a channel. -func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) error { - apiEndpoint := s.conn.Endpoint - if len(apiEndpoint) == 0 || endsWithGithub.MatchString(apiEndpoint) { - apiEndpoint = "https://api.github.com" - } - - var installationClient *github.Client - var err error - - switch cred := s.conn.GetCredential().(type) { - case *sourcespb.GitHub_Unauthenticated: - s.enumerateUnauthenticated(ctx) - case *sourcespb.GitHub_Token: - if err = s.enumerateWithToken(ctx, apiEndpoint, cred.Token); err != nil { - return err - } - case *sourcespb.GitHub_GithubApp: - if installationClient, err = s.enumerateWithApp(ctx, apiEndpoint, cred.GithubApp); err != nil { - return err - } - default: - // TODO: move this error to Init - return errors.Errorf("Invalid configuration given for source. Name: %s, Type: %s", s.name, s.Type()) - } - - s.normalizeRepos(ctx) - - // We must sort the repos so we can resume later if necessary. - sort.Strings(s.repos) - - return s.scan(ctx, installationClient, chunksChan) -} - func (s *Source) scan(ctx context.Context, installationClient *github.Client, chunksChan chan *sources.Chunk) error { var scanned uint64 @@ -514,12 +600,12 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch reposToScan, progressIndexOffset := sources.FilterReposToResume(s.repos, s.GetProgress().EncodedResumeInfo) s.repos = reposToScan + scanErrs := sources.NewScanErrors() // Setup scan options if it wasn't provided. if s.scanOptions == nil { s.scanOptions = &git.ScanOptions{} } - scanErrs := sources.NewScanErrors() for i, repoURL := range s.repos { i, repoURL := i, repoURL s.jobPool.Go(func() error { @@ -548,6 +634,10 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch var err error path, repo, err = s.cloneRepo(ctx, repoURL, installationClient) + if err != nil { + scanErrs.Add(err) + } + defer os.RemoveAll(path) if err != nil { scanErrs.Add(fmt.Errorf("error cloning repo %s: %w", repoURL, err)) @@ -653,10 +743,9 @@ func (s *Source) handleRateLimit(errIn error, res *github.Response) bool { return true } -func (s *Source) getReposByOrg(ctx context.Context, org string) ([]string, error) { +func (s *Source) getReposByOrg(ctx context.Context, org string) error { logger := s.log.WithValues("org", org) - var repos []string opts := &github.RepositoryListByOrgOptions{ ListOptions: github.ListOptions{ PerPage: defaultPagination, @@ -673,7 +762,7 @@ func (s *Source) getReposByOrg(ctx context.Context, org string) ([]string, error continue } if err != nil { - return nil, fmt.Errorf("could not list repos for org %s: %w", org, err) + return fmt.Errorf("could not list repos for org %s: %w", org, err) } if len(someRepos) == 0 || res == nil { break @@ -681,48 +770,29 @@ func (s *Source) getReposByOrg(ctx context.Context, org string) ([]string, error logger.V(2).Info("Listed repos", "page", opts.Page, "last_page", res.LastPage) for _, r := range someRepos { - if s.ignoreRepo(r.GetFullName()) { - continue - } - if !s.includeRepo(r.GetFullName()) { - continue - } - - numRepos++ if r.GetFork() { - numForks++ if !s.conn.IncludeForks { continue } + numForks++ } - repos = append(repos, r.GetCloneURL()) + s.filteredRepoCache.Set(r.GetFullName(), r.GetCloneURL()) + numRepos++ } if res.NextPage == 0 { break } opts.Page = res.NextPage } - logger.V(2).Info("found repos", "total", numRepos, "forks", numForks) - return repos, nil -} -func (s *Source) addRepos(ctx context.Context, entity string, getRepos func(context.Context, string) ([]string, error)) error { - repos, err := getRepos(ctx, entity) - if err != nil { - return err - } - // Add the repos to the set of repos. - for _, repo := range repos { - common.AddStringSliceItem(repo, &s.repos) - } + logger.V(2).Info("found repos", "total", numRepos, "forks", numForks) return nil } -func (s *Source) getReposByUser(ctx context.Context, user string) ([]string, error) { - var repos []string +func (s *Source) getReposByUser(ctx context.Context, user string) error { opts := &github.RepositoryListOptions{ ListOptions: github.ListOptions{ - PerPage: 50, + PerPage: defaultPagination, }, } @@ -736,7 +806,7 @@ func (s *Source) getReposByUser(ctx context.Context, user string) ([]string, err continue } if err != nil { - return nil, fmt.Errorf("could not list repos for user %s: %w", user, err) + return fmt.Errorf("could not list repos for user %s: %w", user, err) } if res == nil { break @@ -744,64 +814,24 @@ func (s *Source) getReposByUser(ctx context.Context, user string) ([]string, err logger.V(2).Info("Listed repos", "page", opts.Page, "last_page", res.LastPage) for _, r := range someRepos { - if s.ignoreRepo(r.GetFullName()) { - continue - } - if !s.includeRepo(r.GetFullName()) { - continue - } - if r.GetFork() && !s.conn.IncludeForks { continue } - repos = append(repos, r.GetCloneURL()) + + s.filteredRepoCache.Set(r.GetFullName(), r.GetCloneURL()) } if res.NextPage == 0 { break } opts.Page = res.NextPage } - return repos, nil -} - -func (s *Source) includeRepo(r string) bool { - if len(s.includeRepos) == 0 { - return true - } - logger := s.log.WithValues("repo", r) - for _, include := range s.includeRepos { - g, err := glob.Compile(include) - if err != nil { - logger.V(2).Info("invalid glob", "glob", include, "error", err) - continue - } - if g.Match(r) { - logger.V(2).Info("including repo") - return true - } - } - return false -} - -func (s *Source) ignoreRepo(r string) bool { - logger := s.log.WithValues("repo", r) - for _, ignore := range s.ignoreRepos { - g, err := glob.Compile(ignore) - if err != nil { - logger.V(2).Info("invalid glob", "glob", ignore, "error", err) - continue - } - if g.Match(r) { - logger.V(2).Info("ignoring repo") - return true - } - } - return false + return nil } -func (s *Source) getGistsByUser(ctx context.Context, user string) ([]string, error) { - var gistURLs []string +// addUserGistsToCache collects all the gist urls for a given user, +// and adds them to the filteredRepoCache. +func (s *Source) addUserGistsToCache(ctx context.Context, user string) error { gistOpts := &github.GistListOptions{} logger := s.log.WithValues("user", user) for { @@ -813,10 +843,10 @@ func (s *Source) getGistsByUser(ctx context.Context, user string) ([]string, err continue } if err != nil { - return nil, fmt.Errorf("could not list gists for user %s: %w", user, err) + return fmt.Errorf("could not list gists for user %s: %w", user, err) } for _, gist := range gists { - gistURLs = append(gistURLs, gist.GetGitPullURL()) + s.filteredRepoCache.Set(gist.GetID(), gist.GetGitPullURL()) } if res == nil || res.NextPage == 0 { break @@ -824,18 +854,6 @@ func (s *Source) getGistsByUser(ctx context.Context, user string) ([]string, err logger.V(2).Info("Listed gists", "page", gistOpts.Page, "last_page", res.LastPage) gistOpts.Page = res.NextPage } - return gistURLs, nil -} - -func (s *Source) addGistsByUser(ctx context.Context, user string) error { - gists, err := s.getGistsByUser(ctx, user) - if err != nil { - return err - } - // add the gists to the set of repos - for _, gist := range gists { - common.AddStringSliceItem(gist, &s.repos) - } return nil } @@ -863,10 +881,11 @@ func (s *Source) addMembersByApp(ctx context.Context, installationClient *github } func (s *Source) addReposByApp(ctx context.Context) error { - // Authenticated enumeration of repos + // Authenticated enumeration of repos. opts := &github.ListOptions{ PerPage: defaultPagination, } + for { someRepos, res, err := s.apiClient.Apps.ListRepos(ctx, opts) if err == nil { @@ -886,9 +905,10 @@ func (s *Source) addReposByApp(ctx context.Context) error { if r.GetFork() && !s.conn.IncludeForks { continue } - common.AddStringSliceItem(r.GetCloneURL(), &s.repos) - s.log.V(2).Info("Enumerated repo", "repo", r.GetCloneURL()) + s.filteredRepoCache.Set(r.GetFullName(), r.GetCloneURL()) + s.log.V(2).Info("Enumerated repo", "repo", r.GetFullName()) } + if res.NextPage == 0 { break } @@ -936,8 +956,8 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) { } else { continue } - s.log.V(2).Info("adding organization for repository enumeration", "id", org.ID, "org", name) - common.AddStringSliceItem(name, &s.orgs) + s.orgsCache.Set(name, name) + s.log.V(2).Info("adding organization for repository enumeration", "id", org.ID, "name", name) } } } @@ -964,13 +984,10 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) { } logger.V(2).Info("Listed orgs", "page", orgOpts.Page, "last_page", resp.LastPage) for _, org := range orgs { - var name string - if org.Login != nil { - name = *org.Login - } else { + if org.Login == nil { continue } - common.AddStringSliceItem(name, &s.orgs) + s.orgsCache.Set(*org.Login, *org.Login) } if resp.NextPage == 0 { break @@ -1008,7 +1025,9 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error { if usr == nil || *usr == "" { continue } - common.AddStringSliceItem(*usr, &s.members) + if _, ok := s.memberCache[*usr]; !ok { + s.memberCache[*usr] = struct{}{} + } } if res.NextPage == 0 { break @@ -1021,48 +1040,23 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error { func (s *Source) addReposForMembers(ctx context.Context) { s.log.Info("Fetching repos from members", "members", len(s.members)) - for _, member := range s.members { - if err := s.addGistsByUser(ctx, member); err != nil { + for member := range s.memberCache { + if err := s.addUserGistsToCache(ctx, member); err != nil { s.log.Info("Unable to fetch gists by user", "user", member, "error", err) } - if err := s.addRepos(ctx, member, s.getReposByUser); err != nil { + if err := s.getReposByUser(ctx, member); err != nil { s.log.Info("Unable to fetch repos by user", "user", member, "error", err) } } } -func (s *Source) normalizeRepos(ctx context.Context) { - // TODO: Add check/fix for repos that are missing scheme - normalizedRepos := map[string]struct{}{} - for _, repo := range s.repos { - // If there's a '/', assume it's a URL and try to normalize it. - if strings.ContainsRune(repo, '/') { - repoNormalized, err := giturl.NormalizeGithubRepo(repo) - if err != nil { - s.log.Info("Repo not in expected format", "repo", repo, "error", err) - continue - } - normalizedRepos[repoNormalized] = struct{}{} - continue - } - // Otherwise, assume it's a user and enumerate repositories and gists. - if repos, err := s.getReposByUser(ctx, repo); err == nil { - for _, repo := range repos { - normalizedRepos[repo] = struct{}{} - } - } - if gists, err := s.getGistsByUser(ctx, repo); err == nil { - for _, gist := range gists { - normalizedRepos[gist] = struct{}{} - } - } +func (s *Source) normalizeRepo(repo string) (string, error) { + // If there's a '/', assume it's a URL and try to normalize it. + if strings.ContainsRune(repo, '/') { + return giturl.NormalizeGithubRepo(repo) } - // Replace s.repos. - s.repos = s.repos[:0] - for key := range normalizedRepos { - s.repos = append(s.repos, key) - } + return "", fmt.Errorf("no repositories found for %s", repo) } // setProgressCompleteWithRepo calls the s.SetProgressComplete after safely setting up the encoded resume info string. diff --git a/pkg/sources/github/github_integration_test.go b/pkg/sources/github/github_integration_test.go index c08538e23d42..341de98fe8df 100644 --- a/pkg/sources/github/github_integration_test.go +++ b/pkg/sources/github/github_integration_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/anypb" + "github.com/trufflesecurity/trufflehog/v3/pkg/cache/memory" "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" @@ -52,10 +53,12 @@ func TestSource_Token(t *testing.T) { } s := Source{ - conn: conn, - httpClient: common.SaneHttpClient(), - log: logr.Discard(), + conn: conn, + httpClient: common.SaneHttpClient(), + log: logr.Discard(), + memberCache: map[string]struct{}{}, } + s.filteredRepoCache = s.newFilteredRepoCache(memory.New(), nil, nil) installationClient, err := s.enumerateWithApp(ctx, "https://api.github.com", conn.GetGithubApp()) assert.NoError(t, err) @@ -526,7 +529,7 @@ func TestSource_paginateGists(t *testing.T) { } chunksCh := make(chan *sources.Chunk, 5) go func() { - s.addGistsByUser(ctx, tt.user) + s.addUserGistsToCache(ctx, tt.user) chunksCh <- &sources.Chunk{} }() var wantedRepo string @@ -542,11 +545,11 @@ func TestSource_paginateGists(t *testing.T) { func gistsCheckFunc(expected string, minRepos int, s *Source) sources.ChunkFunc { return func(chunk *sources.Chunk) error { - if minRepos != 0 && minRepos > len(s.repos) { + if minRepos != 0 && minRepos > s.filteredRepoCache.Count() { return fmt.Errorf("didn't find enough repos. expected: %d, got :%d", minRepos, len(s.repos)) } if expected != "" { - for _, repo := range s.repos { + for _, repo := range s.filteredRepoCache.Values() { if repo == expected { return nil } diff --git a/pkg/sources/github/github_test.go b/pkg/sources/github/github_test.go index afbeb2c0d351..5b1ff486cf4c 100644 --- a/pkg/sources/github/github_test.go +++ b/pkg/sources/github/github_test.go @@ -6,14 +6,15 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "fmt" "net/http" "reflect" - "sort" "strconv" "testing" "time" "github.com/go-logr/logr" + "github.com/google/go-cmp/cmp" "github.com/google/go-github/v42/github" "github.com/stretchr/testify/assert" "golang.org/x/sync/errgroup" @@ -65,17 +66,23 @@ func TestAddReposByOrg(t *testing.T) { Get("/orgs/super-secret-org/repos"). Reply(200). JSON([]map[string]string{ - {"clone_url": "super-secret-repo", "name": "super-secret-repo"}, - {"clone_url": "super-secret-repo2", "full_name": "secret/super-secret-repo2"}, + {"clone_url": "https://github.com/super-secret-repo.git", "full_name": "super-secret-repo"}, + {"clone_url": "https://github.com/super-secret-repo2.git", "full_name": "secret/super-secret-repo2"}, }) - s := initTestSource(nil) - s.ignoreRepos = []string{"secret/super-*-repo2"} + s := initTestSource(&sourcespb.GitHub{ + Credential: &sourcespb.GitHub_Token{ + Token: "super secret token", + }, + IncludeRepos: nil, + IgnoreRepos: []string{"secret/super-*-repo2"}, + }) // gock works here because github.NewClient is using the default HTTP Transport - err := s.addRepos(context.TODO(), "super-secret-org", s.getReposByOrg) + err := s.getReposByOrg(context.TODO(), "super-secret-org") assert.Nil(t, err) - assert.Equal(t, 1, len(s.repos)) - assert.Equal(t, []string{"super-secret-repo"}, s.repos) + assert.Equal(t, 1, s.filteredRepoCache.Count()) + ok := s.filteredRepoCache.Exists("super-secret-repo") + assert.True(t, ok) assert.True(t, gock.IsDone()) } @@ -86,21 +93,26 @@ func TestAddReposByOrg_IncludeRepos(t *testing.T) { Get("/orgs/super-secret-org/repos"). Reply(200). JSON([]map[string]string{ - {"clone_url": "super-secret-repo", "full_name": "secret/super-secret-repo"}, - {"clone_url": "super-secret-repo2", "full_name": "secret/super-secret-repo2"}, - {"clone_url": "super-secret-repo2", "full_name": "secret/not-super-secret-repo"}, + {"clone_url": "https://github.com/super-secret-repo.git", "full_name": "secret/super-secret-repo"}, + {"clone_url": "https://github.com/super-secret-repo2.git", "full_name": "secret/super-secret-repo2"}, + {"clone_url": "https://github.com/super-secret-repo2.git", "full_name": "secret/not-super-secret-repo"}, }) - src := &sourcespb.GitHub{ + s := initTestSource(&sourcespb.GitHub{ + Credential: &sourcespb.GitHub_Token{ + Token: "super secret token", + }, + IncludeRepos: []string{"secret/super*"}, Organizations: []string{"super-secret-org"}, - } - s := initTestSource(src) - s.includeRepos = []string{"secret/super*"} + }) // gock works here because github.NewClient is using the default HTTP Transport - err := s.addRepos(context.TODO(), "super-secret-org", s.getReposByOrg) + err := s.getReposByOrg(context.TODO(), "super-secret-org") assert.Nil(t, err) - assert.Equal(t, 2, len(s.repos)) - assert.Equal(t, []string{"super-secret-repo", "super-secret-repo2"}, s.repos) + assert.Equal(t, 2, s.filteredRepoCache.Count()) + ok := s.filteredRepoCache.Exists("secret/super-secret-repo") + assert.True(t, ok) + ok = s.filteredRepoCache.Exists("secret/super-secret-repo2") + assert.True(t, ok) assert.True(t, gock.IsDone()) } @@ -111,16 +123,21 @@ func TestAddReposByUser(t *testing.T) { Get("/users/super-secret-user/repos"). Reply(200). JSON([]map[string]string{ - {"clone_url": "super-secret-repo", "name": "super-secret-repo"}, - {"clone_url": "super-secret-repo2", "full_name": "secret/super-secret-repo2"}, + {"clone_url": "https://github.com/super-secret-repo.git", "full_name": "super-secret-repo"}, + {"clone_url": "https://github.com/super-secret-repo2.git", "full_name": "secret/super-secret-repo2"}, }) - s := initTestSource(nil) - s.ignoreRepos = []string{"secret/super-secret-repo2"} - err := s.addRepos(context.TODO(), "super-secret-user", s.getReposByUser) + s := initTestSource(&sourcespb.GitHub{ + Credential: &sourcespb.GitHub_Token{ + Token: "super secret token", + }, + IgnoreRepos: []string{"secret/super-secret-repo2"}, + }) + err := s.getReposByUser(context.TODO(), "super-secret-user") assert.Nil(t, err) - assert.Equal(t, 1, len(s.repos)) - assert.Equal(t, []string{"super-secret-repo"}, s.repos) + assert.Equal(t, 1, s.filteredRepoCache.Count()) + ok := s.filteredRepoCache.Exists("super-secret-repo") + assert.True(t, ok) assert.True(t, gock.IsDone()) } @@ -130,13 +147,36 @@ func TestAddGistsByUser(t *testing.T) { gock.New("https://api.github.com"). Get("/users/super-secret-user/gists"). Reply(200). - JSON([]map[string]string{{"git_pull_url": "super-secret-gist"}}) + JSON([]map[string]string{{"git_pull_url": "https://githug.com/super-secret-gist.git", "id": "super-secret-gist"}}) s := initTestSource(nil) - err := s.addGistsByUser(context.TODO(), "super-secret-user") + err := s.addUserGistsToCache(context.TODO(), "super-secret-user") assert.Nil(t, err) - assert.Equal(t, 1, len(s.repos)) - assert.Equal(t, []string{"super-secret-gist"}, s.repos) + assert.Equal(t, 1, s.filteredRepoCache.Count()) + ok := s.filteredRepoCache.Exists("super-secret-gist") + assert.True(t, ok) + assert.True(t, gock.IsDone()) +} + +func TestAddMembersByOrg(t *testing.T) { + defer gock.Off() + + gock.New("https://api.github.com"). + Get("/orgs/org1/members"). + Reply(200). + JSON([]map[string]string{ + {"login": "testman1"}, + {"login": "testman2"}, + }) + + s := initTestSource(nil) + err := s.addMembersByOrg(context.TODO(), "org1") + assert.Nil(t, err) + assert.Equal(t, 2, len(s.memberCache)) + _, ok := s.memberCache["testman1"] + assert.True(t, ok) + _, ok = s.memberCache["testman2"] + assert.True(t, ok) assert.True(t, gock.IsDone()) } @@ -161,8 +201,13 @@ func TestAddMembersByApp(t *testing.T) { s := initTestSource(nil) err := s.addMembersByApp(context.TODO(), github.NewClient(nil)) assert.Nil(t, err) - assert.Equal(t, 3, len(s.members)) - assert.Equal(t, []string{"ssm1", "ssm2", "ssm3"}, s.members) + assert.Equal(t, 3, len(s.memberCache)) + _, ok := s.memberCache["ssm1"] + assert.True(t, ok) + _, ok = s.memberCache["ssm2"] + assert.True(t, ok) + _, ok = s.memberCache["ssm3"] + assert.True(t, ok) assert.True(t, gock.IsDone()) } @@ -174,16 +219,19 @@ func TestAddReposByApp(t *testing.T) { Reply(200). JSON(map[string]interface{}{ "repositories": []map[string]string{ - {"clone_url": "ssr1"}, - {"clone_url": "ssr2"}, + {"clone_url": "https://github/ssr1.git", "full_name": "ssr1"}, + {"clone_url": "https://github/ssr2.git", "full_name": "ssr2"}, }, }) s := initTestSource(nil) err := s.addReposByApp(context.TODO()) assert.Nil(t, err) - assert.Equal(t, 2, len(s.repos)) - assert.Equal(t, []string{"ssr1", "ssr2"}, s.repos) + assert.Equal(t, 2, s.filteredRepoCache.Count()) + ok := s.filteredRepoCache.Exists("ssr1") + assert.True(t, ok) + ok = s.filteredRepoCache.Exists("ssr2") + assert.True(t, ok) assert.True(t, gock.IsDone()) } @@ -201,8 +249,9 @@ func TestAddOrgsByUser(t *testing.T) { s := initTestSource(nil) s.addOrgsByUser(context.TODO(), "super-secret-user") - assert.Equal(t, 1, len(s.orgs)) - assert.Equal(t, []string{"sso2"}, s.orgs) + assert.Equal(t, 1, s.orgsCache.Count()) + ok := s.orgsCache.Exists("sso2") + assert.True(t, ok) assert.True(t, gock.IsDone()) } @@ -213,30 +262,15 @@ func TestNormalizeRepos(t *testing.T) { name string setup func() repos []string - expected []string + expected map[string]struct{} + wantErr bool }{ { - name: "repo url", - setup: func() {}, - repos: []string{"https://github.com/super-secret-user/super-secret-repo"}, - expected: []string{"https://github.com/super-secret-user/super-secret-repo.git"}, - }, - { - name: "username with gists", - setup: func() { - gock.New("https://api.github.com"). - Get("/users/super-secret-user/gists"). - Reply(200). - JSON([]map[string]string{{"git_pull_url": "https://github.com/super-secret-user/super-secret-gist.git"}}) - gock.New("https://api.github.com"). - Get("/users/super-secret-user/repos"). - Reply(200). - JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-user/super-secret-repo.git"}}) - }, - repos: []string{"super-secret-user"}, - expected: []string{ - "https://github.com/super-secret-user/super-secret-repo.git", - "https://github.com/super-secret-user/super-secret-gist.git", + name: "repo url", + setup: func() {}, + repos: []string{"https://github.com/super-secret-user/super-secret-repo"}, + expected: map[string]struct{}{ + "https://github.com/super-secret-user/super-secret-repo.git": {}, }, }, { @@ -250,13 +284,15 @@ func TestNormalizeRepos(t *testing.T) { Reply(404) }, repos: []string{"not-found"}, - expected: []string{}, + expected: map[string]struct{}{}, + wantErr: true, }, { name: "unexpected format", setup: func() {}, repos: []string{"/foo/"}, - expected: []string{}, + expected: map[string]struct{}{}, + wantErr: true, }, } @@ -265,16 +301,25 @@ func TestNormalizeRepos(t *testing.T) { defer gock.Off() tt.setup() s := initTestSource(nil) - s.repos = tt.repos - s.normalizeRepos(context.TODO()) - assert.Equal(t, len(tt.expected), len(s.repos)) - // sort and compare - sort.Slice(tt.expected, func(i, j int) bool { return tt.expected[i] < tt.expected[j] }) - sort.Slice(s.repos, func(i, j int) bool { return s.repos[i] < s.repos[j] }) - assert.Equal(t, tt.expected, s.repos) + got, err := s.normalizeRepo(tt.repos[0]) + if (err != nil) != tt.wantErr { + t.Errorf("normalizeRepo() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != "" { + for k := range tt.expected { + assert.Equal(t, got, k) + } + } + res := make(map[string]struct{}, s.filteredRepoCache.Count()) + for _, v := range s.filteredRepoCache.Keys() { + res[v] = struct{}{} + } - assert.True(t, gock.IsDone()) + if got == "" && !cmp.Equal(res, tt.expected) { + t.Errorf("normalizeRepo() got = %v, want %v", s.repos, tt.expected) + } }) } } @@ -296,13 +341,14 @@ func TestEnumerateUnauthenticated(t *testing.T) { gock.New("https://api.github.com"). Get("/orgs/super-secret-org/repos"). Reply(200). - JSON([]map[string]string{{"clone_url": "super-secret-repo"}}) + JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-repo.git", "full_name": "super-secret-repo"}}) s := initTestSource(nil) s.orgs = []string{"super-secret-org"} s.enumerateUnauthenticated(context.TODO()) - assert.Equal(t, 1, len(s.repos)) - assert.Equal(t, []string{"super-secret-repo"}, s.repos) + assert.Equal(t, 1, s.filteredRepoCache.Count()) + ok := s.filteredRepoCache.Exists("super-secret-repo") + assert.True(t, ok) assert.True(t, gock.IsDone()) } @@ -317,21 +363,157 @@ func TestEnumerateWithToken(t *testing.T) { gock.New("https://api.github.com"). Get("/users/super-secret-user/repos"). Reply(200). - JSON([]map[string]string{{"clone_url": "super-secret-repo"}}) + JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-repo.git", "full_name": "super-secret-repo"}}) + + gock.New("https://api.github.com"). + Get("/user/orgs"). + MatchParam("per_page", "100"). + Reply(200). + JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-repo.git", "full_name": "super-secret-repo"}}) gock.New("https://api.github.com"). Get("/users/super-secret-user/gists"). Reply(200). - JSON([]map[string]string{{"clone_url": ""}}) + JSON([]map[string]string{{"git_pull_url": "https://github.com/super-secret-gist.git", "id": "super-secret-gist"}}) s := initTestSource(nil) err := s.enumerateWithToken(context.TODO(), "https://api.github.com", "token") assert.Nil(t, err) - assert.Equal(t, 2, len(s.repos)) - assert.Equal(t, []string{"super-secret-repo", ""}, s.repos) + assert.Equal(t, 2, s.filteredRepoCache.Count()) + ok := s.filteredRepoCache.Exists("super-secret-repo") + assert.True(t, ok) + ok = s.filteredRepoCache.Exists("super-secret-gist") + assert.True(t, ok) assert.True(t, gock.IsDone()) } +func BenchmarkEnumerateWithToken(b *testing.B) { + defer gock.Off() + + gock.New("https://api.github.com"). + Get("/user"). + Reply(200). + JSON(map[string]string{"login": "super-secret-user"}) + + gock.New("https://api.github.com"). + Get("/users/super-secret-user/repos"). + Reply(200). + JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-repo.git"}}) + + gock.New("https://api.github.com"). + Get("/user/orgs"). + MatchParam("per_page", "100"). + Reply(200). + JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-repo.git"}}) + + gock.New("https://api.github.com"). + Get("/users/super-secret-user/gists"). + Reply(200). + JSON([]map[string]string{{"git_pull_url": "https://github.com/super-secret-gist.git"}}) + + s := initTestSource(nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = s.enumerateWithToken(context.TODO(), "https://api.github.com", "token") + } +} + +func TestEnumerate(t *testing.T) { + defer gock.Off() + + gock.New("https://api.github.com"). + Get("/user"). + Reply(200). + JSON(map[string]string{"login": "super-secret-user"}) + + gock.New("https://api.github.com"). + Get("/users/super-secret-user/repos"). + Reply(200). + JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-repo.git", "full_name": "super-secret-repo"}}) + + gock.New("https://api.github.com"). + Get("/user/orgs"). + MatchParam("per_page", "100"). + Reply(200). + JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-repo.git", "full_name": "super-secret-repo"}}) + + gock.New("https://api.github.com"). + Get("/users/super-secret-user/gists"). + Reply(200). + JSON([]map[string]string{{"git_pull_url": "https://github.com/super-secret-gist.git", "id": "super-secret-gist"}}) + + s := initTestSource(&sourcespb.GitHub{ + Credential: &sourcespb.GitHub_Token{ + Token: "super secret token", + }, + }) + + _, err := s.enumerate(context.TODO(), "https://api.github.com") + assert.Nil(t, err) + assert.Equal(t, 2, s.filteredRepoCache.Count()) + ok := s.filteredRepoCache.Exists("super-secret-repo") + assert.True(t, ok) + ok = s.filteredRepoCache.Exists("super-secret-gist") + assert.True(t, ok) + assert.True(t, gock.IsDone()) +} + +func setupMocks(b *testing.B) { + b.Helper() + + gock.New("https://api.github.com"). + Get("/user"). + Reply(200). + JSON(map[string]string{"login": "super-secret-user"}) + + gock.New("https://api.github.com"). + Get("/users/super-secret-user/repos"). + Reply(200). + JSON(mockRepos()) + + gock.New("https://api.github.com"). + Get("/user/orgs"). + MatchParam("per_page", "100"). + Reply(200). + JSON([]map[string]string{{"clone_url": "https://github.com/super-secret-repo.git"}}) + + gock.New("https://api.github.com"). + Get("/users/super-secret-user/gists"). + Reply(200). + JSON(mockGists()) +} + +func mockRepos() []map[string]string { + res := make([]map[string]string, 0, 10000) + for i := 0; i < 10000; i++ { + res = append(res, map[string]string{"clone_url": fmt.Sprintf("https://githu/super-secret-repo-%d.git", i)}) + } + return res +} + +func mockGists() []map[string]string { + res := make([]map[string]string, 0, 100) + for i := 0; i < 100; i++ { + res = append(res, map[string]string{"git_pull_url": fmt.Sprintf("https://githu/super-secret-gist-%d.git", i)}) + } + return res +} + +func BenchmarkEnumerate(b *testing.B) { + for i := 0; i < b.N; i++ { + s := initTestSource(&sourcespb.GitHub{ + Credential: &sourcespb.GitHub_Token{ + Token: "super secret token", + }, + }) + setupMocks(b) + + b.StartTimer() + _, _ = s.enumerate(context.TODO(), "https://api.github.com") + } +} + func TestEnumerateWithToken_IncludeRepos(t *testing.T) { defer gock.Off()