diff --git a/bucket/leaky.go b/bucket/leaky.go new file mode 100644 index 0000000..c4e792f --- /dev/null +++ b/bucket/leaky.go @@ -0,0 +1,54 @@ +package bucket + +import ( + "sync" + "time" +) + +// LeakyLimiter implements a leaky bucket rate limiter. Requests fill the bucket, +// which drains at a constant rate. If the bucket is full, requests are rejected. +// Unlike TokenLimiter, this enforces a smooth output rate rather than allowing bursts. +type LeakyLimiter struct { + mu sync.Mutex + + capacity, level, rate float64 + lastUpdatedAt time.Time +} + +// NewLeakyLimiter creates a new leaky bucket limiter. +// Capacity is the maximum bucket size. Rate is how many requests drain per second. +func NewLeakyLimiter(capacity, rate uint32) *LeakyLimiter { + return &LeakyLimiter{ + capacity: float64(capacity), + rate: float64(rate), + lastUpdatedAt: time.Now(), + } +} + +func (lim *LeakyLimiter) update() { + t := time.Now() + if t.Before(lim.lastUpdatedAt) { + return + } + + lim.level = max(0, lim.level-t.Sub(lim.lastUpdatedAt).Seconds()*lim.rate) + lim.lastUpdatedAt = t +} + +// Allow reports whether a request is allowed. It adds one to the bucket level +// if there is room and returns true. If the bucket is full, it returns false +// without blocking. +func (lim *LeakyLimiter) Allow() bool { + lim.mu.Lock() + defer lim.mu.Unlock() + + lim.update() + + if lim.level+1 <= lim.capacity { + lim.level++ + + return true + } + + return false +} diff --git a/bucket/leaky_test.go b/bucket/leaky_test.go new file mode 100644 index 0000000..62d15e8 --- /dev/null +++ b/bucket/leaky_test.go @@ -0,0 +1,72 @@ +package bucket_test + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/serroba/rate/bucket" + "github.com/stretchr/testify/assert" +) + +func TestLeakyLimiter_Allow(t *testing.T) { + type args struct { + capacity uint32 + rate uint32 + } + + tests := []struct { + name string + args args + previousAttempts int + want bool + }{ + {name: "Test With No Capacity", args: args{capacity: 0, rate: 0}, want: false}, + {name: "Test With Capacity 1", args: args{capacity: 1, rate: 0}, want: true}, + { + name: "Test With Capacity 1 with previous attempt", + args: args{capacity: 1, rate: 0}, + previousAttempts: 1, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lim := bucket.NewLeakyLimiter(tt.args.capacity, tt.args.rate) + for range tt.previousAttempts { + lim.Allow() + } + + assert.Equalf(t, tt.want, lim.Allow(), "Allow()") + }) + } +} + +func TestLeakyLimiter_Allow_Concurrent(t *testing.T) { + var ( + allow atomic.Int64 + deny atomic.Int64 + wg sync.WaitGroup + ) + + lim := bucket.NewLeakyLimiter(10, 5) + + for range 15 { + wg.Add(1) + + go func() { + defer wg.Done() + + if lim.Allow() { + allow.Add(1) + } else { + deny.Add(1) + } + }() + } + + wg.Wait() + + assert.Equal(t, int64(10), allow.Load()) + assert.Equal(t, int64(5), deny.Load()) +} diff --git a/bucket/registry.go b/bucket/registry.go index 049ce22..fa3fae0 100644 --- a/bucket/registry.go +++ b/bucket/registry.go @@ -7,24 +7,28 @@ import ( type ( Identifier string Registry struct { - mu sync.Mutex - limiters map[Identifier]*TokenLimiter - capacity, rate uint32 + mu sync.Mutex + factory LimiterFactory + limiters map[Identifier]Limiter } ) -func NewRegistry(capacity, rate uint32, users ...Identifier) (*Registry, error) { - limiters := make(map[Identifier]*TokenLimiter) +type Limiter interface { + Allow() bool +} + +type LimiterFactory func() Limiter + +func NewRegistry(factory LimiterFactory, keys ...Identifier) (*Registry, error) { + limiters := make(map[Identifier]Limiter) - for _, user := range users { - limiter := NewLimiter(capacity, rate) - limiters[user] = limiter + for _, key := range keys { + limiters[key] = factory() } return &Registry{ limiters: limiters, - capacity: capacity, - rate: rate, + factory: factory, }, nil } @@ -34,7 +38,7 @@ func (r *Registry) Allow(key Identifier) bool { lim, ok := r.limiters[key] if !ok { - lim = NewLimiter(r.capacity, r.rate) + lim = r.factory() r.limiters[key] = lim } diff --git a/bucket/registry_test.go b/bucket/registry_test.go index 5604747..13aefea 100644 --- a/bucket/registry_test.go +++ b/bucket/registry_test.go @@ -10,134 +10,218 @@ import ( "github.com/stretchr/testify/require" ) +var strategies = []struct { + name string + factory func(capacity, rate uint32) bucket.LimiterFactory +}{ + { + name: "token", + factory: func(capacity, rate uint32) bucket.LimiterFactory { + return func() bucket.Limiter { + return bucket.NewTokenLimiter(capacity, rate) + } + }, + }, + { + name: "leaky", + factory: func(capacity, rate uint32) bucket.LimiterFactory { + return func() bucket.Limiter { + return bucket.NewLeakyLimiter(capacity, rate) + } + }, + }, +} + func TestNewRegistry(t *testing.T) { - reg, err := bucket.NewRegistry(10, 2) - require.NoError(t, err) - require.NotNil(t, reg) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() + + reg, err := bucket.NewRegistry(s.factory(10, 2)) + require.NoError(t, err) + require.NotNil(t, reg) + }) + } } func TestNewRegistry_WithUsers(t *testing.T) { - reg, err := bucket.NewRegistry(10, 2, "alice", "bob") - require.NoError(t, err) - require.NotNil(t, reg) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() + + reg, err := bucket.NewRegistry(s.factory(10, 2), "alice", "bob") + require.NoError(t, err) + require.NotNil(t, reg) + }) + } } func TestRegistry_Allow_ExistingUser(t *testing.T) { - reg, err := bucket.NewRegistry(2, 0, "alice") - require.NoError(t, err) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() - require.True(t, reg.Allow("alice")) - require.True(t, reg.Allow("alice")) - require.False(t, reg.Allow("alice")) + reg, err := bucket.NewRegistry(s.factory(2, 0), "alice") + require.NoError(t, err) + + require.True(t, reg.Allow("alice")) + require.True(t, reg.Allow("alice")) + require.False(t, reg.Allow("alice")) + }) + } } func TestRegistry_Allow_NewUser(t *testing.T) { - reg, err := bucket.NewRegistry(2, 0) - require.NoError(t, err) + t.Parallel() - // First call for a new user should create limiter and allow - require.True(t, reg.Allow("alice")) - require.True(t, reg.Allow("alice")) - require.False(t, reg.Allow("alice")) + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() + + reg, err := bucket.NewRegistry(s.factory(2, 0)) + require.NoError(t, err) + + // First call for a new user should create limiter and allow + require.True(t, reg.Allow("alice")) + require.True(t, reg.Allow("alice")) + require.False(t, reg.Allow("alice")) + }) + } } func TestRegistry_Allow_IndependentUsers(t *testing.T) { - reg, err := bucket.NewRegistry(1, 0) - require.NoError(t, err) + t.Parallel() - // Each user has their own bucket - require.True(t, reg.Allow("alice")) - require.True(t, reg.Allow("bob")) + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() - // Both exhausted now - require.False(t, reg.Allow("alice")) - require.False(t, reg.Allow("bob")) + reg, err := bucket.NewRegistry(s.factory(1, 0)) + require.NoError(t, err) + + // Each user has their own bucket + require.True(t, reg.Allow("alice")) + require.True(t, reg.Allow("bob")) + + // Both exhausted now + require.False(t, reg.Allow("alice")) + require.False(t, reg.Allow("bob")) + }) + } } func TestRegistry_Allow_Concurrent(t *testing.T) { - reg, err := bucket.NewRegistry(100, 0) - require.NoError(t, err) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() - var ( - allowed atomic.Int64 - wg sync.WaitGroup - ) + reg, err := bucket.NewRegistry(s.factory(100, 0)) + require.NoError(t, err) - // 50 goroutines per user, 4 users = 200 goroutines - users := []bucket.Identifier{"alice", "bob", "charlie", "diana"} - for _, user := range users { - for range 50 { - wg.Add(1) + var ( + allowed atomic.Int64 + wg sync.WaitGroup + ) - go func(u bucket.Identifier) { - defer wg.Done() + // 50 goroutines per user, 4 users = 200 goroutines + users := []bucket.Identifier{"alice", "bob", "charlie", "diana"} + for _, user := range users { + for range 50 { + wg.Add(1) - if reg.Allow(u) { - allowed.Add(1) + go func(u bucket.Identifier) { + defer wg.Done() + + if reg.Allow(u) { + allowed.Add(1) + } + }(user) } - }(user) - } - } + } - wg.Wait() + wg.Wait() - // Each user has capacity 100, only 50 requests each, so all should be allowed - require.Equal(t, int64(200), allowed.Load()) + // Each user has capacity 100, only 50 requests each, so all should be allowed + require.Equal(t, int64(200), allowed.Load()) + }) + } } func TestRegistry_Deny_Concurrent(t *testing.T) { - reg, err := bucket.NewRegistry(100, 0) - require.NoError(t, err) - - var ( - allowed atomic.Int64 - deny atomic.Int64 - wg sync.WaitGroup - ) - - // 50 goroutines per user, 4 users = 200 goroutines - users := []bucket.Identifier{"alice", "bob", "charlie", "diana"} - for _, user := range users { - for range 110 { - wg.Add(1) - - go func(u bucket.Identifier) { - defer wg.Done() - - if reg.Allow(u) { - allowed.Add(1) - } else { - deny.Add(1) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() + + reg, err := bucket.NewRegistry(s.factory(100, 0)) + require.NoError(t, err) + + var ( + allowed atomic.Int64 + deny atomic.Int64 + wg sync.WaitGroup + ) + + users := []bucket.Identifier{"alice", "bob", "charlie", "diana"} + for _, user := range users { + for range 110 { + wg.Add(1) + + go func(u bucket.Identifier) { + defer wg.Done() + + if reg.Allow(u) { + allowed.Add(1) + } else { + deny.Add(1) + } + }(user) } - }(user) - } - } + } - wg.Wait() + wg.Wait() - // Each user has capacity 100, only 50 requests each, so all should be allowed - assert.Equal(t, int64(400), allowed.Load()) - assert.Equal(t, int64(40), deny.Load()) + assert.Equal(t, int64(400), allowed.Load()) + assert.Equal(t, int64(40), deny.Load()) + }) + } } func TestRegistry_Allow_ConcurrentNewUsers(t *testing.T) { - reg, err := bucket.NewRegistry(5, 0) - require.NoError(t, err) + t.Parallel() - var wg sync.WaitGroup + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() - // Create 100 different users concurrently - for i := range 100 { - wg.Add(1) + reg, err := bucket.NewRegistry(s.factory(5, 0)) + require.NoError(t, err) - go func(id int) { - defer wg.Done() + var wg sync.WaitGroup - user := bucket.Identifier(rune('a' + id%26)) - reg.Allow(user) - }(i) - } + // Create 100 different users concurrently + for i := range 100 { + wg.Add(1) - wg.Wait() - // If we get here without panic or race, the test passes + go func(id int) { + defer wg.Done() + + user := bucket.Identifier(rune('a' + id%26)) + reg.Allow(user) + }(i) + } + + wg.Wait() + // If we get here without panic or race, the test passes + }) + } } diff --git a/bucket/token.go b/bucket/token.go index 7c5994f..b451aa2 100644 --- a/bucket/token.go +++ b/bucket/token.go @@ -24,9 +24,9 @@ type TokenLimiter struct { clock clock } -// NewLimiter creates a new rate limiter with the given capacity and refill rate. +// NewTokenLimiter creates a new rate limiter with the given capacity and refill rate. // Capacity is the maximum burst size. Rate is tokens added per second. -func NewLimiter(capacity, rate uint32) *TokenLimiter { +func NewTokenLimiter(capacity, rate uint32) *TokenLimiter { return NewLimiterWithClock(capacity, rate, realClock{}) } diff --git a/bucket/token_test.go b/bucket/token_test.go index 64c3f61..f63fd2e 100644 --- a/bucket/token_test.go +++ b/bucket/token_test.go @@ -90,7 +90,7 @@ func TestLimiter_Allow(t *testing.T) { } func TestLimiter_Allow_Concurrent(t *testing.T) { - lim := bucket.NewLimiter(100, 0) + lim := bucket.NewTokenLimiter(100, 0) var ( allowed atomic.Int64