Skip to content
20 changes: 12 additions & 8 deletions async_handoff_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ func TestEventDrivenHandoffIntegration(t *testing.T) {
Dialer: func(ctx context.Context) (net.Conn, error) {
return &mockNetConn{addr: "original:6379"}, nil
},
PoolSize: int32(5),
PoolTimeout: time.Second,
PoolSize: int32(5),
MaxConcurrentDials: 5,
PoolTimeout: time.Second,
})

// Add the hook to the pool after creation
Expand Down Expand Up @@ -153,8 +154,9 @@ func TestEventDrivenHandoffIntegration(t *testing.T) {
return &mockNetConn{addr: "original:6379"}, nil
},

PoolSize: int32(10),
PoolTimeout: time.Second,
PoolSize: int32(10),
MaxConcurrentDials: 10,
PoolTimeout: time.Second,
})
defer testPool.Close()

Expand Down Expand Up @@ -225,8 +227,9 @@ func TestEventDrivenHandoffIntegration(t *testing.T) {
return &mockNetConn{addr: "original:6379"}, nil
},

PoolSize: int32(3),
PoolTimeout: time.Second,
PoolSize: int32(3),
MaxConcurrentDials: 3,
PoolTimeout: time.Second,
})
defer testPool.Close()

Expand Down Expand Up @@ -288,8 +291,9 @@ func TestEventDrivenHandoffIntegration(t *testing.T) {
return &mockNetConn{addr: "original:6379"}, nil
},

PoolSize: int32(2),
PoolTimeout: time.Second,
PoolSize: int32(2),
MaxConcurrentDials: 2,
PoolTimeout: time.Second,
})
defer testPool.Close()

Expand Down
22 changes: 12 additions & 10 deletions internal/pool/bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ func BenchmarkPoolGetPut(b *testing.B) {
for _, bm := range benchmarks {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(bm.poolSize),
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
Dialer: dummyDialer,
PoolSize: int32(bm.poolSize),
MaxConcurrentDials: bm.poolSize,
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
})

b.ResetTimer()
Expand Down Expand Up @@ -75,11 +76,12 @@ func BenchmarkPoolGetRemove(b *testing.B) {
for _, bm := range benchmarks {
b.Run(bm.String(), func(b *testing.B) {
connPool := pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(bm.poolSize),
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
Dialer: dummyDialer,
PoolSize: int32(bm.poolSize),
MaxConcurrentDials: bm.poolSize,
PoolTimeout: time.Second,
DialTimeout: 1 * time.Second,
ConnMaxIdleTime: time.Hour,
})

b.ResetTimer()
Expand Down
36 changes: 20 additions & 16 deletions internal/pool/buffer_size_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ var _ = Describe("Buffer Size Configuration", func() {

It("should use default buffer sizes when not specified", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(1),
PoolTimeout: 1000,
Dialer: dummyDialer,
PoolSize: int32(1),
MaxConcurrentDials: 1,
PoolTimeout: 1000,
})

cn, err := connPool.NewConn(ctx)
Expand All @@ -47,11 +48,12 @@ var _ = Describe("Buffer Size Configuration", func() {
customWriteSize := 64 * 1024 // 64KB

connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(1),
PoolTimeout: 1000,
ReadBufferSize: customReadSize,
WriteBufferSize: customWriteSize,
Dialer: dummyDialer,
PoolSize: int32(1),
MaxConcurrentDials: 1,
PoolTimeout: 1000,
ReadBufferSize: customReadSize,
WriteBufferSize: customWriteSize,
})

cn, err := connPool.NewConn(ctx)
Expand All @@ -68,11 +70,12 @@ var _ = Describe("Buffer Size Configuration", func() {

It("should handle zero buffer sizes by using defaults", func() {
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(1),
PoolTimeout: 1000,
ReadBufferSize: 0, // Should use default
WriteBufferSize: 0, // Should use default
Dialer: dummyDialer,
PoolSize: int32(1),
MaxConcurrentDials: 1,
PoolTimeout: 1000,
ReadBufferSize: 0, // Should use default
WriteBufferSize: 0, // Should use default
})

cn, err := connPool.NewConn(ctx)
Expand Down Expand Up @@ -104,9 +107,10 @@ var _ = Describe("Buffer Size Configuration", func() {
// Test the scenario where someone creates a pool directly (like in tests)
// without setting ReadBufferSize and WriteBufferSize
connPool = pool.NewConnPool(&pool.Options{
Dialer: dummyDialer,
PoolSize: int32(1),
PoolTimeout: 1000,
Dialer: dummyDialer,
PoolSize: int32(1),
MaxConcurrentDials: 1,
PoolTimeout: 1000,
// ReadBufferSize and WriteBufferSize are not set (will be 0)
})

Expand Down
5 changes: 3 additions & 2 deletions internal/pool/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ func TestPoolWithHooks(t *testing.T) {
Dialer: func(ctx context.Context) (net.Conn, error) {
return &net.TCPConn{}, nil // Mock connection
},
PoolSize: 1,
DialTimeout: time.Second,
PoolSize: 1,
MaxConcurrentDials: 1,
DialTimeout: time.Second,
}

pool := NewConnPool(opt)
Expand Down
108 changes: 102 additions & 6 deletions internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ type Options struct {

PoolFIFO bool
PoolSize int32
MaxConcurrentDials int
DialTimeout time.Duration
PoolTimeout time.Duration
MinIdleConns int32
Expand Down Expand Up @@ -119,7 +120,9 @@ type ConnPool struct {
dialErrorsNum uint32 // atomic
lastDialError atomic.Value

queue chan struct{}
queue chan struct{}
dialsInProgress chan struct{}
dialsQueue *wantConnQueue

connsMu sync.Mutex
conns map[uint64]*Conn
Expand All @@ -145,9 +148,11 @@ func NewConnPool(opt *Options) *ConnPool {
p := &ConnPool{
cfg: opt,

queue: make(chan struct{}, opt.PoolSize),
conns: make(map[uint64]*Conn),
idleConns: make([]*Conn, 0, opt.PoolSize),
queue: make(chan struct{}, opt.PoolSize),
conns: make(map[uint64]*Conn),
dialsInProgress: make(chan struct{}, opt.MaxConcurrentDials),
dialsQueue: newWantConnQueue(),
idleConns: make([]*Conn, 0, opt.PoolSize),
}

// Only create MinIdleConns if explicitly requested (> 0)
Expand Down Expand Up @@ -226,6 +231,7 @@ func (p *ConnPool) checkMinIdleConns() {
return
}
}

}

func (p *ConnPool) addIdleConn() error {
Expand Down Expand Up @@ -473,9 +479,8 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {

atomic.AddUint32(&p.stats.Misses, 1)

newcn, err := p.newConn(ctx, true)
newcn, err := p.queuedNewConn(ctx)
if err != nil {
p.freeTurn()
return nil, err
}

Expand All @@ -495,6 +500,97 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) {
return newcn, nil
}

func (p *ConnPool) queuedNewConn(ctx context.Context) (*Conn, error) {
select {
case p.dialsInProgress <- struct{}{}:
// Got permission, proceed to create connection
case <-ctx.Done():
p.freeTurn()
return nil, ctx.Err()
}

dialCtx, cancel := context.WithTimeout(context.Background(), p.cfg.DialTimeout)

w := &wantConn{
ctx: dialCtx,
cancelCtx: cancel,
result: make(chan wantConnResult, 1),
}
var err error
defer func() {
if err != nil {
w.cancel(ctx, p)
}
}()

p.dialsQueue.enqueue(w)

go func(w *wantConn) {
var freeTurnCalled bool
defer func() {
if err := recover(); err != nil {
if !freeTurnCalled {
p.freeTurn()
}
internal.Logger.Printf(context.Background(), "queuedNewConn panic: %+v", err)
}
}()

defer w.cancelCtx()
defer func() { <-p.dialsInProgress }() // Release connection creation permission

dialCtx := w.getCtxForDial()
cn, cnErr := p.newConn(dialCtx, true)
delivered := w.tryDeliver(cn, cnErr)
if cnErr == nil && delivered {
return
} else if cnErr == nil && !delivered {
p.putIdleConn(dialCtx, cn)
freeTurnCalled = true
} else {
p.freeTurn()
freeTurnCalled = true
}
}(w)

select {
case <-ctx.Done():
err = ctx.Err()
return nil, err
case result := <-w.result:
err = result.err
return result.cn, err
}
}

func (p *ConnPool) putIdleConn(ctx context.Context, cn *Conn) {
for {
w, ok := p.dialsQueue.dequeue()
if !ok {
break
}
if w.tryDeliver(cn, nil) {
return
}
}

cn.SetUsable(true)

p.connsMu.Lock()
defer p.connsMu.Unlock()

if p.closed() {
_ = cn.Close()
return
}

// poolSize is increased in newConn
p.idleConns = append(p.idleConns, cn)
p.idleConnsLen.Add(1)

p.freeTurn()
}

func (p *ConnPool) waitTurn(ctx context.Context) error {
select {
case <-ctx.Done():
Expand Down
Loading
Loading