Skip to content

Commit 8476dfe

Browse files
committed
Replace Wrap* with hooks that support context
1 parent b902746 commit 8476dfe

File tree

10 files changed

+423
-349
lines changed

10 files changed

+423
-349
lines changed

cluster.go

Lines changed: 181 additions & 192 deletions
Large diffs are not rendered by default.

command.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,14 @@ type baseCmd struct {
100100

101101
var _ Cmder = (*Cmd)(nil)
102102

103-
func (cmd *baseCmd) Err() error {
104-
return cmd.err
103+
func (cmd *baseCmd) Name() string {
104+
if len(cmd._args) > 0 {
105+
// Cmd name must be lower cased.
106+
s := internal.ToLower(cmd.stringArg(0))
107+
cmd._args[0] = s
108+
return s
109+
}
110+
return ""
105111
}
106112

107113
func (cmd *baseCmd) Args() []interface{} {
@@ -116,14 +122,8 @@ func (cmd *baseCmd) stringArg(pos int) string {
116122
return s
117123
}
118124

119-
func (cmd *baseCmd) Name() string {
120-
if len(cmd._args) > 0 {
121-
// Cmd name must be lower cased.
122-
s := internal.ToLower(cmd.stringArg(0))
123-
cmd._args[0] = s
124-
return s
125-
}
126-
return ""
125+
func (cmd *baseCmd) Err() error {
126+
return cmd.err
127127
}
128128

129129
func (cmd *baseCmd) readTimeout() *time.Duration {

example_instrumentation_test.go

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,54 @@
11
package redis_test
22

33
import (
4+
"context"
45
"fmt"
56

67
"github.com/go-redis/redis"
78
)
89

10+
type redisHook struct{}
11+
12+
var _ redis.Hook = redisHook{}
13+
14+
func (redisHook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
15+
fmt.Printf("starting processing: <%s>\n", cmd)
16+
return ctx, nil
17+
}
18+
19+
func (redisHook) AfterProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
20+
fmt.Printf("finished processing: <%s>\n", cmd)
21+
return ctx, nil
22+
}
23+
24+
func (redisHook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
25+
fmt.Printf("pipeline starting processing: %v\n", cmds)
26+
return ctx, nil
27+
}
28+
29+
func (redisHook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
30+
fmt.Printf("pipeline finished processing: %v\n", cmds)
31+
return ctx, nil
32+
}
33+
934
func Example_instrumentation() {
10-
redisdb := redis.NewClient(&redis.Options{
35+
rdb := redis.NewClient(&redis.Options{
1136
Addr: ":6379",
1237
})
13-
redisdb.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error {
14-
return func(cmd redis.Cmder) error {
15-
fmt.Printf("starting processing: <%s>\n", cmd)
16-
err := old(cmd)
17-
fmt.Printf("finished processing: <%s>\n", cmd)
18-
return err
19-
}
20-
})
38+
rdb.AddHook(redisHook{})
2139

22-
redisdb.Ping()
40+
rdb.Ping()
2341
// Output: starting processing: <ping: >
2442
// finished processing: <ping: PONG>
2543
}
2644

2745
func ExamplePipeline_instrumentation() {
28-
redisdb := redis.NewClient(&redis.Options{
46+
rdb := redis.NewClient(&redis.Options{
2947
Addr: ":6379",
3048
})
49+
rdb.AddHook(redisHook{})
3150

32-
redisdb.WrapProcessPipeline(func(old func([]redis.Cmder) error) func([]redis.Cmder) error {
33-
return func(cmds []redis.Cmder) error {
34-
fmt.Printf("pipeline starting processing: %v\n", cmds)
35-
err := old(cmds)
36-
fmt.Printf("pipeline finished processing: %v\n", cmds)
37-
return err
38-
}
39-
})
40-
41-
redisdb.Pipelined(func(pipe redis.Pipeliner) error {
51+
rdb.Pipelined(func(pipe redis.Pipeliner) error {
4252
pipe.Ping()
4353
pipe.Ping()
4454
return nil

redis.go

Lines changed: 125 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,114 @@ func SetLogger(logger *log.Logger) {
2323
internal.Logger = logger
2424
}
2525

26+
//------------------------------------------------------------------------------
27+
28+
type Hook interface {
29+
BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error)
30+
AfterProcess(ctx context.Context, cmd Cmder) (context.Context, error)
31+
32+
BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
33+
AfterProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error)
34+
}
35+
36+
type hooks struct {
37+
hooks []Hook
38+
}
39+
40+
func (hs *hooks) AddHook(hook Hook) {
41+
hs.hooks = append(hs.hooks, hook)
42+
}
43+
44+
func (hs *hooks) copy() {
45+
hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)]
46+
}
47+
48+
func (hs hooks) process(ctx context.Context, cmd Cmder, fn func(Cmder) error) error {
49+
ctx, err := hs.beforeProcess(ctx, cmd)
50+
if err != nil {
51+
return err
52+
}
53+
54+
cmdErr := fn(cmd)
55+
56+
_, err = hs.afterProcess(ctx, cmd)
57+
if err != nil {
58+
return err
59+
}
60+
61+
return cmdErr
62+
}
63+
64+
func (hs hooks) beforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) {
65+
for _, h := range hs.hooks {
66+
var err error
67+
ctx, err = h.BeforeProcess(ctx, cmd)
68+
if err != nil {
69+
return nil, err
70+
}
71+
}
72+
return ctx, nil
73+
}
74+
75+
func (hs hooks) afterProcess(ctx context.Context, cmd Cmder) (context.Context, error) {
76+
for _, h := range hs.hooks {
77+
var err error
78+
ctx, err = h.AfterProcess(ctx, cmd)
79+
if err != nil {
80+
return nil, err
81+
}
82+
}
83+
return ctx, nil
84+
}
85+
86+
func (hs hooks) processPipeline(ctx context.Context, cmds []Cmder, fn func([]Cmder) error) error {
87+
ctx, err := hs.beforeProcessPipeline(ctx, cmds)
88+
if err != nil {
89+
return err
90+
}
91+
92+
cmdsErr := fn(cmds)
93+
94+
_, err = hs.afterProcessPipeline(ctx, cmds)
95+
if err != nil {
96+
return err
97+
}
98+
99+
return cmdsErr
100+
}
101+
102+
func (hs hooks) beforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) {
103+
for _, h := range hs.hooks {
104+
var err error
105+
ctx, err = h.BeforeProcessPipeline(ctx, cmds)
106+
if err != nil {
107+
return nil, err
108+
}
109+
}
110+
return ctx, nil
111+
}
112+
113+
func (hs hooks) afterProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) {
114+
for _, h := range hs.hooks {
115+
var err error
116+
ctx, err = h.AfterProcessPipeline(ctx, cmds)
117+
if err != nil {
118+
return nil, err
119+
}
120+
}
121+
return ctx, nil
122+
}
123+
124+
//------------------------------------------------------------------------------
125+
26126
type baseClient struct {
27127
opt *Options
28128
connPool pool.Pooler
29129
limiter Limiter
30130

31-
process func(Cmder) error
32-
processPipeline func([]Cmder) error
33-
processTxPipeline func([]Cmder) error
34-
35131
onClose func() error // hook called when client is closed
36132
}
37133

38-
func (c *baseClient) init() {
39-
c.process = c.defaultProcess
40-
c.processPipeline = c.defaultProcessPipeline
41-
c.processTxPipeline = c.defaultProcessTxPipeline
42-
}
43-
44134
func (c *baseClient) String() string {
45135
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB)
46136
}
@@ -159,22 +249,11 @@ func (c *baseClient) initConn(cn *pool.Conn) error {
159249
// Do creates a Cmd from the args and processes the cmd.
160250
func (c *baseClient) Do(args ...interface{}) *Cmd {
161251
cmd := NewCmd(args...)
162-
_ = c.Process(cmd)
252+
_ = c.process(cmd)
163253
return cmd
164254
}
165255

166-
// WrapProcess wraps function that processes Redis commands.
167-
func (c *baseClient) WrapProcess(
168-
fn func(oldProcess func(cmd Cmder) error) func(cmd Cmder) error,
169-
) {
170-
c.process = fn(c.process)
171-
}
172-
173-
func (c *baseClient) Process(cmd Cmder) error {
174-
return c.process(cmd)
175-
}
176-
177-
func (c *baseClient) defaultProcess(cmd Cmder) error {
256+
func (c *baseClient) process(cmd Cmder) error {
178257
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ {
179258
if attempt > 0 {
180259
time.Sleep(c.retryBackoff(attempt))
@@ -249,18 +328,11 @@ func (c *baseClient) getAddr() string {
249328
return c.opt.Addr
250329
}
251330

252-
func (c *baseClient) WrapProcessPipeline(
253-
fn func(oldProcess func([]Cmder) error) func([]Cmder) error,
254-
) {
255-
c.processPipeline = fn(c.processPipeline)
256-
c.processTxPipeline = fn(c.processTxPipeline)
257-
}
258-
259-
func (c *baseClient) defaultProcessPipeline(cmds []Cmder) error {
331+
func (c *baseClient) processPipeline(cmds []Cmder) error {
260332
return c.generalProcessPipeline(cmds, c.pipelineProcessCmds)
261333
}
262334

263-
func (c *baseClient) defaultProcessTxPipeline(cmds []Cmder) error {
335+
func (c *baseClient) processTxPipeline(cmds []Cmder) error {
264336
return c.generalProcessPipeline(cmds, c.txPipelineProcessCmds)
265337
}
266338

@@ -388,6 +460,7 @@ type Client struct {
388460
cmdable
389461

390462
ctx context.Context
463+
hooks
391464
}
392465

393466
// NewClient returns a client to the Redis Server specified by Options.
@@ -400,7 +473,6 @@ func NewClient(opt *Options) *Client {
400473
connPool: newConnPool(opt),
401474
},
402475
}
403-
c.baseClient.init()
404476
c.init()
405477

406478
return &c
@@ -427,9 +499,22 @@ func (c *Client) WithContext(ctx context.Context) *Client {
427499
}
428500

429501
func (c *Client) clone() *Client {
430-
cp := *c
431-
cp.init()
432-
return &cp
502+
clone := *c
503+
clone.hooks.copy()
504+
clone.init()
505+
return &clone
506+
}
507+
508+
func (c *Client) Process(cmd Cmder) error {
509+
return c.hooks.process(c.ctx, cmd, c.baseClient.process)
510+
}
511+
512+
func (c *Client) processPipeline(cmds []Cmder) error {
513+
return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processPipeline)
514+
}
515+
516+
func (c *Client) processTxPipeline(cmds []Cmder) error {
517+
return c.hooks.processPipeline(c.ctx, cmds, c.baseClient.processTxPipeline)
433518
}
434519

435520
// Options returns read-only Options that were used to create the client.
@@ -547,11 +632,14 @@ func newConn(opt *Options, cn *pool.Conn) *Conn {
547632
connPool: pool.NewSingleConnPool(cn),
548633
},
549634
}
550-
c.baseClient.init()
551635
c.statefulCmdable.setProcessor(c.Process)
552636
return &c
553637
}
554638

639+
func (c *Conn) Process(cmd Cmder) error {
640+
return c.baseClient.process(cmd)
641+
}
642+
555643
func (c *Conn) Pipelined(fn func(Pipeliner) error) ([]Cmder, error) {
556644
return c.Pipeline().Pipelined(fn)
557645
}

redis_test.go

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -224,43 +224,6 @@ var _ = Describe("Client", func() {
224224
Expect(err).NotTo(HaveOccurred())
225225
Expect(got).To(Equal(bigVal))
226226
})
227-
228-
It("should call WrapProcess", func() {
229-
var fnCalled bool
230-
231-
client.WrapProcess(func(old func(redis.Cmder) error) func(redis.Cmder) error {
232-
return func(cmd redis.Cmder) error {
233-
fnCalled = true
234-
return old(cmd)
235-
}
236-
})
237-
238-
Expect(client.Ping().Err()).NotTo(HaveOccurred())
239-
Expect(fnCalled).To(BeTrue())
240-
})
241-
242-
It("should call WrapProcess after WithContext", func() {
243-
var fn1Called, fn2Called bool
244-
245-
client.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error {
246-
return func(cmd redis.Cmder) error {
247-
fn1Called = true
248-
return old(cmd)
249-
}
250-
})
251-
252-
client2 := client.WithContext(client.Context())
253-
client2.WrapProcess(func(old func(cmd redis.Cmder) error) func(cmd redis.Cmder) error {
254-
return func(cmd redis.Cmder) error {
255-
fn2Called = true
256-
return old(cmd)
257-
}
258-
})
259-
260-
Expect(client2.Ping().Err()).NotTo(HaveOccurred())
261-
Expect(fn2Called).To(BeTrue())
262-
Expect(fn1Called).To(BeTrue())
263-
})
264227
})
265228

266229
var _ = Describe("Client timeout", func() {

0 commit comments

Comments
 (0)