Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sqlhooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ func (drv *Driver) Open(name string) (driver.Conn, error) {
return conn, err
}

// Drivers that don't implement driver.ConnBeginTx are not supported.
if _, ok := conn.(driver.ConnBeginTx); !ok {
return nil, errors.New("driver must implement driver.ConnBeginTx")
}

wrapped := &Conn{conn, drv.hooks}
if isExecer(conn) && isQueryer(conn) && isSessionResetter(conn) {
return &ExecerQueryerContextWithSessionResetter{wrapped,
Expand Down Expand Up @@ -97,6 +102,9 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt
func (conn *Conn) Prepare(query string) (driver.Stmt, error) { return conn.Conn.Prepare(query) }
func (conn *Conn) Close() error { return conn.Conn.Close() }
func (conn *Conn) Begin() (driver.Tx, error) { return conn.Conn.Begin() }
func (conn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return conn.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
}

// ExecerContext implements a database/sql.driver.ExecerContext
type ExecerContext struct {
Expand Down
25 changes: 25 additions & 0 deletions sqlhooks_interface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
*FakeConnQueryer
*FakeConnSessionResetter
}{}, nil
case "NonConnBeginTx":
return &FakeConnUnsupported{}, nil
}

return nil, errors.New("Fake driver not implemented")
Expand All @@ -80,6 +82,9 @@ func (*FakeConnBasic) Close() error {
func (*FakeConnBasic) Begin() (driver.Tx, error) {
return nil, errors.New("Not implemented")
}
func (*FakeConnBasic) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) {
return nil, errors.New("Not implemented")
}

type FakeConnExecer struct{}

Expand Down Expand Up @@ -111,6 +116,20 @@ func (*FakeConnSessionResetter) ResetSession(ctx context.Context) error {
return errors.New("Not implemented")
}

// FakeConnUnsupported implements a database/sql.driver.Conn but doesn't implement
// driver.ConnBeginTx.
type FakeConnUnsupported struct{}

func (*FakeConnUnsupported) Prepare(query string) (driver.Stmt, error) {
return nil, errors.New("Not implemented")
}
func (*FakeConnUnsupported) Close() error {
return errors.New("Not implemented")
}
func (*FakeConnUnsupported) Begin() (driver.Tx, error) {
return nil, errors.New("Not implemented")
}

func TestInterfaces(t *testing.T) {
drv := Wrap(&fakeDriver{}, &testHooks{})

Expand All @@ -123,3 +142,9 @@ func TestInterfaces(t *testing.T) {
}
}
}

func TestUnsupportedDrivers(t *testing.T) {
drv := Wrap(&fakeDriver{}, &testHooks{})
_, err := drv.Open("NonConnBeginTx")
require.EqualError(t, err, "driver must implement driver.ConnBeginTx")
}