Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support dynamic return values #742

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 78 additions & 11 deletions mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,15 @@ type Call struct {
Arguments Arguments

// Holds the arguments that should be returned when
// this method is called.
// this method is called. If the first and only value is
// function which takes and returns Arguments, that will be invoked
// on each call of the mock to determine what to return.
ReturnArguments Arguments

// if Run() was given a function which returns arguments, we'll call that whenever
// this call is invoked and use its return values as the arguments to return.
returnFunc func(args Arguments) Arguments

// Holds the caller info for the On() call
callerInfo []string

Expand Down Expand Up @@ -102,15 +108,27 @@ func (c *Call) unlock() {
c.Parent.mutex.Unlock()
}

// Return specifies the return arguments for the expectation.
// If the only return arg is a function which takes and returns Arguments, invoke it instead of returning it as the value
func (c *Call) getReturnArguments(args Arguments) Arguments {
if c.returnFunc != nil && len(c.ReturnArguments) > 0 {
panic("Cannot specify a function with Run() that returns arguments and also specify a Return() fixed set of return arguments")
}

if c.returnFunc != nil {
return c.returnFunc(args)
}

return c.ReturnArguments
}

// Return specifies fixed return arguments for the expectation, that will be returned for every invocation.
// If you want to specify dynamic return values see the Run(fn) function.
//
// Mock.On("DoSomething").Return(errors.New("failed"))
func (c *Call) Return(returnArguments ...interface{}) *Call {
c.lock()
defer c.unlock()

c.ReturnArguments = returnArguments

return c
}

Expand Down Expand Up @@ -172,18 +190,67 @@ func (c *Call) After(d time.Duration) *Call {
return c
}

// Run sets a handler to be called before returning. It can be used when
// mocking a method (such as an unmarshaler) that takes a pointer to a struct and
// sets properties in such struct
// Run sets a handler to be called before returning, possibly determining the return values of the call too.
//
// You can pass three types of functions to it:
//
// Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) {
// 1) func(Arguments) that will not affect what is returned (you can still call Return() to specify them)
//
// Mock.On("Unmarshal", mock.AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) {
// arg := args.Get(0).(*map[string]interface{})
// arg["foo"] = "bar"
// })
func (c *Call) Run(fn func(args Arguments)) *Call {
//
// 2) A function which matches the signature of your mocked function itself, and determines the return values dynamically.
//
// Mock.On("HelloWorld", mock.Anything).Run(func(name string) string {
// return "Hello " + name
// })
//
// 3) func(Arguments) Arguments which behaves like (2) except you need to do the typecasting yourself
//
// Mock.On("HelloWorld", mock.Anything).Run(func(args mock.Arguments) mock.Arguments {
// return mock.Arguments([]any{"Hello " + args[0].(string)})
// })
func (c *Call) Run(fn interface{}) *Call {
c.lock()
defer c.unlock()
c.RunFn = fn
switch f := fn.(type) {
case func(Arguments):
c.RunFn = f
case func(Arguments) Arguments:
c.returnFunc = f
default:
fnVal := reflect.ValueOf(fn)
if fnVal.Kind() != reflect.Func {
panic(fmt.Sprintf("Invalid argument passed to Run(), must be a function, is a %T", fn))
}
fnType := fnVal.Type()
c.returnFunc = func(args Arguments) (resp Arguments) {
var argVals []reflect.Value
for i, arg := range args {
if i == len(args)-1 && fnType.IsVariadic() {
// splat the variadic arg back out in the call, as expected by reflect.Value#Call
argVal := reflect.ValueOf(arg)
for j := 0; j < argVal.Len(); j++ {
argVals = append(argVals, argVal.Index(j))
}
} else {
argVals = append(argVals, reflect.ValueOf(arg))
}
}

// actually call the fn
ret := fnVal.Call(argVals)

for _, val := range ret {
resp = append(resp, val.Interface())
}

return resp
}
}

return c
}

Expand Down Expand Up @@ -555,7 +622,7 @@ func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Argumen
}

m.mutex.Lock()
returnArgs := call.ReturnArguments
returnArgs := call.getReturnArguments(arguments)
m.mutex.Unlock()

return returnArgs
Expand Down
89 changes: 89 additions & 0 deletions mock/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,95 @@ func Test_Mock_Return_Run_Out_Of_Order(t *testing.T) {
assert.NotNil(t, call.Run)
}

func Test_Mock_Run_ReturnFunc(t *testing.T) {

// make a test impl object
var mockedService = new(TestExampleImplementation)

t.Run("can dynamically set the return values", func(t *testing.T) {
mockedService.On("TheExampleMethod", Anything, Anything, Anything).
Run(func(a, b, c int) (int, error) {
return a + 40, fmt.Errorf("hmm")
}).
Twice()

answer, _ := mockedService.TheExampleMethod(2, 4, 5)
assert.Equal(t, 42, answer)

answer, _ = mockedService.TheExampleMethod(44, 4, 5)
assert.Equal(t, 84, answer)
})

t.Run("handles func(Args) Args style", func(t *testing.T) {
mockedService.On("TheExampleMethod", Anything, Anything, Anything).
Run(func(args Arguments) Arguments {
return []interface{}{args[0].(int) + 40, fmt.Errorf("hmm")}
}).
Twice()

answer, _ := mockedService.TheExampleMethod(2, 4, 5)
assert.Equal(t, 42, answer)

answer, _ = mockedService.TheExampleMethod(44, 4, 5)
assert.Equal(t, 84, answer)
})

t.Run("handles pointer input args", func(t *testing.T) {
mockedService.On("TheExampleMethod3", Anything).Run(func(et *ExampleType) error {
if et == nil {
return fmt.Errorf("Nil obj")
}
return nil
}).Twice()

err := mockedService.TheExampleMethod3(nil)
assert.Error(t, err)

err = mockedService.TheExampleMethod3(&ExampleType{})
assert.NoError(t, err)
})

t.Run("handles no return args", func(t *testing.T) {
mockedService.On("TheExampleMethod2", Anything).Run(func(yesno bool) {
// nothing to return
}).Once()

mockedService.TheExampleMethod2(true)
})

t.Run("handles variadic input args", func(t *testing.T) {
mockedService.
On("TheExampleMethodMixedVariadic", Anything, Anything).
Run(func(a int, b ...int) error {
var sum = a
for _, v := range b {
sum += v
}
return fmt.Errorf("%v", sum)
})

assert.Equal(t, "42", mockedService.TheExampleMethodMixedVariadic(40, 1, 1).Error())
assert.Equal(t, "40", mockedService.TheExampleMethodMixedVariadic(40).Error())
})

t.Run("panics if Run() called with an invalid value", func(t *testing.T) {
assert.PanicsWithValue(t,
"Invalid argument passed to Run(), must be a function, is a int",
func() { mockedService.On("TheExampleMethod").Run(42) },
)
})

t.Run("panics if both Return() and Run() are called specifying return args", func(t *testing.T) {
mockedService.On("TheExampleMethod", Anything, Anything, Anything).
Run(func(a, b, c int) (int, error) {
return a + 40, fmt.Errorf("hmm")
}).
Return(80, nil)

assert.PanicsWithValue(t, "Cannot specify a function with Run() that returns arguments and also specify a Return() fixed set of return arguments", func() { mockedService.TheExampleMethod(1, 2, 3) })
})
}

func Test_Mock_Return_Once(t *testing.T) {

// make a test impl object
Expand Down