Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration of SpecializedHandler for Enhanced Archive Processing #1625

Merged
merged 13 commits into from
Aug 15, 2023
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ require (
github.com/google/go-containerregistry v0.15.2
github.com/google/go-github/v42 v42.0.0
github.com/googleapis/gax-go/v2 v2.12.0
github.com/h2non/filetype v1.1.3
github.com/hashicorp/go-retryablehttp v0.7.4
github.com/hashicorp/golang-lru v0.5.1
github.com/jlaffaye/ftp v0.2.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg=
github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY=
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
Expand Down
198 changes: 182 additions & 16 deletions pkg/handlers/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,13 @@ import (
"errors"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strings"
"time"

"github.com/h2non/filetype"
"github.com/mholt/archiver/v4"

"github.com/trufflesecurity/trufflehog/v3/pkg/common"
Expand All @@ -32,8 +37,8 @@ type Archive struct {
}

// New sets a default maximum size and current size counter.
func (d *Archive) New() {
d.size = 0
func (a *Archive) New() {
a.size = 0
}

// SetArchiveMaxSize sets the maximum size of the archive.
Expand All @@ -52,14 +57,14 @@ func SetArchiveMaxTimeout(timeout time.Duration) {
}

// FromFile extracts the files from an archive.
func (d *Archive) FromFile(originalCtx context.Context, data io.Reader) chan ([]byte) {
func (a *Archive) FromFile(originalCtx context.Context, data io.Reader) chan ([]byte) {
archiveChan := make(chan ([]byte), 512)
go func() {
ctx, cancel := context.WithTimeout(originalCtx, maxTimeout)
logger := logContext.AddLogger(ctx).Logger()
defer cancel()
defer close(archiveChan)
err := d.openArchive(ctx, 0, data, archiveChan)
err := a.openArchive(ctx, 0, data, archiveChan)
if err != nil {
if errors.Is(err, archiver.ErrNoMatch) {
return
Expand All @@ -71,7 +76,7 @@ func (d *Archive) FromFile(originalCtx context.Context, data io.Reader) chan ([]
}

// openArchive takes a reader and extracts the contents up to the maximum depth.
func (d *Archive) openArchive(ctx context.Context, depth int, reader io.Reader, archiveChan chan ([]byte)) error {
func (a *Archive) openArchive(ctx context.Context, depth int, reader io.Reader, archiveChan chan ([]byte)) error {
if depth >= maxDepth {
return fmt.Errorf("max archive depth reached")
}
Expand All @@ -97,14 +102,14 @@ func (d *Archive) openArchive(ctx context.Context, depth int, reader io.Reader,
if err != nil {
return err
}
fileBytes, err := d.ReadToMax(ctx, compReader)
fileBytes, err := a.ReadToMax(ctx, compReader)
if err != nil {
return err
}
newReader := bytes.NewReader(fileBytes)
return d.openArchive(ctx, depth+1, newReader, archiveChan)
return a.openArchive(ctx, depth+1, newReader, archiveChan)
case archiver.Extractor:
err := archive.Extract(context.WithValue(ctx, depthKey, depth+1), reader, nil, d.extractorHandler(archiveChan))
err := archive.Extract(context.WithValue(ctx, depthKey, depth+1), reader, nil, a.extractorHandler(archiveChan))
if err != nil {
return err
}
Expand All @@ -114,7 +119,7 @@ func (d *Archive) openArchive(ctx context.Context, depth int, reader io.Reader,
}

// IsFiletype returns true if the provided reader is an archive.
func (d *Archive) IsFiletype(ctx context.Context, reader io.Reader) (io.Reader, bool) {
func (a *Archive) IsFiletype(ctx context.Context, reader io.Reader) (io.Reader, bool) {
format, readerB, err := archiver.Identify("", reader)
if err != nil {
return readerB, false
Expand All @@ -129,7 +134,7 @@ func (d *Archive) IsFiletype(ctx context.Context, reader io.Reader) (io.Reader,
}

// extractorHandler is applied to each file in an archiver.Extractor file.
func (d *Archive) extractorHandler(archiveChan chan ([]byte)) func(context.Context, archiver.File) error {
func (a *Archive) extractorHandler(archiveChan chan ([]byte)) func(context.Context, archiver.File) error {
return func(ctx context.Context, f archiver.File) error {
logger := logContext.AddLogger(ctx).Logger()
logger.V(5).Info("Handling extracted file.", "filename", f.Name())
Expand All @@ -142,13 +147,13 @@ func (d *Archive) extractorHandler(archiveChan chan ([]byte)) func(context.Conte
if err != nil {
return err
}
fileBytes, err := d.ReadToMax(ctx, fReader)
fileBytes, err := a.ReadToMax(ctx, fReader)
if err != nil {
return err
}
fileContent := bytes.NewReader(fileBytes)

err = d.openArchive(ctx, depth, fileContent, archiveChan)
err = a.openArchive(ctx, depth, fileContent, archiveChan)
if err != nil {
return err
}
Expand All @@ -157,7 +162,7 @@ func (d *Archive) extractorHandler(archiveChan chan ([]byte)) func(context.Conte
}

// ReadToMax reads up to the max size.
func (d *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte, err error) {
func (a *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte, err error) {
// Archiver v4 is in alpha and using an experimental version of
// rardecode. There is a bug somewhere with rar decoder format 29
// that can lead to a panic. An issue is open in rardecode repo
Expand All @@ -175,7 +180,7 @@ func (d *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte,
}
}()
fileContent := bytes.Buffer{}
logger.V(5).Info("Remaining buffer capacity", "bytes", maxSize-d.size)
logger.V(5).Info("Remaining buffer capacity", "bytes", maxSize-a.size)
for i := 0; i <= maxSize/512; i++ {
if common.IsDone(ctx) {
return nil, ctx.Err()
Expand All @@ -185,17 +190,178 @@ func (d *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte,
if err != nil && !errors.Is(err, io.EOF) {
return []byte{}, err
}
d.size += bRead
a.size += bRead
if len(fileChunk) > 0 {
fileContent.Write(fileChunk[0:bRead])
}
if bRead < 512 {
return fileContent.Bytes(), nil
}
if d.size >= maxSize && bRead == 512 {
if a.size >= maxSize && bRead == 512 {
logger.V(2).Info("Max archive size reached.")
return fileContent.Bytes(), nil
}
}
return fileContent.Bytes(), nil
}

const (
arMimeType = "application/x-unix-archive"
rpmMimeType = "application/x-rpm"
)

// HandleSpecialized takes a file path and an io.Reader representing the input file,
// and processes it based on its extension, such as handling Debian (.deb) and RPM (.rpm) packages.
// It returns an io.Reader that can be used to read the processed content of the file,
// and an error if any issues occurred during processing.
// The caller is responsible for closing the returned reader.
func (a *Archive) HandleSpecialized(ctx context.Context, reader io.Reader) (io.Reader, bool, error) {
buffer := make([]byte, 512)
n, err := reader.Read(buffer)
if err != nil {
return nil, false, fmt.Errorf("unable to read file for MIME type detection: %w", err)
}

// Create a new reader that starts with the buffer we just read
// and continues with the rest of the original reader.
reader = io.MultiReader(bytes.NewReader(buffer[:n]), reader)

kind, err := filetype.Match(buffer)
if err != nil {
return nil, false, fmt.Errorf("unable to determine file type: %w", err)
}

switch mimeType := kind.MIME.Value; mimeType {
case arMimeType: // includes .deb files
reader, err = extractDebContent(ctx, reader)
case rpmMimeType:
reader, err = extractRpmContent(ctx, reader)
default:
return reader, false, nil
}

if err != nil {
return nil, false, fmt.Errorf("unable to extract file with MIME type %s: %w", kind.MIME.Value, err)
}
return reader, true, nil
}

// extractDebContent takes a .deb file as an io.Reader, extracts its contents
// into a temporary directory, and returns a Reader for the extracted data archive.
// It handles the extraction process by using the 'ar' command and manages temporary
// files and directories for the operation.
// The caller is responsible for closing the returned reader.
func extractDebContent(_ context.Context, file io.Reader) (io.ReadCloser, error) {
tempEnv, err := createTempEnv(file)
if err != nil {
return nil, err
}
defer os.Remove(tempEnv.tempFileName)
defer os.RemoveAll(tempEnv.extractPath)

cmd := exec.Command("ar", "x", tempEnv.tempFile.Name())
cmd.Dir = tempEnv.extractPath
if err := executeCommand(cmd); err != nil {
return nil, err
}

// List the content of the extraction directory.
extractedFiles, err := os.ReadDir(tempEnv.extractPath)
if err != nil {
return nil, fmt.Errorf("unable to read extracted directory: %w", err)
}

// Determine the correct data archive name. (e.g., data.tar.gz, data.tar.xz)
var dataArchiveName string
for _, file := range extractedFiles {
if strings.HasPrefix(file.Name(), "data.tar.") {
dataArchiveName = file.Name() // Use the actual name if different
break
}
}

return openDataArchive(tempEnv.extractPath, dataArchiveName)
}

// extractRpmContent takes an .rpm file as an io.Reader, extracts its contents
// into a temporary directory, and returns a Reader for the extracted data archive.
// It handles the extraction process by using the 'rpm2cpio' and 'cpio' commands and manages temporary
// files and directories for the operation.
// The caller is responsible for closing the returned reader.
func extractRpmContent(_ context.Context, file io.Reader) (io.ReadCloser, error) {
tempEnv, err := createTempEnv(file)
if err != nil {
return nil, err
}
defer os.Remove(tempEnv.tempFileName)
defer os.RemoveAll(tempEnv.extractPath)

// Use rpm2cpio to convert the RPM file to a cpio archive and then extract it using cpio command.
cmd := exec.Command("sh", "-c", "rpm2cpio "+tempEnv.tempFile.Name()+" | cpio -id")
cmd.Dir = tempEnv.extractPath
if err := executeCommand(cmd); err != nil {
return nil, err
}

// List the content of the extraction directory.
extractedFiles, err := os.ReadDir(tempEnv.extractPath)
if err != nil {
return nil, fmt.Errorf("unable to read extracted directory: %w", err)
}

var dataArchiveName string
// Determine the correct data archive name.
for _, file := range extractedFiles {
if strings.HasSuffix(file.Name(), ".tar.gz") {
dataArchiveName = file.Name()
break
}
}

return openDataArchive(tempEnv.extractPath, dataArchiveName)
}

type tempEnv struct {
tempFile *os.File
tempFileName string
extractPath string
}

// createTempEnv creates a temporary file and a temporary directory for extracting archives.
// The caller is responsible for removing these temporary resources
// (both the file and directory) when they are no longer needed.
func createTempEnv(file io.Reader) (tempEnv, error) {
tempFile, err := os.CreateTemp("", "tmp")
if err != nil {
return tempEnv{}, fmt.Errorf("unable to create temporary file: %w", err)
}

extractPath, err := os.MkdirTemp("", "tmp_archive")
if err != nil {
return tempEnv{}, fmt.Errorf("unable to create temporary directory: %w", err)
}

if _, err = io.Copy(tempFile, file); err != nil {
return tempEnv{}, fmt.Errorf("unable to copy content to temporary file: %w", err)
}

return tempEnv{tempFile: tempFile, tempFileName: tempFile.Name(), extractPath: extractPath}, nil
}

func executeCommand(cmd *exec.Cmd) error {
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("unable to execute command: %w; error: %s", err, stderr.String())
}
return nil
}

func openDataArchive(extractPath string, dataArchiveName string) (io.ReadCloser, error) {
dataArchivePath := filepath.Join(extractPath, dataArchiveName)
dataFile, err := os.Open(dataArchivePath)
if err != nil {
return nil, fmt.Errorf("unable to open file: %w", err)
}
return dataFile, nil
}
36 changes: 36 additions & 0 deletions pkg/handlers/archive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package handlers

import (
"context"
"io"
"net/http"
"os"
"regexp"
"strings"
"testing"
Expand Down Expand Up @@ -122,3 +124,37 @@ func TestHandleFile(t *testing.T) {
assert.True(t, HandleFile(context.Background(), reader, &sources.Chunk{}, ch))
assert.Equal(t, 1, len(ch))
}

func TestExtractDebContent(t *testing.T) {
// Open the sample .deb file from the testdata folder.
file, err := os.Open("testdata/test.deb")
assert.Nil(t, err)
defer file.Close()

ctx := context.Background()

reader, err := extractDebContent(ctx, file)
assert.Nil(t, err)

content, err := io.ReadAll(reader)
assert.Nil(t, err)
expectedLength := 1015582
assert.Equal(t, expectedLength, len(string(content)))
}

func TestExtractRPMContent(t *testing.T) {
// Open the sample .rpm file from the testdata folder.
file, err := os.Open("testdata/test.rpm")
assert.Nil(t, err)
defer file.Close()

ctx := context.Background()

reader, err := extractRpmContent(ctx, file)
assert.Nil(t, err)

content, err := io.ReadAll(reader)
assert.Nil(t, err)
expectedLength := 1822720
assert.Equal(t, expectedLength, len(string(content)))
}