Skip to content

Commit

Permalink
util/codegen: add NamedTypes
Browse files Browse the repository at this point in the history
And use it in cmd/cloner.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
  • Loading branch information
josharian committed Sep 17, 2021
1 parent 367a973 commit d8a8f70
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 26 deletions.
30 changes: 4 additions & 26 deletions cmd/cloner/cloner.go
Expand Up @@ -17,8 +17,6 @@ import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/token"
"go/types"
"log"
"os"
Expand Down Expand Up @@ -62,33 +60,13 @@ func main() {
pkg := pkgs[0]
buf := new(bytes.Buffer)
imports := make(map[string]struct{})
namedTypes := codegen.NamedTypes(pkg)
for _, typeName := range typeNames {
found := false
for _, file := range pkg.Syntax {
for _, d := range file.Decls {
decl, ok := d.(*ast.GenDecl)
if !ok || decl.Tok != token.TYPE {
continue
}
for _, s := range decl.Specs {
spec, ok := s.(*ast.TypeSpec)
if !ok || spec.Name.Name != typeName {
continue
}
typeNameObj := pkg.TypesInfo.Defs[spec.Name]
typ, ok := typeNameObj.Type().(*types.Named)
if !ok {
continue
}
pkg := typeNameObj.Pkg()
gen(buf, imports, typ, pkg)
found = true
}
}
}
if !found {
typ, ok := namedTypes[typeName]
if !ok {
log.Fatalf("could not find type %s", typeName)
}
gen(buf, imports, typ, pkg.Types)
}

w := func(format string, args ...interface{}) {
Expand Down
30 changes: 30 additions & 0 deletions util/codegen/codegen.go
Expand Up @@ -8,9 +8,13 @@ package codegen
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/token"
"go/types"
"os"

"golang.org/x/tools/go/packages"
)

// WriteFormatted writes code to path.
Expand Down Expand Up @@ -41,6 +45,32 @@ func WriteFormatted(code []byte, path string) error {
return nil
}

// NamedTypes returns all named types in pkg, keyed by their type name.
func NamedTypes(pkg *packages.Package) map[string]*types.Named {
nt := make(map[string]*types.Named)
for _, file := range pkg.Syntax {
for _, d := range file.Decls {
decl, ok := d.(*ast.GenDecl)
if !ok || decl.Tok != token.TYPE {
continue
}
for _, s := range decl.Specs {
spec, ok := s.(*ast.TypeSpec)
if !ok {
continue
}
typeNameObj := pkg.TypesInfo.Defs[spec.Name]
typ, ok := typeNameObj.Type().(*types.Named)
if !ok {
continue
}
nt[spec.Name.Name] = typ
}
}
}
return nt
}

// AssertStructUnchanged generates code that asserts at compile time that type t is unchanged.
// tname is the named type corresponding to t.
// ctx is a single-word context for this assertion, such as "Clone".
Expand Down

0 comments on commit d8a8f70

Please sign in to comment.