diff --git a/pkg/writers/buffered_file_writer/bufferedfilewriter.go b/pkg/writers/buffered_file_writer/bufferedfilewriter.go index 85dcfbc53f1b..a0409aa51665 100644 --- a/pkg/writers/buffered_file_writer/bufferedfilewriter.go +++ b/pkg/writers/buffered_file_writer/bufferedfilewriter.go @@ -172,6 +172,10 @@ func (w *BufferedFileWriter) CloseForWriting() error { return nil } + // Return the buffer to the pool since the contents have been written to the file and + // the writer is transitioning to read-only mode. + defer w.bufPool.Put(w.buf) + if w.buf.Len() > 0 { _, err := w.buf.WriteTo(w.file) if err != nil { diff --git a/pkg/writers/buffered_file_writer/bufferedfilewriter_test.go b/pkg/writers/buffered_file_writer/bufferedfilewriter_test.go index b9937816506c..d79306fd61b5 100644 --- a/pkg/writers/buffered_file_writer/bufferedfilewriter_test.go +++ b/pkg/writers/buffered_file_writer/bufferedfilewriter_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/trufflesecurity/trufflehog/v3/pkg/context" + "github.com/trufflesecurity/trufflehog/v3/pkg/writers/buffer" ) func TestBufferedFileWriterNewThreshold(t *testing.T) { @@ -506,3 +507,33 @@ func BenchmarkBufferedFileWriterWriteSmall(b *testing.B) { rc.Close() } } + +func TestBufferWriterCloseForWritingWithFile(t *testing.T) { + bufPool := buffer.NewBufferPool() + + ctx := context.Background() + buf := bufPool.Get(ctx) + writer := &BufferedFileWriter{ + threshold: 10, + bufPool: bufPool, + buf: buf, + } + + // Write data exceeding the threshold to ensure a file is created. + data := []byte("this is a longer string exceeding the threshold") + _, err := writer.Write(data) + assert.NoError(t, err) + + err = writer.CloseForWriting() + assert.NoError(t, err) + assert.Equal(t, readOnly, writer.state) + + rdr, err := writer.ReadCloser() + assert.NoError(t, err) + defer rdr.Close() + + // Get a buffer from the pool and check if it is the same buffer used in the writer. + bufFromPool := bufPool.Get(ctx) + assert.Same(t, buf, bufFromPool, "Buffer should be returned to the pool") + bufPool.Put(bufFromPool) +}