/
block_decompressor.go
145 lines (130 loc) · 4.01 KB
/
block_decompressor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
package nsz
import (
"io"
"math"
"github.com/klauspost/compress/zstd"
)
type blockInfo struct {
fileoffset int64 // Position this block starts at in the source file
blockCompressedLength int64 // How many bytes can be read from the compressed source at a maximum (before it runs into the next block)
}
func (b *blockInfo) GetReader(reader io.ReadSeeker) io.Reader {
return io.LimitReader(reader, b.blockCompressedLength)
}
type Decompressor struct {
io.Reader
source io.ReadSeeker
initalOffset int64
header *BlockHeader
blockSize int64
compressionBlocks []blockInfo
currentVirtualPos int64
currentDecompressor *zstd.Decoder
currentNotCompressedReader io.Reader
trace int64
}
func NewBlockDecompressor(reader io.ReadSeeker) (*Decompressor, error) {
dec := &Decompressor{
source: reader,
currentVirtualPos: 0,
}
header, err := NewBlockHeader(reader)
if err != nil {
return nil, err
}
dec.header = header
n, err := reader.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
}
dec.initalOffset = n
//Now need to convert all of the recorded lengths into the index of blocks
//Each block represents 2^BlockSizeExponent
dec.blockSize = int64(math.Pow(2, float64(header.BlockSizeExponent)))
dec.compressionBlocks = make([]blockInfo, len(header.CompressedBlockSizeList))
for i, bs := range header.CompressedBlockSizeList {
dec.compressionBlocks[i] = blockInfo{
fileoffset: n,
blockCompressedLength: int64(bs),
}
n += int64(bs)
}
return dec, nil
}
func (d *Decompressor) Read(p []byte) (n int, err error) {
//read out from the existing zstd compressor if it exists
n = 0
if d.currentDecompressor != nil || d.currentNotCompressedReader != nil {
//Try and read as much as we can from this existing compressor
if d.currentDecompressor != nil {
n, err = d.currentDecompressor.Read(p)
} else {
n, err = d.currentNotCompressedReader.Read(p)
}
//Need to check if we hit EOF, so we can close the decompressor
isEOF := err == io.EOF
d.currentVirtualPos += int64(n)
d.trace += int64(n)
if isEOF {
if d.currentDecompressor != nil {
d.currentDecompressor.Close()
}
d.currentDecompressor = nil
d.currentNotCompressedReader = nil
err = nil
}
return n, err
}
//Load in the next decompression block if we can
nextBlock := d.currentVirtualPos / d.blockSize
if int(nextBlock) >= len(d.compressionBlocks) || d.currentVirtualPos > d.header.DecompressedSize {
return 0, io.EOF
}
nextBlocko := d.compressionBlocks[nextBlock]
//Seek to appropriate starting point
{
_, err = d.source.Seek(nextBlocko.fileoffset, io.SeekStart)
if err != nil {
return 0, err
}
}
//Decide if we are reading a compressed block or an uncompressed (skipped) block
expectedDecompressedBlockSize := int64(d.blockSize)
if int(nextBlock) == len(d.compressionBlocks)-1 {
expectedDecompressedBlockSize = int64(d.header.DecompressedSize - d.currentVirtualPos)
}
//If expectedDecompressedBlockSize is the same as the recorded block size; its not compressed
if int64(expectedDecompressedBlockSize) == nextBlocko.blockCompressedLength {
d.currentNotCompressedReader = nextBlocko.GetReader(d.source)
d.currentDecompressor = nil
} else {
zstdReader, err := zstd.NewReader(nextBlocko.GetReader(d.source))
if err != nil {
return 0, err
}
d.currentDecompressor = zstdReader
d.currentNotCompressedReader = nil
}
//Try and read as much as we can from this existing compressor
if d.currentDecompressor != nil {
n, err = d.currentDecompressor.Read(p)
} else {
n, err = d.currentNotCompressedReader.Read(p)
}
//Need to check if we hit EOF, so we can close the decompressor
if err == io.EOF {
if d.currentDecompressor != nil {
d.currentDecompressor.Close()
}
d.currentDecompressor = nil
d.currentNotCompressedReader = nil
}
d.currentVirtualPos += int64(n)
d.trace = int64(n)
return n, err
}
func (d *Decompressor) Close() {
if d.currentDecompressor != nil {
d.currentDecompressor.Close()
}
}