From 19f8321b47db462ff6df11998b2d3b4fb160a64b Mon Sep 17 00:00:00 2001 From: Sebastian Machuca Date: Thu, 1 Jan 2026 12:57:29 +1100 Subject: [PATCH 1/2] Adding coverage for both registries --- bucket/leaky.go | 46 ++++++++ bucket/leaky_test.go | 72 +++++++++++++ bucket/registry.go | 26 +++-- bucket/registry_test.go | 232 +++++++++++++++++++++++++++------------- bucket/token.go | 4 +- bucket/token_test.go | 2 +- 6 files changed, 292 insertions(+), 90 deletions(-) create mode 100644 bucket/leaky.go create mode 100644 bucket/leaky_test.go diff --git a/bucket/leaky.go b/bucket/leaky.go new file mode 100644 index 0000000..48ad9a8 --- /dev/null +++ b/bucket/leaky.go @@ -0,0 +1,46 @@ +package bucket + +import ( + "sync" + "time" +) + +type LeakyLimiter struct { + mu sync.Mutex + + capacity, level, rate float64 + lastUpdatedAt time.Time +} + +func NewLeakyLimiter(capacity, rate uint32) *LeakyLimiter { + return &LeakyLimiter{ + capacity: float64(capacity), + rate: float64(rate), + lastUpdatedAt: time.Now(), + } +} + +func (lim *LeakyLimiter) update() { + t := time.Now() + if t.Before(lim.lastUpdatedAt) { + return + } + + lim.level = max(0, lim.level-t.Sub(lim.lastUpdatedAt).Seconds()*lim.rate) + lim.lastUpdatedAt = t +} + +func (lim *LeakyLimiter) Allow() bool { + lim.mu.Lock() + defer lim.mu.Unlock() + + lim.update() + + if lim.level+1 <= lim.capacity { + lim.level++ + + return true + } + + return false +} diff --git a/bucket/leaky_test.go b/bucket/leaky_test.go new file mode 100644 index 0000000..4370272 --- /dev/null +++ b/bucket/leaky_test.go @@ -0,0 +1,72 @@ +package bucket_test + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/serroba/rate/bucket" + "github.com/stretchr/testify/assert" +) + +func TestLeakyLimiter_Allow(t *testing.T) { + type args struct { + capacity uint32 + rate uint32 + } + + tests := []struct { + name string + args args + previousAttempts int + want bool + }{ + {name: "Test With No Capacity", args: args{capacity: 0, rate: 0}, want: false}, + {name: "Test With Capacity 1", args: args{capacity: 1, rate: 0}, want: true}, + { + name: "Test With Capacity 1 with previous attempt", + args: args{capacity: 1, rate: 0}, + previousAttempts: 1, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lim := bucket.NewLeakyLimiter(tt.args.capacity, tt.args.rate) + for range tt.previousAttempts { + lim.Allow() + } + + assert.Equalf(t, tt.want, lim.Allow(), "Allow()") + }) + } +} + +func TestWithConcurrency(t *testing.T) { + var ( + allow atomic.Int64 + deny atomic.Int64 + wg sync.WaitGroup + ) + + lim := bucket.NewLeakyLimiter(10, 5) + + for range 15 { + wg.Add(1) + + go func() { + defer wg.Done() + + if lim.Allow() { + allow.Add(1) + } else { + deny.Add(1) + } + }() + } + + wg.Wait() + + assert.Equal(t, int64(10), allow.Load()) + assert.Equal(t, int64(5), deny.Load()) +} diff --git a/bucket/registry.go b/bucket/registry.go index 049ce22..8f28158 100644 --- a/bucket/registry.go +++ b/bucket/registry.go @@ -7,24 +7,28 @@ import ( type ( Identifier string Registry struct { - mu sync.Mutex - limiters map[Identifier]*TokenLimiter - capacity, rate uint32 + mu sync.Mutex + factory LimiterFactory + limiters map[Identifier]Limiter } ) -func NewRegistry(capacity, rate uint32, users ...Identifier) (*Registry, error) { - limiters := make(map[Identifier]*TokenLimiter) +type Limiter interface { + Allow() bool +} + +type LimiterFactory func(key Identifier) Limiter + +func NewRegistry(factory LimiterFactory, keys ...Identifier) (*Registry, error) { + limiters := make(map[Identifier]Limiter) - for _, user := range users { - limiter := NewLimiter(capacity, rate) - limiters[user] = limiter + for _, key := range keys { + limiters[key] = factory(key) } return &Registry{ limiters: limiters, - capacity: capacity, - rate: rate, + factory: factory, }, nil } @@ -34,7 +38,7 @@ func (r *Registry) Allow(key Identifier) bool { lim, ok := r.limiters[key] if !ok { - lim = NewLimiter(r.capacity, r.rate) + lim = r.factory(key) r.limiters[key] = lim } diff --git a/bucket/registry_test.go b/bucket/registry_test.go index 5604747..ebe448c 100644 --- a/bucket/registry_test.go +++ b/bucket/registry_test.go @@ -10,118 +10,198 @@ import ( "github.com/stretchr/testify/require" ) +var strategies = []struct { + name string + factory func(capacity, rate uint32) bucket.LimiterFactory +}{ + { + name: "token", + factory: func(capacity, rate uint32) bucket.LimiterFactory { + return func(bucket.Identifier) bucket.Limiter { + return bucket.NewTokenLimiter(capacity, rate) + } + }, + }, + { + name: "leaky", + factory: func(capacity, rate uint32) bucket.LimiterFactory { + return func(bucket.Identifier) bucket.Limiter { + return bucket.NewLeakyLimiter(capacity, rate) + } + }, + }, +} + func TestNewRegistry(t *testing.T) { - reg, err := bucket.NewRegistry(10, 2) - require.NoError(t, err) - require.NotNil(t, reg) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() + + reg, err := bucket.NewRegistry(s.factory(10, 2)) + require.NoError(t, err) + require.NotNil(t, reg) + }) + } } func TestNewRegistry_WithUsers(t *testing.T) { - reg, err := bucket.NewRegistry(10, 2, "alice", "bob") - require.NoError(t, err) - require.NotNil(t, reg) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() + + reg, err := bucket.NewRegistry(s.factory(10, 2), "alice", "bob") + require.NoError(t, err) + require.NotNil(t, reg) + }) + } } func TestRegistry_Allow_ExistingUser(t *testing.T) { - reg, err := bucket.NewRegistry(2, 0, "alice") - require.NoError(t, err) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() + + reg, err := bucket.NewRegistry(s.factory(2, 0), "alice") + require.NoError(t, err) - require.True(t, reg.Allow("alice")) - require.True(t, reg.Allow("alice")) - require.False(t, reg.Allow("alice")) + require.True(t, reg.Allow("alice")) + require.True(t, reg.Allow("alice")) + require.False(t, reg.Allow("alice")) + }) + } } func TestRegistry_Allow_NewUser(t *testing.T) { - reg, err := bucket.NewRegistry(2, 0) - require.NoError(t, err) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() + + reg, err := bucket.NewRegistry(s.factory(2, 0)) + require.NoError(t, err) - // First call for a new user should create limiter and allow - require.True(t, reg.Allow("alice")) - require.True(t, reg.Allow("alice")) - require.False(t, reg.Allow("alice")) + // First call for a new user should create limiter and allow + require.True(t, reg.Allow("alice")) + require.True(t, reg.Allow("alice")) + require.False(t, reg.Allow("alice")) + }) + } } func TestRegistry_Allow_IndependentUsers(t *testing.T) { - reg, err := bucket.NewRegistry(1, 0) - require.NoError(t, err) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() + + reg, err := bucket.NewRegistry(s.factory(1, 0)) + require.NoError(t, err) - // Each user has their own bucket - require.True(t, reg.Allow("alice")) - require.True(t, reg.Allow("bob")) + // Each user has their own bucket + require.True(t, reg.Allow("alice")) + require.True(t, reg.Allow("bob")) - // Both exhausted now - require.False(t, reg.Allow("alice")) - require.False(t, reg.Allow("bob")) + // Both exhausted now + require.False(t, reg.Allow("alice")) + require.False(t, reg.Allow("bob")) + }) + } } func TestRegistry_Allow_Concurrent(t *testing.T) { - reg, err := bucket.NewRegistry(100, 0) - require.NoError(t, err) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() + + reg, err := bucket.NewRegistry(s.factory(100, 0)) + require.NoError(t, err) - var ( - allowed atomic.Int64 - wg sync.WaitGroup - ) + var ( + allowed atomic.Int64 + wg sync.WaitGroup + ) - // 50 goroutines per user, 4 users = 200 goroutines - users := []bucket.Identifier{"alice", "bob", "charlie", "diana"} - for _, user := range users { - for range 50 { - wg.Add(1) + // 50 goroutines per user, 4 users = 200 goroutines + users := []bucket.Identifier{"alice", "bob", "charlie", "diana"} + for _, user := range users { + for range 50 { + wg.Add(1) - go func(u bucket.Identifier) { - defer wg.Done() + go func(u bucket.Identifier) { + defer wg.Done() - if reg.Allow(u) { - allowed.Add(1) + if reg.Allow(u) { + allowed.Add(1) + } + }(user) } - }(user) - } - } + } - wg.Wait() + wg.Wait() - // Each user has capacity 100, only 50 requests each, so all should be allowed - require.Equal(t, int64(200), allowed.Load()) + // Each user has capacity 100, only 50 requests each, so all should be allowed + require.Equal(t, int64(200), allowed.Load()) + }) + } } func TestRegistry_Deny_Concurrent(t *testing.T) { - reg, err := bucket.NewRegistry(100, 0) - require.NoError(t, err) - - var ( - allowed atomic.Int64 - deny atomic.Int64 - wg sync.WaitGroup - ) - - // 50 goroutines per user, 4 users = 200 goroutines - users := []bucket.Identifier{"alice", "bob", "charlie", "diana"} - for _, user := range users { - for range 110 { - wg.Add(1) - - go func(u bucket.Identifier) { - defer wg.Done() - - if reg.Allow(u) { - allowed.Add(1) - } else { - deny.Add(1) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() + + reg, err := bucket.NewRegistry(s.factory(100, 0)) + require.NoError(t, err) + + var ( + allowed atomic.Int64 + deny atomic.Int64 + wg sync.WaitGroup + ) + + // 50 goroutines per user, 4 users = 200 goroutines + users := []bucket.Identifier{"alice", "bob", "charlie", "diana"} + for _, user := range users { + for range 110 { + wg.Add(1) + + go func(u bucket.Identifier) { + defer wg.Done() + + if reg.Allow(u) { + allowed.Add(1) + } else { + deny.Add(1) + } + }(user) } - }(user) - } - } + } - wg.Wait() + wg.Wait() - // Each user has capacity 100, only 50 requests each, so all should be allowed - assert.Equal(t, int64(400), allowed.Load()) - assert.Equal(t, int64(40), deny.Load()) + // Each user has capacity 100, only 50 requests each, so all should be allowed + assert.Equal(t, int64(400), allowed.Load()) + assert.Equal(t, int64(40), deny.Load()) + }) + } } func TestRegistry_Allow_ConcurrentNewUsers(t *testing.T) { - reg, err := bucket.NewRegistry(5, 0) + reg, err := bucket.NewRegistry(func(bucket.Identifier) bucket.Limiter { + return bucket.NewTokenLimiter(5, 0) + }) require.NoError(t, err) var wg sync.WaitGroup diff --git a/bucket/token.go b/bucket/token.go index 7c5994f..b451aa2 100644 --- a/bucket/token.go +++ b/bucket/token.go @@ -24,9 +24,9 @@ type TokenLimiter struct { clock clock } -// NewLimiter creates a new rate limiter with the given capacity and refill rate. +// NewTokenLimiter creates a new rate limiter with the given capacity and refill rate. // Capacity is the maximum burst size. Rate is tokens added per second. -func NewLimiter(capacity, rate uint32) *TokenLimiter { +func NewTokenLimiter(capacity, rate uint32) *TokenLimiter { return NewLimiterWithClock(capacity, rate, realClock{}) } diff --git a/bucket/token_test.go b/bucket/token_test.go index 64c3f61..f63fd2e 100644 --- a/bucket/token_test.go +++ b/bucket/token_test.go @@ -90,7 +90,7 @@ func TestLimiter_Allow(t *testing.T) { } func TestLimiter_Allow_Concurrent(t *testing.T) { - lim := bucket.NewLimiter(100, 0) + lim := bucket.NewTokenLimiter(100, 0) var ( allowed atomic.Int64 From 195e322c071cd6c114dbd3da5d91328f089f6d1d Mon Sep 17 00:00:00 2001 From: Sebastian Machuca Date: Thu, 1 Jan 2026 13:07:41 +1100 Subject: [PATCH 2/2] Improve test and add comments --- bucket/leaky.go | 8 ++++++++ bucket/leaky_test.go | 2 +- bucket/registry.go | 6 +++--- bucket/registry_test.go | 44 ++++++++++++++++++++++------------------- 4 files changed, 36 insertions(+), 24 deletions(-) diff --git a/bucket/leaky.go b/bucket/leaky.go index 48ad9a8..c4e792f 100644 --- a/bucket/leaky.go +++ b/bucket/leaky.go @@ -5,6 +5,9 @@ import ( "time" ) +// LeakyLimiter implements a leaky bucket rate limiter. Requests fill the bucket, +// which drains at a constant rate. If the bucket is full, requests are rejected. +// Unlike TokenLimiter, this enforces a smooth output rate rather than allowing bursts. type LeakyLimiter struct { mu sync.Mutex @@ -12,6 +15,8 @@ type LeakyLimiter struct { lastUpdatedAt time.Time } +// NewLeakyLimiter creates a new leaky bucket limiter. +// Capacity is the maximum bucket size. Rate is how many requests drain per second. func NewLeakyLimiter(capacity, rate uint32) *LeakyLimiter { return &LeakyLimiter{ capacity: float64(capacity), @@ -30,6 +35,9 @@ func (lim *LeakyLimiter) update() { lim.lastUpdatedAt = t } +// Allow reports whether a request is allowed. It adds one to the bucket level +// if there is room and returns true. If the bucket is full, it returns false +// without blocking. func (lim *LeakyLimiter) Allow() bool { lim.mu.Lock() defer lim.mu.Unlock() diff --git a/bucket/leaky_test.go b/bucket/leaky_test.go index 4370272..62d15e8 100644 --- a/bucket/leaky_test.go +++ b/bucket/leaky_test.go @@ -42,7 +42,7 @@ func TestLeakyLimiter_Allow(t *testing.T) { } } -func TestWithConcurrency(t *testing.T) { +func TestLeakyLimiter_Allow_Concurrent(t *testing.T) { var ( allow atomic.Int64 deny atomic.Int64 diff --git a/bucket/registry.go b/bucket/registry.go index 8f28158..fa3fae0 100644 --- a/bucket/registry.go +++ b/bucket/registry.go @@ -17,13 +17,13 @@ type Limiter interface { Allow() bool } -type LimiterFactory func(key Identifier) Limiter +type LimiterFactory func() Limiter func NewRegistry(factory LimiterFactory, keys ...Identifier) (*Registry, error) { limiters := make(map[Identifier]Limiter) for _, key := range keys { - limiters[key] = factory(key) + limiters[key] = factory() } return &Registry{ @@ -38,7 +38,7 @@ func (r *Registry) Allow(key Identifier) bool { lim, ok := r.limiters[key] if !ok { - lim = r.factory(key) + lim = r.factory() r.limiters[key] = lim } diff --git a/bucket/registry_test.go b/bucket/registry_test.go index ebe448c..13aefea 100644 --- a/bucket/registry_test.go +++ b/bucket/registry_test.go @@ -17,7 +17,7 @@ var strategies = []struct { { name: "token", factory: func(capacity, rate uint32) bucket.LimiterFactory { - return func(bucket.Identifier) bucket.Limiter { + return func() bucket.Limiter { return bucket.NewTokenLimiter(capacity, rate) } }, @@ -25,7 +25,7 @@ var strategies = []struct { { name: "leaky", factory: func(capacity, rate uint32) bucket.LimiterFactory { - return func(bucket.Identifier) bucket.Limiter { + return func() bucket.Limiter { return bucket.NewLeakyLimiter(capacity, rate) } }, @@ -171,7 +171,6 @@ func TestRegistry_Deny_Concurrent(t *testing.T) { wg sync.WaitGroup ) - // 50 goroutines per user, 4 users = 200 goroutines users := []bucket.Identifier{"alice", "bob", "charlie", "diana"} for _, user := range users { for range 110 { @@ -191,7 +190,6 @@ func TestRegistry_Deny_Concurrent(t *testing.T) { wg.Wait() - // Each user has capacity 100, only 50 requests each, so all should be allowed assert.Equal(t, int64(400), allowed.Load()) assert.Equal(t, int64(40), deny.Load()) }) @@ -199,25 +197,31 @@ func TestRegistry_Deny_Concurrent(t *testing.T) { } func TestRegistry_Allow_ConcurrentNewUsers(t *testing.T) { - reg, err := bucket.NewRegistry(func(bucket.Identifier) bucket.Limiter { - return bucket.NewTokenLimiter(5, 0) - }) - require.NoError(t, err) + t.Parallel() + + for _, s := range strategies { + t.Run(s.name, func(t *testing.T) { + t.Parallel() - var wg sync.WaitGroup + reg, err := bucket.NewRegistry(s.factory(5, 0)) + require.NoError(t, err) - // Create 100 different users concurrently - for i := range 100 { - wg.Add(1) + var wg sync.WaitGroup - go func(id int) { - defer wg.Done() + // Create 100 different users concurrently + for i := range 100 { + wg.Add(1) - user := bucket.Identifier(rune('a' + id%26)) - reg.Allow(user) - }(i) - } + go func(id int) { + defer wg.Done() + + user := bucket.Identifier(rune('a' + id%26)) + reg.Allow(user) + }(i) + } - wg.Wait() - // If we get here without panic or race, the test passes + wg.Wait() + // If we get here without panic or race, the test passes + }) + } }