diff --git a/pkg/migrations/backfill.go b/pkg/migrations/backfill.go index b0cad98a..35ef729f 100644 --- a/pkg/migrations/backfill.go +++ b/pkg/migrations/backfill.go @@ -18,7 +18,7 @@ import ( // 2. Get the first batch of rows from the table, ordered by the primary key. // 3. Update each row in the batch, setting the value of the primary key column to itself. // 4. Repeat steps 2 and 3 until no more rows are returned. -func backfill(ctx context.Context, conn *sql.DB, table *schema.Table, cbs ...CallbackFn) error { +func Backfill(ctx context.Context, conn *sql.DB, table *schema.Table, cbs ...CallbackFn) error { // get the backfill column identityColumn := getIdentityColumn(table) if identityColumn == nil { diff --git a/pkg/migrations/migrations.go b/pkg/migrations/migrations.go index 7182478c..a8d422e5 100644 --- a/pkg/migrations/migrations.go +++ b/pkg/migrations/migrations.go @@ -17,7 +17,8 @@ type Operation interface { // Start will apply the required changes to enable supporting the new schema // version in the database (through a view) // update the given views to expose the new schema version - Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error + // Returns the table that requires backfilling, if any. + Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) // Complete will update the database schema to match the current version // after calling Start. diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index df8d9f75..c08999f6 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -15,31 +15,32 @@ import ( var _ Operation = (*OpAddColumn)(nil) -func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) if err := addColumn(ctx, conn, *o, table); err != nil { - return fmt.Errorf("failed to start add column operation: %w", err) + return nil, fmt.Errorf("failed to start add column operation: %w", err) } if o.Column.Comment != nil { if err := addCommentToColumn(ctx, conn, o.Table, TemporaryName(o.Column.Name), *o.Column.Comment); err != nil { - return fmt.Errorf("failed to add comment to column: %w", err) + return nil, fmt.Errorf("failed to add comment to column: %w", err) } } if !o.Column.IsNullable() && o.Column.Default == nil { if err := addNotNullConstraint(ctx, conn, o.Table, o.Column.Name, TemporaryName(o.Column.Name)); err != nil { - return fmt.Errorf("failed to add not null constraint: %w", err) + return nil, fmt.Errorf("failed to add not null constraint: %w", err) } } if o.Column.Check != nil { if err := o.addCheckConstraint(ctx, conn); err != nil { - return fmt.Errorf("failed to add check constraint: %w", err) + return nil, fmt.Errorf("failed to add check constraint: %w", err) } } + var tableToBackfill *schema.Table if o.Up != nil { err := createTrigger(ctx, conn, triggerConfig{ Name: TriggerName(o.Table, o.Column.Name), @@ -52,18 +53,16 @@ func (o *OpAddColumn) Start(ctx context.Context, conn *sql.DB, stateSchema strin SQL: *o.Up, }) if err != nil { - return fmt.Errorf("failed to create trigger: %w", err) - } - if err := backfill(ctx, conn, table, cbs...); err != nil { - return fmt.Errorf("failed to backfill column: %w", err) + return nil, fmt.Errorf("failed to create trigger: %w", err) } + tableToBackfill = table } table.AddColumn(o.Column.Name, schema.Column{ Name: TemporaryName(o.Column.Name), }) - return nil + return tableToBackfill, nil } func (o *OpAddColumn) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_alter_column.go b/pkg/migrations/op_alter_column.go index ca2474fa..9ede6b84 100644 --- a/pkg/migrations/op_alter_column.go +++ b/pkg/migrations/op_alter_column.go @@ -11,7 +11,7 @@ import ( var _ Operation = (*OpAlterColumn)(nil) -func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpAlterColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { op := o.innerOperation() return op.Start(ctx, conn, stateSchema, s, cbs...) diff --git a/pkg/migrations/op_change_type.go b/pkg/migrations/op_change_type.go index 9a6cd2e2..f50d5a9e 100644 --- a/pkg/migrations/op_change_type.go +++ b/pkg/migrations/op_change_type.go @@ -21,14 +21,14 @@ type OpChangeType struct { var _ Operation = (*OpChangeType)(nil) -func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) column := table.GetColumn(o.Column) // Create a copy of the column on the underlying table. d := NewColumnDuplicator(conn, table, column).WithType(o.Type) if err := d.Duplicate(ctx); err != nil { - return fmt.Errorf("failed to duplicate column: %w", err) + return nil, fmt.Errorf("failed to duplicate column: %w", err) } // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. @@ -43,12 +43,7 @@ func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema stri SQL: o.Up, }) if err != nil { - return fmt.Errorf("failed to create up trigger: %w", err) - } - - // Backfill the new column with values from the old column. - if err := backfill(ctx, conn, table, cbs...); err != nil { - return fmt.Errorf("failed to backfill column: %w", err) + return nil, fmt.Errorf("failed to create up trigger: %w", err) } // Add the new column to the internal schema representation. This is done @@ -70,10 +65,10 @@ func (o *OpChangeType) Start(ctx context.Context, conn *sql.DB, stateSchema stri SQL: o.Down, }) if err != nil { - return fmt.Errorf("failed to create down trigger: %w", err) + return nil, fmt.Errorf("failed to create down trigger: %w", err) } - return nil + return table, nil } func (o *OpChangeType) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_create_index.go b/pkg/migrations/op_create_index.go index 49396cb2..8316ec66 100644 --- a/pkg/migrations/op_create_index.go +++ b/pkg/migrations/op_create_index.go @@ -14,13 +14,13 @@ import ( var _ Operation = (*OpCreateIndex)(nil) -func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpCreateIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // create index concurrently _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)", pq.QuoteIdentifier(o.Name), pq.QuoteIdentifier(o.Table), strings.Join(quoteColumnNames(o.Columns), ", "))) - return err + return nil, err } func (o *OpCreateIndex) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_create_table.go b/pkg/migrations/op_create_table.go index b5d73fc8..8fcf9401 100644 --- a/pkg/migrations/op_create_table.go +++ b/pkg/migrations/op_create_table.go @@ -13,20 +13,20 @@ import ( var _ Operation = (*OpCreateTable)(nil) -func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { tempName := TemporaryName(o.Name) _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s (%s)", pq.QuoteIdentifier(tempName), columnsToSQL(o.Columns))) if err != nil { - return err + return nil, err } // Add comments to any columns that have them for _, col := range o.Columns { if col.Comment != nil { if err := addCommentToColumn(ctx, conn, tempName, col.Name, *col.Comment); err != nil { - return fmt.Errorf("failed to add comment to column: %w", err) + return nil, fmt.Errorf("failed to add comment to column: %w", err) } } } @@ -34,7 +34,7 @@ func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema str // Add comment to the table itself if o.Comment != nil { if err := addCommentToTable(ctx, conn, tempName, *o.Comment); err != nil { - return fmt.Errorf("failed to add comment to table: %w", err) + return nil, fmt.Errorf("failed to add comment to table: %w", err) } } @@ -50,7 +50,7 @@ func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema str Columns: columns, }) - return nil + return nil, nil } func (o *OpCreateTable) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_column.go b/pkg/migrations/op_drop_column.go index 61e75915..27a50e15 100644 --- a/pkg/migrations/op_drop_column.go +++ b/pkg/migrations/op_drop_column.go @@ -13,7 +13,7 @@ import ( var _ Operation = (*OpDropColumn)(nil) -func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { if o.Down != nil { err := createTrigger(ctx, conn, triggerConfig{ Name: TriggerName(o.Table, o.Column), @@ -26,12 +26,12 @@ func (o *OpDropColumn) Start(ctx context.Context, conn *sql.DB, stateSchema stri SQL: *o.Down, }) if err != nil { - return err + return nil, err } } s.GetTable(o.Table).RemoveColumn(o.Column) - return nil + return nil, nil } func (o *OpDropColumn) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_constraint.go b/pkg/migrations/op_drop_constraint.go index a650753b..bf3fcfc9 100644 --- a/pkg/migrations/op_drop_constraint.go +++ b/pkg/migrations/op_drop_constraint.go @@ -13,14 +13,14 @@ import ( var _ Operation = (*OpDropConstraint)(nil) -func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) column := table.GetColumn(o.Column) // Create a copy of the column on the underlying table. d := NewColumnDuplicator(conn, table, column).WithoutConstraint(o.Name) if err := d.Duplicate(ctx); err != nil { - return fmt.Errorf("failed to duplicate column: %w", err) + return nil, fmt.Errorf("failed to duplicate column: %w", err) } // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. @@ -35,12 +35,7 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema SQL: o.upSQL(), }) if err != nil { - return fmt.Errorf("failed to create up trigger: %w", err) - } - - // Backfill the new column with values from the old column. - if err := backfill(ctx, conn, table, cbs...); err != nil { - return fmt.Errorf("failed to backfill column: %w", err) + return nil, fmt.Errorf("failed to create up trigger: %w", err) } // Add the new column to the internal schema representation. This is done @@ -62,9 +57,9 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema SQL: o.Down, }) if err != nil { - return fmt.Errorf("failed to create down trigger: %w", err) + return nil, fmt.Errorf("failed to create down trigger: %w", err) } - return nil + return table, nil } func (o *OpDropConstraint) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_index.go b/pkg/migrations/op_drop_index.go index c360579a..5a437dd4 100644 --- a/pkg/migrations/op_drop_index.go +++ b/pkg/migrations/op_drop_index.go @@ -12,9 +12,9 @@ import ( var _ Operation = (*OpDropIndex)(nil) -func (o *OpDropIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpDropIndex) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // no-op - return nil + return nil, nil } func (o *OpDropIndex) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_not_null.go b/pkg/migrations/op_drop_not_null.go index 5d502a9d..d0147d8a 100644 --- a/pkg/migrations/op_drop_not_null.go +++ b/pkg/migrations/op_drop_not_null.go @@ -20,14 +20,14 @@ type OpDropNotNull struct { var _ Operation = (*OpDropNotNull)(nil) -func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) column := table.GetColumn(o.Column) // Create a copy of the column on the underlying table. d := NewColumnDuplicator(conn, table, column).WithoutNotNull() if err := d.Duplicate(ctx); err != nil { - return fmt.Errorf("failed to duplicate column: %w", err) + return nil, fmt.Errorf("failed to duplicate column: %w", err) } // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. @@ -42,12 +42,7 @@ func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema str SQL: o.upSQL(), }) if err != nil { - return fmt.Errorf("failed to create up trigger: %w", err) - } - - // Backfill the new column with values from the old column. - if err := backfill(ctx, conn, table, cbs...); err != nil { - return fmt.Errorf("failed to backfill column: %w", err) + return nil, fmt.Errorf("failed to create up trigger: %w", err) } // Add the new column to the internal schema representation. This is done @@ -69,10 +64,10 @@ func (o *OpDropNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema str SQL: o.Down, }) if err != nil { - return fmt.Errorf("failed to create down trigger: %w", err) + return nil, fmt.Errorf("failed to create down trigger: %w", err) } - return nil + return table, nil } func (o *OpDropNotNull) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_drop_table.go b/pkg/migrations/op_drop_table.go index f43b1ca0..a6fc316b 100644 --- a/pkg/migrations/op_drop_table.go +++ b/pkg/migrations/op_drop_table.go @@ -13,9 +13,9 @@ import ( var _ Operation = (*OpDropTable)(nil) -func (o *OpDropTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpDropTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { s.RemoveTable(o.Name) - return nil + return nil, nil } func (o *OpDropTable) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_raw_sql.go b/pkg/migrations/op_raw_sql.go index 1719ed9b..bb6ec8c0 100644 --- a/pkg/migrations/op_raw_sql.go +++ b/pkg/migrations/op_raw_sql.go @@ -11,12 +11,12 @@ import ( var _ Operation = (*OpRawSQL)(nil) -func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpRawSQL) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { if !o.OnComplete { _, err := conn.ExecContext(ctx, o.Up) - return err + return nil, err } - return nil + return nil, nil } func (o *OpRawSQL) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_rename_column.go b/pkg/migrations/op_rename_column.go index 063cee49..272aef57 100644 --- a/pkg/migrations/op_rename_column.go +++ b/pkg/migrations/op_rename_column.go @@ -19,10 +19,10 @@ type OpRenameColumn struct { var _ Operation = (*OpRenameColumn)(nil) -func (o *OpRenameColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpRenameColumn) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) table.RenameColumn(o.From, o.To) - return nil + return nil, nil } func (o *OpRenameColumn) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_rename_table.go b/pkg/migrations/op_rename_table.go index ae10b174..afba756c 100644 --- a/pkg/migrations/op_rename_table.go +++ b/pkg/migrations/op_rename_table.go @@ -13,8 +13,8 @@ import ( var _ Operation = (*OpRenameTable)(nil) -func (o *OpRenameTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { - return s.RenameTable(o.From, o.To) +func (o *OpRenameTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { + return nil, s.RenameTable(o.From, o.To) } func (o *OpRenameTable) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_set_check.go b/pkg/migrations/op_set_check.go index 497db2f8..975bb2a0 100644 --- a/pkg/migrations/op_set_check.go +++ b/pkg/migrations/op_set_check.go @@ -22,19 +22,19 @@ type OpSetCheckConstraint struct { var _ Operation = (*OpSetCheckConstraint)(nil) -func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) column := table.GetColumn(o.Column) // Create a copy of the column on the underlying table. d := NewColumnDuplicator(conn, table, column) if err := d.Duplicate(ctx); err != nil { - return fmt.Errorf("failed to duplicate column: %w", err) + return nil, fmt.Errorf("failed to duplicate column: %w", err) } // Add the check constraint to the new column as NOT VALID. if err := o.addCheckConstraint(ctx, conn); err != nil { - return fmt.Errorf("failed to add check constraint: %w", err) + return nil, fmt.Errorf("failed to add check constraint: %w", err) } // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. @@ -49,12 +49,7 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSch SQL: o.Up, }) if err != nil { - return fmt.Errorf("failed to create up trigger: %w", err) - } - - // Backfill the new column with values from the old column. - if err := backfill(ctx, conn, table, cbs...); err != nil { - return fmt.Errorf("failed to backfill column: %w", err) + return nil, fmt.Errorf("failed to create up trigger: %w", err) } // Add the new column to the internal schema representation. This is done @@ -76,9 +71,9 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSch SQL: o.Down, }) if err != nil { - return fmt.Errorf("failed to create down trigger: %w", err) + return nil, fmt.Errorf("failed to create down trigger: %w", err) } - return nil + return table, nil } func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_set_fk.go b/pkg/migrations/op_set_fk.go index 0308c9f8..1137ad96 100644 --- a/pkg/migrations/op_set_fk.go +++ b/pkg/migrations/op_set_fk.go @@ -21,19 +21,19 @@ type OpSetForeignKey struct { var _ Operation = (*OpSetForeignKey)(nil) -func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) column := table.GetColumn(o.Column) // Create a copy of the column on the underlying table. d := NewColumnDuplicator(conn, table, column) if err := d.Duplicate(ctx); err != nil { - return fmt.Errorf("failed to duplicate column: %w", err) + return nil, fmt.Errorf("failed to duplicate column: %w", err) } // Create a NOT VALID foreign key constraint on the new column. if err := o.addForeignKeyConstraint(ctx, conn); err != nil { - return fmt.Errorf("failed to add foreign key constraint: %w", err) + return nil, fmt.Errorf("failed to add foreign key constraint: %w", err) } // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. @@ -48,12 +48,7 @@ func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema s SQL: o.Up, }) if err != nil { - return fmt.Errorf("failed to create up trigger: %w", err) - } - - // Backfill the new column with values from the old column. - if err := backfill(ctx, conn, table, cbs...); err != nil { - return fmt.Errorf("failed to backfill column: %w", err) + return nil, fmt.Errorf("failed to create up trigger: %w", err) } // Add the new column to the internal schema representation. This is done @@ -75,10 +70,10 @@ func (o *OpSetForeignKey) Start(ctx context.Context, conn *sql.DB, stateSchema s SQL: o.Down, }) if err != nil { - return fmt.Errorf("failed to create down trigger: %w", err) + return nil, fmt.Errorf("failed to create down trigger: %w", err) } - return nil + return table, nil } func (o *OpSetForeignKey) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_set_notnull.go b/pkg/migrations/op_set_notnull.go index 1f8541e4..3d6e938f 100644 --- a/pkg/migrations/op_set_notnull.go +++ b/pkg/migrations/op_set_notnull.go @@ -20,19 +20,19 @@ type OpSetNotNull struct { var _ Operation = (*OpSetNotNull)(nil) -func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) column := table.GetColumn(o.Column) // Create a copy of the column on the underlying table. d := NewColumnDuplicator(conn, table, column) if err := d.Duplicate(ctx); err != nil { - return fmt.Errorf("failed to duplicate column: %w", err) + return nil, fmt.Errorf("failed to duplicate column: %w", err) } // Add an unchecked NOT NULL constraint to the new column. if err := addNotNullConstraint(ctx, conn, o.Table, o.Column, TemporaryName(o.Column)); err != nil { - return fmt.Errorf("failed to add not null constraint: %w", err) + return nil, fmt.Errorf("failed to add not null constraint: %w", err) } // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. @@ -47,12 +47,7 @@ func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema stri SQL: o.Up, }) if err != nil { - return fmt.Errorf("failed to create up trigger: %w", err) - } - - // Backfill the new column with values from the old column. - if err := backfill(ctx, conn, table, cbs...); err != nil { - return fmt.Errorf("failed to backfill column: %w", err) + return nil, fmt.Errorf("failed to create up trigger: %w", err) } // Add the new column to the internal schema representation. This is done @@ -74,10 +69,10 @@ func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema stri SQL: o.downSQL(), }) if err != nil { - return fmt.Errorf("failed to create down trigger: %w", err) + return nil, fmt.Errorf("failed to create down trigger: %w", err) } - return nil + return table, nil } func (o *OpSetNotNull) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_set_replica_identity.go b/pkg/migrations/op_set_replica_identity.go index 425ee3d1..c2c2370d 100644 --- a/pkg/migrations/op_set_replica_identity.go +++ b/pkg/migrations/op_set_replica_identity.go @@ -15,7 +15,7 @@ import ( var _ Operation = (*OpSetReplicaIdentity)(nil) -func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { // build the correct form of the `SET REPLICA IDENTITY` statement based on the`identity type identitySQL := strings.ToUpper(o.Identity.Type) if identitySQL == "INDEX" { @@ -26,7 +26,7 @@ func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn *sql.DB, stateSch _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s REPLICA IDENTITY %s", pq.QuoteIdentifier(o.Table), identitySQL)) - return err + return nil, err } func (o *OpSetReplicaIdentity) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/migrations/op_set_unique.go b/pkg/migrations/op_set_unique.go index 3b61494f..c6b6f597 100644 --- a/pkg/migrations/op_set_unique.go +++ b/pkg/migrations/op_set_unique.go @@ -21,19 +21,19 @@ type OpSetUnique struct { var _ Operation = (*OpSetUnique)(nil) -func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) error { +func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema, cbs ...CallbackFn) (*schema.Table, error) { table := s.GetTable(o.Table) column := table.GetColumn(o.Column) // create a copy of the column on the underlying table. d := NewColumnDuplicator(conn, table, column) if err := d.Duplicate(ctx); err != nil { - return fmt.Errorf("failed to duplicate column: %w", err) + return nil, fmt.Errorf("failed to duplicate column: %w", err) } // Add a unique index to the new column if err := o.addUniqueIndex(ctx, conn); err != nil { - return fmt.Errorf("failed to add unique index: %w", err) + return nil, fmt.Errorf("failed to add unique index: %w", err) } // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. @@ -48,12 +48,7 @@ func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema strin SQL: o.Up, }) if err != nil { - return fmt.Errorf("failed to create up trigger: %w", err) - } - - // Backfill the new column with values from the old column. - if err := backfill(ctx, conn, table, cbs...); err != nil { - return fmt.Errorf("failed to backfill column: %w", err) + return nil, fmt.Errorf("failed to create up trigger: %w", err) } // Add the new column to the internal schema representation. This is done @@ -75,10 +70,10 @@ func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema strin SQL: o.downSQL(), }) if err != nil { - return fmt.Errorf("failed to create down trigger: %w", err) + return nil, fmt.Errorf("failed to create down trigger: %w", err) } - return nil + return table, nil } func (o *OpSetUnique) Complete(ctx context.Context, conn *sql.DB, s *schema.Schema) error { diff --git a/pkg/roll/execute.go b/pkg/roll/execute.go index 03374246..9578ceee 100644 --- a/pkg/roll/execute.go +++ b/pkg/roll/execute.go @@ -40,8 +40,9 @@ func (m *Roll) Start(ctx context.Context, migration *migrations.Migration, cbs . } // execute operations + var tablesToBackfill []*schema.Table for _, op := range migration.Operations { - err := op.Start(ctx, m.pgConn, m.state.Schema(), newSchema, cbs...) + table, err := op.Start(ctx, m.pgConn, m.state.Schema(), newSchema, cbs...) if err != nil { errRollback := m.Rollback(ctx) @@ -60,6 +61,16 @@ func (m *Roll) Start(ctx context.Context, migration *migrations.Migration, cbs . } } } + if table != nil { + tablesToBackfill = append(tablesToBackfill, table) + } + } + + // perform backfill operations for those tables that require it + for _, table := range tablesToBackfill { + if err := migrations.Backfill(ctx, m.pgConn, table, cbs...); err != nil { + return fmt.Errorf("unable to backfill table %q: %w", table.Name, err) + } } if m.disableVersionSchemas {