Skip to content

Commit

Permalink
feat: rewrite cacheable MGET/JSON.MGET with slice pool
Browse files Browse the repository at this point in the history
  • Loading branch information
rueian committed Aug 2, 2022
1 parent 2e7d5bf commit 02e23f3
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 53 deletions.
26 changes: 5 additions & 21 deletions helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ func clusterMGetCache(cc *clusterClient, ctx context.Context, ttl time.Duration,

ret = make(map[string]RedisMessage, len(keys))

ch := make(chan uint16, len(mgets))
for slot := range mgets {
ch <- slot
ch := make(chan cmds.Cacheable, len(mgets))
for _, cmd := range mgets {
ch <- cmds.Cacheable(cmd)
}
close(ch)

Expand All @@ -61,23 +61,16 @@ func clusterMGetCache(cc *clusterClient, ctx context.Context, ttl time.Duration,
concurrency = cc.cpus
}

width := maxWidth(mgets)

for i := 0; i < concurrency; i++ {
go func() {
keyIdx := make([]string, width)
for slot := range ch {
cmd := cmds.Cacheable(mgets[slot])
for i, k := range cmd.Commands()[1:] {
keyIdx[i] = k
}
for cmd := range ch {
arr, err2 := cc.doCache(ctx, cmd, ttl).ToArray()
mu.Lock()
if err2 != nil {
err = err2
} else {
for i, resp := range arr {
ret[keyIdx[i]] = resp
ret[cmd.MGetCacheKey(i)] = resp
}
}
mu.Unlock()
Expand All @@ -94,12 +87,3 @@ func clusterMGetCache(cc *clusterClient, ctx context.Context, ttl time.Duration,
}
return ret, nil
}

func maxWidth(mgets map[uint16]cmds.Completed) (max int) {
for _, cmd := range mgets {
if l := len(cmd.Commands()); max < l {
max = l
}
}
return max
}
18 changes: 18 additions & 0 deletions internal/cmds/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,25 @@ func (c Arbitrary) ReadOnly() Completed {
return c.Build()
}

// MultiGet is used to complete constructing a command and mark it as mtGetTag command.
func (c Arbitrary) MultiGet() Completed {
if len(c.cs.s) == 0 || len(c.cs.s[0]) == 0 {
panic(arbitraryNoCommand)
}
if c.cs.s[0] != "MGET" && c.cs.s[0] != "JSON.MGET" {
panic(arbitraryMultiGet)
}
c.cf = mtGetTag
return c.Build()
}

// IsZero is used to test if Arbitrary is initialized
func (c Arbitrary) IsZero() bool {
return c.cs == nil
}

var (
arbitraryNoCommand = "Arbitrary should be provided with redis command"
arbitrarySubscribe = "Arbitrary does not support SUBSCRIBE/UNSUBSCRIBE"
arbitraryMultiGet = "Arbitrary.MultiGet is only valid for MGET and JSON.MGET"
)
39 changes: 39 additions & 0 deletions internal/cmds/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ retry:
}
}

func TestArbitraryIsZero(t *testing.T) {
builder := NewBuilder(NoSlot)
if cmd := builder.Arbitrary("any", "cmd"); cmd.IsZero() {
t.Fatalf("arbitrary failed")
}
var cmd Arbitrary
if !cmd.IsZero() {
t.Fatalf("arbitrary failed")
}
}

func TestArbitrary(t *testing.T) {
builder := NewBuilder(NoSlot)
cmd := builder.Arbitrary("any", "cmd").Keys("k1", "k2").Args("a1", "a2")
Expand Down Expand Up @@ -62,3 +73,31 @@ func TestEmptySubscribe(t *testing.T) {
}()
builder.Arbitrary("SUBSCRIBE").Build()
}

func TestEmptyArbitraryMultiGet(t *testing.T) {
builder := NewBuilder(NoSlot)
defer func() {
if e := recover(); e != arbitraryNoCommand {
t.Errorf("arbitrary not check empty")
}
}()
builder.Arbitrary().MultiGet()
}

func TestArbitraryMultiGet(t *testing.T) {
builder := NewBuilder(NoSlot)
cacheable := Cacheable(builder.Arbitrary("MGET").Args("KKK").MultiGet())
if !cacheable.IsMGet() {
t.Fatalf("arbitrary failed")
}
}

func TestArbitraryMultiGetPanic(t *testing.T) {
builder := NewBuilder(NoSlot)
defer func() {
if e := recover(); e != arbitraryMultiGet {
t.Errorf("arbitrary not check MGET command")
}
}()
builder.Arbitrary("SUBSCRIBE").MultiGet()
}
73 changes: 41 additions & 32 deletions pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -773,15 +773,16 @@ func (p *pipe) DoCache(ctx context.Context, cmd cmds.Cacheable, ttl time.Duratio
func (p *pipe) doCacheMGet(ctx context.Context, cmd cmds.Cacheable, ttl time.Duration) RedisResult {
commands := cmd.Commands()
entries := make(map[int]*entry)
builder := cmds.NewBuilder(cmds.InitSlot)
result := RedisResult{val: RedisMessage{typ: '*', values: nil}}
cc := cmd.MGetCacheCmd()
mgetcc := cmd.MGetCacheCmd()
keys := len(commands) - 1
if cc[0] == 'J' {
if mgetcc[0] == 'J' {
keys-- // the last one of JSON.MGET is a path, not a key
}
j := 1
var rewrite cmds.Arbitrary
for i, key := range commands[1 : keys+1] {
v, entry := p.cache.GetOrPrepare(cmd.MGetCacheKey(i), cc, ttl)
v, entry := p.cache.GetOrPrepare(key, mgetcc, ttl)
if v.typ != 0 { // cache hit for one key
if len(result.val.values) == 0 {
result.val.values = make([]RedisMessage, keys)
Expand All @@ -793,46 +794,54 @@ func (p *pipe) doCacheMGet(ctx context.Context, cmd cmds.Cacheable, ttl time.Dur
entries[i] = entry // store entries for later entry.Wait() to avoid MGET deadlock each others.
continue
}
commands[j] = key // rewrite MGET
j++
if rewrite.IsZero() {
rewrite = builder.Arbitrary(commands[0])
}
rewrite = rewrite.Args(key)
}

var partial []RedisMessage
if j != 1 {
last := j
if cc[0] == 'J' { // rewrite JSON.MGET path
commands[j] = commands[len(commands)-1]
last++
}
rewrite := cmds.NewMGetCompleted(commands[:last])
multi := make([]cmds.Completed, 0, len(commands[1:j])+4)
if !rewrite.IsZero() {
var rewritten cmds.Completed
var keys int
if mgetcc[0] == 'J' { // rewrite JSON.MGET path
rewritten = rewrite.Args(commands[len(commands)-1]).MultiGet()
keys = len(rewritten.Commands()) - 2
} else {
rewritten = rewrite.MultiGet()
keys = len(rewritten.Commands()) - 1
}

multi := make([]cmds.Completed, 0, keys+4)
multi = append(multi, cmds.OptInCmd, cmds.MultiCmd)
for _, key := range commands[1:j] {
multi = append(multi, cmds.NewCompleted([]string{"PTTL", key}))
for _, key := range rewritten.Commands()[1 : keys+1] {
multi = append(multi, builder.Pttl().Key(key).Build())
}
multi = append(multi, rewrite, cmds.ExecCmd)
multi = append(multi, rewritten, cmds.ExecCmd)

resp := p.DoMulti(ctx, multi...)
exec, err := resp[len(multi)-1].ToArray()
if err != nil {
var msg RedisMessage
var er2 error
if _, ok := err.(*RedisError); !ok {
msg = RedisMessage{}
er2 = err
} else if resp[len(multi)-2].val.typ != '+' { // EXEC aborted, return err of the input cmd in MULTI block
msg = resp[len(multi)-2].val
er2 = nil
} else {
msg = resp[len(multi)-1].val
er2 = nil
if _, ok := err.(*RedisError); ok {
err = nil
if resp[len(multi)-2].val.typ != '+' { // EXEC aborted, return err of the input cmd in MULTI block
msg = resp[len(multi)-2].val
} else {
msg = resp[len(multi)-1].val
}
}
for _, key := range commands[1:j] {
p.cache.Cancel(key, cc, msg, er2)
for _, key := range rewritten.Commands()[1 : keys+1] {
p.cache.Cancel(key, mgetcc, msg, err)
}
return newResult(msg, er2)
return newResult(msg, err)
}
if last == len(commands) { // all cache miss
defer func() {
for _, cmd := range multi[2 : len(multi)-1] {
cmds.Put(cmd.CommandSlice())
}
}()
if len(rewritten.Commands()) == len(commands) { // all cache miss
return newResult(exec[len(exec)-1], nil)
}
partial = exec[len(exec)-1].values
Expand All @@ -851,7 +860,7 @@ func (p *pipe) doCacheMGet(ctx context.Context, cmd cmds.Cacheable, ttl time.Dur
result.val.values[i] = v
}

j = 0
j := 0
for _, ret := range partial {
for ; j < len(result.val.values); j++ {
if result.val.values[j].typ == 0 {
Expand Down

0 comments on commit 02e23f3

Please sign in to comment.