Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions options.go
Original file line number Diff line number Diff line change
@@ -21,10 +21,12 @@ type Limiter interface {
// Allow returns nil if operation is allowed or an error otherwise.
// If operation is allowed client must ReportResult of the operation
// whether it is a success or a failure.
Allow() error
// The returned context will be passed to ReportResult.
Allow(ctx context.Context) (context.Context, error)
// ReportResult reports the result of the previously allowed operation.
// nil indicates a success, non-nil error usually indicates a failure.
ReportResult(result error)
// Context can be used to access state tracked by previous Allow call.
ReportResult(ctx context.Context, result error)
}

// Options keeps the settings to set up redis connection.
4 changes: 2 additions & 2 deletions osscluster.go
Original file line number Diff line number Diff line change
@@ -1319,7 +1319,7 @@ func (c *ClusterClient) processPipelineNode(
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
) {
_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
cn, err := node.Client.getConn(ctx)
ctx, cn, err := node.Client.getConn(ctx)
if err != nil {
node.MarkAsFailing()
_ = c.mapCmdsByNode(ctx, failedCmds, cmds)
@@ -1504,7 +1504,7 @@ func (c *ClusterClient) processTxPipelineNode(
) {
cmds = wrapMultiExec(ctx, cmds)
_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
cn, err := node.Client.getConn(ctx)
ctx, cn, err := node.Client.getConn(ctx)
if err != nil {
_ = c.mapCmdsByNode(ctx, failedCmds, cmds)
setCmdsErr(cmds, err)
17 changes: 9 additions & 8 deletions redis.go
Original file line number Diff line number Diff line change
@@ -237,23 +237,24 @@ func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) {
return cn, nil
}

func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) {
func (c *baseClient) getConn(ctx context.Context) (context.Context, *pool.Conn, error) {
var err error
if c.opt.Limiter != nil {
err := c.opt.Limiter.Allow()
ctx, err = c.opt.Limiter.Allow(ctx)
if err != nil {
return nil, err
return ctx, nil, err
}
}

cn, err := c._getConn(ctx)
if err != nil {
if c.opt.Limiter != nil {
c.opt.Limiter.ReportResult(err)
c.opt.Limiter.ReportResult(ctx, err)
}
return nil, err
return ctx, nil, err
}

return cn, nil
return ctx, cn, nil
}

func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) {
@@ -365,7 +366,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error {

func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) {
if c.opt.Limiter != nil {
c.opt.Limiter.ReportResult(err)
c.opt.Limiter.ReportResult(ctx, err)
}

if isBadConn(err, false, c.opt.Addr) {
@@ -378,7 +379,7 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error)
func (c *baseClient) withConn(
ctx context.Context, fn func(context.Context, *pool.Conn) error,
) error {
cn, err := c.getConn(ctx)
ctx, cn, err := c.getConn(ctx)
if err != nil {
return err
}
Loading
Oops, something went wrong.