From ab71b93f7dc2afa64566138bc334f1a083600995 Mon Sep 17 00:00:00 2001 From: Bill Rich Date: Fri, 28 Oct 2022 08:57:55 -0700 Subject: [PATCH] Add context to handler (#877) * Add context to handler * Return rather than break out of select --- pkg/handlers/handlers.go | 19 ++++++++++++++----- pkg/sources/filesystem/filesystem.go | 2 +- pkg/sources/git/git.go | 8 ++++---- pkg/sources/s3/s3.go | 2 +- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/pkg/handlers/handlers.go b/pkg/handlers/handlers.go index 2d120ebe2208..5b78aec67368 100644 --- a/pkg/handlers/handlers.go +++ b/pkg/handlers/handlers.go @@ -1,6 +1,7 @@ package handlers import ( + "context" "io" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" @@ -18,7 +19,7 @@ type Handler interface { New() } -func HandleFile(file io.Reader, chunkSkel *sources.Chunk, chunksChan chan (*sources.Chunk)) bool { +func HandleFile(ctx context.Context, file io.Reader, chunkSkel *sources.Chunk, chunksChan chan (*sources.Chunk)) bool { for _, handler := range DefaultHandlers() { handler.New() var isType bool @@ -27,10 +28,18 @@ func HandleFile(file io.Reader, chunkSkel *sources.Chunk, chunksChan chan (*sour continue } handlerChan := handler.FromFile(file) - for data := range handlerChan { - chunk := *chunkSkel - chunk.Data = data - chunksChan <- &chunk + for { + select { + case data := <-handlerChan: + chunk := *chunkSkel + chunk.Data = data + chunksChan <- &chunk + case <-ctx.Done(): + return false + } + if handlerChan == nil { + break + } } return true } diff --git a/pkg/sources/filesystem/filesystem.go b/pkg/sources/filesystem/filesystem.go index 2a9642fc9bf5..5b71c31de6d6 100644 --- a/pkg/sources/filesystem/filesystem.go +++ b/pkg/sources/filesystem/filesystem.go @@ -131,7 +131,7 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk) err }, Verify: s.verify, } - if handlers.HandleFile(reReader, chunkSkel, chunksChan) { + if handlers.HandleFile(ctx, reReader, chunkSkel, chunksChan) { return nil } diff --git a/pkg/sources/git/git.go b/pkg/sources/git/git.go index fd2afbc98e0a..6839425e53ba 100644 --- a/pkg/sources/git/git.go +++ b/pkg/sources/git/git.go @@ -366,7 +366,7 @@ func (s *Git) ScanCommits(ctx context.Context, repo *git.Repository, path string SourceMetadata: metadata, Verify: s.verify, } - if err := handleBinary(repo, chunksChan, chunkSkel, commitHash, fileName); err != nil { + if err := handleBinary(ctx, repo, chunksChan, chunkSkel, commitHash, fileName); err != nil { log.WithError(err).WithField("file", fileName).Debug("Error handling binary file") } continue @@ -504,7 +504,7 @@ func (s *Git) ScanUnstaged(ctx context.Context, repo *git.Repository, path strin SourceMetadata: metadata, Verify: s.verify, } - if err := handleBinary(repo, chunksChan, chunkSkel, commitHash, fileName); err != nil { + if err := handleBinary(ctx, repo, chunksChan, chunkSkel, commitHash, fileName); err != nil { log.WithError(err).WithField("file", fileName).Debug("Error handling binary file") } continue @@ -782,7 +782,7 @@ func getSafeRemoteURL(repo *git.Repository, preferred string) string { return safeURL } -func handleBinary(repo *git.Repository, chunksChan chan *sources.Chunk, chunkSkel *sources.Chunk, commitHash plumbing.Hash, path string) error { +func handleBinary(ctx context.Context, repo *git.Repository, chunksChan chan *sources.Chunk, chunkSkel *sources.Chunk, commitHash plumbing.Hash, path string) error { log.WithField("path", path).Trace("Binary file found in repository.") commit, err := repo.CommitObject(commitHash) if err != nil { @@ -805,7 +805,7 @@ func handleBinary(repo *git.Repository, chunksChan chan *sources.Chunk, chunkSke return err } - if handlers.HandleFile(reader, chunkSkel, chunksChan) { + if handlers.HandleFile(ctx, reader, chunkSkel, chunksChan) { return nil } diff --git a/pkg/sources/s3/s3.go b/pkg/sources/s3/s3.go index 6c5dbb59c8c4..a50cdc0b2d9b 100644 --- a/pkg/sources/s3/s3.go +++ b/pkg/sources/s3/s3.go @@ -287,7 +287,7 @@ func (s *Source) pageChunker(ctx context.Context, client *s3.S3, chunksChan chan }, Verify: s.verify, } - if handlers.HandleFile(reader, chunkSkel, chunksChan) { + if handlers.HandleFile(ctx, reader, chunkSkel, chunksChan) { return }