From e3e1062546309bfcc5f9aaacdc154b900b31b86b Mon Sep 17 00:00:00 2001 From: andrewwillette Date: Sat, 16 Mar 2024 23:02:34 -0500 Subject: [PATCH] adding support for mock.Anything in slices --- mock/mock.go | 167 ++++++++++++++++++++++++++-------------------- mock/mock_test.go | 30 +++++++++ 2 files changed, 123 insertions(+), 74 deletions(-) diff --git a/mock/mock.go b/mock/mock.go index 3da53cdaf..10df1c80d 100644 --- a/mock/mock.go +++ b/mock/mock.go @@ -930,99 +930,118 @@ func (args Arguments) Diff(objects []interface{}) (string, int) { } for i := 0; i < maxArgCount; i++ { - var actual, expected interface{} - var actualFmt, expectedFmt string + var expected, actual interface{} + if len(args) <= i { + expected = "(Missing)" + } else { + expected = args[i] + } if len(objects) <= i { actual = "(Missing)" - actualFmt = "(Missing)" } else { actual = objects[i] - actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual) } - - if len(args) <= i { - expected = "(Missing)" - expectedFmt = "(Missing)" - } else { - expected = args[i] - expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected) + equal, elementOutput := compareElements(expected, actual, i, false) + output += elementOutput + if !equal { + differences++ } + } - if matcher, ok := expected.(argumentMatcher); ok { - var matches bool - func() { - defer func() { - if r := recover(); r != nil { - actualFmt = fmt.Sprintf("panic in argument matcher: %v", r) - } - }() - matches = matcher.Matches(actual) + if differences == 0 { + return "No differences.", differences + } + + return output, differences +} + +func compareElements(expected, actual interface{}, i int, isSlice bool) (bool, string) { + var expectedFmt, actualFmt string + if expected == "(Missing)" { + expectedFmt = expected.(string) + } else { + expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected) + } + if actual == "(Missing)" { + actualFmt = actual.(string) + } else { + actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual) + } + if matcher, ok := expected.(argumentMatcher); ok { + var matches bool + func() { + defer func() { + if r := recover(); r != nil { + actualFmt = fmt.Sprintf("panic in argument matcher: %v", r) + } }() - if matches { - output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher) + matches = matcher.Matches(actual) + }() + if matches { + return true, fmt.Sprintf("\t%d: PASS: %s matched by %s\n", i, actualFmt, matcher) + } else { + return false, fmt.Sprintf("\t%d: FAIL: %s not matched by %s\n", i, actualFmt, matcher) + } + } else { + switch expected := expected.(type) { + case anythingOfTypeArgument: + if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) { + return false, fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, expected, reflect.TypeOf(actual).Name(), actualFmt) } else { - differences++ - output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher) + return true, "" } - } else { - switch expected := expected.(type) { - case anythingOfTypeArgument: - // type checking - if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) { - // not match - differences++ - output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt) - } - case *IsTypeArgument: - actualT := reflect.TypeOf(actual) - if actualT != expected.t { - differences++ - output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected.t.Name(), actualT.Name(), actualFmt) - } - case *FunctionalOptionsArgument: - t := expected.value + case *IsTypeArgument: + actualT := reflect.TypeOf(actual) + if actualT != expected.t { + return false, fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, expected.t, reflect.TypeOf(actual).Name(), actualFmt) + } else { + return true, "" + } + case *FunctionalOptionsArgument: + t := expected.value - var name string - tValue := reflect.ValueOf(t) - if tValue.Len() > 0 { - name = "[]" + reflect.TypeOf(tValue.Index(0).Interface()).String() - } + var name string + tValue := reflect.ValueOf(t) + if tValue.Len() > 0 { + name = "[]" + reflect.TypeOf(tValue.Index(0).Interface()).String() + } - tName := reflect.TypeOf(t).Name() - if name != reflect.TypeOf(actual).String() && tValue.Len() != 0 { - differences++ - output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, tName, reflect.TypeOf(actual).Name(), actualFmt) + tName := reflect.TypeOf(t).Name() + if name != reflect.TypeOf(actual).String() && tValue.Len() != 0 { + return false, fmt.Sprintf("\t%d: FAIL: type %s != type %s - %s\n", i, tName, reflect.TypeOf(actual).Name(), actualFmt) + } else { + if ef, af := assertOpts(t, actual); ef == "" && af == "" { + return true, fmt.Sprintf("\t%d: PASS: %s == %s\n", i, tName, tName) } else { - if ef, af := assertOpts(t, actual); ef == "" && af == "" { - // match - output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, tName, tName) - } else { - // not match - differences++ - output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, af, ef) - } + return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, af, ef) } - - default: - if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { - // match - output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt) - } else { - // not match - differences++ - output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt) + } + case []interface{}: + if ev, av := reflect.ValueOf(expected), reflect.ValueOf(actual); ev.Kind() == reflect.Slice && av.Kind() == reflect.Slice { + // Unroll slices to check for Anything / AnythingOFType + if ev.Len() != av.Len() { + return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt) + } + for e := 0; e < ev.Len(); e++ { + equal, _ := compareElements(ev.Index(e).Interface(), av.Index(e).Interface(), i, true) + if !equal { + return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt) + } } + return true, fmt.Sprintf("\t%d: PASS: %s == %s\n", i, actualFmt, expectedFmt) + } + default: + if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) { + // match + return true, fmt.Sprintf("\t%d: PASS: %s == %s\n", i, actualFmt, expectedFmt) + } else { + // not match + return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt) } } - } - - if differences == 0 { - return "No differences.", differences - } - - return output, differences + return false, fmt.Sprintf("\t%d: FAIL: %s != %s\n", i, actualFmt, expectedFmt) } // Assert compares the arguments with the specified objects and fails if diff --git a/mock/mock_test.go b/mock/mock_test.go index b80a8a75b..e267dd351 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -2146,3 +2146,33 @@ type user interface { type mockUser struct{ Mock } func (m *mockUser) Use(c caller) { m.Called(c) } + +func TestAnythingInSlices(t *testing.T) { + m := &TestExampleImplementation{} + + m.On("TheExampleMethodVariadic", []interface{}{1, Anything, 3, Anything, 5}).Return(nil) + var err error + + assert.NotPanics(t, func() { + err = m.TheExampleMethodVariadic(1, 2, 3, 4, 5) + }) + + assert.NoError(t, err) + m.AssertExpectations(t) + m.AssertCalled(t, "TheExampleMethodVariadic", []interface{}{Anything, 2, Anything, 4, Anything}) +} + +func TestAnythingOfTypeInSlices(t *testing.T) { + m := &TestExampleImplementation{} + + m.On("TheExampleMethodVariadic", []interface{}{1, AnythingOfType("int"), 3, AnythingOfType("int"), 5}).Return(nil) + var err error + + assert.NotPanics(t, func() { + err = m.TheExampleMethodVariadic(1, 2, 3, 4, 5) + }) + + assert.NoError(t, err) + m.AssertExpectations(t) + m.AssertCalled(t, "TheExampleMethodVariadic", []interface{}{AnythingOfType("int"), 2, AnythingOfType("int"), 4, AnythingOfType("int")}) +}