Skip to content

Commit

Permalink
Minimal parameter objects implementation
Browse files Browse the repository at this point in the history
This adds support for parameter objects to dig. Structs opt-in to be
treated as parameter objects by embedding `dig.Param`.

For the sake of simplicity from the users' point-of-view, we don't
recurse into parameter objects inside other parameter objects. This can
be changed in the future if there's demand for it.
  • Loading branch information
abhinav committed May 26, 2017
1 parent 13204cd commit 80fdf72
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 9 deletions.
132 changes: 123 additions & 9 deletions dig.go
Expand Up @@ -28,8 +28,10 @@ import (
)

var (
_noValue reflect.Value
_errType = reflect.TypeOf((*error)(nil)).Elem()
_noValue reflect.Value
_errType = reflect.TypeOf((*error)(nil)).Elem()
_parameterObjectType = reflect.TypeOf((*parameterObject)(nil)).Elem()
_paramType = reflect.TypeOf(Param{})
)

// A Container is a directed, acyclic graph of dependencies. Dependencies are
Expand Down Expand Up @@ -110,6 +112,9 @@ func (c *Container) provideInstance(val interface{}) error {
if vtype == _errType {
return errors.New("can't provide errors")
}
if vtype.Implements(_parameterObjectType) {
return errors.New("can't provide parameter objects")
}
if _, ok := c.nodes[vtype]; ok {
return errors.New("already in container")
}
Expand All @@ -126,6 +131,9 @@ func (c *Container) provideConstructor(ctor interface{}, ctype reflect.Type) err
// Don't register errors into the container.
continue
}
if rt.Implements(_parameterObjectType) {
return errors.New("can't provide parameter objects")
}
if _, ok := returnTypes[rt]; ok {
return fmt.Errorf("returns multiple %v", rt)
}
Expand All @@ -140,7 +148,10 @@ func (c *Container) provideConstructor(ctor interface{}, ctype reflect.Type) err

nodes := make([]node, 0, len(returnTypes))
for rt := range returnTypes {
n := newNode(rt, ctor, ctype)
n, err := newNode(rt, ctor, ctype)
if err != nil {
return err
}
nodes = append(nodes, n)
c.nodes[rt] = n
}
Expand Down Expand Up @@ -222,7 +233,18 @@ func (c *Container) remove(nodes []node) {
func (c *Container) constructorArgs(ctype reflect.Type) ([]reflect.Value, error) {
args := make([]reflect.Value, 0, ctype.NumIn())
for i := 0; i < ctype.NumIn(); i++ {
arg, err := c.get(ctype.In(i))
var (
arg reflect.Value
err error
)

t := ctype.In(i)
if t.Implements(_parameterObjectType) {
arg, err = getParameterObject(c, t)
} else {
arg, err = c.get(t)
}

if err != nil {
return nil, fmt.Errorf("couldn't get arguments for constructor %v: %v", ctype, err)
}
Expand All @@ -238,17 +260,27 @@ type node struct {
deps []reflect.Type
}

func newNode(provides reflect.Type, ctor interface{}, ctype reflect.Type) node {
deps := make([]reflect.Type, ctype.NumIn())
for i := range deps {
deps[i] = ctype.In(i)
func newNode(provides reflect.Type, ctor interface{}, ctype reflect.Type) (node, error) {
deps := make([]reflect.Type, 0, ctype.NumIn())
for i := 0; i < ctype.NumIn(); i++ {
t := ctype.In(i)
if t.Implements(_parameterObjectType) {
pdeps, err := getParameterDependencies(t)
if err != nil {
return node{}, err
}
deps = append(deps, pdeps...)
} else {
deps = append(deps, t)
}
}

return node{
provides: provides,
ctor: ctor,
ctype: ctype,
deps: deps,
}
}, nil
}

func cycleError(cycle []reflect.Type, last reflect.Type) error {
Expand Down Expand Up @@ -278,3 +310,85 @@ func detectCycles(n node, graph map[reflect.Type]node, path []reflect.Type) erro
}
return nil
}

// Param is embedded inside structs to opt those structs in as Dig parameter
// objects.
//
// TODO usage docs
type Param struct{}

var _ parameterObject = Param{}

// Param is the only instance of parameterObject.
func (Param) parameterObject() {}

// Users embed the Param struct to opt a struct in as a parameter object.
// Param implements this interface so the struct into which Param is embedded
// also implements this interface. This provides us an easy way to check if
// something embeds Param without iterating through all its fields.
type parameterObject interface {
parameterObject()
}

// Returns dependencies introduced by a parameter object.
func getParameterDependencies(t reflect.Type) ([]reflect.Type, error) {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}

var deps []reflect.Type
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.PkgPath != "" {
continue // skip private fields
}

// Skip the embedded Param type.
if f.Anonymous && f.Type == _paramType {
continue
}

// THe user added a parameter object as a dependency. We don't recurse
// /yet/ so let's try to give an informative error message.
if f.Type.Implements(_parameterObjectType) {
return nil, fmt.Errorf(
"dig parameter objects may not be used as fields of other parameter objects: "+
"field %v (type %v) of %v is a parameter object", f.Name, f.Type, t)
}

deps = append(deps, f.Type)
}
return deps, nil
}

func getParameterObject(c *Container, t reflect.Type) (reflect.Value, error) {
dest := reflect.New(t).Elem()
result := dest
for t.Kind() == reflect.Ptr {
t = t.Elem()
dest.Set(reflect.New(t))
dest = dest.Elem()
}

for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.PkgPath != "" {
continue // skip private fields
}

// Skip the embedded Param type.
if f.Anonymous && f.Type == _paramType {
continue
}

v, err := c.get(f.Type)
if err != nil {
return result, fmt.Errorf(
"could not get field %v (type %v) of %v: %v", f.Name, f.Type, t, err)
}

dest.Field(i).Set(v)
}

return result, nil
}
147 changes: 147 additions & 0 deletions dig_test.go
Expand Up @@ -24,6 +24,7 @@ import (
"bytes"
"errors"
"io"
"io/ioutil"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -199,6 +200,78 @@ func TestEndToEndSuccess(t *testing.T) {
}), "invoke failed")
})

t.Run("param", func(t *testing.T) {
c := New()

type contents string

type Args struct {
Param

privateContents contents
Contents contents
}

require.NoError(t,
c.Provide(func(args Args) *bytes.Buffer {
// testify's Empty doesn't work on string aliases for some
// reason
require.Len(t, args.privateContents, 0, "private contents must be empty")

require.NotEmpty(t, args.Contents, "contents must not be empty")
return bytes.NewBufferString(string(args.Contents))
}), "provide constructor failed")

require.NoError(t,
c.Provide(contents("hello world")),
"provide value failed")

require.NoError(t, c.Invoke(func(buff *bytes.Buffer) {
out, err := ioutil.ReadAll(buff)
require.NoError(t, err, "read from buffer failed")
require.Equal(t, "hello world", string(out), "contents don't match")
}))
})

t.Run("param pointer", func(t *testing.T) {
c := New()

require.NoError(t, c.Provide(func() io.Writer {
return new(bytes.Buffer)
}), "provide failed")

type Args struct {
Param

Writer io.Writer
}

require.NoError(t, c.Invoke(func(args *Args) {
require.NotNil(t, args, "args must not be nil")
require.NotNil(t, args.Writer, "writer must not be nil")
}), "invoke failed")
})

t.Run("invoke param", func(t *testing.T) {
c := New()
require.NoError(t, c.Provide(func() *bytes.Buffer {
return new(bytes.Buffer)
}), "provide failed")

type Args struct {
Param

privateBuffer *bytes.Buffer

*bytes.Buffer
}

require.NoError(t, c.Invoke(func(args Args) {
require.Nil(t, args.privateBuffer, "private buffer must be nil")
require.NotNil(t, args.Buffer, "invoke got nil buffer")
}))
})

t.Run("multiple-type constructor", func(t *testing.T) {
c := New()
constructor := func() (*bytes.Buffer, []int, error) {
Expand All @@ -223,6 +296,36 @@ func TestEndToEndSuccess(t *testing.T) {
})
}

func TestParamsDontRecurse(t *testing.T) {
t.Parallel()

t.Run("param field", func(t *testing.T) {

type anotherParam struct {
Param

Reader io.Reader
}

type someParam struct {
Param

Writer io.Writer
Another anotherParam
}

c := New()
err := c.Provide(func(a someParam) *bytes.Buffer {
panic("constructor must not be called")
})
require.Error(t, err, "provide must fail")
require.Contains(t, err.Error(),
"parameter objects may not be used as fields of other parameter objects")
require.Contains(t, err.Error(),
"field Another (type dig.anotherParam) of dig.someParam is a parameter object")
})
}

func TestProvideRespectsConstructorErrors(t *testing.T) {
t.Run("constructor succeeds", func(t *testing.T) {
c := New()
Expand Down Expand Up @@ -265,6 +368,50 @@ func TestCantProvideErrors(t *testing.T) {
assert.NoError(t, c.Provide(errors.New("foo")))
}

func TestCantProvideParameterObjects(t *testing.T) {
t.Parallel()

t.Run("instance", func(t *testing.T) {
type Args struct{ Param }

c := New()
err := c.Provide(Args{})
require.Error(t, err, "provide should fail")
require.Contains(t, err.Error(), "can't provide parameter objects")
})

t.Run("pointer", func(t *testing.T) {
type Args struct{ Param }

c := New()
err := c.Provide(&Args{})
require.Error(t, err, "provide should fail")
require.Contains(t, err.Error(), "can't provide parameter objects")
})

t.Run("constructor", func(t *testing.T) {
type Args struct{ Param }

c := New()
err := c.Provide(func() (Args, error) {
panic("great sadness")
})
require.Error(t, err, "provide should fail")
require.Contains(t, err.Error(), "can't provide parameter objects")
})

t.Run("constructor pointer", func(t *testing.T) {
type Args struct{ Param }

c := New()
err := c.Provide(func() (*Args, error) {
panic("great sadness")
})
require.Error(t, err, "provide should fail")
require.Contains(t, err.Error(), "can't provide parameter objects")
})
}

func TestProvideKnownTypesFails(t *testing.T) {
t.Parallel()
c := New()
Expand Down

0 comments on commit 80fdf72

Please sign in to comment.