Skip to content

Commit

Permalink
store: properly handle snappy compression continuations
Browse files Browse the repository at this point in the history
Snappy works on byte level and it can cut two different chunks in the
middle of a varint. Thus, if there's some error from the Decbuf then
fill up the buffer and try reading a varint again. Added repro test.

Closes thanos-io#6545.

Signed-off-by: Giedrius Statkevičius <giedrius.statkevicius@vinted.com>
  • Loading branch information
GiedriusS committed Aug 10, 2023
1 parent 84567ec commit 20253bc
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 23 deletions.
7 changes: 7 additions & 0 deletions pkg/store/6545postingsrepro

Large diffs are not rendered by default.

70 changes: 47 additions & 23 deletions pkg/store/postings_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,7 @@ func (it *streamedDiffVarintPostings) At() storage.SeriesRef {
return it.curSeries
}

func (it *streamedDiffVarintPostings) readNextChunk() bool {
if len(it.db.B) > 0 {
return true
}
func (it *streamedDiffVarintPostings) readNextChunk(remainder []byte) bool {
// Normal EOF.
if len(it.input) == 0 {
return false
Expand All @@ -255,13 +252,13 @@ func (it *streamedDiffVarintPostings) readNextChunk() bool {
it.err = fmt.Errorf("corrupted identifier")
return false
}
if string(it.input[:6]) != magicBody {
if string(it.input[:len(magicBody)]) != magicBody {
it.err = fmt.Errorf("got bad identifier %s", string(it.input[:6]))
return false
}
it.input = it.input[6:]
it.readSnappyIdentifier = true
return it.readNextChunk()
return it.readNextChunk(nil)
case chunkTypeCompressedData:
if !it.readSnappyIdentifier {
it.err = fmt.Errorf("missing magic snappy marker")
Expand All @@ -276,7 +273,6 @@ func (it *streamedDiffVarintPostings) readNextChunk() bool {
it.err = io.ErrUnexpectedEOF
return false
}
encodedBuf := it.input[:chunkLen]

if it.buf == nil {
if it.disablePooling {
Expand All @@ -291,6 +287,15 @@ func (it *streamedDiffVarintPostings) readNextChunk() bool {
}
}

encodedBuf := it.input[:chunkLen]

// NOTE(GiedriusS): we can probably optimize this better but this should be rare enough
// and not cause any problems.
if len(remainder) > 0 {
remainderCopy := make([]byte, 0, len(remainder))
remainderCopy = append(remainderCopy, remainder...)
remainder = remainderCopy
}
decoded, err := s2.Decode(it.buf, encodedBuf[checksumSize:])
if err != nil {
it.err = err
Expand All @@ -300,7 +305,11 @@ func (it *streamedDiffVarintPostings) readNextChunk() bool {
it.err = fmt.Errorf("mismatched checksum (got %v, expected %v)", crc(decoded), checksum)
return false
}
it.db.B = decoded
if len(remainder) > 0 {
it.db.B = append(remainder, decoded...)
} else {
it.db.B = decoded
}
case chunkTypeUncompressedData:
if !it.readSnappyIdentifier {
it.err = fmt.Errorf("missing magic snappy marker")
Expand All @@ -315,11 +324,25 @@ func (it *streamedDiffVarintPostings) readNextChunk() bool {
it.err = io.ErrUnexpectedEOF
return false
}
it.db.B = it.input[checksumSize:chunkLen]
if crc(it.db.B) != checksum {
it.err = fmt.Errorf("mismatched checksum (got %v, expected %v)", crc(it.db.B), checksum)
uncompressedData := it.input[checksumSize:chunkLen]
if crc(uncompressedData) != checksum {
it.err = fmt.Errorf("mismatched checksum (got %v, expected %v)", crc(uncompressedData), checksum)
return false
}

// NOTE(GiedriusS): we can probably optimize this better but this should be rare enough
// and not cause any problems.
if len(remainder) > 0 {
remainderCopy := make([]byte, 0, len(remainder))
remainderCopy = append(remainderCopy, remainder...)
remainder = remainderCopy
}

if len(remainder) > 0 {
it.db.B = append(remainder, uncompressedData...)
} else {
it.db.B = uncompressedData
}
default:
if chunkType <= 0x7f {
it.err = fmt.Errorf("unsupported chunk type %v", chunkType)
Expand All @@ -336,19 +359,21 @@ func (it *streamedDiffVarintPostings) readNextChunk() bool {
}

func (it *streamedDiffVarintPostings) Next() bool {
if !it.readNextChunk() {
return false
}
val := it.db.Uvarint()
if it.db.Err() != nil {
if it.db.Err() != io.EOF {
it.err = it.db.Err()
// Continue reading next chunks until there is at least binary.MaxVarintLen64.
// If we cannot add any more chunks then return false.
for {
val := it.db.Uvarint64()
if it.db.Err() != nil {
if !it.readNextChunk(it.db.B) {
return false
}
it.db.E = nil
continue
}
return false
}

it.curSeries = it.curSeries + storage.SeriesRef(val)
return true
it.curSeries = it.curSeries + storage.SeriesRef(val)
return true
}
}

func (it *streamedDiffVarintPostings) Err() error {
Expand Down Expand Up @@ -534,7 +559,6 @@ func snappyStreamedEncode(postingsLength int, diffVarintPostings []byte) ([]byte
if err != nil {
return nil, fmt.Errorf("creating snappy compressor: %w", err)
}

_, err = sw.Write(diffVarintPostings)
if err != nil {
return nil, err
Expand Down
46 changes: 46 additions & 0 deletions pkg/store/postings_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"bytes"
"context"
crand "crypto/rand"
"io"
"math"
"math/rand"
"os"
"sort"
"strconv"
"testing"
Expand Down Expand Up @@ -338,3 +340,47 @@ func FuzzSnappyStreamEncoding(f *testing.F) {
testutil.Ok(t, err)
})
}

func TestRegressionIssue6545(t *testing.T) {
diffVarintPostings, err := os.ReadFile("6545postingsrepro")
testutil.Ok(t, err)

gotPostings := 0
dvp := newDiffVarintPostings(diffVarintPostings, nil)
decodedPostings := []storage.SeriesRef{}
for dvp.Next() {
decodedPostings = append(decodedPostings, dvp.At())
gotPostings++
}
testutil.Ok(t, dvp.Err())
testutil.Equals(t, 114024, gotPostings)

dataToCache, err := snappyStreamedEncode(114024, diffVarintPostings)
testutil.Ok(t, err)

// Check that the original decompressor works well.
sr := s2.NewReader(bytes.NewBuffer(dataToCache[3:]))
readBytes, err := io.ReadAll(sr)
testutil.Ok(t, err)
testutil.Equals(t, readBytes, diffVarintPostings)

dvp = newDiffVarintPostings(readBytes, nil)
gotPostings = 0
for dvp.Next() {
gotPostings++
}
testutil.Equals(t, 114024, gotPostings)

p, err := decodePostings(dataToCache)
testutil.Ok(t, err)

i := 0
for p.Next() {
post := p.At()
testutil.Equals(t, uint64(decodedPostings[i]), uint64(post))
i++
}

testutil.Ok(t, p.Err())
testutil.Equals(t, 114024, i)
}

0 comments on commit 20253bc

Please sign in to comment.