Skip to content

Commit

Permalink
feat: default conn write timeout and periodic PING
Browse files Browse the repository at this point in the history
  • Loading branch information
rueian committed May 15, 2022
1 parent d932192 commit fdedc0a
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 20 deletions.
4 changes: 4 additions & 0 deletions internal/cmds/cmds.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ var (
QuitCmd = Completed{
cs: &CommandSlice{s: []string{"QUIT"}},
}
// PingCmd is predefined PING
PingCmd = Completed{
cs: &CommandSlice{s: []string{"PING"}},
}
// SlotCmd is predefined CLUSTER SLOTS
SlotCmd = Completed{
cs: &CommandSlice{s: []string{"CLUSTER", "SLOTS"}},
Expand Down
64 changes: 58 additions & 6 deletions pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type pipe struct {
state int32
slept int32
version int32
timeout time.Duration

once sync.Once
cond sync.Cond
Expand Down Expand Up @@ -68,6 +69,8 @@ func newPipe(conn net.Conn, option *ClientOption) (p *pipe, err error) {

subs: newSubs(),
psubs: newSubs(),

timeout: option.ConnWriteTimeout,
}

helloCmd := []string{"HELLO", "3"}
Expand Down Expand Up @@ -137,22 +140,32 @@ func (p *pipe) background() {

func (p *pipe) _background() {
wg := sync.WaitGroup{}
wg.Add(2)
exit := func() {
// stop accepting new requests
atomic.CompareAndSwapInt32(&p.state, 1, 2)
_ = p.conn.Close() // force both read & write goroutine to exit
wg.Done()
}
wg.Add(1)
go func() {
p._backgroundWrite()
exit()
}()
wg.Add(1)
go func() {
p._backgroundRead()
exit()
p._awake()
}()
if p.timeout > 0 {
go func() {
if err := p._backgroundPing(); err != ErrClosing {
p.error.CompareAndSwap(nil, &errs{error: err})
atomic.CompareAndSwapInt32(&p.state, 1, 2)
_ = p.conn.Close() // force both read & write goroutine to exit
}
}()
}
wg.Wait()

p.subs.Close()
Expand Down Expand Up @@ -316,6 +329,40 @@ func (p *pipe) _backgroundRead() {
}
}

func (p *pipe) _backgroundPing() error {
var timer *time.Timer
for atomic.LoadInt32(&p.state) == 1 {
ws := atomic.AddInt32(&p.waits, 1)
ch := p.queue.PutOne(cmds.PingCmd)
if ws == 1 {
p._awake()
}
if timer == nil {
timer = time.NewTimer(p.timeout)
} else {
timer.Reset(p.timeout)
}
select {
case resp := <-ch:
atomic.AddInt32(&p.waits, -1)
if !timer.Stop() {
<-timer.C
}
if err := resp.NonRedisError(); err != nil {
return err
}
case <-timer.C:
go func() {
<-ch
atomic.AddInt32(&p.waits, -1)
}()
return context.DeadlineExceeded
}
time.Sleep(time.Second)
}
return ErrClosing
}

func (p *pipe) handlePush(values []RedisMessage) {
if len(values) < 2 {
return
Expand Down Expand Up @@ -534,6 +581,9 @@ func (p *pipe) syncDo(ctx context.Context, cmd cmds.Completed) (resp RedisResult
if dl, ok := ctx.Deadline(); ok {
p.conn.SetDeadline(dl)
defer p.conn.SetDeadline(time.Time{})
} else if p.timeout > 0 {
p.conn.SetDeadline(time.Now().Add(p.timeout))
defer p.conn.SetDeadline(time.Time{})
}

var msg RedisMessage
Expand All @@ -558,6 +608,9 @@ func (p *pipe) syncDoMulti(ctx context.Context, resp []RedisResult, multi []cmds
if dl, ok := ctx.Deadline(); ok {
p.conn.SetDeadline(dl)
defer p.conn.SetDeadline(time.Time{})
} else if p.timeout > 0 {
p.conn.SetDeadline(time.Now().Add(p.timeout))
defer p.conn.SetDeadline(time.Time{})
}

var err error
Expand Down Expand Up @@ -644,16 +697,15 @@ func (p *pipe) Error() error {

func (p *pipe) Close() {
p.error.CompareAndSwap(nil, errClosing)
atomic.CompareAndSwapInt32(&p.state, 0, 2)
atomic.CompareAndSwapInt32(&p.state, 1, 2)
atomic.AddInt32(&p.waits, 1)
stopping1 := atomic.CompareAndSwapInt32(&p.state, 0, 2)
stopping2 := atomic.CompareAndSwapInt32(&p.state, 1, 2)
if p.queue != nil {
p.background()
p._awake()
for atomic.LoadInt32(&p.waits) != 1 {
runtime.Gosched()
if stopping1 || stopping2 {
<-p.queue.PutOne(cmds.QuitCmd)
}
<-p.queue.PutOne(cmds.QuitCmd)
}
atomic.AddInt32(&p.waits, -1)
atomic.CompareAndSwapInt32(&p.state, 2, 3)
Expand Down
49 changes: 40 additions & 9 deletions pipe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"io"
"net"
"os"
"runtime"
"strconv"
"strings"
Expand Down Expand Up @@ -1022,6 +1021,16 @@ func TestOngoingDeadlineContextInSyncMode_Do(t *testing.T) {
p.Close()
}

func TestWriteDeadlineInSyncMode_Do(t *testing.T) {
p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: 1 * time.Second / 2})
defer closeConn()

if err := p.Do(context.Background(), cmds.NewCompleted([]string{"GET", "a"})).NonRedisError(); err != context.DeadlineExceeded {
t.Fatalf("unexpected err %v", err)
}
p.Close()
}

func TestOngoingDeadlineContextInSyncMode_DoMulti(t *testing.T) {
p, _, _, closeConn := setup(t, ClientOption{})
defer closeConn()
Expand All @@ -1035,6 +1044,16 @@ func TestOngoingDeadlineContextInSyncMode_DoMulti(t *testing.T) {
p.Close()
}

func TestWriteDeadlineInSyncMode_DoMulti(t *testing.T) {
p, _, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: time.Second / 2})
defer closeConn()

if err := p.DoMulti(context.Background(), cmds.NewCompleted([]string{"GET", "a"}))[0].NonRedisError(); err != context.DeadlineExceeded {
t.Fatalf("unexpected err %v", err)
}
p.Close()
}

func TestOngoingCancelContextInPipelineMode_Do(t *testing.T) {
p, mock, close, closeConn := setup(t, ClientOption{})
defer closeConn()
Expand Down Expand Up @@ -1063,7 +1082,7 @@ func TestOngoingCancelContextInPipelineMode_Do(t *testing.T) {
time.Sleep(time.Millisecond * 100)
}
cancel()
if atomic.LoadInt32(&canceled) != 50 {
for atomic.LoadInt32(&canceled) != 50 {
t.Logf("wait canceled count to be 50 %v", atomic.LoadInt32(&canceled))
time.Sleep(time.Millisecond * 100)
}
Expand All @@ -1075,7 +1094,7 @@ func TestOngoingCancelContextInPipelineMode_Do(t *testing.T) {
}

func TestOngoingWriteTimeoutInPipelineMode_Do(t *testing.T) {
p, mock, _, closeConn := setup(t, ClientOption{ConnReadWriteTimeout: time.Second / 2})
p, mock, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: time.Second / 2})
defer closeConn()

ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
Expand All @@ -1089,7 +1108,7 @@ func TestOngoingWriteTimeoutInPipelineMode_Do(t *testing.T) {
s, err := p.Do(ctx, cmds.NewCompleted([]string{"GET", "a"})).ToString()
if s == "OK" {
atomic.AddInt32(&success, 1)
} else if errors.Is(err, os.ErrDeadlineExceeded) {
} else if errors.Is(err, context.DeadlineExceeded) {
atomic.AddInt32(&timeout, 1)
}
}()
Expand All @@ -1099,7 +1118,7 @@ func TestOngoingWriteTimeoutInPipelineMode_Do(t *testing.T) {
t.Logf("wait success count to be 1 %v", atomic.LoadInt32(&success))
time.Sleep(time.Millisecond * 100)
}
if atomic.LoadInt32(&timeout) != 99 {
for atomic.LoadInt32(&timeout) != 99 {
t.Logf("wait timeout count to be 99 %v", atomic.LoadInt32(&timeout))
time.Sleep(time.Millisecond * 100)
}
Expand Down Expand Up @@ -1134,7 +1153,7 @@ func TestOngoingCancelContextInPipelineMode_DoMulti(t *testing.T) {
time.Sleep(time.Millisecond * 100)
}
cancel()
if atomic.LoadInt32(&canceled) != 50 {
for atomic.LoadInt32(&canceled) != 50 {
t.Logf("wait canceled count to be 50 %v", atomic.LoadInt32(&canceled))
time.Sleep(time.Millisecond * 100)
}
Expand All @@ -1146,7 +1165,7 @@ func TestOngoingCancelContextInPipelineMode_DoMulti(t *testing.T) {
}

func TestOngoingWriteTimeoutInPipelineMode_DoMulti(t *testing.T) {
p, mock, _, closeConn := setup(t, ClientOption{ConnReadWriteTimeout: time.Second / 2})
p, mock, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: time.Second / 2})
defer closeConn()

ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
Expand All @@ -1160,7 +1179,7 @@ func TestOngoingWriteTimeoutInPipelineMode_DoMulti(t *testing.T) {
s, err := p.DoMulti(ctx, cmds.NewCompleted([]string{"GET", "a"}))[0].ToString()
if s == "OK" {
atomic.AddInt32(&success, 1)
} else if errors.Is(err, os.ErrDeadlineExceeded) {
} else if errors.Is(err, context.DeadlineExceeded) {
atomic.AddInt32(&timeout, 1)
}
}()
Expand All @@ -1170,13 +1189,25 @@ func TestOngoingWriteTimeoutInPipelineMode_DoMulti(t *testing.T) {
t.Logf("wait success count to be 1 %v", atomic.LoadInt32(&success))
time.Sleep(time.Millisecond * 100)
}
if atomic.LoadInt32(&timeout) != 99 {
for atomic.LoadInt32(&timeout) != 99 {
t.Logf("wait timeout count to be 99 %v", atomic.LoadInt32(&timeout))
time.Sleep(time.Millisecond * 100)
}
p.Close()
}

func TestPingOnConnError(t *testing.T) {
p, mock, _, closeConn := setup(t, ClientOption{ConnWriteTimeout: 3 * time.Second})
p.background()
mock.Expect("PING")
closeConn()
time.Sleep(time.Second / 2)
p.Close()
if err := p.Error(); !strings.HasPrefix(err.Error(), "io:") {
t.Fatalf("unexpect err %v", err)
}
}

func TestDeadPipe(t *testing.T) {
ctx := context.Background()
if err := dead.Error(); err != ErrClosing {
Expand Down
13 changes: 9 additions & 4 deletions redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,10 @@ func run(t *testing.T, client Client, cases ...func(*testing.T, Client)) {
}

func TestSingleClientIntegration(t *testing.T) {
client, err := NewClient(ClientOption{InitAddress: []string{"127.0.0.1:6379"}})
client, err := NewClient(ClientOption{
InitAddress: []string{"127.0.0.1:6379"},
ConnWriteTimeout: 120 * time.Second,
})
if err != nil {
t.Fatal(err)
}
Expand All @@ -319,7 +322,8 @@ func TestSingleClientIntegration(t *testing.T) {

func TestSentinelClientIntegration(t *testing.T) {
client, err := NewClient(ClientOption{
InitAddress: []string{"127.0.0.1:26379"},
InitAddress: []string{"127.0.0.1:26379"},
ConnWriteTimeout: 120 * time.Second,
Sentinel: SentinelOption{
MasterSet: "test",
},
Expand All @@ -335,8 +339,9 @@ func TestSentinelClientIntegration(t *testing.T) {

func TestClusterClientIntegration(t *testing.T) {
client, err := NewClient(ClientOption{
InitAddress: []string{"127.0.0.1:7001", "127.0.0.1:7002", "127.0.0.1:7003"},
ShuffleInit: true,
InitAddress: []string{"127.0.0.1:7001", "127.0.0.1:7002", "127.0.0.1:7003"},
ConnWriteTimeout: 120 * time.Second,
ShuffleInit: true,
})
if err != nil {
t.Fatal(err)
Expand Down
6 changes: 5 additions & 1 deletion rueidis.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ type ClientOption struct {
// The default is DefaultPoolSize.
BlockingPoolSize int

// ConnWriteTimeout is used to apply net.Conn.SetWriteDeadline
// ConnWriteTimeout is applied net.Conn.SetWriteDeadline and periodic PING to redis
// Since the Dialer.KeepAlive will not be triggered if there is data in the outgoing buffer,
// ConnWriteTimeout should be set in order to detect local congestion or unresponsive redis server.
// This default is ClientOption.Dialer.KeepAlive * (9+1), where 9 is the default of tcp_keepalive_probes on Linux.
ConnWriteTimeout time.Duration

// ShuffleInit is a handy flag that shuffles the InitAddress after passing to the NewClient() if it is true
Expand Down Expand Up @@ -187,6 +188,9 @@ func dial(dst string, opt *ClientOption) (conn net.Conn, err error) {
if opt.Dialer.KeepAlive == 0 {
opt.Dialer.KeepAlive = DefaultTCPKeepAlive
}
if opt.ConnWriteTimeout == 0 {
opt.ConnWriteTimeout = opt.Dialer.KeepAlive * 10
}
if opt.TLSConfig != nil {
conn, err = tls.DialWithDialer(&opt.Dialer, "tcp", dst, opt.TLSConfig)
} else {
Expand Down

0 comments on commit fdedc0a

Please sign in to comment.