Skip to content

Commit ff5e93f

Browse files
committed
fix: barry quick fix, 2025-06-08 10:42:50
1 parent 1458dce commit ff5e93f

File tree

16 files changed

+1864
-110
lines changed

16 files changed

+1864
-110
lines changed

dix.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ func New(opts ...Option) Container {
3636
// 提供了统一的依赖注入接口,简化API设计。
3737
//
3838
// 支持的输入类型:
39-
// - 函数:func(deps...) - 解析函数参数并调用函数
39+
// - 函数:func(deps...) 或 func(deps...) error - 解析函数参数并调用函数
4040
// - 结构体指针:&struct{} - 注入到结构体字段
4141
//
4242
// 函数注入规则:
43-
// - 函数只能有入参,不能有出参
43+
// - 函数可以没有返回值,或者有一个 error 返回值
4444
// - 函数参数类型必须在容器中已注册
4545
// - 支持的参数类型:指针(*T)、接口(interface{})、结构体(struct{})、切片([]T)、映射(map[string]T)
4646
// - 不支持基本类型参数:string, int, bool 等
@@ -63,7 +63,7 @@ func New(opts ...Option) Container {
6363
//
6464
// 示例:
6565
//
66-
// // 函数注入
66+
// // 函数注入(无返回值)
6767
// _, err := dix.Inject(container, func(logger Logger, db *Database, handlers []Handler) {
6868
// // 使用注入的依赖启动服务器
6969
// startServer(logger, db, handlers)
@@ -72,6 +72,15 @@ func New(opts ...Option) Container {
7272
// log.Fatal(err)
7373
// }
7474
//
75+
// // 函数注入(error 返回值)
76+
// _, err := dix.Inject(container, func(logger Logger, db *Database) error {
77+
// // 执行可能失败的初始化操作
78+
// return initializeSystem(logger, db)
79+
// })
80+
// if err != nil {
81+
// log.Fatal(err)
82+
// }
83+
//
7584
// // 结构体注入
7685
// type Service struct {
7786
// Logger Logger

dixglobal/global.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
package dixglobal
22

33
import (
4-
"github.com/pubgo/funk/errors"
54
"reflect"
65

6+
"github.com/pubgo/funk/errors"
7+
78
"github.com/pubgo/dix/dixinternal"
89
)
910

@@ -19,6 +20,14 @@ var _container = dixinternal.New(dixinternal.WithValuesNull())
1920
// server.ListenAndServe()
2021
// })
2122
//
23+
// // 或者带错误处理的函数注入
24+
// Inject(func(server *http.Server, db *Database) error {
25+
// if err := db.Connect(); err != nil {
26+
// return err
27+
// }
28+
// return server.ListenAndServe()
29+
// })
30+
//
2231
// // 或者使用结构体注入
2332
// type App struct {
2433
// Server *http.Server

dixinternal/container.go

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,19 +132,42 @@ func (c *ContainerImpl) getAllProviders() map[reflect.Type][]Provider {
132132

133133
// removeProvider 移除提供者(用于回滚)
134134
func (c *ContainerImpl) removeProvider(provider Provider) {
135-
typ := provider.Type()
136-
providers := c.resolver.providers[typ]
137-
138-
for i, p := range providers {
139-
if p == provider {
140-
// 移除该提供者
141-
c.resolver.providers[typ] = append(providers[:i], providers[i+1:]...)
142-
break
135+
// 移除该 provider 在所有其能提供的类型下的注册
136+
providedTypes := provider.ProvidedTypes()
137+
138+
for _, typ := range providedTypes {
139+
providers := c.resolver.providers[typ]
140+
141+
for i, p := range providers {
142+
if p == provider {
143+
// 移除该提供者
144+
c.resolver.providers[typ] = append(providers[:i], providers[i+1:]...)
145+
break
146+
}
147+
}
148+
149+
// 如果该类型没有提供者了,删除整个条目
150+
if len(c.resolver.providers[typ]) == 0 {
151+
delete(c.resolver.providers, typ)
143152
}
144153
}
145154

146-
// 如果该类型没有提供者了,删除整个条目
147-
if len(c.resolver.providers[typ]) == 0 {
148-
delete(c.resolver.providers, typ)
155+
// 向后兼容:如果没有 ProvidedTypes,使用传统的 PrimaryType() 方法
156+
if len(providedTypes) == 0 {
157+
typ := provider.PrimaryType()
158+
providers := c.resolver.providers[typ]
159+
160+
for i, p := range providers {
161+
if p == provider {
162+
// 移除该提供者
163+
c.resolver.providers[typ] = append(providers[:i], providers[i+1:]...)
164+
break
165+
}
166+
}
167+
168+
// 如果该类型没有提供者了,删除整个条目
169+
if len(c.resolver.providers[typ]) == 0 {
170+
delete(c.resolver.providers, typ)
171+
}
149172
}
150173
}

dixinternal/cycle_detector.go

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,48 @@ func (cd *CycleDetectorImpl) buildDependencyGraph(providers map[reflect.Type][]P
2929
graph[outputType] = make(map[reflect.Type]bool)
3030
}
3131

32+
// 跟踪已处理的 provider,避免重复处理多类型 provider
33+
processedProviders := make(map[Provider]bool)
34+
3235
// 构建依赖关系图
33-
for outputType, providerList := range providers {
36+
for _, providerList := range providers {
3437
for _, provider := range providerList {
35-
for _, dep := range provider.Dependencies() {
36-
// 递归获取所有依赖类型
37-
depTypes := cd.getAllDependencyTypes(dep)
38-
for _, depType := range depTypes {
39-
graph[outputType][depType] = true
38+
// 跳过已处理的 provider,避免多类型 provider 被重复处理
39+
if processedProviders[provider] {
40+
continue
41+
}
42+
processedProviders[provider] = true
43+
44+
// 获取该 provider 能提供的所有类型
45+
providedTypes := provider.ProvidedTypes()
46+
if len(providedTypes) == 0 {
47+
// 向后兼容:使用 PrimaryType
48+
providedTypes = []reflect.Type{provider.PrimaryType()}
49+
}
50+
51+
// 为每个提供的类型建立依赖关系
52+
for _, providedType := range providedTypes {
53+
if graph[providedType] == nil {
54+
graph[providedType] = make(map[reflect.Type]bool)
55+
}
56+
57+
for _, dep := range provider.Dependencies() {
58+
// 递归获取所有依赖类型
59+
depTypes := cd.getAllDependencyTypes(dep)
60+
for _, depType := range depTypes {
61+
// 避免自依赖(同一个 provider 提供的类型之间不应该有依赖关系)
62+
isProvidedBySameProvider := false
63+
for _, pt := range providedTypes {
64+
if pt == depType {
65+
isProvidedBySameProvider = true
66+
break
67+
}
68+
}
69+
70+
if !isProvidedBySameProvider {
71+
graph[providedType][depType] = true
72+
}
73+
}
4074
}
4175
}
4276
}

dixinternal/injector.go

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ func (inj *InjectorImpl) InjectStruct(target reflect.Value, opts Options) (err e
5050
continue
5151
}
5252

53+
// 只对支持的字段类型进行注入
54+
if !isInjectableFieldType(field.Type) {
55+
logger.Debug().Msgf("skipping non-injectable field: %s (type: %s)", field.Name, field.Type.String())
56+
continue
57+
}
58+
5359
err := inj.injectField(fieldValue, field, opts)
5460
if err != nil {
5561
return WrapError(err, ErrorTypeInjection, "failed to inject struct field").
@@ -84,8 +90,19 @@ func (inj *InjectorImpl) InjectFunc(fn reflect.Value, opts Options) (err error)
8490
}
8591

8692
fnType := fn.Type()
87-
if fnType.NumOut() != 0 {
88-
return NewValidationError("injectable function must have no return values").
93+
94+
// 检查返回值:允许没有返回值或者只有一个 error 返回值
95+
var hasErrorReturn bool
96+
if fnType.NumOut() == 1 {
97+
// 如果有一个返回值,必须是 error 类型
98+
errorType := fnType.Out(0)
99+
if !errorType.Implements(reflect.TypeOf((*error)(nil)).Elem()) {
100+
return NewValidationError("injectable function can only return error type").
101+
WithDetail("return_type", errorType.String())
102+
}
103+
hasErrorReturn = true
104+
} else if fnType.NumOut() > 1 {
105+
return NewValidationError("injectable function can have at most one return value (error)").
89106
WithDetail("return_count", fnType.NumOut())
90107
}
91108

@@ -94,8 +111,8 @@ func (inj *InjectorImpl) InjectFunc(fn reflect.Value, opts Options) (err error)
94111
WithDetail("parameter_count", fnType.NumIn())
95112
}
96113

97-
// 解析函数参数依赖
98-
dependencies, err := parseDependencies(fnType)
114+
// 解析函数参数依赖(对于普通函数注入,不需要结构体字段展开)
115+
dependencies, err := parseBasicDependencies(fnType)
99116
if err != nil {
100117
return WrapError(err, ErrorTypeValidation, "failed to parse function dependencies")
101118
}
@@ -107,7 +124,18 @@ func (inj *InjectorImpl) InjectFunc(fn reflect.Value, opts Options) (err error)
107124
}
108125

109126
// 调用函数
110-
fn.Call(args)
127+
results := fn.Call(args)
128+
129+
// 如果函数有 error 返回值,检查并处理
130+
if hasErrorReturn && len(results) > 0 {
131+
errorValue := results[0]
132+
if !errorValue.IsNil() {
133+
if funcErr, ok := errorValue.Interface().(error); ok {
134+
return WrapError(funcErr, ErrorTypeInjection, "injected function returned error")
135+
}
136+
}
137+
}
138+
111139
return nil
112140
}
113141

@@ -178,6 +206,11 @@ func (inj *InjectorImpl) injectField(fieldValue reflect.Value, field reflect.Str
178206
// 解析并设置值
179207
val, err := inj.resolver.Resolve(field.Type, opts)
180208
if err != nil {
209+
// 如果无法解析,尝试从结构体字段依赖中获取
210+
if structFieldVal := inj.tryResolveFromStructFields(field.Type, opts); structFieldVal.IsValid() {
211+
fieldValue.Set(structFieldVal)
212+
return nil
213+
}
181214
return WrapError(err, ErrorTypeInjection, "failed to resolve field value").
182215
WithDetail("field_name", field.Name).
183216
WithDetail("field_type", field.Type.String())
@@ -220,15 +253,22 @@ func (inj *InjectorImpl) injectField(fieldValue reflect.Value, field reflect.Str
220253
}
221254

222255
default:
223-
return NewValidationError("unsupported field type for injection").
224-
WithDetail("field_name", field.Name).
225-
WithDetail("field_type", field.Type.String()).
226-
WithDetail("field_kind", field.Type.Kind().String())
256+
// 对于不支持的字段类型,直接跳过而不是报错
257+
logger.Debug().Msgf("skipping field with unsupported type: %s (type: %s)", field.Name, field.Type.String())
258+
return nil
227259
}
228260

229261
return nil
230262
}
231263

264+
// tryResolveFromStructFields 尝试从结构体字段依赖中解析类型
265+
func (inj *InjectorImpl) tryResolveFromStructFields(fieldType reflect.Type, opts Options) reflect.Value {
266+
// 这是一个辅助方法,用于处理结构体字段的递归依赖解析
267+
// 在实际实现中,这可能需要与resolver的内部实现更紧密地集成
268+
// 目前先返回零值,表示无法解析
269+
return reflect.Value{}
270+
}
271+
232272
// injectMethods 注入方法
233273
func (inj *InjectorImpl) injectMethods(target reflect.Value, opts Options) error {
234274
targetType := target.Type()

0 commit comments

Comments
 (0)