forked from vulcand/oxy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
connlimit.go
162 lines (137 loc) · 3.92 KB
/
connlimit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
// Package connlimit provides control over simultaneous connections coming from the same source
package connlimit
import (
"fmt"
"net/http"
"sync"
"abstraction.fr/oxy/v2/utils"
)
// ConnLimiter tracks concurrent connection per token
// and is capable of rejecting connections if they are failed
type ConnLimiter struct {
mutex *sync.Mutex
extract utils.SourceExtractor
connections map[string]int64
maxConnections int64
totalConnections int64
next http.Handler
errHandler utils.ErrorHandler
log utils.Logger
debug utils.LoggerDebugFunc
}
// New creates a new ConnLimiter
func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) {
if extract == nil {
return nil, fmt.Errorf("Extract function can not be nil")
}
cl := &ConnLimiter{
mutex: &sync.Mutex{},
extract: extract,
maxConnections: maxConnections,
connections: make(map[string]int64),
next: next,
log: &utils.DefaultLogger{},
debug: utils.DefaultLoggerDebugFunc,
}
for _, o := range options {
if err := o(cl); err != nil {
return nil, err
}
}
if cl.errHandler == nil {
cl.errHandler = &ConnErrHandler{
log: cl.log,
debug: cl.debug,
}
}
return cl, nil
}
// Logger defines the logger the connection limiter will use.
func Logger(l utils.Logger) ConnLimitOption {
return func(cl *ConnLimiter) error {
cl.log = l
return nil
}
}
// Debug defines if we should generate debug logs. It will still depends on the
// logger to print them or not.
func Debug(d utils.LoggerDebugFunc) ConnLimitOption {
return func(cl *ConnLimiter) error {
cl.debug = d
return nil
}
}
// Wrap sets the next handler to be called by connexion limiter handler.
func (cl *ConnLimiter) Wrap(h http.Handler) {
cl.next = h
}
func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
token, amount, err := cl.extract.Extract(r)
if err != nil {
cl.log.Errorf("failed to extract source of the connection: %v", err)
cl.errHandler.ServeHTTP(w, r, err)
return
}
if err := cl.acquire(token, amount); err != nil {
cl.log.Debugf("limiting request source %s: %v", token, err)
cl.errHandler.ServeHTTP(w, r, err)
return
}
defer cl.release(token, amount)
cl.next.ServeHTTP(w, r)
}
func (cl *ConnLimiter) acquire(token string, amount int64) error {
cl.mutex.Lock()
defer cl.mutex.Unlock()
connections := cl.connections[token]
if connections >= cl.maxConnections {
return &MaxConnError{max: cl.maxConnections}
}
cl.connections[token] += amount
cl.totalConnections += amount
return nil
}
func (cl *ConnLimiter) release(token string, amount int64) {
cl.mutex.Lock()
defer cl.mutex.Unlock()
cl.connections[token] -= amount
cl.totalConnections -= amount
// Otherwise it would grow forever
if cl.connections[token] == 0 {
delete(cl.connections, token)
}
}
// MaxConnError maximum connections reached error
type MaxConnError struct {
max int64
}
func (m *MaxConnError) Error() string {
return fmt.Sprintf("max connections reached: %d", m.max)
}
// ConnErrHandler connection limiter error handler
type ConnErrHandler struct {
log utils.Logger
debug utils.LoggerDebugFunc
}
func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
if e.debug() {
dump := utils.DumpHttpRequest(req)
e.log.Debugf("connlimit: begin ServeHttp on request: %s", dump)
defer e.log.Debugf("connlimit: completed ServeHttp on request: %s", dump)
}
if _, ok := err.(*MaxConnError); ok {
w.WriteHeader(429)
w.Write([]byte(err.Error()))
return
}
utils.DefaultHandler.ServeHTTP(w, req, err)
}
// ConnLimitOption connection limit option type
type ConnLimitOption func(l *ConnLimiter) error
// ErrorHandler sets error handler of the server
func ErrorHandler(h utils.ErrorHandler) ConnLimitOption {
return func(cl *ConnLimiter) error {
cl.errHandler = h
return nil
}
}