Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions compose.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package sqlhooks

import (
"context"
"fmt"
)

// Compose allows for composing multiple Hooks into one.
// It runs every callback on every hook in argument order,
// even if previous hooks return an error.
// If multiple hooks return errors, the error return value will be
// MultipleErrors, which allows for introspecting the errors if necessary.
func Compose(hooks ...Hooks) Hooks {
return composed(hooks)
}

type composed []Hooks

func (c composed) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
var errors []error
for _, hook := range c {
c, err := hook.Before(ctx, query, args...)
if err != nil {
errors = append(errors, err)
}
if c != nil {
ctx = c
}
}
return ctx, wrapErrors(nil, errors)
}

func (c composed) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
var errors []error
for _, hook := range c {
var err error
c, err := hook.After(ctx, query, args...)
if err != nil {
errors = append(errors, err)
}
if c != nil {
ctx = c
}
}
return ctx, wrapErrors(nil, errors)
}

func (c composed) OnError(ctx context.Context, cause error, query string, args ...interface{}) error {
var errors []error
for _, hook := range c {
if onErrorer, ok := hook.(OnErrorer); ok {
if err := onErrorer.OnError(ctx, cause, query, args...); err != nil && err != cause {
errors = append(errors, err)
}
}
}
return wrapErrors(cause, errors)
}

func wrapErrors(def error, errors []error) error {
switch len(errors) {
case 0:
return def
case 1:
return errors[0]
default:
return MultipleErrors(errors)
}
}

// MultipleErrors is an error that contains multiple errors.
type MultipleErrors []error

func (m MultipleErrors) Error() string {
return fmt.Sprint("multiple errors:", []error(m))
}
26 changes: 26 additions & 0 deletions compose_1_13.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// +build go1.13

package sqlhooks

import "errors"

// Is returns true if any of the wrapped errors is target according to errors.Is()
func (m MultipleErrors) Is(target error) bool {
for _, err := range m {
if errors.Is(err, target) {
return true
}
}
return false
}

// Is tries to convert each wrapped error to target with errors.As() and returns true that succeeds.
// If none of the errors are convertible, returns false.
func (m MultipleErrors) As(target interface{}) bool {
for _, err := range m {
if errors.As(err, &target) {
return true
}
}
return false
}
96 changes: 96 additions & 0 deletions compose_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package sqlhooks

import (
"context"
"errors"
"reflect"
"testing"
)

var (
oops = errors.New("oops")
oopsHook = &testHooks{
before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, oops
},
after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, oops
},
onError: func(ctx context.Context, err error, query string, args ...interface{}) error {
return oops
},
}
okHook = &testHooks{
before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, nil
},
after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, nil
},
onError: func(ctx context.Context, err error, query string, args ...interface{}) error {
return nil
},
}
)

func TestCompose(t *testing.T) {
for _, it := range []struct {
name string
hooks Hooks
want error
}{
{"happy case", Compose(okHook, okHook), nil},
{"no hooks", Compose(), nil},
{"multiple errors", Compose(oopsHook, okHook, oopsHook), MultipleErrors([]error{oops, oops})},
{"single error", Compose(okHook, oopsHook, okHook), oops},
} {
t.Run(it.name, func(t *testing.T) {
t.Run("Before", func(t *testing.T) {
_, got := it.hooks.Before(context.Background(), "query")
if !reflect.DeepEqual(it.want, got) {
t.Errorf("unexpected error. want: %q, got: %q", it.want, got)
}
})
t.Run("After", func(t *testing.T) {
_, got := it.hooks.After(context.Background(), "query")
if !reflect.DeepEqual(it.want, got) {
t.Errorf("unexpected error. want: %q, got: %q", it.want, got)
}
})
t.Run("OnError", func(t *testing.T) {
cause := errors.New("crikey")
want := it.want
if want == nil {
want = cause
}
got := it.hooks.(OnErrorer).OnError(context.Background(), cause, "query")
if !reflect.DeepEqual(want, got) {
t.Errorf("unexpected error. want: %q, got: %q", want, got)
}
})
})
}
}

func TestWrapErrors(t *testing.T) {
var (
err1 = errors.New("oops")
err2 = errors.New("oops2")
)
for _, it := range []struct {
name string
def error
errors []error
want error
}{
{"no errors", err1, nil, err1},
{"single error", nil, []error{err1}, err1},
{"multiple errors", nil, []error{err1, err2}, MultipleErrors([]error{err1, err2})},
} {
t.Run(it.name, func(t *testing.T) {
if want, got := it.want, wrapErrors(it.def, it.errors); !reflect.DeepEqual(want, got) {
t.Errorf("unexpected wrapping. want: %q, got %q", want, got)
}
})
}
}
2 changes: 1 addition & 1 deletion sqlhooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (h *testHooks) After(ctx context.Context, query string, args ...interface{}
return h.after(ctx, query, args...)
}

func (h *testHooks) ErrHook(ctx context.Context, err error, query string, args ...interface{}) error {
func (h *testHooks) OnError(ctx context.Context, err error, query string, args ...interface{}) error {
return h.onError(ctx, err, query, args...)
}

Expand Down