-
Notifications
You must be signed in to change notification settings - Fork 55
/
inject.go
141 lines (119 loc) · 2.46 KB
/
inject.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
package zdi
import (
"fmt"
"reflect"
"github.com/sohaha/zlsgo/zerror"
"github.com/sohaha/zlsgo/zreflect"
)
func (inj *injector) InvokeWithErrorOnly(f interface{}) (err error) {
v, err := inj.Invoke(f)
if err != nil {
return err
}
if len(v) == 0 {
return nil
}
for i := range v {
if err, ok := v[i].Interface().(error); ok {
return err
}
}
return nil
}
func (inj *injector) Invoke(f interface{}) (values []reflect.Value, err error) {
catch := zerror.TryCatch(func() error {
t := zreflect.TypeOf(f)
switch v := f.(type) {
case PreInvoker:
values, err = inj.fast(v, t, t.NumIn())
default:
values, err = inj.call(f, t, t.NumIn())
}
return nil
})
if catch != nil {
err = catch
}
return
}
func (inj *injector) call(f interface{}, t reflect.Type, numIn int) ([]reflect.Value, error) {
var in []reflect.Value
if numIn > 0 {
in = make([]reflect.Value, numIn)
var argType reflect.Type
for i := 0; i < numIn; i++ {
argType = t.In(i)
val, ok := inj.Get(argType)
if !ok {
return nil, fmt.Errorf("value not found for type %v", argType)
}
in[i] = val
}
}
return zreflect.ValueOf(f).Call(in), nil
}
func (inj *injector) Map(val interface{}, opt ...Option) (override reflect.Type) {
o := mapOption{}
for _, opt := range opt {
opt(&o)
}
if o.key == nil {
o.key = reflect.TypeOf(val)
}
if _, ok := inj.values[o.key]; ok {
override = o.key
}
inj.values[o.key] = zreflect.ValueOf(val)
return
}
func (inj *injector) Maps(values ...interface{}) (override []reflect.Type) {
for _, val := range values {
o := inj.Map(val)
if o != nil {
override = append(override, o)
}
}
return
}
func (inj *injector) Set(typ reflect.Type, val reflect.Value) {
inj.values[typ] = val
}
func (inj *injector) Get(t reflect.Type) (reflect.Value, bool) {
val := inj.values[t]
if val.IsValid() {
return val, true
}
if provider, ok := inj.providers[t]; ok {
results, err := inj.Invoke(provider.Interface())
if err != nil {
panic(err)
}
for _, result := range results {
resultType := result.Type()
inj.values[resultType] = result
delete(inj.providers, resultType)
if resultType == t {
val = result
}
}
if val.IsValid() {
return val, true
}
}
if t.Kind() == reflect.Interface {
for k, v := range inj.values {
if k.Implements(t) {
val = v
break
}
}
}
if val.IsValid() {
return val, true
}
var ok bool
if inj.parent != nil {
val, ok = inj.parent.Get(t)
}
return val, ok
}