Skip to content

Commit

Permalink
Fix for Broken ... params (Issue #38)
Browse files Browse the repository at this point in the history
  • Loading branch information
Shawn Burke committed Aug 4, 2015
1 parent 14e99be commit 5a49587
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
5 changes: 5 additions & 0 deletions mockery/fixtures/requester_variable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package test

type RequesterVariable interface {
Get(values ...string) bool
}
19 changes: 18 additions & 1 deletion mockery/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,23 @@ func (g *Generator) Generate() error {
g.printf("(%s) {\n", strings.Join(returns, ", "))
}

formatParamNames := func() string {
names := ""
for i, name := range paramNames {
if i > 0 {
names += ", "
}

paramType := paramTypes[i]
// for variable args, move the ... to the end.
if strings.Index(paramType, "...") == 0 {
name += "..."
}
names += name
}
return names
}

if len(returnTypes) > 0 {
g.printf("\tret := _m.Called(%s)\n\n", strings.Join(paramNames, ", "))

Expand All @@ -310,7 +327,7 @@ func (g *Generator) Generate() error {
for idx, typ := range returnTypes {
g.printf("\tvar r%d %s\n", idx, typ)
g.printf("\tif rf, ok := ret.Get(%d).(func(%s) %s); ok {\n", idx, strings.Join(paramTypes, ", "), typ)
g.printf("\t\tr%d = rf(%s)\n", idx, strings.Join(paramNames, ", "))
g.printf("\t\tr%d = rf(%s)\n", idx, formatParamNames())
g.printf("\t} else {\n")
if typ == "error" {
g.printf("\t\tr%d = ret.Error(%d)\n", idx, idx)
Expand Down
32 changes: 32 additions & 0 deletions mockery/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,38 @@ func (_m *RequesterElided) Get(path string, url string) error {
assert.Equal(t, expected, gen.buf.String())
}

func TestGeneratorVariableArgs(t *testing.T) {

parser := NewParser()
parser.Parse(filepath.Join(fixturePath, "requester_variable.go"))

iface, err := parser.Find("RequesterVariable")

gen := NewGenerator(iface)

err = gen.Generate()
assert.NoError(t, err)
expected := `type RequesterVariable struct {
mock.Mock
}
func (_m *RequesterVariable) Get(values ...string) bool {
ret := _m.Called(values)
var r0 bool
if rf, ok := ret.Get(0).(func(...string) bool); ok {
r0 = rf(values...)
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
`

assert.Equal(t, expected, gen.buf.String())
}

func TestGeneratorFuncType(t *testing.T) {
parser := NewParser()
parser.Parse(filepath.Join(fixturePath, "func_type.go"))
Expand Down

0 comments on commit 5a49587

Please sign in to comment.