forked from vulcand/oxy
/
connlimit.go
143 lines (120 loc) · 3.36 KB
/
connlimit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// Package connlimit provides control over simultaneous connections coming from the same source
package connlimit
import (
"fmt"
"net/http"
"sync"
"abstraction.fr/oxy/v2/utils"
)
// ConnLimiter tracks concurrent connection per token
// and is capable of rejecting connections if they are failed
type ConnLimiter struct {
mutex *sync.Mutex
extract utils.SourceExtractor
connections map[string]int64
maxConnections int64
totalConnections int64
next http.Handler
errHandler utils.ErrorHandler
log utils.Logger
}
// New creates a new ConnLimiter
func New(next http.Handler, extract utils.SourceExtractor, maxConnections int64, options ...ConnLimitOption) (*ConnLimiter, error) {
if extract == nil {
return nil, fmt.Errorf("Extract function can not be nil")
}
cl := &ConnLimiter{
mutex: &sync.Mutex{},
extract: extract,
maxConnections: maxConnections,
connections: make(map[string]int64),
next: next,
log: &utils.DefaultLogger{},
}
for _, o := range options {
if err := o(cl); err != nil {
return nil, err
}
}
if cl.errHandler == nil {
cl.errHandler = &ConnErrHandler{
log: cl.log,
}
}
return cl, nil
}
// Logger defines the logger the connection limiter will use.
func Logger(l utils.Logger) ConnLimitOption {
return func(cl *ConnLimiter) error {
cl.log = l
return nil
}
}
// Wrap sets the next handler to be called by connexion limiter handler.
func (cl *ConnLimiter) Wrap(h http.Handler) {
cl.next = h
}
func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
token, amount, err := cl.extract.Extract(r)
if err != nil {
cl.log.Errorf("failed to extract source of the connection: %v", err)
cl.errHandler.ServeHTTP(w, r, err)
return
}
if err := cl.acquire(token, amount); err != nil {
cl.log.Debugf("limiting request source %s: %v", token, err)
cl.errHandler.ServeHTTP(w, r, err)
return
}
defer cl.release(token, amount)
cl.next.ServeHTTP(w, r)
}
func (cl *ConnLimiter) acquire(token string, amount int64) error {
cl.mutex.Lock()
defer cl.mutex.Unlock()
connections := cl.connections[token]
if connections >= cl.maxConnections {
return &MaxConnError{max: cl.maxConnections}
}
cl.connections[token] += amount
cl.totalConnections += amount
return nil
}
func (cl *ConnLimiter) release(token string, amount int64) {
cl.mutex.Lock()
defer cl.mutex.Unlock()
cl.connections[token] -= amount
cl.totalConnections -= amount
// Otherwise it would grow forever
if cl.connections[token] == 0 {
delete(cl.connections, token)
}
}
// MaxConnError maximum connections reached error
type MaxConnError struct {
max int64
}
func (m *MaxConnError) Error() string {
return fmt.Sprintf("max connections reached: %d", m.max)
}
// ConnErrHandler connection limiter error handler
type ConnErrHandler struct {
log utils.Logger
}
func (e *ConnErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
if _, ok := err.(*MaxConnError); ok {
w.WriteHeader(429)
w.Write([]byte(err.Error()))
return
}
utils.DefaultHandler.ServeHTTP(w, req, err)
}
// ConnLimitOption connection limit option type
type ConnLimitOption func(l *ConnLimiter) error
// ErrorHandler sets error handler of the server
func ErrorHandler(h utils.ErrorHandler) ConnLimitOption {
return func(cl *ConnLimiter) error {
cl.errHandler = h
return nil
}
}