Skip to content

Commit

Permalink
Add InvokeOnce functionality to container
Browse files Browse the repository at this point in the history
  • Loading branch information
anuptalwalkar committed May 17, 2017
1 parent d88cbc9 commit 16581fc
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 5 deletions.
17 changes: 17 additions & 0 deletions container.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ var (
errReturnKind = errors.New("constructor return type must be a pointer")
errArgKind = errors.New("constructor arguments must be pointers")

// ErrInvokeOnce is returned if function is already invoked
ErrInvokeOnce = errors.New("function is already invoked")

_typeOfError = reflect.TypeOf((*error)(nil)).Elem()

_forCtor = "for constructor %v"
Expand All @@ -50,6 +53,20 @@ type Container struct {
graph.Graph
}

// InvokeOnce only allows function invokation once to register the
// return types. If return types are already registered, specific error
// is returned to the caller
func (c *Container) InvokeOnce(t interface{}) error {
ctype := reflect.TypeOf(t)
if ctype.Kind() != reflect.Func {
return errors.Wrapf(errParamType, _invokeErr, ctype)
}
if err := c.ValidateReturnTypes(ctype, true); err != nil {
return ErrInvokeOnce
}
return c.Invoke(t)
}

// Invoke the function and resolve the dependencies immidiately without providing the
// constructor to the graph. The Invoke function returns error object which can be
// occurred during the execution
Expand Down
26 changes: 26 additions & 0 deletions container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,32 @@ func TestInvokeSuccess(t *testing.T) {
assert.NotNil(t, c1)
}

func TestInvokeOnce(t *testing.T) {
t.Parallel()
c := New()

err := c.Provide(
NewParent1,
NewChild1,
NewGrandchild1,
)
assert.NoError(t, err)
var c1 *Child1

err = c.InvokeOnce(NewParent1)
assert.NoError(t, err)

err = c.InvokeOnce(NewParent1)
assert.Equal(t, ErrInvokeOnce, err)

err = c.InvokeOnce(func(p1 *Parent1) {
require.NotNil(t, p1)
c1 = p1.c1
})
assert.NoError(t, err)
assert.NotNil(t, c1)
}

func TestInvokeAndRegisterSuccess(t *testing.T) {
t.Parallel()
c := New()
Expand Down
22 changes: 17 additions & 5 deletions internal/graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (g *Graph) InsertConstructor(ctor interface{}) error {
objTypes[i] = ctype.Out(i)
}

if err := g.ValidateReturnTypes(ctype); err != nil {
if err := g.ValidateReturnTypes(ctype, false); err != nil {
return err
}

Expand Down Expand Up @@ -137,22 +137,34 @@ func (g *Graph) InsertConstructor(ctor interface{}) error {
return nil
}

// ValidateReturnTypes validates if ctor's return type is already insterted in the graph
func (g *Graph) ValidateReturnTypes(ctype reflect.Type) error {
// ValidateReturnTypes validates if Invoke func's return type is Provided to the graph
// checkCachedObjects=true additionally checks if objects are resolved and cached in the graph
func (g *Graph) ValidateReturnTypes(ctype reflect.Type, checkCachedObjects bool) error {
objMap := make(map[reflect.Type]bool, ctype.NumOut())
for i := 0; i < ctype.NumOut(); i++ {
objType := ctype.Out(i)
if _, ok := g.nodes[objType]; ok {
return errors.Wrapf(errRetNode, "ctor: %v, object type: %v", ctype, ctype.Out(i))
if checkCachedObjects {
if g.checkCachedObjects(objType) {
return errors.Wrapf(errRetNode, "ctor: %v, object type: %v", ctype, objType)
}
} else {
return errors.Wrapf(errRetNode, "ctor: %v, object type: %v", ctype, objType)
}
}
if objMap[objType] {
return errors.Wrapf(errRetNode, "ctor: %v, object type: %v", ctype, ctype.Out(i))
return errors.Wrapf(errRetNode, "ctor: %v, object type: %v", ctype, objType)
}
objMap[objType] = true
}
return nil
}

func (g *Graph) checkCachedObjects(objType reflect.Type) bool {
obj, ok := g.nodes[objType].(*objNode)
return ok && obj.cached
}

// DFS and tracking if same node is visited twice
func (g *Graph) recursiveDetectCycles(n graphNode, l []string) error {
for _, el := range l {
Expand Down
14 changes: 14 additions & 0 deletions internal/graph/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ func TestCtorConflicts(t *testing.T) {
require.Contains(t, err.Error(), "ctor: func() (*graph.Child1, *graph.Child1, error), object type: *graph.Child1: node already exist for the constructor")
}

func TestValidateReturnTypes(t *testing.T) {
t.Parallel()
g := NewGraph()

err := g.InsertConstructor(threeObjects)
require.NoError(t, err)
err = g.ValidateReturnTypes(reflect.TypeOf(oneObject), false)
require.Contains(t, err.Error(), "ctor: func() (*graph.Child1, error), object type: *graph.Child1: node already exist for the constructor")

g.InsertObject(reflect.ValueOf(&Child1{}))
err = g.ValidateReturnTypes(reflect.TypeOf(oneObject), true)
require.Contains(t, err.Error(), "ctor: func() (*graph.Child1, error), object type: *graph.Child1: node already exist for the constructor")
}

func TestMultiObjectRegisterResolve(t *testing.T) {
t.Parallel()
g := NewGraph()
Expand Down

0 comments on commit 16581fc

Please sign in to comment.