From 2d574ff412d745f218eed1a33db908b62bf08fb0 Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Fri, 22 Sep 2023 10:00:32 +0100 Subject: [PATCH] Make `down` sql optional for the `set_unique` operation (#119) Follow up to https://github.com/xataio/pg-roll/pull/118 to make the `down` SQL optional. When making an existing column unique, the `down` sql is almost always going to be a simple copy from the new column to the old. --- pkg/migrations/op_set_unique.go | 11 +- pkg/migrations/op_set_unique_test.go | 240 +++++++++++++++++---------- 2 files changed, 163 insertions(+), 88 deletions(-) diff --git a/pkg/migrations/op_set_unique.go b/pkg/migrations/op_set_unique.go index bf08303d..11a7b705 100644 --- a/pkg/migrations/op_set_unique.go +++ b/pkg/migrations/op_set_unique.go @@ -69,7 +69,7 @@ func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema strin TableName: o.Table, PhysicalColumn: o.Column, StateSchema: stateSchema, - SQL: o.Down, + SQL: o.downSQL(), }) if err != nil { return fmt.Errorf("failed to create down trigger: %w", err) @@ -173,3 +173,12 @@ func (o *OpSetUnique) addUniqueIndex(ctx context.Context, conn *sql.DB) error { return err } + +// Down SQL is either user-specified or defaults to copying the value from the new column to the old. +func (o *OpSetUnique) downSQL() string { + if o.Down != "" { + return o.Down + } + + return o.Column +} diff --git a/pkg/migrations/op_set_unique_test.go b/pkg/migrations/op_set_unique_test.go index 9d9c3124..8ae4e3cc 100644 --- a/pkg/migrations/op_set_unique_test.go +++ b/pkg/migrations/op_set_unique_test.go @@ -4,112 +4,178 @@ import ( "database/sql" "testing" + "github.com/stretchr/testify/assert" "github.com/xataio/pg-roll/pkg/migrations" ) func TestSetColumnUnique(t *testing.T) { t.Parallel() - ExecuteTests(t, TestCases{{ - name: "set unique", - migrations: []migrations.Migration{ - { - Name: "01_add_table", - Operations: migrations.Operations{ - &migrations.OpCreateTable{ - Name: "reviews", - Columns: []migrations.Column{ - { - Name: "id", - Type: "serial", - PrimaryKey: true, - }, - { - Name: "username", - Type: "text", - Nullable: false, - }, - { - Name: "product", - Type: "text", - Nullable: false, - }, - { - Name: "review", - Type: "text", - Nullable: false, + ExecuteTests(t, TestCases{ + { + name: "set unique with default down sql", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "reviews", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + PrimaryKey: true, + }, + { + Name: "username", + Type: "text", + Nullable: false, + }, + { + Name: "product", + Type: "text", + Nullable: false, + }, + { + Name: "review", + Type: "text", + Nullable: false, + }, }, }, }, }, - }, - { - Name: "02_set_unique", - Operations: migrations.Operations{ - &migrations.OpAlterColumn{ - Table: "reviews", - Column: "review", - Unique: &migrations.UniqueConstraint{ - Name: "reviews_review_unique", + { + Name: "02_set_unique", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "reviews", + Column: "review", + Unique: &migrations.UniqueConstraint{ + Name: "reviews_review_unique", + }, + Up: "review || '-' || (random()*1000000)::integer", }, - Up: "review || '-' || (random()*1000000)::integer", - Down: "review", }, }, }, - }, - afterStart: func(t *testing.T, db *sql.DB) { - // Inserting values into the old schema that violate uniqueness should succeed. - MustInsert(t, db, "public", "01_add_table", "reviews", map[string]string{ - "username": "alice", "product": "apple", "review": "good", - }) - MustInsert(t, db, "public", "01_add_table", "reviews", map[string]string{ - "username": "bob", "product": "banana", "review": "good", - }) + afterStart: func(t *testing.T, db *sql.DB) { + // Inserting values into the old schema that violate uniqueness should succeed. + MustInsert(t, db, "public", "01_add_table", "reviews", map[string]string{ + "username": "alice", "product": "apple", "review": "good", + }) + MustInsert(t, db, "public", "01_add_table", "reviews", map[string]string{ + "username": "bob", "product": "banana", "review": "good", + }) - // Inserting values into the new schema that violate uniqueness should fail. - MustInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ - "username": "carl", "product": "carrot", "review": "bad", - }) - MustNotInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ - "username": "dana", "product": "durian", "review": "bad", - }) - }, - afterRollback: func(t *testing.T, db *sql.DB) { - // The new (temporary) `review` column should not exist on the underlying table. - ColumnMustNotExist(t, db, "public", "reviews", migrations.TemporaryName("review")) + // Inserting values into the new schema that violate uniqueness should fail. + MustInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ + "username": "carl", "product": "carrot", "review": "bad", + }) + MustNotInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ + "username": "dana", "product": "durian", "review": "bad", + }) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + // The new (temporary) `review` column should not exist on the underlying table. + ColumnMustNotExist(t, db, "public", "reviews", migrations.TemporaryName("review")) - // The up function no longer exists. - FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", "review")) - // The down function no longer exists. - FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", migrations.TemporaryName("review"))) + // The up function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", "review")) + // The down function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", migrations.TemporaryName("review"))) - // The up trigger no longer exists. - TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", "review")) - // The down trigger no longer exists. - TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("review"))) - }, - afterComplete: func(t *testing.T, db *sql.DB) { - // The new (temporary) `review` column should not exist on the underlying table. - ColumnMustNotExist(t, db, "public", "reviews", migrations.TemporaryName("review")) + // The up trigger no longer exists. + TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", "review")) + // The down trigger no longer exists. + TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("review"))) + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // The new (temporary) `review` column should not exist on the underlying table. + ColumnMustNotExist(t, db, "public", "reviews", migrations.TemporaryName("review")) - // The up function no longer exists. - FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", "review")) - // The down function no longer exists. - FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", migrations.TemporaryName("review"))) + // The up function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", "review")) + // The down function no longer exists. + FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", migrations.TemporaryName("review"))) - // The up trigger no longer exists. - TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", "review")) - // The down trigger no longer exists. - TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("review"))) + // The up trigger no longer exists. + TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", "review")) + // The down trigger no longer exists. + TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("review"))) - // Inserting values into the new schema that violate uniqueness should fail. - MustInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ - "username": "earl", "product": "elderberry", "review": "ok", - }) - MustNotInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ - "username": "flora", "product": "fig", "review": "ok", - }) + // Inserting values into the new schema that violate uniqueness should fail. + MustInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ + "username": "earl", "product": "elderberry", "review": "ok", + }) + MustNotInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ + "username": "flora", "product": "fig", "review": "ok", + }) + }, + }, + { + name: "set unique with default user supplied down sql", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "reviews", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + PrimaryKey: true, + }, + { + Name: "username", + Type: "text", + Nullable: false, + }, + { + Name: "product", + Type: "text", + Nullable: false, + }, + { + Name: "review", + Type: "text", + Nullable: false, + }, + }, + }, + }, + }, + { + Name: "02_set_unique", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "reviews", + Column: "review", + Unique: &migrations.UniqueConstraint{ + Name: "reviews_review_unique", + }, + Up: "review || '-' || (random()*1000000)::integer", + Down: "review || '!'", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // Inserting values into the new schema backfills the old column using the `down` SQL. + MustInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ + "username": "carl", "product": "carrot", "review": "bad", + }) + + rows := MustSelect(t, db, "public", "01_add_table", "reviews") + assert.Equal(t, []map[string]any{ + {"id": 1, "username": "carl", "product": "carrot", "review": "bad!"}, + }, rows) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + }, }, - }}) + }) }