This repository has been archived by the owner on Nov 21, 2023. It is now read-only.
/
typewrap.go
170 lines (146 loc) · 4.04 KB
/
typewrap.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
package jlibhttp
import (
"context"
"fmt"
"github.com/mitchellh/mapstructure"
"github.com/pkg/errors"
"github.com/superisaac/jlib"
"reflect"
)
func typeIsStruct(tp reflect.Type) bool {
return (tp.Kind() == reflect.Struct ||
(tp.Kind() == reflect.Ptr && typeIsStruct(tp.Elem())))
}
func interfaceToValue(a interface{}, outputType reflect.Type) (reflect.Value, error) {
output := reflect.Zero(outputType).Interface()
config := &mapstructure.DecoderConfig{
Metadata: nil,
TagName: "json",
Result: &output,
}
decoder, err := mapstructure.NewDecoder(config)
if err != nil {
return reflect.Value{}, err
}
err = decoder.Decode(a)
if err != nil {
return reflect.Value{}, err
}
return reflect.ValueOf(output), nil
}
func valueToInterface(tp reflect.Type, val reflect.Value) (interface{}, error) {
var output interface{}
if typeIsStruct(tp) {
output = make(map[string]interface{})
} else {
output = reflect.Zero(tp).Interface()
}
config := &mapstructure.DecoderConfig{
Metadata: nil,
TagName: "json",
Result: &output,
}
decoder, err := mapstructure.NewDecoder(config)
if err != nil {
return nil, err
}
err = decoder.Decode(val.Interface())
if err != nil {
return nil, err
}
return output, nil
}
type FirstArgSpec interface {
Check(firstArgType reflect.Type) bool
Value(req *RPCRequest) interface{}
String() string
}
type ReqSpec struct{}
func (self ReqSpec) Check(firstArgType reflect.Type) bool {
return firstArgType.Kind() == reflect.Ptr && firstArgType.String() == self.String()
}
func (self ReqSpec) Value(req *RPCRequest) interface{} {
return req
}
func (self ReqSpec) String() string {
return "*jlibhttp.RPCRequest"
}
type ContextSpec struct{}
func (self ContextSpec) Check(firstArgType reflect.Type) bool {
ctxType := reflect.TypeOf((*context.Context)(nil)).Elem()
return firstArgType.Kind() == reflect.Interface && firstArgType.Implements(ctxType)
}
func (self ContextSpec) Value(req *RPCRequest) interface{} {
return req.Context()
}
func (self ContextSpec) String() string {
return "context.Context"
}
func wrapTyped(tfunc interface{}, firstArgSpec FirstArgSpec) (RequestCallback, error) {
funcType := reflect.TypeOf(tfunc)
if funcType.Kind() != reflect.Func {
return nil, errors.New("tfunc is not func type")
}
numIn := funcType.NumIn()
requireFirstArg := firstArgSpec != (FirstArgSpec)(nil)
firstArgNum := 0
if requireFirstArg {
firstArgNum = 1
// check inputs and 1st argument
if numIn < firstArgNum {
return nil, errors.New("func must have 1 more arguments")
}
firstArgType := funcType.In(0)
if !firstArgSpec.Check(firstArgType) {
return nil, errors.New(fmt.Sprintf("the first arg must be %s", firstArgSpec.String()))
}
}
// check outputs
if funcType.NumOut() != 2 {
return nil, errors.New("func return number must be 2")
}
errType := funcType.Out(1)
errInterface := reflect.TypeOf((*error)(nil)).Elem()
if !errType.Implements(errInterface) {
return nil, errors.New("second output does not implement error")
}
handler := func(req *RPCRequest, params []interface{}) (interface{}, error) {
// check inputs
if numIn > len(params)+firstArgNum {
return nil, jlib.ParamsError("no enough params size")
}
// params -> []reflect.Value
fnArgs := []reflect.Value{}
if requireFirstArg {
v := firstArgSpec.Value(req)
fnArgs = append(fnArgs, reflect.ValueOf(v))
}
j := 0
for i := firstArgNum; i < numIn; i++ {
argType := funcType.In(i)
param := params[j]
j++
argValue, err := interfaceToValue(param, argType)
if err != nil {
return nil, jlib.ParamsError(
fmt.Sprintf("params %d %s", i+1, err))
}
fnArgs = append(fnArgs, argValue)
}
// wrap result
resValues := reflect.ValueOf(tfunc).Call(fnArgs)
resType := funcType.Out(0)
errRes := resValues[1].Interface()
if errRes != nil {
if err, ok := errRes.(error); ok {
return nil, err
} else {
return nil, errors.New(fmt.Sprintf("error return is not error %+v", errRes))
}
}
res, err := valueToInterface(
resType, resValues[0])
return res, err
}
return handler, nil
}