forked from system-pclub/GCatch
/
fillreturns.go
259 lines (233 loc) · 7.04 KB
/
fillreturns.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package fillreturns defines an Analyzer that will attempt to
// automatically fill in a return statement that has missing
// values with zero value elements.
package fillreturns
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/types"
"regexp"
"strconv"
"strings"
"github.com/system-pclub/GCatch/GCatch/tools/go/analysis"
"github.com/system-pclub/GCatch/GCatch/tools/go/ast/astutil"
"github.com/system-pclub/GCatch/GCatch/tools/internal/analysisinternal"
"github.com/system-pclub/GCatch/GCatch/tools/internal/typeparams"
)
const Doc = `suggested fixes for "wrong number of return values (want %d, got %d)"
This checker provides suggested fixes for type errors of the
type "wrong number of return values (want %d, got %d)". For example:
func m() (int, string, *bool, error) {
return
}
will turn into
func m() (int, string, *bool, error) {
return 0, "", nil, nil
}
This functionality is similar to https://github.com/sqs/goreturns.
`
var Analyzer = &analysis.Analyzer{
Name: "fillreturns",
Doc: Doc,
Requires: []*analysis.Analyzer{},
Run: run,
RunDespiteErrors: true,
}
var wrongReturnNumRegex = regexp.MustCompile(`wrong number of return values \(want (\d+), got (\d+)\)`)
func run(pass *analysis.Pass) (interface{}, error) {
info := pass.TypesInfo
if info == nil {
return nil, fmt.Errorf("nil TypeInfo")
}
errors := analysisinternal.GetTypeErrors(pass)
outer:
for _, typeErr := range errors {
// Filter out the errors that are not relevant to this analyzer.
if !FixesError(typeErr.Msg) {
continue
}
var file *ast.File
for _, f := range pass.Files {
if f.Pos() <= typeErr.Pos && typeErr.Pos <= f.End() {
file = f
break
}
}
if file == nil {
continue
}
// Get the end position of the error.
var buf bytes.Buffer
if err := format.Node(&buf, pass.Fset, file); err != nil {
continue
}
typeErrEndPos := analysisinternal.TypeErrorEndPos(pass.Fset, buf.Bytes(), typeErr.Pos)
// Get the path for the relevant range.
path, _ := astutil.PathEnclosingInterval(file, typeErr.Pos, typeErrEndPos)
if len(path) == 0 {
return nil, nil
}
// Check to make sure the node of interest is a ReturnStmt.
ret, ok := path[0].(*ast.ReturnStmt)
if !ok {
return nil, nil
}
// Get the function type that encloses the ReturnStmt.
var enclosingFunc *ast.FuncType
for _, n := range path {
switch node := n.(type) {
case *ast.FuncLit:
enclosingFunc = node.Type
case *ast.FuncDecl:
enclosingFunc = node.Type
}
if enclosingFunc != nil {
break
}
}
if enclosingFunc == nil {
continue
}
// Skip any generic enclosing functions, since type parameters don't
// have 0 values.
// TODO(rstambler): We should be able to handle this if the return
// values are all concrete types.
if tparams := typeparams.ForFuncType(enclosingFunc); tparams != nil && tparams.NumFields() > 0 {
return nil, nil
}
// Find the function declaration that encloses the ReturnStmt.
var outer *ast.FuncDecl
for _, p := range path {
if p, ok := p.(*ast.FuncDecl); ok {
outer = p
break
}
}
if outer == nil {
return nil, nil
}
// Skip any return statements that contain function calls with multiple return values.
for _, expr := range ret.Results {
e, ok := expr.(*ast.CallExpr)
if !ok {
continue
}
if tup, ok := info.TypeOf(e).(*types.Tuple); ok && tup.Len() > 1 {
continue outer
}
}
// Duplicate the return values to track which values have been matched.
remaining := make([]ast.Expr, len(ret.Results))
copy(remaining, ret.Results)
fixed := make([]ast.Expr, len(enclosingFunc.Results.List))
// For each value in the return function declaration, find the leftmost element
// in the return statement that has the desired type. If no such element exits,
// fill in the missing value with the appropriate "zero" value.
var retTyps []types.Type
for _, ret := range enclosingFunc.Results.List {
retTyps = append(retTyps, info.TypeOf(ret.Type))
}
matches :=
analysisinternal.FindMatchingIdents(retTyps, file, ret.Pos(), info, pass.Pkg)
for i, retTyp := range retTyps {
var match ast.Expr
var idx int
for j, val := range remaining {
if !matchingTypes(info.TypeOf(val), retTyp) {
continue
}
if !analysisinternal.IsZeroValue(val) {
match, idx = val, j
break
}
// If the current match is a "zero" value, we keep searching in
// case we find a non-"zero" value match. If we do not find a
// non-"zero" value, we will use the "zero" value.
match, idx = val, j
}
if match != nil {
fixed[i] = match
remaining = append(remaining[:idx], remaining[idx+1:]...)
} else {
idents, ok := matches[retTyp]
if !ok {
return nil, fmt.Errorf("invalid return type: %v", retTyp)
}
// Find the identifer whose name is most similar to the return type.
// If we do not find any identifer that matches the pattern,
// generate a zero value.
value := analysisinternal.FindBestMatch(retTyp.String(), idents)
if value == nil {
value = analysisinternal.ZeroValue(
pass.Fset, file, pass.Pkg, retTyp)
}
if value == nil {
return nil, nil
}
fixed[i] = value
}
}
// Remove any non-matching "zero values" from the leftover values.
var nonZeroRemaining []ast.Expr
for _, expr := range remaining {
if !analysisinternal.IsZeroValue(expr) {
nonZeroRemaining = append(nonZeroRemaining, expr)
}
}
// Append leftover return values to end of new return statement.
fixed = append(fixed, nonZeroRemaining...)
newRet := &ast.ReturnStmt{
Return: ret.Pos(),
Results: fixed,
}
// Convert the new return statement AST to text.
var newBuf bytes.Buffer
if err := format.Node(&newBuf, pass.Fset, newRet); err != nil {
return nil, err
}
pass.Report(analysis.Diagnostic{
Pos: typeErr.Pos,
End: typeErrEndPos,
Message: typeErr.Msg,
SuggestedFixes: []analysis.SuggestedFix{{
Message: "Fill in return values",
TextEdits: []analysis.TextEdit{{
Pos: ret.Pos(),
End: ret.End(),
NewText: newBuf.Bytes(),
}},
}},
})
}
return nil, nil
}
func matchingTypes(want, got types.Type) bool {
if want == got || types.Identical(want, got) {
return true
}
// Code segment to help check for untyped equality from (golang/go#32146).
if rhs, ok := want.(*types.Basic); ok && rhs.Info()&types.IsUntyped > 0 {
if lhs, ok := got.Underlying().(*types.Basic); ok {
return rhs.Info()&types.IsConstType == lhs.Info()&types.IsConstType
}
}
return types.AssignableTo(want, got) || types.ConvertibleTo(want, got)
}
func FixesError(msg string) bool {
matches := wrongReturnNumRegex.FindStringSubmatch(strings.TrimSpace(msg))
if len(matches) < 3 {
return false
}
if _, err := strconv.Atoi(matches[1]); err != nil {
return false
}
if _, err := strconv.Atoi(matches[2]); err != nil {
return false
}
return true
}