Skip to content

Commit

Permalink
Make state initialization concurrency safe (#285)
Browse files Browse the repository at this point in the history
Make `pgroll` state initialization concurrency safe by using Postgres
advisory locking to ensure at most one connection can initialize at at
time.

See docs on Postgres advisory locking:
*
https://www.postgresql.org/docs/current/functions-admin.html#FUNCTIONS-ADVISORY-LOCKS
*
https://www.postgresql.org/docs/current/explicit-locking.html#ADVISORY-LOCKS

Closes #283
  • Loading branch information
andrew-farries committed Feb 26, 2024
1 parent 161fde6 commit b4e3044
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 71 deletions.
25 changes: 21 additions & 4 deletions pkg/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,28 @@ func New(ctx context.Context, pgURL, stateSchema string) (*State, error) {
}

func (s *State) Init(ctx context.Context) error {
// ensure pgroll internal tables exist
// TODO: eventually use migrations for this instead of hardcoding
_, err := s.pgConn.ExecContext(ctx, fmt.Sprintf(sqlInit, pq.QuoteIdentifier(s.schema)))
tx, err := s.pgConn.Begin()
if err != nil {
return err
}
defer tx.Rollback()

return err
// Try to obtain an advisory lock.
// The key is an arbitrary number, used to distinguish the lock from other locks.
// The lock is automatically released when the transaction is committed or rolled back.
const key int64 = 0x2c03057fb9525b
_, err = tx.ExecContext(ctx, "SELECT pg_advisory_xact_lock($1)", key)
if err != nil {
return err
}

// Perform pgroll state initialization
_, err = tx.ExecContext(ctx, fmt.Sprintf(sqlInit, pq.QuoteIdentifier(s.schema)))
if err != nil {
return err
}

return tx.Commit()
}

func (s *State) Close() error {
Expand Down
24 changes: 24 additions & 0 deletions pkg/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding/json"
"fmt"
"strings"
"sync"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -132,6 +133,29 @@ func TestPgRollInitializationInANonDefaultSchema(t *testing.T) {
})
}

func TestConcurrentInitialization(t *testing.T) {
t.Parallel()

testutils.WithUninitializedState(t, func(state *state.State) {
ctx := context.Background()
numGoroutines := 10

wg := sync.WaitGroup{}
wg.Add(numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()

if err := state.Init(ctx); err != nil {
t.Error(err)
}
}()
}

wg.Wait()
})
}

func TestReadSchema(t *testing.T) {
t.Parallel()

Expand Down
125 changes: 58 additions & 67 deletions pkg/testutils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,49 +89,13 @@ func WithStateInSchemaAndConnectionToContainer(t *testing.T, schema string, fn f
t.Helper()
ctx := context.Background()

tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

dbName := randomDBName()

_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}

u, err := url.Parse(tConnStr)
if err != nil {
t.Fatal(err)
}

u.Path = "/" + dbName
connStr := u.String()
db, connStr, _ := setupTestDatabase(t)

st, err := state.New(ctx, connStr, schema)
if err != nil {
t.Fatal(err)
}

db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

// init the state
if err := st.Init(ctx); err != nil {
t.Fatal(err)
}
Expand All @@ -143,46 +107,25 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.
WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn)
}

func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) {
func WithUninitializedState(t *testing.T, fn func(*state.State)) {
t.Helper()
ctx := context.Background()

tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

dbName := randomDBName()

_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}
_, connStr, _ := setupTestDatabase(t)

u, err := url.Parse(tConnStr)
st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
t.Fatal(err)
}

u.Path = "/" + dbName
connStr := u.String()
fn(st)
}

db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}
func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) {
t.Helper()
ctx := context.Background()

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})
db, connStr, dbName := setupTestDatabase(t)

st, err := state.New(ctx, connStr, "pgroll")
if err != nil {
Expand Down Expand Up @@ -230,3 +173,51 @@ func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, f
func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) {
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
}

// setupTestDatabase creates a new database in the test container and returns:
// - a connection to the new database
// - the connection string to the new database
// - the name of the new database
func setupTestDatabase(t *testing.T) (*sql.DB, string, string) {
t.Helper()
ctx := context.Background()

tDB, err := sql.Open("postgres", tConnStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := tDB.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

dbName := randomDBName()

_, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName)))
if err != nil {
t.Fatal(err)
}

u, err := url.Parse(tConnStr)
if err != nil {
t.Fatal(err)
}

u.Path = "/" + dbName
connStr := u.String()

db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatal(err)
}

t.Cleanup(func() {
if err := db.Close(); err != nil {
t.Fatalf("Failed to close database connection: %v", err)
}
})

return db, connStr, dbName
}

0 comments on commit b4e3044

Please sign in to comment.