/
tag_error_types.go
396 lines (333 loc) · 12.4 KB
/
tag_error_types.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
package analysis
import (
"fmt"
"go/ast"
"go/types"
"sort"
"strings"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/ast/inspector"
)
// ErrorType is a fact about a ree.Error type,
// declaring which error codes Code() might return,
// and/or what field gets returned by a call to Code().
type ErrorType struct {
Codes []string // error codes, or nil
Field *ErrorCodeField // field information, or nil
}
// ErrorCodeField is part of ErrorType,
// and declares the field that might be returned by the Code() method of the ree.Error.
type ErrorCodeField struct {
Name string
Position int
}
func (*ErrorType) AFact() {}
func (e *ErrorType) String() string {
sort.Strings(e.Codes)
return fmt.Sprintf("ErrorType{Field:%v, Codes:%v}", e.Field, strings.Join(e.Codes, " "))
}
func (f *ErrorCodeField) String() string {
return fmt.Sprintf("{Name:%q, Position:%d}", f.Name, f.Position)
}
// findAndTagErrorTypes finds all errors with a Code() method
// and exports an ErrorType fact for all valid error types.
func findAndTagErrorTypes(pass *analysis.Pass, lookup *funcLookup) {
inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
// We only need to see type declarations.
nodeFilter := []ast.Node{
(*ast.GenDecl)(nil),
}
inspect.Nodes(nodeFilter, func(node ast.Node, _ bool) bool {
genDecl := node.(*ast.GenDecl)
for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
typ := pass.TypesInfo.Defs[typeSpec.Name].Type()
// Filter out all types that are not errors with a Code() method.
if !types.Implements(typ, tReeError) {
typ = types.NewPointer(typ)
if !types.Implements(typ, tReeError) {
continue
}
}
// Export error type fact for error.
err := tagErrorType(pass, lookup, typ, typeSpec)
if err != nil {
pass.ReportRangef(node, "%v", err)
}
}
// Never recurse deeper.
return false
})
}
// tagErrorType exports an ErrorType fact for the given error if it's a valid error type.
func tagErrorType(pass *analysis.Pass, lookup *funcLookup, err types.Type, spec *ast.TypeSpec) error {
namedErr := getNamedType(err)
if namedErr == nil {
logf("err type: %#v\n", err)
return fmt.Errorf("type is an invalid error type")
}
// Ignore interface types: we don't need to tag them, only concrete implementations.
if _, ok := namedErr.Underlying().(*types.Interface); ok {
return nil
}
funcDecl, receiver := getCodeFuncFromError(pass, lookup, err)
if funcDecl == nil {
return fmt.Errorf(`found no method "Code() string"`)
}
errorType := analyseCodeMethod(pass, spec, funcDecl, receiver)
if errorType == nil {
return fmt.Errorf("type %q is an invalid error type: could not find any error codes", namedErr.Obj().Name())
}
analyseMethodsOfErrorType(pass, lookup, errorType, err)
pass.ExportObjectFact(namedErr.Obj(), errorType)
return nil
}
// getErrorTypeForError gets the ErrorType for the given error from cache,
// or on a cache miss computes said ErrorType and stores it in the cache.
func getErrorTypeForError(pass *analysis.Pass, err types.Type) (*ErrorType, error) {
namedErr := getNamedType(err)
if namedErr == nil {
logf("err type: %#v\n", err)
return nil, fmt.Errorf("passed invalid err type to getErrorTypeForError")
}
errorType := new(ErrorType)
if pass.ImportObjectFact(namedErr.Obj(), errorType) {
return errorType, nil
}
return nil, nil
}
// getCodeFuncFromError finds and returns the method declaration of "Code() string" for the given error type.
//
// The second result is the identifier which is the receiver of the method,
// or nil if the receiver is unnamed.
func getCodeFuncFromError(pass *analysis.Pass, lookup *funcLookup, err types.Type) (result *ast.FuncDecl, receiver *ast.Ident) {
// Use lookup struct to find correct Code() method
methods, ok := lookup.methods["Code"]
if !ok {
return nil, nil
}
// Search through all methods named "Code" to find the right one for the given error type.
for _, funcDecl := range methods {
// funcDecl is guaranteed to have one receiver, because it is a method
receiverField := funcDecl.Recv.List[0]
if !errorTypesSubset(pass.TypesInfo.TypeOf(receiverField.Type), err) {
continue
}
if len(receiverField.Names) == 1 {
return funcDecl, receiverField.Names[0]
}
return funcDecl, nil
}
return nil, nil
}
// errorTypesSubset checks if type1 is a subset of type2.
func errorTypesSubset(type1, type2 types.Type) bool {
pointer2, ok2 := type2.(*types.Pointer)
return types.Identical(type1, type2) ||
(ok2 && types.Identical(type1, pointer2.Elem()))
}
type codeMethodAnalysis struct {
pass *analysis.Pass
funcDecl *ast.FuncDecl
receiver *ast.Ident
visited map[*ast.Object]struct{}
// Output
codes CodeSet
errorCodeField *ast.Ident
}
// analyseCodeMethod inspects the error type.
//
// If the Code() method returns a constant value:
// - That is the error code we're looking for
// - Having multiple return statements returning different error codes is also possible
// - (We only ever consider constant value expressions. Everything else would be hard to impossible to track.)
// - Empty strings are allowed, but not considered a code
// If the Code() method returns a single struct field:
// - Find and return the field position and identifier
// - Position needed for tracking creation with a constructor
// - Identifier needed for creation with named constructor and tracking assignments to the field
// All other return statements are marked as invalid by emitting diagnostics.
func analyseCodeMethod(pass *analysis.Pass, spec *ast.TypeSpec, funcDecl *ast.FuncDecl, receiver *ast.Ident) *ErrorType {
state := codeMethodAnalysis{
pass: pass,
funcDecl: funcDecl,
receiver: receiver,
visited: map[*ast.Object]struct{}{},
codes: Set(),
errorCodeField: nil,
}
ast.Inspect(funcDecl, func(node ast.Node) bool {
switch node := node.(type) {
case *ast.FuncLit:
return false // We're not interested in return statements of nested function literals.
case *ast.ReturnStmt:
if len(node.Results) == 0 { // Return statement with named result.
state.analyseNamedReturn()
} else if len(node.Results) == 1 {
state.analyseReturnedExpression(node.Results[0])
} else {
panic("should be unreachable: we already know that the method returns a single value. Return statements that don't do so should lead to a compile time error.")
}
}
return true
})
fieldName := state.errorCodeField
constants := state.codes
var field *ErrorCodeField
if fieldName != nil {
position := getFieldPosition(spec, fieldName)
if position >= 0 {
field = &ErrorCodeField{fieldName.Name, position}
} else {
pass.Reportf(funcDecl.Pos(), "returned field %q is not a valid error code field (promoted fields are not supported currently, but might be added in the future)", fieldName)
}
}
if len(constants) == 0 && field == nil {
// In this case errors are already reported:
// The signature of the Code() method requires at least one return statement in its implementation.
// The return statements are all analysed and only if all are invalid this branch is entered.
return nil
}
return &ErrorType{Codes: constants.Slice(), Field: field}
}
func (state *codeMethodAnalysis) analyseReturnedExpression(node ast.Expr) {
pass := state.pass
returnResult := astutil.Unparen(node)
// If the return statement returns a constant string value:
// Check if it is a valid error code and if so add it to the error code constants.
returnType := pass.TypesInfo.Types[returnResult]
if returnType.Value != nil {
value, err := getErrorCodeFromConstant(returnType.Value)
if err == nil {
if value != "" { // Ignore empty string result of Code method.
state.codes.Add(value)
}
} else {
pass.ReportRangef(node, "%v", err)
}
return
}
// Otherwise check if a single field is returned.
// Make sure that always the same field is returned and otherwise emit a diagnostic.
expression, ok := returnResult.(*ast.SelectorExpr)
if ok && state.receiver != nil {
ident, ok := astutil.Unparen(expression.X).(*ast.Ident)
if ok && ident.Obj == state.receiver.Obj {
if state.errorCodeField == nil {
state.errorCodeField = expression.Sel
} else if state.errorCodeField.Name != expression.Sel.Name {
pass.ReportRangef(node, "only single field allowed: cannot return field %q because field %q was returned previously", expression.Sel.Name, state.errorCodeField.Name)
}
return
}
}
// If an ident is returned: analyse the ident taint.
// This also checks, if the ident is allowed to be returned. (i.e. that it is local)
returnIdent, ok := returnResult.(*ast.Ident)
if ok {
state.analyseReturnedIdentTaint(returnIdent)
return
}
pass.ReportRangef(node, `function %q should always return a string constant or a single field`, state.funcDecl.Name.Name)
}
func (state *codeMethodAnalysis) analyseNamedReturn() {
funcDecl := state.funcDecl
if funcDecl.Type.Results == nil || len(funcDecl.Type.Results.List) != 1 {
panic("should be unreachable: we already know that the method returns a single value.")
}
returnField := funcDecl.Type.Results.List[0]
if len(returnField.Names) != 1 {
panic("should be unreachable: we already know that the method returns a single named value. (Encountered empty return, so returned value must be named.)")
}
returnIdent := returnField.Names[0]
state.analyseReturnedIdentTaint(returnIdent)
}
func (state *codeMethodAnalysis) analyseReturnedIdentTaint(ident *ast.Ident) {
pass := state.pass
taintResult := taintSpreadForIdentOfImmutableType(state.pass, state.visited, ident, &funcDefinition{state.funcDecl, nil})
for _, badIdent := range taintResult.identOutOfScope {
pass.ReportRangef(badIdent, "error code variable may not be a parameter, receiver or global variable")
}
for _, destruct := range taintResult.destructAssignment {
pass.ReportRangef(destruct.source, "unsupported: assigning result of function call to variable %q is not allowed", destruct.target.Name)
}
for _, expr := range taintResult.expressions {
state.analyseReturnedExpression(expr)
}
}
// getFieldPosition gets the position of the given field in the error struct.
func getFieldPosition(errorTypeSpec *ast.TypeSpec, fieldName *ast.Ident) int {
errorType, ok := errorTypeSpec.Type.(*ast.StructType)
if !ok || errorType.Fields.List == nil {
return -1
}
i := 0
for _, field := range errorType.Fields.List {
if field.Names == nil {
i++
continue
}
for _, name := range field.Names {
if name.Name == fieldName.Name {
return i
}
i++
}
}
return -1
}
// analyseMethodsOfErrorType looks at all methods of the given error type
// and makes sure there are no invalid assingments to the error code field.
func analyseMethodsOfErrorType(pass *analysis.Pass, lookup *funcLookup, errorType *ErrorType, err types.Type) {
// Return early if there is no error code field.
if errorType.Field == nil {
return
}
assignedCodes := Set()
errorMethods := collectMethodsForErrorType(pass, lookup, err)
for _, method := range errorMethods {
// Only consider methods that have a named receiver,
// because only for those assignments to a field are possible.
receivers := method.Recv.List[0]
if len(receivers.Names) != 1 {
continue
}
receiver := receivers.Names[0]
newCodes := findCodesAssignedToErrorCodeField(pass, &funcDefinition{method, nil}, errorType, receiver.Obj)
assignedCodes = Union(assignedCodes, newCodes)
}
// If more error codes are found, add them to the given error type.
if len(assignedCodes) > 0 {
codes := Union(SliceToSet(errorType.Codes), assignedCodes)
errorType.Codes = codes.Slice()
}
}
// collectMethodsForErrorType finds all methods defined for the given error type in the current package.
func collectMethodsForErrorType(pass *analysis.Pass, lookup *funcLookup, err types.Type) []*ast.FuncDecl {
namedErr := getNamedType(err)
if namedErr == nil {
return nil
}
// Only consider method names that were discovered by the type checker.
result := make([]*ast.FuncDecl, 0, namedErr.NumMethods())
for i := 0; i < namedErr.NumMethods(); i++ {
methodName := namedErr.Method(i).Name()
methods, ok := lookup.methods[methodName]
if !ok {
continue
}
for _, funcDecl := range methods {
receiverField := funcDecl.Recv.List[0]
if !errorTypesSubset(pass.TypesInfo.TypeOf(receiverField.Type), err) {
continue
}
result = append(result, funcDecl)
}
}
return result
}