Skip to content

Commit

Permalink
multilimiter,adaptive ratelimit
Browse files Browse the repository at this point in the history
  • Loading branch information
tarunKoyalwar committed Dec 14, 2022
1 parent c2e8e28 commit cb63b75
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 99 deletions.
7 changes: 7 additions & 0 deletions README.md
Expand Up @@ -8,6 +8,13 @@

A Golang rate limit implementation which allows burst of request during the defined duration.


### Differences with 'golang.org/x/time/rate#Limiter'

The original library i.e `golang.org/x/time/rate` implements classic **token bucket** algorithm allowing a burst of tokens and a refill that happens at a specified ratio by one unit at a time whereas this implementation is a variant that allows a burst of tokens just like "the token bucket" algorithm, but the refill happens entirely at the defined ratio.

This allows scanners to respect maximum defined rate limits, pause until the allowed interval hits, and then process again at maximum speed. The original library slowed down requests according to the refill ratio.

## Example

An Example showing usage of ratelimit as a library is specified below:
Expand Down
26 changes: 26 additions & 0 deletions adaptive_ratelimit_test.go
@@ -0,0 +1,26 @@
package ratelimit_test

import (
"context"
"testing"
"time"

"github.com/projectdiscovery/ratelimit"
"github.com/stretchr/testify/require"
)

func TestAdaptiveRateLimit(t *testing.T) {
limiter := ratelimit.NewUnlimited(context.Background())
start := time.Now()

for i := 0; i < 132; i++ {
limiter.Take()
// got 429 / hit ratelimit after 100
if i == 100 {
// Retry-After and new limiter (calibrate using different statergies)
// new expected ratelimit 30req every 5 sec
limiter.SleepandReset(time.Duration(5)*time.Second, 30, time.Duration(5)*time.Second)
}
}
require.Equal(t, time.Since(start).Round(time.Second), time.Duration(10)*time.Second)
}
14 changes: 0 additions & 14 deletions example/main.go
Expand Up @@ -9,20 +9,6 @@ import (
)

func main() {

fmt.Printf("[+] Complete Tasks Using MulitLimiter with unique key\n")

multiLimiter := ratelimit.NewMultiLimiter(context.Background())
multiLimiter.Add("default", 10)
save1 := time.Now()

for i := 0; i < 11; i++ {
multiLimiter.Take("default")
fmt.Printf("MulitKey Task %v completed after %v\n", i, time.Since(save1))
}

fmt.Printf("\n[+] Complete Tasks Using Limiter\n")

// create a rate limiter by passing context, max tasks/tokens , time interval
limiter := ratelimit.New(context.Background(), 5, time.Duration(10*time.Second))

Expand Down
125 changes: 68 additions & 57 deletions keyratelimit.go
Expand Up @@ -3,87 +3,98 @@ package ratelimit
import (
"context"
"fmt"
"sync"
"time"
)

/*
Note:
This is somewhat modified version of TokenBucket
Here we consider buffer channel as a bucket
*/

// MultiLimiter allows burst of request during defined duration for each key
type MultiLimiter struct {
ticker *time.Ticker
tokens sync.Map // map of buffered channels map[string](chan struct{})
ctx context.Context
// Options of MultiLimiter
type Options struct {
Key string // Unique Identifier
IsUnlimited bool
MaxCount uint
Duration time.Duration
}

func (m *MultiLimiter) run() {
for {
select {
case <-m.ctx.Done():
m.ticker.Stop()
return

case <-m.ticker.C:
// Iterate and fill buffers to their capacity on every tick
m.tokens.Range(func(key, value any) bool {
tokenChan := value.(chan struct{})
if len(tokenChan) == cap(tokenChan) {
// no need to fill buffer/bucket
return true
} else {
for i := 0; i < cap(tokenChan)-len(tokenChan); i++ {
// fill bucket/buffer with tokens
tokenChan <- struct{}{}
}
}
// if it returns false range is stopped
return true
})
// Validate given MultiLimiter Options
func (o *Options) Validate() error {
if !o.IsUnlimited {
if o.Key == "" {
return fmt.Errorf("empty keys not allowed")
}
if o.MaxCount == 0 {
return fmt.Errorf("maxcount cannot be zero")
}
if o.Duration == 0 {
return fmt.Errorf("time duration not set")
}
}
return nil
}

// Adds new bucket with key and given tokenrate returns error if it already exists1
func (m *MultiLimiter) Add(key string, tokensPerMinute uint) error {
_, ok := m.tokens.Load(key)
// MultiLimiter is wrapper around Limiter than can limit based on a key
type MultiLimiter struct {
limiters map[string]*Limiter
ctx context.Context
}

// Adds new bucket with key
func (m *MultiLimiter) Add(opts *Options) error {
if err := opts.Validate(); err != nil {
return err
}
_, ok := m.limiters[opts.Key]
if ok {
return fmt.Errorf("key already exists")
}
// create a buffered channel of size `tokenPerMinute`
tokenChan := make(chan struct{}, tokensPerMinute)
for i := 0; i < int(tokensPerMinute); i++ {
// fill bucket/buffer with tokens
tokenChan <- struct{}{}
var rlimiter *Limiter
if opts.IsUnlimited {
rlimiter = NewUnlimited(m.ctx)
} else {
rlimiter = New(m.ctx, opts.MaxCount, opts.Duration)
}
m.tokens.Store(key, tokenChan)
m.limiters[opts.Key] = rlimiter
return nil
}

// Take one token from bucket / buffer returns error if key not present
// GetLimit returns current ratelimit of given key
func (m *MultiLimiter) GetLimit(key string) (uint, error) {
limiter, ok := m.limiters[key]
if !ok || limiter == nil {
return 0, fmt.Errorf("key doesnot exist")
}
return limiter.GetLimit(), nil
}

// Take one token from bucket returns error if key not present
func (m *MultiLimiter) Take(key string) error {
tokenValue, ok := m.tokens.Load(key)
if !ok {
limiter, ok := m.limiters[key]
if !ok || limiter == nil {
return fmt.Errorf("key doesnot exist")
}
tokenChan := tokenValue.(chan struct{})
<-tokenChan
limiter.Take()
return nil
}

// SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting)
func (m *MultiLimiter) SleepandReset(SleepTime time.Duration, opts *Options) error {
if err := opts.Validate(); err != nil {
return err
}
limiter, ok := m.limiters[opts.Key]
if !ok || limiter == nil {
return fmt.Errorf("key doesnot exist")
}
limiter.SleepandReset(SleepTime, opts.MaxCount, opts.Duration)
return nil
}

// NewMultiLimiter : Limits
func NewMultiLimiter(ctx context.Context) *MultiLimiter {
func NewMultiLimiter(ctx context.Context, opts *Options) (*MultiLimiter, error) {
if err := opts.Validate(); err != nil {
return nil, err
}
multilimiter := &MultiLimiter{
ticker: time.NewTicker(time.Minute), // different implementation than ratelimit
ctx: ctx,
tokens: sync.Map{},
ctx: ctx,
limiters: map[string]*Limiter{},
}

go multilimiter.run()

return multilimiter
return multilimiter, multilimiter.Add(opts)
}
51 changes: 37 additions & 14 deletions keyratelimit_test.go
Expand Up @@ -2,6 +2,7 @@ package ratelimit_test

import (
"context"
"sync"
"testing"
"time"

Expand All @@ -10,21 +11,43 @@ import (
)

func TestMultiLimiter(t *testing.T) {
limiter := ratelimit.NewMultiLimiter(context.Background())
limiter, err := ratelimit.NewMultiLimiter(context.Background(), &ratelimit.Options{
Key: "default",
IsUnlimited: false,
MaxCount: 100,
Duration: time.Duration(3) * time.Second,
})
require.Nil(t, err)
wg := &sync.WaitGroup{}

// 20 tokens every 1 minute
err := limiter.Add("default", 20)
require.Nil(t, err, "failed to add new key")
wg.Add(1)
go func() {
defer wg.Done()
defaultStart := time.Now()
for i := 0; i < 201; i++ {
errx := limiter.Take("default")
require.Nil(t, errx, "failed to take")
}
require.Greater(t, time.Since(defaultStart), time.Duration(6)*time.Second)
}()

before := time.Now()
// take 21 tokens
for i := 0; i < 21; i++ {
err2 := limiter.Take("default")
require.Nil(t, err2, "failed to take")
}
actual := time.Since(before)
expected := time.Duration(time.Minute)

require.Greater(t, actual, expected)
err = limiter.Add(&ratelimit.Options{
Key: "one",
IsUnlimited: false,
MaxCount: 100,
Duration: time.Duration(3) * time.Second,
})
require.Nil(t, err)

wg.Add(1)
go func() {
defer wg.Done()
oneStart := time.Now()
for i := 0; i < 201; i++ {
errx := limiter.Take("one")
require.Nil(t, errx)
}
require.Greater(t, time.Since(oneStart), time.Duration(6)*time.Second)
}()
wg.Wait()
}
63 changes: 49 additions & 14 deletions ratelimit.go
Expand Up @@ -13,16 +13,21 @@ type Limiter struct {
ticker *time.Ticker
tokens chan struct{}
ctx context.Context
// internal
cancelFunc context.CancelFunc
}

func (limiter *Limiter) run() {
func (limiter *Limiter) run(ctx context.Context) {
for {
if limiter.count == 0 {
<-limiter.ticker.C
limiter.count = limiter.maxCount
}

select {
case <-ctx.Done():
// Internal Context
limiter.ticker.Stop()
return
case <-limiter.ctx.Done():
limiter.ticker.Stop()
return
Expand All @@ -39,30 +44,60 @@ func (rateLimiter *Limiter) Take() {
<-rateLimiter.tokens
}

// GetLimit returns current rate limit per given duration
func (ratelimiter *Limiter) GetLimit() uint {
return ratelimiter.maxCount
}

// SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting)
func (ratelimiter *Limiter) SleepandReset(sleepTime time.Duration, newLimit uint, duration time.Duration) {
// stop existing Limiter using internalContext
ratelimiter.cancelFunc()
// drain any token
close(ratelimiter.tokens)
<-ratelimiter.tokens
// sleep
time.Sleep(sleepTime)
//reset and start
ratelimiter.maxCount = newLimit
ratelimiter.count = newLimit
ratelimiter.ticker = time.NewTicker(duration)
ratelimiter.tokens = make(chan struct{})
ctx, cancel := context.WithCancel(context.TODO())
ratelimiter.cancelFunc = cancel
go ratelimiter.run(ctx)
}

// New creates a new limiter instance with the tokens amount and the interval
func New(ctx context.Context, max uint, duration time.Duration) *Limiter {
internalctx, cancel := context.WithCancel(context.TODO())

limiter := &Limiter{
maxCount: uint(max),
count: uint(max),
ticker: time.NewTicker(duration),
tokens: make(chan struct{}),
ctx: ctx,
maxCount: uint(max),
count: uint(max),
ticker: time.NewTicker(duration),
tokens: make(chan struct{}),
ctx: ctx,
cancelFunc: cancel,
}
go limiter.run()
go limiter.run(internalctx)

return limiter
}

// NewUnlimited create a bucket with approximated unlimited tokens
func NewUnlimited(ctx context.Context) *Limiter {
internalctx, cancel := context.WithCancel(context.TODO())

limiter := &Limiter{
maxCount: math.MaxUint,
count: math.MaxUint,
ticker: time.NewTicker(time.Millisecond),
tokens: make(chan struct{}),
ctx: ctx,
maxCount: math.MaxUint,
count: math.MaxUint,
ticker: time.NewTicker(time.Millisecond),
tokens: make(chan struct{}),
ctx: ctx,
cancelFunc: cancel,
}
go limiter.run()
go limiter.run(internalctx)

return limiter
}

0 comments on commit cb63b75

Please sign in to comment.