Skip to content

Commit

Permalink
param: Pointers are not parameter objects (#81)
Browse files Browse the repository at this point in the history
This narrows down the definition of parameter objects to only structs
embedding `dig.Param`. Pointers to structs embedding `dig.Param` are
treated like any other struct pointer.
  • Loading branch information
abhinav committed Jun 1, 2017
1 parent 9d83d4b commit 4ff4da6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 48 deletions.
23 changes: 10 additions & 13 deletions dig.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (c *Container) provideInstance(val interface{}) error {
if vtype == _errType {
return errors.New("can't provide errors")
}
if vtype.Implements(_parameterObjectType) {
if isParameterObject(vtype) {
return errors.New("can't provide parameter objects")
}
if _, ok := c.nodes[vtype]; ok {
Expand All @@ -135,7 +135,7 @@ func (c *Container) provideConstructor(ctor interface{}, ctype reflect.Type) err
// Don't register errors into the container.
continue
}
if rt.Implements(_parameterObjectType) {
if isParameterObject(rt) {
return errors.New("can't provide parameter objects")
}
if _, ok := returnTypes[rt]; ok {
Expand Down Expand Up @@ -180,7 +180,7 @@ func (c *Container) get(t reflect.Type) (reflect.Value, error) {
return v, nil
}

if t.Implements(_parameterObjectType) {
if isParameterObject(t) {
// We do not want parameter objects to be cached.
return c.createParamObject(t)
}
Expand Down Expand Up @@ -270,7 +270,7 @@ func newNode(provides reflect.Type, ctor interface{}, ctype reflect.Type) (node,

// Retrives the dependencies for the parameter of a constructor.
func getCtorParamDependencies(t reflect.Type) (deps []reflect.Type) {
if !t.Implements(_parameterObjectType) {
if !isParameterObject(t) {
deps = append(deps, t)
return
}
Expand Down Expand Up @@ -334,17 +334,14 @@ type parameterObject interface {
parameterObject()
}

func isParameterObject(t reflect.Type) bool {
return t.Implements(_parameterObjectType) && t.Kind() == reflect.Struct
}

// Returns a new Param parent object with all the dependency fields
// populated from the dig container.
func (c *Container) createParamObject(t reflect.Type) (reflect.Value, error) {
dest := reflect.New(t).Elem()
result := dest
for t.Kind() == reflect.Ptr {
t = t.Elem()
dest.Set(reflect.New(t))
dest = dest.Elem()
}

for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.PkgPath != "" {
Expand All @@ -357,12 +354,12 @@ func (c *Container) createParamObject(t reflect.Type) (reflect.Value, error) {
case "true", "yes":
v = reflect.Zero(f.Type)
default:
return result, fmt.Errorf(
return dest, fmt.Errorf(
"could not get field %v (type %v) of %v: %v", f.Name, f.Type, t, err)
}
}

dest.Field(i).Set(v)
}
return result, nil
return dest, nil
}
56 changes: 21 additions & 35 deletions dig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,25 +247,6 @@ func TestEndToEndSuccess(t *testing.T) {
}))
})

t.Run("param pointer", func(t *testing.T) {
c := New()

require.NoError(t, c.Provide(func() io.Writer {
return new(bytes.Buffer)
}), "provide failed")

type Args struct {
Param

Writer io.Writer
}

require.NoError(t, c.Invoke(func(args *Args) {
require.NotNil(t, args, "args must not be nil")
require.NotNil(t, args.Writer, "writer must not be nil")
}), "invoke failed")
})

t.Run("invoke param", func(t *testing.T) {
c := New()
require.NoError(t, c.Provide(func() *bytes.Buffer {
Expand Down Expand Up @@ -316,15 +297,15 @@ func TestEndToEndSuccess(t *testing.T) {
Param

Buffer *bytes.Buffer
Another *anotherParam
Another anotherParam
}

c := New()
require.NoError(t, c.Provide(func() *bytes.Buffer {
return new(bytes.Buffer)
}), "provide must not fail")

require.NoError(t, c.Invoke(func(p *someParam) {
require.NoError(t, c.Invoke(func(p someParam) {
require.NotNil(t, p, "someParam must not be nil")
require.NotNil(t, p.Buffer, "someParam must not be nil")
require.NotNil(t, p.Another, "someParam must not be nil")
Expand Down Expand Up @@ -502,10 +483,9 @@ func TestCanProvideErrorLikeType(t *testing.T) {
c := New()
require.NoError(t, c.Provide(tt), "provide must not fail")

require.NoError(t, c.Invoke(
func(err *someError) {
assert.NotNil(t, err, "invoke received nil")
}), "invoke must not fail")
require.NoError(t, c.Invoke(func(err *someError) {
assert.NotNil(t, err, "invoke received nil")
}), "invoke must not fail")
})
}
}
Expand All @@ -525,10 +505,13 @@ func TestCantProvideParameterObjects(t *testing.T) {
t.Run("pointer", func(t *testing.T) {
type Args struct{ Param }

args := &Args{}

c := New()
err := c.Provide(&Args{})
require.Error(t, err, "provide should fail")
require.Contains(t, err.Error(), "can't provide parameter objects")
require.NoError(t, c.Provide(args), "provide failed")
require.NoError(t, c.Invoke(func(a *Args) {
require.True(t, args == a, "args must match")
}), "invoke failed")
})

t.Run("constructor", func(t *testing.T) {
Expand All @@ -542,15 +525,18 @@ func TestCantProvideParameterObjects(t *testing.T) {
require.Contains(t, err.Error(), "can't provide parameter objects")
})

t.Run("constructor pointer", func(t *testing.T) {
t.Run("pointer from constructor", func(t *testing.T) {
type Args struct{ Param }

args := &Args{}

c := New()
err := c.Provide(func() (*Args, error) {
panic("great sadness")
})
require.Error(t, err, "provide should fail")
require.Contains(t, err.Error(), "can't provide parameter objects")
require.NoError(t, c.Provide(func() (*Args, error) {
return args, nil
}), "provide failed")
require.NoError(t, c.Invoke(func(a *Args) {
require.True(t, args == a, "args must match")
}), "invoke failed")
})
}

Expand Down Expand Up @@ -684,7 +670,7 @@ func TestInvokeFailures(t *testing.T) {
}

c := New()
err := c.Invoke(func(a *args) {
err := c.Invoke(func(a args) {
t.Fatal("function must not be called")
})

Expand Down

0 comments on commit 4ff4da6

Please sign in to comment.