From ddb91d1e8bf110888b2002b416ce2fb92b1b789a Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Tue, 16 Jan 2024 12:11:58 +0000 Subject: [PATCH] Preserve column properties on add `CHECK` constraint operation (#236) Preserve properties of columns when duplicating them for backfilling to add a`CHECK` constraint. Currently, the column properties that are preserved are: * `DEFAULT`s * foreign key constraints but this list will grow as more work is done on #227. --- pkg/migrations/op_set_check.go | 11 +- pkg/migrations/op_set_check_test.go | 371 +++++++++++++++++++--------- 2 files changed, 265 insertions(+), 117 deletions(-) diff --git a/pkg/migrations/op_set_check.go b/pkg/migrations/op_set_check.go index 2dce1426..497db2f8 100644 --- a/pkg/migrations/op_set_check.go +++ b/pkg/migrations/op_set_check.go @@ -27,7 +27,8 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, conn *sql.DB, stateSch column := table.GetColumn(o.Column) // Create a copy of the column on the underlying table. - if err := duplicateColumn(ctx, conn, table, *column); err != nil { + d := NewColumnDuplicator(conn, table, column) + if err := d.Duplicate(ctx); err != nil { return fmt.Errorf("failed to duplicate column: %w", err) } @@ -112,11 +113,9 @@ func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn *sql.DB, s *sc } // Rename the new column to the old column name - _, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s", - pq.QuoteIdentifier(o.Table), - pq.QuoteIdentifier(TemporaryName(o.Column)), - pq.QuoteIdentifier(o.Column))) - if err != nil { + table := s.GetTable(o.Table) + column := table.GetColumn(o.Column) + if err := RenameDuplicatedColumn(ctx, conn, table, column); err != nil { return err } diff --git a/pkg/migrations/op_set_check_test.go b/pkg/migrations/op_set_check_test.go index 7d745f41..8baae581 100644 --- a/pkg/migrations/op_set_check_test.go +++ b/pkg/migrations/op_set_check_test.go @@ -13,129 +13,278 @@ import ( func TestSetCheckConstraint(t *testing.T) { t.Parallel() - ExecuteTests(t, TestCases{{ - name: "add check constraint", - migrations: []migrations.Migration{ - { - Name: "01_add_table", - Operations: migrations.Operations{ - &migrations.OpCreateTable{ - Name: "posts", - Columns: []migrations.Column{ - { - Name: "id", - Type: "serial", - Pk: true, + ExecuteTests(t, TestCases{ + { + name: "add check constraint", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "posts", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "title", + Type: "text", + }, }, - { - Name: "title", - Type: "text", + }, + }, + }, + { + Name: "02_add_check_constraint", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "posts", + Column: "title", + Check: &migrations.CheckConstraint{ + Name: "check_title_length", + Constraint: "length(title) > 3", }, + Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)", + Down: "title", }, }, }, }, - { - Name: "02_add_check_constraint", - Operations: migrations.Operations{ - &migrations.OpAlterColumn{ - Table: "posts", - Column: "title", - Check: &migrations.CheckConstraint{ - Name: "check_title_length", - Constraint: "length(title) > 3", + afterStart: func(t *testing.T, db *sql.DB) { + // The new (temporary) `title` column should exist on the underlying table. + ColumnMustExist(t, db, "public", "posts", migrations.TemporaryName("title")) + + // Inserting a row that meets the check constraint into the old view works. + MustInsert(t, db, "public", "01_add_table", "posts", map[string]string{ + "title": "post by alice", + }) + + // Inserting a row that does not meet the check constraint into the old view also works. + MustInsert(t, db, "public", "01_add_table", "posts", map[string]string{ + "title": "b", + }) + + // Both rows have been backfilled into the new view; the short title has + // been rewritten using `up` SQL to meet the length constraint. + rows := MustSelect(t, db, "public", "02_add_check_constraint", "posts") + assert.Equal(t, []map[string]any{ + {"id": 1, "title": "post by alice"}, + {"id": 2, "title": "---b"}, + }, rows) + + // Inserting a row that meets the check constraint into the new view works. + MustInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ + "title": "post by carl", + }) + + // Inserting a row that does not meet the check constraint into the new view fails. + MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ + "title": "d", + }) + + // The row that was inserted into the new view has been backfilled into the old view. + rows = MustSelect(t, db, "public", "01_add_table", "posts") + assert.Equal(t, []map[string]any{ + {"id": 1, "title": "post by alice"}, + {"id": 2, "title": "b"}, + {"id": 3, "title": "post by carl"}, + }, rows) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + // The new (temporary) `title` column should not exist on the underlying table. + ColumnMustNotExist(t, db, "public", "posts", migrations.TemporaryName("title")) + + // The up function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", "title")) + // The down function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", migrations.TemporaryName("title"))) + + // The up trigger no longer exists. + TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", "title")) + // The down trigger no longer exists. + TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", migrations.TemporaryName("title"))) + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // Inserting a row that meets the check constraint into the new view works. + MustInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ + "title": "post by dana", + }) + + // Inserting a row that does not meet the check constraint into the new view fails. + MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ + "title": "e", + }) + + // The data in the new `posts` view is as expected. + rows := MustSelect(t, db, "public", "02_add_check_constraint", "posts") + assert.Equal(t, []map[string]any{ + {"id": 1, "title": "post by alice"}, + {"id": 2, "title": "---b"}, + {"id": 3, "title": "post by carl"}, + {"id": 5, "title": "post by dana"}, + }, rows) + + // The up function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", "title")) + // The down function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", migrations.TemporaryName("title"))) + + // The up trigger no longer exists. + TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", "title")) + // The down trigger no longer exists. + TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", migrations.TemporaryName("title"))) + }, + }, + { + name: "column defaults are preserved when adding a check constraint", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "posts", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "title", + Type: "text", + Default: ptr("'untitled'"), + }, + }, + }, + }, + }, + { + Name: "02_add_check_constraint", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "posts", + Column: "title", + Check: &migrations.CheckConstraint{ + Name: "check_title_length", + Constraint: "length(title) > 3", + }, + Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)", + Down: "title", }, - Up: "(SELECT CASE WHEN length(title) <= 3 THEN LPAD(title, 4, '-') ELSE title END)", - Down: "title", }, }, }, + afterStart: func(t *testing.T, db *sql.DB) { + // A row can be inserted into the new version of the table. + MustInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ + "id": "1", + }) + + // The newly inserted row respects the default value of the column. + rows := MustSelect(t, db, "public", "02_add_check_constraint", "posts") + assert.Equal(t, []map[string]any{ + {"id": 1, "title": "untitled"}, + }, rows) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // A row can be inserted into the new version of the table. + MustInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ + "id": "2", + }) + + // The newly inserted row respects the default value of the column. + rows := MustSelect(t, db, "public", "02_add_check_constraint", "posts") + assert.Equal(t, []map[string]any{ + {"id": 1, "title": "untitled"}, + {"id": 2, "title": "untitled"}, + }, rows) + }, }, - afterStart: func(t *testing.T, db *sql.DB) { - // The new (temporary) `title` column should exist on the underlying table. - ColumnMustExist(t, db, "public", "posts", migrations.TemporaryName("title")) - - // Inserting a row that meets the check constraint into the old view works. - MustInsert(t, db, "public", "01_add_table", "posts", map[string]string{ - "title": "post by alice", - }) - - // Inserting a row that does not meet the check constraint into the old view also works. - MustInsert(t, db, "public", "01_add_table", "posts", map[string]string{ - "title": "b", - }) - - // Both rows have been backfilled into the new view; the short title has - // been rewritten using `up` SQL to meet the length constraint. - rows := MustSelect(t, db, "public", "02_add_check_constraint", "posts") - assert.Equal(t, []map[string]any{ - {"id": 1, "title": "post by alice"}, - {"id": 2, "title": "---b"}, - }, rows) - - // Inserting a row that meets the check constraint into the new view works. - MustInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ - "title": "post by carl", - }) - - // Inserting a row that does not meet the check constraint into the new view fails. - MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ - "title": "d", - }) - - // The row that was inserted into the new view has been backfilled into the old view. - rows = MustSelect(t, db, "public", "01_add_table", "posts") - assert.Equal(t, []map[string]any{ - {"id": 1, "title": "post by alice"}, - {"id": 2, "title": "b"}, - {"id": 3, "title": "post by carl"}, - }, rows) - }, - afterRollback: func(t *testing.T, db *sql.DB) { - // The new (temporary) `title` column should not exist on the underlying table. - ColumnMustNotExist(t, db, "public", "posts", migrations.TemporaryName("title")) - - // The up function no longer exists. - FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", "title")) - // The down function no longer exists. - FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", migrations.TemporaryName("title"))) - - // The up trigger no longer exists. - TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", "title")) - // The down trigger no longer exists. - TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", migrations.TemporaryName("title"))) - }, - afterComplete: func(t *testing.T, db *sql.DB) { - // Inserting a row that meets the check constraint into the new view works. - MustInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ - "title": "post by dana", - }) - - // Inserting a row that does not meet the check constraint into the new view fails. - MustNotInsert(t, db, "public", "02_add_check_constraint", "posts", map[string]string{ - "title": "e", - }) - - // The data in the new `posts` view is as expected. - rows := MustSelect(t, db, "public", "02_add_check_constraint", "posts") - assert.Equal(t, []map[string]any{ - {"id": 1, "title": "post by alice"}, - {"id": 2, "title": "---b"}, - {"id": 3, "title": "post by carl"}, - {"id": 5, "title": "post by dana"}, - }, rows) - - // The up function no longer exists. - FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", "title")) - // The down function no longer exists. - FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("posts", migrations.TemporaryName("title"))) - - // The up trigger no longer exists. - TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", "title")) - // The down trigger no longer exists. - TriggerMustNotExist(t, db, "public", "posts", migrations.TriggerName("posts", migrations.TemporaryName("title"))) + { + name: "foreign keys are preserved when adding a check constraint", + migrations: []migrations.Migration{ + { + Name: "01_add_departments_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "departments", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "name", + Type: "text", + Nullable: false, + }, + }, + }, + }, + }, + { + Name: "02_add_employees_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "employees", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "name", + Type: "text", + Nullable: false, + }, + { + Name: "department_id", + Type: "integer", + Nullable: true, + References: &migrations.ForeignKeyReference{ + Name: "fk_employee_department", + Table: "departments", + Column: "id", + }, + }, + }, + }, + }, + }, + { + Name: "03_add_check_constraint", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "employees", + Column: "department_id", + Check: &migrations.CheckConstraint{ + Name: "check_valid_department_id", + Constraint: "department_id > 1", + }, + Up: "(SELECT CASE WHEN department_id <= 1 THEN 2 ELSE department_id END)", + Down: "department_id", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // A temporary FK constraint has been created on the temporary column + ConstraintMustExist(t, db, "public", "employees", migrations.TemporaryName("fk_employee_department")) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // The foreign key constraint still exists on the column + ConstraintMustExist(t, db, "public", "employees", "fk_employee_department") + }, }, - }}) + }) } func TestSetCheckConstraintValidation(t *testing.T) {