Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions internal/pool/double_freeturn_simple_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package pool_test

import (
"context"
"net"
"sync/atomic"
"testing"
"time"

"github.com/redis/go-redis/v9/internal/pool"
)

// TestDoubleFreeTurnSimple tests the double-free bug with a simple scenario.
// This test FAILS with the OLD code and PASSES with the NEW code.
//
// Scenario:
// 1. Request A times out, Dial A completes and delivers connection to Request B
// 2. Request B's own Dial B completes later
// 3. With the bug: Dial B frees Request B's turn (even though Request B is using connection A)
// 4. Then Request B calls Put() and frees the turn AGAIN (double-free)
// 5. This allows more concurrent operations than PoolSize permits
//
// Detection method:
// - Try to acquire PoolSize+1 connections after the double-free
// - With the bug: All succeed (pool size violated)
// - With the fix: Only PoolSize succeed
func TestDoubleFreeTurnSimple(t *testing.T) {
ctx := context.Background()

var dialCount atomic.Int32
dialBComplete := make(chan struct{})
requestBGotConn := make(chan struct{})
requestBCalledPut := make(chan struct{})

controlledDialer := func(ctx context.Context) (net.Conn, error) {
count := dialCount.Add(1)

if count == 1 {
// Dial A: takes 150ms
time.Sleep(150 * time.Millisecond)
t.Logf("Dial A completed")
} else if count == 2 {
// Dial B: takes 300ms (longer than Dial A)
time.Sleep(300 * time.Millisecond)
t.Logf("Dial B completed")
close(dialBComplete)
} else {
// Other dials: fast
time.Sleep(10 * time.Millisecond)
}

return newDummyConn(), nil
}

testPool := pool.NewConnPool(&pool.Options{
Dialer: controlledDialer,
PoolSize: 2, // Only 2 concurrent operations allowed
MaxConcurrentDials: 5,
DialTimeout: 1 * time.Second,
PoolTimeout: 1 * time.Second,
})
defer testPool.Close()

// Request A: Short timeout (100ms), will timeout before dial completes (150ms)
go func() {
shortCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()

_, err := testPool.Get(shortCtx)
if err != nil {
t.Logf("Request A: Timed out as expected: %v", err)
}
}()

// Wait for Request A to start
time.Sleep(20 * time.Millisecond)

// Request B: Long timeout, will receive connection from Request A's dial
requestBDone := make(chan struct{})
go func() {
defer close(requestBDone)

longCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()

cn, err := testPool.Get(longCtx)
if err != nil {
t.Errorf("Request B: Should have received connection but got error: %v", err)
return
}

t.Logf("Request B: Got connection from Request A's dial")
close(requestBGotConn)

// Wait for dial B to complete
<-dialBComplete

t.Logf("Request B: Dial B completed")

// Wait a bit to allow Dial B goroutine to finish and call freeTurn()
time.Sleep(100 * time.Millisecond)

// Signal that we're ready for the test to check semaphore state
close(requestBCalledPut)

// Wait for the test to check QueueLen
time.Sleep(200 * time.Millisecond)

t.Logf("Request B: Now calling Put()")
testPool.Put(ctx, cn)
t.Logf("Request B: Put() called")
}()

// Wait for Request B to get the connection
<-requestBGotConn

// Wait for Dial B to complete and freeTurn() to be called
<-requestBCalledPut

// NOW WE'RE IN THE CRITICAL WINDOW
// Request B is holding a connection (from Dial A)
// Dial B has completed and returned (freeTurn() has been called)
// With the bug:
// - Dial B freed Request B's turn (BUG!)
// - QueueLen should be 0
// With the fix:
// - Dial B did NOT free Request B's turn
// - QueueLen should be 1 (Request B still holds the turn)

t.Logf("\n=== CRITICAL CHECK: QueueLen ===")
t.Logf("Request B is holding a connection, Dial B has completed and returned")
queueLen := testPool.QueueLen()
t.Logf("QueueLen: %d", queueLen)

// Wait for Request B to finish
select {
case <-requestBDone:
case <-time.After(1 * time.Second):
t.Logf("Request B timed out")
}

t.Logf("\n=== Results ===")
t.Logf("QueueLen during critical window: %d", queueLen)
t.Logf("Expected with fix: 1 (Request B still holds the turn)")
t.Logf("Expected with bug: 0 (Dial B freed Request B's turn)")

if queueLen == 0 {
t.Errorf("DOUBLE-FREE BUG DETECTED!")
t.Errorf("QueueLen is 0, meaning Dial B freed Request B's turn")
t.Errorf("But Request B is still holding a connection, so its turn should NOT be freed yet")
} else if queueLen == 1 {
t.Logf("✓ CORRECT: QueueLen is 1")
t.Logf("Request B is still holding the turn (will be freed when Request B calls Put())")
} else {
t.Logf("Unexpected QueueLen: %d (expected 1 with fix, 0 with bug)", queueLen)
}
}

229 changes: 229 additions & 0 deletions internal/pool/double_freeturn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
package pool

import (
"context"
"net"
"sync"
"sync/atomic"
"testing"
"time"
)

// TestDoubleFreeTurnBug demonstrates the double freeTurn bug where:
// 1. Dial goroutine creates a connection
// 2. Original waiter times out
// 3. putIdleConn delivers connection to another waiter
// 4. Dial goroutine calls freeTurn() (FIRST FREE)
// 5. Second waiter uses connection and calls Put()
// 6. Put() calls freeTurn() (SECOND FREE - BUG!)
//
// This causes the semaphore to be released twice, allowing more concurrent
// operations than PoolSize allows.
func TestDoubleFreeTurnBug(t *testing.T) {
var dialCount atomic.Int32
var putCount atomic.Int32

// Slow dialer - 150ms per dial
slowDialer := func(ctx context.Context) (net.Conn, error) {
dialCount.Add(1)
select {
case <-time.After(150 * time.Millisecond):
server, client := net.Pipe()
go func() {
defer server.Close()
buf := make([]byte, 1024)
for {
_, err := server.Read(buf)
if err != nil {
return
}
}
}()
return client, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}

opt := &Options{
Dialer: slowDialer,
PoolSize: 10, // Small pool to make bug easier to trigger
MaxConcurrentDials: 10,
MinIdleConns: 0,
PoolTimeout: 100 * time.Millisecond,
DialTimeout: 5 * time.Second,
}

connPool := NewConnPool(opt)
defer connPool.Close()

// Scenario:
// 1. Request A starts dial (100ms timeout - will timeout before dial completes)
// 2. Request B arrives (500ms timeout - will wait in queue)
// 3. Request A times out at 100ms
// 4. Dial completes at 150ms
// 5. putIdleConn delivers connection to Request B
// 6. Dial goroutine calls freeTurn() - FIRST FREE
// 7. Request B uses connection and calls Put()
// 8. Put() calls freeTurn() - SECOND FREE (BUG!)

var wg sync.WaitGroup

// Request A: Short timeout, will timeout before dial completes
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

cn, err := connPool.Get(ctx)
if err != nil {
// Expected to timeout
t.Logf("Request A timed out as expected: %v", err)
} else {
// Should not happen
t.Errorf("Request A should have timed out but got connection")
connPool.Put(ctx, cn)
putCount.Add(1)
}
}()

// Wait a bit for Request A to start dialing
time.Sleep(10 * time.Millisecond)

// Request B: Long timeout, will receive the connection from putIdleConn
wg.Add(1)
go func() {
defer wg.Done()
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()

cn, err := connPool.Get(ctx)
if err != nil {
t.Errorf("Request B should have succeeded but got error: %v", err)
} else {
t.Logf("Request B got connection successfully")
// Use the connection briefly
time.Sleep(50 * time.Millisecond)
connPool.Put(ctx, cn)
putCount.Add(1)
}
}()

wg.Wait()

// Check results
t.Logf("\n=== Results ===")
t.Logf("Dials: %d", dialCount.Load())
t.Logf("Puts: %d", putCount.Load())

// The bug is hard to detect directly without instrumenting freeTurn,
// but we can verify the scenario works correctly:
// - Request A should timeout
// - Request B should succeed and get the connection
// - 1-2 dials may occur (Request A starts one, Request B may start another)
// - 1 put should occur (Request B returning the connection)

if putCount.Load() != 1 {
t.Errorf("Expected 1 put, got %d", putCount.Load())
}

t.Logf("✓ Scenario completed successfully")
t.Logf("Note: The double freeTurn bug would cause semaphore to be released twice,")
t.Logf("allowing more concurrent operations than PoolSize permits.")
t.Logf("With the fix, putIdleConn returns true when delivering to a waiter,")
t.Logf("preventing the dial goroutine from calling freeTurn (waiter will call it later).")
}

// TestDoubleFreeTurnHighConcurrency tests the bug under high concurrency
func TestDoubleFreeTurnHighConcurrency(t *testing.T) {
var dialCount atomic.Int32
var getSuccesses atomic.Int32
var getFailures atomic.Int32

slowDialer := func(ctx context.Context) (net.Conn, error) {
dialCount.Add(1)
select {
case <-time.After(200 * time.Millisecond):
server, client := net.Pipe()
go func() {
defer server.Close()
buf := make([]byte, 1024)
for {
_, err := server.Read(buf)
if err != nil {
return
}
}
}()
return client, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}

opt := &Options{
Dialer: slowDialer,
PoolSize: 20,
MaxConcurrentDials: 20,
MinIdleConns: 0,
PoolTimeout: 100 * time.Millisecond,
DialTimeout: 5 * time.Second,
}

connPool := NewConnPool(opt)
defer connPool.Close()

// Create many requests with varying timeouts
// Some will timeout before dial completes, triggering the putIdleConn delivery path
const numRequests = 100
var wg sync.WaitGroup

for i := 0; i < numRequests; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()

// Vary timeout: some short (will timeout), some long (will succeed)
timeout := 100 * time.Millisecond
if id%3 == 0 {
timeout = 500 * time.Millisecond
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

cn, err := connPool.Get(ctx)
if err != nil {
getFailures.Add(1)
} else {
getSuccesses.Add(1)
time.Sleep(10 * time.Millisecond)
connPool.Put(ctx, cn)
}
}(i)

// Stagger requests
if i%10 == 0 {
time.Sleep(5 * time.Millisecond)
}
}

wg.Wait()

t.Logf("\n=== High Concurrency Results ===")
t.Logf("Requests: %d", numRequests)
t.Logf("Successes: %d", getSuccesses.Load())
t.Logf("Failures: %d", getFailures.Load())
t.Logf("Dials: %d", dialCount.Load())

// Verify that some requests succeeded despite timeouts
// This exercises the putIdleConn delivery path
if getSuccesses.Load() == 0 {
t.Errorf("Expected some successful requests, got 0")
}

t.Logf("✓ High concurrency test completed")
t.Logf("Note: This test exercises the putIdleConn delivery path where the bug occurs")
}

Loading
Loading