forked from vulcand/vulcand
/
connlimiter.go
94 lines (80 loc) · 2.51 KB
/
connlimiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
// Simultaneous connection limiter
package connlimit
import (
"fmt"
"github.com/mailgun/vulcand/Godeps/_workspace/src/github.com/mailgun/vulcan/errors"
"github.com/mailgun/vulcand/Godeps/_workspace/src/github.com/mailgun/vulcan/limit"
"github.com/mailgun/vulcand/Godeps/_workspace/src/github.com/mailgun/vulcan/netutils"
"github.com/mailgun/vulcand/Godeps/_workspace/src/github.com/mailgun/vulcan/request"
"net/http"
"sync"
)
// This limiter tracks concurrent connection per token
// and is capable of rejecting connections if they are failed
type ConnectionLimiter struct {
mutex *sync.Mutex
mapper limit.MapperFn
connections map[string]int64
maxConnections int64
totalConnections int64
}
func NewClientIpLimiter(maxConnections int64) (*ConnectionLimiter, error) {
return NewConnectionLimiter(limit.MapClientIp, maxConnections)
}
func NewConnectionLimiter(mapper limit.MapperFn, maxConnections int64) (*ConnectionLimiter, error) {
if mapper == nil {
return nil, fmt.Errorf("Mapper function can not be nil")
}
if maxConnections <= 0 {
return nil, fmt.Errorf("Max connections should be >= 0")
}
return &ConnectionLimiter{
mutex: &sync.Mutex{},
mapper: mapper,
maxConnections: maxConnections,
connections: make(map[string]int64),
}, nil
}
func (cl *ConnectionLimiter) ProcessRequest(r request.Request) (*http.Response, error) {
cl.mutex.Lock()
defer cl.mutex.Unlock()
token, amount, err := cl.mapper(r)
if err != nil {
return nil, err
}
connections := cl.connections[token]
if connections >= cl.maxConnections {
return netutils.NewTextResponse(
r.GetHttpRequest(),
errors.StatusTooManyRequests,
fmt.Sprintf("Connection limit reached. Max is: %d, yours: %d", cl.maxConnections, connections)), nil
}
cl.connections[token] += amount
cl.totalConnections += int64(amount)
return nil, nil
}
func (cl *ConnectionLimiter) ProcessResponse(r request.Request, a request.Attempt) {
cl.mutex.Lock()
defer cl.mutex.Unlock()
token, amount, err := cl.mapper(r)
if err != nil {
return
}
cl.connections[token] -= amount
cl.totalConnections -= int64(amount)
// Otherwise it would grow forever
if cl.connections[token] == 0 {
delete(cl.connections, token)
}
}
func (cl *ConnectionLimiter) GetConnectionCount() int64 {
cl.mutex.Lock()
defer cl.mutex.Unlock()
return cl.totalConnections
}
func (cl *ConnectionLimiter) GetMaxConnections() int64 {
return cl.maxConnections
}
func (cl *ConnectionLimiter) SetMaxConnections(max int64) {
cl.maxConnections = max
}