diff --git a/limiter_fixed_truncated_window.go b/limiter_fixed_truncated_window.go index 2c82f71..c5caf74 100644 --- a/limiter_fixed_truncated_window.go +++ b/limiter_fixed_truncated_window.go @@ -38,7 +38,7 @@ type FixedTruncatedWindowRateLimiter struct { mu sync.Mutex - rate time.Duration + rate Rate window time.Time capacity int64 rateLimitReached bool @@ -63,14 +63,21 @@ func (l *FixedTruncatedWindowRateLimiter) Check(ctx context.Context) (Result, er } func (l *FixedTruncatedWindowRateLimiter) Dump(ctx context.Context) (r Result, err error) { + l.mu.Lock() + defer l.mu.Unlock() + now := l.clock.Now() - window := now.Truncate(l.rate) - c, err := l.db.Get(ctx, window) + if TimeGTE(l.window.Add(l.rate.Duration()), now) { + l.rateLimitReached = false + l.window = now.Truncate(l.rate.Unit) + } + + c, err := l.db.Get(ctx, l.window) if c >= l.capacity { // rate limit exceeded - r = res(l.rate-now.Sub(window), l.capacity-c) + r = res(l.window.Add(l.rate.Duration()).Sub(now), l.capacity-c) } else { r = res(0, l.capacity-c) } @@ -89,21 +96,19 @@ func (l *FixedTruncatedWindowRateLimiter) try(ctx context.Context, tokens int64) now := l.clock.Now() - window := now.Truncate(l.rate) - - if !l.window.Equal(window) { + if TimeGTE(l.window.Add(l.rate.Duration()), now) { l.rateLimitReached = false - l.window = window + l.window = now.Truncate(l.rate.Unit) } - ttw := l.rate - now.Sub(window) + ttw := l.window.Add(l.rate.Duration()).Sub(now) if l.rateLimitReached { return res(ttw, 0), ErrRateLimitExceeded } c, err := l.db.Inc(ctx, FixedWindowIncArgs{ - Window: window, + Window: l.window, Tokens: tokens, Capacity: l.capacity, TTL: ttw, @@ -134,22 +139,20 @@ func (l *FixedTruncatedWindowRateLimiter) check(ctx context.Context, tokens int6 now := l.clock.Now() - window := now.Truncate(l.rate) - - if !l.window.Equal(window) { + if TimeGTE(l.window.Add(l.rate.Duration()), now) { // new window so no rate Limit l.rateLimitReached = false - l.window = window + l.window = now.Truncate(l.rate.Unit) return res(0, l.capacity), nil } - ttw := l.rate - now.Sub(window) + ttw := l.window.Add(l.rate.Duration()).Sub(now) if l.rateLimitReached { return res(ttw, 0), ErrRateLimitExceeded } - c, err := l.db.Get(ctx, window) + c, err := l.db.Get(ctx, l.window) if err != nil { return nores, err @@ -177,7 +180,7 @@ func NewFixedTruncatedWindowRateLimiter( ) *FixedTruncatedWindowRateLimiter { return &FixedTruncatedWindowRateLimiter{ capacity: args.Capacity, - rate: args.Rate.Duration(), + rate: args.Rate, clock: args.Clock, db: args.DB, rateLimitReached: false, diff --git a/limiter_fixed_truncated_window_test.go b/limiter_fixed_truncated_window_test.go index 846ce20..84c65c2 100644 --- a/limiter_fixed_truncated_window_test.go +++ b/limiter_fixed_truncated_window_test.go @@ -158,15 +158,15 @@ func TestNewFixedTruncatedWindowRateLimiter(t *testing.T) { }, { method: check, - passTime: 0, + passTime: time.Second, expectedErr: ErrRateLimitExceeded, - expectedTtw: time.Second * 4, + expectedTtw: time.Second * 10, expectedFreeSlots: 0, }, { method: try, passTime: 0, - expectedTtw: time.Second * 4, + expectedTtw: time.Second * 9, expectedErr: ErrRateLimitExceeded, expectedFreeSlots: 0, }, @@ -201,10 +201,10 @@ func TestNewFixedTruncatedWindowRateLimiter(t *testing.T) { }, { method: check, - passTime: 0, - expectedErr: nil, - expectedFreeSlots: 2, - expectedTtw: 0, + passTime: time.Second * 7, + expectedErr: ErrRateLimitExceeded, + expectedFreeSlots: 0, + expectedTtw: time.Second * 7, }, { method: try, @@ -256,18 +256,18 @@ func TestNewFixedTruncatedWindowRateLimiter(t *testing.T) { expectedFreeSlots: 0, expectedErr: nil, }, - // Rate Limit is reached and 1 second passes... { method: check, passTime: 0, expectedErr: ErrRateLimitExceeded, expectedFreeSlots: 0, - expectedTtw: time.Second * 2, + expectedTtw: time.Second * 10, }, + // Rate Limit is reached and 1 second passes... { method: try, passTime: time.Second, - expectedTtw: time.Second * 2, + expectedTtw: time.Second * 10, expectedErr: ErrRateLimitExceeded, expectedFreeSlots: 0, }, @@ -275,23 +275,23 @@ func TestNewFixedTruncatedWindowRateLimiter(t *testing.T) { method: check, passTime: 0, expectedErr: ErrRateLimitExceeded, - expectedTtw: time.Second, + expectedTtw: time.Second * 9, expectedFreeSlots: 0, }, // Rate limit is still held. Moving 2 seconds and getting into next window { method: try, passTime: time.Second * 2, - expectedTtw: time.Second, + expectedTtw: time.Second * 9, expectedErr: ErrRateLimitExceeded, expectedFreeSlots: 0, }, { method: check, - passTime: 0, - expectedErr: nil, - expectedTtw: 0, - expectedFreeSlots: 2, + passTime: time.Second * 7, + expectedErr: ErrRateLimitExceeded, + expectedTtw: time.Second * 7, + expectedFreeSlots: 0, }, // Requests check be made again { @@ -346,15 +346,15 @@ func TestNewFixedTruncatedWindowRateLimiter(t *testing.T) { }, { method: dump, - passTime: 0, + passTime: time.Second * 2, expectedErr: nil, expectedFreeSlots: 0, - expectedTtw: time.Second * 2, + expectedTtw: time.Second * 10, }, { method: try, - passTime: 0, - expectedTtw: time.Second * 2, + passTime: time.Second * 8, + expectedTtw: time.Second * 8, expectedFreeSlots: 0, expectedErr: ErrRateLimitExceeded, }, @@ -362,8 +362,8 @@ func TestNewFixedTruncatedWindowRateLimiter(t *testing.T) { method: dump, passTime: 0, expectedErr: nil, - expectedFreeSlots: 0, - expectedTtw: time.Second * 2, + expectedFreeSlots: 2, + expectedTtw: 0, }, }, }, @@ -403,10 +403,10 @@ func TestNewFixedTruncatedWindowRateLimiter(t *testing.T) { }, { method: dump, - passTime: time.Second * 2, + passTime: time.Second * 10, expectedErr: nil, expectedFreeSlots: 0, - expectedTtw: time.Second * 2, + expectedTtw: time.Second * 10, }, { method: dump, diff --git a/limiter_token_fixed_window_test.go b/limiter_token_fixed_window_test.go index 7dcca46..353b1d5 100644 --- a/limiter_token_fixed_window_test.go +++ b/limiter_token_fixed_window_test.go @@ -173,7 +173,7 @@ func TestNewTokenFixedWindowRateLimiter_WindowTruncated(t *testing.T) { method: check, forwardAfter: 0, expectedFreeSlots: 6, - expectedTtw: time.Second * 4, + expectedTtw: time.Second * 10, expectedErr: ErrRateLimitExceeded, requestTokens: 20, }, @@ -182,14 +182,14 @@ func TestNewTokenFixedWindowRateLimiter_WindowTruncated(t *testing.T) { requestTokens: 7, // 21 -> Rate Limit! forwardAfter: 0, expectedFreeSlots: 0, - expectedTtw: time.Second * 4, + expectedTtw: time.Second * 10, expectedErr: ErrRateLimitExceeded, }, { method: check, forwardAfter: 0, expectedFreeSlots: 0, - expectedTtw: time.Second * 4, + expectedTtw: time.Second * 10, expectedErr: ErrRateLimitExceeded, requestTokens: 1, }, @@ -227,8 +227,8 @@ func TestNewTokenFixedWindowRateLimiter_WindowTruncated(t *testing.T) { }, { method: try, - requestTokens: 1, // 2 - forwardAfter: time.Second * 2, // 2022-02-05 00:00:11 + requestTokens: 1, + forwardAfter: time.Second * 9, // now it's 2022-02-05 00:00:09 expectedFreeSlots: 0, expectedTtw: 0, // TODO(@sonirico): Should this be zero? expectedErr: nil, @@ -252,7 +252,7 @@ func TestNewTokenFixedWindowRateLimiter_WindowTruncated(t *testing.T) { method: check, forwardAfter: 0, expectedFreeSlots: 0, - expectedTtw: time.Second * 9, + expectedTtw: time.Second * 10, expectedErr: ErrRateLimitExceeded, requestTokens: 1, }, @@ -301,33 +301,31 @@ func TestNewTokenFixedWindowRateLimiter_WindowTruncated(t *testing.T) { forwardAfter: 0, expectedFreeSlots: 2, expectedErr: ErrRateLimitExceeded, - expectedTtw: time.Second * 2, + expectedTtw: time.Second * 10, requestTokens: 3, }, - // Rate Limit is reached and 1 second passes... { method: try, requestTokens: 3, // 11 forwardAfter: time.Second, expectedFreeSlots: 0, - expectedTtw: time.Second * 2, + expectedTtw: time.Second * 10, expectedErr: ErrRateLimitExceeded, }, { method: check, forwardAfter: 0, expectedFreeSlots: 0, - expectedTtw: time.Second, + expectedTtw: time.Second * 9, expectedErr: ErrRateLimitExceeded, requestTokens: 1, }, - // Rate limit is still held. Moving 2 seconds and getting into next window { method: try, requestTokens: 3, // 11 - forwardAfter: time.Second * 2, + forwardAfter: time.Second * 9, expectedFreeSlots: 0, - expectedTtw: time.Second, + expectedTtw: time.Second * 9, expectedErr: ErrRateLimitExceeded, }, { @@ -338,7 +336,6 @@ func TestNewTokenFixedWindowRateLimiter_WindowTruncated(t *testing.T) { expectedErr: nil, requestTokens: 1, }, - // Requests check be made again { method: try, requestTokens: 3, // 3 @@ -392,7 +389,7 @@ func TestNewTokenFixedWindowRateLimiter_WindowTruncated(t *testing.T) { { method: dump, expectedFreeSlots: 0, - expectedTtw: time.Second * 2, + expectedTtw: time.Second * 10, expectedErr: nil, }, }, diff --git a/tests/main.go b/tests/main.go index b70d359..61758e3 100644 --- a/tests/main.go +++ b/tests/main.go @@ -20,35 +20,46 @@ func main() { redisCli := redis.NewClient(redisOpts) rateLimit := - pacemaker.NewFixedWindowRateLimiter( - pacemaker.FixedWindowArgs{ + pacemaker.NewFixedTruncatedWindowRateLimiter( + pacemaker.FixedTruncatedWindowArgs{ Capacity: 1200, Rate: pacemaker.Rate{ - Unit: time.Hour, - Amount: 1, + Unit: time.Minute, + Amount: 3, }, Clock: pacemaker.NewClock(), DB: pacemaker.NewFixedWindowRedisStorage( redisCli, pacemaker.FixedWindowRedisStorageOpts{ - Prefix: "pacemaker", + Prefix: "pacemaker-test-marcos-2", }, ), }, ) - result, err := rateLimit.Try(ctx) + result, err := rateLimit.Dump(ctx) + + if err != nil { + log.Printf("error dump: '%v'", err) + } + + log.Printf("Dump Result: '%v'", result) + + result, err = rateLimit.Try(ctx) if err != nil { log.Printf("error try: '%v'", err) } log.Printf("Try Result: '%v'", result) - result, err = rateLimit.Dump(ctx) + for { + time.Sleep(time.Second) + result, err = rateLimit.Dump(ctx) - if err != nil { - log.Printf("error dump: '%v'", err) - } + if err != nil { + log.Printf("error dump: '%v'", err) + } - log.Printf("Dump Result: '%v'", result) + log.Printf("Dump Result: '%v'", result) + } } diff --git a/utils.go b/utils.go index 53223a8..80f3576 100644 --- a/utils.go +++ b/utils.go @@ -4,6 +4,7 @@ import ( "crypto" "encoding/hex" "strings" + "time" "golang.org/x/exp/constraints" ) @@ -40,3 +41,9 @@ func min[T constraints.Integer](a, b T) T { } return a } + +// TimeGTE returns true if `target` is greater than or equals `from` +func TimeGTE(from time.Time, target time.Time) bool { + // return target.After(from) || target.Equal(from) + return !target.Before(from) +}