Skip to content

Commit

Permalink
simplified invoke validation
Browse files Browse the repository at this point in the history
  • Loading branch information
anuptalwalkar committed May 18, 2017
1 parent 427b672 commit b59414b
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 18 deletions.
2 changes: 1 addition & 1 deletion container.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (c *Container) InvokeOnce(t interface{}) error {
if ctype.Kind() != reflect.Func {
return errors.Wrapf(errParamType, _invokeErr, ctype)
}
if err := c.ValidateReturnTypes(ctype, true); err != nil {
if err := c.ValidateInvokeReturnTypes(ctype); err != nil {
return ErrInvokeOnce
}
return c.Invoke(t)
Expand Down
45 changes: 31 additions & 14 deletions internal/graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (g *Graph) InsertConstructor(ctor interface{}) error {
objTypes[i] = ctype.Out(i)
}

if err := g.ValidateReturnTypes(ctype, false); err != nil {
if err := g.validateCtorReturnTypes(ctype); err != nil {
return err
}

Expand Down Expand Up @@ -137,25 +137,42 @@ func (g *Graph) InsertConstructor(ctor interface{}) error {
return nil
}

// ValidateReturnTypes validates if Invoke func's return type is Provided to the graph
// checkCachedObjects=true ensures to throw an error only when the node is already
// resolved and cached.
//
func (g *Graph) ValidateReturnTypes(ctype reflect.Type, checkCachedObjects bool) error {
objMap := make(map[reflect.Type]bool, ctype.NumOut())
// ValidateInvokeReturnTypes validates Invoke return types and returns an error
// if the grah node of return type is resolved and cached.
func (g *Graph) ValidateInvokeReturnTypes(ctype reflect.Type) error {
if err := g.checkDuplicateReturns(ctype); err != nil {
return err
}
for i := 0; i < ctype.NumOut(); i++ {
objType := ctype.Out(i)
if _, ok := g.nodes[objType]; ok {
// for Invoke validation, graphNode may not be resolved (still in the form of funcNode).
// checkCachedObjects ensures to throw an error only when the node is resolved and cached.
if checkCachedObjects {
if obj, ok := g.nodes[objType].(*objNode); ok && obj.cached {
return errors.Wrapf(errRetNode, "ctor: %v, object type: %v", ctype, objType)
}
} else {
if obj, ok := g.nodes[objType].(*objNode); ok && obj.cached {
return errors.Wrapf(errRetNode, "ctor: %v, object type: %v", ctype, objType)
}
}
}
return nil
}

// validateCtorReturnTypes validates provided constructor and returns error
// when graph node for ctor return type is not already provided
func (g *Graph) validateCtorReturnTypes(ctype reflect.Type) error {
if err := g.checkDuplicateReturns(ctype); err != nil {
return err
}
for i := 0; i < ctype.NumOut(); i++ {
objType := ctype.Out(i)
if _, ok := g.nodes[objType]; ok {
return errors.Wrapf(errRetNode, "ctor: %v, object type: %v", ctype, objType)
}
}
return nil
}

func (g *Graph) checkDuplicateReturns(ctype reflect.Type) error {
objMap := make(map[reflect.Type]bool, ctype.NumOut())
for i := 0; i < ctype.NumOut(); i++ {
objType := ctype.Out(i)
if objMap[objType] {
return errors.Wrapf(errRetNode, "ctor: %v, object type: %v", ctype, objType)
}
Expand Down
11 changes: 8 additions & 3 deletions internal/graph/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,22 @@ func TestCtorConflicts(t *testing.T) {
require.Contains(t, err.Error(), "ctor: func() (*graph.Child1, *graph.Child1, error), object type: *graph.Child1: node already exist for the constructor")
}

func TestConstructorOverrideReturnsError(t *testing.T) {
func TestCtorOverrideReturnsError(t *testing.T) {
t.Parallel()
g := NewGraph()

err := g.InsertConstructor(threeObjects)
require.NoError(t, err)
err = g.ValidateReturnTypes(reflect.TypeOf(oneObject), false)
err = g.validateCtorReturnTypes(reflect.TypeOf(oneObject))
require.Contains(t, err.Error(), "ctor: func() (*graph.Child1, error), object type: *graph.Child1: node already exist for the constructor")
}

func TestInvokeOverrideReturnsError(t *testing.T) {
t.Parallel()
g := NewGraph()

g.InsertObject(reflect.ValueOf(&Child1{}))
err = g.ValidateReturnTypes(reflect.TypeOf(oneObject), true)
err := g.ValidateInvokeReturnTypes(reflect.TypeOf(oneObject))
require.Contains(t, err.Error(), "ctor: func() (*graph.Child1, error), object type: *graph.Child1: node already exist for the constructor")
}

Expand Down

0 comments on commit b59414b

Please sign in to comment.