Skip to content

Commit

Permalink
Merge pull request #15 from xuzhenglun/master
Browse files Browse the repository at this point in the history
handler sql error, closes #10
  • Loading branch information
Gustavo Chaín committed Jan 30, 2018
2 parents b4a12ba + aaebe3c commit 928fba3
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 9 deletions.
6 changes: 6 additions & 0 deletions hooks/loghooks/loghooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (co
h.log.Printf("Query: `%s`, Args: `%q`. took: %s", query, args, time.Since(ctx.Value("started").(time.Time)))
return ctx, nil
}

func (h *Hook) OnError(ctx context.Context, err error, query string, args ...interface{}) error {
h.log.Printf("Error: %v, Query: `%s`, Args: `%q`, Took: %s",
err, query, args, time.Since(ctx.Value("started").(time.Time)))
return err
}
22 changes: 19 additions & 3 deletions hooks/othooks/othooks.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package othooks

import "context"
import "github.com/opentracing/opentracing-go"
import "github.com/opentracing/opentracing-go/log"
import (
"context"

"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/log"
)

type Hook struct {
tracer opentracing.Tracer
Expand Down Expand Up @@ -35,3 +38,16 @@ func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (co

return ctx, nil
}

func (h *Hook) OnError(ctx context.Context, err error, query string, args ...interface{}) error {
span := opentracing.SpanFromContext(ctx)
if span != nil {
defer span.Finish()
span.SetTag("error", true)
span.LogFields(
log.Error(err),
)
}

return err
}
27 changes: 24 additions & 3 deletions sqlhooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,33 @@ import (
// Hook is the hook callback signature
type Hook func(ctx context.Context, query string, args ...interface{}) (context.Context, error)

// ErrorHook is the error handling callback signature
type ErrorHook func(ctx context.Context, err error, query string, args ...interface{}) error

// Hooks instances may be passed to Wrap() to define an instrumented driver
type Hooks interface {
Before(ctx context.Context, query string, args ...interface{}) (context.Context, error)
After(ctx context.Context, query string, args ...interface{}) (context.Context, error)
}

// OnErrorer instances will be called if any error happens
type OnErrorer interface {
OnError(ctx context.Context, err error, query string, args ...interface{}) error
}

func handlerErr(ctx context.Context, hooks Hooks, err error, query string, args ...interface{}) error {
h, ok := hooks.(OnErrorer)
if !ok {
return err
}

if err := h.OnError(ctx, err, query, args...); err != nil {
return err
}

return err
}

// Driver implements a database/sql/driver.Driver
type Driver struct {
driver.Driver
Expand Down Expand Up @@ -110,7 +131,7 @@ func (conn *ExecerContext) ExecContext(ctx context.Context, query string, args [

results, err := conn.execContext(ctx, query, args)
if err != nil {
return results, err
return results, handlerErr(ctx, conn.hooks, err, query, list...)
}

if ctx, err = conn.hooks.After(ctx, query, list...); err != nil {
Expand Down Expand Up @@ -160,7 +181,7 @@ func (stmt *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (dr

results, err := stmt.execContext(ctx, args)
if err != nil {
return results, err
return results, handlerErr(ctx, stmt.hooks, err, stmt.query, list...)
}

if ctx, err = stmt.hooks.After(ctx, stmt.query, list...); err != nil {
Expand Down Expand Up @@ -194,7 +215,7 @@ func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (d

rows, err := stmt.queryContext(ctx, args)
if err != nil {
return rows, err
return rows, handlerErr(ctx, stmt.hooks, err, stmt.query, list...)
}

if ctx, err = stmt.hooks.After(ctx, stmt.query, list...); err != nil {
Expand Down
1 change: 1 addition & 0 deletions sqlhooks_mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func TestMySQL(t *testing.T) {
s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1)
s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus")
s.TestHooksErrors(t, "SELECT 1+1")
s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS")

t.Run("DBWorks", func(t *testing.T) {
s.hooks.noop()
Expand Down
1 change: 1 addition & 0 deletions sqlhooks_postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func TestPostgres(t *testing.T) {
s.TestHooksExecution(t, "SELECT * FROM users WHERE id = $1", 1)
s.TestHooksArguments(t, "SELECT * FROM users WHERE id = $1 AND name = $2", int64(1), "Gus")
s.TestHooksErrors(t, "SELECT 1+1")
s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS")

t.Run("DBWorks", func(t *testing.T) {
s.hooks.noop()
Expand Down
3 changes: 2 additions & 1 deletion sqlhooks_sqlite3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"testing"
"time"

sqlite3 "github.com/mattn/go-sqlite3"
"github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand All @@ -31,6 +31,7 @@ func TestSQLite3(t *testing.T) {
s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1)
s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus")
s.TestHooksErrors(t, "SELECT 1+1")
s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS")

t.Run("DBWorks", func(t *testing.T) {
s.hooks.noop()
Expand Down
31 changes: 29 additions & 2 deletions sqlhooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ import (
)

type testHooks struct {
before Hook
after Hook
before Hook
after Hook
onError ErrorHook
}

func (h *testHooks) noop() {
Expand All @@ -34,6 +35,10 @@ func (h *testHooks) After(ctx context.Context, query string, args ...interface{}
return h.after(ctx, query, args...)
}

func (h *testHooks) ErrHook(ctx context.Context, err error, query string, args ...interface{}) error {
return h.onError(ctx, err, query, args...)
}

type suite struct {
db *sql.DB
hooks *testHooks
Expand Down Expand Up @@ -155,6 +160,28 @@ func (s *suite) TestHooksErrors(t *testing.T, query string) {
t.Run("TestHooksErrors", func(t *testing.T) { s.testHooksErrors(t, query) })
}

func (s *suite) testErrHookHook(t *testing.T, query string, args ...interface{}) {
s.hooks.before = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, nil
}

s.hooks.after = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
assert.False(t, true, "after hook should not run")
return ctx, nil
}

s.hooks.onError = func(ctx context.Context, err error, query string, args ...interface{}) error {
assert.True(t, true, "onError hook should run")
return err
}

s.db.Query(query)
}

func (s *suite) TestErrHookHook(t *testing.T, query string, args ...interface{}) {
t.Run("TestErrHookHook", func(t *testing.T) { s.testErrHookHook(t, query, args...) })
}

func TestNamedValueToValue(t *testing.T) {
named := []driver.NamedValue{
{Ordinal: 1, Value: "foo"},
Expand Down

0 comments on commit 928fba3

Please sign in to comment.