forked from ava-labs/avalanchego
-
Notifications
You must be signed in to change notification settings - Fork 0
/
codec.go
348 lines (320 loc) · 10.5 KB
/
codec.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
// (c) 2019-2020, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.
package codec
import (
"errors"
"fmt"
"reflect"
"unicode"
"github.com/ava-labs/gecko/utils/wrappers"
)
const (
defaultMaxSize = 1 << 18 // default max size, in bytes, of something being marshalled by Marshal()
defaultMaxSliceLength = 1 << 18 // default max length of a slice being marshalled by Marshal()
)
// ErrBadCodec is returned when one tries to perform an operation
// using an unknown codec
var (
errBadCodec = errors.New("wrong or unknown codec used")
errNil = errors.New("can't marshal nil value")
errUnmarshalNil = errors.New("can't unmarshal into nil")
errNeedPointer = errors.New("must unmarshal into a pointer")
errMarshalUnregisteredType = errors.New("can't marshal an unregistered type")
errUnmarshalUnregisteredType = errors.New("can't unmarshal an unregistered type")
errUnknownType = errors.New("don't know how to marshal/unmarshal this type")
errMarshalUnexportedField = errors.New("can't serialize an unexported field")
errUnmarshalUnexportedField = errors.New("can't deserialize into an unexported field")
errOutOfMemory = errors.New("out of memory")
errSliceTooLarge = errors.New("slice too large")
)
// Codec handles marshaling and unmarshaling of structs
type codec struct {
maxSize int
maxSliceLen int
typeIDToType map[uint32]reflect.Type
typeToTypeID map[reflect.Type]uint32
}
// Codec marshals and unmarshals
type Codec interface {
RegisterType(interface{}) error
Marshal(interface{}) ([]byte, error)
Unmarshal([]byte, interface{}) error
}
// New returns a new codec
func New(maxSize, maxSliceLen int) Codec {
return codec{
maxSize: maxSize,
maxSliceLen: maxSliceLen,
typeIDToType: map[uint32]reflect.Type{},
typeToTypeID: map[reflect.Type]uint32{},
}
}
// NewDefault returns a new codec with reasonable default values
func NewDefault() Codec { return New(defaultMaxSize, defaultMaxSliceLength) }
// RegisterType is used to register types that may be unmarshaled into an interface typed value
// [val] is a value of the type being registered
func (c codec) RegisterType(val interface{}) error {
valType := reflect.TypeOf(val)
if _, exists := c.typeToTypeID[valType]; exists {
return fmt.Errorf("type %v has already been registered", valType)
}
c.typeIDToType[uint32(len(c.typeIDToType))] = reflect.TypeOf(val)
c.typeToTypeID[valType] = uint32(len(c.typeIDToType) - 1)
return nil
}
// A few notes:
// 1) See codec_test.go for examples of usage
// 2) We use "marshal" and "serialize" interchangeably, and "unmarshal" and "deserialize" interchangeably
// 3) To include a field of a struct in the serialized form, add the tag `serialize:"true"` to it
// 4) These typed members of a struct may be serialized:
// bool, string, uint[8,16,32,64, int[8,16,32,64],
// structs, slices, arrays, interface.
// structs, slices and arrays can only be serialized if their constituent parts can be.
// 5) To marshal an interface typed value, you must pass a _pointer_ to the value
// 6) If you want to be able to unmarshal into an interface typed value,
// you must call codec.RegisterType([instance of the type that fulfills the interface]).
// 7) nil slices will be unmarshaled as an empty slice of the appropriate type
// 8) Serialized fields must be exported
// Marshal returns the byte representation of [value]
// If you want to marshal an interface, [value] must be a pointer
// to the interface
func (c codec) Marshal(value interface{}) ([]byte, error) {
if value == nil {
return nil, errNil
}
return c.marshal(reflect.ValueOf(value))
}
// Marshal [value] to bytes
func (c codec) marshal(value reflect.Value) ([]byte, error) {
p := wrappers.Packer{MaxSize: c.maxSize, Bytes: []byte{}}
t := value.Type()
valueKind := value.Kind()
switch valueKind {
case reflect.Interface, reflect.Ptr, reflect.Slice:
if value.IsNil() {
return nil, errNil
}
}
switch valueKind {
case reflect.Uint8:
p.PackByte(uint8(value.Uint()))
return p.Bytes, p.Err
case reflect.Int8:
p.PackByte(uint8(value.Int()))
return p.Bytes, p.Err
case reflect.Uint16:
p.PackShort(uint16(value.Uint()))
return p.Bytes, p.Err
case reflect.Int16:
p.PackShort(uint16(value.Int()))
return p.Bytes, p.Err
case reflect.Uint32:
p.PackInt(uint32(value.Uint()))
return p.Bytes, p.Err
case reflect.Int32:
p.PackInt(uint32(value.Int()))
return p.Bytes, p.Err
case reflect.Uint64:
p.PackLong(value.Uint())
return p.Bytes, p.Err
case reflect.Int64:
p.PackLong(uint64(value.Int()))
return p.Bytes, p.Err
case reflect.Uintptr, reflect.Ptr:
return c.marshal(value.Elem())
case reflect.String:
p.PackStr(value.String())
return p.Bytes, p.Err
case reflect.Bool:
p.PackBool(value.Bool())
return p.Bytes, p.Err
case reflect.Interface:
typeID, ok := c.typeToTypeID[reflect.TypeOf(value.Interface())] // Get the type ID of the value being marshaled
if !ok {
return nil, fmt.Errorf("can't marshal unregistered type '%v'", reflect.TypeOf(value.Interface()).String())
}
p.PackInt(typeID)
bytes, err := c.Marshal(value.Interface())
if err != nil {
return nil, err
}
p.PackFixedBytes(bytes)
if p.Errored() {
return nil, p.Err
}
return p.Bytes, err
case reflect.Array, reflect.Slice:
numElts := value.Len() // # elements in the slice/array (assumed to be <= 2^31 - 1)
// If this is a slice, pack the number of elements in the slice
if valueKind == reflect.Slice {
p.PackInt(uint32(numElts))
}
for i := 0; i < numElts; i++ { // Pack each element in the slice/array
eltBytes, err := c.marshal(value.Index(i))
if err != nil {
return nil, err
}
p.PackFixedBytes(eltBytes)
}
return p.Bytes, p.Err
case reflect.Struct:
for i := 0; i < t.NumField(); i++ { // Go through all fields of this struct
field := t.Field(i)
if !shouldSerialize(field) { // Skip fields we don't need to serialize
continue
}
if unicode.IsLower(rune(field.Name[0])) { // Can only marshal exported fields
return nil, errMarshalUnexportedField
}
fieldVal := value.Field(i) // The field we're serializing
if fieldVal.Kind() == reflect.Slice && fieldVal.IsNil() {
p.PackInt(0)
continue
}
fieldBytes, err := c.marshal(fieldVal) // Serialize the field
if err != nil {
return nil, err
}
p.PackFixedBytes(fieldBytes)
}
return p.Bytes, p.Err
case reflect.Invalid:
return nil, errUnmarshalNil
default:
return nil, errUnknownType
}
}
// Unmarshal unmarshals [bytes] into [dest], where
// [dest] must be a pointer or interface
func (c codec) Unmarshal(bytes []byte, dest interface{}) error {
p := &wrappers.Packer{Bytes: bytes}
if len(bytes) > c.maxSize {
return errSliceTooLarge
}
if dest == nil {
return errNil
}
destPtr := reflect.ValueOf(dest)
if destPtr.Kind() != reflect.Ptr {
return errNeedPointer
}
destVal := destPtr.Elem()
err := c.unmarshal(p, destVal)
if err != nil {
return err
}
if p.Offset != len(p.Bytes) {
return fmt.Errorf("has %d leftover bytes after unmarshalling", len(p.Bytes)-p.Offset)
}
return nil
}
// Unmarshal bytes from [p] into [field]
// [field] must be addressable
func (c codec) unmarshal(p *wrappers.Packer, field reflect.Value) error {
kind := field.Kind()
switch kind {
case reflect.Uint8:
field.SetUint(uint64(p.UnpackByte()))
case reflect.Int8:
field.SetInt(int64(p.UnpackByte()))
case reflect.Uint16:
field.SetUint(uint64(p.UnpackShort()))
case reflect.Int16:
field.SetInt(int64(p.UnpackShort()))
case reflect.Uint32:
field.SetUint(uint64(p.UnpackInt()))
case reflect.Int32:
field.SetInt(int64(p.UnpackInt()))
case reflect.Uint64:
field.SetUint(p.UnpackLong())
case reflect.Int64:
field.SetInt(int64(p.UnpackLong()))
case reflect.Bool:
field.SetBool(p.UnpackBool())
case reflect.Slice:
sliceLen := int(p.UnpackInt()) // number of elements in the slice
if sliceLen < 0 || sliceLen > c.maxSliceLen {
return errSliceTooLarge
}
// First set [field] to be a slice of the appropriate type/capacity (right now [field] is nil)
slice := reflect.MakeSlice(field.Type(), sliceLen, sliceLen)
field.Set(slice)
// Unmarshal each element into the appropriate index of the slice
for i := 0; i < sliceLen; i++ {
if err := c.unmarshal(p, field.Index(i)); err != nil {
return err
}
}
case reflect.Array:
for i := 0; i < field.Len(); i++ {
if err := c.unmarshal(p, field.Index(i)); err != nil {
return err
}
}
case reflect.String:
field.SetString(p.UnpackStr())
case reflect.Interface:
// Get the type ID
typeID := p.UnpackInt()
// Get a struct that implements the interface
typ, ok := c.typeIDToType[typeID]
if !ok {
return errUnmarshalUnregisteredType
}
// Ensure struct actually does implement the interface
fieldType := field.Type()
if !typ.Implements(fieldType) {
return fmt.Errorf("%s does not implement interface %s", typ, fieldType)
}
concreteInstancePtr := reflect.New(typ) // instance of the proper type
// Unmarshal into the struct
if err := c.unmarshal(p, concreteInstancePtr.Elem()); err != nil {
return err
}
// And assign the filled struct to the field
field.Set(concreteInstancePtr.Elem())
case reflect.Struct:
// Type of this struct
structType := reflect.TypeOf(field.Interface())
// Go through all the fields and umarshal into each
for i := 0; i < structType.NumField(); i++ {
structField := structType.Field(i)
if !shouldSerialize(structField) { // Skip fields we don't need to unmarshal
continue
}
if unicode.IsLower(rune(structField.Name[0])) { // Only unmarshal into exported field
return errUnmarshalUnexportedField
}
field := field.Field(i) // Get the field
if err := c.unmarshal(p, field); err != nil { // Unmarshal into the field
return err
}
if p.Errored() { // If there was an error just return immediately
return p.Err
}
}
case reflect.Ptr:
// Get the type this pointer points to
underlyingType := field.Type().Elem()
// Create a new pointer to a new value of the underlying type
underlyingValue := reflect.New(underlyingType)
// Fill the value
if err := c.unmarshal(p, underlyingValue.Elem()); err != nil {
return err
}
// Assign to the top-level struct's member
field.Set(underlyingValue)
case reflect.Invalid:
return errUnmarshalNil
default:
return errUnknownType
}
return p.Err
}
// Returns true iff [field] should be serialized
func shouldSerialize(field reflect.StructField) bool {
if field.Tag.Get("serialize") == "true" {
return true
}
return false
}