diff --git a/sqlhooks.go b/sqlhooks.go index 2964249..a55d826 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -160,6 +160,13 @@ func (conn *ExecerContext) Exec(query string, args []driver.Value) (driver.Resul return nil, errors.New("Exec was called when ExecContext was implemented") } +func (conn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if ciCtx, is := conn.Conn.(driver.ConnBeginTx); is { + return ciCtx.BeginTx(ctx, opts) + } + return nil, errors.New("driver does not implement driver.ConnBeginTx") +} + // QueryerContext implements a database/sql.driver.QueryerContext type QueryerContext struct { *Conn diff --git a/sqlhooks_postgres_test.go b/sqlhooks_postgres_test.go index 6a6560b..bf5d92e 100644 --- a/sqlhooks_postgres_test.go +++ b/sqlhooks_postgres_test.go @@ -1,6 +1,7 @@ package sqlhooks import ( + "context" "database/sql" "os" "testing" @@ -53,5 +54,14 @@ func TestPostgres(t *testing.T) { s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count), ) assert.Equal(t, 5, count) + + { // Should execute the query successfully when a transaction with non default isolation level is used. + var count int + tx, err := s.db.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}) + if assert.NoError(t, err) { + require.NoError(t, tx.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)) + assert.Equal(t, 5, count) + } + } }) }