-
Notifications
You must be signed in to change notification settings - Fork 0
/
di_generic.go
81 lines (69 loc) · 2.21 KB
/
di_generic.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
//go:build go1.18
package diingo
import (
"fmt"
"reflect"
)
type ProviderFuncWithError[T any] func(dependencies ...any) (T, error)
type ProviderFunc[T any] func(dependencies ...any) T
type Provider[T any] interface {
ProviderFunc[any] | ProviderFuncWithError[any]
}
type provider interface {
provide(dependencies ...any) (any, error)
}
var _ provider = ProviderFunc[any](nil)
var _ provider = ProviderFuncWithError[any](nil)
func (p ProviderFunc[T]) provide(dependencies ...any) (T, error) {
return p(dependencies), nil
}
func (p ProviderFuncWithError[T]) provide(dependencies ...any) (T, error) {
return p(dependencies)
}
func LoadDependencies[T any](obj *T, providers ...any) error {
rootType := reflect.TypeOf(obj)
rootNode := newFunctionNode(createConstructor(obj, rootType))
dependencyNodes := createDependencyNodes(providers...)
for _, node := range dependencyNodes {
node.confirmDependencyWith(rootNode)
for _, constructorNode := range dependencyNodes {
if constructorNode == node {
continue
}
node.confirmDependencyWith(constructorNode)
}
}
value, err := rootNode.Value()
if err != nil {
return err
}
if !value.IsValid() {
return fmt.Errorf("failed to resolve dependencies")
}
return nil
}
func createDependencyNodes[P Provider[any] | any](providers ...P) []*dNode {
constructorNodes := make([]*dNode, 0, len(providers))
for _, provider := range providers {
providerValue := reflect.ValueOf(provider)
constructorNodes = append(constructorNodes, newValueNode(providerValue))
if providerValue.Kind() == reflect.Func {
constructorNodes = append(constructorNodes, newFunctionNode(providerValue))
}
}
return constructorNodes
}
func createConstructor[T any](obj *T, returnType reflect.Type) reflect.Value {
elemType := returnType.Elem()
arguments := make([]reflect.Type, 0, elemType.NumField())
for i := 0; i < elemType.NumField(); i++ {
arguments = append(arguments, elemType.Field(i).Type)
}
return reflect.MakeFunc(reflect.FuncOf(arguments, []reflect.Type{returnType}, false), func(args []reflect.Value) (results []reflect.Value) {
returnValue := reflect.ValueOf(obj)
for i, arg := range args {
returnValue.Elem().Field(i).Set(arg)
}
return []reflect.Value{returnValue}
})
}