Skip to content

Commit

Permalink
Prevent self-referential dependency recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul M Fox committed Sep 12, 2015
1 parent 4837b88 commit 547bfba
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 7 deletions.
41 changes: 41 additions & 0 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,44 @@ func assertConcreteValue(c ConcreteType, t *testing.T) {
t.Errorf("Test Stringer: got %s, expected %s", g, e)
}
}

//////////////////////////////////////////
// Self-referential, ouroboros types
//////////////////////////////////////////

type Valuer interface {
Value() int
}

// Self-referential valuer #1
type Ouroboros1 struct {
A Valuer `inj:""`
B Valuer `inj:""`
V int
}

func (o Ouroboros1) Value() int { return o.V }

// Self-referential valuer #2
type Ouroboros2 struct {
A Valuer `inj:""`
B Valuer `inj:""`
V int
}

func (o Ouroboros2) Value() int { return o.V }

// Self-referential valuer #3
type Ouroboros3 struct {
V int
}

func (o Ouroboros3) Value() int { return o.V }

// Self-referential valuer #4
type Ouroboros4 struct {
Ouroboros3 `inj:""`
V int
}

func (o Ouroboros4) Value() int { return o.V }
24 changes: 21 additions & 3 deletions graph_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ func (g *Graph) connect() {

func (g *Graph) assignValueToNode(o reflect.Value, dep GraphNodeDependency) error {

v, err := g.findFieldValue(o, dep.Path)
parents := []reflect.Value{}
v, err := g.findFieldValue(o, dep.Path, &parents)

if err != nil {
return err
Expand All @@ -47,6 +48,21 @@ func (g *Graph) assignValueToNode(o reflect.Value, dep GraphNodeDependency) erro
// Run through the graph and see if anything is settable
for typ, node := range g.Nodes {

valid := true

// Don't assign anything to itself or its children
for _, parent := range parents {

if parent.Interface() == node.Value.Interface() {
valid = false
break
}
}

if !valid {
continue
}

if typ.AssignableTo(v.Type()) {
v.Set(node.Value)
return nil
Expand All @@ -57,7 +73,9 @@ func (g *Graph) assignValueToNode(o reflect.Value, dep GraphNodeDependency) erro
}

// Required a struct type
func (g *Graph) findFieldValue(parent reflect.Value, path StructPath) (reflect.Value, error) {
func (g *Graph) findFieldValue(parent reflect.Value, path StructPath, linneage *[]reflect.Value) (reflect.Value, error) {

*linneage = append(*linneage, parent)

// Dereference incoming values
if parent.Kind() == reflect.Ptr {
Expand Down Expand Up @@ -85,5 +103,5 @@ func (g *Graph) findFieldValue(parent reflect.Value, path StructPath) (reflect.V
}

// Otherwise recurse
return g.findFieldValue(f, path)
return g.findFieldValue(f, path, linneage)
}
6 changes: 3 additions & 3 deletions graph_connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func Test_ConnectFindFieldValue(t *testing.T) {
}

for _, d := range descs {
rv, e := g.findFieldValue(v, d.path)
rv, e := g.findFieldValue(v, d.path, &[]reflect.Value{})

if e != nil {
t.Errorf("findFieldValue: %s", e.Error())
Expand All @@ -274,7 +274,7 @@ func Test_ConnectF(t *testing.T) {

g := NewGraph()

_, e := g.findFieldValue(reflect.ValueOf("123"), ".Child1")
_, e := g.findFieldValue(reflect.ValueOf("123"), ".Child1", &[]reflect.Value{})

if e == nil {
fmt.Errorf("Didn't error when type wasn't struct")
Expand All @@ -286,7 +286,7 @@ func Test_ConnectG(t *testing.T) {

g, p := NewGraph(), &validConnectTester{}

_, e := g.findFieldValue(reflect.ValueOf(p), ".This.Doesnt.Exist")
_, e := g.findFieldValue(reflect.ValueOf(p), ".This.Doesnt.Exist", &[]reflect.Value{})

if e == nil {
fmt.Errorf("Didn't error when path was wrong")
Expand Down
100 changes: 99 additions & 1 deletion graph_provide_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func Test_ProvideOverride1(t *testing.T) {
}
}

// Embedded structs should parsed like any other
// Embedded structs should be parsed like any other
func Test_EmbeddedStructProvision(t *testing.T) {

g := NewGraph()
Expand All @@ -132,6 +132,104 @@ func Test_EmbeddedStructProvision(t *testing.T) {
}
}

// Self-referential dependencies shouldn't be assigned
func Test_SelfReferencingDoesntHappen(t *testing.T) {

g := NewGraph()

o := Ouroboros1{}

g.Provide(&o)

valid, errs := g.Assert()

if valid {
t.Fatalf("g.Assert() is valid when it shouldn't be")
}

// There are two deps that should have missed
if g, e := len(errs), 2; g != e {
t.Fatalf("Expected %d errors, got %d (%v)", e, g, errs)
}
}

// Self-referential dependencies shouldn't impede proper injection
func Test_SelfReferencingShouldntCircumentInjection(t *testing.T) {

g := NewGraph()

o1 := Ouroboros1{V: 1}
o2 := Ouroboros2{V: 2}

g.Provide(&o1, &o2)

valid, errs := g.Assert()

if !valid {
t.Fatalf("g.Assert() is not valid when it should be (%v)", errs)
}

// The values should now be 'crossed'
if o1.A.Value() != o1.B.Value() {
t.Errorf("o1.A != o1.B")
}

if o1.A.Value() != o2.Value() {
t.Errorf("o1.B and B aren't equal to o2")
}

if o2.A.Value() != o2.B.Value() {
t.Errorf("o2.A != o2.B")
}

if o2.A.Value() != o1.Value() {
t.Errorf("o2.B and B aren't equal to o1")
}
}

// Self-referential prevention must extend to embedding
func Test_EmbeddedSelfReferencingDoesntHappen(t *testing.T) {

g := NewGraph()

o := Ouroboros4{}

g.Provide(&o)

valid, errs := g.Assert()

if valid {
t.Fatalf("g.Assert() is valid when it shouldn't be")
}

// There is one dep that should have missed
if g, e := len(errs), 1; g != e {
t.Fatalf("Expected %d error, got %d (%v)", e, g, errs)
}
}

// Self-referential dependencies shouldn't impede proper injection
func Test_EmbeddedSelfReferencingShouldntCircumentInjection(t *testing.T) {

g := NewGraph()

o1 := Ouroboros3{V: 1}
o2 := Ouroboros4{V: 2}

g.Provide(o1, &o2)

valid, errs := g.Assert()

if !valid {
t.Fatalf("g.Assert() is not valid when it should be (%v)", errs)
}

// The value should now be assigned
if o2.Ouroboros3.Value() != o1.Value() {
t.Errorf("o2.A != o2.B")
}
}

//////////////////////////////////////////
// Benchmark tests
//////////////////////////////////////////
Expand Down

0 comments on commit 547bfba

Please sign in to comment.