Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Parameter Objects #71

Merged
merged 3 commits into from Jun 1, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
132 changes: 123 additions & 9 deletions dig.go
Expand Up @@ -28,8 +28,10 @@ import (
)

var (
_noValue reflect.Value
_errType = reflect.TypeOf((*error)(nil)).Elem()
_noValue reflect.Value
_errType = reflect.TypeOf((*error)(nil)).Elem()
_parameterObjectType = reflect.TypeOf((*parameterObject)(nil)).Elem()
_paramType = reflect.TypeOf(Param{})
)

// A Container is a directed, acyclic graph of dependencies. Dependencies are
Expand Down Expand Up @@ -113,6 +115,9 @@ func (c *Container) provideInstance(val interface{}) error {
if vtype == _errType {
return errors.New("can't provide errors")
}
if vtype.Implements(_parameterObjectType) {
return errors.New("can't provide parameter objects")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameter term is super generic. Can this error be more explicit? can't provide parameterized object

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been using the term "parameter object" for types that embed dig.Param in all the documentation and errors so far. That's the standard name for objects which represent all the parameter of a function, created solely to avoid having too many arguments on the function. See http://wiki.c2.com/?ParameterObject, https://refactoring.com/catalog/introduceParameterObject.html.

We can call this something else, maybe give it a dig-specific name, but "parameterized object" has a whole different meaning.

}
if _, ok := c.nodes[vtype]; ok {
return errors.New("already in container")
}
Expand All @@ -129,6 +134,9 @@ func (c *Container) provideConstructor(ctor interface{}, ctype reflect.Type) err
// Don't register errors into the container.
continue
}
if rt.Implements(_parameterObjectType) {
return errors.New("can't provide parameter objects")
}
if _, ok := returnTypes[rt]; ok {
return fmt.Errorf("returns multiple %v", rt)
}
Expand All @@ -143,7 +151,10 @@ func (c *Container) provideConstructor(ctor interface{}, ctype reflect.Type) err

nodes := make([]node, 0, len(returnTypes))
for rt := range returnTypes {
n := newNode(rt, ctor, ctype)
n, err := newNode(rt, ctor, ctype)
if err != nil {
return err
}
nodes = append(nodes, n)
c.nodes[rt] = n
}
Expand Down Expand Up @@ -220,7 +231,18 @@ func (c *Container) remove(nodes []node) {
func (c *Container) constructorArgs(ctype reflect.Type) ([]reflect.Value, error) {
args := make([]reflect.Value, 0, ctype.NumIn())
for i := 0; i < ctype.NumIn(); i++ {
arg, err := c.get(ctype.In(i))
var (
arg reflect.Value
err error
)

t := ctype.In(i)
if t.Implements(_parameterObjectType) {
arg, err = c.getParameterObject(t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: getting parameterObject can be part of c.get(t). Then you don't need to declare variables and assign. implementation wise we just want an object back from container.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. Updating the other PR with this.

} else {
arg, err = c.get(t)
}

if err != nil {
return nil, fmt.Errorf("couldn't get arguments for constructor %v: %v", ctype, err)
}
Expand All @@ -236,17 +258,27 @@ type node struct {
deps []reflect.Type
}

func newNode(provides reflect.Type, ctor interface{}, ctype reflect.Type) node {
deps := make([]reflect.Type, ctype.NumIn())
for i := range deps {
deps[i] = ctype.In(i)
func newNode(provides reflect.Type, ctor interface{}, ctype reflect.Type) (node, error) {
deps := make([]reflect.Type, 0, ctype.NumIn())
for i := 0; i < ctype.NumIn(); i++ {
t := ctype.In(i)
if t.Implements(_parameterObjectType) {
pdeps, err := getParameterDependencies(t)
if err != nil {
return node{}, err
}
deps = append(deps, pdeps...)
} else {
deps = append(deps, t)
}
}

return node{
provides: provides,
ctor: ctor,
ctype: ctype,
deps: deps,
}
}, nil
}

func cycleError(cycle []reflect.Type, last reflect.Type) error {
Expand Down Expand Up @@ -275,3 +307,85 @@ func detectCycles(n node, graph map[reflect.Type]node, path []reflect.Type, seen
}
return nil
}

// Param is embedded inside structs to opt those structs in as Dig parameter
// objects.
//
// TODO usage docs
type Param struct{}

var _ parameterObject = Param{}

// Param is the only instance of parameterObject.
func (Param) parameterObject() {}

// Users embed the Param struct to opt a struct in as a parameter object.
// Param implements this interface so the struct into which Param is embedded
// also implements this interface. This provides us an easy way to check if
// something embeds Param without iterating through all its fields.
type parameterObject interface {
parameterObject()
}

// Returns dependencies introduced by a parameter object.
func getParameterDependencies(t reflect.Type) ([]reflect.Type, error) {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}

var deps []reflect.Type
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.PkgPath != "" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: refactor? pkgPath and Anonymous is been checked in multiple places.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately that's unavoidable. These are two different entry points. This check is necessary to make sure we don't muck with private fields.

continue // skip private fields
}

// Skip the embedded Param type.
if f.Anonymous && f.Type == _paramType {
continue
}

// The user added a parameter object as a dependency. We don't recurse
// /yet/ so let's try to give an informative error message.
if f.Type.Implements(_parameterObjectType) {
return nil, fmt.Errorf(
"dig parameter objects may not be used as fields of other parameter objects: "+
"field %v (type %v) of %v is a parameter object", f.Name, f.Type, t)
}

deps = append(deps, f.Type)
}
return deps, nil
}

func (c *Container) getParameterObject(t reflect.Type) (reflect.Value, error) {
dest := reflect.New(t).Elem()
result := dest
for t.Kind() == reflect.Ptr {
t = t.Elem()
dest.Set(reflect.New(t))
dest = dest.Elem()
}

for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.PkgPath != "" {
continue // skip private fields
}

// Skip the embedded Param type.
if f.Anonymous && f.Type == _paramType {
continue
}

v, err := c.get(f.Type)
if err != nil {
return result, fmt.Errorf(
"could not get field %v (type %v) of %v: %v", f.Name, f.Type, t, err)
}

dest.Field(i).Set(v)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check if the field is settable? (In case dig will support non-pointer types)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The field will always be settable because the struct that contains it is addressable (because we created it with reflect.New).

}

return result, nil
}
147 changes: 147 additions & 0 deletions dig_test.go
Expand Up @@ -25,6 +25,7 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -213,6 +214,78 @@ func TestEndToEndSuccess(t *testing.T) {
}), "invoke failed")
})

t.Run("param", func(t *testing.T) {
c := New()

type contents string

type Args struct {
Param

privateContents contents
Contents contents
}

require.NoError(t,
c.Provide(func(args Args) *bytes.Buffer {
// testify's Empty doesn't work on string aliases for some
// reason
require.Len(t, args.privateContents, 0, "private contents must be empty")

require.NotEmpty(t, args.Contents, "contents must not be empty")
return bytes.NewBufferString(string(args.Contents))
}), "provide constructor failed")

require.NoError(t,
c.Provide(contents("hello world")),
"provide value failed")

require.NoError(t, c.Invoke(func(buff *bytes.Buffer) {
out, err := ioutil.ReadAll(buff)
require.NoError(t, err, "read from buffer failed")
require.Equal(t, "hello world", string(out), "contents don't match")
}))
})

t.Run("param pointer", func(t *testing.T) {
c := New()

require.NoError(t, c.Provide(func() io.Writer {
return new(bytes.Buffer)
}), "provide failed")

type Args struct {
Param

Writer io.Writer
}

require.NoError(t, c.Invoke(func(args *Args) {
require.NotNil(t, args, "args must not be nil")
require.NotNil(t, args.Writer, "writer must not be nil")
}), "invoke failed")
})

t.Run("invoke param", func(t *testing.T) {
c := New()
require.NoError(t, c.Provide(func() *bytes.Buffer {
return new(bytes.Buffer)
}), "provide failed")

type Args struct {
Param

privateBuffer *bytes.Buffer

*bytes.Buffer
}

require.NoError(t, c.Invoke(func(args Args) {
require.Nil(t, args.privateBuffer, "private buffer must be nil")
require.NotNil(t, args.Buffer, "invoke got nil buffer")
}))
})

t.Run("multiple-type constructor", func(t *testing.T) {
c := New()
constructor := func() (*bytes.Buffer, []int, error) {
Expand Down Expand Up @@ -290,6 +363,36 @@ func TestProvideConstructorErrors(t *testing.T) {
})
}

func TestParamsDontRecurse(t *testing.T) {
t.Parallel()

t.Run("param field", func(t *testing.T) {

type anotherParam struct {
Param

Reader io.Reader
}

type someParam struct {
Param

Writer io.Writer
Another anotherParam
}

c := New()
err := c.Provide(func(a someParam) *bytes.Buffer {
panic("constructor must not be called")
})
require.Error(t, err, "provide must fail")
require.Contains(t, err.Error(),
"parameter objects may not be used as fields of other parameter objects")
require.Contains(t, err.Error(),
"field Another (type dig.anotherParam) of dig.someParam is a parameter object")
})
}

func TestProvideRespectsConstructorErrors(t *testing.T) {
t.Run("constructor succeeds", func(t *testing.T) {
c := New()
Expand Down Expand Up @@ -360,6 +463,50 @@ func TestCanProvideErrorLikeType(t *testing.T) {
}
}

func TestCantProvideParameterObjects(t *testing.T) {
t.Parallel()

t.Run("instance", func(t *testing.T) {
type Args struct{ Param }

c := New()
err := c.Provide(Args{})
require.Error(t, err, "provide should fail")
require.Contains(t, err.Error(), "can't provide parameter objects")
})

t.Run("pointer", func(t *testing.T) {
type Args struct{ Param }

c := New()
err := c.Provide(&Args{})
require.Error(t, err, "provide should fail")
require.Contains(t, err.Error(), "can't provide parameter objects")
})

t.Run("constructor", func(t *testing.T) {
type Args struct{ Param }

c := New()
err := c.Provide(func() (Args, error) {
panic("great sadness")
})
require.Error(t, err, "provide should fail")
require.Contains(t, err.Error(), "can't provide parameter objects")
})

t.Run("constructor pointer", func(t *testing.T) {
type Args struct{ Param }

c := New()
err := c.Provide(func() (*Args, error) {
panic("great sadness")
})
require.Error(t, err, "provide should fail")
require.Contains(t, err.Error(), "can't provide parameter objects")
})
}

func TestProvideKnownTypesFails(t *testing.T) {
t.Parallel()

Expand Down