Skip to content

Commit

Permalink
fix: handle failures of subscribe/unsubscribe commands correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
rueian committed Jun 23, 2022
1 parent 0044e1d commit 2eba08d
Show file tree
Hide file tree
Showing 6 changed files with 432 additions and 181 deletions.
143 changes: 87 additions & 56 deletions pipe.go
Expand Up @@ -48,8 +48,9 @@ type pipe struct {
once sync.Once

info map[string]RedisMessage
subs *subs
nsubs *subs
psubs *subs
ssubs *subs
pshks atomic.Value
error atomic.Value
}
Expand All @@ -62,8 +63,9 @@ func newPipe(conn net.Conn, option *ClientOption) (p *pipe, err error) {
r: bufio.NewReader(conn),
w: bufio.NewWriter(conn),

subs: newSubs(),
nsubs: newSubs(),
psubs: newSubs(),
ssubs: newSubs(),

timeout: option.ConnWriteTimeout,
pinggap: option.Dialer.KeepAlive,
Expand Down Expand Up @@ -142,8 +144,9 @@ func (p *pipe) _background() {
}
<-wait

p.subs.Close()
p.nsubs.Close()
p.psubs.Close()
p.ssubs.Close()
if old := p.pshks.Swap(emptypshks).(*pshks); old.close != nil {
old.close <- p.Error()
close(old.close)
Expand All @@ -159,24 +162,13 @@ func (p *pipe) _background() {
// clean up cache and free pending calls
p.cache.FreeAndClose(RedisMessage{typ: '-', string: p.Error().Error()})
for atomic.LoadInt32(&p.waits) != 0 {
if ones[0], multi, ch = p.queue.NextWriteCmd(); ch != nil {
if multi == nil {
multi = ones
}
for _, one := range multi {
if one.NoReply() {
ch <- newErrResult(p.Error())
}
}
}
p.queue.NextWriteCmd()
if ones[0], multi, ch, cond = p.queue.NextResultCh(); ch != nil {
if multi == nil {
multi = ones
}
for _, one := range multi {
if !one.NoReply() {
ch <- newErrResult(p.Error())
}
for range multi {
ch <- newErrResult(p.Error())
}
cond.L.Unlock()
cond.Signal()
Expand Down Expand Up @@ -216,10 +208,7 @@ func (p *pipe) _backgroundWrite() (err error) {
multi = ones
}
for _, cmd := range multi {
if err = writeCmd(p.w, cmd.Commands()); cmd.NoReply() {
err = p.w.Flush()
ch <- newErrResult(err)
}
err = writeCmd(p.w, cmd.Commands())
}
if err != nil {
if err != ErrClosing { // ignore ErrClosing to allow final QUIT command to be sent
Expand All @@ -239,15 +228,15 @@ func (p *pipe) _backgroundRead() (err error) {
multi []cmds.Completed
ch chan RedisResult
ff int // fulfilled count
skip int // skip rest push messages
ver = p.version
pr bool // push reply
)

defer func() {
if err != nil && ff < len(multi) {
for ; ff < len(multi); ff++ {
if !multi[ff].NoReply() {
ch <- newResult(msg, err)
}
ch <- newErrResult(err)
}
cond.L.Unlock()
cond.Signal()
Expand All @@ -259,8 +248,14 @@ func (p *pipe) _backgroundRead() (err error) {
return
}
if msg.typ == '>' {
p.handlePush(msg.values)
continue
if pr = p.handlePush(msg.values); !pr {
continue
}
if skip > 0 {
skip--
pr = false
continue
}
} else if ver < 7 && len(msg.values) != 0 {
// This is a workaround for Redis 6's broken invalidation protocol: https://github.com/redis/redis/issues/8935
// When Redis 6 handles MULTI, MGET, or other multi-keys command,
Expand Down Expand Up @@ -293,7 +288,6 @@ func (p *pipe) _backgroundRead() (err error) {
p.cache.Update(ck, cc, cp, msg.values[0].integer)
}
}
nextCMD:
if ff == len(multi) {
ff = 0
ones[0], multi, ch, cond = p.queue.NextResultCh() // ch should not be nil, otherwise it must be a protocol bug
Expand All @@ -304,20 +298,29 @@ func (p *pipe) _backgroundRead() (err error) {
multi = ones
}
}
if multi[ff].NoReply() {
ff++
if ff == len(multi) {
cond.L.Unlock()
cond.Signal()
if pr {
if !multi[ff].NoReply() {
panic(protocolbug)
}
goto nextCMD
} else {
ff++
ch <- newResult(msg, err)
if ff == len(multi) {
cond.L.Unlock()
cond.Signal()
if len(multi[ff].Commands()) == 1 { // wildcard unsubscribe
switch strings.ToUpper(multi[ff].Commands()[0]) {
case "UNSUBSCRIBE":
skip = p.nsubs.Confirmed()
case "PUNSUBSCRIBE":
skip = p.psubs.Confirmed()
case "SUNSUBSCRIBE":
skip = p.ssubs.Confirmed()
}
} else {
skip = len(multi[ff].Commands()) - 2
}
msg = RedisMessage{} // override successful subscribe/unsubscribe response to empty
pr = false
}
ch <- newResult(msg, err)
if ff++; ff == len(multi) {
cond.L.Unlock()
cond.Signal()
}
}
}
Expand All @@ -332,7 +335,7 @@ func (p *pipe) _backgroundPing() (err error) {
return err
}

func (p *pipe) handlePush(values []RedisMessage) {
func (p *pipe) handlePush(values []RedisMessage) (reply bool) {
if len(values) < 2 {
return
}
Expand All @@ -346,10 +349,10 @@ func (p *pipe) handlePush(values []RedisMessage) {
} else {
p.cache.Delete(values[1].values)
}
case "message", "smessage":
case "message":
if len(values) >= 3 {
m := PubSubMessage{Channel: values[1].string, Message: values[2].string}
p.subs.Publish(values[1].string, m)
p.nsubs.Publish(values[1].string, m)
p.pshks.Load().(*pshks).hooks.OnMessage(m)
}
case "pmessage":
Expand All @@ -358,39 +361,67 @@ func (p *pipe) handlePush(values []RedisMessage) {
p.psubs.Publish(values[1].string, m)
p.pshks.Load().(*pshks).hooks.OnMessage(m)
}
case "unsubscribe", "sunsubscribe":
p.subs.Unsubscribe(values[1].string)
case "smessage":
if len(values) >= 3 {
m := PubSubMessage{Channel: values[1].string, Message: values[2].string}
p.ssubs.Publish(values[1].string, m)
p.pshks.Load().(*pshks).hooks.OnMessage(m)
}
case "unsubscribe":
p.nsubs.Unsubscribe(values[1].string)
if len(values) >= 3 {
p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string, Channel: values[1].string, Count: values[2].integer})
}
p.queue.CleanNoReply()
return true
case "punsubscribe":
p.psubs.Unsubscribe(values[1].string)
if len(values) >= 3 {
p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string, Channel: values[1].string, Count: values[2].integer})
}
p.queue.CleanNoReply()
case "subscribe", "psubscribe", "ssubscribe":
return true
case "sunsubscribe":
p.ssubs.Unsubscribe(values[1].string)
if len(values) >= 3 {
p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string, Channel: values[1].string, Count: values[2].integer})
}
return true
case "subscribe":
p.nsubs.Confirm(values[1].string)
if len(values) >= 3 {
p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string, Channel: values[1].string, Count: values[2].integer})
}
p.queue.CleanNoReply()
return true
case "psubscribe":
p.psubs.Confirm(values[1].string)
if len(values) >= 3 {
p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string, Channel: values[1].string, Count: values[2].integer})
}
return true
case "ssubscribe":
p.ssubs.Confirm(values[1].string)
if len(values) >= 3 {
p.pshks.Load().(*pshks).hooks.OnSubscription(PubSubSubscription{Kind: values[0].string, Channel: values[1].string, Count: values[2].integer})
}
return true
}
return false
}

func (p *pipe) Receive(ctx context.Context, subscribe cmds.Completed, fn func(message PubSubMessage)) error {
if p.subs == nil || p.psubs == nil {
if p.nsubs == nil || p.psubs == nil || p.ssubs == nil {
return ErrClosing
}

var sb *subs
cmd, args := subscribe.Commands()[0], subscribe.Commands()[1:]

switch cmd {
case "SUBSCRIBE", "SSUBSCRIBE":
sb = p.subs
case "SUBSCRIBE":
sb = p.nsubs
case "PSUBSCRIBE":
sb = p.psubs
case "SSUBSCRIBE":
sb = p.ssubs
default:
panic(wrongreceive)
}
Expand Down Expand Up @@ -523,11 +554,12 @@ func (p *pipe) DoMulti(ctx context.Context, multi ...cmds.Completed) []RedisResu
}

isOptIn := multi[0].IsOptIn() // len(multi) > 0 should have already been checked by upper layer
noReply := multi[0].NoReply()
noReply := false

for _, cmd := range multi[1:] {
if noReply != cmd.NoReply() {
panic(prohibitmix)
for _, cmd := range multi {
if cmd.NoReply() {
noReply = true
break
}
}

Expand Down Expand Up @@ -749,7 +781,6 @@ func deadFn() *pipe {

const (
protocolbug = "protocol bug, message handled out of order"
prohibitmix = "mixing SUBSCRIBE, PSUBSCRIBE, UNSUBSCRIBE, PUNSUBSCRIBE with other commands in DoMulti is prohibited"
wrongreceive = `only SUBSCRIBE, SSUBSCRIBE, or PSUBSCRIBE command are allowed in Receive`
)

Expand Down

0 comments on commit 2eba08d

Please sign in to comment.