-
Notifications
You must be signed in to change notification settings - Fork 402
/
ratelimiter.go
152 lines (134 loc) · 4.05 KB
/
ratelimiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
// Copyright (C) 2019 Storj Labs, Inc.
// See LICENSE for copying information.
package web
import (
"context"
"net"
"net/http"
"strings"
"sync"
"time"
"golang.org/x/time/rate"
)
// IPRateLimiterConfig configures an IPRateLimiter.
type IPRateLimiterConfig struct {
Duration time.Duration `help:"the rate at which request are allowed" default:"5m"`
Burst int `help:"number of events before the limit kicks in" default:"5"`
NumLimits int `help:"number of IPs whose rate limits we store" default:"1000"`
}
// IPRateLimiter imposes a rate limit per HTTP user IP.
type IPRateLimiter struct {
config IPRateLimiterConfig
mu sync.Mutex
ipLimits map[string]*userLimit
}
// userLimit is the per-IP limiter.
type userLimit struct {
limiter *rate.Limiter
lastSeen time.Time
}
// NewIPRateLimiter constructs an IPRateLimiter.
func NewIPRateLimiter(config IPRateLimiterConfig) *IPRateLimiter {
return &IPRateLimiter{
config: config,
ipLimits: make(map[string]*userLimit),
}
}
// Run occasionally cleans old rate-limiting data, until context cancel.
func (rl *IPRateLimiter) Run(ctx context.Context) {
cleanupTicker := time.NewTicker(rl.config.Duration)
defer cleanupTicker.Stop()
for {
select {
case <-ctx.Done():
return
case <-cleanupTicker.C:
rl.cleanupLimiters()
}
}
}
// cleanupLimiters removes old rate limits to free memory.
func (rl *IPRateLimiter) cleanupLimiters() {
rl.mu.Lock()
defer rl.mu.Unlock()
for ip, v := range rl.ipLimits {
if time.Since(v.lastSeen) > rl.config.Duration {
delete(rl.ipLimits, ip)
}
}
}
// Limit applies a per IP rate limiting as an HTTP Handler.
func (rl *IPRateLimiter) Limit(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip, err := getRequestIP(r)
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
ipLimit := rl.getUserLimit(ip)
if !ipLimit.Allow() {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// getRequestIP gets the original IP address of the request by handling the request headers.
func getRequestIP(r *http.Request) (ip string, err error) {
realIP := r.Header.Get("X-REAL-IP")
if realIP != "" {
return realIP, nil
}
forwardedIPs := r.Header.Get("X-FORWARDED-FOR")
if forwardedIPs != "" {
ips := strings.Split(forwardedIPs, ", ")
if len(ips) > 0 {
return ips[0], nil
}
}
ip, _, err = net.SplitHostPort(r.RemoteAddr)
return ip, err
}
// getUserLimit returns a rate limiter for an IP.
func (rl *IPRateLimiter) getUserLimit(ip string) *rate.Limiter {
rl.mu.Lock()
defer rl.mu.Unlock()
v, exists := rl.ipLimits[ip]
if !exists {
if len(rl.ipLimits) >= rl.config.NumLimits {
// Tracking only N limits prevents an out-of-memory DOS attack
// Returning StatusTooManyRequests would be just as bad
// The least-bad option may be to remove the oldest key
oldestKey := ""
var oldestTime *time.Time
for ip, v := range rl.ipLimits {
// while we're looping, we'd prefer to just delete expired records
if time.Since(v.lastSeen) > rl.config.Duration {
delete(rl.ipLimits, ip)
}
// but we're prepared to delete the oldest non-expired
if oldestTime == nil || v.lastSeen.Before(*oldestTime) {
oldestTime = &v.lastSeen
oldestKey = ip
}
}
// only delete the oldest non-expired if there's still an issue
if oldestKey != "" && len(rl.ipLimits) >= rl.config.NumLimits {
delete(rl.ipLimits, oldestKey)
}
}
limiter := rate.NewLimiter(rate.Limit(time.Second)/rate.Limit(rl.config.Duration), rl.config.Burst)
rl.ipLimits[ip] = &userLimit{limiter, time.Now()}
return limiter
}
v.lastSeen = time.Now()
return v.limiter
}
// Burst returns the number of events that happen before the rate limit.
func (rl *IPRateLimiter) Burst() int {
return rl.config.Burst
}
// Duration returns the amount of time required between events.
func (rl *IPRateLimiter) Duration() time.Duration {
return rl.config.Duration
}