diff --git a/hack/snifftest/main.go b/hack/snifftest/main.go index 219e7c57d7d7..8f76c7a361ea 100644 --- a/hack/snifftest/main.go +++ b/hack/snifftest/main.go @@ -188,8 +188,16 @@ func main() { logger.Info("cloned repo", "repo", r) - s := git.NewGit(sourcespb.SourceType_SOURCE_TYPE_GIT, 0, 0, "snifftest", false, runtime.NumCPU(), - func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData { + cfg := &git.Config{ + SourceName: "snifftest", + JobID: 0, + SourceID: 0, + SourceType: sourcespb.SourceType_SOURCE_TYPE_GIT, + Verify: false, + SkipBinaries: true, + SkipArchives: false, + Concurrency: runtime.NumCPU(), + SourceMetadataFunc: func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData { return &source_metadatapb.MetaData{ Data: &source_metadatapb.MetaData_Git{ Git: &source_metadatapb.Git{ @@ -202,9 +210,8 @@ func main() { }, } }, - true, - false, - ) + } + s := git.NewGit(cfg) logger.Info("scanning repo", "repo", r) err = s.ScanRepo(ctx, repo, path, git.NewScanOptions(), sources.ChanReporter{Ch: chunksChan}) diff --git a/pkg/gitparse/gitparse.go b/pkg/gitparse/gitparse.go index 14aecc6bafa5..14792efb08c1 100644 --- a/pkg/gitparse/gitparse.go +++ b/pkg/gitparse/gitparse.go @@ -250,6 +250,13 @@ func (state ParseState) String() string { }[state] } +// WithContentWriter sets the ContentWriter for the Parser. +func WithContentWriter(writer contentWriter) Option { + return func(parser *Parser) { + parser.contentWriter = writer + } +} + // WithMaxDiffSize sets maxDiffSize option. Diffs larger than maxDiffSize will // be truncated. func WithMaxDiffSize(maxDiffSize int) Option { diff --git a/pkg/gitparse/gitparse_test.go b/pkg/gitparse/gitparse_test.go index 9f1e03abc3fe..e5e8ad40a145 100644 --- a/pkg/gitparse/gitparse_test.go +++ b/pkg/gitparse/gitparse_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/trufflesecurity/trufflehog/v3/pkg/context" + bufferedfilewriter "github.com/trufflesecurity/trufflehog/v3/pkg/writers/buffered_file_writer" ) type testCaseLine struct { @@ -746,6 +747,15 @@ func TestIndividualCommitParsing(t *testing.T) { } } +func newBufferedFileWriterWithContent(content []byte) *bufferedfilewriter.BufferedFileWriter { + b := bufferedfilewriter.New() + _, err := b.Write(context.Background(), content) // Using Write method to add content + if err != nil { + panic(err) + } + return b +} + func newBufferWithContent(content []byte) *buffer { var b buffer _, _ = b.Write(context.Background(), content) // Using Write method to add content @@ -825,6 +835,79 @@ func TestStagedDiffParsing(t *testing.T) { } } +func TestStagedDiffParsingBufferedFileWriter(t *testing.T) { + expected := []Commit{ + { + Hash: "", + Author: "", + Date: newTime("0001-01-01 00:00:00 +0000 UTC"), + Message: strings.Builder{}, + Diffs: []Diff{ + { + PathB: "aws", + LineStart: 1, + contentWriter: newBufferedFileWriterWithContent([]byte("[default]\naws_access_key_id = AKIAXYZDQCEN4B6JSJQI\naws_secret_access_key = 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\noutput = json\nregion = us-east-2\n")), + IsBinary: false, + }, + { + PathB: "aws2", + LineStart: 1, + contentWriter: newBufferedFileWriterWithContent([]byte("\n\nthis is the secret: [Default]\nAccess key Id: AKIAILE3JG6KMS3HZGCA\nSecret Access Key: 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\n\nokay thank you bye\n")), + IsBinary: false, + }, + { + PathB: "core/runtime/src/main/java/io/quarkus/runtime/QuarkusApplication.java", + LineStart: 3, + contentWriter: newBufferedFileWriterWithContent([]byte("/**\n * This is usually used for command mode applications with a startup logic. The logic is executed inside\n * {@link QuarkusApplication#run} method before the main application exits.\n */\n")), + IsBinary: false, + }, + { + PathB: "trufflehog_3.42.0_linux_arm64.tar.gz", + IsBinary: true, + contentWriter: newBufferedFileWriterWithContent(nil), + }, + { + PathB: "tzu", + LineStart: 11, + contentWriter: newBufferedFileWriterWithContent([]byte("\n\n\n\nSource: https://www.gnu.org/software/diffutils/manual/diffutils.html#An-Example-of-Unified-Format\n")), + IsBinary: false, + }, + { + PathB: "lao", + LineStart: 1, + contentWriter: newBufferedFileWriterWithContent([]byte("The Way that can be told of is not the eternal Way;\nThe name that can be named is not the eternal name.\nThe Nameless is the origin of Heaven and Earth;\nThe Named is the mother of all things.\nTherefore let there always be non-being,\n so we may see their subtlety,\nAnd let there always be being,\n so we may see their outcome.\nThe two are the same,\nBut after they are produced,\n they have different names.\n")), + IsBinary: false, + }, + { + PathB: "tzu", + LineStart: 1, + contentWriter: newBufferedFileWriterWithContent([]byte("The Nameless is the origin of Heaven and Earth;\nThe named is the mother of all things.\n\nTherefore let there always be non-being,\n so we may see their subtlety,\nAnd let there always be being,\n so we may see their outcome.\nThe two are the same,\nBut after they are produced,\n they have different names.\nThey both may be called deep and profound.\nDeeper and more profound,\nThe door of all subtleties!\n")), + IsBinary: false, + }, + }, + }, + } + + r := bytes.NewReader([]byte(stagedDiffs)) + commitChan := make(chan Commit) + parser := NewParser() + go func() { + parser.FromReader(context.Background(), r, commitChan, true) + }() + i := 0 + for commit := range commitChan { + if len(expected) <= i { + t.Errorf("Missing expected case for commit: %+v", commit) + break + } + + if !commit.Equal(context.Background(), &expected[i]) { + t.Errorf("Commit does not match.\nexpected:\n%+v\n\nactual:\n%+v\n", expected[i], commit) + } + i++ + } +} + func TestCommitParseFailureRecovery(t *testing.T) { expected := []Commit{ { @@ -884,6 +967,65 @@ func TestCommitParseFailureRecovery(t *testing.T) { } } +func TestCommitParseFailureRecoveryBufferedFileWriter(t *testing.T) { + expected := []Commit{ + { + Hash: "df393b4125c2aa217211b2429b8963d0cefcee27", + Author: "Stephen ", + Date: newTime("Wed Dec 06 14:44:41 2017 -0800"), + Message: newStringBuilderValue("Add travis testing\n"), + Diffs: []Diff{ + { + PathB: ".travis.yml", + LineStart: 1, + contentWriter: newBufferedFileWriterWithContent([]byte("language: python\npython:\n - \"2.6\"\n - \"2.7\"\n - \"3.2\"\n - \"3.3\"\n - \"3.4\"\n - \"3.5\"\n - \"3.5-dev\" # 3.5 development branch\n - \"3.6\"\n - \"3.6-dev\" # 3.6 development branch\n - \"3.7-dev\" # 3.7 development branch\n - \"nightly\"\n")), + IsBinary: false, + }, + }, + }, + { + Hash: "3d76a97faad96e0f326afb61c232b9c2a18dca35", + Author: "John Smith ", + Date: newTime("Tue Jul 11 18:03:54 2023 -0400"), + Message: strings.Builder{}, + Diffs: []Diff{}, + }, + { + Hash: "7bd16429f1f708746dabf970e54b05d2b4734997", + Author: "John Smith ", + Date: newTime("Tue Jul 11 18:10:49 2023 -0400"), + Message: newStringBuilderValue("Change file\n"), + Diffs: []Diff{ + { + PathB: "tzu", + LineStart: 11, + contentWriter: newBufferedFileWriterWithContent([]byte("\n\n\n\nSource: https://www.gnu.org/software/diffutils/manual/diffutils.html#An-Example-of-Unified-Format\n")), + IsBinary: false, + }, + }, + }, + } + + r := bytes.NewReader([]byte(recoverableCommits)) + commitChan := make(chan Commit) + parser := NewParser() + go func() { + parser.FromReader(context.Background(), r, commitChan, false) + }() + i := 0 + for commit := range commitChan { + if len(expected) <= i { + t.Errorf("Missing expected case for commit: %+v", commit) + break + } + + if !commit.Equal(context.Background(), &expected[i]) { + t.Errorf("Commit does not match.\nexpected: %+v\n\nactual : %+v\n", expected[i], commit) + } + i++ + } +} + const recoverableCommits = `commit df393b4125c2aa217211b2429b8963d0cefcee27 Author: Stephen Date: Wed Dec 06 14:44:41 2017 -0800 @@ -1004,6 +1146,56 @@ func TestDiffParseFailureRecovery(t *testing.T) { } } +func TestDiffParseFailureRecoveryBufferedFileWriter(t *testing.T) { + expected := []Commit{ + { + Hash: "", + Author: "", + Date: newTime("0001-01-01 00:00:00 +0000 UTC"), + Message: strings.Builder{}, + Diffs: []Diff{ + { + PathB: "aws", + LineStart: 1, + contentWriter: newBufferedFileWriterWithContent([]byte("[default]\naws_access_key_id = AKIAXYZDQCEN4B6JSJQI\naws_secret_access_key = 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\noutput = json\nregion = us-east-2\n")), + IsBinary: false, + }, + { + PathB: "tzu", + LineStart: 11, + contentWriter: newBufferedFileWriterWithContent([]byte("\n\n\n\nSource: https://www.gnu.org/software/diffutils/manual/diffutils.html#An-Example-of-Unified-Format\n")), + IsBinary: false, + }, + { + PathB: "tzu", + LineStart: 1, + contentWriter: newBufferedFileWriterWithContent([]byte("The Nameless is the origin of Heaven and Earth;\nThe named is the mother of all things.\n\nTherefore let there always be non-being,\n so we may see their subtlety,\nAnd let there always be being,\n so we may see their outcome.\nThe two are the same,\nBut after they are produced,\n they have different names.\nThey both may be called deep and profound.\nDeeper and more profound,\nThe door of all subtleties!\n")), + IsBinary: false, + }, + }, + }, + } + + r := bytes.NewReader([]byte(recoverableDiffs)) + commitChan := make(chan Commit) + parser := NewParser() + go func() { + parser.FromReader(context.Background(), r, commitChan, true) + }() + i := 0 + for commit := range commitChan { + if len(expected) <= i { + t.Errorf("Missing expected case for commit: %+v", commit) + break + } + + if !commit.Equal(context.Background(), &expected[i]) { + t.Errorf("Commit does not match.\nexpected: %+v\n\nactual : %+v\n", expected[i], commit) + } + i++ + } +} + const recoverableDiffs = `diff --git a/aws b/aws index 2ee133b..12b4843 100644 --- a/aws diff --git a/pkg/sources/git/git.go b/pkg/sources/git/git.go index b2b8ffb6c9c6..8159d5be2f54 100644 --- a/pkg/sources/git/git.go +++ b/pkg/sources/git/git.go @@ -34,21 +34,28 @@ import ( "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sanitizer" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" + bufferedfilewriter "github.com/trufflesecurity/trufflehog/v3/pkg/writers/buffered_file_writer" ) const SourceType = sourcespb.SourceType_SOURCE_TYPE_GIT type Source struct { name string - sourceId sources.SourceID - jobId sources.JobID + sourceID sources.SourceID + jobID sources.JobID verify bool - git *Git + + useCustomContentWriter bool + git *Git + scanOptions *ScanOptions + sources.Progress - conn *sourcespb.Git - scanOptions *ScanOptions + conn *sourcespb.Git } +// WithCustomContentWriter sets the useCustomContentWriter flag on the source. +func (s *Source) WithCustomContentWriter() { s.useCustomContentWriter = true } + type Git struct { sourceType sourcespb.SourceType sourceName string @@ -60,26 +67,54 @@ type Git struct { concurrency *semaphore.Weighted skipBinaries bool skipArchives bool + + parser *gitparse.Parser } type metrics struct { commitsScanned uint64 } -func NewGit(sourceType sourcespb.SourceType, jobID sources.JobID, sourceID sources.SourceID, sourceName string, verify bool, concurrency int, - sourceMetadataFunc func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData, skipBinaries bool, - skipArchives bool, -) *Git { +// Config for a Git source. +type Config struct { + Concurrency int + SourceMetadataFunc func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData + + SourceName string + JobID sources.JobID + SourceID sources.SourceID + SourceType sourcespb.SourceType + Verify bool + SkipBinaries bool + SkipArchives bool + + // UseCustomContentWriter indicates whether to use a custom contentWriter. + // When set to true, the parser will use a custom contentWriter provided through the WithContentWriter option. + // When false, the parser will use the default buffer (in-memory) contentWriter. + UseCustomContentWriter bool +} + +// NewGit creates a new Git instance with the provided configuration. The Git instance is used to interact with +// Git repositories. +func NewGit(config *Config) *Git { + var parser *gitparse.Parser + if config.UseCustomContentWriter { + parser = gitparse.NewParser(gitparse.WithContentWriter(bufferedfilewriter.New())) + } else { + parser = gitparse.NewParser() + } + return &Git{ - sourceType: sourceType, - sourceName: sourceName, - sourceID: sourceID, - jobID: jobID, - sourceMetadataFunc: sourceMetadataFunc, - verify: verify, - concurrency: semaphore.NewWeighted(int64(concurrency)), - skipBinaries: skipBinaries, - skipArchives: skipArchives, + sourceType: config.SourceType, + sourceName: config.SourceName, + sourceID: config.SourceID, + jobID: config.JobID, + sourceMetadataFunc: config.SourceMetadataFunc, + verify: config.Verify, + concurrency: semaphore.NewWeighted(int64(config.Concurrency)), + skipBinaries: config.SkipBinaries, + skipArchives: config.SkipArchives, + parser: parser, } } @@ -97,11 +132,11 @@ func (s *Source) Type() sourcespb.SourceType { } func (s *Source) SourceID() sources.SourceID { - return s.sourceId + return s.sourceID } func (s *Source) JobID() sources.JobID { - return s.jobId + return s.jobID } // withScanOptions sets the scan options. @@ -112,8 +147,8 @@ func (s *Source) withScanOptions(scanOptions *ScanOptions) { // Init returns an initialized GitHub source. func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { s.name = name - s.sourceId = sourceId - s.jobId = jobId + s.sourceID = sourceId + s.jobID = jobId s.verify = verify if s.scanOptions == nil { s.scanOptions = &ScanOptions{} @@ -166,8 +201,16 @@ func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, so return err } - s.git = NewGit(s.Type(), s.jobId, s.sourceId, s.name, s.verify, concurrency, - func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData { + cfg := &Config{ + SourceName: s.name, + JobID: s.jobID, + SourceID: s.sourceID, + SourceType: s.Type(), + Verify: s.verify, + SkipBinaries: conn.GetSkipBinaries(), + SkipArchives: conn.GetSkipArchives(), + Concurrency: concurrency, + SourceMetadataFunc: func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData { return &source_metadatapb.MetaData{ Data: &source_metadatapb.MetaData_Git{ Git: &source_metadatapb.Git{ @@ -181,9 +224,9 @@ func (s *Source) Init(aCtx context.Context, name string, jobId sources.JobID, so }, } }, - conn.GetSkipBinaries(), - conn.GetSkipArchives(), - ) + UseCustomContentWriter: s.useCustomContentWriter, + } + s.git = NewGit(cfg) return nil } diff --git a/pkg/sources/github/github.go b/pkg/sources/github/github.go index d44b5e6ab9fe..35b64e8ca272 100644 --- a/pkg/sources/github/github.go +++ b/pkg/sources/github/github.go @@ -7,7 +7,6 @@ import ( "net/url" "os" "regexp" - "runtime" "sort" "strconv" "strings" @@ -62,7 +61,9 @@ type Source struct { memberCache map[string]struct{} repoSizes repoSize totalRepoSize int // total size of all repos in kb - git *git.Git + + useCustomContentWriter bool + git *git.Git scanOptMu sync.Mutex // protects the scanOptions scanOptions *git.ScanOptions @@ -85,6 +86,9 @@ type Source struct { sources.CommonSourceUnitUnmarshaller } +// WithCustomContentWriter sets the useCustomContentWriter flag on the source. +func (s *Source) WithCustomContentWriter() { s.useCustomContentWriter = true } + func (s *Source) WithScanOptions(scanOptions *git.ScanOptions) { s.scanOptions = scanOptions } @@ -259,8 +263,16 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so s.publicMap = map[string]source_metadatapb.Visibility{} - s.git = git.NewGit(s.Type(), s.JobID(), s.SourceID(), s.name, s.verify, runtime.NumCPU(), - func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData { + cfg := &git.Config{ + SourceName: s.name, + JobID: s.jobID, + SourceID: s.sourceID, + SourceType: s.Type(), + Verify: s.verify, + SkipBinaries: conn.GetSkipBinaries(), + SkipArchives: conn.GetSkipArchives(), + Concurrency: concurrency, + SourceMetadataFunc: func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData { return &source_metadatapb.MetaData{ Data: &source_metadatapb.MetaData_Github{ Github: &source_metadatapb.Github{ @@ -276,9 +288,9 @@ func (s *Source) Init(aCtx context.Context, name string, jobID sources.JobID, so }, } }, - conn.GetSkipBinaries(), - conn.GetSkipArchives(), - ) + UseCustomContentWriter: s.useCustomContentWriter, + } + s.git = git.NewGit(cfg) return nil } diff --git a/pkg/sources/gitlab/gitlab.go b/pkg/sources/gitlab/gitlab.go index 91a97b3a8508..bfbc058b8061 100644 --- a/pkg/sources/gitlab/gitlab.go +++ b/pkg/sources/gitlab/gitlab.go @@ -4,7 +4,6 @@ import ( "fmt" "net/url" "os" - "runtime" "sort" "strings" "sync" @@ -32,26 +31,34 @@ import ( const SourceType = sourcespb.SourceType_SOURCE_TYPE_GITLAB type Source struct { - name string - sourceId sources.SourceID - jobId sources.JobID - verify bool - authMethod string - user string - password string - token string - url string - repos []string - ignoreRepos []string - git *git.Git - scanOptions *git.ScanOptions + name string + sourceID sources.SourceID + jobID sources.JobID + verify bool + + authMethod string + user string + password string + token string + url string + repos []string + ignoreRepos []string + + useCustomContentWriter bool + git *git.Git + scanOptions *git.ScanOptions + resumeInfoSlice []string resumeInfoMutex sync.Mutex sources.Progress + jobPool *errgroup.Group sources.CommonSourceUnitUnmarshaller } +// WithCustomContentWriter sets the useCustomContentWriter flag on the source. +func (s *Source) WithCustomContentWriter() { s.useCustomContentWriter = true } + // Ensure the Source satisfies the interfaces at compile time. var _ sources.Source = (*Source)(nil) var _ sources.SourceUnitUnmarshaller = (*Source)(nil) @@ -64,18 +71,18 @@ func (s *Source) Type() sourcespb.SourceType { } func (s *Source) SourceID() sources.SourceID { - return s.sourceId + return s.sourceID } func (s *Source) JobID() sources.JobID { - return s.jobId + return s.jobID } // Init returns an initialized Gitlab source. func (s *Source) Init(_ context.Context, name string, jobId sources.JobID, sourceId sources.SourceID, verify bool, connection *anypb.Any, concurrency int) error { s.name = name - s.sourceId = sourceId - s.jobId = jobId + s.sourceID = sourceId + s.jobID = jobId s.verify = verify s.jobPool = &errgroup.Group{} s.jobPool.SetLimit(concurrency) @@ -121,8 +128,16 @@ func (s *Source) Init(_ context.Context, name string, jobId sources.JobID, sourc return err } - s.git = git.NewGit(s.Type(), s.JobID(), s.SourceID(), s.name, s.verify, runtime.NumCPU(), - func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData { + cfg := &git.Config{ + SourceName: s.name, + JobID: s.jobID, + SourceID: s.sourceID, + SourceType: s.Type(), + Verify: s.verify, + SkipBinaries: conn.GetSkipBinaries(), + SkipArchives: conn.GetSkipArchives(), + Concurrency: concurrency, + SourceMetadataFunc: func(file, email, commit, timestamp, repository string, line int64) *source_metadatapb.MetaData { return &source_metadatapb.MetaData{ Data: &source_metadatapb.MetaData_Gitlab{ Gitlab: &source_metadatapb.Gitlab{ @@ -137,9 +152,9 @@ func (s *Source) Init(_ context.Context, name string, jobId sources.JobID, sourc }, } }, - conn.GetSkipBinaries(), - conn.GetSkipArchives(), - ) + UseCustomContentWriter: s.useCustomContentWriter, + } + s.git = git.NewGit(cfg) return nil } diff --git a/pkg/writers/buffered_file_writer/bufferedfilewriter.go b/pkg/writers/buffered_file_writer/bufferedfilewriter.go new file mode 100644 index 000000000000..e35c1ac3edb3 --- /dev/null +++ b/pkg/writers/buffered_file_writer/bufferedfilewriter.go @@ -0,0 +1,238 @@ +// Package bufferedfilewriter provides a writer that buffers data in memory until a threshold is exceeded at +// which point it switches to writing to a temporary file. +package bufferedfilewriter + +import ( + "bytes" + "fmt" + "io" + "os" + "sync" + + "github.com/trufflesecurity/trufflehog/v3/pkg/cleantemp" + "github.com/trufflesecurity/trufflehog/v3/pkg/context" +) + +// bufferPool is used to store buffers for reuse. +var bufferPool = sync.Pool{ + // TODO: Consider growing the buffer before returning it if we can find an optimal size. + // Ideally the size would cover the majority of cases without being too large. + // This would avoid the need to grow the buffer when writing to it, reducing allocations. + New: func() any { return new(bytes.Buffer) }, +} + +// state represents the current mode of BufferedFileWriter. +type state uint8 + +const ( + // writeOnly indicates the BufferedFileWriter is in write-only mode. + writeOnly state = iota + // readOnly indicates the BufferedFileWriter has been closed and is in read-only mode. + readOnly +) + +// BufferedFileWriter manages a buffer for writing data, flushing to a file when a threshold is exceeded. +// It supports either write-only or read-only mode, indicated by its state. +type BufferedFileWriter struct { + threshold uint64 // Threshold for switching to file writing. + size uint64 // Total size of the data written. + + state state // Current state of the writer. (writeOnly or readOnly) + + buf bytes.Buffer // Buffer for storing data under the threshold in memory. + filename string // Name of the temporary file. + file io.WriteCloser // File for storing data over the threshold. +} + +// Option is a function that modifies a BufferedFileWriter. +type Option func(*BufferedFileWriter) + +// WithThreshold sets the threshold for switching to file writing. +func WithThreshold(threshold uint64) Option { + return func(w *BufferedFileWriter) { w.threshold = threshold } +} + +// New creates a new BufferedFileWriter with the given options. +func New(opts ...Option) *BufferedFileWriter { + const defaultThreshold = 10 * 1024 * 1024 // 10MB + w := &BufferedFileWriter{threshold: defaultThreshold, state: writeOnly} + for _, opt := range opts { + opt(w) + } + return w +} + +// Len returns the number of bytes written to the buffer or file. +func (w *BufferedFileWriter) Len() int { return int(w.size) } + +// String returns all the data written to the buffer or file as a string or an empty string if there is an error. +func (w *BufferedFileWriter) String() (string, error) { + if w.file == nil { + return w.buf.String(), nil + } + + // Data is in a file, read from the file. + file, err := os.Open(w.filename) + if err != nil { + return "", fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + // Create a buffer large enough to hold file data and additional buffer data, if any. + fileSize := w.size + buf := bytes.NewBuffer(make([]byte, 0, fileSize)) + + // Read the file contents into the buffer. + if _, err := io.Copy(buf, file); err != nil { + return "", fmt.Errorf("failed to read file contents: %w", err) + } + + // Append buffer data, if any, to the end of the file contents. + buf.Write(w.buf.Bytes()) + + return buf.String(), nil +} + +// Write writes data to the buffer or a file, depending on the size. +func (w *BufferedFileWriter) Write(ctx context.Context, data []byte) (int, error) { + if w.state != writeOnly { + return 0, fmt.Errorf("BufferedFileWriter must be in write-only mode to write") + } + + size := uint64(len(data)) + defer func() { + w.size += size + ctx.Logger().V(4).Info( + "write complete", + "data_size", size, + "content_size", w.buf.Len(), + "total_size", w.size, + ) + }() + + if w.buf.Len() == 0 { + bufPtr, ok := bufferPool.Get().(*bytes.Buffer) + if !ok { + ctx.Logger().Error(fmt.Errorf("buffer pool returned unexpected type"), "using new buffer") + bufPtr = new(bytes.Buffer) + } + bufPtr.Reset() // Reset the buffer to clear any existing data + w.buf = *bufPtr + } + + if uint64(w.buf.Len())+size <= w.threshold { + // If the total size is within the threshold, write to the buffer. + ctx.Logger().V(4).Info( + "writing to buffer", + "data_size", size, + "content_size", w.buf.Len(), + ) + return w.buf.Write(data) + } + + // Switch to file writing if threshold is exceeded. + // This helps in managing memory efficiently for large content. + if w.file == nil { + file, err := os.CreateTemp(os.TempDir(), cleantemp.MkFilename()) + if err != nil { + return 0, err + } + + w.filename = file.Name() + w.file = file + + // Transfer existing data in buffer to the file, then clear the buffer. + // This ensures all the data is in one place - either entirely in the buffer or the file. + if w.buf.Len() > 0 { + ctx.Logger().V(4).Info("writing buffer to file", "content_size", w.buf.Len()) + if _, err := w.file.Write(w.buf.Bytes()); err != nil { + return 0, err + } + // Reset the buffer to clear any existing data and return it to the pool. + w.buf.Reset() + bufferPool.Put(&w.buf) + } + } + ctx.Logger().V(4).Info("writing to file", "data_size", size) + + return w.file.Write(data) +} + +// CloseForWriting flushes any remaining data in the buffer to the file, closes the file if created, +// and transitions the BufferedFileWriter to read-only mode. +func (w *BufferedFileWriter) CloseForWriting() error { + defer func() { w.state = readOnly }() + if w.file == nil { + return nil + } + + if w.buf.Len() > 0 { + _, err := w.file.Write(w.buf.Bytes()) + if err != nil { + return err + } + } + return w.file.Close() +} + +// ReadCloser returns an io.ReadCloser to read the written content. It provides a reader +// based on the current storage medium of the data (in-memory buffer or file). +// If the total content size exceeds the predefined threshold, it is stored in a temporary file and a file +// reader is returned. For in-memory data, it returns a custom reader that handles returning +// the buffer to the pool. +// The caller should call Close() on the returned io.Reader when done to ensure files are cleaned up. +// It can only be used when the BufferedFileWriter is in read-only mode. +func (w *BufferedFileWriter) ReadCloser() (io.ReadCloser, error) { + if w.state != readOnly { + return nil, fmt.Errorf("BufferedFileWriter must be in read-only mode to read") + } + + if w.file != nil { + // Data is in a file, read from the file. + file, err := os.Open(w.filename) + if err != nil { + return nil, err + } + return newAutoDeletingFileReader(file), nil + } + + // Data is in memory. + return &bufferReadCloser{ + Reader: bytes.NewReader(w.buf.Bytes()), + onClose: func() { bufferPool.Put(&w.buf) }, + }, nil +} + +// autoDeletingFileReader wraps an *os.File and deletes the file on Close. +type autoDeletingFileReader struct{ *os.File } + +// newAutoDeletingFileReader creates a new autoDeletingFileReader. +func newAutoDeletingFileReader(file *os.File) *autoDeletingFileReader { + return &autoDeletingFileReader{File: file} +} + +// Close implements the io.Closer interface, deletes the file after closing. +func (r *autoDeletingFileReader) Close() error { + defer os.Remove(r.Name()) // Delete the file after closing + return r.File.Close() +} + +// bufferReadCloser is a custom implementation of io.ReadCloser. It wraps a bytes.Reader +// for reading data from an in-memory buffer and includes an onClose callback. +// The onClose callback is used to return the buffer to the pool, ensuring buffer re-usability. +type bufferReadCloser struct { + *bytes.Reader + onClose func() +} + +// Close implements the io.Closer interface. It calls the onClose callback to return the buffer +// to the pool, enabling buffer reuse. This method should be called by the consumers of ReadCloser +// once they have finished reading the data to ensure proper resource management. +func (brc *bufferReadCloser) Close() error { + if brc.onClose == nil { + return nil + } + + brc.onClose() // Return the buffer to the pool + return nil +} diff --git a/pkg/writers/buffered_file_writer/bufferedfilewriter_test.go b/pkg/writers/buffered_file_writer/bufferedfilewriter_test.go new file mode 100644 index 000000000000..471b9389dd16 --- /dev/null +++ b/pkg/writers/buffered_file_writer/bufferedfilewriter_test.go @@ -0,0 +1,308 @@ +package bufferedfilewriter + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/trufflesecurity/trufflehog/v3/pkg/context" +) + +func TestBufferedFileWriterNewThreshold(t *testing.T) { + t.Parallel() + + const ( + defaultThreshold = 10 * 1024 * 1024 // 10MB + customThreshold = 20 * 1024 * 1024 // 20MB + ) + + tests := []struct { + name string + options []Option + expectedThreshold uint64 + }{ + {name: "Default Threshold", expectedThreshold: defaultThreshold}, + {name: "Custom Threshold", options: []Option{WithThreshold(customThreshold)}, expectedThreshold: customThreshold}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + writer := New(tc.options...) + assert.Equal(t, tc.expectedThreshold, writer.threshold) + // The state should always be writeOnly when created. + assert.Equal(t, writeOnly, writer.state) + }) + } +} + +func TestBufferedFileWriterString(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input []byte + expectedStr string + additionalInput []byte + threshold uint64 + }{ + {name: "Empty", input: []byte(""), expectedStr: ""}, + {name: "Nil", input: nil, expectedStr: ""}, + {name: "Small content, buffer only", input: []byte("hello"), expectedStr: "hello"}, + { + name: "Large content, buffer only", + input: []byte("longer string with more characters"), + expectedStr: "longer string with more characters", + }, + { + name: "Large content, file only", + input: []byte("longer string with more characters"), + expectedStr: "longer string with more characters", + threshold: 5, + }, + { + name: "Content in both file and buffer", + input: []byte("initial content exceeding threshold"), + additionalInput: []byte(" more content in buffer"), + expectedStr: "initial content exceeding threshold more content in buffer", + threshold: 10, // Set a threshold that the initial content exceeds + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + writer := New(WithThreshold(tc.threshold)) + // First write, should go to file if it exceeds the threshold. + _, err := writer.Write(ctx, tc.input) + assert.NoError(t, err) + + // Second write, should go to buffer + if tc.additionalInput != nil { + _, err = writer.Write(ctx, tc.additionalInput) + assert.NoError(t, err) + } + + got, err := writer.String() + assert.NoError(t, err) + + assert.Equal(t, tc.expectedStr, got, "String content mismatch") + }) + } +} + +func TestBufferedFileWriterLen(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input []byte + expectedLen int + }{ + {name: "Empty", input: []byte(""), expectedLen: 0}, + {name: "Nil", input: nil, expectedLen: 0}, + {name: "Small content", input: []byte("hello"), expectedLen: 5}, + {name: "Large content", input: []byte("longer string with more characters"), expectedLen: 34}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + writer := New() + _, err := writer.Write(context.Background(), tc.input) + assert.NoError(t, err) + + length := writer.Len() + assert.Equal(t, tc.expectedLen, length) + }) + } +} + +// TestBufferedFileWriterWriteWithinThreshold tests that data is written to the buffer when the threshold +// is not exceeded. +func TestBufferedFileWriterWriteWithinThreshold(t *testing.T) { + t.Parallel() + + ctx := context.Background() + data := []byte("hello world") + + writer := New(WithThreshold(64)) + _, err := writer.Write(ctx, data) + assert.NoError(t, err) + + assert.Equal(t, data, writer.buf.Bytes()) +} + +// TestBufferedFileWriterWriteExceedsThreshold tests that data is written to a file when the threshold +// is exceeded. +func TestBufferedFileWriterWriteExceedsThreshold(t *testing.T) { + t.Parallel() + + ctx := context.Background() + data := []byte("hello world") + + writer := New(WithThreshold(5)) + _, err := writer.Write(ctx, data) + assert.NoError(t, err) + + defer func() { + err := writer.CloseForWriting() + assert.NoError(t, err) + }() + + assert.NotNil(t, writer.file) + assert.Len(t, writer.buf.Bytes(), 0) + fileContents, err := os.ReadFile(writer.filename) + assert.NoError(t, err) + assert.Equal(t, data, fileContents) +} + +// TestBufferedFileWriterWriteAfterFlush tests that data is written to a file when the threshold +// is exceeded, and subsequent writes are to the buffer until the threshold is exceeded again. +func TestBufferedFileWriterWriteAfterFlush(t *testing.T) { + t.Parallel() + + ctx := context.Background() + initialData := []byte("initial data is longer than subsequent data") + subsequentData := []byte("subsequent data") + + // Initialize writer with a threshold that initialData will exceed. + writer := New(WithThreshold(uint64(len(initialData) - 1))) + _, err := writer.Write(ctx, initialData) + assert.NoError(t, err) + + defer func() { + err := writer.CloseForWriting() + assert.NoError(t, err) + }() + + // Get the file modification time after the initial write. + initialModTime, err := getFileModTime(t, writer.filename) + assert.NoError(t, err) + fileContents, err := os.ReadFile(writer.filename) + assert.NoError(t, err) + assert.Equal(t, initialData, fileContents) + + // Perform a subsequent write with data under the threshold. + _, err = writer.Write(ctx, subsequentData) + assert.NoError(t, err) + + assert.Equal(t, subsequentData, writer.buf.Bytes()) // Check buffer contents + finalModTime, err := getFileModTime(t, writer.filename) + assert.NoError(t, err) + assert.Equal(t, initialModTime, finalModTime) // File should not be modified again +} + +func getFileModTime(t *testing.T, fileName string) (time.Time, error) { + t.Helper() + + fileInfo, err := os.Stat(fileName) + if err != nil { + return time.Time{}, err + } + return fileInfo.ModTime(), nil +} + +func TestBufferedFileWriterClose(t *testing.T) { + t.Parallel() + + const threshold = 10 + ctx := context.Background() + + tests := []struct { + name string + prepareWriter func(*BufferedFileWriter) // Function to prepare the writer for the test + expectFileContent string + }{ + { + name: "No File Created, Only Buffer Data", + prepareWriter: func(w *BufferedFileWriter) { + // Write data under the threshold + _, _ = w.Write(ctx, []byte("small data")) + }, + expectFileContent: "", + }, + { + name: "File Created, No Data in Buffer", + prepareWriter: func(w *BufferedFileWriter) { + // Write data over the threshold to create a file + _, _ = w.Write(ctx, []byte("large data is more than the threshold")) + }, + expectFileContent: "large data is more than the threshold", + }, + { + name: "File Created, Data in Buffer", + prepareWriter: func(w *BufferedFileWriter) { + // Write data over the threshold to create a file, then write more data + _, _ = w.Write(ctx, []byte("large data is more than the threshold")) + _, _ = w.Write(ctx, []byte(" more data")) + }, + expectFileContent: "large data is more than the threshold more data", + }, + { + name: "File Created, Buffer Cleared", + prepareWriter: func(w *BufferedFileWriter) { + // Write data over the threshold to create a file, then clear the buffer. + _, _ = w.Write(ctx, []byte("large data is more than the threshold")) + w.buf.Reset() + }, + expectFileContent: "large data is more than the threshold", + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + writer := New(WithThreshold(threshold)) + + tc.prepareWriter(writer) + + err := writer.CloseForWriting() + assert.NoError(t, err) + + if writer.file != nil { + fileContents, err := os.ReadFile(writer.filename) + assert.NoError(t, err) + assert.Equal(t, tc.expectFileContent, string(fileContents)) + return + } + + // If no file was created, the buffer should be empty. + assert.Equal(t, tc.expectFileContent, "") + }) + } +} + +func TestBufferedFileWriterStateTransitionOnClose(t *testing.T) { + t.Parallel() + writer := New() + + // Initially, the writer should be in write-only mode. + assert.Equal(t, writeOnly, writer.state) + + // Perform some write operation. + _, err := writer.Write(context.Background(), []byte("test data")) + assert.NoError(t, err) + + // Close the writer. + err = writer.CloseForWriting() + assert.NoError(t, err) + + // After closing, the writer should be in read-only mode. + assert.Equal(t, readOnly, writer.state) +} + +func TestBufferedFileWriterWriteInReadOnlyState(t *testing.T) { + t.Parallel() + writer := New() + _ = writer.CloseForWriting() // Transition to read-only mode + + // Attempt to write in read-only mode. + _, err := writer.Write(context.Background(), []byte("should fail")) + assert.Error(t, err) +}