forked from zeromicro/go-zero
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tokenlimit.go
182 lines (157 loc) · 4.69 KB
/
tokenlimit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
package limit
import (
"context"
"errors"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
xrate "golang.org/x/time/rate"
)
const (
tokenFormat = "{%s}.tokens"
timestampFormat = "{%s}.ts"
pingInterval = time.Millisecond * 100
)
// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key
// KEYS[1] as tokens_key
// KEYS[2] as timestamp_key
var script = redis.NewScript(`local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)
local last_tokens = tonumber(redis.call("get", KEYS[1]))
if last_tokens == nil then
last_tokens = capacity
end
local last_refreshed = tonumber(redis.call("get", KEYS[2]))
if last_refreshed == nil then
last_refreshed = 0
end
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
if allowed then
new_tokens = filled_tokens - requested
end
redis.call("setex", KEYS[1], ttl, new_tokens)
redis.call("setex", KEYS[2], ttl, now)
return allowed`)
// A TokenLimiter controls how frequently events are allowed to happen with in one second.
type TokenLimiter struct {
rate int
burst int
store *redis.Redis
tokenKey string
timestampKey string
rescueLock sync.Mutex
redisAlive uint32
monitorStarted bool
rescueLimiter *xrate.Limiter
}
// NewTokenLimiter returns a new TokenLimiter that allows events up to rate and permits
// bursts of at most burst tokens.
func NewTokenLimiter(rate, burst int, store *redis.Redis, key string) *TokenLimiter {
tokenKey := fmt.Sprintf(tokenFormat, key)
timestampKey := fmt.Sprintf(timestampFormat, key)
return &TokenLimiter{
rate: rate,
burst: burst,
store: store,
tokenKey: tokenKey,
timestampKey: timestampKey,
redisAlive: 1,
rescueLimiter: xrate.NewLimiter(xrate.Every(time.Second/time.Duration(rate)), burst),
}
}
// Allow is shorthand for AllowN(time.Now(), 1).
func (lim *TokenLimiter) Allow() bool {
return lim.AllowN(time.Now(), 1)
}
// AllowCtx is shorthand for AllowNCtx(ctx,time.Now(), 1) with incoming context.
func (lim *TokenLimiter) AllowCtx(ctx context.Context) bool {
return lim.AllowNCtx(ctx, time.Now(), 1)
}
// AllowN reports whether n events may happen at time now.
// Use this method if you intend to drop / skip events that exceed the rate.
// Otherwise, use Reserve or Wait.
func (lim *TokenLimiter) AllowN(now time.Time, n int) bool {
return lim.reserveN(context.Background(), now, n)
}
// AllowNCtx reports whether n events may happen at time now with incoming context.
// Use this method if you intend to drop / skip events that exceed the rate.
// Otherwise, use Reserve or Wait.
func (lim *TokenLimiter) AllowNCtx(ctx context.Context, now time.Time, n int) bool {
return lim.reserveN(ctx, now, n)
}
func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) bool {
if atomic.LoadUint32(&lim.redisAlive) == 0 {
return lim.rescueLimiter.AllowN(now, n)
}
resp, err := lim.store.ScriptRunCtx(ctx,
script,
[]string{
lim.tokenKey,
lim.timestampKey,
},
[]string{
strconv.Itoa(lim.rate),
strconv.Itoa(lim.burst),
strconv.FormatInt(now.Unix(), 10),
strconv.Itoa(n),
})
// redis allowed == false
// Lua boolean false -> r Nil bulk reply
if err == redis.Nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
logx.Errorf("fail to use rate limiter: %s", err)
return false
}
if err != nil {
logx.Errorf("fail to use rate limiter: %s, use in-process limiter for rescue", err)
lim.startMonitor()
return lim.rescueLimiter.AllowN(now, n)
}
code, ok := resp.(int64)
if !ok {
logx.Errorf("fail to eval redis script: %v, use in-process limiter for rescue", resp)
lim.startMonitor()
return lim.rescueLimiter.AllowN(now, n)
}
// redis allowed == true
// Lua boolean true -> r integer reply with value of 1
return code == 1
}
func (lim *TokenLimiter) startMonitor() {
lim.rescueLock.Lock()
defer lim.rescueLock.Unlock()
if lim.monitorStarted {
return
}
lim.monitorStarted = true
atomic.StoreUint32(&lim.redisAlive, 0)
go lim.waitForRedis()
}
func (lim *TokenLimiter) waitForRedis() {
ticker := time.NewTicker(pingInterval)
defer func() {
ticker.Stop()
lim.rescueLock.Lock()
lim.monitorStarted = false
lim.rescueLock.Unlock()
}()
for range ticker.C {
if lim.store.Ping() {
atomic.StoreUint32(&lim.redisAlive, 1)
return
}
}
}