diff --git a/main.go b/main.go index 5298be457416..d6a50dcf2801 100644 --- a/main.go +++ b/main.go @@ -5,10 +5,12 @@ import ( "net/http" _ "net/http/pprof" "os" + "os/signal" "runtime" "strconv" "strings" "syscall" + "time" "github.com/alecthomas/kingpin/v2" "github.com/felixge/fgprof" @@ -219,12 +221,36 @@ func main() { } func run(state overseer.State) { - ctx := context.Background() - go cleantemp.RunCleanupLoop(ctx) + ctx, cancel := context.WithCancelCause(context.Background()) + defer cancel(nil) + + go func() { + if err := cleantemp.CleanTempArtifacts(ctx); err != nil { + ctx.Logger().Error(err, "error cleaning temporary artifacts") + } + }() logger := ctx.Logger() logFatal := logFatalFunc(logger) + killSignal := make(chan os.Signal, 1) + signal.Notify(killSignal, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + go func() { + <-killSignal + logger.Info("Received signal, shutting down.") + cancel(fmt.Errorf("canceling context due to signal")) + + if err := cleantemp.CleanTempArtifacts(ctx); err != nil { + logger.Error(err, "error cleaning temporary artifacts") + } else { + logger.Info("cleaned temporary artifacts") + } + + time.Sleep(time.Second * 10) + logger.Info("10 seconds elapsed. Forcing shutdown.") + os.Exit(0) + }() + logger.V(2).Info(fmt.Sprintf("trufflehog %s", version.BuildVersion)) if *githubScanToken != "" { diff --git a/pkg/cleantemp/cleantemp.go b/pkg/cleantemp/cleantemp.go index 8ff4c7d980cb..0122963bd430 100644 --- a/pkg/cleantemp/cleantemp.go +++ b/pkg/cleantemp/cleantemp.go @@ -7,7 +7,6 @@ import ( "regexp" "strconv" "strings" - "time" "github.com/mitchellh/go-ps" @@ -106,27 +105,3 @@ func CleanTempArtifacts(ctx logContext.Context) error { return nil } - -// RunCleanupLoop runs a loop that cleans up orphaned directories every 15 seconds. -func RunCleanupLoop(ctx logContext.Context) { - err := CleanTempArtifacts(ctx) - if err != nil { - ctx.Logger().Error(err, "Error cleaning up orphaned directories ") - } - - const cleanupLoopInterval = 15 * time.Second - ticker := time.NewTicker(cleanupLoopInterval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if err := CleanTempArtifacts(ctx); err != nil { - ctx.Logger().Error(err, "error cleaning up orphaned directories") - } - case <-ctx.Done(): - ctx.Logger().Info("Cleanup loop exiting due to context cancellation") - return - } - } -} diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index a8113d637c07..38f0da44371a 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -11,6 +11,7 @@ import ( lru "github.com/hashicorp/golang-lru" "google.golang.org/protobuf/proto" + "github.com/trufflesecurity/trufflehog/v3/pkg/cleantemp" "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/config" "github.com/trufflesecurity/trufflehog/v3/pkg/context" @@ -425,6 +426,10 @@ func (e *Engine) Finish(ctx context.Context) error { close(e.results) // Detector workers are done, close the results channel and call it a day. e.WgNotifier.Wait() // Wait for the notifier workers to finish notifying results. + if err := cleantemp.CleanTempArtifacts(ctx); err != nil { + ctx.Logger().Error(err, "error cleaning temp artifacts") + } + e.metrics.ScanDuration = time.Since(e.metrics.scanStartTime) return err diff --git a/pkg/handlers/archive.go b/pkg/handlers/archive.go index 6e7ee4efdf05..ccefdadc7624 100644 --- a/pkg/handlers/archive.go +++ b/pkg/handlers/archive.go @@ -213,12 +213,7 @@ func (a *Archive) extractorHandler(archiveChan chan []byte) func(context.Context return nil } - fileBytes, err := a.ReadToMax(lCtx, fReader) - if err != nil { - return err - } - - return a.openArchive(lCtx, depth, bytes.NewReader(fileBytes), archiveChan) + return a.openArchive(lCtx, depth, fReader, archiveChan) } } diff --git a/pkg/sources/git/git.go b/pkg/sources/git/git.go index f96ed08816df..4d705c0f04f7 100644 --- a/pkg/sources/git/git.go +++ b/pkg/sources/git/git.go @@ -5,7 +5,6 @@ import ( "bytes" "errors" "fmt" - "io" "net/url" "os" "os/exec" @@ -1015,7 +1014,6 @@ func (s *Git) handleBinary(ctx context.Context, gitDir string, reporter sources. } } - const maxSize = 1 * 1024 * 1024 * 1024 // 1GB cmd := exec.Command("git", "-C", gitDir, "cat-file", "blob", commitHash.String()+":"+path) var stderr bytes.Buffer @@ -1043,31 +1041,12 @@ func (s *Git) handleBinary(ctx context.Context, gitDir string, reporter sources. } }() - var fileContent bytes.Buffer - // Create a limited reader to ensure we don't read more than the max size. - lr := io.LimitReader(fileReader, int64(maxSize)) - - // Using io.CopyBuffer for performance advantages. Though buf is mandatory - // for the method, due to the internal implementation of io.CopyBuffer, when - // *bytes.Buffer implements io.WriterTo or io.ReaderFrom, the provided buf - // is simply ignored. Thus, we can pass nil for the buf parameter. - _, err = io.CopyBuffer(&fileContent, lr, nil) - if err != nil && !errors.Is(err, io.EOF) { - return err - } - - if fileContent.Len() == maxSize { - fileCtx.Logger().V(2).Info("Max archive size reached.") - } - bufferName := cleantemp.MkFilename() - reader, err := diskbufferreader.New(&fileContent, diskbufferreader.WithBufferName(bufferName)) - + reader, err := diskbufferreader.New(fileReader, diskbufferreader.WithBufferName(bufferName)) if err != nil { return err } - defer reader.Close() if handlers.HandleFile(fileCtx, reader, chunkSkel, reporter, handlerOpts...) {