Skip to content

Commit

Permalink
Merge pull request #27 from projectdiscovery/issue-26-data-race
Browse files Browse the repository at this point in the history
fix race conditions in ratelimit
  • Loading branch information
Mzack9999 committed Jan 31, 2023
2 parents f514092 + 69db5c5 commit d7e697e
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 95 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
uses: actions/checkout@v3

- name: Test
run: go test ./...
run: go test -race ./...

- name: Build Example
run: go build example/main.go
37 changes: 14 additions & 23 deletions adaptive_ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,17 @@
package ratelimit_test

import (
"context"
"testing"
"time"
// func TestAdaptiveRateLimit(t *testing.T) {
// limiter := ratelimit.NewUnlimited(context.Background())
// start := time.Now()

"github.com/projectdiscovery/ratelimit"
"github.com/stretchr/testify/require"
)

func TestAdaptiveRateLimit(t *testing.T) {
limiter := ratelimit.NewUnlimited(context.Background())
start := time.Now()

for i := 0; i < 132; i++ {
limiter.Take()
// got 429 / hit ratelimit after 100
if i == 100 {
// Retry-After and new limiter (calibrate using different statergies)
// new expected ratelimit 30req every 5 sec
limiter.SleepandReset(time.Duration(5)*time.Second, 30, time.Duration(5)*time.Second)
}
}
require.Equal(t, time.Since(start).Round(time.Second), time.Duration(10)*time.Second)
}
// for i := 0; i < 132; i++ {
// limiter.Take()
// // got 429 / hit ratelimit after 100
// if i == 100 {
// // Retry-After and new limiter (calibrate using different statergies)
// // new expected ratelimit 30req every 5 sec
// limiter.SleepandReset(time.Duration(5)*time.Second, 30, time.Duration(5)*time.Second)
// }
// }
// require.Equal(t, time.Since(start).Round(time.Second), time.Duration(10)*time.Second)
// }
6 changes: 2 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ module github.com/projectdiscovery/ratelimit

go 1.18

require (
github.com/stretchr/testify v1.8.1
golang.org/x/exp v0.0.0-20221217163422-3c43f8badb15
)
require github.com/stretchr/testify v1.8.1

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/projectdiscovery/utils v0.0.6 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/projectdiscovery/utils v0.0.6 h1:6SDn/5E5NxrAfcYrZ7omXPmiU9n8p0rKXZ4BAOQyzbw=
github.com/projectdiscovery/utils v0.0.6/go.mod h1:PCwA5YuCYWPgHaGiZmr53/SA9iGQmAnw7DSHuhr8VPQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/exp v0.0.0-20221217163422-3c43f8badb15 h1:5oN1Pz/eDhCpbMbLstvIPa0b/BEQo6g6nwV3pLjfM6w=
golang.org/x/exp v0.0.0-20221217163422-3c43f8badb15/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
Expand Down
88 changes: 51 additions & 37 deletions keyratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@ package ratelimit

import (
"context"
"fmt"
"sync"
"time"

"golang.org/x/exp/maps"
errorutil "github.com/projectdiscovery/utils/errors"
)

var (
ErrKeyAlreadyExists = errorutil.NewWithTag("MultiLimiter", "key already exists")
ErrKeyMissing = errorutil.NewWithTag("MultiLimiter", "key does not exist")
)

// Options of MultiLimiter
Expand All @@ -20,21 +25,21 @@ type Options struct {
func (o *Options) Validate() error {
if !o.IsUnlimited {
if o.Key == "" {
return fmt.Errorf("empty keys not allowed")
return errorutil.NewWithTag("MultiLimiter", "empty keys not allowed")
}
if o.MaxCount == 0 {
return fmt.Errorf("maxcount cannot be zero")
return errorutil.NewWithTag("MultiLimiter", "maxcount cannot be zero")
}
if o.Duration == 0 {
return fmt.Errorf("time duration not set")
return errorutil.NewWithTag("MultiLimiter", "time duration not set")
}
}
return nil
}

// MultiLimiter is wrapper around Limiter than can limit based on a key
type MultiLimiter struct {
limiters map[string]*Limiter
limiters sync.Map // map of limiters
ctx context.Context
}

Expand All @@ -43,68 +48,77 @@ func (m *MultiLimiter) Add(opts *Options) error {
if err := opts.Validate(); err != nil {
return err
}
_, ok := m.limiters[opts.Key]
if ok {
return fmt.Errorf("key already exists")
}
var rlimiter *Limiter
if opts.IsUnlimited {
rlimiter = NewUnlimited(m.ctx)
} else {
rlimiter = New(m.ctx, opts.MaxCount, opts.Duration)
}
m.limiters[opts.Key] = rlimiter
// ok if true if key already exists
_, ok := m.limiters.LoadOrStore(opts.Key, rlimiter)
if ok {
return ErrKeyAlreadyExists.Msgf("key: %v", opts.Key)
}
return nil
}

// GetLimit returns current ratelimit of given key
func (m *MultiLimiter) GetLimit(key string) (uint, error) {
limiter, ok := m.limiters[key]
if !ok || limiter == nil {
return 0, fmt.Errorf("key doesnot exist")
limiter, err := m.get(key)
if err != nil {
return 0, err
}
return limiter.GetLimit(), nil
}

// Take one token from bucket returns error if key not present
func (m *MultiLimiter) Take(key string) error {
limiter, ok := m.limiters[key]
if !ok || limiter == nil {
return fmt.Errorf("key doesnot exist")
limiter, err := m.get(key)
if err != nil {
return err
}
limiter.Take()
return nil
}

// Stop internal limiters with defined keys or all if no key is provided
func (m *MultiLimiter) Stop(keys ...string) {
if len(keys) > 0 {
m.stopWithKeys(keys...)
} else {
m.stopWithKeys(maps.Keys(m.limiters)...)
// AddAndTake adds key if not present and then takes token from bucket
func (m *MultiLimiter) AddAndTake(opts *Options) {
if limiter, err := m.get(opts.Key); err == nil {
limiter.Take()
return
}
_ = m.Add(opts)
_ = m.Take(opts.Key)
}

// stopWithKeys stops the internal limiters matching keys
func (m *MultiLimiter) stopWithKeys(keys ...string) {
for _, key := range keys {
if limiter, ok := m.limiters[key]; ok {
// Stop internal limiters with defined keys or all if no key is provided
func (m *MultiLimiter) Stop(keys ...string) {
if len(keys) == 0 {
m.limiters.Range(func(key, value any) bool {
if limiter, ok := value.(*Limiter); ok {
limiter.Stop()
}
return true
})
return
}
for _, v := range keys {
if limiter, err := m.get(v); err == nil {
limiter.Stop()
}
}
}

// SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting)
func (m *MultiLimiter) SleepandReset(SleepTime time.Duration, opts *Options) error {
if err := opts.Validate(); err != nil {
return err
// get returns *Limiter instance
func (m *MultiLimiter) get(key string) (*Limiter, error) {
val, _ := m.limiters.Load(key)
if val == nil {
return nil, ErrKeyMissing.Msgf("key: %v", key)
}
limiter, ok := m.limiters[opts.Key]
if !ok || limiter == nil {
return fmt.Errorf("key doesnot exist")
if limiter, ok := val.(*Limiter); ok {
return limiter, nil
}
limiter.SleepandReset(SleepTime, opts.MaxCount, opts.Duration)
return nil
return nil, errorutil.NewWithTag("MultiLimiter", "type assertion of rateLimiter failed in multiLimiter")
}

// NewMultiLimiter : Limits
Expand All @@ -114,7 +128,7 @@ func NewMultiLimiter(ctx context.Context, opts *Options) (*MultiLimiter, error)
}
multilimiter := &MultiLimiter{
ctx: ctx,
limiters: map[string]*Limiter{},
limiters: sync.Map{},
}
return multilimiter, multilimiter.Add(opts)
}
7 changes: 5 additions & 2 deletions keyratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func TestMultiLimiter(t *testing.T) {
})
require.Nil(t, err)
wg := &sync.WaitGroup{}
expectedTime := (time.Duration(6) * time.Second).Round(time.Millisecond)

wg.Add(1)
go func() {
Expand All @@ -28,7 +29,8 @@ func TestMultiLimiter(t *testing.T) {
errx := limiter.Take("default")
require.Nil(t, errx, "failed to take")
}
require.Greater(t, time.Since(defaultStart), time.Duration(6)*time.Second)
timeTaken := time.Since(defaultStart).Round(time.Millisecond)
require.GreaterOrEqualf(t, timeTaken.Nanoseconds(), expectedTime.Nanoseconds(), "more token sent than requested in given timeframe")
}()

err = limiter.Add(&ratelimit.Options{
Expand All @@ -47,7 +49,8 @@ func TestMultiLimiter(t *testing.T) {
errx := limiter.Take("one")
require.Nil(t, errx)
}
require.Greater(t, time.Since(oneStart), time.Duration(6)*time.Second)
timeTaken := time.Since(oneStart).Round(time.Millisecond)
require.GreaterOrEqualf(t, timeTaken.Nanoseconds(), expectedTime.Nanoseconds(), "more token sent than requested in given timeframe")
}()
wg.Wait()
}
40 changes: 21 additions & 19 deletions ratelimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type Limiter struct {
}

func (limiter *Limiter) run(ctx context.Context) {
defer close(limiter.tokens)
for {
if limiter.count == 0 {
<-limiter.ticker.C
Expand Down Expand Up @@ -49,28 +50,29 @@ func (ratelimiter *Limiter) GetLimit() uint {
return ratelimiter.maxCount
}

// SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting)
func (ratelimiter *Limiter) SleepandReset(sleepTime time.Duration, newLimit uint, duration time.Duration) {
// stop existing Limiter using internalContext
ratelimiter.cancelFunc()
// drain any token
close(ratelimiter.tokens)
<-ratelimiter.tokens
// sleep
time.Sleep(sleepTime)
//reset and start
ratelimiter.maxCount = newLimit
ratelimiter.count = newLimit
ratelimiter.ticker = time.NewTicker(duration)
ratelimiter.tokens = make(chan struct{})
ctx, cancel := context.WithCancel(context.TODO())
ratelimiter.cancelFunc = cancel
go ratelimiter.run(ctx)
}
// TODO: SleepandReset should be able to handle multiple calls without resetting multiple times
// Which is not possible in this implementation
// // SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting)
// func (ratelimiter *Limiter) SleepandReset(sleepTime time.Duration, newLimit uint, duration time.Duration) {
// // stop existing Limiter using internalContext
// ratelimiter.cancelFunc()
// // drain any token
// close(ratelimiter.tokens)
// <-ratelimiter.tokens
// // sleep
// time.Sleep(sleepTime)
// //reset and start
// ratelimiter.maxCount = newLimit
// ratelimiter.count = newLimit
// ratelimiter.ticker = time.NewTicker(duration)
// ratelimiter.tokens = make(chan struct{})
// ctx, cancel := context.WithCancel(context.TODO())
// ratelimiter.cancelFunc = cancel
// go ratelimiter.run(ctx)
// }

// Stop the rate limiter canceling the internal context
func (ratelimiter *Limiter) Stop() {
defer close(ratelimiter.tokens)
if ratelimiter.cancelFunc != nil {
ratelimiter.cancelFunc()
}
Expand Down
27 changes: 20 additions & 7 deletions ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ func TestRateLimit(t *testing.T) {
limiter.Take()
count++
}
took := time.Since(start)
took := time.Since(start).Nanoseconds()
require.Equal(t, count, 10)
require.True(t, took < expected)
require.Less(t, took, expected.Nanoseconds())
// take another one above max
limiter.Take()
took = time.Since(start)
require.True(t, took >= expected)
took = time.Since(start).Nanoseconds()
require.GreaterOrEqual(t, took, expected.Nanoseconds())
})

t.Run("Unlimited Rate Limit", func(t *testing.T) {
Expand All @@ -41,7 +41,7 @@ func TestRateLimit(t *testing.T) {
}
took := time.Since(start)
require.Equal(t, count, 1000)
require.True(t, took < time.Duration(1*time.Second))
require.Lessf(t, took.Nanoseconds(), time.Duration(1*time.Second).Nanoseconds(), "burst rate of unlimited ratelimit is too low")
})

t.Run("Concurrent Rate Limit Use", func(t *testing.T) {
Expand All @@ -67,9 +67,22 @@ func TestRateLimit(t *testing.T) {
}

wg.Wait()
limiter.Stop()
took := time.Since(start)
require.Equal(t, expected, int(count))
require.True(t, took >= time.Duration(10*time.Second))
require.True(t, took <= time.Duration(12*time.Second))
require.GreaterOrEqualf(t, took, time.Duration(10*time.Second).Nanoseconds(), "ratelimit timeframe mismatch should be > 10s")
require.LessOrEqualf(t, took, time.Duration(12*time.Second).Nanoseconds(), "ratelimit timeframe mismatch should be < 10s")
})

t.Run("Time comparsion", func(t *testing.T) {
limiter := New(context.TODO(), 100, time.Duration(3)*time.Second)
// if ratelimit works correctly it should take at least 6 sec to take/consume 201 tokens
startTime := time.Now()
for i := 0; i < 201; i++ {
limiter.Take()
}
timetaken := time.Since(startTime)
expected := time.Duration(6) * time.Second
require.GreaterOrEqualf(t, timetaken.Nanoseconds(), expected.Nanoseconds(), "more tokens sent than expected with ratelimit")
})
}

0 comments on commit d7e697e

Please sign in to comment.