forked from btcsuite/btcd
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cmdparse.go
556 lines (492 loc) · 17.9 KB
/
cmdparse.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
// Copyright (c) 2014 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
package btcjson
import (
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
)
// makeParams creates a slice of interface values for the given struct.
func makeParams(rt reflect.Type, rv reflect.Value) []interface{} {
numFields := rt.NumField()
params := make([]interface{}, 0, numFields)
lastParam := -1
for i := 0; i < numFields; i++ {
rtf := rt.Field(i)
rvf := rv.Field(i)
params = append(params, rvf.Interface())
if rtf.Type.Kind() == reflect.Ptr {
if rvf.IsNil() {
// Omit optional null params unless a non-null param follows
continue
}
}
lastParam = i
}
return params[:lastParam+1]
}
// MarshalCmd marshals the passed command to a JSON-RPC request byte slice that
// is suitable for transmission to an RPC server. The provided command type
// must be a registered type. All commands provided by this package are
// registered by default.
func MarshalCmd(rpcVersion RPCVersion, id interface{}, cmd interface{}) ([]byte, error) {
// Look up the cmd type and error out if not registered.
rt := reflect.TypeOf(cmd)
registerLock.RLock()
method, ok := concreteTypeToMethod[rt]
registerLock.RUnlock()
if !ok {
str := fmt.Sprintf("%q is not registered", method)
return nil, makeError(ErrUnregisteredMethod, str)
}
// The provided command must not be nil.
rv := reflect.ValueOf(cmd)
if rv.IsNil() {
str := "the specified command is nil"
return nil, makeError(ErrInvalidType, str)
}
// Create a slice of interface values in the order of the struct fields
// while respecting pointer fields as optional params and only adding
// them if they are non-nil.
params := makeParams(rt.Elem(), rv.Elem())
// Generate and marshal the final JSON-RPC request.
rawCmd, err := NewRequest(rpcVersion, id, method, params)
if err != nil {
return nil, err
}
return json.Marshal(rawCmd)
}
// checkNumParams ensures the supplied number of params is at least the minimum
// required number for the command and less than the maximum allowed.
func checkNumParams(numParams int, info *methodInfo) error {
if numParams < info.numReqParams || numParams > info.maxParams {
if info.numReqParams == info.maxParams {
str := fmt.Sprintf("wrong number of params (expected "+
"%d, received %d)", info.numReqParams,
numParams)
return makeError(ErrNumParams, str)
}
str := fmt.Sprintf("wrong number of params (expected "+
"between %d and %d, received %d)", info.numReqParams,
info.maxParams, numParams)
return makeError(ErrNumParams, str)
}
return nil
}
// populateDefaults populates default values into any remaining optional struct
// fields that did not have parameters explicitly provided. The caller should
// have previously checked that the number of parameters being passed is at
// least the required number of parameters to avoid unnecessary work in this
// function, but since required fields never have default values, it will work
// properly even without the check.
func populateDefaults(numParams int, info *methodInfo, rv reflect.Value) {
// When there are no more parameters left in the supplied parameters,
// any remaining struct fields must be optional. Thus, populate them
// with their associated default value as needed.
for i := numParams; i < info.maxParams; i++ {
rvf := rv.Field(i)
if defaultVal, ok := info.defaults[i]; ok {
rvf.Set(defaultVal)
}
}
}
// UnmarshalCmd unmarshals a JSON-RPC request into a suitable concrete command
// so long as the method type contained within the marshalled request is
// registered.
func UnmarshalCmd(r *Request) (interface{}, error) {
registerLock.RLock()
rtp, ok := methodToConcreteType[r.Method]
info := methodToInfo[r.Method]
registerLock.RUnlock()
if !ok {
str := fmt.Sprintf("%q is not registered", r.Method)
return nil, makeError(ErrUnregisteredMethod, str)
}
rt := rtp.Elem()
rvp := reflect.New(rt)
rv := rvp.Elem()
// Ensure the number of parameters are correct.
numParams := len(r.Params)
if err := checkNumParams(numParams, &info); err != nil {
return nil, err
}
// Loop through each of the struct fields and unmarshal the associated
// parameter into them.
for i := 0; i < numParams; i++ {
rvf := rv.Field(i)
// Unmarshal the parameter into the struct field.
concreteVal := rvf.Addr().Interface()
if err := json.Unmarshal(r.Params[i], &concreteVal); err != nil {
// The most common error is the wrong type, so
// explicitly detect that error and make it nicer.
fieldName := strings.ToLower(rt.Field(i).Name)
if jerr, ok := err.(*json.UnmarshalTypeError); ok {
str := fmt.Sprintf("parameter #%d '%s' must "+
"be type %v (got %v)", i+1, fieldName,
jerr.Type, jerr.Value)
return nil, makeError(ErrInvalidType, str)
}
// Fallback to showing the underlying error.
str := fmt.Sprintf("parameter #%d '%s' failed to "+
"unmarshal: %v", i+1, fieldName, err)
return nil, makeError(ErrInvalidType, str)
}
}
// When there are less supplied parameters than the total number of
// params, any remaining struct fields must be optional. Thus, populate
// them with their associated default value as needed.
if numParams < info.maxParams {
populateDefaults(numParams, &info, rv)
}
return rvp.Interface(), nil
}
// isNumeric returns whether the passed reflect kind is a signed or unsigned
// integer of any magnitude or a float of any magnitude.
func isNumeric(kind reflect.Kind) bool {
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64, reflect.Float32, reflect.Float64:
return true
}
return false
}
// typesMaybeCompatible returns whether the source type can possibly be
// assigned to the destination type. This is intended as a relatively quick
// check to weed out obviously invalid conversions.
func typesMaybeCompatible(dest reflect.Type, src reflect.Type) bool {
// The same types are obviously compatible.
if dest == src {
return true
}
// When both types are numeric, they are potentially compatible.
srcKind := src.Kind()
destKind := dest.Kind()
if isNumeric(destKind) && isNumeric(srcKind) {
return true
}
if srcKind == reflect.String {
// Strings can potentially be converted to numeric types.
if isNumeric(destKind) {
return true
}
switch destKind {
// Strings can potentially be converted to bools by
// strconv.ParseBool.
case reflect.Bool:
return true
// Strings can be converted to any other type which has as
// underlying type of string.
case reflect.String:
return true
// Strings can potentially be converted to arrays, slice,
// structs, and maps via json.Unmarshal.
case reflect.Array, reflect.Slice, reflect.Struct, reflect.Map:
return true
}
}
return false
}
// baseType returns the type of the argument after indirecting through all
// pointers along with how many indirections were necessary.
func baseType(arg reflect.Type) (reflect.Type, int) {
var numIndirects int
for arg.Kind() == reflect.Ptr {
arg = arg.Elem()
numIndirects++
}
return arg, numIndirects
}
// assignField is the main workhorse for the NewCmd function which handles
// assigning the provided source value to the destination field. It supports
// direct type assignments, indirection, conversion of numeric types, and
// unmarshaling of strings into arrays, slices, structs, and maps via
// json.Unmarshal.
func assignField(paramNum int, fieldName string, dest reflect.Value, src reflect.Value) error {
// Just error now when the types have no chance of being compatible.
destBaseType, destIndirects := baseType(dest.Type())
srcBaseType, srcIndirects := baseType(src.Type())
if !typesMaybeCompatible(destBaseType, srcBaseType) {
str := fmt.Sprintf("parameter #%d '%s' must be type %v (got "+
"%v)", paramNum, fieldName, destBaseType, srcBaseType)
return makeError(ErrInvalidType, str)
}
// Check if it's possible to simply set the dest to the provided source.
// This is the case when the base types are the same or they are both
// pointers that can be indirected to be the same without needing to
// create pointers for the destination field.
if destBaseType == srcBaseType && srcIndirects >= destIndirects {
for i := 0; i < srcIndirects-destIndirects; i++ {
src = src.Elem()
}
dest.Set(src)
return nil
}
// Optional variables can be set null using "null" string
if destIndirects > 0 && src.String() == "null" {
return nil
}
// When the destination has more indirects than the source, the extra
// pointers have to be created. Only create enough pointers to reach
// the same level of indirection as the source so the dest can simply be
// set to the provided source when the types are the same.
destIndirectsRemaining := destIndirects
if destIndirects > srcIndirects {
indirectDiff := destIndirects - srcIndirects
for i := 0; i < indirectDiff; i++ {
dest.Set(reflect.New(dest.Type().Elem()))
dest = dest.Elem()
destIndirectsRemaining--
}
}
if destBaseType == srcBaseType {
dest.Set(src)
return nil
}
// Make any remaining pointers needed to get to the base dest type since
// the above direct assign was not possible and conversions are done
// against the base types.
for i := 0; i < destIndirectsRemaining; i++ {
dest.Set(reflect.New(dest.Type().Elem()))
dest = dest.Elem()
}
// Indirect through to the base source value.
for src.Kind() == reflect.Ptr {
src = src.Elem()
}
// Perform supported type conversions.
switch src.Kind() {
// Source value is a signed integer of various magnitude.
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
switch dest.Kind() {
// Destination is a signed integer of various magnitude.
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
srcInt := src.Int()
if dest.OverflowInt(srcInt) {
str := fmt.Sprintf("parameter #%d '%s' "+
"overflows destination type %v",
paramNum, fieldName, destBaseType)
return makeError(ErrInvalidType, str)
}
dest.SetInt(srcInt)
// Destination is an unsigned integer of various magnitude.
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64:
srcInt := src.Int()
if srcInt < 0 || dest.OverflowUint(uint64(srcInt)) {
str := fmt.Sprintf("parameter #%d '%s' "+
"overflows destination type %v",
paramNum, fieldName, destBaseType)
return makeError(ErrInvalidType, str)
}
dest.SetUint(uint64(srcInt))
default:
str := fmt.Sprintf("parameter #%d '%s' must be type "+
"%v (got %v)", paramNum, fieldName, destBaseType,
srcBaseType)
return makeError(ErrInvalidType, str)
}
// Source value is an unsigned integer of various magnitude.
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64:
switch dest.Kind() {
// Destination is a signed integer of various magnitude.
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
srcUint := src.Uint()
if srcUint > uint64(1<<63)-1 {
str := fmt.Sprintf("parameter #%d '%s' "+
"overflows destination type %v",
paramNum, fieldName, destBaseType)
return makeError(ErrInvalidType, str)
}
if dest.OverflowInt(int64(srcUint)) {
str := fmt.Sprintf("parameter #%d '%s' "+
"overflows destination type %v",
paramNum, fieldName, destBaseType)
return makeError(ErrInvalidType, str)
}
dest.SetInt(int64(srcUint))
// Destination is an unsigned integer of various magnitude.
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64:
srcUint := src.Uint()
if dest.OverflowUint(srcUint) {
str := fmt.Sprintf("parameter #%d '%s' "+
"overflows destination type %v",
paramNum, fieldName, destBaseType)
return makeError(ErrInvalidType, str)
}
dest.SetUint(srcUint)
default:
str := fmt.Sprintf("parameter #%d '%s' must be type "+
"%v (got %v)", paramNum, fieldName, destBaseType,
srcBaseType)
return makeError(ErrInvalidType, str)
}
// Source value is a float.
case reflect.Float32, reflect.Float64:
destKind := dest.Kind()
if destKind != reflect.Float32 && destKind != reflect.Float64 {
str := fmt.Sprintf("parameter #%d '%s' must be type "+
"%v (got %v)", paramNum, fieldName, destBaseType,
srcBaseType)
return makeError(ErrInvalidType, str)
}
srcFloat := src.Float()
if dest.OverflowFloat(srcFloat) {
str := fmt.Sprintf("parameter #%d '%s' overflows "+
"destination type %v", paramNum, fieldName,
destBaseType)
return makeError(ErrInvalidType, str)
}
dest.SetFloat(srcFloat)
// Source value is a string.
case reflect.String:
switch dest.Kind() {
// String -> bool
case reflect.Bool:
b, err := strconv.ParseBool(src.String())
if err != nil {
str := fmt.Sprintf("parameter #%d '%s' must "+
"parse to a %v", paramNum, fieldName,
destBaseType)
return makeError(ErrInvalidType, str)
}
dest.SetBool(b)
// String -> signed integer of varying size.
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64:
srcInt, err := strconv.ParseInt(src.String(), 0, 0)
if err != nil {
str := fmt.Sprintf("parameter #%d '%s' must "+
"parse to a %v", paramNum, fieldName,
destBaseType)
return makeError(ErrInvalidType, str)
}
if dest.OverflowInt(srcInt) {
str := fmt.Sprintf("parameter #%d '%s' "+
"overflows destination type %v",
paramNum, fieldName, destBaseType)
return makeError(ErrInvalidType, str)
}
dest.SetInt(srcInt)
// String -> unsigned integer of varying size.
case reflect.Uint, reflect.Uint8, reflect.Uint16,
reflect.Uint32, reflect.Uint64:
srcUint, err := strconv.ParseUint(src.String(), 0, 0)
if err != nil {
str := fmt.Sprintf("parameter #%d '%s' must "+
"parse to a %v", paramNum, fieldName,
destBaseType)
return makeError(ErrInvalidType, str)
}
if dest.OverflowUint(srcUint) {
str := fmt.Sprintf("parameter #%d '%s' "+
"overflows destination type %v",
paramNum, fieldName, destBaseType)
return makeError(ErrInvalidType, str)
}
dest.SetUint(srcUint)
// String -> float of varying size.
case reflect.Float32, reflect.Float64:
srcFloat, err := strconv.ParseFloat(src.String(), 0)
if err != nil {
str := fmt.Sprintf("parameter #%d '%s' must "+
"parse to a %v", paramNum, fieldName,
destBaseType)
return makeError(ErrInvalidType, str)
}
if dest.OverflowFloat(srcFloat) {
str := fmt.Sprintf("parameter #%d '%s' "+
"overflows destination type %v",
paramNum, fieldName, destBaseType)
return makeError(ErrInvalidType, str)
}
dest.SetFloat(srcFloat)
// String -> string (typecast).
case reflect.String:
dest.SetString(src.String())
// String -> arrays, slices, structs, and maps via
// json.Unmarshal.
case reflect.Array, reflect.Slice, reflect.Struct, reflect.Map:
concreteVal := dest.Addr().Interface()
err := json.Unmarshal([]byte(src.String()), &concreteVal)
if err != nil {
str := fmt.Sprintf("parameter #%d '%s' must "+
"be valid JSON which unsmarshals to a %v",
paramNum, fieldName, destBaseType)
return makeError(ErrInvalidType, str)
}
dest.Set(reflect.ValueOf(concreteVal).Elem())
}
}
return nil
}
// NewCmd provides a generic mechanism to create a new command that can marshal
// to a JSON-RPC request while respecting the requirements of the provided
// method. The method must have been registered with the package already along
// with its type definition. All methods associated with the commands exported
// by this package are already registered by default.
//
// The arguments are most efficient when they are the exact same type as the
// underlying field in the command struct associated with the the method,
// however this function also will perform a variety of conversions to make it
// more flexible. This allows, for example, command line args which are strings
// to be passed unaltered. In particular, the following conversions are
// supported:
//
// - Conversion between any size signed or unsigned integer so long as the
// value does not overflow the destination type
// - Conversion between float32 and float64 so long as the value does not
// overflow the destination type
// - Conversion from string to boolean for everything strconv.ParseBool
// recognizes
// - Conversion from string to any size integer for everything
// strconv.ParseInt and strconv.ParseUint recognizes
// - Conversion from string to any size float for everything
// strconv.ParseFloat recognizes
// - Conversion from string to arrays, slices, structs, and maps by treating
// the string as marshalled JSON and calling json.Unmarshal into the
// destination field
func NewCmd(method string, args ...interface{}) (interface{}, error) {
// Look up details about the provided method. Any methods that aren't
// registered are an error.
registerLock.RLock()
rtp, ok := methodToConcreteType[method]
info := methodToInfo[method]
registerLock.RUnlock()
if !ok {
str := fmt.Sprintf("%q is not registered", method)
return nil, makeError(ErrUnregisteredMethod, str)
}
// Ensure the number of parameters are correct.
numParams := len(args)
if err := checkNumParams(numParams, &info); err != nil {
return nil, err
}
// Create the appropriate command type for the method. Since all types
// are enforced to be a pointer to a struct at registration time, it's
// safe to indirect to the struct now.
rvp := reflect.New(rtp.Elem())
rv := rvp.Elem()
rt := rtp.Elem()
// Loop through each of the struct fields and assign the associated
// parameter into them after checking its type validity.
for i := 0; i < numParams; i++ {
// Attempt to assign each of the arguments to the according
// struct field.
rvf := rv.Field(i)
fieldName := strings.ToLower(rt.Field(i).Name)
err := assignField(i+1, fieldName, rvf, reflect.ValueOf(args[i]))
if err != nil {
return nil, err
}
}
return rvp.Interface(), nil
}