diff --git a/README.md b/README.md index 3e4facf..6c48cdb 100644 --- a/README.md +++ b/README.md @@ -10,60 +10,3 @@ go get github.com/gchaincl/sqlhooks ``` # Usage [![GoDoc](https://godoc.org/github.com/gchaincl/dotsql?status.svg)](https://godoc.org/github.com/gchaincl/sqlhooks) -```go - package main - - import ( - "log" - "time" - - "github.com/gchaincl/sqlhooks" - _ "github.com/mattn/go-sqlite3" - ) - - // Hooks satisfies sqlhooks.Queryer interface - type Hooks struct { - count int - } - - func (h *Hooks) BeforeQuery(ctx *sqlhooks.Context) error { - h.count++ - ctx.Set("t", time.Now()) - ctx.Set("id", h.count) - log.Printf("[query#%d] %s, args: %v", ctx.Get("id").(int), ctx.Query, ctx.Args) - return nil - } - - func (h *Hooks) AfterQuery(ctx *sqlhooks.Context) error { - d := time.Since(ctx.Get("t").(time.Time)) - log.Printf("[query#%d] took %s (err: %v)", ctx.Get("id").(int), d, ctx.Error) - return ctx.Error - } - - func main() { - db, _ := sqlhooks.Open("sqlite3", ":memory:", &Hooks{}) - - // Do you're stuff - db.Exec("CREATE TABLE t (id INTEGER, text VARCHAR(16))") - db.Exec("INSERT into t (text) VALUES(?), (?)", "foo", "bar") - db.Query("SELECT id, text FROM t") - db.Query("Invalid Query") - } - -``` - -``` -2016/06/02 14:28:24 [query#1] SELECT id, text FROM t, args: [] -2016/06/02 14:28:24 [query#1] took 122.406µs (err: ) -2016/06/02 14:28:24 [query#2] Invalid Query, args: [] -2016/06/02 14:28:24 [query#2] took 23.148µs (err: near "Invalid": syntax error) -``` - -# Benchmark -``` -PASS -BenchmarkExec-4 500000 4604 ns/op -BenchmarkExecWithSQLHooks-4 300000 5726 ns/op -BenchmarkPreparedExec-4 1000000 1820 ns/op -BenchmarkPreparedExecWithSQLHooks-4 1000000 2088 ns/op -``` diff --git a/benchmark_test.go b/benchmark_test.go deleted file mode 100644 index 356616d..0000000 --- a/benchmark_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package sqlhooks - -import ( - "database/sql" - "testing" -) - -func init() { - sql.Register("sqlhooks", NewDriver("test", NewHooksMock( - func(ctx *Context) error { - return nil - }, - func(ctx *Context) error { - return ctx.Error - }, - ))) -} - -func newDB(b *testing.B, driver string) *sql.DB { - db, err := sql.Open(driver, "db") - if err != nil { - b.Fatalf("Open: %v", err) - } - - if _, err := db.Exec("WIPE"); err != nil { - b.Fatalf("WIPE: %v", err) - } - - if _, err := db.Exec("CREATE|t|f1=string"); err != nil { - b.Fatalf("CREATE: %v", err) - } - - return db -} - -func BenchmarkExec(b *testing.B) { - db := newDB(b, "test") - for i := 0; i < b.N; i++ { - _, err := db.Exec("INSERT|t|f1=?", "xxx") - if err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkExecWithSQLHooks(b *testing.B) { - db := newDB(b, "sqlhooks") - for i := 0; i < b.N; i++ { - _, err := db.Exec("INSERT|t|f1=?", "xxx") - if err != nil { - b.Fatal(err) - } - } -} -func BenchmarkPreparedExec(b *testing.B) { - db := newDB(b, "test") - stmt, err := db.Prepare("INSERT|t|f1=?") - if err != nil { - b.Fatalf("prepare: %v", err) - } - - for i := 0; i < b.N; i++ { - if _, err := stmt.Exec("xxx"); err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkPreparedExecWithSQLHooks(b *testing.B) { - db := newDB(b, "sqlhooks") - stmt, err := db.Prepare("INSERT|t|f1=?") - if err != nil { - b.Fatalf("prepare: %v", err) - } - - for i := 0; i < b.N; i++ { - if _, err := stmt.Exec("xxx"); err != nil { - b.Fatal(err) - } - } -} diff --git a/context.go b/context.go deleted file mode 100644 index 6faaef0..0000000 --- a/context.go +++ /dev/null @@ -1,33 +0,0 @@ -package sqlhooks - -type Context struct { - Error error - Query string - Args []interface{} - - values map[string]interface{} -} - -func NewContext() *Context { - return &Context{} -} - -func (ctx *Context) Get(key string) interface{} { - if ctx.values == nil { - ctx.values = make(map[string]interface{}) - } - - if v, ok := ctx.values[key]; ok { - return v - } - - return nil -} - -func (ctx *Context) Set(key string, value interface{}) { - if ctx.values == nil { - ctx.values = make(map[string]interface{}) - } - - ctx.values[key] = value -} diff --git a/driver.go b/driver.go deleted file mode 100644 index 279bdef..0000000 --- a/driver.go +++ /dev/null @@ -1,258 +0,0 @@ -package sqlhooks - -import ( - "database/sql" - "database/sql/driver" -) - -func driverToInterface(args []driver.Value) []interface{} { - r := make([]interface{}, len(args)) - for i, arg := range args { - r[i] = arg - } - return r -} - -func interfaceToDriver(args []interface{}) []driver.Value { - r := make([]driver.Value, len(args)) - for i, arg := range args { - r[i] = arg - } - return r -} - -type tx struct { - driver.Tx - hooks HookType - ctx *Context -} - -func (t tx) Commit() error { - var ctx *Context - - if v, ok := t.hooks.(Commiter); ok { - ctx = NewContext() - if err := v.BeforeCommit(ctx); err != nil { - return err - } - } - - err := t.Tx.Commit() - - if v, ok := t.hooks.(Commiter); ok { - ctx.Error = err - err = v.AfterCommit(ctx) - } - - return err -} - -func (t tx) Rollback() error { - var ctx *Context - - if v, ok := t.hooks.(Rollbacker); ok { - ctx = NewContext() - if err := v.BeforeRollback(ctx); err != nil { - return err - } - } - - err := t.Tx.Rollback() - - if v, ok := t.hooks.(Rollbacker); ok { - ctx.Error = err - err = v.AfterRollback(ctx) - } - - return err -} - -type stmt struct { - driver.Stmt - hooks HookType - ctx *Context -} - -func (s stmt) Close() error { - return s.Stmt.Close() -} - -func (s stmt) Exec(args []driver.Value) (res driver.Result, err error) { - if t, ok := s.hooks.(Stmter); ok { - s.ctx.Args = driverToInterface(args) - if err := t.BeforeStmtExec(s.ctx); err != nil { - return nil, err - } - args = interfaceToDriver(s.ctx.Args) - } - - return s.Stmt.Exec(args) -} - -func (s stmt) NumInput() int { - return s.Stmt.NumInput() -} - -func (s stmt) Query(args []driver.Value) (driver.Rows, error) { - if t, ok := s.hooks.(Stmter); ok { - s.ctx.Args = driverToInterface(args) - if err := t.BeforeStmtQuery(s.ctx); err != nil { - return nil, err - } - args = interfaceToDriver(s.ctx.Args) - } - - rows, err := s.Stmt.Query(args) - - if t, ok := s.hooks.(Stmter); ok { - s.ctx.Error = err - err = t.AfterStmtQuery(s.ctx) - } - - return rows, err -} - -type conn struct { - driver.Conn - hooks HookType -} - -func (c conn) Prepare(query string) (driver.Stmt, error) { - var ctx *Context - - if t, ok := c.hooks.(Stmter); ok { - ctx = NewContext() - ctx.Query = query - - if err := t.BeforePrepare(ctx); err != nil { - return nil, err - } - - query = ctx.Query - } - - _stmt, err := c.Conn.Prepare(query) - - if t, ok := c.hooks.(Stmter); ok { - err = t.AfterPrepare(ctx) - } - - return stmt{_stmt, c.hooks, ctx}, err -} - -func (c conn) Query(query string, args []driver.Value) (driver.Rows, error) { - if queryer, ok := c.Conn.(driver.Queryer); ok { - var ctx *Context - if t, ok := c.hooks.(Queryer); ok { - ctx = NewContext() - ctx.Query = query - ctx.Args = driverToInterface(args) - - if err := t.BeforeQuery(ctx); err != nil { - return nil, err - } - - query = ctx.Query - args = interfaceToDriver(ctx.Args) - } - - rows, err := queryer.Query(query, args) - - if t, ok := c.hooks.(Queryer); ok { - ctx.Error = err - err = t.AfterQuery(ctx) - } - - return rows, err - } - - // Not implemented by underlying driver - return nil, driver.ErrSkip -} - -func (c conn) Exec(query string, args []driver.Value) (driver.Result, error) { - if execer, ok := c.Conn.(driver.Execer); ok { - var ctx *Context - if t, ok := c.hooks.(Execer); ok { - ctx = NewContext() - ctx.Query = query - ctx.Args = driverToInterface(args) - - if err := t.BeforeExec(ctx); err != nil { - return nil, err - } - - query = ctx.Query - args = interfaceToDriver(ctx.Args) - - } - - res, err := execer.Exec(query, args) - - if t, ok := c.hooks.(Execer); ok { - ctx.Error = err - err = t.AfterExec(ctx) - } - - return res, err - } - - // Not implemented by underlying driver - return nil, driver.ErrSkip -} - -func (c conn) Close() error { - return c.Conn.Close() -} - -func (c conn) Begin() (driver.Tx, error) { - var ctx *Context - - if t, ok := c.hooks.(Beginner); ok { - ctx = NewContext() - - if err := t.BeforeBegin(ctx); err != nil { - return nil, err - } - } - - _tx, err := c.Conn.Begin() - - if t, ok := c.hooks.(Beginner); ok { - ctx.Error = err - err = t.AfterBegin(ctx) - } - - return tx{_tx, c.hooks, ctx}, err -} - -// Driver it's a proxy for a specific sql driver -type Driver struct { - driver driver.Driver - name string - hooks HookType -} - -// NewDriver will create a Proxy Driver with defined Hooks -// name is the underlying driver name -func NewDriver(name string, hooks HookType) *Driver { - return &Driver{name: name, hooks: hooks} -} - -// Open returns a new connection to the database, using the underlying specified driver -func (d *Driver) Open(dsn string) (driver.Conn, error) { - if d.driver == nil { - // Get Driver by Opening a new connection - db, err := sql.Open(d.name, dsn) - if err != nil { - return nil, err - } - if err := db.Close(); err != nil { - return nil, err - } - d.driver = db.Driver() - } - - _conn, err := d.driver.Open(dsn) - return conn{_conn, d.hooks}, err -} diff --git a/example_test.go b/example_test.go deleted file mode 100644 index c500226..0000000 --- a/example_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package sqlhooks - -import "database/sql" - -type MyQueryer struct{} - -func (q MyQueryer) BeforeQuery(ctx *Context) error { - return nil -} - -func (q MyQueryer) AfterQuery(ctx *Context) error { - return nil -} - -func ExampleNewDriver() { - // MyQueryer satisfies Queryer interface - hooks := MyQueryer{} - - // mysql is the driver we're going to attach to - driver := NewDriver("mysql", &hooks) - sql.Register("sqlhooks-mysql", driver) -} - -func ExampleOpen() { - // Where using nil as HookType, so no hooks will be attached. - // In order attach hooks, the HookType should implement one of the following interfaces: - // - Beginner - // - Commiter - // - Rollbacker - // - Stmter - // - Queryer - // - Execer - db, err := Open("mysql", "user:pass@/db", nil) - if err != nil { - panic(err) - } - - db.Query("SELECT 1+1") -} diff --git a/examples/gorp/main.go b/examples/gorp/main.go deleted file mode 100644 index c18e399..0000000 --- a/examples/gorp/main.go +++ /dev/null @@ -1,65 +0,0 @@ -package main - -import ( - "log" - "strings" - - "github.com/gchaincl/sqlhooks" - "github.com/go-gorp/gorp" - _ "github.com/mattn/go-sqlite3" -) - -type Post struct { - // db tag lets you specify the column name if it differs from the struct field - Id int64 `db:"post_id"` - Title string `db:",size:50"` // Column size set to 50 - Body string `db:"article_body,size:1024"` // Set both column name and size -} - -type Hooks struct{} - -func (h Hooks) BeforeQuery(ctx *sqlhooks.Context) error { - log.Println(ctx.Query, ctx.Args) - return nil -} - -func (h Hooks) AfterQuery(ctx *sqlhooks.Context) error { - return ctx.Error -} - -// Update Post's title field Before Inserting -func (h Hooks) BeforeExec(ctx *sqlhooks.Context) error { - if strings.HasPrefix(ctx.Query, `insert into "posts"`) { - ctx.Args[0] = "[updated] " + ctx.Args[0].(string) - } - return nil -} - -func (h Hooks) AfterExec(ctx *sqlhooks.Context) error { - return ctx.Error -} - -func main() { - db, err := sqlhooks.Open("sqlite3", ":memory:", &Hooks{}) - if err != nil { - panic(err) - } - - dbmap := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}} - dbmap.AddTableWithName(Post{}, "posts").SetKeys(true, "Id") - if err := dbmap.CreateTablesIfNotExists(); err != nil { - panic(err) - } - - dbmap.Insert( - &Post{Title: "Foo", Body: "Some Content"}, - &Post{Title: "Bar", Body: "More Content"}, - ) - post := Post{} - p, err := dbmap.Get(&post, 1) - if err != nil { - panic(err) - } - - log.Printf("%#v", p) -} diff --git a/examples/meddler/main.go b/examples/meddler/main.go deleted file mode 100644 index 36b9def..0000000 --- a/examples/meddler/main.go +++ /dev/null @@ -1,44 +0,0 @@ -package main - -import ( - "log" - "time" - - "github.com/gchaincl/sqlhooks" - _ "github.com/mattn/go-sqlite3" - "github.com/russross/meddler" -) - -type Person struct { - ID int `meddler:"id,pk"` - Name string `meddler:"name"` - Age int `meddler:"age"` - Created time.Time `meddler:"created,localtime"` -} - -type MyQueyer struct { - count int -} - -func (mq *MyQueyer) BeforeQuery(ctx *sqlhooks.Context) error { - mq.count++ - - ctx.Set("id", mq.count) - log.Printf("[query#%d] %s %q", ctx.Get("id").(int), ctx.Query, ctx.Args) - return nil -} - -func (mq MyQueyer) AfterQuery(ctx *sqlhooks.Context) error { - log.Printf("[query#%d] done (err = %v)", ctx.Get("id").(int), ctx.Error) - return ctx.Error -} - -func main() { - db, err := sqlhooks.Open("sqlite3", ":memory:", &MyQueyer{}) - if err != nil { - panic(err) - } - - p := new(Person) - meddler.Load(db, "person", p, 1) -} diff --git a/fakedb_test.go b/fakedb_test.go deleted file mode 100644 index 595c407..0000000 --- a/fakedb_test.go +++ /dev/null @@ -1,881 +0,0 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package sqlhooks - -import ( - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "io" - "log" - "strconv" - "strings" - "sync" - "testing" - "time" -) - -var _ = log.Printf - -// fakeDriver is a fake database that implements Go's driver.Driver -// interface, just for testing. -// -// It speaks a query language that's semantically similar to but -// syntactically different and simpler than SQL. The syntax is as -// follows: -// -// WIPE -// CREATE||=,=,... -// where types are: "string", [u]int{8,16,32,64}, "bool" -// INSERT||col=val,col2=val2,col3=? -// SELECT||projectcol1,projectcol2|filtercol=?,filtercol2=? -// -// Any of these can be preceded by PANIC||, to cause the -// named method on fakeStmt to panic. -// -// When opening a fakeDriver's database, it starts empty with no -// tables. All tables and data are stored in memory only. -type fakeDriver struct { - mu sync.Mutex // guards 3 following fields - openCount int // conn opens - closeCount int // conn closes - waitCh chan struct{} - waitingCh chan struct{} - dbs map[string]*fakeDB -} - -type fakeDB struct { - name string - - mu sync.Mutex - free []*fakeConn - tables map[string]*table - badConn bool -} - -type table struct { - mu sync.Mutex - colname []string - coltype []string - rows []*row -} - -func (t *table) columnIndex(name string) int { - for n, nname := range t.colname { - if name == nname { - return n - } - } - return -1 -} - -type row struct { - cols []interface{} // must be same size as its table colname + coltype -} - -func (r *row) clone() *row { - nrow := &row{cols: make([]interface{}, len(r.cols))} - copy(nrow.cols, r.cols) - return nrow -} - -type fakeConn struct { - db *fakeDB // where to return ourselves to - - currTx *fakeTx - - // Stats for tests: - mu sync.Mutex - stmtsMade int - stmtsClosed int - numPrepare int - - // bad connection tests; see isBad() - bad bool - stickyBad bool -} - -func (c *fakeConn) incrStat(v *int) { - c.mu.Lock() - *v++ - c.mu.Unlock() -} - -type fakeTx struct { - c *fakeConn -} - -type fakeStmt struct { - c *fakeConn - q string // just for debugging - - cmd string - table string - panic string - - closed bool - - colName []string // used by CREATE, INSERT, SELECT (selected columns) - colType []string // used by CREATE - colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) - placeholders int // used by INSERT/SELECT: number of ? params - - whereCol []string // used by SELECT (all placeholders) - - placeholderConverter []driver.ValueConverter // used by INSERT -} - -var fdriver driver.Driver = &fakeDriver{} - -func init() { - sql.Register("test", fdriver) -} - -func contains(list []string, y string) bool { - for _, x := range list { - if x == y { - return true - } - } - return false -} - -type Dummy struct { - driver.Driver -} - -// hook to simulate connection failures -var hookOpenErr struct { - sync.Mutex - fn func() error -} - -func setHookOpenErr(fn func() error) { - hookOpenErr.Lock() - defer hookOpenErr.Unlock() - hookOpenErr.fn = fn -} - -// Supports dsn forms: -// -// ; (only currently supported option is `badConn`, -// which causes driver.ErrBadConn to be returned on -// every other conn.Begin()) -func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { - hookOpenErr.Lock() - fn := hookOpenErr.fn - hookOpenErr.Unlock() - if fn != nil { - if err := fn(); err != nil { - return nil, err - } - } - parts := strings.Split(dsn, ";") - if len(parts) < 1 { - return nil, errors.New("fakedb: no database name") - } - name := parts[0] - - db := d.getDB(name) - - d.mu.Lock() - d.openCount++ - d.mu.Unlock() - conn := &fakeConn{db: db} - - if len(parts) >= 2 && parts[1] == "badConn" { - conn.bad = true - } - if d.waitCh != nil { - d.waitingCh <- struct{}{} - <-d.waitCh - d.waitCh = nil - d.waitingCh = nil - } - return conn, nil -} - -func (d *fakeDriver) getDB(name string) *fakeDB { - d.mu.Lock() - defer d.mu.Unlock() - if d.dbs == nil { - d.dbs = make(map[string]*fakeDB) - } - db, ok := d.dbs[name] - if !ok { - db = &fakeDB{name: name} - d.dbs[name] = db - } - return db -} - -func (db *fakeDB) wipe() { - db.mu.Lock() - defer db.mu.Unlock() - db.tables = nil -} - -func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { - db.mu.Lock() - defer db.mu.Unlock() - if db.tables == nil { - db.tables = make(map[string]*table) - } - if _, exist := db.tables[name]; exist { - return fmt.Errorf("table %q already exists", name) - } - if len(columnNames) != len(columnTypes) { - return fmt.Errorf("create table of %q len(names) != len(types): %d vs %d", - name, len(columnNames), len(columnTypes)) - } - db.tables[name] = &table{colname: columnNames, coltype: columnTypes} - return nil -} - -// must be called with db.mu lock held -func (db *fakeDB) table(table string) (*table, bool) { - if db.tables == nil { - return nil, false - } - t, ok := db.tables[table] - return t, ok -} - -func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { - db.mu.Lock() - defer db.mu.Unlock() - t, ok := db.table(table) - if !ok { - return - } - for n, cname := range t.colname { - if cname == column { - return t.coltype[n], true - } - } - return "", false -} - -func (c *fakeConn) isBad() bool { - if c.stickyBad { - return true - } else if c.bad { - // alternate between bad conn and not bad conn - c.db.badConn = !c.db.badConn - return c.db.badConn - } else { - return false - } -} - -func (c *fakeConn) Begin() (driver.Tx, error) { - if c.isBad() { - return nil, driver.ErrBadConn - } - if c.currTx != nil { - return nil, errors.New("already in a transaction") - } - c.currTx = &fakeTx{c: c} - return c.currTx, nil -} - -var hookPostCloseConn struct { - sync.Mutex - fn func(*fakeConn, error) -} - -func setHookpostCloseConn(fn func(*fakeConn, error)) { - hookPostCloseConn.Lock() - defer hookPostCloseConn.Unlock() - hookPostCloseConn.fn = fn -} - -var testStrictClose *testing.T - -// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close -// fails to close. If nil, the check is disabled. -func setStrictFakeConnClose(t *testing.T) { - testStrictClose = t -} - -func (c *fakeConn) Close() (err error) { - drv := fdriver.(*fakeDriver) - defer func() { - if err != nil && testStrictClose != nil { - testStrictClose.Errorf("failed to close a test fakeConn: %v", err) - } - hookPostCloseConn.Lock() - fn := hookPostCloseConn.fn - hookPostCloseConn.Unlock() - if fn != nil { - fn(c, err) - } - if err == nil { - drv.mu.Lock() - drv.closeCount++ - drv.mu.Unlock() - } - }() - if c.currTx != nil { - return errors.New("can't close fakeConn; in a Transaction") - } - if c.db == nil { - return errors.New("can't close fakeConn; already closed") - } - if c.stmtsMade > c.stmtsClosed { - return errors.New("can't close; dangling statement(s)") - } - c.db = nil - return nil -} - -func checkSubsetTypes(args []driver.Value) error { - for n, arg := range args { - switch arg.(type) { - case int64, float64, bool, nil, []byte, string, time.Time: - default: - return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg) - } - } - return nil -} - -func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { - // This is an optional interface, but it's implemented here - // just to check that all the args are of the proper types. - // ErrSkip is returned so the caller acts as if we didn't - // implement this at all. - err := checkSubsetTypes(args) - if err != nil { - return nil, err - } - return nil, driver.ErrSkip -} - -func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { - // This is an optional interface, but it's implemented here - // just to check that all the args are of the proper types. - // ErrSkip is returned so the caller acts as if we didn't - // implement this at all. - err := checkSubsetTypes(args) - if err != nil { - return nil, err - } - return nil, driver.ErrSkip -} - -func errf(msg string, args ...interface{}) error { - return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) -} - -// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? -// (note that where columns must always contain ? marks, -// just a limitation for fakedb) -func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) { - if len(parts) != 3 { - stmt.Close() - return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) - } - stmt.table = parts[0] - stmt.colName = strings.Split(parts[1], ",") - for n, colspec := range strings.Split(parts[2], ",") { - if colspec == "" { - continue - } - nameVal := strings.Split(colspec, "=") - if len(nameVal) != 2 { - stmt.Close() - return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) - } - column, value := nameVal[0], nameVal[1] - _, ok := c.db.columnType(stmt.table, column) - if !ok { - stmt.Close() - return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) - } - if value != "?" { - stmt.Close() - return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", - stmt.table, column) - } - stmt.whereCol = append(stmt.whereCol, column) - stmt.placeholders++ - } - return stmt, nil -} - -// parts are table|col=type,col2=type2 -func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) { - if len(parts) != 2 { - stmt.Close() - return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) - } - stmt.table = parts[0] - for n, colspec := range strings.Split(parts[1], ",") { - nameType := strings.Split(colspec, "=") - if len(nameType) != 2 { - stmt.Close() - return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) - } - stmt.colName = append(stmt.colName, nameType[0]) - stmt.colType = append(stmt.colType, nameType[1]) - } - return stmt, nil -} - -// parts are table|col=?,col2=val -func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) { - if len(parts) != 2 { - stmt.Close() - return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) - } - stmt.table = parts[0] - for n, colspec := range strings.Split(parts[1], ",") { - nameVal := strings.Split(colspec, "=") - if len(nameVal) != 2 { - stmt.Close() - return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) - } - column, value := nameVal[0], nameVal[1] - ctype, ok := c.db.columnType(stmt.table, column) - if !ok { - stmt.Close() - return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) - } - stmt.colName = append(stmt.colName, column) - - if value != "?" { - var subsetVal interface{} - // Convert to driver subset type - switch ctype { - case "string": - subsetVal = []byte(value) - case "blob": - subsetVal = []byte(value) - case "int32": - i, err := strconv.Atoi(value) - if err != nil { - stmt.Close() - return nil, errf("invalid conversion to int32 from %q", value) - } - subsetVal = int64(i) // int64 is a subset type, but not int32 - default: - stmt.Close() - return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) - } - stmt.colValue = append(stmt.colValue, subsetVal) - } else { - stmt.placeholders++ - stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) - stmt.colValue = append(stmt.colValue, "?") - } - } - return stmt, nil -} - -// hook to simulate broken connections -var hookPrepareBadConn func() bool - -func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { - c.numPrepare++ - if c.db == nil { - panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) - } - - if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { - return nil, driver.ErrBadConn - } - - parts := strings.Split(query, "|") - if len(parts) < 1 { - return nil, errf("empty query") - } - stmt := &fakeStmt{q: query, c: c} - if len(parts) >= 3 && parts[0] == "PANIC" { - stmt.panic = parts[1] - parts = parts[2:] - } - cmd := parts[0] - stmt.cmd = cmd - parts = parts[1:] - - c.incrStat(&c.stmtsMade) - switch cmd { - case "WIPE": - // Nothing - case "SELECT": - return c.prepareSelect(stmt, parts) - case "CREATE": - return c.prepareCreate(stmt, parts) - case "INSERT": - return c.prepareInsert(stmt, parts) - case "NOSERT": - // Do all the prep-work like for an INSERT but don't actually insert the row. - // Used for some of the concurrent tests. - return c.prepareInsert(stmt, parts) - default: - stmt.Close() - return nil, errf("unsupported command type %q", cmd) - } - return stmt, nil -} - -func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { - if s.panic == "ColumnConverter" { - panic(s.panic) - } - if len(s.placeholderConverter) == 0 { - return driver.DefaultParameterConverter - } - return s.placeholderConverter[idx] -} - -func (s *fakeStmt) Close() error { - if s.panic == "Close" { - panic(s.panic) - } - if s.c == nil { - panic("nil conn in fakeStmt.Close") - } - if s.c.db == nil { - panic("in fakeStmt.Close, conn's db is nil (already closed)") - } - if !s.closed { - s.c.incrStat(&s.c.stmtsClosed) - s.closed = true - } - return nil -} - -var errClosed = errors.New("fakedb: statement has been closed") - -// hook to simulate broken connections -var hookExecBadConn func() bool - -func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { - if s.panic == "Exec" { - panic(s.panic) - } - if s.closed { - return nil, errClosed - } - - if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { - return nil, driver.ErrBadConn - } - - err := checkSubsetTypes(args) - if err != nil { - return nil, err - } - - db := s.c.db - switch s.cmd { - case "WIPE": - db.wipe() - return driver.ResultNoRows, nil - case "CREATE": - if err := db.createTable(s.table, s.colName, s.colType); err != nil { - return nil, err - } - return driver.ResultNoRows, nil - case "INSERT": - return s.execInsert(args, true) - case "NOSERT": - // Do all the prep-work like for an INSERT but don't actually insert the row. - // Used for some of the concurrent tests. - return s.execInsert(args, false) - } - fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s) - return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) -} - -// When doInsert is true, add the row to the table. -// When doInsert is false do prep-work and error checking, but don't -// actually add the row to the table. -func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) { - db := s.c.db - if len(args) != s.placeholders { - panic("error in pkg db; should only get here if size is correct") - } - db.mu.Lock() - t, ok := db.table(s.table) - db.mu.Unlock() - if !ok { - return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) - } - - t.mu.Lock() - defer t.mu.Unlock() - - var cols []interface{} - if doInsert { - cols = make([]interface{}, len(t.colname)) - } - argPos := 0 - for n, colname := range s.colName { - colidx := t.columnIndex(colname) - if colidx == -1 { - return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) - } - var val interface{} - if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" { - val = args[argPos] - argPos++ - } else { - val = s.colValue[n] - } - if doInsert { - cols[colidx] = val - } - } - - if doInsert { - t.rows = append(t.rows, &row{cols: cols}) - } - return driver.RowsAffected(1), nil -} - -// hook to simulate broken connections -var hookQueryBadConn func() bool - -func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { - if s.panic == "Query" { - panic(s.panic) - } - if s.closed { - return nil, errClosed - } - - if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { - return nil, driver.ErrBadConn - } - - err := checkSubsetTypes(args) - if err != nil { - return nil, err - } - - db := s.c.db - if len(args) != s.placeholders { - panic("error in pkg db; should only get here if size is correct") - } - - db.mu.Lock() - t, ok := db.table(s.table) - db.mu.Unlock() - if !ok { - return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) - } - - if s.table == "magicquery" { - if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" { - if args[0] == "sleep" { - time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond) - } - } - } - - t.mu.Lock() - defer t.mu.Unlock() - - colIdx := make(map[string]int) // select column name -> column index in table - for _, name := range s.colName { - idx := t.columnIndex(name) - if idx == -1 { - return nil, fmt.Errorf("fakedb: unknown column name %q", name) - } - colIdx[name] = idx - } - - mrows := []*row{} -rows: - for _, trow := range t.rows { - // Process the where clause, skipping non-match rows. This is lazy - // and just uses fmt.Sprintf("%v") to test equality. Good enough - // for test code. - for widx, wcol := range s.whereCol { - idx := t.columnIndex(wcol) - if idx == -1 { - return nil, fmt.Errorf("db: invalid where clause column %q", wcol) - } - tcol := trow.cols[idx] - if bs, ok := tcol.([]byte); ok { - // lazy hack to avoid sprintf %v on a []byte - tcol = string(bs) - } - if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { - continue rows - } - } - mrow := &row{cols: make([]interface{}, len(s.colName))} - for seli, name := range s.colName { - mrow.cols[seli] = trow.cols[colIdx[name]] - } - mrows = append(mrows, mrow) - } - - cursor := &rowsCursor{ - pos: -1, - rows: mrows, - cols: s.colName, - errPos: -1, - } - return cursor, nil -} - -func (s *fakeStmt) NumInput() int { - if s.panic == "NumInput" { - panic(s.panic) - } - return s.placeholders -} - -// hook to simulate broken connections -var hookCommitBadConn func() bool - -func (tx *fakeTx) Commit() error { - tx.c.currTx = nil - if hookCommitBadConn != nil && hookCommitBadConn() { - return driver.ErrBadConn - } - return nil -} - -// hook to simulate broken connections -var hookRollbackBadConn func() bool - -func (tx *fakeTx) Rollback() error { - tx.c.currTx = nil - if hookRollbackBadConn != nil && hookRollbackBadConn() { - return driver.ErrBadConn - } - return nil -} - -type rowsCursor struct { - cols []string - pos int - rows []*row - closed bool - - // errPos and err are for making Next return early with error. - errPos int - err error - - // a clone of slices to give out to clients, indexed by the - // the original slice's first byte address. we clone them - // just so we're able to corrupt them on close. - bytesClone map[*byte][]byte -} - -func (rc *rowsCursor) Close() error { - if !rc.closed { - for _, bs := range rc.bytesClone { - bs[0] = 255 // first byte corrupted - } - } - rc.closed = true - return nil -} - -func (rc *rowsCursor) Columns() []string { - return rc.cols -} - -var rowsCursorNextHook func(dest []driver.Value) error - -func (rc *rowsCursor) Next(dest []driver.Value) error { - if rowsCursorNextHook != nil { - return rowsCursorNextHook(dest) - } - - if rc.closed { - return errors.New("fakedb: cursor is closed") - } - rc.pos++ - if rc.pos == rc.errPos { - return rc.err - } - if rc.pos >= len(rc.rows) { - return io.EOF // per interface spec - } - for i, v := range rc.rows[rc.pos].cols { - // TODO(bradfitz): convert to subset types? naah, I - // think the subset types should only be input to - // driver, but the sql package should be able to handle - // a wider range of types coming out of drivers. all - // for ease of drivers, and to prevent drivers from - // messing up conversions or doing them differently. - dest[i] = v - - if bs, ok := v.([]byte); ok { - if rc.bytesClone == nil { - rc.bytesClone = make(map[*byte][]byte) - } - clone, ok := rc.bytesClone[&bs[0]] - if !ok { - clone = make([]byte, len(bs)) - copy(clone, bs) - rc.bytesClone[&bs[0]] = clone - } - dest[i] = clone - } - } - return nil -} - -// fakeDriverString is like driver.String, but indirects pointers like -// DefaultValueConverter. -// -// This could be surprising behavior to retroactively apply to -// driver.String now that Go1 is out, but this is convenient for -// our TestPointerParamsAndScans. -// -type fakeDriverString struct{} - -func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { - switch c := v.(type) { - case string, []byte: - return v, nil - case *string: - if c == nil { - return nil, nil - } - return *c, nil - } - return fmt.Sprintf("%v", v), nil -} - -func converterForType(typ string) driver.ValueConverter { - switch typ { - case "bool": - return driver.Bool - case "nullbool": - return driver.Null{Converter: driver.Bool} - case "int32": - return driver.Int32 - case "string": - return driver.NotNull{Converter: fakeDriverString{}} - case "nullstring": - return driver.Null{Converter: fakeDriverString{}} - case "int64": - // TODO(coopernurse): add type-specific converter - return driver.NotNull{Converter: driver.DefaultParameterConverter} - case "nullint64": - // TODO(coopernurse): add type-specific converter - return driver.Null{Converter: driver.DefaultParameterConverter} - case "float64": - // TODO(coopernurse): add type-specific converter - return driver.NotNull{Converter: driver.DefaultParameterConverter} - case "nullfloat64": - // TODO(coopernurse): add type-specific converter - return driver.Null{Converter: driver.DefaultParameterConverter} - case "datetime": - return driver.DefaultParameterConverter - } - panic("invalid fakedb column type of " + typ) -} diff --git a/hooks/hooks.go b/hooks/hooks.go deleted file mode 100644 index 4e49cfd..0000000 --- a/hooks/hooks.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package hooks provides ready-to-use hook implementations -package hooks diff --git a/hooks/logger/logger.go b/hooks/logger/logger.go deleted file mode 100644 index e09c465..0000000 --- a/hooks/logger/logger.go +++ /dev/null @@ -1,69 +0,0 @@ -// Package logger provides a query logger -package logger - -import ( - "log" - "os" - "sync/atomic" - "time" - - "github.com/gchaincl/sqlhooks" -) - -type Logger interface { - Printf(format string, v ...interface{}) -} - -type hook struct { - id uint64 - Log Logger -} - -func (h *hook) next() uint64 { - return atomic.AddUint64(&h.id, 1) -} - -func New() *hook { - return &hook{ - Log: log.New(os.Stderr, "", log.LstdFlags), - } -} - -func (h *hook) before(ctx *sqlhooks.Context) error { - id := h.next() - ctx.Set("start", time.Now()) - ctx.Set("id", id) - - h.Log.Printf("[query#%09d] %s %v", id, ctx.Query, ctx.Args) - return nil - -} - -func (h *hook) after(ctx *sqlhooks.Context) error { - id := ctx.Get("id") - took := time.Since(ctx.Get("start").(time.Time)) - - if err := ctx.Error; err != nil { - h.Log.Printf("[query#%09d] Finished with error: %v", id, err) - return err - } - - h.Log.Printf("[query#%09d] took %s", id, took) - return nil -} - -func (h *hook) BeforeQuery(ctx *sqlhooks.Context) error { - return h.before(ctx) -} - -func (h *hook) AfterQuery(ctx *sqlhooks.Context) error { - return h.after(ctx) -} - -func (h *hook) BeforeExec(ctx *sqlhooks.Context) error { - return h.before(ctx) -} - -func (h *hook) AfterExec(ctx *sqlhooks.Context) error { - return h.after(ctx) -} diff --git a/hooks/logger/logger_test.go b/hooks/logger/logger_test.go deleted file mode 100644 index e4749d9..0000000 --- a/hooks/logger/logger_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package logger - -import ( - "bytes" - "errors" - "log" - "testing" - - "github.com/gchaincl/sqlhooks" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func newTestHook() (*hook, *bytes.Buffer) { - buf := bytes.Buffer{} - hook := New() - hook.Log = log.New(&buf, "", 0) - - return hook, &buf -} - -func TestLoggerQuery(t *testing.T) { - hook, buf := newTestHook() - - ctx := sqlhooks.NewContext() - ctx.Query = "SELECT * FROM table" - ctx.Args = []interface{}{"1", 2} - - require.NoError(t, hook.BeforeQuery(ctx)) - assert.Contains(t, buf.String(), "[query#000000001] ") - assert.Contains(t, buf.String(), ctx.Query) - assert.Contains(t, buf.String(), "[1 2]") - - buf.Reset() - require.NoError(t, hook.AfterQuery(ctx)) - assert.Contains(t, buf.String(), "[query#000000001] ") - assert.Contains(t, buf.String(), "took") -} - -func TestLoggerExec(t *testing.T) { - hook, buf := newTestHook() - - ctx := sqlhooks.NewContext() - ctx.Query = "INSERT INTO table (foo, bar) VALUES (?, ?)" - ctx.Args = []interface{}{"x", "z"} - - require.NoError(t, hook.BeforeExec(ctx)) - assert.Contains(t, buf.String(), "[query#000000001] ") - assert.Contains(t, buf.String(), ctx.Query) - assert.Contains(t, buf.String(), "[x z]") - - buf.Reset() - require.NoError(t, hook.AfterExec(ctx)) - assert.Contains(t, buf.String(), "[query#000000001] ") - assert.Contains(t, buf.String(), "took") -} - -func TestLoggerWithErrors(t *testing.T) { - hook, buf := newTestHook() - - ctx := sqlhooks.NewContext() - ctx.Query = "INSERT INTO table (foo, bar) VALUES (?, ?)" - ctx.Args = []interface{}{"x", "z"} - ctx.Error = errors.New("boom") - - require.NoError(t, hook.BeforeExec(ctx)) - - buf.Reset() - require.Error(t, hook.AfterExec(ctx)) - assert.Contains(t, buf.String(), "Finished with error: boom") -} - -func TestLoggerIncrementsQueryCounter(t *testing.T) { - hook, buf := newTestHook() - - ctx := sqlhooks.NewContext() - for _ = range [9]bool{} { - hook.BeforeQuery(ctx) - } - - buf.Reset() - hook.BeforeQuery(ctx) - assert.Contains(t, buf.String(), "[query#000000010] ") - - buf.Reset() - hook.AfterQuery(ctx) - assert.Contains(t, buf.String(), "[query#000000010] ") -} diff --git a/hooks/loghooks/example_test.go b/hooks/loghooks/example_test.go new file mode 100644 index 0000000..d2462a1 --- /dev/null +++ b/hooks/loghooks/example_test.go @@ -0,0 +1,17 @@ +package loghooks + +import ( + "database/sql" + + "github.com/gchaincl/sqlhooks" + sqlite3 "github.com/mattn/go-sqlite3" +) + +func Example() { + driver := sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, New()) + sql.Register("sqlite3-logger", driver) + db, _ := sql.Open("sqlite3-logger", ":memory:") + + // This query will output logs + db.Query("SELECT 1+1") +} diff --git a/hooks/loghooks/loghooks.go b/hooks/loghooks/loghooks.go new file mode 100644 index 0000000..17c0036 --- /dev/null +++ b/hooks/loghooks/loghooks.go @@ -0,0 +1,30 @@ +package loghooks + +import ( + "context" + "log" + "os" + "time" +) + +type logger interface { + Printf(string, ...interface{}) +} + +type Hook struct { + log logger +} + +func New() *Hook { + return &Hook{ + log: log.New(os.Stderr, "", log.LstdFlags), + } +} +func (h *Hook) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + return context.WithValue(ctx, "started", time.Now()), nil +} + +func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + h.log.Printf("Query: `%s`, Args: `%q`. took: %s", query, args, time.Since(ctx.Value("started").(time.Time))) + return ctx, nil +} diff --git a/sqlhooks.go b/sqlhooks.go index 8cc747c..200c000 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -1,141 +1,181 @@ -/* -Package Sqlhooks provides a mechanism to execute a callbacks around specific database/sql functions. +package sqlhooks -The purpose of sqlhooks is to provide a way to instrument your database operations, -making really to log queries and arguments, measure execution time, -modifies queries before the are executed or stop execution if some conditions are met. +import ( + "context" + "database/sql/driver" +) -Example: - package main +type Hook func(ctx context.Context, query string, args ...interface{}) (context.Context, error) - import ( - "log" - "time" +type Hooks interface { + Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) + After(ctx context.Context, query string, args ...interface{}) (context.Context, error) +} - "github.com/gchaincl/sqlhooks" - _ "github.com/mattn/go-sqlite3" - ) +type Driver struct { + driver.Driver + hooks Hooks +} - // Hooks satisfies sqlhooks.Queryer interface - type Hooks struct { - count int +func (drv *Driver) Open(name string) (driver.Conn, error) { + conn, err := drv.Driver.Open(name) + if err != nil { + return conn, err } - func (h *Hooks) BeforeQuery(ctx *sqlhooks.Context) error { - h.count++ - ctx.Set("t", time.Now()) - ctx.Set("id", h.count) - log.Printf("[query#%d] %s, args: %v", ctx.Get("id").(int), ctx.Query, ctx.Args) - return nil + return &Conn{conn, drv.hooks}, nil +} + +type Conn struct { + Conn driver.Conn + hooks Hooks +} + +func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + var ( + stmt driver.Stmt + err error + ) + + if c, ok := conn.Conn.(driver.ConnPrepareContext); ok { + stmt, err = c.PrepareContext(ctx, query) + } else { + stmt, err = conn.Prepare(query) } - func (h *Hooks) AfterQuery(ctx *sqlhooks.Context) error { - d := time.Since(ctx.Get("t").(time.Time)) - log.Printf("[query#%d] took %s (err: %v)", ctx.Get("id").(int), d, ctx.Error) - return ctx.Error + if err != nil { + return stmt, err } - func main() { - hooks := &Hooks{} + return &Stmt{stmt, conn.hooks, query}, nil +} - // Connect to attached driver - db, _ := sqlhooks.Open("sqlite3", ":memory:", hooks) +func (conn *Conn) Prepare(query string) (driver.Stmt, error) { return conn.Conn.Prepare(query) } +func (conn *Conn) Close() error { return conn.Conn.Close() } +func (conn *Conn) Begin() (driver.Tx, error) { return conn.Conn.Begin() } - // Do you're stuff - db.Exec("CREATE TABLE t (id INTEGER, text VARCHAR(16))") - db.Exec("INSERT into t (text) VALUES(?), (?)", "foo", "bar") - db.Query("SELECT id, text FROM t") - db.Query("Invalid Query") +type Stmt struct { + Stmt driver.Stmt + hooks Hooks + query string +} + +func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + if s, ok := stmt.Stmt.(driver.StmtExecContext); ok { + return s.ExecContext(ctx, args) } -*/ -package sqlhooks -import ( - "database/sql" - "fmt" - "time" -) + values := make([]driver.Value, len(args)) + for _, arg := range args { + values[arg.Ordinal-1] = arg.Value + } -var ( - drivers = make(map[interface{}]string) -) + return stmt.Exec(values) +} + +func (stmt *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + var err error -// Open Register a sqlhook driver and opens a connection against it, -// driverName is the driver where we're attaching to. -func Open(driverName, dsn string, hooks HookType) (*sql.DB, error) { - if registeredName, ok := drivers[hooks]; ok { - return sql.Open(registeredName, dsn) + // Exec `Before` Hooks + if ctx, err = stmt.hooks.Before(ctx, stmt.query, args); err != nil { + return nil, err } - registeredName := fmt.Sprintf("sqlhooks:%d", time.Now().UnixNano()) - sql.Register(registeredName, NewDriver(driverName, hooks)) - drivers[hooks] = registeredName + results, err := stmt.execContext(ctx, args) + if err != nil { + return results, err + } - return sql.Open(registeredName, dsn) -} + if ctx, err = stmt.hooks.After(ctx, stmt.query, args); err != nil { + return nil, err + } -/* -HookType is the type of Hook. -In order to reduce the amount boilerplate, it's organized by database operations, -so you can only implement the hooks you need for certain operation -This type is an alias for interface{}, however the hook should implement at least one of the following interfaces: - - Beginner - - Commiter - - Rollbacker - - Stmter - - Queryer - - Execer - -Every hook can be attached Before or After the operation. -Before hooks are triggered just before execute the operation (Begin, Commit, Rollback, Prepare, Query, Exec), -if they returns an error, neither the operation nor the After hook will executed, and the error will be returned to the caller - -After hooks are triggered after the operation complete, the there is an error it will be passed inside *Context. -The error returned by an After hook will override the error returned from the operation, that's why in most cases -an after hooks should: - return ctx.Error + return results, err +} -*/ -type HookType interface{} +func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + if s, ok := stmt.Stmt.(driver.StmtQueryContext); ok { + return s.QueryContext(ctx, args) + } -// Beginner is the interface implemented by objects that wants to hook to Begin function -type Beginner interface { - BeforeBegin(*Context) error - AfterBegin(*Context) error + values := make([]driver.Value, len(args)) + for _, arg := range args { + values[arg.Ordinal-1] = arg.Value + } + return stmt.Query(values) } -// Commiter is the interface implemented by objects that wants to hook to Commit function -type Commiter interface { - BeforeCommit(*Context) error - AfterCommit(*Context) error +func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + var err error + + list := namedToInterface(args) + + // Exec Before Hooks + if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil { + return nil, err + } + + rows, err := stmt.queryContext(ctx, args) + if err != nil { + return rows, err + } + + if ctx, err = stmt.hooks.After(ctx, stmt.query, list...); err != nil { + return nil, err + } + + return rows, err } -// Rollbacker is the interface implemented by objects that wants to hook to Rollback function -type Rollbacker interface { - BeforeRollback(*Context) error - AfterRollback(*Context) error +func (stmt *Stmt) Close() error { return stmt.Stmt.Close() } +func (stmt *Stmt) NumInput() int { return stmt.Stmt.NumInput() } +func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { return stmt.Stmt.Exec(args) } +func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { return stmt.Stmt.Query(args) } + +func Wrap(driver driver.Driver, hooks Hooks) driver.Driver { + return &Driver{driver, hooks} } -// Stmter is the interface implemented by objects that wants to hook to Statement related functions -type Stmter interface { - BeforePrepare(*Context) error - AfterPrepare(*Context) error +func namedToInterface(args []driver.NamedValue) []interface{} { + list := make([]interface{}, len(args)) + for i, a := range args { + list[i] = a.Value + } + return list +} - BeforeStmtQuery(*Context) error - AfterStmtQuery(*Context) error +/* +type hooks struct { +} - BeforeStmtExec(*Context) error - AfterStmtExec(*Context) error +func (h *hooks) Before(ctx context.Context, query string, args ...interface{}) error { + log.Printf("before> ctx = %+v, q=%s, args = %+v\n", ctx, query, args) + return nil } -// Queryer is the interface implemented by objects that wants to hook to Query function -type Queryer interface { - BeforeQuery(*Context) error - AfterQuery(*Context) error +func (h *hooks) After(ctx context.Context, query string, args ...interface{}) error { + log.Printf("after> ctx = %+v, q=%s, args = %+v\n", ctx, query, args) + return nil } -// Execer is the interface implemented by objects that wants to hook to Exec function -type Execer interface { - BeforeExec(*Context) error - AfterExec(*Context) error +func main() { + sql.Register("sqlite3-proxy", Wrap(&sqlite3.SQLiteDriver{}, &hooks{})) + db, err := sql.Open("sqlite3-proxy", ":memory:") + if err != nil { + log.Fatalln(err) + } + + if _, ok := driver.Stmt(&Stmt{}).(driver.StmtExecContext); !ok { + panic("NOPE") + } + + if _, err := db.Exec("CREATE table users(id int)"); err != nil { + log.Printf("|err| = %+v\n", err) + } + + if _, err := db.QueryContext(context.Background(), "SELECT * FROM users WHERE id = ?", 1); err != nil { + log.Printf("err = %+v\n", err) + } + } +*/ diff --git a/sqlhooks_fakedb_test.go b/sqlhooks_fakedb_test.go deleted file mode 100644 index 050890d..0000000 --- a/sqlhooks_fakedb_test.go +++ /dev/null @@ -1,11 +0,0 @@ -package sqlhooks - -func init() { - queries["test"] = ops{ - wipe: "WIPE", - create: "CREATE|t|f1=string,f2=string", - insert: "INSERT|t|f1=?,f2=?", - selectwhere: "SELECT|t|f1,f2|f1=?,f2=?", - selectall: "SELECT|t|f1,f2|", - } -} diff --git a/sqlhooks_mock_test.go b/sqlhooks_mock_test.go deleted file mode 100644 index c6c8e04..0000000 --- a/sqlhooks_mock_test.go +++ /dev/null @@ -1,160 +0,0 @@ -package sqlhooks - -type HooksMock struct { - beforeQuery func(c *Context) error - afterQuery func(c *Context) error - - beforeExec func(c *Context) error - afterExec func(c *Context) error - - beforeBegin func(c *Context) error - afterBegin func(c *Context) error - - beforeCommit func(c *Context) error - afterCommit func(c *Context) error - - beforeRollback func(c *Context) error - afterRollback func(c *Context) error - - beforePrepare func(c *Context) error - afterPrepare func(c *Context) error - - beforeStmtQuery func(c *Context) error - afterStmtQuery func(c *Context) error - - beforeStmtExec func(c *Context) error - afterStmtExec func(*Context) error -} - -func (h HooksMock) BeforeQuery(c *Context) error { - if h.beforeQuery != nil { - return h.beforeQuery(c) - } - return nil -} - -func (h HooksMock) AfterQuery(c *Context) error { - if h.afterQuery != nil { - return h.afterQuery(c) - } - return nil -} - -func (h HooksMock) BeforeExec(c *Context) error { - if h.beforeExec != nil { - return h.beforeExec(c) - } - return nil -} - -func (h HooksMock) AfterExec(c *Context) error { - if h.afterExec != nil { - return h.afterExec(c) - } - return nil -} - -func (h HooksMock) BeforeBegin(c *Context) error { - if h.beforeBegin != nil { - return h.beforeBegin(c) - } - return nil -} - -func (h HooksMock) AfterBegin(c *Context) error { - if h.afterBegin != nil { - return h.afterBegin(c) - } - return nil -} - -func (h HooksMock) BeforeCommit(c *Context) error { - if h.beforeCommit != nil { - return h.beforeCommit(c) - } - return nil -} - -func (h HooksMock) AfterCommit(c *Context) error { - if h.afterCommit != nil { - return h.afterCommit(c) - } - return nil -} - -func (h HooksMock) BeforeRollback(c *Context) error { - if h.beforeRollback != nil { - return h.beforeRollback(c) - } - return nil -} - -func (h HooksMock) AfterRollback(c *Context) error { - if h.afterRollback != nil { - return h.afterRollback(c) - } - return nil -} - -func (h HooksMock) BeforePrepare(c *Context) error { - if h.beforePrepare != nil { - return h.beforePrepare(c) - } - return nil -} - -func (h HooksMock) AfterPrepare(c *Context) error { - if h.afterPrepare != nil { - return h.afterPrepare(c) - } - return nil -} - -func (h HooksMock) BeforeStmtQuery(c *Context) error { - if h.beforeStmtQuery != nil { - return h.beforeStmtQuery(c) - } - return nil -} - -func (h HooksMock) AfterStmtQuery(c *Context) error { - if h.afterStmtQuery != nil { - return h.afterStmtQuery(c) - } - return nil -} - -func (h HooksMock) BeforeStmtExec(c *Context) error { - if h.beforeStmtExec != nil { - return h.beforeStmtExec(c) - } - return nil -} - -func (h HooksMock) AfterStmtExec(c *Context) error { - if h.afterStmtExec != nil { - return h.afterStmtExec(c) - } - return nil -} - -func NewHooksMock(before, after func(*Context) error) *HooksMock { - return &HooksMock{ - beforeQuery: before, - beforeExec: before, - beforeBegin: before, - beforeCommit: before, - beforeRollback: before, - beforePrepare: before, - beforeStmtQuery: before, - beforeStmtExec: before, - afterQuery: after, - afterExec: after, - afterBegin: after, - afterCommit: after, - afterRollback: after, - afterPrepare: after, - afterStmtQuery: after, - afterStmtExec: after, - } -} diff --git a/sqlhooks_mysql_test.go b/sqlhooks_mysql_test.go index bb4abb8..6dd1af2 100644 --- a/sqlhooks_mysql_test.go +++ b/sqlhooks_mysql_test.go @@ -1,15 +1,56 @@ -// +build mysql - package sqlhooks -import _ "github.com/go-sql-driver/mysql" +import ( + "database/sql" + "os" + "testing" + + "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setUpMySQL(t *testing.T, dsn string) { + db, err := sql.Open("mysql", dsn) + require.NoError(t, err) + require.NoError(t, db.Ping()) + defer db.Close() -func init() { - queries["mysql"] = ops{ - wipe: "DROP TABLE IF EXISTS t", - create: "CREATE TABLE t(f1 varchar(32), f2 varchar(32))", - insert: "INSERT INTO t VALUES(?, ?)", - selectwhere: "SELECT f1, f2 FROM t WHERE f1=? AND f2=?", - selectall: "SELECT f1, f2 FROM t", + _, err = db.Exec("CREATE table IF NOT EXISTS users(id int, name text)") + require.NoError(t, err) +} + +func TestMySQL(t *testing.T) { + dsn := os.Getenv("SQLHOOKS_MYSQL_DSN") + if dsn == "" { + t.Skipf("SQLHOOKS_MYSQL_DSN not set") } + + setUpMySQL(t, dsn) + + s := newSuite(t, &mysql.MySQLDriver{}, dsn) + + s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1) + s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus") + s.TestHooksErrors(t, "SELECT 1+1") + + t.Run("DBWorks", func(t *testing.T) { + s.hooks.noop() + if _, err := s.db.Exec("DELETE FROM users"); err != nil { + t.Fatal(err) + } + + stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES(?, ?)") + require.NoError(t, err) + for i := range [5]struct{}{} { + _, err := stmt.Exec(i, "gus") + require.NoError(t, err) + } + + var count int + require.NoError(t, + s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count), + ) + assert.Equal(t, 5, count) + }) } diff --git a/sqlhooks_postgres_test.go b/sqlhooks_postgres_test.go index 10ae76d..7cfbea0 100644 --- a/sqlhooks_postgres_test.go +++ b/sqlhooks_postgres_test.go @@ -1,15 +1,56 @@ -// +build postgres - package sqlhooks -import _ "github.com/lib/pq" +import ( + "database/sql" + "os" + "testing" + + "github.com/lib/pq" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setUpPostgres(t *testing.T, dsn string) { + db, err := sql.Open("postgres", dsn) + require.NoError(t, err) + require.NoError(t, db.Ping()) + defer db.Close() -func init() { - queries["postgres"] = ops{ - wipe: "DROP TABLE IF EXISTS t", - create: "CREATE TABLE t(f1 varchar(32), f2 varchar(32))", - insert: "INSERT INTO t VALUES($1, $2)", - selectwhere: "SELECT f1, f2 FROM t WHERE f1=$1 AND f2=$2", - selectall: "SELECT f1, f2 FROM t", + _, err = db.Exec("CREATE table IF NOT EXISTS users(id int, name text)") + require.NoError(t, err) +} + +func TestPostgres(t *testing.T) { + dsn := os.Getenv("SQLHOOKS_POSTGRES_DSN") + if dsn == "" { + t.Skipf("SQLHOOKS_POSTGRES_DSN not set") } + + setUpPostgres(t, dsn) + + s := newSuite(t, &pq.Driver{}, dsn) + + s.TestHooksExecution(t, "SELECT * FROM users WHERE id = $1", 1) + s.TestHooksArguments(t, "SELECT * FROM users WHERE id = $1 AND name = $2", int64(1), "Gus") + s.TestHooksErrors(t, "SELECT 1+1") + + t.Run("DBWorks", func(t *testing.T) { + s.hooks.noop() + if _, err := s.db.Exec("DELETE FROM users"); err != nil { + t.Fatal(err) + } + + stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES($1, $2)") + require.NoError(t, err) + for i := range [5]struct{}{} { + _, err := stmt.Exec(i, "gus") + require.NoError(t, err) + } + + var count int + require.NoError(t, + s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count), + ) + assert.Equal(t, 5, count) + }) } diff --git a/sqlhooks_sqlite3_test.go b/sqlhooks_sqlite3_test.go index 9fa5f50..bf910c8 100644 --- a/sqlhooks_sqlite3_test.go +++ b/sqlhooks_sqlite3_test.go @@ -1,15 +1,54 @@ -// +build sqlite3 - package sqlhooks -import _ "github.com/mattn/go-sqlite3" +import ( + "database/sql" + "os" + "testing" + "time" + + sqlite3 "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setUp(t *testing.T) func() { + dbName := "sqlite3test.db" + + db, err := sql.Open("sqlite3", dbName) + require.NoError(t, err) + defer db.Close() + + _, err = db.Exec("CREATE table users(id int, name text)") + require.NoError(t, err) + + return func() { os.Remove(dbName) } +} + +func TestSQLite3(t *testing.T) { + defer setUp(t)() + s := newSuite(t, &sqlite3.SQLiteDriver{}, "sqlite3test.db") + + s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1) + s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus") + s.TestHooksErrors(t, "SELECT 1+1") + + t.Run("DBWorks", func(t *testing.T) { + s.hooks.noop() + if _, err := s.db.Exec("DELETE FROM users"); err != nil { + t.Fatal(err) + } + + stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES(?, ?)") + require.NoError(t, err) + for range [5]struct{}{} { + _, err := stmt.Exec(time.Now().UnixNano(), "gus") + require.NoError(t, err) + } -func init() { - queries["sqlite3"] = ops{ - wipe: "DROP TABLE IF EXISTS t", - create: "CREATE TABLE t(f1, f2)", - insert: "INSERT INTO t VALUES(?, ?)", - selectwhere: "SELECT f1, f2 FROM t WHERE f1=? AND f2=?", - selectall: "SELECT f1, f2 FROM t", - } + var count int + require.NoError(t, + s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count), + ) + assert.Equal(t, 5, count) + }) } diff --git a/sqlhooks_test.go b/sqlhooks_test.go index 40090c1..7093382 100644 --- a/sqlhooks_test.go +++ b/sqlhooks_test.go @@ -1,287 +1,149 @@ package sqlhooks import ( + "context" "database/sql" - "flag" + "database/sql/driver" + "errors" "fmt" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -var ( - driverFlag = flag.String("driver", "test", "SQL Driver") - dsnFlag = flag.String("dsn", "db", "DSN") -) - -type ops struct { - wipe string - create string - insert string - selectwhere string - selectall string +type testHooks struct { + before Hook + after Hook } -var queries = make(map[string]ops) - -func openDBWithHooks(t *testing.T, hooks interface{}, dsnArgs ...string) *sql.DB { - q := queries[*driverFlag] - - dsn := *dsnFlag - for _, arg := range dsnArgs { - dsn = dsn + arg - } - - // First, we connect directly using `test` driver - if db, err := sql.Open(*driverFlag, dsn); err != nil { - t.Fatalf("sql.Open: %v", err) - return nil - } else { - if _, err := db.Exec(q.wipe); err != nil { - t.Fatalf("WIPE: %v", err) - } - - if _, err := db.Exec(q.create); err != nil { - t.Fatalf("CREATE: %v", err) - } - if err := db.Close(); err != nil { - t.Fatalf("db.Close: %v", err) - } +func (h *testHooks) noop() { + noop := func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + return ctx, nil } - db, err := Open(*driverFlag, dsn, hooks) - if err != nil { - t.Fatalf("sql.Open: %v", err) - } - - return db + h.before, h.after = noop, noop } -func TestBeforeAndAfterHooks(t *testing.T) { - q := queries[*driverFlag] - - for _, hook := range []string{"Query", "Exec", "Begin", "Commit", "Rollback"} { - beforeOk := false - before := func(ctx *Context) error { - beforeOk = true - return nil - } - - afterOk := false - after := func(ctx *Context) error { - afterOk = true - return ctx.Error - } - - hooks := NewHooksMock(before, after) - db := openDBWithHooks(t, hooks) - - switch hook { - case "Query": - db.Query(q.selectall) - case "Exec": - db.Exec(q.insert) - case "Begin": - tx, _ := db.Begin() - - hooks.beforeCommit = nil - hooks.afterCommit = nil - tx.Commit() - case "Commit": - hooks.beforeBegin = nil - hooks.afterBegin = nil - - tx, _ := db.Begin() - tx.Commit() - case "Rollback": - hooks.beforeBegin = nil - hooks.afterBegin = nil - - tx, _ := db.Begin() - tx.Rollback() - } - - assert.True(t, beforeOk, "'Before%s' hook didn't run", hook) - assert.True(t, afterOk, "'After%s' hook didn't run", hook) - } +func (h *testHooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + return h.before(ctx, query, args...) } -func TestBeforeQueryStopsAndReturnsError(t *testing.T) { - q := queries[*driverFlag] - - for _, hook := range []string{"Query", "Exec", "Begin", "Commit", "Rollback"} { - someErr := fmt.Errorf("Some Error") - before := func(ctx *Context) error { - return someErr - } - - // this hook should never run - after := func(ctx *Context) error { - assert.True(t, false, "'After%s' should not run", hook) - return nil - } - - hooks := NewHooksMock(before, after) - db := openDBWithHooks(t, hooks) - - var err error - switch hook { - case "Query": - _, err = db.Query(q.selectall) - case "Exec": - _, err = db.Exec(q.insert) - case "Begin": - var tx *sql.Tx - tx, err = db.Begin() - assert.Nil(t, tx) - case "Commit": - hooks.beforeBegin = nil - hooks.afterBegin = nil - tx, _ := db.Begin() - - err = tx.Commit() - case "Rollback": - hooks.beforeBegin = nil - hooks.afterBegin = nil - tx, _ := db.Begin() - - err = tx.Rollback() - } - - assert.Equal(t, someErr, err, "On %s hooks", hook) - } +func (h *testHooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + return h.after(ctx, query, args...) } -func TestBeforeModifiesQueryAndArgs(t *testing.T) { - if *driverFlag == "test" { - t.SkipNow() - } - - q := queries[*driverFlag] - - // this hook convert the select where into a select all - before := func(ctx *Context) error { - ctx.Args = nil - ctx.Query = q.selectall - return nil - } - - after := func(ctx *Context) error { - assert.Equal(t, q.selectall, ctx.Query) - assert.Equal(t, []interface{}(nil), ctx.Args) - return ctx.Error - } - - hooks := &HooksMock{ - beforeQuery: before, - afterQuery: after, - } - db := openDBWithHooks(t, hooks) - - db.Exec(q.insert, "x", "y") - rows, err := db.Query(q.selectwhere, "a", "b") - assert.NoError(t, err) - - found := false - for rows.Next() { - found = true - } - - assert.True(t, found) +type suite struct { + db *sql.DB + hooks *testHooks } -func TestBeforePrepare(t *testing.T) { - q := queries[*driverFlag] +func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite { + hooks := &testHooks{} + driverName := fmt.Sprintf("sqlhooks-%s", time.Now().String()) + sql.Register(driverName, Wrap(driver, hooks)) - before := func(ctx *Context) error { - ctx.Query = q.selectall - return nil - } - - db := openDBWithHooks(t, &HooksMock{beforePrepare: before}) + db, err := sql.Open(driverName, dsn) + require.NoError(t, err) + require.NoError(t, db.Ping()) - _, err := db.Prepare("invalid query") - assert.NoError(t, err) + return &suite{db, hooks} } -func TestAfterReceivesAndHideTheError(t *testing.T) { - for _, hook := range []string{"Query", "Exec"} { - after := func(ctx *Context) error { - assert.Error(t, ctx.Error) - return nil // hide the error - } - - db := openDBWithHooks(t, &HooksMock{ - afterQuery: after, - afterExec: after, - }) - - var err error - switch hook { - case "Query": - _, err = db.Query("invalid query") - case "Exec": - _, err = db.Exec("invalid query") - } - assert.NoError(t, err) - } +func (s *suite) TestHooksExecution(t *testing.T, query string, args ...interface{}) { + var before, after bool + + s.hooks.before = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { + before = true + return ctx, nil + } + s.hooks.after = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { + after = true + return ctx, nil + } + + t.Run("Query", func(t *testing.T) { + before, after = false, false + _, err := s.db.Query(query, args...) + require.NoError(t, err) + assert.True(t, before, "Before Hook did not run for query: "+query) + assert.True(t, after, "After Hook did not run for query: "+query) + }) + + t.Run("QueryContext", func(t *testing.T) { + before, after = false, false + _, err := s.db.QueryContext(context.Background(), query, args...) + require.NoError(t, err) + assert.True(t, before, "Before Hook did not run for query: "+query) + assert.True(t, after, "After Hook did not run for query: "+query) + }) + + t.Run("Exec", func(t *testing.T) { + before, after = false, false + _, err := s.db.Exec(query, args...) + require.NoError(t, err) + assert.True(t, before, "Before Hook did not run for query: "+query) + assert.True(t, after, "After Hook did not run for query: "+query) + }) + + t.Run("ExecContext", func(t *testing.T) { + before, after = false, false + _, err := s.db.ExecContext(context.Background(), query, args...) + require.NoError(t, err) + assert.True(t, before, "Before Hook did not run for query: "+query) + assert.True(t, after, "After Hook did not run for query: "+query) + }) + + t.Run("Statements", func(t *testing.T) { + before, after = false, false + stmt, err := s.db.Prepare(query) + require.NoError(t, err) + + // Hooks just run when the stmt is executed (Query or Exec) + assert.False(t, before, "Before Hook run before execution: "+query) + assert.False(t, after, "After Hook run before execution: "+query) + + stmt.Query(args...) + assert.True(t, before, "Before Hook did not run for query: "+query) + assert.True(t, after, "After Hook did not run for query: "+query) + }) } -func TestDriverItWorksWithNilHooks(t *testing.T) { - q := queries[*driverFlag] - - db := openDBWithHooks(t, nil) - - for _ = range [10]bool{} { - _, err := db.Exec(q.insert, "foo", "bar") - assert.NoError(t, err) - } - - rows, err := db.Query(q.selectall) - assert.NoError(t, err) - - items := 0 - for rows.Next() { - items++ +func (s *suite) testHooksArguments(t *testing.T, query string, args ...interface{}) { + hook := func(ctx context.Context, q string, a ...interface{}) (context.Context, error) { + assert.Equal(t, query, q) + assert.Equal(t, args, a) + assert.Equal(t, "val", ctx.Value("key").(string)) + return ctx, nil } + s.hooks.before = hook + s.hooks.after = hook - assert.Equal(t, 10, items) + ctx := context.WithValue(context.Background(), "key", "val") + _, err := s.db.QueryContext(ctx, query, args...) + require.NoError(t, err) } -func TestValuesAreSavedAndRetrievedFromCtx(t *testing.T) { - q := queries[*driverFlag] +func (s *suite) TestHooksArguments(t *testing.T, query string, args ...interface{}) { + t.Run("TestHooksArguments", func(t *testing.T) { s.testHooksArguments(t, query, args...) }) +} - before := func(ctx *Context) error { - ctx.Set("foo", 123) - ctx.Set("bar", "sqlhooks") - return nil +func (s *suite) testHooksErrors(t *testing.T, query string) { + boom := errors.New("boom") + s.hooks.before = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + return ctx, boom } - after := func(ctx *Context) error { - assert.Equal(t, 123, ctx.Get("foo").(int)) - assert.Equal(t, "sqlhooks", ctx.Get("bar").(string)) - return ctx.Error + s.hooks.after = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { + assert.False(t, true, "this should not run") + return ctx, nil } - hooks := NewHooksMock(before, after) - db := openDBWithHooks(t, hooks) - - _, err := db.Query(q.selectall) - assert.NoError(t, err) + _, err := s.db.Query(query) + assert.Equal(t, boom, err) } -func TestDriverIsNotRegisteredTwice(t *testing.T) { - registeredDrivers := sql.Drivers() - - for i := 0; i < 100; i++ { - _, err := Open("test", "db", nil) - if err != nil { - t.Fatalf("Unexpected error, got %v", err) - } - } - - registeredAfterOpen := len(sql.Drivers()) - len(registeredDrivers) - if registeredAfterOpen > 1 { - t.Errorf("Driver registered %d times more than expected", registeredAfterOpen-1) - } +func (s *suite) TestHooksErrors(t *testing.T, query string) { + t.Run("TestHooksErrors", func(t *testing.T) { s.testHooksErrors(t, query) }) }