-
Notifications
You must be signed in to change notification settings - Fork 34
/
imports.go
122 lines (98 loc) · 2.2 KB
/
imports.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// Package importers helps with dynamic imports for templating
package importers
import (
"bytes"
"fmt"
"regexp"
"sort"
"strings"
"sync"
"golang.org/x/tools/go/packages"
)
//nolint:gochecknoglobals
var (
pkgRgx = regexp.MustCompile(`"([^"]+)"`)
standardPackages = make(map[string]struct{})
stdPkgOnce sync.Once
)
func getStandardPackages() map[string]struct{} {
stdPkgOnce.Do(func() {
pkgs, err := packages.Load(nil, "std")
if err != nil {
panic(err)
}
for _, p := range pkgs {
standardPackages[p.PkgPath] = struct{}{}
}
})
return standardPackages
}
// List of imports
type List []string
// Len implements sort.Interface.Len
func (l List) Len() int {
return len(l)
}
// Swap implements sort.Interface.Swap
func (l List) Swap(i, j int) {
l[i], l[j] = l[j], l[i]
}
// Less implements sort.Interface.Less
func (l List) Less(i, j int) bool {
res := strings.Compare(strings.TrimLeft(l[i], "_ "), strings.TrimLeft(l[j], "_ "))
return res <= 0
}
func (l List) GetSorted() (List, List) {
var std, third List //nolint:prealloc
for _, pkg := range l {
if pkg == "" {
continue
}
var pkgName string
if pkgSlice := pkgRgx.FindStringSubmatch(pkg); len(pkgSlice) > 0 {
pkgName = pkgSlice[1]
}
if _, ok := getStandardPackages()[pkgName]; ok {
std = append(std, pkg)
continue
}
third = append(third, pkg)
}
// Make sure the lists are sorted, so that the output is consistent
sort.Sort(std)
sort.Sort(third)
return std, third
}
// Format the set into Go syntax (compatible with go imports)
func (l List) Format() []byte {
if len(l) < 1 {
return []byte{}
}
if len(l) == 1 {
return []byte(fmt.Sprintf("import %s", l[0]))
}
standard, thirdparty := l.GetSorted()
buf := &bytes.Buffer{}
buf.WriteString("import (")
for _, std := range standard {
fmt.Fprintf(buf, "\n\t%s", std)
}
if len(standard) != 0 && len(thirdparty) != 0 {
buf.WriteString("\n")
}
for _, third := range thirdparty {
fmt.Fprintf(buf, "\n\t%s", third)
}
buf.WriteString("\n)\n")
return buf.Bytes()
}
func combineStringSlices(a, b []string) []string {
c := make([]string, len(a)+len(b))
if len(a) > 0 {
copy(c, a)
}
if len(b) > 0 {
copy(c[len(a):], b)
}
return c
}