diff --git a/dig_test.go b/dig_test.go index 93ff05f7..a9ea730c 100644 --- a/dig_test.go +++ b/dig_test.go @@ -247,7 +247,27 @@ func TestEndToEndSuccess(t *testing.T) { require.NoError(t, c.Invoke(func(a *A, b *B) {}), "AB invoke failed") require.Equal(t, 1, count, "Constructor must be called once") }) + t.Run("method invocation inside Invoke", func(t *testing.T) { + c := New() + type A struct{} + type B struct{} + cA := func() (*A, error) { + return &A{}, nil + } + cB := func() (*B, error) { + return &B{}, nil + } + getA := func(a *A) { + c.Invoke(func(b *B) { + assert.NotNil(t, b, "got nil B") + }) + assert.NotNil(t, a, "got nil A") + } + require.NoError(t, c.Provide(cA), "provide failed") + require.NoError(t, c.Provide(cB), "provide failed") + require.NoError(t, c.Invoke(getA), "A invoke failed") + }) t.Run("collections and instances of same type", func(t *testing.T) { c := New() require.NoError(t, c.Provide(func() []*bytes.Buffer { @@ -259,6 +279,17 @@ func TestEndToEndSuccess(t *testing.T) { }) } +func TestProvideConstructorErrors(t *testing.T) { + t.Run("multiple-type constructor returns multiple objects of same type", func(t *testing.T) { + c := New() + type A struct{} + constructor := func() (*A, *A, error) { + return &A{}, &A{}, nil + } + require.Error(t, c.Provide(constructor), "provide failed") + }) +} + func TestProvideRespectsConstructorErrors(t *testing.T) { t.Run("constructor succeeds", func(t *testing.T) { c := New() @@ -349,6 +380,21 @@ func TestProvideKnownTypesFails(t *testing.T) { }) } + t.Run("provide constructor twice", func(t *testing.T) { + c := New() + assert.NoError(t, c.Provide(func() *bytes.Buffer { return nil })) + assert.Error(t, c.Provide(func() *bytes.Buffer { return nil })) + }) + t.Run("provide instance and constructor fails", func(t *testing.T) { + c := New() + assert.NoError(t, c.Provide(&bytes.Buffer{})) + assert.Error(t, c.Provide(func() *bytes.Buffer { return nil })) + }) + t.Run("provide constructor then object instance fails", func(t *testing.T) { + c := New() + assert.NoError(t, c.Provide(func() *bytes.Buffer { return nil })) + assert.Error(t, c.Provide(&bytes.Buffer{})) + }) } func TestProvideCycleFails(t *testing.T) {