Skip to content

Commit

Permalink
[bug] - fix data races (#1577)
Browse files Browse the repository at this point in the history
* fix data race.

* Add test and fix additional data race.

* address comments.
  • Loading branch information
ahrav committed Jul 31, 2023
1 parent 406ce7b commit eb00d0d
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 36 deletions.
9 changes: 5 additions & 4 deletions pkg/sources/git/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"runtime"
"strconv"
"strings"
"sync/atomic"
"time"

diskbufferreader "github.com/bill-rich/disk-buffer-reader"
Expand Down Expand Up @@ -56,7 +57,7 @@ type Git struct {
}

type metrics struct {
commitsScanned int
commitsScanned uint64
}

func NewGit(sourceType sourcespb.SourceType, jobID, sourceID int64, sourceName string, verify bool, concurrency int,
Expand Down Expand Up @@ -344,8 +345,8 @@ func CloneRepoUsingSSH(ctx context.Context, gitUrl string, args ...string) (stri
return CloneRepo(ctx, userInfo, gitUrl, args...)
}

func (s *Git) CommitsScanned() int {
return s.metrics.commitsScanned
func (s *Git) CommitsScanned() uint64 {
return atomic.LoadUint64(&s.metrics.commitsScanned)
}

func (s *Git) ScanCommits(ctx context.Context, repo *git.Repository, path string, scanOptions *ScanOptions, chunksChan chan *sources.Chunk) error {
Expand Down Expand Up @@ -380,7 +381,7 @@ func (s *Git) ScanCommits(ctx context.Context, repo *git.Repository, path string
break
}
depth++
s.metrics.commitsScanned++
atomic.AddUint64(&s.metrics.commitsScanned, 1)
logger.V(5).Info("scanning commit", "commit", commit.Hash)
for _, diff := range commit.Diffs {
if !scanOptions.Filter.Pass(diff.PathB) {
Expand Down
67 changes: 41 additions & 26 deletions pkg/sources/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,38 @@ const (
)

type Source struct {
name string
githubUser string
githubToken string
sourceID int64
jobID int64
verify bool
repos []string
members []string
orgsCache cache.Cache
filteredRepoCache *filteredRepoCache
memberCache map[string]struct{}
repoSizes repoSize
totalRepoSize int // total size in bytes of all repos
git *git.Git
scanOptions *git.ScanOptions
httpClient *http.Client
log logr.Logger
conn *sourcespb.GitHub
jobPool *errgroup.Group
resumeInfoMutex sync.Mutex
resumeInfoSlice []string
apiClient *github.Client
mu sync.Mutex
publicMap map[string]source_metadatapb.Visibility
name string
// Protects the user and token.
userMu sync.Mutex
githubUser string
githubToken string

sourceID int64
jobID int64
verify bool
repos []string
members []string
orgsCache cache.Cache
filteredRepoCache *filteredRepoCache
memberCache map[string]struct{}
repoSizes repoSize
totalRepoSize int // total size in bytes of all repos
git *git.Git

scanOptMu sync.Mutex // protects the scanOptions
scanOptions *git.ScanOptions

httpClient *http.Client
log logr.Logger
conn *sourcespb.GitHub
jobPool *errgroup.Group
resumeInfoMutex sync.Mutex
resumeInfoSlice []string
apiClient *github.Client

mu sync.Mutex // protects the visibility maps
publicMap map[string]source_metadatapb.Visibility

includePRComments bool
includeIssueComments bool
includeGistComments bool
Expand All @@ -79,6 +87,13 @@ func (s *Source) WithScanOptions(scanOptions *git.ScanOptions) {
s.scanOptions = scanOptions
}

func (s *Source) setScanOptions(base, head string) {
s.scanOptMu.Lock()
defer s.scanOptMu.Unlock()
s.scanOptions.BaseHash = base
s.scanOptions.HeadHash = head
}

// Ensure the Source satisfies the interfaces at compile time
var _ sources.Source = (*Source)(nil)
var _ sources.SourceUnitUnmarshaller = (*Source)(nil)
Expand Down Expand Up @@ -683,6 +698,7 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch
path, repo, err = s.cloneRepo(ctx, repoURL, installationClient)
if err != nil {
scanErrs.Add(err)
return nil
}

defer os.RemoveAll(path)
Expand All @@ -691,8 +707,7 @@ func (s *Source) scan(ctx context.Context, installationClient *github.Client, ch
return nil
}

s.scanOptions.BaseHash = s.conn.Base
s.scanOptions.HeadHash = s.conn.Head
s.setScanOptions(s.conn.Base, s.conn.Head)

repoSize := s.repoSizes.getRepo(repoURL)
logger.V(2).Info(fmt.Sprintf("scanning repo %d/%d", i, len(s.repos)), "repo_size", repoSize)
Expand Down
71 changes: 71 additions & 0 deletions pkg/sources/github/github_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,77 @@ func TestSource_ScanComments(t *testing.T) {
}
}

func TestSource_ScanChunks(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()

secret, err := common.GetTestSecret(ctx)
if err != nil {
t.Fatal(fmt.Errorf("failed to access secret: %v", err))
}

// For the personal access token test.
githubToken := secret.MustGetField("GITHUB_TOKEN")

type init struct {
name string
verify bool
connection *sourcespb.GitHub
}
tests := []struct {
name string
init init
wantChunks int
}{
{
name: "token authenticated, 4 repos",
init: init{
name: "test source",
connection: &sourcespb.GitHub{
Repositories: []string{
"https://github.com/truffle-test-integration-org/another-test-repo.git",
"https://github.com/trufflesecurity/trufflehog.git",
"https://github.com/Akash-goyal-github/Inventory-Management-System.git",
"https://github.com/R1ck404/Crypto-Exchange-Example.git",
"https://github.com/Stability-AI/generative-models.git",
"https://github.com/bloomberg/blazingmq.git",
"https://github.com/Kong/kong.git",
},
Credential: &sourcespb.GitHub_Token{Token: githubToken},
},
},
wantChunks: 20000,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := Source{}

conn, err := anypb.New(tt.init.connection)
if err != nil {
t.Fatal(err)
}

err = s.Init(ctx, tt.init.name, 0, 0, tt.init.verify, conn, 8)
assert.Nil(t, err)

chunksCh := make(chan *sources.Chunk, 1)
go func() {
defer close(chunksCh)
err = s.Chunks(ctx, chunksCh)
assert.Nil(t, err)
}()

i := 0
for range chunksCh {
i++
}
assert.GreaterOrEqual(t, i, tt.wantChunks)
})
}
}

func TestSource_Scan(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*300)
defer cancel()
Expand Down
22 changes: 16 additions & 6 deletions pkg/sources/github/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,8 @@ func (s *Source) cloneRepo(
}

case *sourcespb.GitHub_Token:
// We never refresh user provided tokens, so if we already have them, we never need to try and fetch them again.
if s.githubUser == "" || s.githubToken == "" {
s.githubUser, s.githubToken, err = s.userAndToken(ctx, installationClient)
if err != nil {
return "", nil, fmt.Errorf("error getting token for repo %s: %w", repoURL, err)
}
if err := s.getUserAndToken(ctx, repoURL, installationClient); err != nil {
return "", nil, fmt.Errorf("error getting token for repo %s: %w", repoURL, err)
}
path, repo, err = git.CloneRepoUsingToken(ctx, s.githubToken, repoURL, s.githubUser)
if err != nil {
Expand All @@ -66,6 +62,20 @@ func (s *Source) cloneRepo(
return path, repo, nil
}

func (s *Source) getUserAndToken(ctx context.Context, repoURL string, installationClient *github.Client) error {
// We never refresh user provided tokens, so if we already have them, we never need to try and fetch them again.
s.userMu.Lock()
defer s.mu.Unlock()
if s.githubUser == "" || s.githubToken == "" {
var err error
s.githubUser, s.githubToken, err = s.userAndToken(ctx, installationClient)
if err != nil {
return fmt.Errorf("error getting token for repo %s: %w", repoURL, err)
}
}
return nil
}

func (s *Source) userAndToken(ctx context.Context, installationClient *github.Client) (string, string, error) {
switch cred := s.conn.GetCredential().(type) {
case *sourcespb.GitHub_BasicAuth:
Expand Down

0 comments on commit eb00d0d

Please sign in to comment.