diff --git a/pkg/sources/filesystem/filesystem.go b/pkg/sources/filesystem/filesystem.go index 5f47ffa47c7f..9a1c4ceb9975 100644 --- a/pkg/sources/filesystem/filesystem.go +++ b/pkg/sources/filesystem/filesystem.go @@ -119,12 +119,7 @@ func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sour // Skip over non-regular files. We do this check here to suppress noisy // logs for trying to scan directories and other non-regular files in // our traversal. - fileStat, err := os.Stat(fullPath) - if err != nil { - ctx.Logger().Info("unable to stat file", "path", fullPath, "error", err) - return nil - } - if !fileStat.Mode().IsRegular() { + if !d.Type().IsRegular() { return nil } if s.filter != nil && !s.filter.Pass(fullPath) { @@ -223,9 +218,35 @@ func (s *Source) scanFile(ctx context.Context, path string, chunksChan chan *sou // filepath or a directory. func (s *Source) Enumerate(ctx context.Context, reporter sources.UnitReporter) error { for _, path := range s.paths { - item := sources.CommonSourceUnit{ID: path} - if err := reporter.UnitOk(ctx, item); err != nil { - return err + fileInfo, err := os.Stat(filepath.Clean(path)) + if err != nil { + if err := reporter.UnitErr(ctx, err); err != nil { + return err + } + continue + } + if !fileInfo.IsDir() { + item := sources.CommonSourceUnit{ID: path} + if err := reporter.UnitOk(ctx, item); err != nil { + return err + } + continue + } + err = fs.WalkDir(os.DirFS(path), ".", func(relativePath string, d fs.DirEntry, err error) error { + if err != nil { + return reporter.UnitErr(ctx, err) + } + if d.IsDir() { + return nil + } + fullPath := filepath.Join(path, relativePath) + item := sources.CommonSourceUnit{ID: fullPath} + return reporter.UnitOk(ctx, item) + }) + if err != nil { + if err := reporter.UnitErr(ctx, err); err != nil { + return err + } } } return nil diff --git a/pkg/sources/filesystem/filesystem_test.go b/pkg/sources/filesystem/filesystem_test.go index e6925fafe45a..0dccb3ee5393 100644 --- a/pkg/sources/filesystem/filesystem_test.go +++ b/pkg/sources/filesystem/filesystem_test.go @@ -2,6 +2,7 @@ package filesystem import ( "os" + "path/filepath" "strings" "testing" "time" @@ -115,14 +116,36 @@ func TestScanFile(t *testing.T) { } func TestEnumerate(t *testing.T) { + // TODO: refactor to allow a virtual filesystem. t.Parallel() ctx := context.Background() // Setup the connection to test enumeration. + dir, err := os.MkdirTemp("", "trufflehog-test-enumerate") + assert.NoError(t, err) + defer os.RemoveAll(dir) + units := []string{ "/one", "/two", "/three", "/path/to/dir/", "/path/to/another/dir/", } + // Prefix the units with the tempdir and create files on disk. + for i, unit := range units { + fullPath := filepath.Join(dir, unit) + units[i] = fullPath + if i < 3 { + f, err := os.Create(fullPath) + assert.NoError(t, err) + f.Close() + } else { + assert.NoError(t, os.MkdirAll(fullPath, 0755)) + // Create a file in the directory for enumeration to find. + f, err := os.CreateTemp(fullPath, "file") + assert.NoError(t, err) + units[i] = f.Name() + f.Close() + } + } conn, err := anypb.New(&sourcespb.Filesystem{ Paths: units[0:3], Directories: units[3:],