-
Notifications
You must be signed in to change notification settings - Fork 281
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0d29401
commit 435cb4e
Showing
5 changed files
with
416 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package retry | ||
|
||
import ( | ||
"context" | ||
"reflect" | ||
"time" | ||
|
||
"github.com/cenkalti/backoff/v4" | ||
) | ||
|
||
type config struct { | ||
maxInterval time.Duration | ||
watches []watch | ||
|
||
backoff.BackOff | ||
} | ||
|
||
// watch is a helper struct to watch multiple channels | ||
type watch struct { | ||
name string | ||
ch reflect.Value | ||
fn func(context.Context) error | ||
this bool | ||
} | ||
|
||
// Option configures the retry handler | ||
type Option func(*config) | ||
|
||
// WithWatch adds a watch to the retry handler | ||
// that will be triggered when a value is received on the channel | ||
// and the function will be called, also within a retry handler | ||
func WithWatch[T any](name string, ch <-chan T, fn func(context.Context) error) Option { | ||
return func(cfg *config) { | ||
cfg.watches = append(cfg.watches, watch{name: name, ch: reflect.ValueOf(ch), fn: fn, this: false}) | ||
} | ||
} | ||
|
||
// WithMaxInterval sets the upper bound for the retry handler | ||
func WithMaxInterval(d time.Duration) Option { | ||
return func(cfg *config) { | ||
cfg.maxInterval = d | ||
} | ||
} | ||
|
||
func newConfig(opts ...Option) ([]watch, backoff.BackOff) { | ||
cfg := new(config) | ||
for _, opt := range []Option{ | ||
WithMaxInterval(time.Minute * 5), | ||
} { | ||
opt(cfg) | ||
} | ||
|
||
for _, opt := range opts { | ||
opt(cfg) | ||
} | ||
|
||
for i, w := range cfg.watches { | ||
cfg.watches[i].fn = withRetry(cfg, w) | ||
} | ||
|
||
bo := backoff.NewExponentialBackOff() | ||
bo.MaxInterval = cfg.maxInterval | ||
bo.MaxElapsedTime = 0 | ||
|
||
return cfg.watches, bo | ||
} | ||
|
||
func withRetry(cfg *config, w watch) func(context.Context) error { | ||
if w.fn == nil { | ||
return func(_ context.Context) error { return nil } | ||
} | ||
|
||
return func(ctx context.Context) error { | ||
return Retry(ctx, w.name, w.fn, WithMaxInterval(cfg.maxInterval)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package retry | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
) | ||
|
||
// TerminalError is an error that should not be retried | ||
type TerminalError interface { | ||
error | ||
IsTerminal() | ||
} | ||
|
||
// terminalError is an error that should not be retried | ||
type terminalError struct { | ||
Err error | ||
} | ||
|
||
// Error implements error for terminalError | ||
func (e *terminalError) Error() string { | ||
return fmt.Sprintf("terminal error: %v", e.Err) | ||
} | ||
|
||
// Unwrap implements errors.Unwrap for terminalError | ||
func (e *terminalError) Unwrap() error { | ||
return e.Err | ||
} | ||
|
||
// Is implements errors.Is for terminalError | ||
func (e *terminalError) Is(err error) bool { | ||
//nolint:errorlint | ||
_, ok := err.(*terminalError) | ||
return ok | ||
} | ||
|
||
// IsTerminal implements TerminalError for terminalError | ||
func (e *terminalError) IsTerminal() {} | ||
|
||
// NewTerminalError creates a new terminal error that cannot be retried | ||
func NewTerminalError(err error) error { | ||
return &terminalError{Err: err} | ||
} | ||
|
||
// IsTerminalError returns true if the error is a terminal error | ||
func IsTerminalError(err error) bool { | ||
var te TerminalError | ||
return errors.As(err, &te) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
package retry_test | ||
|
||
import ( | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/pomerium/pomerium/internal/retry" | ||
) | ||
|
||
type testError string | ||
|
||
func (e testError) Error() string { | ||
return string(e) | ||
} | ||
|
||
func (e testError) IsTerminal() {} | ||
|
||
func TestError(t *testing.T) { | ||
t.Run("local terminal error", func(t *testing.T) { | ||
err := fmt.Errorf("wrap: %w", retry.NewTerminalError(fmt.Errorf("inner"))) | ||
require.True(t, retry.IsTerminalError(err)) | ||
}) | ||
t.Run("external terminal error", func(t *testing.T) { | ||
err := fmt.Errorf("wrap: %w", testError("inner")) | ||
require.True(t, retry.IsTerminalError(err)) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
// Package retry provides a retry loop with exponential back-off | ||
// while watching arbitrary signal channels for side effects. | ||
package retry | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"reflect" | ||
"time" | ||
) | ||
|
||
// Retry retries a function (with exponential back-off) until it succeeds. | ||
// It additionally watches arbitrary channels and calls the handler function when a value is received. | ||
// Handler functions are also retried with exponential back-off. | ||
// If a terminal error is returned from the handler function, the retry loop is aborted. | ||
// If the context is canceled, the retry loop is aborted. | ||
func Retry( | ||
ctx context.Context, | ||
name string, | ||
fn func(context.Context) error, | ||
opts ...Option, | ||
) error { | ||
watches, backoff := newConfig(opts...) | ||
ticker := time.NewTicker(backoff.NextBackOff()) | ||
defer ticker.Stop() | ||
|
||
s := makeSelect(ctx, watches, name, ticker.C, fn) | ||
|
||
restart: | ||
for { | ||
err := fn(ctx) | ||
if err == nil { | ||
return nil | ||
} | ||
if IsTerminalError(err) { | ||
return err | ||
} | ||
|
||
backoff.Reset() | ||
backoff: | ||
for { | ||
ticker.Reset(backoff.NextBackOff()) | ||
|
||
next, err := s.Exec(ctx) | ||
switch next { | ||
case nextRestart: | ||
continue restart | ||
case nextBackoff: | ||
continue backoff | ||
case nextExit: | ||
return err | ||
default: | ||
panic("unreachable") | ||
} | ||
} | ||
} | ||
} | ||
|
||
type selectCase struct { | ||
watches []watch | ||
cases []reflect.SelectCase | ||
} | ||
|
||
func makeSelect( | ||
ctx context.Context, | ||
watches []watch, | ||
name string, | ||
ch <-chan time.Time, | ||
fn func(context.Context) error, | ||
) *selectCase { | ||
watches = append(watches, | ||
watch{ | ||
name: "context", | ||
fn: func(ctx context.Context) error { | ||
// unreachable, the context handler will never be called | ||
// as its channel can only be closed | ||
return ctx.Err() | ||
}, | ||
ch: reflect.ValueOf(ctx.Done()), | ||
}, | ||
watch{ | ||
name: name, | ||
fn: fn, | ||
ch: reflect.ValueOf(ch), | ||
this: true, | ||
}, | ||
) | ||
cases := make([]reflect.SelectCase, 0, len(watches)) | ||
for _, w := range watches { | ||
cases = append(cases, reflect.SelectCase{ | ||
Dir: reflect.SelectRecv, | ||
Chan: w.ch, | ||
}) | ||
} | ||
return &selectCase{ | ||
watches: watches, | ||
cases: cases, | ||
} | ||
} | ||
|
||
type next int | ||
|
||
const ( | ||
nextRestart next = iota // try again from the beginning | ||
nextBackoff // backoff and try again | ||
nextExit // exit | ||
) | ||
|
||
func (s *selectCase) Exec(ctx context.Context) (next, error) { | ||
chosen, _, ok := reflect.Select(s.cases) | ||
if !ok { | ||
return nextExit, fmt.Errorf("watch %s closed", s.watches[chosen].name) | ||
} | ||
|
||
w := s.watches[chosen] | ||
|
||
err := w.fn(ctx) | ||
if err != nil { | ||
return onError(w, err) | ||
} | ||
|
||
if !w.this { | ||
return nextRestart, nil | ||
} | ||
|
||
return nextExit, nil | ||
} | ||
|
||
func onError(w watch, err error) (next, error) { | ||
if IsTerminalError(err) { | ||
return nextExit, err | ||
} | ||
|
||
if w.this { | ||
return nextBackoff, fmt.Errorf("retry %s failed: %w", w.name, err) | ||
} | ||
|
||
panic("unreachable, as watches are wrapped in retries and may only return terminal errors") | ||
} |
Oops, something went wrong.