diff --git a/pkg/migrations/op_common.go b/pkg/migrations/op_common.go index 9ca96639..0b54d8fc 100644 --- a/pkg/migrations/op_common.go +++ b/pkg/migrations/op_common.go @@ -19,6 +19,7 @@ const ( OpNameCreateIndex OpName = "create_index" OpNameDropIndex OpName = "drop_index" OpNameRenameColumn OpName = "rename_column" + OpNameSetUnique OpName = "set_unique" ) func TemporaryName(name string) string { @@ -98,6 +99,9 @@ func (v *Operations) UnmarshalJSON(data []byte) error { case OpNameDropIndex: item = &OpDropIndex{} + case OpNameSetUnique: + item = &OpSetUnique{} + default: return fmt.Errorf("unknown migration type: %v", opName) } @@ -154,6 +158,9 @@ func (v Operations) MarshalJSON() ([]byte, error) { case *OpDropIndex: opName = OpNameDropIndex + case *OpSetUnique: + opName = OpNameSetUnique + default: panic(fmt.Errorf("unknown operation for %T", op)) } diff --git a/pkg/migrations/op_set_unique.go b/pkg/migrations/op_set_unique.go new file mode 100644 index 00000000..9d5be29f --- /dev/null +++ b/pkg/migrations/op_set_unique.go @@ -0,0 +1,66 @@ +package migrations + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/lib/pq" + + "pg-roll/pkg/schema" +) + +type OpSetUnique struct { + Table string `json:"table"` + Columns []string `json:"columns"` +} + +var _ Operation = (*OpSetUnique)(nil) + +func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, schemaName string, stateSchema string, s *schema.Schema) error { + // create unique index concurrently + _, err := conn.ExecContext(ctx, fmt.Sprintf("CREATE UNIQUE INDEX CONCURRENTLY IF NOT EXISTS %s ON %s (%s)", + pq.QuoteIdentifier(IndexName(o.Table, o.Columns)), + pq.QuoteIdentifier(o.Table), + strings.Join(quoteColumnNames(o.Columns), ", "))) + return err +} + +func (o *OpSetUnique) Complete(ctx context.Context, conn *sql.DB) error { + // create a unique constraint using the unique index + _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s ADD CONSTRAINT %s UNIQUE USING INDEX %s", + pq.QuoteIdentifier(o.Table), + pq.QuoteIdentifier(UniqueConstraintName(o.Table, o.Columns)), + pq.QuoteIdentifier(IndexName(o.Table, o.Columns)))) + + return err +} + +func (o *OpSetUnique) Rollback(ctx context.Context, conn *sql.DB) error { + // drop the index concurrently + _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", + IndexName(o.Table, o.Columns))) + + return err +} + +func (o *OpSetUnique) Validate(ctx context.Context, s *schema.Schema) error { + table := s.GetTable(o.Table) + + if table == nil { + return TableDoesNotExistError{Name: o.Table} + } + + for _, column := range o.Columns { + if table.GetColumn(column) == nil { + return ColumnDoesNotExistError{Table: o.Table, Name: column} + } + } + + return nil +} + +func UniqueConstraintName(table string, columns []string) string { + return "_pgroll_unique_" + table + "_" + strings.Join(columns, "_") +}