diff --git a/pkg/state/state.go b/pkg/state/state.go index ec95cccc..79f3597a 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -354,11 +354,28 @@ func New(ctx context.Context, pgURL, stateSchema string) (*State, error) { } func (s *State) Init(ctx context.Context) error { - // ensure pgroll internal tables exist - // TODO: eventually use migrations for this instead of hardcoding - _, err := s.pgConn.ExecContext(ctx, fmt.Sprintf(sqlInit, pq.QuoteIdentifier(s.schema))) + tx, err := s.pgConn.Begin() + if err != nil { + return err + } + defer tx.Rollback() - return err + // Try to obtain an advisory lock. + // The key is an arbitrary number, used to distinguish the lock from other locks. + // The lock is automatically released when the transaction is committed or rolled back. + const key int64 = 0x2c03057fb9525b + _, err = tx.ExecContext(ctx, "SELECT pg_advisory_xact_lock($1)", key) + if err != nil { + return err + } + + // Perform pgroll state initialization + _, err = tx.ExecContext(ctx, fmt.Sprintf(sqlInit, pq.QuoteIdentifier(s.schema))) + if err != nil { + return err + } + + return tx.Commit() } func (s *State) Close() error { diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index 3f5a44c5..0602b023 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -8,6 +8,7 @@ import ( "encoding/json" "fmt" "strings" + "sync" "testing" "github.com/google/go-cmp/cmp" @@ -132,6 +133,29 @@ func TestPgRollInitializationInANonDefaultSchema(t *testing.T) { }) } +func TestConcurrentInitialization(t *testing.T) { + t.Parallel() + + testutils.WithUninitializedState(t, func(state *state.State) { + ctx := context.Background() + numGoroutines := 10 + + wg := sync.WaitGroup{} + wg.Add(numGoroutines) + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + + if err := state.Init(ctx); err != nil { + t.Error(err) + } + }() + } + + wg.Wait() + }) +} + func TestReadSchema(t *testing.T) { t.Parallel() diff --git a/pkg/testutils/util.go b/pkg/testutils/util.go index 45fe7d0f..74ccea9e 100644 --- a/pkg/testutils/util.go +++ b/pkg/testutils/util.go @@ -89,49 +89,13 @@ func WithStateInSchemaAndConnectionToContainer(t *testing.T, schema string, fn f t.Helper() ctx := context.Background() - tDB, err := sql.Open("postgres", tConnStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := tDB.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - dbName := randomDBName() - - _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) - if err != nil { - t.Fatal(err) - } - - u, err := url.Parse(tConnStr) - if err != nil { - t.Fatal(err) - } - - u.Path = "/" + dbName - connStr := u.String() + db, connStr, _ := setupTestDatabase(t) st, err := state.New(ctx, connStr, schema) if err != nil { t.Fatal(err) } - db, err := sql.Open("postgres", connStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - // init the state if err := st.Init(ctx); err != nil { t.Fatal(err) } @@ -143,46 +107,25 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql. WithStateInSchemaAndConnectionToContainer(t, "pgroll", fn) } -func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) { +func WithUninitializedState(t *testing.T, fn func(*state.State)) { t.Helper() ctx := context.Background() - tDB, err := sql.Open("postgres", tConnStr) - if err != nil { - t.Fatal(err) - } - - t.Cleanup(func() { - if err := tDB.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) - - dbName := randomDBName() - - _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) - if err != nil { - t.Fatal(err) - } + _, connStr, _ := setupTestDatabase(t) - u, err := url.Parse(tConnStr) + st, err := state.New(ctx, connStr, "pgroll") if err != nil { t.Fatal(err) } - u.Path = "/" + dbName - connStr := u.String() + fn(st) +} - db, err := sql.Open("postgres", connStr) - if err != nil { - t.Fatal(err) - } +func WithMigratorInSchemaAndConnectionToContainerWithOptions(t *testing.T, schema string, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) { + t.Helper() + ctx := context.Background() - t.Cleanup(func() { - if err := db.Close(); err != nil { - t.Fatalf("Failed to close database connection: %v", err) - } - }) + db, connStr, dbName := setupTestDatabase(t) st, err := state.New(ctx, connStr, "pgroll") if err != nil { @@ -230,3 +173,51 @@ func WithMigratorInSchemaAndConnectionToContainer(t *testing.T, schema string, f func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, db *sql.DB)) { WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn) } + +// setupTestDatabase creates a new database in the test container and returns: +// - a connection to the new database +// - the connection string to the new database +// - the name of the new database +func setupTestDatabase(t *testing.T) (*sql.DB, string, string) { + t.Helper() + ctx := context.Background() + + tDB, err := sql.Open("postgres", tConnStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := tDB.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + dbName := randomDBName() + + _, err = tDB.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", pq.QuoteIdentifier(dbName))) + if err != nil { + t.Fatal(err) + } + + u, err := url.Parse(tConnStr) + if err != nil { + t.Fatal(err) + } + + u.Path = "/" + dbName + connStr := u.String() + + db, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := db.Close(); err != nil { + t.Fatalf("Failed to close database connection: %v", err) + } + }) + + return db, connStr, dbName +}