Skip to content

Commit 09eb108

Browse files
committed
Allow passing context where possible
1 parent 3da4357 commit 09eb108

File tree

9 files changed

+98
-51
lines changed

9 files changed

+98
-51
lines changed

cluster.go

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -724,16 +724,24 @@ func (c *ClusterClient) Close() error {
724724

725725
// Do creates a Cmd from the args and processes the cmd.
726726
func (c *ClusterClient) Do(args ...interface{}) *Cmd {
727+
return c.DoContext(c.ctx, args...)
728+
}
729+
730+
func (c *ClusterClient) DoContext(ctx context.Context, args ...interface{}) *Cmd {
727731
cmd := NewCmd(args...)
728-
c.Process(cmd)
732+
c.ProcessContext(ctx, cmd)
729733
return cmd
730734
}
731735

732736
func (c *ClusterClient) Process(cmd Cmder) error {
733-
return c.hooks.process(c.ctx, cmd, c.process)
737+
return c.ProcessContext(c.ctx, cmd)
738+
}
739+
740+
func (c *ClusterClient) ProcessContext(ctx context.Context, cmd Cmder) error {
741+
return c.hooks.process(ctx, cmd, c.process)
734742
}
735743

736-
func (c *ClusterClient) process(cmd Cmder) error {
744+
func (c *ClusterClient) process(ctx context.Context, cmd Cmder) error {
737745
var node *clusterNode
738746
var ask bool
739747
for attempt := 0; attempt <= c.opt.MaxRedirects; attempt++ {
@@ -755,11 +763,11 @@ func (c *ClusterClient) process(cmd Cmder) error {
755763
pipe := node.Client.Pipeline()
756764
_ = pipe.Process(NewCmd("ASKING"))
757765
_ = pipe.Process(cmd)
758-
_, err = pipe.Exec()
766+
_, err = pipe.ExecContext(ctx)
759767
_ = pipe.Close()
760768
ask = false
761769
} else {
762-
err = node.Client.Process(cmd)
770+
err = node.Client.ProcessContext(ctx, cmd)
763771
}
764772

765773
// If there is no error - we are done.
@@ -1022,11 +1030,11 @@ func (c *ClusterClient) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
10221030
return c.Pipeline().Pipelined(fn)
10231031
}
10241032

1025-
func (c *ClusterClient) processPipeline(cmds []Cmder) error {
1033+
func (c *ClusterClient) processPipeline(ctx context.Context, cmds []Cmder) error {
10261034
return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline)
10271035
}
10281036

1029-
func (c *ClusterClient) _processPipeline(cmds []Cmder) error {
1037+
func (c *ClusterClient) _processPipeline(ctx context.Context, cmds []Cmder) error {
10301038
cmdsMap := newCmdsMap()
10311039
err := c.mapCmdsByNode(cmds, cmdsMap)
10321040
if err != nil {
@@ -1216,11 +1224,11 @@ func (c *ClusterClient) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
12161224
return c.TxPipeline().Pipelined(fn)
12171225
}
12181226

1219-
func (c *ClusterClient) processTxPipeline(cmds []Cmder) error {
1220-
return c.hooks.processPipeline(c.ctx, cmds, c._processTxPipeline)
1227+
func (c *ClusterClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
1228+
return c.hooks.processPipeline(ctx, cmds, c._processTxPipeline)
12211229
}
12221230

1223-
func (c *ClusterClient) _processTxPipeline(cmds []Cmder) error {
1231+
func (c *ClusterClient) _processTxPipeline(ctx context.Context, cmds []Cmder) error {
12241232
state, err := c.state.Get()
12251233
if err != nil {
12261234
return err

iterator.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package redis
22

3-
import "sync"
3+
import (
4+
"sync"
5+
)
46

57
// ScanIterator is used to incrementally iterate over a collection of elements.
68
// It's safe for concurrent use by multiple goroutines.

options.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ import (
1616

1717
// Limiter is the interface of a rate limiter or a circuit breaker.
1818
type Limiter interface {
19-
// Allow returns a nil if operation is allowed or an error otherwise.
20-
// If operation is allowed client must report the result of operation
21-
// whether is a success or a failure.
19+
// Allow returns nil if operation is allowed or an error otherwise.
20+
// If operation is allowed client must ReportResult of the operation
21+
// whether it is a success or a failure.
2222
Allow() error
2323
// ReportResult reports the result of previously allowed operation.
2424
// nil indicates a success, non-nil error indicates a failure.

pipeline.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
package redis
22

33
import (
4+
"context"
45
"sync"
56

67
"github.com/go-redis/redis/internal/pool"
78
)
89

9-
type pipelineExecer func([]Cmder) error
10+
type pipelineExecer func(context.Context, []Cmder) error
1011

1112
// Pipeliner is an mechanism to realise Redis Pipeline technique.
1213
//
@@ -28,6 +29,7 @@ type Pipeliner interface {
2829
Close() error
2930
Discard() error
3031
Exec() ([]Cmder, error)
32+
ExecContext(ctx context.Context) ([]Cmder, error)
3133
}
3234

3335
var _ Pipeliner = (*Pipeline)(nil)
@@ -96,6 +98,10 @@ func (c *Pipeline) discard() error {
9698
// Exec always returns list of commands and error of the first failed
9799
// command if any.
98100
func (c *Pipeline) Exec() ([]Cmder, error) {
101+
return c.ExecContext(nil)
102+
}
103+
104+
func (c *Pipeline) ExecContext(ctx context.Context) ([]Cmder, error) {
99105
c.mu.Lock()
100106
defer c.mu.Unlock()
101107

@@ -110,10 +116,10 @@ func (c *Pipeline) Exec() ([]Cmder, error) {
110116
cmds := c.cmds
111117
c.cmds = nil
112118

113-
return cmds, c.exec(cmds)
119+
return cmds, c.exec(ctx, cmds)
114120
}
115121

116-
func (c *Pipeline) pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
122+
func (c *Pipeline) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
117123
if err := fn(c); err != nil {
118124
return nil, err
119125
}
@@ -122,16 +128,12 @@ func (c *Pipeline) pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
122128
return cmds, err
123129
}
124130

125-
func (c *Pipeline) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
126-
return c.pipelined(fn)
127-
}
128-
129131
func (c *Pipeline) Pipeline() Pipeliner {
130132
return c
131133
}
132134

133135
func (c *Pipeline) TxPipelined(fn func(Pipeliner) error) ([]Cmder, error) {
134-
return c.pipelined(fn)
136+
return c.Pipelined(fn)
135137
}
136138

137139
func (c *Pipeline) TxPipeline() Pipeliner {

redis.go

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,15 @@ func (hs *hooks) AddHook(hook Hook) {
4545
hs.hooks = append(hs.hooks, hook)
4646
}
4747

48-
func (hs hooks) process(ctx context.Context, cmd Cmder, fn func(Cmder) error) error {
48+
func (hs hooks) process(
49+
ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error,
50+
) error {
4951
ctx, err := hs.beforeProcess(ctx, cmd)
5052
if err != nil {
5153
return err
5254
}
5355

54-
cmdErr := fn(cmd)
56+
cmdErr := fn(ctx, cmd)
5557

5658
_, err = hs.afterProcess(ctx, cmd)
5759
if err != nil {
@@ -83,13 +85,15 @@ func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) (context.Context, e
8385
return ctx, nil
8486
}
8587

86-
func (hs hooks) processPipeline(ctx context.Context, cmds []Cmder, fn func([]Cmder) error) error {
88+
func (hs hooks) processPipeline(
89+
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error,
90+
) error {
8791
ctx, err := hs.beforeProcessPipeline(ctx, cmds)
8892
if err != nil {
8993
return err
9094
}
9195

92-
cmdsErr := fn(cmds)
96+
cmdsErr := fn(ctx, cmds)
9397

9498
_, err = hs.afterProcessPipeline(ctx, cmds)
9599
if err != nil {
@@ -246,14 +250,7 @@ func (c *baseClient) initConn(cn *pool.Conn) error {
246250
return nil
247251
}
248252

249-
// Do creates a Cmd from the args and processes the cmd.
250-
func (c *baseClient) Do(args ...interface{}) *Cmd {
251-
cmd := NewCmd(args...)
252-
_ = c.process(cmd)
253-
return cmd
254-
}
255-
256-
func (c *baseClient) process(cmd Cmder) error {
253+
func (c *baseClient) process(ctx context.Context, cmd Cmder) error {
257254
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
258255
if attempt > 0 {
259256
time.Sleep(c.retryBackoff(attempt))
@@ -328,11 +325,11 @@ func (c *baseClient) getAddr() string {
328325
return c.opt.Addr
329326
}
330327

331-
func (c *baseClient) processPipeline(cmds []Cmder) error {
328+
func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error {
332329
return c.generalProcessPipeline(cmds, c.pipelineProcessCmds)
333330
}
334331

335-
func (c *baseClient) processTxPipeline(cmds []Cmder) error {
332+
func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error {
336333
return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds)
337334
}
338335

@@ -503,16 +500,31 @@ func (c *Client) WithContext(ctx context.Context) *Client {
503500
return &clone
504501
}
505502

503+
// Do creates a Cmd from the args and processes the cmd.
504+
func (c *Client) Do(args ...interface{}) *Cmd {
505+
return c.DoContext(c.ctx, args...)
506+
}
507+
508+
func (c *Client) DoContext(ctx context.Context, args ...interface{}) *Cmd {
509+
cmd := NewCmd(args...)
510+
_ = c.ProcessContext(ctx, cmd)
511+
return cmd
512+
}
513+
506514
func (c *Client) Process(cmd Cmder) error {
507-
return c.hooks.process(c.ctx, cmd, c.baseClient.process)
515+
return c.ProcessContext(c.ctx, cmd)
508516
}
509517

510-
func (c *Client) processPipeline(cmds []Cmder) error {
511-
return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processPipeline)
518+
func (c *Client) ProcessContext(ctx context.Context, cmd Cmder) error {
519+
return c.hooks.process(ctx, cmd, c.baseClient.process)
512520
}
513521

514-
func (c *Client) processTxPipeline(cmds []Cmder) error {
515-
return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processTxPipeline)
522+
func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error {
523+
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline)
524+
}
525+
526+
func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error {
527+
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processTxPipeline)
516528
}
517529

518530
// Options returns read-only Options that were used to create the client.
@@ -637,7 +649,11 @@ func newConn(opt *Options, cn *pool.Conn) *Conn {
637649
}
638650

639651
func (c *Conn) Process(cmd Cmder) error {
640-
return c.baseClient.process(cmd)
652+
return c.ProcessContext(context.TODO(), cmd)
653+
}
654+
655+
func (c *Conn) ProcessContext(ctx context.Context, cmd Cmder) error {
656+
return c.baseClient.process(ctx, cmd)
641657
}
642658

643659
func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {

ring.go

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -396,13 +396,21 @@ func (c *Ring) WithContext(ctx context.Context) *Ring {
396396

397397
// Do creates a Cmd from the args and processes the cmd.
398398
func (c *Ring) Do(args ...interface{}) *Cmd {
399+
return c.DoContext(c.ctx, args...)
400+
}
401+
402+
func (c *Ring) DoContext(ctx context.Context, args ...interface{}) *Cmd {
399403
cmd := NewCmd(args...)
400-
c.Process(cmd)
404+
c.ProcessContext(ctx, cmd)
401405
return cmd
402406
}
403407

404408
func (c *Ring) Process(cmd Cmder) error {
405-
return c.hooks.process(c.ctx, cmd, c.process)
409+
return c.ProcessContext(c.ctx, cmd)
410+
}
411+
412+
func (c *Ring) ProcessContext(ctx context.Context, cmd Cmder) error {
413+
return c.hooks.process(ctx, cmd, c.process)
406414
}
407415

408416
// Options returns read-only Options that were used to create the client.
@@ -532,7 +540,7 @@ func (c *Ring) cmdShard(cmd Cmder) (*ringShard, error) {
532540
return c.shards.GetByKey(firstKey)
533541
}
534542

535-
func (c *Ring) process(cmd Cmder) error {
543+
func (c *Ring) process(ctx context.Context, cmd Cmder) error {
536544
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
537545
if attempt > 0 {
538546
time.Sleep(c.retryBackoff(attempt))
@@ -544,7 +552,7 @@ func (c *Ring) process(cmd Cmder) error {
544552
return err
545553
}
546554

547-
err = shard.Client.Process(cmd)
555+
err = shard.Client.ProcessContext(ctx, cmd)
548556
if err == nil {
549557
return nil
550558
}
@@ -567,11 +575,11 @@ func (c *Ring) Pipeline() Pipeliner {
567575
return &pipe
568576
}
569577

570-
func (c *Ring) processPipeline(cmds []Cmder) error {
571-
return c.hooks.processPipeline(c.ctx, cmds, c._processPipeline)
578+
func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error {
579+
return c.hooks.processPipeline(ctx, cmds, c._processPipeline)
572580
}
573581

574-
func (c *Ring) _processPipeline(cmds []Cmder) error {
582+
func (c *Ring) _processPipeline(ctx context.Context, cmds []Cmder) error {
575583
cmdsMap := make(map[string][]Cmder)
576584
for _, cmd := range cmds {
577585
cmdInfo := c.cmdInfo(cmd.Name())

sentinel.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,11 @@ func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient {
136136
}
137137

138138
func (c *SentinelClient) Process(cmd Cmder) error {
139-
return c.baseClient.process(cmd)
139+
return c.ProcessContext(c.ctx, cmd)
140+
}
141+
142+
func (c *SentinelClient) ProcessContext(ctx context.Context, cmd Cmder) error {
143+
return c.baseClient.process(ctx, cmd)
140144
}
141145

142146
func (c *SentinelClient) pubSub() *PubSub {

tx.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@ func (c *Tx) WithContext(ctx context.Context) *Tx {
5656
}
5757

5858
func (c *Tx) Process(cmd Cmder) error {
59-
return c.baseClient.process(cmd)
59+
return c.ProcessContext(c.ctx, cmd)
60+
}
61+
62+
func (c *Tx) ProcessContext(ctx context.Context, cmd Cmder) error {
63+
return c.baseClient.process(ctx, cmd)
6064
}
6165

6266
// Watch prepares a transaction and marks the keys to be watched

universal.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,10 @@ type UniversalClient interface {
162162
Context() context.Context
163163
AddHook(Hook)
164164
Watch(fn func(*Tx) error, keys ...string) error
165+
Do(args ...interface{}) *Cmd
166+
DoContext(ctx context.Context, args ...interface{}) *Cmd
165167
Process(cmd Cmder) error
168+
ProcessContext(ctx context.Context, cmd Cmder) error
166169
Subscribe(channels ...string) *PubSub
167170
PSubscribe(channels ...string) *PubSub
168171
Close() error

0 commit comments

Comments
 (0)