-
Notifications
You must be signed in to change notification settings - Fork 10
/
rateLimiter.go
94 lines (80 loc) · 2.48 KB
/
rateLimiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
package middleware
import (
"fmt"
"net/http"
"github.com/RohitAwate/OAuth2Bin/oauth2/cache"
"github.com/gomodule/redigo/redis"
)
// RatePolicy represents the rate limiting policy
// for a specific route.
//
// Route: the server route to apply the policy to
// Limit: the number of API calls allowed
// Minutes: the duration in minutes over which 'Limit' is imposed
type RatePolicy struct {
Route string `json:"route"`
Limit int `json:"limit"`
Minutes int `json:"minutes"`
}
// RateLimiter is an implementation of Middleware.
// It holds a list of policies that are checked
// when the CheckLimit method is invoked.
type RateLimiter struct {
Policies []RatePolicy
}
// Handle checks if the client is within the limits enforced by the policies
// and returns the appropriate boolean value.
func (rl RateLimiter) Handle(handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
policy := rl.getRatePolicy(r.URL.Path)
if policy == nil {
// letting this request pass since no policies are set
handler.ServeHTTP(w, r)
return
}
hits, err := setHit(policy, r.RemoteAddr)
if err != nil {
// letting this request pass since there may be an issue with Redis
handler.ServeHTTP(w, r)
return
}
if hits > policy.Limit {
showError(policy, w, r)
} else {
handler.ServeHTTP(w, r)
}
}
}
// Searches the policies based on the route
func (rl RateLimiter) getRatePolicy(route string) *RatePolicy {
for _, policy := range rl.Policies {
if route == policy.Route {
return &policy
}
}
return nil
}
// TODO: try to use goroutines for Redis calls
// Registers a new hit for the route from the IP in Redis.
// Returns the current hit count or an error.
func setHit(policy *RatePolicy, ip string) (int, error) {
conn := cache.NewConn()
defer cache.CloseConn(conn)
key := fmt.Sprintf("%s:%s", policy.Route, ip)
res, err := redis.String(conn.Do("GET", key))
// if key exists, increment
if err == nil {
return redis.Int(conn.Do("INCR", key))
// else, set key with value 1 and set TTL according to policy
} else if res == "" && err == redis.ErrNil {
res, err = redis.String(conn.Do("SET", key, 1, "EX", policy.Minutes*60))
if res == "OK" {
return 1, nil
}
}
return -1, nil
}
func showError(policy *RatePolicy, w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
fmt.Fprintf(w, "You have exceeded the rate limit of %d requests per %d minute(s) on this route.\n", policy.Limit, policy.Minutes)
}