/
main.go
119 lines (100 loc) · 2.3 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package main
import (
"flag"
"go/ast"
"go/token"
"log"
"os"
"path"
"strings"
"text/template"
"golang.org/x/tools/go/packages"
)
var typeNames = flag.String("type", "", "comma-separated list of type names; must be set")
var tmpl = template.New("all")
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}
// inspect traverses AST node and stores all const names of given type name.
func inspect(node ast.Node, typeName string, names *[]string) bool {
decl, ok := node.(*ast.GenDecl)
if !ok || decl.Tok != token.CONST {
return true
}
// The name of the type of the constants we are declaring.
// Can change if this is a multi-element declaration.
typ := ""
for _, spec := range decl.Specs {
// Guaranteed to succeed as this is CONST.
vspec := spec.(*ast.ValueSpec)
if vspec.Type == nil && len(vspec.Values) > 0 {
// "X = 1" with no type but a value.
typ = ""
ce, ok := vspec.Values[0].(*ast.CallExpr)
if !ok {
continue
}
id, ok := ce.Fun.(*ast.Ident)
if !ok {
continue
}
typ = id.Name
}
if vspec.Type != nil {
// "X T". Type is defined.
ident, ok := vspec.Type.(*ast.Ident)
if !ok {
continue
}
typ = ident.Name
}
if typ != typeName {
// This is not the type we're looking for.
continue
}
for _, name := range vspec.Names {
if name.Name == "_" {
continue
}
// Add the value name to the list.
*names = append(*names, name.Name)
}
}
return false
}
// loadPackage loads the package from go:generate file.
func loadPackage() *packages.Package {
fileName := os.Getenv("GOFILE")
wd, err := os.Getwd()
handleError(err)
path := path.Join(wd, fileName)
cfg := &packages.Config{
Mode: packages.NeedSyntax | packages.NeedName,
Tests: false,
}
pkgs, err := packages.Load(cfg, path)
handleError(err)
if len(pkgs) != 1 {
log.Fatalf("error: %d packages found", len(pkgs))
}
return pkgs[0]
}
func main() {
flag.Parse()
log.SetFlags(0)
log.SetPrefix("enumall: ")
types := strings.Split(*typeNames, ",")
pkg := loadPackage()
for _, s := range pkg.Syntax {
for _, lookupTypeName := range types {
gen := generator{
PackageName: pkg.Name,
TypeName: lookupTypeName,
}
ast.Inspect(s, func(n ast.Node) bool { return inspect(n, lookupTypeName, &gen.Values) })
gen.generate()
}
}
}