/
astutil.go
114 lines (93 loc) · 2.41 KB
/
astutil.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package astutil
import (
"errors"
"fmt"
"go/ast"
"strings"
"golang.org/x/tools/go/packages"
)
const (
buildTagPrefix = "+build"
)
// PackageImports is map of imports with their package names
type PackageImports map[string]string
// UsesImport is for analyze if the import dependency is in use
func UsesImport(f *ast.File, packageImports PackageImports, importPath string) bool {
importIdentNames := make(map[string]struct{}, len(f.Imports))
var importSpec *ast.ImportSpec
for _, spec := range f.Imports {
name := spec.Name.String()
switch name {
case "<nil>":
pkgName := packageImports[importPath]
importIdentNames[pkgName] = struct{}{}
case "_", ".":
return true
default:
importIdentNames[name] = struct{}{}
}
if importPath == strings.Trim(spec.Path.Value, `"`) {
importSpec = spec
}
}
var used bool
ast.Walk(
visitFn(
func(node ast.Node) {
sel, ok := node.(*ast.SelectorExpr)
if ok {
ident, ok := sel.X.(*ast.Ident)
if ok {
if _, ok := importIdentNames[ident.Name]; ok {
pkg := packageImports[importPath]
if (ident.Name == pkg || ident.Name == importSpec.Name.String()) && ident.Obj == nil {
used = true
return
}
}
}
}
},
), f,
)
return used
}
// LoadPackageDependencies will return all package's imports with it names:
// key - package(ex.: github/pkg/errors), value - name(ex.: errors)
func LoadPackageDependencies(dir, buildTag string) (PackageImports, error) {
cfg := &packages.Config{
Dir: dir,
Tests: true,
Mode: packages.NeedName | packages.NeedImports,
}
if buildTag != "" {
cfg.BuildFlags = []string{fmt.Sprintf(`-tags=%s`, buildTag)}
}
pkgs, err := packages.Load(cfg)
if err != nil {
return PackageImports{}, err
}
if packages.PrintErrors(pkgs) > 0 {
return PackageImports{}, errors.New("package has an errors")
}
result := PackageImports{}
for _, pkg := range pkgs {
for imprt, pkg := range pkg.Imports {
result[imprt] = pkg.Name
}
}
return result, nil
}
// ParseBuildTag parse `// +build ...` on a first line of *ast.File
func ParseBuildTag(f *ast.File) string {
comments := f.Comments
if len(comments) > 0 && strings.Contains(comments[0].Text(), buildTagPrefix) {
return strings.TrimSpace(strings.Trim(comments[0].Text(), buildTagPrefix))
}
return ""
}
type visitFn func(node ast.Node)
func (f visitFn) Visit(node ast.Node) ast.Visitor {
f(node)
return f
}