Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the AST visitor faster #7701

Merged
merged 9 commits into from Mar 18, 2021
193 changes: 188 additions & 5 deletions go/tools/asthelpergen/asthelpergen.go
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"go/types"
"io/ioutil"
"log"
"path"
"strings"

Expand Down Expand Up @@ -49,15 +50,42 @@ type generator interface {
createFile(pkgName string) (string, *jen.File)
}

type generatorSPI interface {
addType(t types.Type)
addFunc(name string, t methodType, code jen.Code)
scope() *types.Scope
findImplementations(iff *types.Interface, impl func(types.Type) error) error
iface() *types.Interface
}

type generator2 interface {
interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error
structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error
ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error
ptrToBasicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error
ptrToOtherMethod(t types.Type, ptr *types.Pointer, spi generatorSPI) error
sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error
basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error
}

// astHelperGen finds implementations of the given interface,
// and uses the supplied `generator`s to produce the output code
type astHelperGen struct {
DebugTypes bool
mod *packages.Module
sizes types.Sizes
namedIface *types.Named
iface *types.Interface
_iface *types.Interface
gens []generator
gens2 []generator2

methods []jen.Code
_scope *types.Scope
todo []types.Type
}

func (gen *astHelperGen) iface() *types.Interface {
return gen._iface
}

func newGenerator(mod *packages.Module, sizes types.Sizes, named *types.Named, generators ...generator) *astHelperGen {
Expand All @@ -66,7 +94,7 @@ func newGenerator(mod *packages.Module, sizes types.Sizes, named *types.Named, g
mod: mod,
sizes: sizes,
namedIface: named,
iface: named.Underlying().(*types.Interface),
_iface: named.Underlying().(*types.Interface),
gens: generators,
}
}
Expand Down Expand Up @@ -96,6 +124,31 @@ func findImplementations(scope *types.Scope, iff *types.Interface, impl func(typ
}
return nil
}
func (gen *astHelperGen) findImplementations(iff *types.Interface, impl func(types.Type) error) error {
for _, name := range gen._scope.Names() {
obj := gen._scope.Lookup(name)
if _, ok := obj.(*types.TypeName); !ok {
continue
}
baseType := obj.Type()
if types.Implements(baseType, iff) {
err := impl(baseType)
if err != nil {
return err
}
continue
}
pointerT := types.NewPointer(baseType)
if types.Implements(pointerT, iff) {
err := impl(pointerT)
if err != nil {
return err
}
continue
}
}
return nil
}

func (gen *astHelperGen) visitStruct(t types.Type, stroct *types.Struct) error {
for _, g := range gen.gens {
Expand Down Expand Up @@ -130,7 +183,7 @@ func (gen *astHelperGen) visitInterface(t types.Type, iface *types.Interface) er
// GenerateCode is the main loop where we build up the code per file.
func (gen *astHelperGen) GenerateCode() (map[string]*jen.File, error) {
pkg := gen.namedIface.Obj().Pkg()
iface, ok := gen.iface.Underlying().(*types.Interface)
iface, ok := gen._iface.Underlying().(*types.Interface)
if !ok {
return nil, fmt.Errorf("expected interface, but got %T", gen.iface)
}
Expand Down Expand Up @@ -165,6 +218,12 @@ func (gen *astHelperGen) GenerateCode() (map[string]*jen.File, error) {
result[fullPath] = code
}

gen._scope = pkg.Scope()
gen.todo = append(gen.todo, gen.namedIface)
file, code := gen.createFile(pkg.Name())
fullPath := path.Join(gen.mod.Dir, strings.TrimPrefix(pkg.Path(), gen.mod.Path), file)
result[fullPath] = code

return result, nil
}

Expand Down Expand Up @@ -247,13 +306,137 @@ func GenerateASTHelpers(packagePatterns []string, rootIface, exceptCloneType str
return types.Implements(t, iface)
}
rewriter := newRewriterGen(interestingType, nt.Obj().Name())
clone := newCloneGen(iface, scope, exceptCloneType)
generator := newGenerator(loaded[0].Module, loaded[0].TypesSizes, nt, rewriter)
generator.gens2 = append(generator.gens2, &equalsGen{})
generator.gens2 = append(generator.gens2, newCloneGen(exceptCloneType))
generator.gens2 = append(generator.gens2, &visitGen{})

generator := newGenerator(loaded[0].Module, loaded[0].TypesSizes, nt, rewriter, clone)
it, err := generator.GenerateCode()
if err != nil {
return nil, err
}

return it, nil
}

var _ generatorSPI = (*astHelperGen)(nil)

func (gen *astHelperGen) scope() *types.Scope {
return gen._scope
}

func (gen *astHelperGen) addType(t types.Type) {
gen.todo = append(gen.todo, t)
}

type methodType int

const (
clone methodType = iota
equals
visit
)

func (gen *astHelperGen) addFunc(name string, typ methodType, code jen.Code) {
var comment string
switch typ {
case clone:
comment = " creates a deep clone of the input."
case equals:
comment = " does deep equals between the two objects."
case visit:
comment = " will visit all parts of the AST"
}
gen.methods = append(gen.methods, jen.Comment(name+comment), code)
}

func (gen *astHelperGen) createFile(pkgName string) (string, *jen.File) {
out := jen.NewFile(pkgName)
out.HeaderComment(licenseFileHeader)
out.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.")
alreadyDone := map[string]bool{}
for len(gen.todo) > 0 {
t := gen.todo[0]
underlying := t.Underlying()
typeName := printableTypeName(t)
gen.todo = gen.todo[1:]

if alreadyDone[typeName] {
continue
}

switch underlying := underlying.(type) {
case *types.Interface:
gen.allGenerators(func(g generator2) error {
return g.interfaceMethod(t, underlying, gen)
})
case *types.Slice:
gen.allGenerators(func(g generator2) error {
return g.sliceMethod(t, underlying, gen)
})
case *types.Struct:
gen.allGenerators(func(g generator2) error {
return g.structMethod(t, underlying, gen)
})
case *types.Pointer:
ptrToType := underlying.Elem().Underlying()

switch ptrToType := ptrToType.(type) {
case *types.Struct:
gen.allGenerators(func(g generator2) error {
return g.ptrToStructMethod(t, ptrToType, gen)
})
case *types.Basic:
gen.allGenerators(func(g generator2) error {
return g.ptrToBasicMethod(t, ptrToType, gen)
})
default:
panic(fmt.Sprintf("%T", ptrToType))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can do log.fatal here, or is this intentional?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. If there's a good reason to panic, please document it.

//c.makePtrCloneMethod(t, ptr)
}
case *types.Basic:
gen.allGenerators(func(g generator2) error {
return g.basicMethod(t, underlying, gen)
})

default:
log.Fatalf("don't know how to handle %s %T", typeName, underlying)
}

alreadyDone[typeName] = true
}

for _, method := range gen.methods {
out.Add(method)
}

return "ast_helper.go", out
}

func (gen *astHelperGen) allGenerators(f func(g generator2) error) {
for _, g := range gen.gens2 {
err := f(g)

if err != nil {
log.Fatalf("%v", err)
}
}
}

// printableTypeName returns a string that can be used as a valid golang identifier
func printableTypeName(t types.Type) string {
switch t := t.(type) {
case *types.Pointer:
return "RefOf" + printableTypeName(t.Elem())
case *types.Slice:
return "SliceOf" + printableTypeName(t.Elem())
case *types.Named:
return t.Obj().Name()
case *types.Basic:
return strings.Title(t.Name())
case *types.Interface:
return t.String()
default:
panic(fmt.Sprintf("unknown type %T %v", t, t))
}
}