forked from 99designs/gqlgen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
complexity.go
104 lines (92 loc) · 3.16 KB
/
complexity.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package complexity
import (
"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/ast"
)
func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int {
walker := complexityWalker{
es: es,
schema: es.Schema(),
vars: vars,
}
return walker.selectionSetComplexity(op.SelectionSet)
}
type complexityWalker struct {
es graphql.ExecutableSchema
schema *ast.Schema
vars map[string]interface{}
}
func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int {
var complexity int
for _, selection := range selectionSet {
switch s := selection.(type) {
case *ast.Field:
fieldDefinition := cw.schema.Types[s.Definition.Type.Name()]
var childComplexity int
switch fieldDefinition.Kind {
case ast.Object, ast.Interface, ast.Union:
childComplexity = cw.selectionSetComplexity(s.SelectionSet)
}
args := s.ArgumentMap(cw.vars)
var fieldComplexity int
if s.ObjectDefinition.Kind == ast.Interface {
fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args)
} else {
fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args)
}
complexity = safeAdd(complexity, fieldComplexity)
case *ast.FragmentSpread:
complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet))
case *ast.InlineFragment:
complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet))
}
}
return complexity
}
func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int {
// Interfaces don't have their own separate field costs, so they have to assume the worst case.
// We iterate over all implementors and choose the most expensive one.
maxComplexity := 0
implementors := cw.schema.GetPossibleTypes(def)
for _, t := range implementors {
fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args)
if fieldComplexity > maxComplexity {
maxComplexity = fieldComplexity
}
}
return maxComplexity
}
func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int {
if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity {
return customComplexity
}
// default complexity calculation
return safeAdd(1, childComplexity)
}
const maxInt = int(^uint(0) >> 1)
// safeAdd is a saturating add of a and b that ignores negative operands.
// If a + b would overflow through normal Go addition,
// it returns the maximum integer value instead.
//
// Adding complexities with this function prevents attackers from intentionally
// overflowing the complexity calculation to allow overly-complex queries.
//
// It also helps mitigate the impact of custom complexities that accidentally
// return negative values.
func safeAdd(a, b int) int {
// Ignore negative operands.
if a < 0 {
if b < 0 {
return 1
}
return b
} else if b < 0 {
return a
}
c := a + b
if c < a {
// Set c to maximum integer instead of overflowing.
c = maxInt
}
return c
}