Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] - base64 decoder #3762

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
improve base64 decoder
  • Loading branch information
ahrav committed Dec 6, 2024
commit 90560a306e26913b8176443452194b42dfe624f1
162 changes: 114 additions & 48 deletions pkg/decoders/base64.go
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/base64"
"unicode"
"unsafe"

"github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
@@ -16,14 +17,15 @@ type (
var (
b64Charset = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/-_=")
b64EndChars = "+/-_="
// Given characters are mostly ASCII, we can use a simple array to map.

b64CharsetMapping [128]bool
)

func init() {
// Build an array of all the characters in the base64 charset.
for _, char := range b64Charset {
b64CharsetMapping[char] = true
if char < 128 {
b64CharsetMapping[char] = true
}
}
}

@@ -33,63 +35,104 @@ func (d *Base64) Type() detectorspb.DecoderType {

func (d *Base64) FromChunk(chunk *sources.Chunk) *DecodableChunk {
decodableChunk := &DecodableChunk{Chunk: chunk, DecoderType: d.Type()}
encodedSubstrings := getSubstringsOfCharacterSet(chunk.Data, 20, b64CharsetMapping, b64EndChars)
decodedSubstrings := make(map[string][]byte)
candidates := getSubstringsOfCharacterSet(chunk.Data, 20, b64CharsetMapping, b64EndChars)

if len(candidates) == 0 {
return nil
}

for _, str := range encodedSubstrings {
dec, err := base64.StdEncoding.DecodeString(str)
if err == nil && len(dec) > 0 && isASCII(dec) {
decodedSubstrings[str] = dec
var decodedCandidates []decodedCandidate
for _, c := range candidates {
data := chunk.Data[c.start:c.end]
substring := bytesToString(data)

// Heuristics: If substring contains '=', try StdEncoding first; otherwise, RawURLEncoding.
// This avoids unnecessary decoding since:
// 1. If a string contains '=', it's likely using standard base64 padding
// 2. If a string can be decoded by both standard and URL-safe base64,
// both decodings would produce identical output (they only differ in
// how they encode '+/' vs '-_')
// 3. Therefore, if we successfully decode with our first attempt, we can
// skip trying the other encoding
var dec []byte
if bytes.Contains(data, []byte("=")) {
dec, _ = base64.StdEncoding.DecodeString(substring)
if len(dec) == 0 {
dec, _ = base64.RawURLEncoding.DecodeString(substring)
}
} else {
dec, _ = base64.RawURLEncoding.DecodeString(substring)
if len(dec) == 0 {
dec, _ = base64.StdEncoding.DecodeString(substring)
}
}

dec, err = base64.RawURLEncoding.DecodeString(str)
if err == nil && len(dec) > 0 && isASCII(dec) {
decodedSubstrings[str] = dec
if len(dec) > 0 && isASCII(dec) {
decodedCandidates = append(decodedCandidates, decodedCandidate{
start: c.start,
end: c.end,
decoded: dec,
})
}
}

if len(decodedSubstrings) > 0 {
var result bytes.Buffer
result.Grow(len(chunk.Data))

start := 0
for _, encoded := range encodedSubstrings {
if decoded, ok := decodedSubstrings[encoded]; ok {
end := bytes.Index(chunk.Data[start:], []byte(encoded))
if end != -1 {
result.Write(chunk.Data[start : start+end])
result.Write(decoded)
start += end + len(encoded)
}
}
if len(decodedCandidates) == 0 {
return nil
}

// Rebuild the chunk data
var result bytes.Buffer
result.Grow(len(chunk.Data))

lastPos := 0
for _, dc := range decodedCandidates {
if dc.start > lastPos {
result.Write(chunk.Data[lastPos:dc.start])
}
result.Write(chunk.Data[start:])
chunk.Data = result.Bytes()
return decodableChunk
result.Write(dc.decoded)
lastPos = dc.end
}

if lastPos < len(chunk.Data) {
result.Write(chunk.Data[lastPos:])
}

return nil
chunk.Data = result.Bytes()
return decodableChunk
}

func bytesToString(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}

func isASCII(b []byte) bool {
for i := 0; i < len(b); i++ {
if b[i] > unicode.MaxASCII {
for _, c := range b {
if c > unicode.MaxASCII {
return false
}
}
return true
}

func getSubstringsOfCharacterSet(data []byte, threshold int, charsetMapping [128]bool, endChars string) []string {
type candidate struct {
start int
end int
hasEq bool
}

type decodedCandidate struct {
start int
end int
decoded []byte
}

func getSubstringsOfCharacterSet(data []byte, threshold int, charsetMapping [128]bool, endChars string) []candidate {
if len(data) == 0 {
return nil
}

count := 0
substringsCount := 0

// Determine the number of substrings that will be returned.
// Pre-allocate the slice to avoid reallocations.
for _, char := range data {
if char < 128 && charsetMapping[char] {
count++
@@ -104,37 +147,60 @@ func getSubstringsOfCharacterSet(data []byte, threshold int, charsetMapping [128
substringsCount++
}

if substringsCount == 0 {
return nil
}

candidates := make([]candidate, 0, substringsCount)

count = 0
start := 0
substrings := make([]string, 0, substringsCount)

equalsFound := false
for i, char := range data {
if char < 128 && charsetMapping[char] {
if count == 0 {
start = i
equalsFound = false
}
if char == '=' {
equalsFound = true
}
count++
} else {
if count > threshold {
substrings = appendB64Substring(data, start, count, substrings, endChars)
candidates = appendB64Substring(data, start, count, candidates, endChars, equalsFound)
}
count = 0
}
}

// handle trailing substring if needed
if count > threshold {
substrings = appendB64Substring(data, start, count, substrings, endChars)
candidates = appendB64Substring(data, start, count, candidates, endChars, equalsFound)
}

return substrings
return candidates
}

func appendB64Substring(data []byte, start, count int, substrings []string, endChars string) []string {
func appendB64Substring(data []byte, start, count int, candidates []candidate, endChars string, hasEq bool) []candidate {
substring := bytes.TrimLeft(data[start:start+count], endChars)
if idx := bytes.IndexByte(bytes.TrimRight(substring, endChars), '='); idx != -1 {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for handling = within the middle of a string has been preserved.

substrings = append(substrings, string(substring[idx+1:]))
trimmedRight := bytes.TrimRight(substring, endChars)

idx := bytes.IndexByte(trimmedRight, '=')
if idx != -1 {
// substring after '='
// Note: substring and trimmedRight differ potentially on trailing chars,
// but trimming right doesn't affect the position of '=' relative to substring.
// idx is from trimmedRight start, which has the same start as substring.
candidates = append(candidates, candidate{
start: start + (count - len(substring)) + idx + 1,
end: start + (count - len(substring)) + len(substring),
hasEq: hasEq,
})
} else {
substrings = append(substrings, string(substring))
candidates = append(candidates, candidate{
start: start + (count - len(substring)),
end: start + (count - len(substring)) + len(substring),
hasEq: hasEq,
})
}
return substrings
return candidates
}
8 changes: 7 additions & 1 deletion pkg/decoders/base64_test.go
Original file line number Diff line number Diff line change
@@ -155,6 +155,8 @@ func BenchmarkFromChunkSmall(b *testing.B) {
d := Base64{}
data := detectors.MustGetBenchmarkData()["small"]

b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
d.FromChunk(&sources.Chunk{Data: data})
}
@@ -164,15 +166,19 @@ func BenchmarkFromChunkMedium(b *testing.B) {
d := Base64{}
data := detectors.MustGetBenchmarkData()["medium"]

b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
d.FromChunk(&sources.Chunk{Data: data})
}
}

func BenchmarkFromChunkLarge(b *testing.B) {
d := Base64{}
data := detectors.MustGetBenchmarkData()["big"]
data := detectors.MustGetBenchmarkData()["large"]

b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
d.FromChunk(&sources.Chunk{Data: data})
}