/
imptree.go
167 lines (141 loc) · 4.37 KB
/
imptree.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
// Copyright 2022 tobbstr. All rights reserved.
// Use of this source code is governed by a MIT-
// license that can be found in the LICENSE file.
package imptree
import (
"fmt"
"os"
"golang.org/x/tools/go/packages"
)
// MatchPkg is a predicate function for selecting which Go packages to include in the tree
type MatchPkg func(*packages.Package) bool
// Node represents a Go source package
type Node struct {
// Children are links to imported packages
Children []*Node
// Parents are links to packages that import this package
Parents []*Node
// PkgPath is the import path to this package
PkgPath string
}
func (n *Node) Remove(node *Node) {
removeNodeRecursively(n, node)
}
func removeNodeRecursively(n *Node, removableNode *Node) {
for i := 0; i < len(n.Children); i++ {
if n.Children[i] == removableNode {
// if i is last element in slice
if len(n.Children) == i+1 {
n.Children = n.Children[:i]
break
}
n.Children = append(n.Children[:i], n.Children[i+1:]...)
i--
continue
}
removeNodeRecursively(n.Children[i], removableNode)
}
}
// Builder is a tree builder. A Builder should not be reused for different trees, instead a new Builder should
// be instantiated.
type Builder struct {
// nodes maps import paths to Nodes
nodes map[string]*Node
// loadPkgs is a hook that allows for testing.
// See https://pkg.go.dev/golang.org/x/tools/go/packages#Load for details regarding its actual
// implementation.
loadPkgs func(cfg *packages.Config, patterns ...string) ([]*packages.Package, error)
// printLoadPkgsErrors is a hook that allows for testing.
// See https://pkg.go.dev/golang.org/x/tools/go/packages#PrintErrors for details regarding its actual
// implementation.
printLoadPkgsErrors func(pkgs []*packages.Package) int
}
// NewBuilder constructs an initialized Builder
func NewBuilder() *Builder {
return &Builder{
nodes: make(map[string]*Node),
loadPkgs: packages.Load,
printLoadPkgsErrors: packages.PrintErrors,
}
}
// Build builds and returns a doubly-linked tree of import paths, so it's possible to see which packages are
// imported by a package (its children) and also which packages import a package (its parents). The tree's
// root package is given by importPath. Only packages matched by matchPkg are included in the tree.
//
// Ex.
// builder.Build("github.com/johndoe/example/cmd/acme", func(pkg *package.Package) bool{
// // includes only packages belonging to the same module
// if strings.Contains(pkg.PkgPath, "github.com/johndoe/example") {
// return true
// }
// return false
// })
func (b *Builder) Build(importPath string, matchPkg MatchPkg) (*Node, error) {
cfg := &packages.Config{}
// Bypass default vendor mode, as we need a package not available in the
// std module vendor folder.
cfg.Env = append(os.Environ(), "GOFLAGS=-mod=mod")
cfg.Mode = packages.NeedImports | packages.NeedName
// pkgs, err := packages.Load(cfg, importPath)
pkgs, err := b.loadPkgs(cfg, importPath)
if err != nil {
return nil, err
}
// if packages.PrintErrors(pkgs) > 0 || len(pkgs) != 1 {
if b.printLoadPkgsErrors(pkgs) > 0 || len(pkgs) != 1 {
return nil, fmt.Errorf("failed to load source package")
}
pkg := pkgs[0]
// build tree
b.buildTree(pkg, matchPkg)
// find tree root node and return it
for _, node := range b.nodes {
for node.Parents != nil {
node = node.Parents[0]
}
return node, nil
}
return nil, fmt.Errorf("could not find tree root node")
}
func (b *Builder) buildTree(pkg *packages.Package, matchPkg MatchPkg) {
if !matchPkg(pkg) {
return
}
var node *Node
if n, ok := b.nodes[pkg.PkgPath]; ok {
node = n
} else {
node = &Node{PkgPath: pkg.PkgPath}
b.nodes[node.PkgPath] = node
}
for importPath, childPkg := range pkg.Imports {
if !matchPkg(childPkg) {
continue
}
var childNode *Node
if cn, ok := b.nodes[importPath]; ok {
childNode = cn
} else {
childNode = &Node{PkgPath: importPath}
b.nodes[importPath] = childNode
}
if !containsNode(childNode.Parents, node) {
childNode.Parents = append(childNode.Parents, node)
}
if !containsNode(node.Children, childNode) {
node.Children = append(node.Children, childNode)
}
b.buildTree(childPkg, matchPkg)
}
}
func containsNode(slc []*Node, node *Node) bool {
if len(slc) == 0 || node == nil {
return false
}
for i := 0; i < len(slc); i++ {
if slc[i] == node {
return true
}
}
return false
}