/
redisConcurrentRateLimiter.kt
59 lines (51 loc) · 1.72 KB
/
redisConcurrentRateLimiter.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
package org.sdf.rkm
import es.moki.ratelimitj.core.limiter.concurrent.Baton
import es.moki.ratelimitj.core.limiter.concurrent.ConcurrentLimitRule
import es.moki.ratelimitj.core.limiter.concurrent.ConcurrentRequestLimiter
import es.moki.ratelimitj.core.limiter.request.RequestLimitRule
import es.moki.ratelimitj.redis.request.RedisRateLimiterFactory
import java.util.*
import java.util.concurrent.TimeUnit
import java.util.function.Supplier
class RedisConcurrentRateLimiter(factory: RedisRateLimiterFactory, rule: ConcurrentLimitRule): ConcurrentRequestLimiter {
private val rules = setOf(RequestLimitRule.of(rule.timeoutMillis.toInt(), TimeUnit.MILLISECONDS, rule.concurrentLimit.toLong()))
private val rateLimiter = factory.getInstance(rules)
override fun acquire(key: String): Baton {
return acquire(key, 1)
}
override fun acquire(key: String, weight: Int): Baton {
return if (rateLimiter.overLimitWhenIncremented(key, weight)) {
RedisBaton()
} else {
RedisBaton(true)
}
}
}
class RedisBaton(var acquired: Boolean = false): Baton {
override fun hasAcquired(): Boolean {
return acquired
}
override fun doAction(action: Runnable) {
if (acquired) {
try {
action.run()
} finally {
acquired = false
}
}
}
override fun release() {
// expires automatically
}
override fun <T : Any> get(action: Supplier<T>): Optional<T> {
return if (acquired) {
try {
Optional.of(action.get())
} finally {
acquired = false
}
} else {
Optional.empty()
}
}
}