diff --git a/internal/retry/config.go b/internal/retry/config.go new file mode 100644 index 00000000000..a5929c77f20 --- /dev/null +++ b/internal/retry/config.go @@ -0,0 +1,76 @@ +package retry + +import ( + "context" + "reflect" + "time" + + "github.com/cenkalti/backoff/v4" +) + +type config struct { + maxInterval time.Duration + watches []watch + + backoff.BackOff +} + +// watch is a helper struct to watch multiple channels +type watch struct { + name string + ch reflect.Value + fn func(context.Context) error + this bool +} + +// Option configures the retry handler +type Option func(*config) + +// WithWatch adds a watch to the retry handler +// that will be triggered when a value is received on the channel +// and the function will be called, also within a retry handler +func WithWatch[T any](name string, ch <-chan T, fn func(context.Context) error) Option { + return func(cfg *config) { + cfg.watches = append(cfg.watches, watch{name: name, ch: reflect.ValueOf(ch), fn: fn, this: false}) + } +} + +// WithMaxInterval sets the upper bound for the retry handler +func WithMaxInterval(d time.Duration) Option { + return func(cfg *config) { + cfg.maxInterval = d + } +} + +func newConfig(opts ...Option) ([]watch, backoff.BackOff) { + cfg := new(config) + for _, opt := range []Option{ + WithMaxInterval(time.Minute * 5), + } { + opt(cfg) + } + + for _, opt := range opts { + opt(cfg) + } + + for i, w := range cfg.watches { + cfg.watches[i].fn = withRetry(cfg, w) + } + + bo := backoff.NewExponentialBackOff() + bo.MaxInterval = cfg.maxInterval + bo.MaxElapsedTime = 0 + + return cfg.watches, bo +} + +func withRetry(cfg *config, w watch) func(context.Context) error { + if w.fn == nil { + return func(_ context.Context) error { return nil } + } + + return func(ctx context.Context) error { + return Retry(ctx, w.name, w.fn, WithMaxInterval(cfg.maxInterval)) + } +} diff --git a/internal/retry/error.go b/internal/retry/error.go new file mode 100644 index 00000000000..bc7d25a31de --- /dev/null +++ b/internal/retry/error.go @@ -0,0 +1,48 @@ +package retry + +import ( + "errors" + "fmt" +) + +// TerminalError is an error that should not be retried +type TerminalError interface { + error + IsTerminal() +} + +// terminalError is an error that should not be retried +type terminalError struct { + Err error +} + +// Error implements error for terminalError +func (e *terminalError) Error() string { + return fmt.Sprintf("terminal error: %v", e.Err) +} + +// Unwrap implements errors.Unwrap for terminalError +func (e *terminalError) Unwrap() error { + return e.Err +} + +// Is implements errors.Is for terminalError +func (e *terminalError) Is(err error) bool { + //nolint:errorlint + _, ok := err.(*terminalError) + return ok +} + +// IsTerminal implements TerminalError for terminalError +func (e *terminalError) IsTerminal() {} + +// NewTerminalError creates a new terminal error that cannot be retried +func NewTerminalError(err error) error { + return &terminalError{Err: err} +} + +// IsTerminalError returns true if the error is a terminal error +func IsTerminalError(err error) bool { + var te TerminalError + return errors.As(err, &te) +} diff --git a/internal/retry/error_test.go b/internal/retry/error_test.go new file mode 100644 index 00000000000..b4f2ddb00eb --- /dev/null +++ b/internal/retry/error_test.go @@ -0,0 +1,29 @@ +package retry_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/retry" +) + +type testError string + +func (e testError) Error() string { + return string(e) +} + +func (e testError) IsTerminal() {} + +func TestError(t *testing.T) { + t.Run("local terminal error", func(t *testing.T) { + err := fmt.Errorf("wrap: %w", retry.NewTerminalError(fmt.Errorf("inner"))) + require.True(t, retry.IsTerminalError(err)) + }) + t.Run("external terminal error", func(t *testing.T) { + err := fmt.Errorf("wrap: %w", testError("inner")) + require.True(t, retry.IsTerminalError(err)) + }) +} diff --git a/internal/retry/retry.go b/internal/retry/retry.go new file mode 100644 index 00000000000..02042e349c9 --- /dev/null +++ b/internal/retry/retry.go @@ -0,0 +1,139 @@ +// Package retry provides a retry loop with exponential back-off +// while watching arbitrary signal channels for side effects. +package retry + +import ( + "context" + "fmt" + "reflect" + "time" +) + +// Retry retries a function (with exponential back-off) until it succeeds. +// It additionally watches arbitrary channels and calls the handler function when a value is received. +// Handler functions are also retried with exponential back-off. +// If a terminal error is returned from the handler function, the retry loop is aborted. +// If the context is canceled, the retry loop is aborted. +func Retry( + ctx context.Context, + name string, + fn func(context.Context) error, + opts ...Option, +) error { + watches, backoff := newConfig(opts...) + ticker := time.NewTicker(backoff.NextBackOff()) + defer ticker.Stop() + + s := makeSelect(ctx, watches, name, ticker.C, fn) + +restart: + for { + err := fn(ctx) + if err == nil { + return nil + } + if IsTerminalError(err) { + return err + } + + backoff.Reset() + backoff: + for { + ticker.Reset(backoff.NextBackOff()) + + next, err := s.Exec(ctx) + switch next { + case nextRestart: + continue restart + case nextBackoff: + continue backoff + case nextExit: + return err + default: + panic("unreachable") + } + } + } +} + +type selectCase struct { + watches []watch + cases []reflect.SelectCase +} + +func makeSelect( + ctx context.Context, + watches []watch, + name string, + ch <-chan time.Time, + fn func(context.Context) error, +) *selectCase { + watches = append(watches, + watch{ + name: "context", + fn: func(ctx context.Context) error { + // unreachable, the context handler will never be called + // as its channel can only be closed + return ctx.Err() + }, + ch: reflect.ValueOf(ctx.Done()), + }, + watch{ + name: name, + fn: fn, + ch: reflect.ValueOf(ch), + this: true, + }, + ) + cases := make([]reflect.SelectCase, 0, len(watches)) + for _, w := range watches { + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: w.ch, + }) + } + return &selectCase{ + watches: watches, + cases: cases, + } +} + +type next int + +const ( + nextRestart next = iota // try again from the beginning + nextBackoff // backoff and try again + nextExit // exit +) + +func (s *selectCase) Exec(ctx context.Context) (next, error) { + chosen, _, ok := reflect.Select(s.cases) + if !ok { + return nextExit, fmt.Errorf("watch %s closed", s.watches[chosen].name) + } + + w := s.watches[chosen] + + err := w.fn(ctx) + if err != nil { + return onError(w, err) + } + + if !w.this { + return nextRestart, nil + } + + return nextExit, nil +} + +func onError(w watch, err error) (next, error) { + if IsTerminalError(err) { + return nextExit, err + } + + if w.this { + return nextBackoff, fmt.Errorf("retry %s failed: %w", w.name, err) + } + + panic("unreachable, as watches are wrapped in retries and may only return terminal errors") +} diff --git a/internal/retry/retry_test.go b/internal/retry/retry_test.go new file mode 100644 index 00000000000..c131653e9a4 --- /dev/null +++ b/internal/retry/retry_test.go @@ -0,0 +1,124 @@ +package retry_test + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/retry" +) + +func TestRetry(t *testing.T) { + t.Parallel() + + ctx := context.Background() + limit := retry.WithMaxInterval(time.Second * 5) + + t.Run("no error", func(t *testing.T) { + t.Parallel() + + err := retry.Retry(ctx, "test", func(_ context.Context) error { + return nil + }, limit) + require.NoError(t, err) + }) + + t.Run("eventually succeeds", func(t *testing.T) { + t.Parallel() + i := 0 + err := retry.Retry(ctx, "test", func(_ context.Context) error { + if i++; i > 2 { + return nil + } + return fmt.Errorf("transient %d", i) + }, limit) + require.NoError(t, err) + }) + + t.Run("eventually fails", func(t *testing.T) { + t.Parallel() + i := 0 + err := retry.Retry(ctx, "test", func(_ context.Context) error { + if i++; i > 2 { + return retry.NewTerminalError(errors.New("the end")) + } + return fmt.Errorf("transient %d", i) + }) + require.Error(t, err) + }) + + t.Run("context canceled", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(ctx) + cancel() + err := retry.Retry(ctx, "test", func(_ context.Context) error { + return fmt.Errorf("retry") + }) + require.Error(t, err) + }) + + t.Run("context canceled after retry", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + err := retry.Retry(ctx, "test", func(_ context.Context) error { + cancel() + return fmt.Errorf("retry") + }) + require.Error(t, err) + }) + + t.Run("success after watch hook", func(t *testing.T) { + t.Parallel() + ch := make(chan struct{}, 1) + ch <- struct{}{} + ok := false + err := retry.Retry(ctx, "test", func(_ context.Context) error { + if ok { + return nil + } + return fmt.Errorf("retry") + }, retry.WithWatch("watch", ch, func(_ context.Context) error { + ok = true + return nil + }), limit) + require.NoError(t, err) + }) + + t.Run("success after watch hook retried", func(t *testing.T) { + t.Parallel() + ch := make(chan struct{}, 1) + ch <- struct{}{} + ok := false + i := 0 + err := retry.Retry(ctx, "test", func(_ context.Context) error { + if ok { + return nil + } + return fmt.Errorf("retry test") + }, retry.WithWatch("watch", ch, func(_ context.Context) error { + if i++; i > 1 { + ok = true + return nil + } + return fmt.Errorf("retry watch") + }), limit) + require.NoError(t, err) + }) + + t.Run("watch hook fails", func(t *testing.T) { + t.Parallel() + ch := make(chan struct{}, 1) + ch <- struct{}{} + err := retry.Retry(ctx, "test", func(_ context.Context) error { + return fmt.Errorf("retry") + }, retry.WithWatch("watch", ch, func(_ context.Context) error { + return retry.NewTerminalError(fmt.Errorf("watch")) + }), limit) + require.Error(t, err) + }) +}