Skip to content

Commit

Permalink
Add QueryerContext interface
Browse files Browse the repository at this point in the history
If we don't support QueryerContext, the db.Query() call will always do
"prepare" statement
  • Loading branch information
surki committed Jan 2, 2019
1 parent 7408a7f commit 98e9bd8
Showing 1 changed file with 69 additions and 1 deletion.
70 changes: 69 additions & 1 deletion sqlhooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,16 @@ func (drv *Driver) Open(name string) (driver.Conn, error) {
}

wrapped := &Conn{conn, drv.hooks}
if isExecer(conn) {
if isExecer(conn) && isQueryer(conn) {
return &ExecerQueryerContext{wrapped, &ExecerContext{wrapped}, &QueryerContext{wrapped}}, nil
} else if isExecer(conn) {
// If conn implements an Execer interface, return a driver.Conn which
// also implements Execer
return &ExecerContext{wrapped}, nil
} else if isQueryer(conn) {
// If conn implements an Queryer interface, return a driver.Conn which
// also implements Queryer
return &QueryerContext{wrapped}, nil
}
return wrapped, nil
}
Expand Down Expand Up @@ -149,6 +155,68 @@ func (conn *ExecerContext) Exec(query string, args []driver.Value) (driver.Resul
return nil, errors.New("Exec was called when ExecContext was implemented")
}

// QueryerContext implements a database/sql.driver.QueryerContext
type QueryerContext struct {
*Conn
}

func isQueryer(conn driver.Conn) bool {
switch conn.(type) {
case driver.QueryerContext:
return true
case driver.Queryer:
return true
default:
return false
}
}

func (conn *QueryerContext) queryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
switch c := conn.Conn.Conn.(type) {
case driver.QueryerContext:
return c.QueryContext(ctx, query, args)
case driver.Queryer:
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return c.Query(query, dargs)
default:
// This should not happen
return nil, errors.New("QueryerContext created for a non Queryer driver.Conn")
}
}

func (conn *QueryerContext) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
var err error

list := namedToInterface(args)

// Query `Before` Hooks
if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil {
return nil, err
}

results, err := conn.queryContext(ctx, query, args)
if err != nil {
return results, handlerErr(ctx, conn.hooks, err, query, list...)
}

if ctx, err = conn.hooks.After(ctx, query, list...); err != nil {
return nil, err
}

return results, err
}

// ExecerQueryerContext implements database/sql.driver.ExecerContext and
// database/sql.driver.QueryerContext
type ExecerQueryerContext struct {
*Conn
*ExecerContext
*QueryerContext
}

// Stmt implements a database/sql/driver.Stmt
type Stmt struct {
Stmt driver.Stmt
Expand Down

0 comments on commit 98e9bd8

Please sign in to comment.