Skip to content

Commit

Permalink
add replace-type parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
RangelReale committed Feb 8, 2023
1 parent 8641a5b commit 3fdb7f2
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 2 deletions.
1 change: 1 addition & 0 deletions cmd/mockery.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ func NewRootCmd() *cobra.Command {
pFlags.Bool("unroll-variadic", true, "For functions with variadic arguments, do not unroll the arguments into the underlying testify call. Instead, pass variadic slice as-is.")
pFlags.Bool("exported", false, "Generates public mocks for private interfaces.")
pFlags.Bool("with-expecter", false, "Generate expecter utility around mock's On, Run and Return methods with explicit types. This option is NOT compatible with -unroll-variadic=false")
pFlags.StringArray("replace-type", nil, "Replace types")

viper.BindPFlags(pFlags)

Expand Down
3 changes: 2 additions & 1 deletion pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,6 @@ type Config struct {
TestOnly bool
UnrollVariadic bool `mapstructure:"unroll-variadic"`
Version bool
WithExpecter bool `mapstructure:"with-expecter"`
WithExpecter bool `mapstructure:"with-expecter"`
ReplaceType []string `mapstructure:"replace-type"`
}
12 changes: 12 additions & 0 deletions pkg/fixtures/example_project/baz/foo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package baz

import (
ifoo "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz/internal/foo"
)

type Baz = ifoo.InternalBaz

type Foo interface {
DoFoo() string
GetBaz() (*Baz, error)
}
6 changes: 6 additions & 0 deletions pkg/fixtures/example_project/baz/internal/foo/foo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package foo

type InternalBaz struct {
One string
Two int
}
55 changes: 54 additions & 1 deletion pkg/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,67 @@ func (g *Generator) getPackageScopedType(ctx context.Context, o *types.TypeName)
if o.Pkg() == nil || o.Pkg().Name() == "main" || (!g.KeepTree && g.InPackage && o.Pkg() == g.iface.Pkg) {
return o.Name()
}
return g.addPackageImport(ctx, o.Pkg()) + "." + o.Name()
pkg := g.addPackageImport(ctx, o.Pkg())
name := o.Name()
g.checkReplaceType(ctx, func(from replaceType, to replaceType) bool {
if o.Pkg().Path() == from.pkg && name == from.typ {
name = to.typ
return false
}
return true
})
return pkg + "." + name
}

func (g *Generator) addPackageImport(ctx context.Context, pkg *types.Package) string {
return g.addPackageImportWithName(ctx, pkg.Path(), pkg.Name())
}

type replaceType struct {
alias string
pkg string
typ string
}

func parseReplaceType(t string) replaceType {
ret := replaceType{}
r := strings.SplitN(t, ":", 2)
if len(r) > 1 {
ret.alias = r[0]
t = r[1]
}
lastInd := strings.LastIndex(t, ".")
ret.pkg = t[:lastInd]
ret.typ = t[lastInd+1:]
return ret
}

func (g *Generator) checkReplaceType(ctx context.Context, f func(from replaceType, to replaceType) bool) {
for _, replace := range g.ReplaceType {
r := strings.SplitN(replace, "=", 2)
if len(r) == 2 {
if !f(parseReplaceType(r[0]), parseReplaceType(r[1])) {
break
}
} else {
log := zerolog.Ctx(ctx)
log.Error().Msgf("invalid replace type value: %s", replace)
}
}
}

func (g *Generator) addPackageImportWithName(ctx context.Context, path, name string) string {
g.checkReplaceType(ctx, func(from replaceType, to replaceType) bool {
if path == from.pkg {
path = to.pkg
if to.alias != "" {
name = to.alias
}
return false
}
return true
})

if existingName, pathExists := g.packagePathToName[path]; pathExists {
return existingName
}
Expand Down
84 changes: 84 additions & 0 deletions pkg/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2329,6 +2329,90 @@ import mock "github.com/stretchr/testify/mock"
s.checkPrologueGeneration(generator, expected)
}

func (s *GeneratorSuite) TestInternalPackagePrologue() {
expected := `package mocks
import baz "github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz"
import mock "github.com/stretchr/testify/mock"
`
generator := NewGenerator(
s.ctx,
config.Config{InPackage: false, LogLevel: "debug", ReplaceType: []string{
"github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz/internal/foo.InternalBaz=baz:github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz.Baz",
}},
s.getInterfaceFromFile("example_project/baz/foo.go", "Foo"),
pkg,
)

s.checkPrologueGeneration(generator, expected)
}

func (s *GeneratorSuite) TestInternalPackage() {
expected := `// Foo is an autogenerated mock type for the Foo type
type Foo struct {
mock.Mock
}
// DoFoo provides a mock function with given fields:
func (_m *Foo) DoFoo() string {
ret := _m.Called()
var r0 string
if rf, ok := ret.Get(0).(func() string); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// GetBaz provides a mock function with given fields:
func (_m *Foo) GetBaz() (*baz.Baz, error) {
ret := _m.Called()
var r0 *baz.Baz
if rf, ok := ret.Get(0).(func() *baz.Baz); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*baz.Baz)
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
type mockConstructorTestingTNewFoo interface {
mock.TestingT
Cleanup(func())
}
// NewFoo creates a new instance of Foo. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
func NewFoo(t mockConstructorTestingTNewFoo) *Foo {
mock := &Foo{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
`
cfg := config.Config{InPackage: false, LogLevel: "debug", ReplaceType: []string{
"github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz/internal/foo.InternalBaz=baz:github.com/vektra/mockery/v2/pkg/fixtures/example_project/baz.Baz",
}}

s.checkGenerationWithConfig("example_project/baz/foo.go", "Foo", cfg, expected)
}

func (s *GeneratorSuite) TestGenericGenerator() {
expected := `// RequesterGenerics is an autogenerated mock type for the RequesterGenerics type
type RequesterGenerics[TAny interface{}, TComparable comparable, TSigned constraints.Signed, TIntf test.GetInt, TExternalIntf io.Writer, TGenIntf test.GetGeneric[TSigned], TInlineType interface{ ~int | ~uint }, TInlineTypeGeneric interface {
Expand Down

0 comments on commit 3fdb7f2

Please sign in to comment.