Skip to content

Commit

Permalink
cache replace types
Browse files Browse the repository at this point in the history
  • Loading branch information
RangelReale committed Feb 21, 2023
1 parent 8e368fa commit 8ca9a1d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
8 changes: 4 additions & 4 deletions docs/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ Replace Types
The `replace-type` parameter allows adding a list of type replacements to be made in package and/or type names.
This can help overcome some parsing problems like type aliases that the Go parser doesn't provide enough information.

This parameter can be specified multiple times.

```shell
mockery --replace-type github.com/vektra/mockery/v2/baz/internal/foo.InternalBaz=baz:github.com/vektra/mockery/v2/baz.Baz
```

This parameter can be specified multiple times.

This will replace any imported named `"github.com/vektra/mockery/v2/baz/internal/foo"`
with `baz "github.com/vektra/mockery/v2/baz"`. The alias is defined with `:` before
the package name. Also, the `InternalBaz` type that comes from this package will be renamed to `baz.Baz`.
Expand Down Expand Up @@ -150,7 +150,7 @@ type Handler struct {
}
```

Mock generated without this parameter:
Invalid mock generated without this parameter (points to an `internal` folder):

```go
import (
Expand All @@ -165,7 +165,7 @@ func (_m *Handler) HandleMessage(m pubsub.Message) error {
}
```

Mock generated with this parameter.
Correct mock generated with this parameter.

```go
import (
Expand Down
45 changes: 31 additions & 14 deletions pkg/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type Generator struct {
localizationCache map[string]string
packagePathToName map[string]string
nameToPackagePath map[string]string
replaceTypeCache []*replaceTypeItem
}

// NewGenerator builds a Generator.
Expand All @@ -51,6 +52,7 @@ func NewGenerator(ctx context.Context, c config.Config, iface *Interface, pkg st
nameToPackagePath: make(map[string]string),
}

g.parseReplaceTypes(ctx)
g.addPackageImportWithName(ctx, "github.com/stretchr/testify/mock", "mock")

return g
Expand Down Expand Up @@ -103,7 +105,7 @@ func (g *Generator) getPackageScopedType(ctx context.Context, o *types.TypeName)
}
pkg := g.addPackageImport(ctx, o.Pkg())
name := o.Name()
g.checkReplaceType(ctx, func(from replaceType, to replaceType) bool {
g.checkReplaceType(ctx, func(from *replaceType, to *replaceType) bool {
if o.Pkg().Path() == from.pkg && name == from.typ {
name = to.typ
return false
Expand All @@ -117,22 +119,16 @@ func (g *Generator) addPackageImport(ctx context.Context, pkg *types.Package) st
return g.addPackageImportWithName(ctx, pkg.Path(), pkg.Name())
}

func (g *Generator) checkReplaceType(ctx context.Context, f func(from replaceType, to replaceType) bool) {
for _, replace := range g.ReplaceType {
r := strings.SplitN(replace, "=", 2)
if len(r) == 2 {
if !f(parseReplaceType(r[0]), parseReplaceType(r[1])) {
break
}
} else {
log := zerolog.Ctx(ctx)
log.Error().Msgf("invalid replace type value: %s", replace)
func (g *Generator) checkReplaceType(ctx context.Context, f func(from *replaceType, to *replaceType) bool) {
for _, item := range g.replaceTypeCache {
if !f(item.from, item.to) {
break
}
}
}

func (g *Generator) addPackageImportWithName(ctx context.Context, path, name string) string {
g.checkReplaceType(ctx, func(from replaceType, to replaceType) bool {
g.checkReplaceType(ctx, func(from *replaceType, to *replaceType) bool {
if path == from.pkg {
path = to.pkg
if to.alias != "" {
Expand All @@ -153,6 +149,22 @@ func (g *Generator) addPackageImportWithName(ctx context.Context, path, name str
return nonConflictingName
}

func (g *Generator) parseReplaceTypes(ctx context.Context) {
for _, replace := range g.Config.ReplaceType {
r := strings.SplitN(replace, "=", 2)
if len(r) != 2 {
log := zerolog.Ctx(ctx)
log.Error().Msgf("invalid replace type value: %s", replace)
continue
}

g.replaceTypeCache = append(g.replaceTypeCache, &replaceTypeItem{
from: parseReplaceType(r[0]),
to: parseReplaceType(r[1]),
})
}
}

func (g *Generator) getNonConflictingName(path, name string) string {
if !g.importNameExists(name) && (!g.InPackage || g.iface.Pkg.Name() != name) {
// do not allow imports with the same name as the package when inPackage
Expand Down Expand Up @@ -949,8 +961,13 @@ type replaceType struct {
typ string
}

func parseReplaceType(t string) replaceType {
ret := replaceType{}
type replaceTypeItem struct {
from *replaceType
to *replaceType
}

func parseReplaceType(t string) *replaceType {
ret := &replaceType{}
r := strings.SplitN(t, ":", 2)
if len(r) > 1 {
ret.alias = r[0]
Expand Down
7 changes: 5 additions & 2 deletions pkg/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2455,6 +2455,10 @@ func (_m *Foo) GetBaz() (*baz.Baz, error) {
ret := _m.Called()
var r0 *baz.Baz
var r1 error
if rf, ok := ret.Get(0).(func() (*baz.Baz, error)); ok {
return rf()
}
if rf, ok := ret.Get(0).(func() *baz.Baz); ok {
r0 = rf()
} else {
Expand All @@ -2463,7 +2467,6 @@ func (_m *Foo) GetBaz() (*baz.Baz, error) {
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
Expand Down Expand Up @@ -2905,6 +2908,6 @@ func TestParseReplaceType(t *testing.T) {

for _, test := range tests {
actual := parseReplaceType(test.value)
assert.Equal(t, test.expected, actual)
assert.Equal(t, test.expected, *actual)
}
}

0 comments on commit 8ca9a1d

Please sign in to comment.