Skip to content
This repository has been archived by the owner on Dec 8, 2020. It is now read-only.

Commit

Permalink
New: Add sqlutil.WithTx() for handling nested SQL transactions
Browse files Browse the repository at this point in the history
  • Loading branch information
impl committed Mar 7, 2019
1 parent b8ec5e6 commit b15b387
Show file tree
Hide file tree
Showing 54 changed files with 27,414 additions and 5 deletions.
38 changes: 33 additions & 5 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

107 changes: 107 additions & 0 deletions sqlutil/tx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package sqlutil

import (
"context"
"database/sql"
"fmt"
"reflect"
"time"

"github.com/puppetlabs/insights-stdlib/lifecycle"
)

type txContextKey uintptr

type txContextValue struct {
tx *sql.Tx
c uint64
}

func (key txContextKey) Get(ctx context.Context) (v txContextValue, ok bool) {
v, ok = ctx.Value(key).(txContextValue)
return
}

func (key txContextKey) Set(ctx context.Context, v txContextValue) context.Context {
return context.WithValue(ctx, key, v)
}

type txDelegate interface {
Rollback() error
Commit() error
}

type savepointDelegate struct {
ctx context.Context
v txContextValue
}

func (sd *savepointDelegate) Rollback() error {
return lifecycle.NewCloserBuilder().
RequireContext(func(ctx context.Context) error {
_, err := sd.v.tx.ExecContext(sd.ctx, fmt.Sprintf("ROLLBACK TO SAVEPOINT tx_%d", sd.v.c))
return err
}).
Timeout(500 * time.Millisecond).
Build().
Do(sd.ctx)
}

func (sd *savepointDelegate) Commit() error {
_, err := sd.v.tx.ExecContext(sd.ctx, fmt.Sprintf("RELEASE SAVEPOINT tx_%d", sd.v.c))
if err != nil {
sd.Rollback()
return err
}

return nil
}

// WithTx executes the given function within a database transaction.
//
// This function is reentrant. It will create nested transactions using
// SAVEPOINTs if needed. However, transactions are not thread-safe, so care must
// be used when creating goroutines inside transactions: ensure a separate
// context is used that does not carry the transaction with it.
func WithTx(ctx context.Context, db *sql.DB, fn func(ctx context.Context, tx *sql.Tx) error) (err error) {
key := txContextKey(reflect.ValueOf(db).Pointer())

var d txDelegate

v, ok := key.Get(ctx)
if ok {
// Use savepoints.
v.c++

if _, err := v.tx.ExecContext(ctx, fmt.Sprintf("SAVEPOINT tx_%d", v.c)); err != nil {
return err
}

ctx = key.Set(ctx, v)
d = &savepointDelegate{ctx: ctx, v: v}
} else {
// Use BeginTx().
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return err
}

v = txContextValue{tx: tx}
ctx = key.Set(ctx, v)
d = tx
}

defer func() {
if p := recover(); p != nil {
d.Rollback()
panic(p)
}
}()

if err := fn(ctx, v.tx); err != nil {
d.Rollback()
return err
}

return d.Commit()
}
146 changes: 146 additions & 0 deletions sqlutil/tx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package sqlutil_test

import (
"context"
"database/sql"
"fmt"
"testing"

"github.com/DATA-DOG/go-sqlmock"
"github.com/puppetlabs/insights-stdlib/sqlutil"
"github.com/stretchr/testify/require"
)

func TestTxSimple(t *testing.T) {
ctx := context.Background()

db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()

mock.ExpectBegin()
mock.ExpectQuery("SELECT 1").WillReturnRows(sqlmock.NewRows([]string{"<value>"}).AddRow(1))
mock.ExpectCommit()

require.NoError(t, sqlutil.WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
_, err := tx.QueryContext(ctx, "SELECT 1")
return err
}))

require.NoError(t, mock.ExpectationsWereMet())
}

func TestTxNested(t *testing.T) {
ctx := context.Background()

db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()

mock.ExpectBegin()
mock.ExpectExec("UPDATE t SET a = 1").WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec("SAVEPOINT tx_1").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("UPDATE t SET a = 2").WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec("RELEASE SAVEPOINT tx_1").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("SAVEPOINT tx_1").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("UPDATE t SET a = 3").WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec("SAVEPOINT tx_2").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("UPDATE t SET a = 4").WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec("RELEASE SAVEPOINT tx_2").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("RELEASE SAVEPOINT tx_1").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectCommit()

require.NoError(t, sqlutil.WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
if _, err := tx.ExecContext(ctx, "UPDATE t SET a = 1"); err != nil {
return err
}

err := sqlutil.WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
_, err = tx.ExecContext(ctx, "UPDATE t SET a = 2")
return err
})
if err != nil {
return err
}

return sqlutil.WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
if _, err := tx.ExecContext(ctx, "UPDATE t SET a = 3"); err != nil {
return err
}

return sqlutil.WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
_, err = tx.ExecContext(ctx, "UPDATE t SET a = 4")
return err
})
})
}))

require.NoError(t, mock.ExpectationsWereMet())
}

func TestTxSimpleRollback(t *testing.T) {
ctx := context.Background()

db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()

mock.ExpectBegin()
mock.ExpectQuery("SELECT 1").WillReturnError(fmt.Errorf("in test"))
mock.ExpectRollback()

require.NotNil(t, sqlutil.WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
_, err := tx.QueryContext(ctx, "SELECT 1")
return err
}))

require.NoError(t, mock.ExpectationsWereMet())
}

func TestTxNestedRollback(t *testing.T) {
ctx := context.Background()

db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()

mock.ExpectBegin()
mock.ExpectExec("UPDATE t SET a = 1").WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec("SAVEPOINT tx_1").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("UPDATE t SET a = 2").WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec("RELEASE SAVEPOINT tx_1").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("SAVEPOINT tx_1").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("SAVEPOINT tx_2").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("UPDATE t SET a = 3").WillReturnError(fmt.Errorf("in test"))
mock.ExpectExec("ROLLBACK TO SAVEPOINT tx_2").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("UPDATE t SET a = 4").WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec("RELEASE SAVEPOINT tx_1").WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectCommit()

require.NoError(t, sqlutil.WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
if _, err := tx.ExecContext(ctx, "UPDATE t SET a = 1"); err != nil {
return err
}

err := sqlutil.WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
_, err = tx.ExecContext(ctx, "UPDATE t SET a = 2")
return err
})
if err != nil {
return err
}

return sqlutil.WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
err := sqlutil.WithTx(ctx, db, func(ctx context.Context, tx *sql.Tx) error {
_, err = tx.ExecContext(ctx, "UPDATE t SET a = 3")
return err
})
require.NotNil(t, err)

_, err = tx.ExecContext(ctx, "UPDATE t SET a = 4")
return err
})
}))

require.NoError(t, mock.ExpectationsWereMet())
}
3 changes: 3 additions & 0 deletions vendor/github.com/DATA-DOG/go-sqlmock/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions vendor/github.com/DATA-DOG/go-sqlmock/.travis.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit b15b387

Please sign in to comment.