From d8a8f700007a67fc369af17c1e4c157754436977 Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Thu, 16 Sep 2021 16:24:50 -0700 Subject: [PATCH] util/codegen: add NamedTypes And use it in cmd/cloner. Signed-off-by: Josh Bleecher Snyder --- cmd/cloner/cloner.go | 30 ++++-------------------------- util/codegen/codegen.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index 90ff9d01446b9..1bc4dfbe18beb 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -17,8 +17,6 @@ import ( "bytes" "flag" "fmt" - "go/ast" - "go/token" "go/types" "log" "os" @@ -62,33 +60,13 @@ func main() { pkg := pkgs[0] buf := new(bytes.Buffer) imports := make(map[string]struct{}) + namedTypes := codegen.NamedTypes(pkg) for _, typeName := range typeNames { - found := false - for _, file := range pkg.Syntax { - for _, d := range file.Decls { - decl, ok := d.(*ast.GenDecl) - if !ok || decl.Tok != token.TYPE { - continue - } - for _, s := range decl.Specs { - spec, ok := s.(*ast.TypeSpec) - if !ok || spec.Name.Name != typeName { - continue - } - typeNameObj := pkg.TypesInfo.Defs[spec.Name] - typ, ok := typeNameObj.Type().(*types.Named) - if !ok { - continue - } - pkg := typeNameObj.Pkg() - gen(buf, imports, typ, pkg) - found = true - } - } - } - if !found { + typ, ok := namedTypes[typeName] + if !ok { log.Fatalf("could not find type %s", typeName) } + gen(buf, imports, typ, pkg.Types) } w := func(format string, args ...interface{}) { diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index 95302440bf5e7..01317073504ed 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -8,9 +8,13 @@ package codegen import ( "bytes" "fmt" + "go/ast" "go/format" + "go/token" "go/types" "os" + + "golang.org/x/tools/go/packages" ) // WriteFormatted writes code to path. @@ -41,6 +45,32 @@ func WriteFormatted(code []byte, path string) error { return nil } +// NamedTypes returns all named types in pkg, keyed by their type name. +func NamedTypes(pkg *packages.Package) map[string]*types.Named { + nt := make(map[string]*types.Named) + for _, file := range pkg.Syntax { + for _, d := range file.Decls { + decl, ok := d.(*ast.GenDecl) + if !ok || decl.Tok != token.TYPE { + continue + } + for _, s := range decl.Specs { + spec, ok := s.(*ast.TypeSpec) + if !ok { + continue + } + typeNameObj := pkg.TypesInfo.Defs[spec.Name] + typ, ok := typeNameObj.Type().(*types.Named) + if !ok { + continue + } + nt[spec.Name.Name] = typ + } + } + } + return nt +} + // AssertStructUnchanged generates code that asserts at compile time that type t is unchanged. // tname is the named type corresponding to t. // ctx is a single-word context for this assertion, such as "Clone".