Skip to content

Commit

Permalink
Refactor Engine to wait for workers in a Finish method (#581)
Browse files Browse the repository at this point in the history
* Refactor Engine to wait for workers in a Finish method

This should allow the engine to run multiple concurrent scans if
desired before shutting down.

Additionally, this commit refactors some of the printing logic to the
output package.

* Fix tests
  • Loading branch information
mcastorina committed May 25, 2022
1 parent aff0792 commit 6fa2171
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 88 deletions.
63 changes: 6 additions & 57 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
Expand All @@ -26,9 +25,6 @@ import (
"github.com/trufflesecurity/trufflehog/v3/pkg/decoders"
"github.com/trufflesecurity/trufflehog/v3/pkg/engine"
"github.com/trufflesecurity/trufflehog/v3/pkg/output"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/git"
)

Expand Down Expand Up @@ -220,11 +216,15 @@ func run(state overseer.State) {
logrus.WithError(err).Fatal("Failed to scan syslog.")
}
}
// asynchronously wait for scanning to finish and cleanup
go e.Finish()

if !*jsonLegacy && !*jsonOut {
fmt.Fprintf(os.Stderr, "🐷🔑🐷 TruffleHog. Unearth your secrets. 🐷🔑🐷\n\n")
}

// NOTE: this loop will terminate when the results channel is closed in
// e.Finish()
foundResults := false
for r := range e.ResultsChan() {
if *onlyVerified && !r.Verified {
Expand All @@ -234,60 +234,9 @@ func run(state overseer.State) {

switch {
case *jsonLegacy:
repoPath, remote, err = git.PrepareRepo(r.SourceMetadata.GetGithub().Repository)
if err != nil || repoPath == "" {
logrus.WithError(err).Fatal("error preparing git repo for scanning")
}
legacy := output.ConvertToLegacyJSON(&r, repoPath)
out, err := json.Marshal(legacy)
if err != nil {
logrus.WithError(err).Fatal("could not marshal result")
}
fmt.Println(string(out))

if remote {
os.RemoveAll(repoPath)
}
output.PrintLegacyJSON(&r)
case *jsonOut:
v := &struct {
// SourceMetadata contains source-specific contextual information.
SourceMetadata *source_metadatapb.MetaData
// SourceID is the ID of the source that the API uses to map secrets to specific sources.
SourceID int64
// SourceType is the type of Source.
SourceType sourcespb.SourceType
// SourceName is the name of the Source.
SourceName string
// DetectorType is the type of Detector.
DetectorType detectorspb.DetectorType
// DetectorName is the string name of the DetectorType.
DetectorName string
Verified bool
// Raw contains the raw secret identifier data. Prefer IDs over secrets since it is used for deduping after hashing.
Raw []byte
// Redacted contains the redacted version of the raw secret identification data for display purposes.
// A secret ID should be used if available.
Redacted string
ExtraData map[string]string
StructuredData *detectorspb.StructuredData
}{
SourceMetadata: r.SourceMetadata,
SourceID: r.SourceID,
SourceType: r.SourceType,
SourceName: r.SourceName,
DetectorType: r.DetectorType,
DetectorName: r.DetectorType.String(),
Verified: r.Verified,
Raw: r.Raw,
Redacted: r.Redacted,
ExtraData: r.ExtraData,
StructuredData: r.StructuredData,
}
out, err := json.Marshal(v)
if err != nil {
logrus.WithError(err).Fatal("could not marshal result")
}
fmt.Println(string(out))
output.PrintJSON(&r)
default:
output.PrintPlainOutput(&r)
}
Expand Down
45 changes: 28 additions & 17 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ type Engine struct {
detectors map[bool][]detectors.Detector
chunksScanned uint64
detectorAvgTime sync.Map
sourcesWg sync.WaitGroup
workersWg sync.WaitGroup
}

type EngineOption func(*Engine)
Expand Down Expand Up @@ -75,15 +77,6 @@ func Start(ctx context.Context, options ...EngineOption) *Engine {
}
logrus.Debugf("running with up to %d workers", e.concurrency)

var workerWg sync.WaitGroup
for i := 0; i < e.concurrency; i++ {
workerWg.Add(1)
go func() {
e.detectorWorker(ctx)
workerWg.Done()
}()
}

if len(e.decoders) == 0 {
e.decoders = decoders.DefaultDecoders()
}
Expand All @@ -101,18 +94,36 @@ func Start(ctx context.Context, options ...EngineOption) *Engine {
len(e.detectors[false]))

// start the workers
go func() {
// close results chan when all workers are done
workerWg.Wait()
// not entirely sure why results don't get processed without this pause
// since we've put all results on the channel at this point.
time.Sleep(time.Second)
close(e.ResultsChan())
}()
for i := 0; i < e.concurrency; i++ {
e.workersWg.Add(1)
go func() {
defer e.workersWg.Done()
e.detectorWorker(ctx)
}()
}

return e
}

// Finish waits for running sources to complete and workers to finish scanning
// chunks before closing their respective channels. Once Finish is called, no
// more sources may be scanned by the engine.
func (e *Engine) Finish() {
// wait for the sources to finish putting chunks onto the chunks channel
e.sourcesWg.Wait()
close(e.chunks)
// wait for the workers to finish processing all of the chunks and putting
// results onto the results channel
e.workersWg.Wait()

// TODO: re-evaluate whether this is needed and investigate why if so
//
// not entirely sure why results don't get processed without this pause
// since we've put all results on the channel at this point.
time.Sleep(time.Second)
close(e.results)
}

func (e *Engine) ChunksChan() chan *sources.Chunk {
return e.chunks
}
Expand Down
6 changes: 4 additions & 2 deletions pkg/engine/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package engine

import (
"context"
"runtime"

"github.com/go-errors/errors"
"github.com/sirupsen/logrus"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/filesystem"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"runtime"
)

func (e *Engine) ScanFileSystem(ctx context.Context, directories []string) error {
Expand All @@ -27,12 +28,13 @@ func (e *Engine) ScanFileSystem(ctx context.Context, directories []string) error
if err != nil {
return errors.WrapPrefix(err, "could not init filesystem source", 0)
}
e.sourcesWg.Add(1)
go func() {
defer e.sourcesWg.Done()
err := fileSystemSource.Chunks(ctx, e.ChunksChan())
if err != nil {
logrus.WithError(err).Error("error scanning filesystem")
}
close(e.ChunksChan())
}()
return nil
}
6 changes: 4 additions & 2 deletions pkg/engine/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package engine
import (
"context"
"fmt"
"runtime"

"github.com/go-errors/errors"
"github.com/go-git/go-git/v5/plumbing/object"
"runtime"

gogit "github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/plumbing"
Expand Down Expand Up @@ -102,12 +103,13 @@ func (e *Engine) ScanGit(ctx context.Context, repoPath, headRef, baseRef string,
}
})

e.sourcesWg.Add(1)
go func() {
defer e.sourcesWg.Done()
err := gitSource.ScanRepo(ctx, repo, repoPath, scanOptions, e.ChunksChan())
if err != nil {
logrus.WithError(err).Fatal("could not scan repo")
}
close(e.ChunksChan())
}()
return nil
}
13 changes: 10 additions & 3 deletions pkg/engine/git_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func TestGitEngine(t *testing.T) {
WithDetectors(false, DefaultDetectors()...),
)
e.ScanGit(ctx, path, tTest.branch, tTest.base, tTest.maxDepth, tTest.filter)
go e.Finish()
resultCount := 0
for result := range e.ResultsChan() {
switch meta := result.SourceMetadata.GetData().(type) {
Expand Down Expand Up @@ -93,11 +94,17 @@ func BenchmarkGitEngine(b *testing.B) {
WithDecoders(decoders.DefaultDecoders()...),
WithDetectors(false, DefaultDetectors()...),
)
for i := 0; i < b.N; i++ {
e.ScanGit(ctx, path, "", "", 0, common.FilterEmpty())
go func() {
resultCount := 0
for _ = range e.ResultsChan() {
for range e.ResultsChan() {
resultCount++
}
}()

for i := 0; i < b.N; i++ {
// TODO: this is measuring the time it takes to initialize the source
// and not to do the full scan
e.ScanGit(ctx, path, "", "", 0, common.FilterEmpty())
}
e.Finish()
}
3 changes: 2 additions & 1 deletion pkg/engine/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ func (e *Engine) ScanGitHub(ctx context.Context, endpoint string, repos, orgs []
return err
}

e.sourcesWg.Add(1)
go func() {
defer e.sourcesWg.Done()
err := source.Chunks(ctx, e.ChunksChan())
if err != nil {
logrus.WithError(err).Fatal("could not scan github")
}
close(e.ChunksChan())
}()
return nil
}
7 changes: 5 additions & 2 deletions pkg/engine/gitlab.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package engine

import (
"fmt"
"runtime"

"github.com/go-errors/errors"
"github.com/sirupsen/logrus"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/gitlab"
"golang.org/x/net/context"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"runtime"
)

func (e *Engine) ScanGitLab(ctx context.Context, endpoint, token string, repositories []string) error {
Expand Down Expand Up @@ -44,12 +45,14 @@ func (e *Engine) ScanGitLab(ctx context.Context, endpoint, token string, reposit
if err != nil {
return errors.WrapPrefix(err, "could not init GitLab source", 0)
}

e.sourcesWg.Add(1)
go func() {
defer e.sourcesWg.Done()
err := gitlabSource.Chunks(ctx, e.ChunksChan())
if err != nil {
logrus.WithError(err).Error("error scanning GitLab")
}
close(e.ChunksChan())
}()
return nil
}
7 changes: 5 additions & 2 deletions pkg/engine/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ package engine
import (
"context"
"fmt"
"runtime"

"github.com/go-errors/errors"
"github.com/sirupsen/logrus"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb"
"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/s3"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"runtime"
)

func (e *Engine) ScanS3(ctx context.Context, key, secret string, cloudCred bool, buckets []string) error {
Expand Down Expand Up @@ -46,12 +47,14 @@ func (e *Engine) ScanS3(ctx context.Context, key, secret string, cloudCred bool,
if err != nil {
return errors.WrapPrefix(err, "failed to init S3 source", 0)
}

e.sourcesWg.Add(1)
go func() {
defer e.sourcesWg.Done()
err := s3Source.Chunks(ctx, e.ChunksChan())
if err != nil {
logrus.WithError(err).Error("error scanning s3")
}
close(e.ChunksChan())
}()
return nil
}
6 changes: 4 additions & 2 deletions pkg/engine/syslog.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package engine

import (
"context"
"os"

"github.com/go-errors/errors"
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"os"

"github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources/syslog"
Expand Down Expand Up @@ -46,12 +47,13 @@ func (e *Engine) ScanSyslog(ctx context.Context, address, protocol, certPath, ke
return err
}

e.sourcesWg.Add(1)
go func() {
defer e.sourcesWg.Done()
err := source.Chunks(ctx, e.ChunksChan())
if err != nil {
logrus.WithError(err).Fatal("could not scan syslog")
}
close(e.ChunksChan())
}()
return nil
}

0 comments on commit 6fa2171

Please sign in to comment.