-
Notifications
You must be signed in to change notification settings - Fork 0
/
limitterDatastore.go
88 lines (76 loc) · 2.64 KB
/
limitterDatastore.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
/*
Limitter that uses datastore
*/
package limitter
import (
"errors"
"time"
"cloud.google.com/go/datastore"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
)
func CreateDatastoreBackedLimitter(pClient *datastore.Client, pTrackerKind string,
pUserIdExtractor func(c *gin.Context) string,
pConfig *LimitterConfig,
pIsMiddleware bool) func(c *gin.Context) {
return func(c *gin.Context) {
// Sets the name/ID for the new entity.
userId := pUserIdExtractor(c)
url := c.Request.URL.Path
currentTime := time.Now()
tracker := &RequestTracker{}
trackerName := CreateTrackerName(userId, url)
trackerKey := datastore.NameKey(pTrackerKind, trackerName, nil)
_, err := pClient.RunInTransaction(c.Request.Context(), func(tx *datastore.Transaction) error {
errTracker := tx.Get(trackerKey, tracker)
if errTracker != nil {
_, isErrorFieldMismatch := errTracker.(*datastore.ErrFieldMismatch)
if isErrorFieldMismatch {
errTracker = nil
if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("LoadUserTracker: TypeMisMatch, kind=%v, url=%v, userId=%v, error=%v",
pTrackerKind, url, userId, errTracker)
}
} else if errors.Is(errTracker, datastore.ErrNoSuchEntity) {
errTracker = nil
tracker = NewRequestTrackerWithExpiration(userId, url, pConfig.CreateExpiration(currentTime))
if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("LoadUserTracker: NotFound, kind=%v, url=%v, userId=%v, error=%v",
pTrackerKind, url, userId, errTracker)
}
} else {
//It's critical
log.Errorf("LoadUserTracker: Failed, kind=%v, url=%v, userId=%v, error=%v",
pTrackerKind, url, userId, errTracker)
return errTracker
}
} else {
if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("RequestLimitter: TrackerLoaded, key=%v, tracker=%v", trackerKey, tracker)
}
}
errValidate := ValidateRequest(tracker, currentTime, url, c.ClientIP(), pConfig)
if errValidate != nil {
return errValidate
}
_, errTracker = tx.Put(trackerKey, tracker)
if errTracker != nil {
log.Errorf("RequestLimitter: UpdateTrackerFailed, UID=%v, key=%v, error=%v", userId, trackerKey, errTracker)
return errTracker
}
return nil
})
ProcessValidateResult(err, c, pIsMiddleware)
if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("RequestLimitter: ValidateFinish, UID=%v, url=%v, IP=%v, calls=%v|%v, window=%v/%v|%v, key=%v, errValidate=%v",
tracker.UID,
url,
c.ClientIP(),
currentTime.UnixMilli()-tracker.LastCall, pConfig.MinRequestInterval,
tracker.WindowRequest, pConfig.MaxRequestPerWindow, tracker.WindowNum,
trackerKey,
err,
)
}
}
}