From 26fb270751d27d45640bb53827525e23c4b358a6 Mon Sep 17 00:00:00 2001 From: Richard Gomez Date: Sun, 4 Feb 2024 12:14:34 -0500 Subject: [PATCH] refactor(github): cleanup logic --- pkg/sources/github/github.go | 588 +++++++----------- pkg/sources/github/github_integration_test.go | 10 +- pkg/sources/github/github_test.go | 46 +- pkg/sources/github/repo.go | 66 +- 4 files changed, 305 insertions(+), 405 deletions(-) diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index 555c69ab7d8c..140a9073a7a9 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -46,6 +46,7 @@ const ( type Source struct { name string + // Protects the user and token. userMu sync.Mutex githubUser string @@ -54,14 +55,12 @@ type Source struct { sourceID sources.SourceID jobID sources.JobID verify bool - repos []string orgsCache cache.Cache + memberCache map[string]struct{} + repos []string filteredRepoCache *filteredRepoCache - // repos that _probably_ have wikis (see the comment on hasWiki). - reposWithWikis map[string]struct{} - memberCache map[string]struct{} - repoSizes repoSize - totalRepoSize int // total size of all repos in kb + repoInfoCache *repoInfoCache + totalRepoSize int // total size of all repos in kb useCustomContentWriter bool git *git.Git @@ -77,12 +76,10 @@ type Source struct { resumeInfoSlice []string apiClient *github.Client - mu sync.Mutex // protects the visibility maps - publicMap map[string]source_metadatapb.Visibility - includePRComments bool includeIssueComments bool includeGistComments bool + sources.Progress sources.CommonSourceUnitUnmarshaller } @@ -121,27 +118,6 @@ func (s *Source) JobID() sources.JobID { return s.jobID } -type repoSize struct { - mu sync.RWMutex - repoSizes map[string]int // size in kb of each repo -} - -func (r *repoSize) addRepo(repo string, size int) { - r.mu.Lock() - defer r.mu.Unlock() - r.repoSizes[repo] = size -} - -func (r *repoSize) getRepo(repo string) int { - r.mu.RLock() - defer r.mu.RUnlock() - return r.repoSizes[repo] -} - -func newRepoSize() repoSize { - return repoSize{repoSizes: make(map[string]int)} -} - // filteredRepoCache is a wrapper around cache.Cache that filters out repos // based on include and exclude globs. type filteredRepoCache struct { @@ -207,6 +183,11 @@ func (c *filteredRepoCache) includeRepo(s string) bool { // Init returns an initialized GitHub source. func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, sourceID sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { + err := git.CmdCheck() + if err != nil { + return err + } + s.log = aCtx.Logger() s.name = name @@ -220,20 +201,22 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so s.apiClient = github.NewClient(s.httpClient) var conn sourcespb.GitHub - err := anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}) + err = anypb.UnmarshalTo(connection, &conn, proto.UnmarshalOptions{}) if err != nil { return fmt.Errorf("error unmarshalling connection: %w", err) } s.conn = &conn + s.orgsCache = memory.New() + for _, org := range s.conn.Organizations { + s.orgsCache.Set(org, org) + } + s.memberCache = make(map[string]struct{}) + s.filteredRepoCache = s.newFilteredRepoCache(memory.New(), append(s.conn.GetRepositories(), s.conn.GetIncludeRepos()...), s.conn.GetIgnoreRepos(), ) - s.reposWithWikis = make(map[string]struct{}) - s.memberCache = make(map[string]struct{}) - - s.repoSizes = newRepoSize() s.repos = s.conn.Repositories for _, repo := range s.repos { r, err := s.normalizeRepo(repo) @@ -243,28 +226,17 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so } s.filteredRepoCache.Set(repo, r) } + s.repoInfoCache = newRepoInfoCache() s.includeIssueComments = s.conn.IncludeIssueComments s.includePRComments = s.conn.IncludePullRequestComments s.includeGistComments = s.conn.IncludeGistComments - s.orgsCache = memory.New() - for _, org := range s.conn.Organizations { - s.orgsCache.Set(org, org) - } - // Head or base should only be used with incoming webhooks if (len(s.conn.Head) > 0 || len(s.conn.Base) > 0) && len(s.repos) != 1 { return fmt.Errorf("cannot specify head or base with multiple repositories") } - err = git.CmdCheck() - if err != nil { - return err - } - - s.publicMap = map[string]source_metadatapb.Visibility{} - cfg := &git.Config{ SourceName: s.name, JobID: s.jobID, @@ -356,83 +328,21 @@ func checkGitHubConnection(ctx context.Context, client *github.Client) error { return err } -func (s *Source) visibilityOf(ctx context.Context, repoURL string) (visibility source_metadatapb.Visibility) { +func (s *Source) visibilityOf(ctx context.Context, repoURL string) source_metadatapb.Visibility { // It isn't possible to get the visibility of a wiki. // We must use the visibility of the corresponding repository. if strings.HasSuffix(repoURL, ".wiki.git") { repoURL = strings.TrimSuffix(repoURL, ".wiki.git") + ".git" } - s.mu.Lock() - visibility, ok := s.publicMap[repoURL] - s.mu.Unlock() - if ok { - return visibility - } - - visibility = source_metadatapb.Visibility_public - defer func() { - s.mu.Lock() - s.publicMap[repoURL] = visibility - s.mu.Unlock() - }() - logger := s.log.WithValues("repo", repoURL) - if _, unauthenticated := s.conn.GetCredential().(*sourcespb.GitHub_Unauthenticated); unauthenticated { - logger.V(3).Info("assuming unauthenticated scan has public visibility") - return source_metadatapb.Visibility_public - } - logger.V(2).Info("Checking public status") - u, err := url.Parse(repoURL) - if err != nil { - logger.Error(err, "Could not parse repository URL.") - return - } - - var resp *github.Response - urlPathParts := strings.Split(u.Path, "/") - switch len(urlPathParts) { - case 2: - // Check if repoURL is a gist. - var gist *github.Gist - repoName := urlPathParts[1] - repoName = strings.TrimSuffix(repoName, ".git") - for { - gist, resp, err = s.apiClient.Gists.Get(ctx, repoName) - if !s.handleRateLimit(err, resp) { - break - } - } - if err != nil || gist == nil { - logger.Error(err, "Could not get Github repository") - return - } - if !(*gist.Public) { - visibility = source_metadatapb.Visibility_private - } - case 3: - var repo *github.Repository - owner := urlPathParts[1] - repoName := urlPathParts[2] - repoName = strings.TrimSuffix(repoName, ".git") - for { - repo, resp, err = s.apiClient.Repositories.Get(ctx, owner, repoName) - if !s.handleRateLimit(err, resp) { - break - } - } - if err != nil || repo == nil { - logger.Error(err, "Could not get Github repository") - return - } - if *repo.Private { - visibility = source_metadatapb.Visibility_private - } - default: - logger.Error(fmt.Errorf("unexpected number of parts"), "RepoURL should split into 2 or 3 parts", - "got", len(urlPathParts), - ) + repoInfo, ok := s.repoInfoCache.get(repoURL) + if !ok { + // This should never happen. + err := fmt.Errorf("no repoInfo for URL: %s", repoURL) + ctx.Logger().Error(err, "failed to get repository visibility") + return source_metadatapb.Visibility_unknown } - return + return repoInfo.visibility } const cloudEndpoint = "https://api.github.com" @@ -590,7 +500,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri ctx.Logger().V(1).Info("Enumerating with token", "endpoint", apiEndpoint) for { ghUser, resp, err = s.apiClient.Users.Get(ctx, "") - if handled := s.handleRateLimit(err, resp); handled { + if s.handleRateLimit(err, resp) { continue } if err != nil { @@ -660,7 +570,7 @@ func (s *Source) enumerateWithToken(ctx context.Context, apiEndpoint, token stri } if s.conn.ScanUsers { - s.log.Info("Adding repos", "members", len(s.memberCache), "orgs", s.orgsCache.Count()) + s.log.Info("Adding repos", "orgs", s.orgsCache.Count(), "members", len(s.memberCache)) s.addReposForMembers(ctx) return nil } @@ -757,7 +667,7 @@ func createGitHubClient(httpClient *http.Client, apiEndpoint string) (*github.Cl } func (s *Source) scan(ctx context.Context, installationClient *github.Client, chunksChan chan *sources.Chunk) error { - var scannedCount uint64 + var scannedCount uint64 = 1 s.log.V(2).Info("Found repos to scan", "count", len(s.repos)) @@ -793,36 +703,41 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch } // Scan the repository + repoInfo, ok := s.repoInfoCache.get(repoURL) + if !ok { + // This should never happen. + err := fmt.Errorf("no repoInfo for URL: %s", repoURL) + s.log.Error(err, "failed to scan repository") + return nil + } repoCtx := context.WithValues(ctx, "repo", repoURL) - duration, err := s.cloneAndScanRepo(repoCtx, installationClient, repoURL, chunksChan) + duration, err := s.cloneAndScanRepo(repoCtx, installationClient, repoURL, repoInfo, chunksChan) if err != nil { scanErrs.Add(err) return nil } // Scan the wiki, if enabled, and the repo has one. - if s.conn.IncludeWikis { - if _, ok := s.reposWithWikis[repoURL]; ok { - wikiURL := strings.TrimSuffix(repoURL, ".git") + ".wiki.git" - wikiCtx := context.WithValue(ctx, "repo", wikiURL) - - _, err := s.cloneAndScanRepo(wikiCtx, installationClient, wikiURL, chunksChan) - if err != nil { - scanErrs.Add(err) - // Don't return, it still might be possible to scan comments. - } + if s.conn.IncludeWikis && repoInfo.hasWiki { + wikiURL := strings.TrimSuffix(repoURL, ".git") + ".wiki.git" + wikiCtx := context.WithValue(ctx, "repo", wikiURL) + + _, err := s.cloneAndScanRepo(wikiCtx, installationClient, wikiURL, repoInfo, chunksChan) + if err != nil { + scanErrs.Add(err) + // Don't return, it still might be possible to scan comments. } } // Scan comments, if enabled. if s.includeGistComments || s.includeIssueComments || s.includePRComments { - if err = s.scanComments(ctx, repoURL, chunksChan); err != nil { + if err = s.scanComments(repoCtx, repoURL, repoInfo, chunksChan); err != nil { scanErrs.Add(fmt.Errorf("error scanning comments in repo %s: %w", repoURL, err)) return nil } } - ctx.Logger().V(2).Info(fmt.Sprintf("scanned %d/%d repos", scannedCount, len(s.repos)), "duration_seconds", duration) + repoCtx.Logger().V(1).Info(fmt.Sprintf("scanned %d/%d repos", scannedCount, len(s.repos)), "duration_seconds", duration) githubReposScanned.WithLabelValues(s.name).Inc() atomic.AddUint64(&scannedCount, 1) return nil @@ -833,12 +748,12 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch if scanErrs.Count() > 0 { s.log.V(0).Info("failed to scan some repositories", "error_count", scanErrs.Count(), "errors", scanErrs) } - s.SetProgressComplete(len(s.repos), len(s.repos), "Completed Github scan", "") + s.SetProgressComplete(len(s.repos), len(s.repos), "Completed GitHub scan", "") return nil } -func (s *Source) cloneAndScanRepo(ctx context.Context, client *github.Client, repoURL string, chunksChan chan *sources.Chunk) (time.Duration, error) { +func (s *Source) cloneAndScanRepo(ctx context.Context, client *github.Client, repoURL string, repoInfo *repoInfo, chunksChan chan *sources.Chunk) (time.Duration, error) { var duration time.Duration ctx.Logger().V(2).Info("attempting to clone repo") @@ -853,13 +768,12 @@ func (s *Source) cloneAndScanRepo(ctx context.Context, client *github.Client, re // Repo size is not collected for wikis. var logger logr.Logger - if !strings.HasSuffix(repoURL, ".wiki.git") { - repoSize := s.repoSizes.getRepo(repoURL) - logger = ctx.Logger().WithValues("repo_size_kb", repoSize) + if !strings.HasSuffix(repoURL, ".wiki.git") && repoInfo.size > 0 { + logger = ctx.Logger().WithValues("repo_size_kb", repoInfo.size) } else { logger = ctx.Logger() } - logger.V(2).Info("scanning repo") + logger.V(1).Info("scanning repo") start := time.Now() if err = s.git.ScanRepo(ctx, repo, path, s.scanOptions, sources.ChanReporter{Ch: chunksChan}); err != nil { @@ -940,18 +854,27 @@ func (s *Source) addUserGistsToCache(ctx context.Context, user string) error { logger := s.log.WithValues("user", user) for { gists, res, err := s.apiClient.Gists.List(ctx, user, gistOpts) - if err == nil { - res.Body.Close() - } - if handled := s.handleRateLimit(err, res); handled { + if s.handleRateLimit(err, res) { continue } if err != nil { return fmt.Errorf("could not list gists for user %s: %w", user, err) } + for _, gist := range gists { s.filteredRepoCache.Set(gist.GetID(), gist.GetGitPullURL()) + + info := &repoInfo{ + owner: gist.GetOwner().GetLogin(), + } + if gist.GetPublic() { + info.visibility = source_metadatapb.Visibility_public + } else { + info.visibility = source_metadatapb.Visibility_private + } + s.repoInfoCache.put(gist.GetGitPullURL(), info) } + if res == nil || res.NextPage == 0 { break } @@ -997,19 +920,18 @@ func (s *Source) addAllVisibleOrgs(ctx context.Context) { } for { orgs, resp, err := s.apiClient.Organizations.ListAll(ctx, orgOpts) - if err == nil { - resp.Body.Close() - } - if handled := s.handleRateLimit(err, resp); handled { + if s.handleRateLimit(err, resp) { continue } if err != nil { s.log.Error(err, "could not list all organizations") return } + if len(orgs) == 0 { break } + lastOrgID := *orgs[len(orgs)-1].ID s.log.V(2).Info(fmt.Sprintf("listed organization IDs %d through %d", orgOpts.Since, lastOrgID)) orgOpts.Since = lastOrgID @@ -1037,19 +959,14 @@ func (s *Source) addOrgsByUser(ctx context.Context, user string) { logger := s.log.WithValues("user", user) for { orgs, resp, err := s.apiClient.Organizations.List(ctx, "", orgOpts) - if err == nil { - resp.Body.Close() - } - if handled := s.handleRateLimit(err, resp); handled { + if s.handleRateLimit(err, resp) { continue } if err != nil { logger.Error(err, "Could not list organizations") return } - if resp == nil { - break - } + logger.V(2).Info("Listed orgs", "page", orgOpts.Page, "last_page", resp.LastPage) for _, org := range orgs { if org.Login == nil { @@ -1075,18 +992,13 @@ func (s *Source) addMembersByOrg(ctx context.Context, org string) error { logger := s.log.WithValues("org", org) for { members, res, err := s.apiClient.Organizations.ListMembers(ctx, org, opts) - if err == nil { - defer res.Body.Close() - } - if handled := s.handleRateLimit(err, res); handled { + if s.handleRateLimit(err, res) { continue } if err != nil || len(members) == 0 { return fmt.Errorf("could not list organization members: account may not have access to list organization members %w", err) } - if res == nil { - break - } + logger.V(2).Info("Listed members", "page", opts.Page, "last_page", res.LastPage) for _, m := range members { usr := m.Login @@ -1121,29 +1033,49 @@ func (s *Source) setProgressCompleteWithRepo(index int, offset int, repoURL stri s.SetProgressComplete(index+offset, len(s.repos)+offset, fmt.Sprintf("Repo: %s", repoURL), encodedResumeInfo) } -const initialPage = 1 // page to start listing from - -func (s *Source) scanComments(ctx context.Context, repoPath string, chunksChan chan *sources.Chunk) error { - // Support ssh and https URLs +func (s *Source) scanComments(ctx context.Context, repoPath string, repoInfo *repoInfo, chunksChan chan *sources.Chunk) error { + // Support ssh and https URLs. repoURL, err := git.GitURLParse(repoPath) if err != nil { return err } + urlString := repoURL.String() - trimmedURL := removeURLAndSplit(repoURL.String()) - if repoURL.Host == "gist.github.com" && s.includeGistComments { - return s.processGistComments(ctx, repoPath, trimmedURL, repoURL, chunksChan) + urlParts := trimURLAndSplit(urlString) + if len(urlParts) < 2 || len(urlParts) > 3 { + return fmt.Errorf("invalid repository or gist URL (%s): length of URL segments should be 2 or 3", urlString) } - return s.processRepoComments(ctx, repoPath, trimmedURL, repoURL, chunksChan) + + if s.includeGistComments && urlParts[0] == "gist.github.com" { + return s.processGistComments(ctx, urlString, urlParts, repoInfo, chunksChan) + } else if s.includeIssueComments || s.includePRComments { + return s.processRepoComments(ctx, repoInfo, chunksChan) + } + return nil +} + +// trimURLAndSplit removes extraneous information from the |url| and splits it into segments. +// This is typically 3 segments: host, owner, and name/ID; however, Gists have some edge cases. +// +// Examples: +// - "https://github.com/trufflesecurity/trufflehog" => ["github.com", "trufflesecurity", "trufflehog"] +// - "https://gist.github.com/nat/5fdbb7f945d121f197fb074578e53948" => ["gist.github.com", "nat", "5fdbb7f945d121f197fb074578e53948"] +// - "https://gist.github.com/ff0e5e8dc8ec22f7a25ddfc3492d3451.git" => ["gist.github.com", "ff0e5e8dc8ec22f7a25ddfc3492d3451"] +func trimURLAndSplit(url string) []string { + trimmedURL := strings.TrimPrefix(url, "https://") + trimmedURL = strings.TrimSuffix(trimmedURL, ".git") + splitURL := strings.Split(trimmedURL, "/") + + return splitURL } -func (s *Source) processGistComments(ctx context.Context, repoPath string, trimmedURL []string, repoURL *url.URL, chunksChan chan *sources.Chunk) error { - ctx.Logger().V(2).Info("scanning github gist comments", "repository", repoPath) +const initialPage = 1 // page to start listing from + +func (s *Source) processGistComments(ctx context.Context, gistURL string, urlParts []string, repoInfo *repoInfo, chunksChan chan *sources.Chunk) error { + ctx.Logger().V(2).Info("Scanning GitHub Gist comments") + // GitHub Gist URL. - gistID, err := extractGistID(trimmedURL) - if err != nil { - return err - } + gistID := extractGistID(urlParts) options := &github.ListOptions{ PerPage: defaultPagination, @@ -1152,13 +1084,13 @@ func (s *Source) processGistComments(ctx context.Context, repoPath string, trimm for { comments, resp, err := s.apiClient.Gists.ListComments(ctx, gistID, options) if s.handleRateLimit(err, resp) { - break + continue } if err != nil { return err } - if err = s.chunkGistComments(ctx, repoURL.String(), comments, chunksChan); err != nil { + if err = s.chunkGistComments(ctx, gistURL, repoInfo, comments, chunksChan); err != nil { return err } @@ -1170,17 +1102,47 @@ func (s *Source) processGistComments(ctx context.Context, repoPath string, trimm return nil } -func extractGistID(url []string) (string, error) { - if len(url) < 2 || len(url) > 3 { - return "", fmt.Errorf("failed to parse Gist URL: length of trimmedURL should be 2 or 3") +func extractGistID(url []string) string { + return url[len(url)-1] +} + +func (s *Source) chunkGistComments(ctx context.Context, gistURL string, gistInfo *repoInfo, comments []*github.GistComment, chunksChan chan *sources.Chunk) error { + for _, comment := range comments { + // Create chunk and send it to the channel. + chunk := &sources.Chunk{ + SourceName: s.name, + SourceID: s.SourceID(), + SourceType: s.Type(), + JobID: s.JobID(), + SourceMetadata: &source_metadatapb.MetaData{ + Data: &source_metadatapb.MetaData_Github{ + Github: &source_metadatapb.Github{ + Link: sanitizer.UTF8(comment.GetURL()), + Username: sanitizer.UTF8(comment.GetUser().GetLogin()), + Email: sanitizer.UTF8(comment.GetUser().GetEmail()), + Repository: sanitizer.UTF8(gistURL), + Timestamp: sanitizer.UTF8(comment.GetCreatedAt().String()), + Visibility: gistInfo.visibility, + }, + }, + }, + Data: []byte(sanitizer.UTF8(comment.GetBody())), + Verify: s.verify, + } + + select { + case <-ctx.Done(): + return ctx.Err() + case chunksChan <- chunk: + } } - return url[len(url)-1], nil + return nil } // Note: these can't be consts because the address is needed when using with the GitHub library. var ( // sortType defines the criteria for sorting comments. - // By default comments are sorted by their creation date. + // By default, comments are sorted by their creation date. sortType = "created" // directionType defines the direction of sorting. // "desc" means comments will be sorted in descending order, showing the latest comments first. @@ -1192,34 +1154,9 @@ var ( state = "all" ) -type repoInfo struct { - owner string - repo string - repoPath string - visibility source_metadatapb.Visibility -} - -func (s *Source) processRepoComments(ctx context.Context, repoPath string, trimmedURL []string, repoURL *url.URL, chunksChan chan *sources.Chunk) error { - // Normal repository URL (https://github.com//). - if len(trimmedURL) < 3 { - return fmt.Errorf("url missing owner and/or repo: '%s'", repoURL.String()) - } - owner := trimmedURL[1] - repo := trimmedURL[2] - - if !(s.includeIssueComments || s.includePRComments) { - return nil - } - - repoInfo := repoInfo{ - owner: owner, - repo: repo, - repoPath: repoPath, - visibility: s.visibilityOf(ctx, repoPath), - } - +func (s *Source) processRepoComments(ctx context.Context, repoInfo *repoInfo, chunksChan chan *sources.Chunk) error { if s.includeIssueComments { - ctx.Logger().V(2).Info("scanning github issues", "repository", repoInfo.repoPath) + ctx.Logger().V(2).Info("Scanning issues") if err := s.processIssues(ctx, repoInfo, chunksChan); err != nil { return err } @@ -1229,7 +1166,7 @@ func (s *Source) processRepoComments(ctx context.Context, repoPath string, trimm } if s.includePRComments { - ctx.Logger().V(2).Info("scanning github pull requests", "repository", repoInfo.repoPath) + ctx.Logger().V(2).Info("Scanning pull requests") if err := s.processPRs(ctx, repoInfo, chunksChan); err != nil { return err } @@ -1242,7 +1179,7 @@ func (s *Source) processRepoComments(ctx context.Context, repoPath string, trimm } -func (s *Source) processIssues(ctx context.Context, info repoInfo, chunksChan chan *sources.Chunk) error { +func (s *Source) processIssues(ctx context.Context, repoInfo *repoInfo, chunksChan chan *sources.Chunk) error { bodyTextsOpts := &github.IssueListByRepoOptions{ Sort: sortType, Direction: directionType, @@ -1254,16 +1191,16 @@ func (s *Source) processIssues(ctx context.Context, info repoInfo, chunksChan ch } for { - issues, resp, err := s.apiClient.Issues.ListByRepo(ctx, info.owner, info.repo, bodyTextsOpts) + issues, resp, err := s.apiClient.Issues.ListByRepo(ctx, repoInfo.owner, repoInfo.name, bodyTextsOpts) if s.handleRateLimit(err, resp) { - break + continue } if err != nil { return err } - if err = s.chunkIssues(ctx, info, issues, chunksChan); err != nil { + if err = s.chunkIssues(ctx, repoInfo, issues, chunksChan); err != nil { return err } @@ -1276,7 +1213,46 @@ func (s *Source) processIssues(ctx context.Context, info repoInfo, chunksChan ch return nil } -func (s *Source) processIssueComments(ctx context.Context, info repoInfo, chunksChan chan *sources.Chunk) error { +func (s *Source) chunkIssues(ctx context.Context, repoInfo *repoInfo, issues []*github.Issue, chunksChan chan *sources.Chunk) error { + for _, issue := range issues { + + // Skip pull requests since covered by processPRs. + if issue.IsPullRequest() { + continue + } + + // Create chunk and send it to the channel. + chunk := &sources.Chunk{ + SourceName: s.name, + SourceID: s.SourceID(), + JobID: s.JobID(), + SourceType: s.Type(), + SourceMetadata: &source_metadatapb.MetaData{ + Data: &source_metadatapb.MetaData_Github{ + Github: &source_metadatapb.Github{ + Link: sanitizer.UTF8(issue.GetHTMLURL()), + Username: sanitizer.UTF8(issue.GetUser().GetLogin()), + Email: sanitizer.UTF8(issue.GetUser().GetEmail()), + Repository: sanitizer.UTF8(repoInfo.fullName), + Timestamp: sanitizer.UTF8(issue.GetCreatedAt().String()), + Visibility: repoInfo.visibility, + }, + }, + }, + Data: []byte(sanitizer.UTF8(issue.GetTitle() + "\n" + issue.GetBody())), + Verify: s.verify, + } + + select { + case <-ctx.Done(): + return ctx.Err() + case chunksChan <- chunk: + } + } + return nil +} + +func (s *Source) processIssueComments(ctx context.Context, repoInfo *repoInfo, chunksChan chan *sources.Chunk) error { issueOpts := &github.IssueListCommentsOptions{ Sort: &sortType, Direction: &directionType, @@ -1287,16 +1263,16 @@ func (s *Source) processIssueComments(ctx context.Context, info repoInfo, chunks } for { - issueComments, resp, err := s.apiClient.Issues.ListComments(ctx, info.owner, info.repo, allComments, issueOpts) + issueComments, resp, err := s.apiClient.Issues.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, issueOpts) if s.handleRateLimit(err, resp) { - break + continue } if err != nil { return err } - if err = s.chunkIssueComments(ctx, info, issueComments, chunksChan); err != nil { + if err = s.chunkIssueComments(ctx, repoInfo, issueComments, chunksChan); err != nil { return err } @@ -1309,7 +1285,40 @@ func (s *Source) processIssueComments(ctx context.Context, info repoInfo, chunks return nil } -func (s *Source) processPRs(ctx context.Context, info repoInfo, chunksChan chan *sources.Chunk) error { +func (s *Source) chunkIssueComments(ctx context.Context, repoInfo *repoInfo, comments []*github.IssueComment, chunksChan chan *sources.Chunk) error { + for _, comment := range comments { + // Create chunk and send it to the channel. + chunk := &sources.Chunk{ + SourceName: s.name, + SourceID: s.SourceID(), + JobID: s.JobID(), + SourceType: s.Type(), + SourceMetadata: &source_metadatapb.MetaData{ + Data: &source_metadatapb.MetaData_Github{ + Github: &source_metadatapb.Github{ + Link: sanitizer.UTF8(comment.GetHTMLURL()), + Username: sanitizer.UTF8(comment.GetUser().GetLogin()), + Email: sanitizer.UTF8(comment.GetUser().GetEmail()), + Repository: sanitizer.UTF8(repoInfo.fullName), + Timestamp: sanitizer.UTF8(comment.GetCreatedAt().String()), + Visibility: repoInfo.visibility, + }, + }, + }, + Data: []byte(sanitizer.UTF8(comment.GetBody())), + Verify: s.verify, + } + + select { + case <-ctx.Done(): + return ctx.Err() + case chunksChan <- chunk: + } + } + return nil +} + +func (s *Source) processPRs(ctx context.Context, repoInfo *repoInfo, chunksChan chan *sources.Chunk) error { prOpts := &github.PullRequestListOptions{ Sort: sortType, Direction: directionType, @@ -1321,16 +1330,16 @@ func (s *Source) processPRs(ctx context.Context, info repoInfo, chunksChan chan } for { - prs, resp, err := s.apiClient.PullRequests.List(ctx, info.owner, info.repo, prOpts) + prs, resp, err := s.apiClient.PullRequests.List(ctx, repoInfo.owner, repoInfo.name, prOpts) if s.handleRateLimit(err, resp) { - break + continue } if err != nil { return err } - if err = s.chunkPullRequests(ctx, info, prs, chunksChan); err != nil { + if err = s.chunkPullRequests(ctx, repoInfo, prs, chunksChan); err != nil { return err } @@ -1343,7 +1352,7 @@ func (s *Source) processPRs(ctx context.Context, info repoInfo, chunksChan chan return nil } -func (s *Source) processPRComments(ctx context.Context, info repoInfo, chunksChan chan *sources.Chunk) error { +func (s *Source) processPRComments(ctx context.Context, repoInfo *repoInfo, chunksChan chan *sources.Chunk) error { prOpts := &github.PullRequestListCommentsOptions{ Sort: sortType, Direction: directionType, @@ -1354,16 +1363,15 @@ func (s *Source) processPRComments(ctx context.Context, info repoInfo, chunksCha } for { - prComments, resp, err := s.apiClient.PullRequests.ListComments(ctx, info.owner, info.repo, allComments, prOpts) + prComments, resp, err := s.apiClient.PullRequests.ListComments(ctx, repoInfo.owner, repoInfo.name, allComments, prOpts) if s.handleRateLimit(err, resp) { - break + continue } - if err != nil { return err } - if err = s.chunkPullRequestComments(ctx, info, prComments, chunksChan); err != nil { + if err = s.chunkPullRequestComments(ctx, repoInfo, prComments, chunksChan); err != nil { return err } @@ -1376,112 +1384,7 @@ func (s *Source) processPRComments(ctx context.Context, info repoInfo, chunksCha return nil } -func (s *Source) chunkIssues(ctx context.Context, repoInfo repoInfo, issues []*github.Issue, chunksChan chan *sources.Chunk) error { - for _, issue := range issues { - - // Skip pull requests since covered by processPRs. - if issue.IsPullRequest() { - continue - } - - // Create chunk and send it to the channel. - chunk := &sources.Chunk{ - SourceName: s.name, - SourceID: s.SourceID(), - JobID: s.JobID(), - SourceType: s.Type(), - SourceMetadata: &source_metadatapb.MetaData{ - Data: &source_metadatapb.MetaData_Github{ - Github: &source_metadatapb.Github{ - Link: sanitizer.UTF8(issue.GetHTMLURL()), - Username: sanitizer.UTF8(issue.GetUser().GetLogin()), - Email: sanitizer.UTF8(issue.GetUser().GetEmail()), - Repository: sanitizer.UTF8(repoInfo.repo), - Timestamp: sanitizer.UTF8(issue.GetCreatedAt().String()), - Visibility: repoInfo.visibility, - }, - }, - }, - Data: []byte(sanitizer.UTF8(issue.GetTitle() + "\n" + issue.GetBody())), - Verify: s.verify, - } - - select { - case <-ctx.Done(): - return ctx.Err() - case chunksChan <- chunk: - } - } - return nil -} - -func (s *Source) chunkIssueComments(ctx context.Context, repoInfo repoInfo, comments []*github.IssueComment, chunksChan chan *sources.Chunk) error { - for _, comment := range comments { - // Create chunk and send it to the channel. - chunk := &sources.Chunk{ - SourceName: s.name, - SourceID: s.SourceID(), - JobID: s.JobID(), - SourceType: s.Type(), - SourceMetadata: &source_metadatapb.MetaData{ - Data: &source_metadatapb.MetaData_Github{ - Github: &source_metadatapb.Github{ - Link: sanitizer.UTF8(comment.GetHTMLURL()), - Username: sanitizer.UTF8(comment.GetUser().GetLogin()), - Email: sanitizer.UTF8(comment.GetUser().GetEmail()), - Repository: sanitizer.UTF8(repoInfo.repo), - Timestamp: sanitizer.UTF8(comment.GetCreatedAt().String()), - Visibility: repoInfo.visibility, - }, - }, - }, - Data: []byte(sanitizer.UTF8(comment.GetBody())), - Verify: s.verify, - } - - select { - case <-ctx.Done(): - return ctx.Err() - case chunksChan <- chunk: - } - } - return nil -} - -func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo repoInfo, comments []*github.PullRequestComment, chunksChan chan *sources.Chunk) error { - for _, comment := range comments { - // Create chunk and send it to the channel. - chunk := &sources.Chunk{ - SourceName: s.name, - SourceID: s.SourceID(), - SourceType: s.Type(), - JobID: s.JobID(), - SourceMetadata: &source_metadatapb.MetaData{ - Data: &source_metadatapb.MetaData_Github{ - Github: &source_metadatapb.Github{ - Link: sanitizer.UTF8(comment.GetHTMLURL()), - Username: sanitizer.UTF8(comment.GetUser().GetLogin()), - Email: sanitizer.UTF8(comment.GetUser().GetEmail()), - Repository: sanitizer.UTF8(repoInfo.repo), - Timestamp: sanitizer.UTF8(comment.GetCreatedAt().String()), - Visibility: repoInfo.visibility, - }, - }, - }, - Data: []byte(sanitizer.UTF8(comment.GetBody())), - Verify: s.verify, - } - - select { - case <-ctx.Done(): - return ctx.Err() - case chunksChan <- chunk: - } - } - return nil -} - -func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs []*github.PullRequest, chunksChan chan *sources.Chunk) error { +func (s *Source) chunkPullRequests(ctx context.Context, repoInfo *repoInfo, prs []*github.PullRequest, chunksChan chan *sources.Chunk) error { for _, pr := range prs { // Create chunk and send it to the channel. chunk := &sources.Chunk{ @@ -1495,7 +1398,7 @@ func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs [ Link: sanitizer.UTF8(pr.GetHTMLURL()), Username: sanitizer.UTF8(pr.GetUser().GetLogin()), Email: sanitizer.UTF8(pr.GetUser().GetEmail()), - Repository: sanitizer.UTF8(repoInfo.repo), + Repository: sanitizer.UTF8(repoInfo.fullName), Timestamp: sanitizer.UTF8(pr.GetCreatedAt().String()), Visibility: repoInfo.visibility, }, @@ -1514,7 +1417,7 @@ func (s *Source) chunkPullRequests(ctx context.Context, repoInfo repoInfo, prs [ return nil } -func (s *Source) chunkGistComments(ctx context.Context, gistUrl string, comments []*github.GistComment, chunksChan chan *sources.Chunk) error { +func (s *Source) chunkPullRequestComments(ctx context.Context, repoInfo *repoInfo, comments []*github.PullRequestComment, chunksChan chan *sources.Chunk) error { for _, comment := range comments { // Create chunk and send it to the channel. chunk := &sources.Chunk{ @@ -1525,13 +1428,12 @@ func (s *Source) chunkGistComments(ctx context.Context, gistUrl string, comments SourceMetadata: &source_metadatapb.MetaData{ Data: &source_metadatapb.MetaData_Github{ Github: &source_metadatapb.Github{ - Link: sanitizer.UTF8(comment.GetURL()), + Link: sanitizer.UTF8(comment.GetHTMLURL()), Username: sanitizer.UTF8(comment.GetUser().GetLogin()), Email: sanitizer.UTF8(comment.GetUser().GetEmail()), - Repository: sanitizer.UTF8(gistUrl), + Repository: sanitizer.UTF8(repoInfo.fullName), Timestamp: sanitizer.UTF8(comment.GetCreatedAt().String()), - // TODO: Fetching this requires making an additional API call. We may want to include this in the future. - // Visibility: s.visibilityOf(ctx, repoPath), + Visibility: repoInfo.visibility, }, }, }, @@ -1602,11 +1504,3 @@ func (s *Source) scanTarget(ctx context.Context, target sources.ChunkingTarget, return common.CancellableWrite(ctx, chunksChan, chunk) } - -func removeURLAndSplit(url string) []string { - trimmedURL := strings.TrimPrefix(url, "https://") - trimmedURL = strings.TrimSuffix(trimmedURL, ".git") - splitURL := strings.Split(trimmedURL, "/") - - return splitURL -} diff --git a/pkg/sources/github/github_integration_test.go b/pkg/sources/github/github_integration_test.go index d76b5b3f922c..2b0cddb98506 100644 --- a/pkg/sources/github/github_integration_test.go +++ b/pkg/sources/github/github_integration_test.go @@ -53,11 +53,11 @@ func TestSource_Token(t *testing.T) { } s := Source{ - conn: conn, - httpClient: common.SaneHttpClient(), - log: logr.Discard(), - memberCache: map[string]struct{}{}, - repoSizes: newRepoSize(), + conn: conn, + httpClient: common.SaneHttpClient(), + log: logr.Discard(), + memberCache: map[string]struct{}{}, + repoInfoCache: newRepoInfoCache(), } s.filteredRepoCache = s.newFilteredRepoCache(memory.New(), nil, nil) diff --git a/pkg/sources/github/github_test.go b/pkg/sources/github/github_test.go index 8b5fd27e1909..a94adf82cd74 100644 --- a/pkg/sources/github/github_test.go +++ b/pkg/sources/github/github_test.go @@ -8,10 +8,8 @@ import ( "encoding/pem" "fmt" "net/http" - "net/url" "reflect" "strconv" - "strings" "testing" "time" @@ -27,7 +25,6 @@ import ( "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" - "github.com/trufflesecurity/trufflehog/v3/pkg/sources" ) func createTestSource(src *sourcespb.GitHub) (*Source, *anypb.Any) { @@ -713,52 +710,19 @@ func Test_scan_SetProgressComplete(t *testing.T) { } } -func TestProcessRepoComments(t *testing.T) { - tests := []struct { - name string - trimmedURL []string - wantErr bool - }{ - { - name: "URL with missing owner and/or repo", - trimmedURL: []string{"https://github.com/"}, - wantErr: true, - }, - { - name: "URL with complete owner and repo", - trimmedURL: []string{"https://github.com/", "owner", "repo"}, - wantErr: false, - }, - // TODO: Add more test cases to cover other scenarios. - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - s := &Source{} - repoURL, _ := url.Parse(strings.Join(tt.trimmedURL, "/")) - chunksChan := make(chan *sources.Chunk) - - err := s.processRepoComments(context.Background(), "repoPath", tt.trimmedURL, repoURL, chunksChan) - assert.Equal(t, tt.wantErr, err != nil) - }) - } -} - func TestGetGistID(t *testing.T) { tests := []struct { trimmedURL []string expected string - err bool }{ - {[]string{"https://gist.github.com", "12345"}, "12345", false}, - {[]string{"https://gist.github.com", "owner", "12345"}, "12345", false}, - {[]string{"https://gist.github.com"}, "", true}, - {[]string{"https://gist.github.com", "owner", "12345", "extra"}, "", true}, + {[]string{"https://gist.github.com", "12345"}, "12345"}, + {[]string{"https://gist.github.com", "owner", "12345"}, "12345"}, + {[]string{"https://gist.github.com"}, ""}, + {[]string{"https://gist.github.com", "owner", "12345", "extra"}, ""}, } for _, tt := range tests { - got, err := extractGistID(tt.trimmedURL) - assert.Equal(t, tt.err, err != nil) + got := extractGistID(tt.trimmedURL) assert.Equal(t, tt.expected, got) } } diff --git a/pkg/sources/github/repo.go b/pkg/sources/github/repo.go index f76b3100643b..13b580d13609 100644 --- a/pkg/sources/github/repo.go +++ b/pkg/sources/github/repo.go @@ -2,19 +2,56 @@ package github import ( "fmt" + "io" "net/http" "strconv" "strings" + "sync" gogit "github.com/go-git/go-git/v5" "github.com/google/go-github/v42/github" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/giturl" + "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sources/git" ) +type repoInfoCache struct { + mu sync.RWMutex + cache map[string]*repoInfo +} + +func newRepoInfoCache() *repoInfoCache { + return &repoInfoCache{ + cache: make(map[string]*repoInfo), + } +} + +func (r *repoInfoCache) put(repoURL string, info *repoInfo) { + r.mu.Lock() + defer r.mu.Unlock() + r.cache[repoURL] = info +} + +func (r *repoInfoCache) get(repoURL string) (*repoInfo, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + + info, ok := r.cache[repoURL] + return info, ok +} + +type repoInfo struct { + owner string + name string + fullName string + hasWiki bool // the repo is _likely_ to have a wiki (see the comment on hasWiki func). + size int + visibility source_metadatapb.Visibility +} + func (s *Source) cloneRepo( ctx context.Context, repoURL string, @@ -103,7 +140,7 @@ func (s *Source) userAndToken(ctx context.Context, installationClient *github.Cl ) for { ghUser, resp, err = s.apiClient.Users.Get(ctx, "") - if handled := s.handleRateLimit(err, resp); handled { + if s.handleRateLimit(err, resp) { continue } if err != nil { @@ -204,18 +241,12 @@ func (s *Source) processRepos(ctx context.Context, target string, listRepos repo for { someRepos, res, err := listRepos(ctx, target, listOpts) - if err == nil { - res.Body.Close() - } - if handled := s.handleRateLimit(err, res); handled { + if s.handleRateLimit(err, res) { continue } if err != nil { return err } - if res == nil { - break - } s.log.V(2).Info("Listed repos", "page", opts.Page, "last_page", res.LastPage) for _, r := range someRepos { @@ -232,12 +263,22 @@ func (s *Source) processRepos(ctx context.Context, target string, listRepos repo } repoName, repoURL := r.GetFullName(), r.GetCloneURL() - s.repoSizes.addRepo(repoURL, r.GetSize()) s.totalRepoSize += r.GetSize() s.filteredRepoCache.Set(repoName, repoURL) - if s.conn.GetIncludeWikis() && s.hasWiki(ctx, r, repoURL) { - s.reposWithWikis[repoURL] = struct{}{} + + info := &repoInfo{ + owner: r.GetOwner().GetLogin(), + name: r.GetName(), + fullName: r.GetFullName(), + hasWiki: s.conn.GetIncludeWikis() && s.hasWiki(ctx, r, repoURL), + size: r.GetSize(), + } + if r.GetPrivate() { + info.visibility = source_metadatapb.Visibility_private + } else { + info.visibility = source_metadatapb.Visibility_public } + s.repoInfoCache.put(repoURL, info) logger.V(3).Info("repo attributes", "name", repoName, "kb_size", r.GetSize(), "repo_url", repoURL) } @@ -270,6 +311,7 @@ func (s *Source) hasWiki(ctx context.Context, repo *github.Repository, repoURL s if err != nil { return false } + _, _ = io.Copy(io.Discard, res.Body) _ = res.Body.Close() // If the wiki is disabled, or is enabled but has no content, the request should be redirected. @@ -288,7 +330,7 @@ type commitQuery struct { // If the file or its diff is not found, it returns an error. func (s *Source) getDiffForFileInCommit(ctx context.Context, query commitQuery) (string, error) { commit, resp, err := s.apiClient.Repositories.GetCommit(ctx, query.owner, query.repo, query.sha, nil) - if handled := s.handleRateLimit(err, resp); handled { + if s.handleRateLimit(err, resp) { return "", fmt.Errorf("error fetching commit %s due to rate limit: %w", query.sha, err) } if err != nil {