Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pkg/migration/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ func FindPendingMigrations(localMigrations, remoteMigrations []string) ([]string
remote := remoteMigrations[i]
filename := filepath.Base(localMigrations[j])
// Check if migration has been applied before, LoadLocalMigrations guarantees a match
local := migrateFilePattern.FindStringSubmatch(filename)[1]
matches := migrateFilePattern.FindStringSubmatch(filename)
if len(matches) < 2 {
return nil, errors.Errorf("invalid migration filename: %s", filename)
}
local := matches[1]
if remote == local {
j++
i++
Expand Down
157 changes: 124 additions & 33 deletions pkg/migration/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
return nil, err
}
file := MigrationFile{Statements: lines}
// Parse version from file name
filename := filepath.Base(path)
matches := migrateFilePattern.FindStringSubmatch(filename)
if len(matches) > 2 {
Expand All @@ -50,7 +49,7 @@
return nil, errors.Errorf("failed to open migration file: %w", err)
}
defer sql.Close()
// Unless explicitly specified, Use file length as max buffer size

if !viper.IsSet("SCANNER_BUFFER_SIZE") {
if fi, err := sql.Stat(); err == nil {
if size := int(fi.Size()); size > parser.MaxScannerCapacity {
Expand All @@ -69,38 +68,81 @@
return &MigrationFile{Statements: lines}, nil
}

// -----------------------------------------------------------------------------
// ExecBatch
// -----------------------------------------------------------------------------

func (m *MigrationFile) ExecBatch(ctx context.Context, conn *pgx.Conn) error {
// Batch migration commands, without using statement cache
batch := &pgconn.Batch{}
for _, line := range m.Statements {
batch.ExecParams(line, nil, nil, nil, nil)
inTx := false

for i, stmt := range m.Statements {
if isNonTransactional(stmt) {

if inTx {
if _, err := conn.Exec(ctx, "COMMIT"); err != nil {
_, _ = conn.Exec(ctx, "ROLLBACK")
return errors.Errorf("failed to commit transaction before non-transactional statement: %v", err)
}
inTx = false
}

if _, err := conn.Exec(ctx, stmt); err != nil {
return formatStatementError(i, stmt, err)
}

} else {

if !inTx {
if _, err := conn.Exec(ctx, "BEGIN"); err != nil {
return errors.Errorf("failed to begin transaction: %v", err)
}
inTx = true
}

if _, err := conn.Exec(ctx, stmt); err != nil {
if inTx {
_, _ = conn.Exec(ctx, "ROLLBACK")
inTx = false
}
return formatStatementError(i, stmt, err)
}
}
}
// Insert into migration history

if inTx {
if _, err := conn.Exec(ctx, "COMMIT"); err != nil {
_, _ = conn.Exec(ctx, "ROLLBACK")
return errors.Errorf("failed to commit transaction: %v", err)
}
}

if len(m.Version) > 0 {
if err := m.insertVersionSQL(conn, batch); err != nil {
if err := m.insertVersionExec(ctx, conn); err != nil {
return err
}
}
// ExecBatch is implicitly transactional
if result, err := conn.PgConn().ExecBatch(ctx, batch).ReadAll(); err != nil {
// Defaults to printing the last statement on error
stat := INSERT_MIGRATION_VERSION
i := len(result)
if i < len(m.Statements) {
stat = m.Statements[i]
}
var msg []string
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
stat = markError(stat, int(pgErr.Position))
if len(pgErr.Detail) > 0 {
msg = append(msg, pgErr.Detail)
}

return nil
}

// -----------------------------------------------------------------------------
// Error Formatting
// -----------------------------------------------------------------------------

func formatStatementError(idx int, stmt string, err error) error {
var msg []string
var pgErr *pgconn.PgError

if errors.As(err, &pgErr) {
stat := markError(stmt, int(pgErr.Position))
if len(pgErr.Detail) > 0 {
msg = append(msg, pgErr.Detail)
}
msg = append(msg, fmt.Sprintf("At statement: %d", i), stat)
msg = append(msg, fmt.Sprintf("At statement: %d", idx), stat)
return errors.Errorf("%w\n%s", err, strings.Join(msg, "\n"))
}
return nil

return err
}

func markError(stat string, pos int) string {
Expand All @@ -110,7 +152,6 @@
pos -= c + 1
continue
}
// Show a caret below the error position
if pos > 0 {
caret := append(bytes.Repeat([]byte{' '}, pos-1), '^')
lines = append(lines[:j+1], string(caret))
Expand All @@ -120,35 +161,82 @@
return strings.Join(lines, "\n")
}

func (m *MigrationFile) insertVersionSQL(conn *pgx.Conn, batch *pgconn.Batch) error {
// -----------------------------------------------------------------------------
// ⭐ insertVersionExec — FIXED FOR PGX V5 ⭐
// -----------------------------------------------------------------------------

func (m *MigrationFile) insertVersionExec(ctx context.Context, conn *pgx.Conn) error {
value := pgtype.TextArray{}
if err := value.Set(m.Statements); err != nil {
return errors.Errorf("failed to set text array: %w", err)
}

ci := conn.ConnInfo()
var err error
var encoded []byte
var valueFormat int16

if conn.Config().PreferSimpleProtocol {
encoded, err = value.EncodeText(ci, encoded)
encoded, err = value.EncodeText(ci, nil)
valueFormat = pgtype.TextFormatCode
} else {
encoded, err = value.EncodeBinary(ci, encoded)
encoded, err = value.EncodeBinary(ci, nil)
valueFormat = pgtype.BinaryFormatCode
}

if err != nil {
return errors.Errorf("failed to encode binary: %w", err)
}
batch.ExecParams(

// ExecParams returns a MultiResultReader WITHOUT Err() in pgx v5
mrr := conn.PgConn().ExecParams(
ctx,
INSERT_MIGRATION_VERSION,
[][]byte{[]byte(m.Version), []byte(m.Name), encoded},
[]uint32{pgtype.TextOID, pgtype.TextOID, pgtype.TextArrayOID},
[]int16{pgtype.TextFormatCode, pgtype.TextFormatCode, valueFormat},
nil,
)

for {
_, readErr := mrr.Read()

Check failure on line 202 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Lint

assignment mismatch: 2 variables but mrr.Read returns 1 value

Check failure on line 202 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Lint

assignment mismatch: 2 variables but mrr.Read returns 1 value

Check failure on line 202 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Lint

assignment mismatch: 2 variables but mrr.Read returns 1 value

Check failure on line 202 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Test

assignment mismatch: 2 variables but mrr.Read returns 1 value

Check failure on line 202 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Link

assignment mismatch: 2 variables but mrr.Read returns 1 value

Check failure on line 202 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Start

assignment mismatch: 2 variables but mrr.Read returns 1 value
if readErr == pgconn.ErrNoMoreResults {

Check failure on line 203 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Lint

undefined: pgconn.ErrNoMoreResults

Check failure on line 203 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Lint

undefined: pgconn.ErrNoMoreResults

Check failure on line 203 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Test

undefined: pgconn.ErrNoMoreResults

Check failure on line 203 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Link

undefined: pgconn.ErrNoMoreResults

Check failure on line 203 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Start

undefined: pgconn.ErrNoMoreResults
break
}
if readErr != nil {
return errors.Errorf("failed to insert migration version: %w", readErr)
}
}

if err := mrr.Close(); err != nil {

Check failure on line 211 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Lint

assignment mismatch: 1 variable but mrr.Close returns 2 values) (typecheck)

Check failure on line 211 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Lint

assignment mismatch: 1 variable but mrr.Close returns 2 values) (typecheck)

Check failure on line 211 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Test

assignment mismatch: 1 variable but mrr.Close returns 2 values

Check failure on line 211 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Link

assignment mismatch: 1 variable but mrr.Close returns 2 values

Check failure on line 211 in pkg/migration/file.go

View workflow job for this annotation

GitHub Actions / Start

assignment mismatch: 1 variable but mrr.Close returns 2 values
return errors.Errorf("failed to insert migration version: %w", err)
}

return nil
}

// -----------------------------------------------------------------------------
// Non-transactional detection
// -----------------------------------------------------------------------------

func isNonTransactional(stmt string) bool {
upper := strings.ToUpper(stmt)

if strings.Contains(upper, "CONCURRENTLY") {
return true
}

if regexp.MustCompile(`(?i)ALTER\s+TYPE\s+.+\s+ADD\s+VALUE`).MatchString(stmt) {
return true
}

return false
}

// -----------------------------------------------------------------------------
// Seed File
// -----------------------------------------------------------------------------

type SeedFile struct {
Path string
Hash string
Expand All @@ -161,31 +249,34 @@
return nil, errors.Errorf("failed to open seed file: %w", err)
}
defer sql.Close()

hash := sha256.New()
if _, err := io.Copy(hash, sql); err != nil {
return nil, errors.Errorf("failed to hash file: %w", err)
}
digest := hex.EncodeToString(hash.Sum(nil))

return &SeedFile{Path: path, Hash: digest}, nil
}

func (m *SeedFile) ExecBatchWithCache(ctx context.Context, conn *pgx.Conn, fsys fs.FS) error {
// Parse each file individually to reduce memory usage
lines, err := parseFile(m.Path, fsys)
if err != nil {
return err
}
// Data statements don't mutate schemas, safe to use statement cache

batch := pgx.Batch{}
if !m.Dirty {
for _, line := range lines {
batch.Queue(line)
}
}

batch.Queue(UPSERT_SEED_FILE, m.Path, m.Hash)
// No need to track version here because there are no schema changes

if err := conn.SendBatch(ctx, &batch).Close(); err != nil {
return errors.Errorf("failed to send batch: %w", err)
}

return nil
}
10 changes: 6 additions & 4 deletions pkg/migration/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ func TestMigrationFile(t *testing.T) {
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query(migration.Statements[0]).
// Expect Exec for statement execution and Exec for migration version insert
conn.Exec(migration.Statements[0]).
Reply("CREATE SCHEMA").
Query(INSERT_MIGRATION_VERSION, "0", "", migration.Statements).
Exec(INSERT_MIGRATION_VERSION, "0", "", migration.Statements).
Reply("INSERT 0 1")
// Run test
err := migration.ExecBatch(context.Background(), conn.MockClient(t))
Expand All @@ -67,9 +68,10 @@ func TestMigrationFile(t *testing.T) {
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query(migration.Statements[0]).
// Simulate error on executing the statement (Exec), and then the insert attempt
conn.Exec(migration.Statements[0]).
ReplyError(pgerrcode.DuplicateSchema, `schema "public" already exists`).
Query(INSERT_MIGRATION_VERSION, "0", "", migration.Statements).
Exec(INSERT_MIGRATION_VERSION, "0", "", migration.Statements).
Reply("INSERT 0 1")
// Run test
err := migration.ExecBatch(context.Background(), conn.MockClient(t))
Expand Down
Loading