diff --git a/compose.go b/compose.go new file mode 100644 index 0000000..767c9e9 --- /dev/null +++ b/compose.go @@ -0,0 +1,76 @@ +package sqlhooks + +import ( + "context" + "fmt" +) + +// Compose allows for composing multiple Hooks into one. +// It runs every callback on every hook in argument order, +// even if previous hooks return an error. +// If multiple hooks return errors, the error return value will be +// MultipleErrors, which allows for introspecting the errors if necessary. +func Compose(hooks ...Hooks) Hooks { + return composed(hooks) +} + +type composed []Hooks + +func (c composed) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + var errors []error + for _, hook := range c { + c, err := hook.Before(ctx, query, args...) + if err != nil { + errors = append(errors, err) + } + if c != nil { + ctx = c + } + } + return ctx, wrapErrors(nil, errors) +} + +func (c composed) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + var errors []error + for _, hook := range c { + var err error + c, err := hook.After(ctx, query, args...) + if err != nil { + errors = append(errors, err) + } + if c != nil { + ctx = c + } + } + return ctx, wrapErrors(nil, errors) +} + +func (c composed) OnError(ctx context.Context, cause error, query string, args ...interface{}) error { + var errors []error + for _, hook := range c { + if onErrorer, ok := hook.(OnErrorer); ok { + if err := onErrorer.OnError(ctx, cause, query, args...); err != nil && err != cause { + errors = append(errors, err) + } + } + } + return wrapErrors(cause, errors) +} + +func wrapErrors(def error, errors []error) error { + switch len(errors) { + case 0: + return def + case 1: + return errors[0] + default: + return MultipleErrors(errors) + } +} + +// MultipleErrors is an error that contains multiple errors. +type MultipleErrors []error + +func (m MultipleErrors) Error() string { + return fmt.Sprint("multiple errors:", []error(m)) +} diff --git a/compose_1_13.go b/compose_1_13.go new file mode 100644 index 0000000..7337d70 --- /dev/null +++ b/compose_1_13.go @@ -0,0 +1,26 @@ +// +build go1.13 + +package sqlhooks + +import "errors" + +// Is returns true if any of the wrapped errors is target according to errors.Is() +func (m MultipleErrors) Is(target error) bool { + for _, err := range m { + if errors.Is(err, target) { + return true + } + } + return false +} + +// Is tries to convert each wrapped error to target with errors.As() and returns true that succeeds. +// If none of the errors are convertible, returns false. +func (m MultipleErrors) As(target interface{}) bool { + for _, err := range m { + if errors.As(err, &target) { + return true + } + } + return false +} diff --git a/compose_test.go b/compose_test.go new file mode 100644 index 0000000..216ddbe --- /dev/null +++ b/compose_test.go @@ -0,0 +1,96 @@ +package sqlhooks + +import ( + "context" + "errors" + "reflect" + "testing" +) + +var ( + oops = errors.New("oops") + oopsHook = &testHooks{ + before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + return ctx, oops + }, + after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + return ctx, oops + }, + onError: func(ctx context.Context, err error, query string, args ...interface{}) error { + return oops + }, + } + okHook = &testHooks{ + before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + return ctx, nil + }, + after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + return ctx, nil + }, + onError: func(ctx context.Context, err error, query string, args ...interface{}) error { + return nil + }, + } +) + +func TestCompose(t *testing.T) { + for _, it := range []struct { + name string + hooks Hooks + want error + }{ + {"happy case", Compose(okHook, okHook), nil}, + {"no hooks", Compose(), nil}, + {"multiple errors", Compose(oopsHook, okHook, oopsHook), MultipleErrors([]error{oops, oops})}, + {"single error", Compose(okHook, oopsHook, okHook), oops}, + } { + t.Run(it.name, func(t *testing.T) { + t.Run("Before", func(t *testing.T) { + _, got := it.hooks.Before(context.Background(), "query") + if !reflect.DeepEqual(it.want, got) { + t.Errorf("unexpected error. want: %q, got: %q", it.want, got) + } + }) + t.Run("After", func(t *testing.T) { + _, got := it.hooks.After(context.Background(), "query") + if !reflect.DeepEqual(it.want, got) { + t.Errorf("unexpected error. want: %q, got: %q", it.want, got) + } + }) + t.Run("OnError", func(t *testing.T) { + cause := errors.New("crikey") + want := it.want + if want == nil { + want = cause + } + got := it.hooks.(OnErrorer).OnError(context.Background(), cause, "query") + if !reflect.DeepEqual(want, got) { + t.Errorf("unexpected error. want: %q, got: %q", want, got) + } + }) + }) + } +} + +func TestWrapErrors(t *testing.T) { + var ( + err1 = errors.New("oops") + err2 = errors.New("oops2") + ) + for _, it := range []struct { + name string + def error + errors []error + want error + }{ + {"no errors", err1, nil, err1}, + {"single error", nil, []error{err1}, err1}, + {"multiple errors", nil, []error{err1, err2}, MultipleErrors([]error{err1, err2})}, + } { + t.Run(it.name, func(t *testing.T) { + if want, got := it.want, wrapErrors(it.def, it.errors); !reflect.DeepEqual(want, got) { + t.Errorf("unexpected wrapping. want: %q, got %q", want, got) + } + }) + } +} diff --git a/sqlhooks_test.go b/sqlhooks_test.go index ac9bfad..904a920 100644 --- a/sqlhooks_test.go +++ b/sqlhooks_test.go @@ -35,7 +35,7 @@ func (h *testHooks) After(ctx context.Context, query string, args ...interface{} return h.after(ctx, query, args...) } -func (h *testHooks) ErrHook(ctx context.Context, err error, query string, args ...interface{}) error { +func (h *testHooks) OnError(ctx context.Context, err error, query string, args ...interface{}) error { return h.onError(ctx, err, query, args...) }