-
Notifications
You must be signed in to change notification settings - Fork 0
/
typefuncs.go
109 lines (101 loc) · 2.68 KB
/
typefuncs.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
package dispel
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"os"
)
// walker adapts a function to satisfy the ast.Visitor interface.
// The function return whether the walk should proceed into the node's children.
type walker func(ast.Node) bool
func (w walker) Visit(node ast.Node) ast.Visitor {
if w(node) {
return w
}
return nil
}
// FindTypesFuncs parses the AST of files of a package, look for methods on types listed in typesNames.
// It returns a map of func names -> *ast.FuncDecl.
func FindTypesFuncs(path string, pkgName string, typeNames []string, excludeFiles []string) (map[string]*ast.FuncDecl, error) {
fset := token.NewFileSet()
pkgs, err := parser.ParseDir(fset, path, func(fi os.FileInfo) bool {
for _, f := range excludeFiles {
if fi.Name() == f {
return false
}
}
return true
}, parser.DeclarationErrors)
if err != nil {
return nil, err
}
pkg, ok := pkgs[pkgName]
if !ok {
return nil, fmt.Errorf("%s: package not found in %q", pkgName, path)
}
funcDecls := make(map[string]*ast.FuncDecl)
for _, astFile := range pkg.Files {
ast.Walk(walker(func(node ast.Node) bool {
switch v := node.(type) {
case *ast.FuncDecl:
if v.Recv != nil {
// this is a method, find the type of the receiver
field := v.Recv.List[0]
var ident *ast.Ident
// Look on value and pointer receivers
switch t := field.Type.(type) {
default:
return true
case *ast.StarExpr:
ident, ok = t.X.(*ast.Ident)
if !ok {
return true
}
case *ast.Ident:
ident = t
}
for _, typeName := range typeNames {
if typeName == ident.Name {
funcDecls[v.Name.String()] = v
return true
}
}
}
}
return true
}), astFile)
}
return funcDecls, nil
}
// FindTypes parses the AST of files of a package, look for the types declared in those files, excluding those listed in excludeFiles.
// It returns a map of type names -> *ast.TypeSpec.
func FindTypes(path string, pkgName string, excludeFiles []string) (map[string]*ast.TypeSpec, error) {
fset := token.NewFileSet()
pkgs, err := parser.ParseDir(fset, path, func(fi os.FileInfo) bool {
for _, f := range excludeFiles {
if fi.Name() == f {
return false
}
}
return true
}, parser.DeclarationErrors)
if err != nil {
return nil, err
}
pkg, ok := pkgs[pkgName]
if !ok {
return nil, fmt.Errorf("%s: package not found in %q", pkgName, path)
}
typeSpecs := make(map[string]*ast.TypeSpec)
for _, astFile := range pkg.Files {
ast.Walk(walker(func(node ast.Node) bool {
switch v := node.(type) {
case *ast.TypeSpec:
typeSpecs[v.Name.String()] = v
}
return true
}), astFile)
}
return typeSpecs, nil
}