Skip to content

Commit

Permalink
Merge fa42214 into a2d2066
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinav committed Jun 1, 2017
2 parents a2d2066 + fa42214 commit 3731e03
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 62 deletions.
49 changes: 17 additions & 32 deletions dig.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ var (
_noValue reflect.Value
_errType = reflect.TypeOf((*error)(nil)).Elem()
_parameterObjectType = reflect.TypeOf((*parameterObject)(nil)).Elem()
_paramType = reflect.TypeOf(Param{})
)

// A Container is a directed, acyclic graph of dependencies. Dependencies are
Expand Down Expand Up @@ -173,11 +172,17 @@ func (c *Container) isAcyclic(n node) error {
return detectCycles(n, c.nodes, nil, make(map[reflect.Type]struct{}))
}

// Retrieve a type from the container
func (c *Container) get(t reflect.Type) (reflect.Value, error) {
if v, ok := c.cache[t]; ok {
return v, nil
}

if t.Implements(_parameterObjectType) {
// No caching
return c.createParamObject(t)
}

n, ok := c.nodes[t]
if !ok {
return _noValue, fmt.Errorf("type %v isn't in the container", t)
Expand Down Expand Up @@ -231,18 +236,7 @@ func (c *Container) remove(nodes []node) {
func (c *Container) constructorArgs(ctype reflect.Type) ([]reflect.Value, error) {
args := make([]reflect.Value, 0, ctype.NumIn())
for i := 0; i < ctype.NumIn(); i++ {
var (
arg reflect.Value
err error
)

t := ctype.In(i)
if t.Implements(_parameterObjectType) {
arg, err = c.getParameterObject(t)
} else {
arg, err = c.get(t)
}

arg, err := c.get(ctype.In(i))
if err != nil {
return nil, fmt.Errorf("couldn't get arguments for constructor %v: %v", ctype, err)
}
Expand Down Expand Up @@ -340,25 +334,22 @@ func getParameterDependencies(t reflect.Type) ([]reflect.Type, error) {
continue // skip private fields
}

// Skip the embedded Param type.
if f.Anonymous && f.Type == _paramType {
continue
}

// The user added a parameter object as a dependency. We don't recurse
// /yet/ so let's try to give an informative error message.
if f.Type.Implements(_parameterObjectType) {
return nil, fmt.Errorf(
"dig parameter objects may not be used as fields of other parameter objects: "+
"field %v (type %v) of %v is a parameter object", f.Name, f.Type, t)
newDeps, err := getParameterDependencies(f.Type)
if err != nil {
return nil, err
}
deps = append(deps, newDeps...)
} else {
deps = append(deps, f.Type)
}

deps = append(deps, f.Type)
}
return deps, nil
}

func (c *Container) getParameterObject(t reflect.Type) (reflect.Value, error) {
// Returns a new Param parent object with all the dependency fields
// populated from the dig container.
func (c *Container) createParamObject(t reflect.Type) (reflect.Value, error) {
dest := reflect.New(t).Elem()
result := dest
for t.Kind() == reflect.Ptr {
Expand All @@ -373,11 +364,6 @@ func (c *Container) getParameterObject(t reflect.Type) (reflect.Value, error) {
continue // skip private fields
}

// Skip the embedded Param type.
if f.Anonymous && f.Type == _paramType {
continue
}

v, err := c.get(f.Type)
if err != nil {
if f.Tag.Get("optional") == "true" {
Expand All @@ -390,6 +376,5 @@ func (c *Container) getParameterObject(t reflect.Type) (reflect.Value, error) {

dest.Field(i).Set(v)
}

return result, nil
}
77 changes: 47 additions & 30 deletions dig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,53 @@ func TestEndToEndSuccess(t *testing.T) {
}))
})

t.Run("param wrapper", func(t *testing.T) {
c := New()
require.NoError(t, c.Provide(func() *bytes.Buffer {
return new(bytes.Buffer)
}), "provide failed")

type MyParam struct{ Param }

type Args struct {
MyParam

Buffer *bytes.Buffer
}

require.NoError(t, c.Invoke(func(args Args) {
require.NotNil(t, args.Buffer, "invoke got nil buffer")
}))
})

t.Run("param recurse", func(t *testing.T) {
type anotherParam struct {
Param

Buffer *bytes.Buffer
}

type someParam struct {
Param

Buffer *bytes.Buffer
Another *anotherParam
}

c := New()
require.NoError(t, c.Provide(func() *bytes.Buffer {
return new(bytes.Buffer)
}), "provide must not fail")

require.NoError(t, c.Invoke(func(p *someParam) {
require.NotNil(t, p, "someParam must not be nil")
require.NotNil(t, p.Buffer, "someParam must not be nil")
require.NotNil(t, p.Another, "someParam must not be nil")
require.NotNil(t, p.Another.Buffer, "someParam must not be nil")
require.True(t, p.Buffer == p.Another.Buffer, "buffer must be the same")
}), "invoke must not fail")
})

t.Run("multiple-type constructor", func(t *testing.T) {
c := New()
constructor := func() (*bytes.Buffer, []int, error) {
Expand Down Expand Up @@ -390,36 +437,6 @@ func TestProvideConstructorErrors(t *testing.T) {
})
}

func TestParamsDontRecurse(t *testing.T) {
t.Parallel()

t.Run("param field", func(t *testing.T) {

type anotherParam struct {
Param

Reader io.Reader
}

type someParam struct {
Param

Writer io.Writer
Another anotherParam
}

c := New()
err := c.Provide(func(a someParam) *bytes.Buffer {
panic("constructor must not be called")
})
require.Error(t, err, "provide must fail")
require.Contains(t, err.Error(),
"parameter objects may not be used as fields of other parameter objects")
require.Contains(t, err.Error(),
"field Another (type dig.anotherParam) of dig.someParam is a parameter object")
})
}

func TestProvideRespectsConstructorErrors(t *testing.T) {
t.Run("constructor succeeds", func(t *testing.T) {
c := New()
Expand Down

0 comments on commit 3731e03

Please sign in to comment.