diff --git a/sqlhooks.go b/sqlhooks.go index 2964249..1da05ba 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -49,6 +49,11 @@ func (drv *Driver) Open(name string) (driver.Conn, error) { return conn, err } + // Drivers that don't implement driver.ConnBeginTx are not supported. + if _, ok := conn.(driver.ConnBeginTx); !ok { + return nil, errors.New("driver must implement driver.ConnBeginTx") + } + wrapped := &Conn{conn, drv.hooks} if isExecer(conn) && isQueryer(conn) && isSessionResetter(conn) { return &ExecerQueryerContextWithSessionResetter{wrapped, @@ -97,6 +102,9 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt func (conn *Conn) Prepare(query string) (driver.Stmt, error) { return conn.Conn.Prepare(query) } func (conn *Conn) Close() error { return conn.Conn.Close() } func (conn *Conn) Begin() (driver.Tx, error) { return conn.Conn.Begin() } +func (conn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return conn.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts) +} // ExecerContext implements a database/sql.driver.ExecerContext type ExecerContext struct { diff --git a/sqlhooks_interface_test.go b/sqlhooks_interface_test.go index 462b12b..dd55c60 100644 --- a/sqlhooks_interface_test.go +++ b/sqlhooks_interface_test.go @@ -63,6 +63,8 @@ func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { *FakeConnQueryer *FakeConnSessionResetter }{}, nil + case "NonConnBeginTx": + return &FakeConnUnsupported{}, nil } return nil, errors.New("Fake driver not implemented") @@ -80,6 +82,9 @@ func (*FakeConnBasic) Close() error { func (*FakeConnBasic) Begin() (driver.Tx, error) { return nil, errors.New("Not implemented") } +func (*FakeConnBasic) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) { + return nil, errors.New("Not implemented") +} type FakeConnExecer struct{} @@ -111,6 +116,20 @@ func (*FakeConnSessionResetter) ResetSession(ctx context.Context) error { return errors.New("Not implemented") } +// FakeConnUnsupported implements a database/sql.driver.Conn but doesn't implement +// driver.ConnBeginTx. +type FakeConnUnsupported struct{} + +func (*FakeConnUnsupported) Prepare(query string) (driver.Stmt, error) { + return nil, errors.New("Not implemented") +} +func (*FakeConnUnsupported) Close() error { + return errors.New("Not implemented") +} +func (*FakeConnUnsupported) Begin() (driver.Tx, error) { + return nil, errors.New("Not implemented") +} + func TestInterfaces(t *testing.T) { drv := Wrap(&fakeDriver{}, &testHooks{}) @@ -123,3 +142,9 @@ func TestInterfaces(t *testing.T) { } } } + +func TestUnsupportedDrivers(t *testing.T) { + drv := Wrap(&fakeDriver{}, &testHooks{}) + _, err := drv.Open("NonConnBeginTx") + require.EqualError(t, err, "driver must implement driver.ConnBeginTx") +}