Skip to content

Commit

Permalink
Add panic variations of all DIG functions (#360)
Browse files Browse the repository at this point in the history
  • Loading branch information
glibsm authored Mar 7, 2017
1 parent 66e987f commit be1e111
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 2 deletions.
22 changes: 21 additions & 1 deletion dig/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,41 @@ func Register(i interface{}) error {
return defaultGraph.Register(i)
}

// MustRegister calls Register and panics if an error is encountered
func MustRegister(i interface{}) {
defaultGraph.MustRegister(i)
}

// RegisterAll into the default graph
func RegisterAll(is ...interface{}) error {
return defaultGraph.RegisterAll(is...)
}

// MustRegisterAll into the default graph
func MustRegisterAll(is ...interface{}) {
defaultGraph.MustRegisterAll(is...)
}

// Resolve an object through the default graph
func Resolve(i interface{}) error {
return defaultGraph.Resolve(i)
}

// ResolveAll the passed in pointers through the dependency graph
// MustResolve through the default graph
func MustResolve(i interface{}) {
defaultGraph.MustResolve(i)
}

// ResolveAll through the default graph
func ResolveAll(is ...interface{}) error {
return defaultGraph.ResolveAll(is...)
}

// MustResolveAll on the default graph
func MustResolveAll(is ...interface{}) {
defaultGraph.MustResolveAll(is...)
}

// Reset the default graph
func Reset() {
defaultGraph.Reset()
Expand Down
18 changes: 17 additions & 1 deletion dig/default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ type Type3 struct {
f float32
}

type Type4 struct{}

func TestDefaultGraph(t *testing.T) {
defer Reset()

Expand All @@ -46,19 +48,33 @@ func TestDefaultGraph(t *testing.T) {

t2 := &Type2{s: "42"}
t3 := &Type3{f: 4.2}
require.NoError(t, RegisterAll(t2, t3))
t4 := &Type4{}
require.NoError(t, RegisterAll(t2))
require.NotPanics(t, func() {
MustRegister(t3)
MustRegisterAll(t4)
})

var t1g *Type1
require.NoError(t, Resolve(&t1g))
require.NotPanics(t, func() {
MustResolve(&t1g)
})
require.True(t, t1g == t1)

var t2g *Type2
var t3g *Type3
require.NoError(t, ResolveAll(&t2g, &t3g))
require.NotPanics(t, func() {
MustResolveAll(&t2g, &t3g)
})
require.True(t, t2g == t2)
require.True(t, t3g == t3)

var t2g2 *Type2
require.NoError(t, DefaultGraph().Resolve(&t2g2))
require.NotPanics(t, func() {
MustResolve(&t2g2)
})
require.Equal(t, t2, t2g2)
}
28 changes: 28 additions & 0 deletions dig/dig.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ func (g *Graph) Register(c interface{}) error {
}
}

// MustRegister will attempt to register the object and panic if error is encountered
func (g *Graph) MustRegister(c interface{}) {
if err := g.Register(c); err != nil {
panic(err)
}
}

// Resolve all of the dependencies of the provided class
//
// Provided object must be a pointer
Expand Down Expand Up @@ -119,6 +126,13 @@ func (g *Graph) Resolve(obj interface{}) (err error) {
return nil
}

// MustResolve calls Resolve and panics if an error is encountered
func (g *Graph) MustResolve(obj interface{}) {
if err := g.Resolve(obj); err != nil {
panic(err)
}
}

// ResolveAll the dependencies of each provided object
// Returns the first error encountered
func (g *Graph) ResolveAll(objs ...interface{}) error {
Expand All @@ -130,6 +144,13 @@ func (g *Graph) ResolveAll(objs ...interface{}) error {
return nil
}

// MustResolveAll calls ResolveAll and panics if an error is encountered
func (g *Graph) MustResolveAll(objs ...interface{}) {
if err := g.ResolveAll(objs...); err != nil {
panic(err)
}
}

// RegisterAll registers all the provided args in the dependency graph
func (g *Graph) RegisterAll(cs ...interface{}) error {
for _, c := range cs {
Expand All @@ -140,6 +161,13 @@ func (g *Graph) RegisterAll(cs ...interface{}) error {
return nil
}

// MustRegisterAll calls RegisterAll and panics is an error is encountered
func (g *Graph) MustRegisterAll(cs ...interface{}) {
if err := g.RegisterAll(cs...); err != nil {
panic(err)
}
}

// Reset the graph by removing all the registered nodes
func (g *Graph) Reset() {
g.Lock()
Expand Down
62 changes: 62 additions & 0 deletions dig/dig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,65 @@ func TestPanicConstructor(t *testing.T) {
require.Contains(t, err.Error(), "panic during Resolve")
require.Contains(t, err.Error(), "RUH ROH")
}

func TestMustFunctions(t *testing.T) {
t.Parallel()
tts := []struct {
name string
f func(g *Graph)
panicExpected bool
}{
{
"wrong register type",
func(g *Graph) { g.MustRegister(2) },
true,
},
{
"wrong register all types",
func(g *Graph) { g.MustRegisterAll("2", "3") },
true,
},
{
"unregistered type",
func(g *Graph) {
var v *Type1
g.MustResolve(&v)
},
true,
},
{
"correct register",
func(g *Graph) { g.MustRegister(NewChild1) },
false,
},
{
"correct register all",
func(g *Graph) { g.MustRegisterAll(NewChild1, NewChild2) },
false,
},
{
"unregistered types",
func(g *Graph) {
var v *Type1
var v2 *Type2
g.MustResolveAll(&v, &v2)
},
true,
},
}

for _, tc := range tts {
t.Run(tc.name, func(t *testing.T) {
g := New()
f := func() {
tc.f(g)
}

if tc.panicExpected {
require.Panics(t, f)
} else {
require.NotPanics(t, f)
}
})
}
}
41 changes: 41 additions & 0 deletions dig/node_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2017 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package dig

import (
"testing"

"reflect"

"github.com/stretchr/testify/require"
)

func TestNodeStrings(t *testing.T) {
n := objNode{}
require.Contains(t, n.String(), "(object)")

fn := funcNode{
node: node{
objType: reflect.TypeOf(n),
},
}
require.Contains(t, fn.String(), "(function)")
}

0 comments on commit be1e111

Please sign in to comment.