-
Notifications
You must be signed in to change notification settings - Fork 0
/
limitterRedis.go
111 lines (92 loc) · 3.68 KB
/
limitterRedis.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/*
Implementation of limitter in redis
*/
package limitter
import (
"context"
"fmt"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
log "github.com/sirupsen/logrus"
)
var rdb *redis.Client
var environment string
var keyPrefix string
func InitRedis(pKeyPrefix string, pEnvironment string, pRedisServerAddress string, pRedisPassword string, pRedisDatabase int) {
rdb = redis.NewClient(&redis.Options{
Addr: pRedisServerAddress,
Password: pRedisPassword,
DB: pRedisDatabase, // use default DB
})
environment = pEnvironment
keyPrefix = pKeyPrefix
log.Infof("RedisRequestLimitter: Init, redisHost=%v, redisPass=%v, redisDB=%v, environment=%v, keyPrefix=%v",
rdb.Options().Addr, len(rdb.Options().Password), rdb.Options().DB, environment, keyPrefix)
}
func CreateRedisTrackerKey(pUserId string, pUrl string) string {
return fmt.Sprintf("%v:%v:%v:%v", keyPrefix, environment, pUserId, pUrl)
}
func LoadRedisRequestTracker(ctx context.Context, rClient *redis.Client, userId string, url string) (*RequestTracker, error) {
trackerKey := CreateRedisTrackerKey(userId, url)
var tracker *RequestTracker = NewRequestTracker(userId, url)
errGetTracker := rClient.HGetAll(ctx, trackerKey).Scan(tracker)
return tracker, errGetTracker
}
func SaveRedisRequestTracker(ctx context.Context, rClient *redis.Client, tracker *RequestTracker, expireSecond int64) error {
_, errSetTracker := rdb.Pipelined(ctx, func(rdb redis.Pipeliner) error {
trackerKey := CreateRedisTrackerKey(tracker.UID, tracker.URL)
rdb.HSet(ctx, trackerKey, "uid", tracker.UID)
rdb.HSet(ctx, trackerKey, "url", tracker.URL)
rdb.HSet(ctx, trackerKey, "winNum", tracker.WindowNum)
rdb.HSet(ctx, trackerKey, "winReq", tracker.WindowRequest)
rdb.HSet(ctx, trackerKey, "last", tracker.LastCall)
rdb.HSet(ctx, trackerKey, "exp", tracker.Exp)
if expireSecond > 0 {
rdb.Expire(ctx, trackerKey, time.Duration(expireSecond)*time.Second)
}
return nil
})
return errSetTracker
}
func CreateRedisBackedLimitter(pUserIdExtractor func(c *gin.Context) string,
pConfig *LimitterConfig, pIsMiddleware bool) func(c *gin.Context) {
return func(c *gin.Context) {
// Sets the name/ID for the new entity.
userId := pUserIdExtractor(c)
url := c.Request.URL.Path
currentTime := time.Now()
//Validate too fast request
trackerKey := CreateRedisTrackerKey(userId, url)
tracker, errGetTracker := LoadRedisRequestTracker(c.Request.Context(), rdb, userId, url)
if errGetTracker != nil {
log.Errorf("RedisLimitter: userId=%v, url=%v, error=%v", userId, url, errGetTracker)
}
errValidate := ValidateRequest(tracker, currentTime, url, c.ClientIP(), pConfig)
// if log.IsLevelEnabled(log.TraceLevel) {
// log.Tracef("TrackerAfter: %v", tracker)
// }
errSetTracker := SaveRedisRequestTracker(c.Request.Context(), rdb, tracker, pConfig.ExpSec)
if errValidate == nil {
if errSetTracker != nil && pConfig.AbortOnFail {
log.Errorf("RedisLimitter: SaveTrackerFailed, userId=%v, key=%v, error=%v", userId, trackerKey, errSetTracker)
errValidate = errSetTracker
}
} else {
log.Errorf("RedisLimitter: ValidateTrackerFailed, userId=%v, key=%v, sinceLastCall=%v, error=%v",
userId, trackerKey, currentTime.UnixMilli()-tracker.LastCall, errValidate)
}
ProcessValidateResult(errValidate, c, pIsMiddleware)
if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("RequestLimitter: ValidateFinish, UID=%v, url=%v, IP=%v, calls=%v|%v, window=%v/%v|%v, key=%v, errValidate=%v",
tracker.UID,
url,
c.ClientIP(),
currentTime.UnixMilli()-tracker.LastCall, pConfig.MinRequestInterval,
tracker.WindowRequest, pConfig.MaxRequestPerWindow, tracker.WindowNum,
trackerKey,
errValidate,
)
}
}
}