Skip to content

Commit

Permalink
Switch to source importing and allow additional imports.
Browse files Browse the repository at this point in the history
Go 1.5 introduced the 'source' compiler. See
golang/go#11415 for details.

This avoids the problem of using out of date information when generating
code. It also means we don't have to have to run `go install` prior to
generating.

I also added the ability to specify additional imports since we use
codegen to generate files that depend on types outside testify.
  • Loading branch information
nomis52 committed Apr 19, 2019
1 parent 34c6fa2 commit 567a8d0
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions _codegen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ import (
)

var (
pkg = flag.String("assert-path", "github.com/stretchr/testify/assert", "Path to the assert package")
includeF = flag.Bool("include-format-funcs", false, "include format functions such as Errorf and Equalf")
outputPkg = flag.String("output-package", "", "package for the resulting code")
tmplFile = flag.String("template", "", "What file to load the function template from")
out = flag.String("out", "", "What file to write the source code to")
pkg = flag.String("assert-path", "github.com/stretchr/testify/assert", "Path to the assert package")
includeF = flag.Bool("include-format-funcs", false, "include format functions such as Errorf and Equalf")
outputPkg = flag.String("output-package", "", "package for the resulting code")
tmplFile = flag.String("template", "", "What file to load the function template from")
out = flag.String("out", "", "What file to write the source code to")
additionalImports = flag.String("additional-imports", "", "additional packages to import")
)

func main() {
Expand Down Expand Up @@ -61,13 +62,20 @@ func generateCode(importer imports.Importer, funcs []testFunc) error {
return err
}

imports := importer.Imports()
if len(*additionalImports) > 0 {
for _, pkg := range strings.Split(*additionalImports, ",") {
imports[pkg] = path.Base(pkg)
}
}

// Generate header
if err := tmplHead.Execute(buff, struct {
Name string
Imports map[string]string
}{
*outputPkg,
importer.Imports(),
imports,
}); err != nil {
return err
}
Expand Down Expand Up @@ -194,7 +202,7 @@ func parsePackageSource(pkg string) (*types.Scope, *doc.Package, error) {
}

cfg := types.Config{
Importer: importer.Default(),
Importer: importer.For("source", nil),
}
info := types.Info{
Defs: make(map[*ast.Ident]types.Object),
Expand Down

0 comments on commit 567a8d0

Please sign in to comment.