-
Notifications
You must be signed in to change notification settings - Fork 23
/
copy.go
183 lines (151 loc) · 5.1 KB
/
copy.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
package internal
import (
"bytes"
"fmt"
"github.com/markbates/pkger"
"github.com/thoas/go-funk"
"go/ast"
"go/format"
"go/parser"
"go/printer"
"go/token"
"golang.org/x/tools/go/ast/astutil"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"strings"
)
// CopyModule copies a module path to a destination.
func CopyModule(toCopy, dest, packageName string) error {
// walk through the dir, see: https://github.com/markbates/pkger/blob/09e9684b656b/examples/app/main.go#L29
info, err := pkger.Info(toCopy)
if err != nil {
return fmt.Errorf("could not resolve %s", toCopy)
}
// get the go files to copy
goFiles := append(info.GoFiles, info.TestGoFiles...)
err = pkger.Walk(toCopy, func(filePath string, info fs.FileInfo, err error) error {
if err != nil {
return fmt.Errorf("error while walking: %w", err)
}
// if it's not a go file, skip it
if !funk.ContainsString(goFiles, info.Name()) {
return nil
}
return copyGoFile(filePath, packageName, dest, info)
})
if err != nil {
return fmt.Errorf("error while copying: %w", err)
}
return nil
}
// copyGoFile copies a go file using the package info.
func copyGoFile(filePath, packageName, dest string, info fs.FileInfo) error {
fileContents, err := getUpdatedFileContents(filePath, packageName)
if err != nil {
return fmt.Errorf("could not get updated file contents: %w", err)
}
newFile := fmt.Sprintf("%s/%s", dest, getFileName(info.Name()))
//nolint: gosec
f, err := os.Create(newFile)
if err != nil {
return fmt.Errorf("could not open file")
}
// write the contents to the file
_, err = f.Write(fileContents)
if err != nil {
return fmt.Errorf("could not write to file: %w", err)
}
err = f.Close()
if err != nil {
return fmt.Errorf("could not close file: %w", err)
}
return nil
}
// CopyFile copies a single go file. This will not bring dependencies.
func CopyFile(fileToCopy, dest, packageName string) error {
// first things first, pkger operates on go modules, so we need to trim
modulePath := path.Dir(fileToCopy)
fileName := path.Base(fileToCopy)
// make sure the last element is a file
if filepath.Ext(fileName) != ".go" {
return fmt.Errorf("must specify a .go file after module, got %s", filepath.Ext(fileName))
}
err := pkger.Walk(modulePath, func(filePath string, info fs.FileInfo, err error) error {
if err != nil {
return fmt.Errorf("error while walking: %w", err)
}
// only copy the target file
if info.Name() != fileName {
return nil
}
return copyGoFile(filePath, packageName, dest, info)
})
if err != nil {
return fmt.Errorf("error while copying: %w", err)
}
return nil
}
// getFileName gets the new file name. Gen is added here before the .go in the case of non tests
// and before _test.go in the case of tests.
func getFileName(originalName string) string {
suffix := filepath.Ext(originalName)
noExtensionName := strings.TrimSuffix(originalName, suffix)
const testSuffix = "_test"
// if it's a test strip it from the original name and add it to the suffix
testIndex := strings.LastIndex(noExtensionName, testSuffix)
if testIndex != -1 {
noExtensionName = noExtensionName[:testIndex] + strings.Replace(noExtensionName[testIndex:], testSuffix, "", 1)
suffix = testSuffix + suffix
}
return noExtensionName + "_gen" + suffix
}
// getUpdatedFileContents rewrites adds the generation header and rewrites the package name.
func getUpdatedFileContents(path, newPackageName string) (fileContents []byte, err error) {
file, err := pkger.Open(path)
if err != nil {
return fileContents, fmt.Errorf("could not open file at %s: %w", path, err)
}
fileContents, err = io.ReadAll(file)
if err != nil {
return fileContents, fmt.Errorf("could not read file %s: %w", fileContents, err)
}
// prepend the header to the file
fileContents = append([]byte(makeGeneratedHeader(path)+"\n\n"), fileContents...)
// rename the package by modifying the ast
fset := token.NewFileSet()
fileAst, err := parser.ParseFile(fset, filepath.Base(path), fileContents, parser.ParseComments)
if err != nil {
return nil, fmt.Errorf("could not parse ast. This could indicate an invalid source file: %w", err)
}
newAst := astutil.Apply(fileAst, nil, func(cursor *astutil.Cursor) bool {
if ident, ok := cursor.Node().(*ast.Ident); ok {
cursor.Replace(&ast.Ident{
NamePos: ident.NamePos,
Name: newPackageName,
Obj: ident.Obj,
})
return false
}
return true
})
fileBuffer := bytes.NewBuffer([]byte{})
err = printer.Fprint(fileBuffer, fset, newAst)
if err != nil {
return nil, fmt.Errorf("could not write resulting ast: %w", err)
}
// TODO: use golangci-lint
formatted, err := format.Source(fileBuffer.Bytes())
if err != nil {
return nil, fmt.Errorf("could not format: %w", err)
}
return formatted, nil
}
// makeGenerated header makes the code generation header
// note: this must conform to https://github.com/golangci/golangci-lint/blob/1fb67fe448da8a3fb525ecef28decceb23b42d7a/pkg/result/processors/autogenerated_exclude.go#L76
// to bypass linters.
func makeGeneratedHeader(origin string) string {
return fmt.Sprintf("// Code copied from %s for testing by synapse modulecopier DO NOT EDIT.\"", origin)
}