diff --git a/pkg/migrations/duplicate.go b/pkg/migrations/duplicate.go index 627c397a..6dced048 100644 --- a/pkg/migrations/duplicate.go +++ b/pkg/migrations/duplicate.go @@ -41,9 +41,10 @@ func (d *Duplicator) WithType(t string) *Duplicator { // constraints as the original column. func (d *Duplicator) Duplicate(ctx context.Context) error { const ( - cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s` - cSetDefaultSQL = `ALTER COLUMN %s SET DEFAULT %s` - cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)` + cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s` + cSetDefaultSQL = `ALTER COLUMN %s SET DEFAULT %s` + cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)` + cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID` ) // Generate SQL to duplicate the column's name and type @@ -68,6 +69,16 @@ func (d *Duplicator) Duplicate(ctx context.Context) error { } } + // Generate SQL to duplicate any check constraints on the column + for _, cc := range d.table.CheckConstraints { + if slices.Contains(cc.Columns, d.column.Name) { + sql += fmt.Sprintf(", "+cAddCheckConstraintSQL, + pq.QuoteIdentifier(DuplicationName(cc.Name)), + rewriteCheckExpression(cc.Definition, d.column.Name, d.asName), + ) + } + } + _, err := d.conn.ExecContext(ctx, sql) return err diff --git a/pkg/migrations/op_change_type_test.go b/pkg/migrations/op_change_type_test.go index 4e6baf7c..3de7f4d9 100644 --- a/pkg/migrations/op_change_type_test.go +++ b/pkg/migrations/op_change_type_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" + "github.com/xataio/pgroll/pkg/testutils" ) func TestChangeColumnType(t *testing.T) { @@ -287,6 +288,63 @@ func TestChangeColumnType(t *testing.T) { }, rows) }, }, + { + name: "changing column type preserves any check constraints on the column", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "integer", + Pk: true, + }, + { + Name: "username", + Type: "text", + Nullable: true, + Check: &migrations.CheckConstraint{ + Name: "username_length", + Constraint: "length(username) > 3", + }, + }, + }, + }, + }, + }, + { + Name: "02_change_type", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "users", + Column: "username", + Type: "varchar(255)", + Up: "username", + Down: "username", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // Inserting a row that violates the check constraint should fail. + MustNotInsert(t, db, "public", "02_change_type", "users", map[string]string{ + "id": "1", + "username": "a", + }, testutils.CheckViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // Inserting a row that violates the check constraint should fail. + MustNotInsert(t, db, "public", "02_change_type", "users", map[string]string{ + "id": "2", + "username": "b", + }, testutils.CheckViolationErrorCode) + }, + }, }) } diff --git a/pkg/migrations/op_set_check_test.go b/pkg/migrations/op_set_check_test.go index 7bf4b724..ecf959ee 100644 --- a/pkg/migrations/op_set_check_test.go +++ b/pkg/migrations/op_set_check_test.go @@ -294,6 +294,71 @@ func TestSetCheckConstraint(t *testing.T) { ValidatedForeignKeyMustExist(t, db, "public", "employees", "fk_employee_department") }, }, + { + name: "existing check constraints are preserved when adding a check constraint", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "posts", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "title", + Type: "text", + Check: &migrations.CheckConstraint{ + Name: "check_title_length", + Constraint: "length(title) > 3", + }, + }, + { + Name: "body", + Type: "text", + }, + }, + }, + }, + }, + { + Name: "02_add_check_constraint", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "posts", + Column: "body", + Check: &migrations.CheckConstraint{ + Name: "check_body_length", + Constraint: "length(body) > 3", + }, + Up: "(SELECT CASE WHEN length(body) <= 3 THEN LPAD(body, 4, '-') ELSE body END)", + Down: "body", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // The check constraint on the `title` column still exists. + MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ + "id": "1", + "title": "a", + "body": "this is the post body", + }, testutils.CheckViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // The check constraint on the `title` column still exists. + MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ + "id": "2", + "title": "b", + "body": "this is another post body", + }, testutils.CheckViolationErrorCode) + }, + }, }) } diff --git a/pkg/migrations/op_set_fk_test.go b/pkg/migrations/op_set_fk_test.go index 709a4daa..dcdbdab8 100644 --- a/pkg/migrations/op_set_fk_test.go +++ b/pkg/migrations/op_set_fk_test.go @@ -357,6 +357,91 @@ func TestSetForeignKey(t *testing.T) { ValidatedForeignKeyMustExist(t, db, "public", "posts", "fk_users_id_1") }, }, + { + name: "check constraints on a column are preserved when adding a foreign key constraint", + migrations: []migrations.Migration{ + { + Name: "01_add_tables", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "name", + Type: "text", + }, + }, + }, + &migrations.OpCreateTable{ + Name: "posts", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "title", + Type: "text", + Check: &migrations.CheckConstraint{ + Name: "title_length", + Constraint: "length(title) > 3", + }, + }, + { + Name: "user_id", + Type: "integer", + }, + }, + }, + }, + }, + { + Name: "02_add_fk_constraint", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "posts", + Column: "user_id", + References: &migrations.ForeignKeyReference{ + Name: "fk_users_id", + Table: "users", + Column: "id", + }, + Up: "(SELECT CASE WHEN EXISTS (SELECT 1 FROM users WHERE users.id = user_id) THEN user_id ELSE NULL END)", + Down: "user_id", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // Set up the users table with a reference row + MustInsert(t, db, "public", "02_add_fk_constraint", "users", map[string]string{ + "name": "alice", + }) + + // Inserting a row that violates the check constraint should fail. + MustNotInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{ + "id": "1", + "user_id": "1", + "title": "a", + }, testutils.CheckViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // Inserting a row that violates the check constraint should fail. + MustNotInsert(t, db, "public", "02_add_fk_constraint", "posts", map[string]string{ + "id": "2", + "user_id": "1", + "title": "b", + }, testutils.CheckViolationErrorCode) + }, + }, }) } diff --git a/pkg/migrations/op_set_notnull_test.go b/pkg/migrations/op_set_notnull_test.go index d619386c..b8098924 100644 --- a/pkg/migrations/op_set_notnull_test.go +++ b/pkg/migrations/op_set_notnull_test.go @@ -301,7 +301,7 @@ func TestSetNotNull(t *testing.T) { }, }, { - name: "setting a nullable column to not null retains any default defined on the column", + name: "setting a column to not null retains any default defined on the column", migrations: []migrations.Migration{ { Name: "01_add_table", @@ -364,6 +364,62 @@ func TestSetNotNull(t *testing.T) { }, rows) }, }, + { + name: "setting a column to not null retains any check constraints defined on the column", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "integer", + Pk: true, + }, + { + Name: "name", + Type: "text", + Nullable: true, + Check: &migrations.CheckConstraint{ + Name: "name_length", + Constraint: "length(name) > 3", + }, + }, + }, + }, + }, + }, + { + Name: "02_set_not_null", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "users", + Column: "name", + Nullable: ptr(false), + Up: "(SELECT CASE WHEN name IS NULL THEN 'anonymous' ELSE name END)", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // Inserting a row that violates the check constraint should fail. + MustNotInsert(t, db, "public", "02_set_not_null", "users", map[string]string{ + "id": "1", + "name": "a", + }, testutils.CheckViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // Inserting a row that violates the check constraint should fail. + MustNotInsert(t, db, "public", "02_set_not_null", "users", map[string]string{ + "id": "2", + "name": "b", + }, testutils.CheckViolationErrorCode) + }, + }, }) } diff --git a/pkg/migrations/op_set_unique_test.go b/pkg/migrations/op_set_unique_test.go index aca893db..6e732401 100644 --- a/pkg/migrations/op_set_unique_test.go +++ b/pkg/migrations/op_set_unique_test.go @@ -339,5 +339,67 @@ func TestSetColumnUnique(t *testing.T) { ValidatedForeignKeyMustExist(t, db, "public", "employees", "fk_employee_department") }, }, + { + name: "check constraints are preserved when adding a unique constraint", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "reviews", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "username", + Type: "text", + }, + { + Name: "review", + Type: "text", + Check: &migrations.CheckConstraint{ + Name: "reviews_review_check", + Constraint: "length(review) > 3", + }, + }, + }, + }, + }, + }, + { + Name: "02_set_unique", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "reviews", + Column: "username", + Unique: &migrations.UniqueConstraint{ + Name: "reviews_username_unique", + }, + Up: "username", + Down: "username", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // Inserting a row that violates the check constraint should fail. + MustNotInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ + "username": "alice", + "review": "x", + }, testutils.CheckViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // Inserting a row that violates the check constraint should fail. + MustNotInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ + "username": "bob", + "review": "y", + }, testutils.CheckViolationErrorCode) + }, + }, }) } diff --git a/pkg/migrations/rename.go b/pkg/migrations/rename.go index 7ac9c2be..bbda06a6 100644 --- a/pkg/migrations/rename.go +++ b/pkg/migrations/rename.go @@ -12,12 +12,15 @@ import ( "github.com/xataio/pgroll/pkg/schema" ) -// RenameDuplicatedColumn renames a duplicated column to its original name and renames any foreign keys -// on the duplicated column to their original name. +// RenameDuplicatedColumn: +// * renames a duplicated column to its original name +// * renames any foreign keys on the duplicated column to their original name. +// * Validates and renames any temporary `CHECK` constraints on the duplicated column. func RenameDuplicatedColumn(ctx context.Context, conn *sql.DB, table *schema.Table, column *schema.Column) error { const ( - cRenameColumnSQL = `ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s` - cRenameConstraintSQL = `ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s` + cRenameColumnSQL = `ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s` + cRenameConstraintSQL = `ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s` + cValidateConstraintSQL = `ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s` ) // Rename the old column to the new column name @@ -33,14 +36,13 @@ func RenameDuplicatedColumn(ctx context.Context, conn *sql.DB, table *schema.Tab // Rename any foreign keys on the duplicated column from their temporary name // to their original name - var renameConstraintSQL string for _, fk := range table.ForeignKeys { if !IsDuplicatedName(fk.Name) { continue } if slices.Contains(fk.Columns, TemporaryName(column.Name)) { - renameConstraintSQL = fmt.Sprintf(cRenameConstraintSQL, + renameConstraintSQL := fmt.Sprintf(cRenameConstraintSQL, pq.QuoteIdentifier(table.Name), pq.QuoteIdentifier(fk.Name), pq.QuoteIdentifier(StripDuplicationPrefix(fk.Name)), @@ -48,7 +50,38 @@ func RenameDuplicatedColumn(ctx context.Context, conn *sql.DB, table *schema.Tab _, err = conn.ExecContext(ctx, renameConstraintSQL) if err != nil { - return fmt.Errorf("failed to rename column constraint %q: %w", fk.Name, err) + return fmt.Errorf("failed to rename foreign key constraint %q: %w", fk.Name, err) + } + } + } + + // Validate and rename any temporary `CHECK` constraints on the duplicated + // column. + for _, cc := range table.CheckConstraints { + if !IsDuplicatedName(cc.Name) { + continue + } + + if slices.Contains(cc.Columns, TemporaryName(column.Name)) { + validateConstraintSQL := fmt.Sprintf(cValidateConstraintSQL, + pq.QuoteIdentifier(table.Name), + pq.QuoteIdentifier(cc.Name), + ) + + _, err = conn.ExecContext(ctx, validateConstraintSQL) + if err != nil { + return fmt.Errorf("failed to validate check constraint %q: %w", cc.Name, err) + } + + renameConstraintSQL := fmt.Sprintf(cRenameConstraintSQL, + pq.QuoteIdentifier(table.Name), + pq.QuoteIdentifier(cc.Name), + pq.QuoteIdentifier(StripDuplicationPrefix(cc.Name)), + ) + + _, err = conn.ExecContext(ctx, renameConstraintSQL) + if err != nil { + return fmt.Errorf("failed to rename check constraint %q: %w", cc.Name, err) } } }