/
store.go
211 lines (184 loc) 路 4.67 KB
/
store.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
package server
import (
"context"
"encoding/binary"
"fmt"
"io"
"os"
"github.com/wandb/wandb/core/pkg/observability"
"github.com/wandb/wandb/core/pkg/leveldb"
"github.com/wandb/wandb/core/pkg/service"
"google.golang.org/protobuf/proto"
)
type HeaderOptions struct {
IDENT [4]byte
Magic uint16
Version byte
}
const (
// headerMagic is the magic number for the header.
headerMagic = 0xBEE1
// headerVersion is the version of the header.
headerVersion = 0
)
// headerIdent returns the header identifier.
func headerIdent() [4]byte {
return [4]byte{':', 'W', '&', 'B'}
}
// NewHeader returns a new header with default values.
func NewHeader() *HeaderOptions {
return &HeaderOptions{
IDENT: headerIdent(),
Magic: headerMagic,
Version: headerVersion,
}
}
// MarshalBinary encodes the header to binary format.
func (o *HeaderOptions) MarshalBinary(w io.Writer) error {
if err := binary.Write(w, binary.LittleEndian, o); err != nil {
return fmt.Errorf("error writing binary data: %w", err)
}
return nil
}
// UnmarshalBinary decodes binary data into the header.
func (o *HeaderOptions) UnmarshalBinary(r io.Reader) error {
if err := binary.Read(r, binary.LittleEndian, o); err != nil {
return fmt.Errorf("error reading binary data: %w", err)
}
return nil
}
// Valid checks if the header is valid based on a reference header.
func (o *HeaderOptions) Valid() bool {
return o.IDENT == headerIdent() && o.Magic == headerMagic && o.Version == headerVersion
}
// Store is the persistent store for a stream
type Store struct {
// ctx is the context for the store
ctx context.Context
// name is the name of the underlying file
name string
// writer is the underlying writer
writer *leveldb.Writer
// reader is the underlying reader
reader *leveldb.Reader
// db is the underlying database
db *os.File
// logger is the logger for the store
logger *observability.CoreLogger
}
// NewStore creates a new store
func NewStore(ctx context.Context, fileName string, logger *observability.CoreLogger) *Store {
sr := &Store{ctx: ctx,
name: fileName,
logger: logger,
}
return sr
}
// Open opens the store
func (sr *Store) Open(flag int) error {
switch flag {
case os.O_RDONLY:
f, err := os.Open(sr.name)
if err != nil {
sr.logger.CaptureError("can't open file", err)
return err
}
sr.db = f
sr.reader = leveldb.NewReaderExt(f, leveldb.CRCAlgoIEEE)
header := NewHeader()
if err := header.UnmarshalBinary(sr.db); err != nil {
sr.logger.CaptureError("can't read header", err)
return err
}
if !header.Valid() {
err := fmt.Errorf("invalid header")
sr.logger.CaptureError("can't read header", err)
return err
}
return nil
case os.O_WRONLY:
f, err := os.Create(sr.name)
if err != nil {
sr.logger.CaptureError("can't open file", err)
return err
}
sr.db = f
sr.writer = leveldb.NewWriterExt(f, leveldb.CRCAlgoIEEE)
header := NewHeader()
if err := header.MarshalBinary(sr.db); err != nil {
sr.logger.CaptureError("can't write header", err)
return err
}
return nil
default:
// TODO: generalize this?
err := fmt.Errorf("invalid flag %d", flag)
sr.logger.CaptureError("can't open file", err)
return err
}
}
// Close closes the store
func (sr *Store) Close() error {
if sr.writer != nil {
err := sr.writer.Close()
if err != nil {
sr.logger.CaptureError("can't close file", err)
}
}
err := sr.db.Close()
if err != nil {
sr.logger.CaptureError("can't close file", err)
}
sr.db = nil
return err
}
func (sr *Store) Write(msg *service.Record) error {
writer, err := sr.writer.Next()
if err != nil {
sr.logger.CaptureError("can't write header", err)
return err
}
out, err := proto.Marshal(msg)
if err != nil {
sr.logger.CaptureError("can't write header", err)
return err
}
if _, err = writer.Write(out); err != nil {
sr.logger.CaptureError("can't write header", err)
return err
}
return nil
}
func (sr *Store) WriteDirectlyToDB(data []byte) (int, error) {
// this is for testing purposes only
return sr.db.Write(data)
}
func (sr *Store) Read() (*service.Record, error) {
// check if db is closed
if sr.db == nil {
err := fmt.Errorf("db is closed")
sr.logger.CaptureError("can't read record", err)
return nil, err
}
reader, err := sr.reader.Next()
if err == io.EOF {
return nil, err
}
if err != nil {
sr.logger.CaptureError("can't read record", err)
sr.reader.Recover()
return nil, err
}
buf, err := io.ReadAll(reader)
if err != nil {
sr.logger.CaptureError("can't read record", err)
sr.reader.Recover()
return nil, err
}
msg := &service.Record{}
if err = proto.Unmarshal(buf, msg); err != nil {
sr.logger.CaptureError("can't read record", err)
return nil, err
}
return msg, nil
}