From 96430ab23e09eb6d551a168464d333736aaae7ab Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Thu, 22 Feb 2024 20:42:02 +0000 Subject: [PATCH 1/4] Add test for concurrent initialization --- pkg/state/state_test.go | 24 ++++++++++++++++++++++++ pkg/testutils/util.go | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index 3f5a44c5..0602b023 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "strings" + "sync" "testing" "github.com/google/go-cmp/cmp" @@ -132,6 +133,29 @@ func TestPgRollInitializationInANonDefaultSchema(t *testing.T) { }) } +func TestConcurrentInitialization(t *testing.T) { + t.Parallel() + + testutils.WithUninitializedState(t, func(state *state.State) { + ctx := context.Background() + numGoroutines := 10 + + wg := sync.WaitGroup{} + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + + if err := state.Init(ctx); err != nil { + t.Error(err) + } + }() + } + + wg.Wait() + }) +} + func TestReadSchema(t *testing.T) { t.Parallel() diff --git a/pkg/testutils/util.go b/pkg/testutils/util.go index 45fe7d0f..3f528191 100644 --- a/pkg/testutils/util.go +++ b/pkg/testutils/util.go @@ -143,6 +143,44 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql. WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn) } +func WithUninitializedState(t *testing.T, fn func(*state.State)) { + t.Helper() + ctx := context.Background() + + tDB, err := sql.Open("postgres", tConnStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := tDB.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + dbName := randomDBName() + + _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) + if err != nil { + t.Fatal(err) + } + + u, err := url.Parse(tConnStr) + if err != nil { + t.Fatal(err) + } + + u.Path = "/" + dbName + connStr := u.String() + + st, err := state.New(ctx, connStr, "pgroll") + if err != nil { + t.Fatal(err) + } + + fn(st) +} + func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) { t.Helper() ctx := context.Background() From 0cba49d165254733f150296656e8649ef2826b0b Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Thu, 22 Feb 2024 16:15:56 +0000 Subject: [PATCH 2/4] Obtain an advisory lock before running Init --- pkg/state/state.go | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/pkg/state/state.go b/pkg/state/state.go index ec95cccc..1ad7b845 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -354,11 +354,26 @@ func New(ctx context.Context, pgURL, stateSchema string) (*State, error) { } func (s *State) Init(ctx context.Context) error { - // ensure pgroll internal tables exist - // TODO: eventually use migrations for this instead of hardcoding - _, err := s.pgConn.ExecContext(ctx, fmt.Sprintf(sqlInit, pq.QuoteIdentifier(s.schema))) + tx, err := s.pgConn.Begin() + if err != nil { + return err + } + defer tx.Rollback() - return err + // Try to obtain an advisory lock + const key int64 = 0x2c03057fb9525b + _, err = tx.ExecContext(ctx, "SELECT pg_advisory_xact_lock($1)", key) + if err != nil { + return err + } + + // Perform pgroll state initialization + _, err = tx.ExecContext(ctx, fmt.Sprintf(sqlInit, pq.QuoteIdentifier(s.schema))) + if err != nil { + return err + } + + return tx.Commit() } func (s *State) Close() error { From 001f9ac03a930f98516f157e6339f1ab5172120c Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Thu, 22 Feb 2024 21:14:06 +0000 Subject: [PATCH 3/4] Refactor testutils helpers Reduce duplication between the different With* helper functions. --- pkg/testutils/util.go | 143 ++++++++++++++---------------------------- 1 file changed, 48 insertions(+), 95 deletions(-) diff --git a/pkg/testutils/util.go b/pkg/testutils/util.go index 3f528191..74ccea9e 100644 --- a/pkg/testutils/util.go +++ b/pkg/testutils/util.go @@ -89,49 +89,13 @@ func WithStateInSchemaAndConnectionToContainer(t *testing.T, schema string, fn f t.Helper() ctx := context.Background() - tDB, err := sql.Open("postgres", tConnStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := tDB.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - dbName := randomDBName() - - _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) - if err != nil { - t.Fatal(err) - } - - u, err := url.Parse(tConnStr) - if err != nil { - t.Fatal(err) - } - - u.Path = "/" + dbName - connStr := u.String() + db, connStr, _ := setupTestDatabase(t) st, err := state.New(ctx, connStr, schema) if err != nil { t.Fatal(err) } - db, err := sql.Open("postgres", connStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - // init the state if err := st.Init(ctx); err != nil { t.Fatal(err) } @@ -147,41 +111,74 @@ func WithUninitializedState(t *testing.T, fn func(*state.State)) { t.Helper() ctx := context.Background() - tDB, err := sql.Open("postgres", tConnStr) + _, connStr, _ := setupTestDatabase(t) + + st, err := state.New(ctx, connStr, "pgroll") + if err != nil { + t.Fatal(err) + } + + fn(st) +} + +func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) { + t.Helper() + ctx := context.Background() + + db, connStr, dbName := setupTestDatabase(t) + + st, err := state.New(ctx, connStr, "pgroll") + if err != nil { + t.Fatal(err) + } + + err = st.Init(ctx) + if err != nil { + t.Fatal(err) + } + + mig, err := roll.New(ctx, connStr, schema, st, opts...) if err != nil { t.Fatal(err) } t.Cleanup(func() { - if err := tDB.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) + if err := mig.Close(); err != nil { + t.Fatalf("Failed to close migrator connection: %v", err) } }) - dbName := randomDBName() - - _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) + _, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)) if err != nil { t.Fatal(err) } - u, err := url.Parse(tConnStr) + _, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON SCHEMA %s TO pgroll", schema)) if err != nil { t.Fatal(err) } - u.Path = "/" + dbName - connStr := u.String() - - st, err := state.New(ctx, connStr, "pgroll") + _, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON DATABASE %s TO pgroll", dbName)) if err != nil { t.Fatal(err) } - fn(st) + fn(mig, db) } -func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) { +func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) { + WithMigratorInSchemaAndConnectionToContainerWithOptions(t, schema, []roll.Option{roll.WithLockTimeoutMs(500)}, fn) +} + +func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { + WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn) +} + +// setupTestDatabase creates a new database in the test container and returns: +// - a connection to the new database +// - the connection string to the new database +// - the name of the new database +func setupTestDatabase(t *testing.T) (*sql.DB, string, string) { t.Helper() ctx := context.Background() @@ -222,49 +219,5 @@ func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schem } }) - st, err := state.New(ctx, connStr, "pgroll") - if err != nil { - t.Fatal(err) - } - - err = st.Init(ctx) - if err != nil { - t.Fatal(err) - } - - mig, err := roll.New(ctx, connStr, schema, st, opts...) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := mig.Close(); err != nil { - t.Fatalf("Failed to close migrator connection: %v", err) - } - }) - - _, err = db.ExecContext(ctx, fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", schema)) - if err != nil { - t.Fatal(err) - } - - _, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON SCHEMA %s TO pgroll", schema)) - if err != nil { - t.Fatal(err) - } - - _, err = db.ExecContext(ctx, fmt.Sprintf("GRANT ALL PRIVILEGES ON DATABASE %s TO pgroll", dbName)) - if err != nil { - t.Fatal(err) - } - - fn(mig, db) -} - -func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, fn func(mig *roll.Roll, db *sql.DB)) { - WithMigratorInSchemaAndConnectionToContainerWithOptions(t, schema, []roll.Option{roll.WithLockTimeoutMs(500)}, fn) -} - -func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { - WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn) + return db, connStr, dbName } From 4cf0e6aad45b927ce706ea929e7c77f04aab7efc Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Mon, 26 Feb 2024 07:49:35 +0000 Subject: [PATCH 4/4] Add comment to locking code --- pkg/state/state.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/state/state.go b/pkg/state/state.go index 1ad7b845..79f3597a 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -360,7 +360,9 @@ func (s *State) Init(ctx context.Context) error { } defer tx.Rollback() - // Try to obtain an advisory lock + // Try to obtain an advisory lock. + // The key is an arbitrary number, used to distinguish the lock from other locks. + // The lock is automatically released when the transaction is committed or rolled back. const key int64 = 0x2c03057fb9525b _, err = tx.ExecContext(ctx, "SELECT pg_advisory_xact_lock($1)", key) if err != nil {