diff --git a/dig.go b/dig.go index a69e9bcc..3db042e1 100644 --- a/dig.go +++ b/dig.go @@ -28,8 +28,10 @@ import ( ) var ( - _noValue reflect.Value - _errType = reflect.TypeOf((*error)(nil)).Elem() + _noValue reflect.Value + _errType = reflect.TypeOf((*error)(nil)).Elem() + _parameterObjectType = reflect.TypeOf((*parameterObject)(nil)).Elem() + _paramType = reflect.TypeOf(Param{}) ) // A Container is a directed, acyclic graph of dependencies. Dependencies are @@ -110,6 +112,9 @@ func (c *Container) provideInstance(val interface{}) error { if vtype == _errType { return errors.New("can't provide errors") } + if vtype.Implements(_parameterObjectType) { + return errors.New("can't provide parameter objects") + } if _, ok := c.nodes[vtype]; ok { return errors.New("already in container") } @@ -126,6 +131,9 @@ func (c *Container) provideConstructor(ctor interface{}, ctype reflect.Type) err // Don't register errors into the container. continue } + if rt.Implements(_parameterObjectType) { + return errors.New("can't provide parameter objects") + } if _, ok := returnTypes[rt]; ok { return fmt.Errorf("returns multiple %v", rt) } @@ -140,7 +148,10 @@ func (c *Container) provideConstructor(ctor interface{}, ctype reflect.Type) err nodes := make([]node, 0, len(returnTypes)) for rt := range returnTypes { - n := newNode(rt, ctor, ctype) + n, err := newNode(rt, ctor, ctype) + if err != nil { + return err + } nodes = append(nodes, n) c.nodes[rt] = n } @@ -222,7 +233,18 @@ func (c *Container) remove(nodes []node) { func (c *Container) constructorArgs(ctype reflect.Type) ([]reflect.Value, error) { args := make([]reflect.Value, 0, ctype.NumIn()) for i := 0; i < ctype.NumIn(); i++ { - arg, err := c.get(ctype.In(i)) + var ( + arg reflect.Value + err error + ) + + t := ctype.In(i) + if t.Implements(_parameterObjectType) { + arg, err = getParameterObject(c, t) + } else { + arg, err = c.get(t) + } + if err != nil { return nil, fmt.Errorf("couldn't get arguments for constructor %v: %v", ctype, err) } @@ -238,17 +260,27 @@ type node struct { deps []reflect.Type } -func newNode(provides reflect.Type, ctor interface{}, ctype reflect.Type) node { - deps := make([]reflect.Type, ctype.NumIn()) - for i := range deps { - deps[i] = ctype.In(i) +func newNode(provides reflect.Type, ctor interface{}, ctype reflect.Type) (node, error) { + deps := make([]reflect.Type, 0, ctype.NumIn()) + for i := 0; i < ctype.NumIn(); i++ { + t := ctype.In(i) + if t.Implements(_parameterObjectType) { + pdeps, err := getParameterDependencies(t) + if err != nil { + return node{}, err + } + deps = append(deps, pdeps...) + } else { + deps = append(deps, t) + } } + return node{ provides: provides, ctor: ctor, ctype: ctype, deps: deps, - } + }, nil } func cycleError(cycle []reflect.Type, last reflect.Type) error { @@ -278,3 +310,85 @@ func detectCycles(n node, graph map[reflect.Type]node, path []reflect.Type) erro } return nil } + +// Param is embedded inside structs to opt those structs in as Dig parameter +// objects. +// +// TODO usage docs +type Param struct{} + +var _ parameterObject = Param{} + +// Param is the only instance of parameterObject. +func (Param) parameterObject() {} + +// Users embed the Param struct to opt a struct in as a parameter object. +// Param implements this interface so the struct into which Param is embedded +// also implements this interface. This provides us an easy way to check if +// something embeds Param without iterating through all its fields. +type parameterObject interface { + parameterObject() +} + +// Returns dependencies introduced by a parameter object. +func getParameterDependencies(t reflect.Type) ([]reflect.Type, error) { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + var deps []reflect.Type + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.PkgPath != "" { + continue // skip private fields + } + + // Skip the embedded Param type. + if f.Anonymous && f.Type == _paramType { + continue + } + + // THe user added a parameter object as a dependency. We don't recurse + // /yet/ so let's try to give an informative error message. + if f.Type.Implements(_parameterObjectType) { + return nil, fmt.Errorf( + "dig parameter objects may not be used as fields of other parameter objects: "+ + "field %v (type %v) of %v is a parameter object", f.Name, f.Type, t) + } + + deps = append(deps, f.Type) + } + return deps, nil +} + +func getParameterObject(c *Container, t reflect.Type) (reflect.Value, error) { + dest := reflect.New(t).Elem() + result := dest + for t.Kind() == reflect.Ptr { + t = t.Elem() + dest.Set(reflect.New(t)) + dest = dest.Elem() + } + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if f.PkgPath != "" { + continue // skip private fields + } + + // Skip the embedded Param type. + if f.Anonymous && f.Type == _paramType { + continue + } + + v, err := c.get(f.Type) + if err != nil { + return result, fmt.Errorf( + "could not get field %v (type %v) of %v: %v", f.Name, f.Type, t, err) + } + + dest.Field(i).Set(v) + } + + return result, nil +} diff --git a/dig_test.go b/dig_test.go index b9b75595..df89a64d 100644 --- a/dig_test.go +++ b/dig_test.go @@ -24,6 +24,7 @@ import ( "bytes" "errors" "io" + "io/ioutil" "testing" "github.com/stretchr/testify/assert" @@ -199,6 +200,78 @@ func TestEndToEndSuccess(t *testing.T) { }), "invoke failed") }) + t.Run("param", func(t *testing.T) { + c := New() + + type contents string + + type Args struct { + Param + + privateContents contents + Contents contents + } + + require.NoError(t, + c.Provide(func(args Args) *bytes.Buffer { + // testify's Empty doesn't work on string aliases for some + // reason + require.Len(t, args.privateContents, 0, "private contents must be empty") + + require.NotEmpty(t, args.Contents, "contents must not be empty") + return bytes.NewBufferString(string(args.Contents)) + }), "provide constructor failed") + + require.NoError(t, + c.Provide(contents("hello world")), + "provide value failed") + + require.NoError(t, c.Invoke(func(buff *bytes.Buffer) { + out, err := ioutil.ReadAll(buff) + require.NoError(t, err, "read from buffer failed") + require.Equal(t, "hello world", string(out), "contents don't match") + })) + }) + + t.Run("param pointer", func(t *testing.T) { + c := New() + + require.NoError(t, c.Provide(func() io.Writer { + return new(bytes.Buffer) + }), "provide failed") + + type Args struct { + Param + + Writer io.Writer + } + + require.NoError(t, c.Invoke(func(args *Args) { + require.NotNil(t, args, "args must not be nil") + require.NotNil(t, args.Writer, "writer must not be nil") + }), "invoke failed") + }) + + t.Run("invoke param", func(t *testing.T) { + c := New() + require.NoError(t, c.Provide(func() *bytes.Buffer { + return new(bytes.Buffer) + }), "provide failed") + + type Args struct { + Param + + privateBuffer *bytes.Buffer + + *bytes.Buffer + } + + require.NoError(t, c.Invoke(func(args Args) { + require.Nil(t, args.privateBuffer, "private buffer must be nil") + require.NotNil(t, args.Buffer, "invoke got nil buffer") + })) + }) + t.Run("multiple-type constructor", func(t *testing.T) { c := New() constructor := func() (*bytes.Buffer, []int, error) { @@ -223,6 +296,36 @@ func TestEndToEndSuccess(t *testing.T) { }) } +func TestParamsDontRecurse(t *testing.T) { + t.Parallel() + + t.Run("param field", func(t *testing.T) { + + type anotherParam struct { + Param + + Reader io.Reader + } + + type someParam struct { + Param + + Writer io.Writer + Another anotherParam + } + + c := New() + err := c.Provide(func(a someParam) *bytes.Buffer { + panic("constructor must not be called") + }) + require.Error(t, err, "provide must fail") + require.Contains(t, err.Error(), + "parameter objects may not be used as fields of other parameter objects") + require.Contains(t, err.Error(), + "field Another (type dig.anotherParam) of dig.someParam is a parameter object") + }) +} + func TestProvideRespectsConstructorErrors(t *testing.T) { t.Run("constructor succeeds", func(t *testing.T) { c := New() @@ -265,6 +368,50 @@ func TestCantProvideErrors(t *testing.T) { assert.NoError(t, c.Provide(errors.New("foo"))) } +func TestCantProvideParameterObjects(t *testing.T) { + t.Parallel() + + t.Run("instance", func(t *testing.T) { + type Args struct{ Param } + + c := New() + err := c.Provide(Args{}) + require.Error(t, err, "provide should fail") + require.Contains(t, err.Error(), "can't provide parameter objects") + }) + + t.Run("pointer", func(t *testing.T) { + type Args struct{ Param } + + c := New() + err := c.Provide(&Args{}) + require.Error(t, err, "provide should fail") + require.Contains(t, err.Error(), "can't provide parameter objects") + }) + + t.Run("constructor", func(t *testing.T) { + type Args struct{ Param } + + c := New() + err := c.Provide(func() (Args, error) { + panic("great sadness") + }) + require.Error(t, err, "provide should fail") + require.Contains(t, err.Error(), "can't provide parameter objects") + }) + + t.Run("constructor pointer", func(t *testing.T) { + type Args struct{ Param } + + c := New() + err := c.Provide(func() (*Args, error) { + panic("great sadness") + }) + require.Error(t, err, "provide should fail") + require.Contains(t, err.Error(), "can't provide parameter objects") + }) +} + func TestProvideKnownTypesFails(t *testing.T) { t.Parallel() c := New()