diff --git a/pkg/migration/apply.go b/pkg/migration/apply.go index e40b58be2..f1d08c5f4 100644 --- a/pkg/migration/apply.go +++ b/pkg/migration/apply.go @@ -25,7 +25,11 @@ func FindPendingMigrations(localMigrations, remoteMigrations []string) ([]string remote := remoteMigrations[i] filename := filepath.Base(localMigrations[j]) // Check if migration has been applied before, LoadLocalMigrations guarantees a match - local := migrateFilePattern.FindStringSubmatch(filename)[1] + matches := migrateFilePattern.FindStringSubmatch(filename) + if len(matches) < 2 { + return nil, errors.Errorf("invalid migration filename: %s", filename) + } + local := matches[1] if remote == local { j++ i++ diff --git a/pkg/migration/file.go b/pkg/migration/file.go index fbd4a3b7f..66e137cba 100644 --- a/pkg/migration/file.go +++ b/pkg/migration/file.go @@ -34,7 +34,6 @@ func NewMigrationFromFile(path string, fsys fs.FS) (*MigrationFile, error) { return nil, err } file := MigrationFile{Statements: lines} - // Parse version from file name filename := filepath.Base(path) matches := migrateFilePattern.FindStringSubmatch(filename) if len(matches) > 2 { @@ -50,7 +49,7 @@ func parseFile(path string, fsys fs.FS) ([]string, error) { return nil, errors.Errorf("failed to open migration file: %w", err) } defer sql.Close() - // Unless explicitly specified, Use file length as max buffer size + if !viper.IsSet("SCANNER_BUFFER_SIZE") { if fi, err := sql.Stat(); err == nil { if size := int(fi.Size()); size > parser.MaxScannerCapacity { @@ -69,38 +68,81 @@ func NewMigrationFromReader(sql io.Reader) (*MigrationFile, error) { return &MigrationFile{Statements: lines}, nil } +// ----------------------------------------------------------------------------- +// ExecBatch +// ----------------------------------------------------------------------------- + func (m *MigrationFile) ExecBatch(ctx context.Context, conn *pgx.Conn) error { - // Batch migration commands, without using statement cache - batch := &pgconn.Batch{} - for _, line := range m.Statements { - batch.ExecParams(line, nil, nil, nil, nil) + inTx := false + + for i, stmt := range m.Statements { + if isNonTransactional(stmt) { + + if inTx { + if _, err := conn.Exec(ctx, "COMMIT"); err != nil { + _, _ = conn.Exec(ctx, "ROLLBACK") + return errors.Errorf("failed to commit transaction before non-transactional statement: %v", err) + } + inTx = false + } + + if _, err := conn.Exec(ctx, stmt); err != nil { + return formatStatementError(i, stmt, err) + } + + } else { + + if !inTx { + if _, err := conn.Exec(ctx, "BEGIN"); err != nil { + return errors.Errorf("failed to begin transaction: %v", err) + } + inTx = true + } + + if _, err := conn.Exec(ctx, stmt); err != nil { + if inTx { + _, _ = conn.Exec(ctx, "ROLLBACK") + inTx = false + } + return formatStatementError(i, stmt, err) + } + } } - // Insert into migration history + + if inTx { + if _, err := conn.Exec(ctx, "COMMIT"); err != nil { + _, _ = conn.Exec(ctx, "ROLLBACK") + return errors.Errorf("failed to commit transaction: %v", err) + } + } + if len(m.Version) > 0 { - if err := m.insertVersionSQL(conn, batch); err != nil { + if err := m.insertVersionExec(ctx, conn); err != nil { return err } } - // ExecBatch is implicitly transactional - if result, err := conn.PgConn().ExecBatch(ctx, batch).ReadAll(); err != nil { - // Defaults to printing the last statement on error - stat := INSERT_MIGRATION_VERSION - i := len(result) - if i < len(m.Statements) { - stat = m.Statements[i] - } - var msg []string - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) { - stat = markError(stat, int(pgErr.Position)) - if len(pgErr.Detail) > 0 { - msg = append(msg, pgErr.Detail) - } + + return nil +} + +// ----------------------------------------------------------------------------- +// Error Formatting +// ----------------------------------------------------------------------------- + +func formatStatementError(idx int, stmt string, err error) error { + var msg []string + var pgErr *pgconn.PgError + + if errors.As(err, &pgErr) { + stat := markError(stmt, int(pgErr.Position)) + if len(pgErr.Detail) > 0 { + msg = append(msg, pgErr.Detail) } - msg = append(msg, fmt.Sprintf("At statement: %d", i), stat) + msg = append(msg, fmt.Sprintf("At statement: %d", idx), stat) return errors.Errorf("%w\n%s", err, strings.Join(msg, "\n")) } - return nil + + return err } func markError(stat string, pos int) string { @@ -110,7 +152,6 @@ func markError(stat string, pos int) string { pos -= c + 1 continue } - // Show a caret below the error position if pos > 0 { caret := append(bytes.Repeat([]byte{' '}, pos-1), '^') lines = append(lines[:j+1], string(caret)) @@ -120,35 +161,82 @@ func markError(stat string, pos int) string { return strings.Join(lines, "\n") } -func (m *MigrationFile) insertVersionSQL(conn *pgx.Conn, batch *pgconn.Batch) error { +// ----------------------------------------------------------------------------- +// ⭐ insertVersionExec — FIXED FOR PGX V5 ⭐ +// ----------------------------------------------------------------------------- + +func (m *MigrationFile) insertVersionExec(ctx context.Context, conn *pgx.Conn) error { value := pgtype.TextArray{} if err := value.Set(m.Statements); err != nil { return errors.Errorf("failed to set text array: %w", err) } + ci := conn.ConnInfo() var err error var encoded []byte var valueFormat int16 + if conn.Config().PreferSimpleProtocol { - encoded, err = value.EncodeText(ci, encoded) + encoded, err = value.EncodeText(ci, nil) valueFormat = pgtype.TextFormatCode } else { - encoded, err = value.EncodeBinary(ci, encoded) + encoded, err = value.EncodeBinary(ci, nil) valueFormat = pgtype.BinaryFormatCode } + if err != nil { return errors.Errorf("failed to encode binary: %w", err) } - batch.ExecParams( + + // ExecParams returns a MultiResultReader WITHOUT Err() in pgx v5 + mrr := conn.PgConn().ExecParams( + ctx, INSERT_MIGRATION_VERSION, [][]byte{[]byte(m.Version), []byte(m.Name), encoded}, []uint32{pgtype.TextOID, pgtype.TextOID, pgtype.TextArrayOID}, []int16{pgtype.TextFormatCode, pgtype.TextFormatCode, valueFormat}, nil, ) + + for { + _, readErr := mrr.Read() + if readErr == pgconn.ErrNoMoreResults { + break + } + if readErr != nil { + return errors.Errorf("failed to insert migration version: %w", readErr) + } + } + + if err := mrr.Close(); err != nil { + return errors.Errorf("failed to insert migration version: %w", err) + } + return nil } +// ----------------------------------------------------------------------------- +// Non-transactional detection +// ----------------------------------------------------------------------------- + +func isNonTransactional(stmt string) bool { + upper := strings.ToUpper(stmt) + + if strings.Contains(upper, "CONCURRENTLY") { + return true + } + + if regexp.MustCompile(`(?i)ALTER\s+TYPE\s+.+\s+ADD\s+VALUE`).MatchString(stmt) { + return true + } + + return false +} + +// ----------------------------------------------------------------------------- +// Seed File +// ----------------------------------------------------------------------------- + type SeedFile struct { Path string Hash string @@ -161,31 +249,34 @@ func NewSeedFile(path string, fsys fs.FS) (*SeedFile, error) { return nil, errors.Errorf("failed to open seed file: %w", err) } defer sql.Close() + hash := sha256.New() if _, err := io.Copy(hash, sql); err != nil { return nil, errors.Errorf("failed to hash file: %w", err) } digest := hex.EncodeToString(hash.Sum(nil)) + return &SeedFile{Path: path, Hash: digest}, nil } func (m *SeedFile) ExecBatchWithCache(ctx context.Context, conn *pgx.Conn, fsys fs.FS) error { - // Parse each file individually to reduce memory usage lines, err := parseFile(m.Path, fsys) if err != nil { return err } - // Data statements don't mutate schemas, safe to use statement cache + batch := pgx.Batch{} if !m.Dirty { for _, line := range lines { batch.Queue(line) } } + batch.Queue(UPSERT_SEED_FILE, m.Path, m.Hash) - // No need to track version here because there are no schema changes + if err := conn.SendBatch(ctx, &batch).Close(); err != nil { return errors.Errorf("failed to send batch: %w", err) } + return nil } diff --git a/pkg/migration/file_test.go b/pkg/migration/file_test.go index 45bee71b6..bdc6dfb6e 100644 --- a/pkg/migration/file_test.go +++ b/pkg/migration/file_test.go @@ -49,9 +49,10 @@ func TestMigrationFile(t *testing.T) { // Setup mock postgres conn := pgtest.NewConn() defer conn.Close(t) - conn.Query(migration.Statements[0]). + // Expect Exec for statement execution and Exec for migration version insert + conn.Exec(migration.Statements[0]). Reply("CREATE SCHEMA"). - Query(INSERT_MIGRATION_VERSION, "0", "", migration.Statements). + Exec(INSERT_MIGRATION_VERSION, "0", "", migration.Statements). Reply("INSERT 0 1") // Run test err := migration.ExecBatch(context.Background(), conn.MockClient(t)) @@ -67,9 +68,10 @@ func TestMigrationFile(t *testing.T) { // Setup mock postgres conn := pgtest.NewConn() defer conn.Close(t) - conn.Query(migration.Statements[0]). + // Simulate error on executing the statement (Exec), and then the insert attempt + conn.Exec(migration.Statements[0]). ReplyError(pgerrcode.DuplicateSchema, `schema "public" already exists`). - Query(INSERT_MIGRATION_VERSION, "0", "", migration.Statements). + Exec(INSERT_MIGRATION_VERSION, "0", "", migration.Statements). Reply("INSERT 0 1") // Run test err := migration.ExecBatch(context.Background(), conn.MockClient(t))