diff --git a/pkg/migrations/op_common.go b/pkg/migrations/op_common.go index 5958ef0d..9ca96639 100644 --- a/pkg/migrations/op_common.go +++ b/pkg/migrations/op_common.go @@ -11,13 +11,14 @@ import ( type OpName string const ( - OpNameCreateTable OpName = "create_table" - OpNameRenameTable OpName = "rename_table" - OpNameDropTable OpName = "drop_table" - OpNameAddColumn OpName = "add_column" - OpNameDropColumn OpName = "drop_column" - OpNameCreateIndex OpName = "create_index" - OpNameDropIndex OpName = "drop_index" + OpNameCreateTable OpName = "create_table" + OpNameRenameTable OpName = "rename_table" + OpNameDropTable OpName = "drop_table" + OpNameAddColumn OpName = "add_column" + OpNameDropColumn OpName = "drop_column" + OpNameCreateIndex OpName = "create_index" + OpNameDropIndex OpName = "drop_index" + OpNameRenameColumn OpName = "rename_column" ) func TemporaryName(name string) string { @@ -88,6 +89,9 @@ func (v *Operations) UnmarshalJSON(data []byte) error { case OpNameDropColumn: item = &OpDropColumn{} + case OpNameRenameColumn: + item = &OpRenameColumn{} + case OpNameCreateIndex: item = &OpCreateIndex{} @@ -141,6 +145,9 @@ func (v Operations) MarshalJSON() ([]byte, error) { case *OpDropColumn: opName = OpNameDropColumn + case *OpRenameColumn: + opName = OpNameRenameColumn + case *OpCreateIndex: opName = OpNameCreateIndex diff --git a/pkg/migrations/op_rename_column.go b/pkg/migrations/op_rename_column.go new file mode 100644 index 00000000..77e0e97a --- /dev/null +++ b/pkg/migrations/op_rename_column.go @@ -0,0 +1,56 @@ +package migrations + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + + "pg-roll/pkg/schema" +) + +type OpRenameColumn struct { + Table string `json:"table"` + From string `json:"from"` + To string `json:"to"` +} + +var _ Operation = (*OpRenameColumn)(nil) + +func (o *OpRenameColumn) Start(ctx context.Context, conn *sql.DB, schemaName string, stateSchema string, s *schema.Schema) error { + table := s.GetTable(o.Table) + table.RenameColumn(o.From, o.To) + return nil +} + +func (o *OpRenameColumn) Complete(ctx context.Context, conn *sql.DB) error { + // rename the column in the underlying table + _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(o.From), + pq.QuoteIdentifier(o.To))) + return err +} + +func (o *OpRenameColumn) Rollback(ctx context.Context, conn *sql.DB) error { + // no-op + return nil +} + +func (o *OpRenameColumn) Validate(ctx context.Context, s *schema.Schema) error { + table := s.GetTable(o.Table) + if table == nil { + return TableDoesNotExistError{Name: o.Table} + } + + if table.GetColumn(o.From) == nil { + return ColumnDoesNotExistError{Table: o.Table, Name: o.From} + } + + if table.GetColumn(o.To) != nil { + return ColumnAlreadyExistsError{Table: o.Table, Name: o.From} + } + + return nil +} diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 801a29a3..da42ab27 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -106,3 +106,8 @@ func (t *Table) AddColumn(name string, c Column) { func (t *Table) RemoveColumn(column string) { delete(t.Columns, column) } + +func (t *Table) RenameColumn(from, to string) { + t.Columns[to] = t.Columns[from] + delete(t.Columns, from) +}