diff --git a/pkg/migrations/backfill.go b/pkg/migrations/backfill.go index efa119b8..b0cad98a 100644 --- a/pkg/migrations/backfill.go +++ b/pkg/migrations/backfill.go @@ -19,19 +19,18 @@ import ( // 3. Update each row in the batch, setting the value of the primary key column to itself. // 4. Repeat steps 2 and 3 until no more rows are returned. func backfill(ctx context.Context, conn *sql.DB, table *schema.Table, cbs ...CallbackFn) error { - // Get the primary key column for the table - pks := table.GetPrimaryKey() - if len(pks) != 1 { - return errors.New("table must have a single primary key column") + // get the backfill column + identityColumn := getIdentityColumn(table) + if identityColumn == nil { + return BackfillNotPossibleError{Table: table.Name} } - pk := pks[0] // Create a batcher for the table. b := batcher{ - table: table, - pkColumn: pk, - lastPK: nil, - batchSize: 1000, + table: table, + identityColumn: identityColumn, + lastValue: nil, + batchSize: 1000, } // Update each batch of rows, invoking callbacks for each one. @@ -51,11 +50,39 @@ func backfill(ctx context.Context, conn *sql.DB, table *schema.Table, cbs ...Cal return nil } +// checkBackfill will return an error if the backfill operation is not supported. +func checkBackfill(table *schema.Table) error { + col := getIdentityColumn(table) + if col == nil { + return BackfillNotPossibleError{Table: table.Name} + } + + return nil +} + +// getIdentityColumn will return a column suitable for use in a backfill operation. +func getIdentityColumn(table *schema.Table) *schema.Column { + pks := table.GetPrimaryKey() + if len(pks) == 1 { + return pks[0] + } + + // If there is no primary key, look for a unique not null column + for _, col := range table.Columns { + if col.Unique && !col.Nullable { + return &col + } + } + + // no suitable column found + return nil +} + type batcher struct { - table *schema.Table - pkColumn *schema.Column - lastPK *string - batchSize int + table *schema.Table + identityColumn *schema.Column + lastValue *string + batchSize int } // updateBatch updates the next batch of rows in the table. @@ -72,7 +99,7 @@ func (b *batcher) updateBatch(ctx context.Context, conn *sql.DB) error { // Execute the query to update the next batch of rows and update the last PK // value for the next batch - err = tx.QueryRowContext(ctx, query).Scan(&b.lastPK) + err = tx.QueryRowContext(ctx, query).Scan(&b.lastValue) if err != nil { return err } @@ -84,8 +111,8 @@ func (b *batcher) updateBatch(ctx context.Context, conn *sql.DB) error { // buildQuery builds the query used to update the next batch of rows. func (b *batcher) buildQuery() string { whereClause := "" - if b.lastPK != nil { - whereClause = fmt.Sprintf("WHERE %s > %v", pq.QuoteIdentifier(b.pkColumn.Name), pq.QuoteLiteral(*b.lastPK)) + if b.lastValue != nil { + whereClause = fmt.Sprintf("WHERE %s > %v", pq.QuoteIdentifier(b.identityColumn.Name), pq.QuoteLiteral(*b.lastValue)) } return fmt.Sprintf(` @@ -96,7 +123,7 @@ func (b *batcher) buildQuery() string { ) SELECT LAST_VALUE(%[1]s) OVER() FROM update `, - pq.QuoteIdentifier(b.pkColumn.Name), + pq.QuoteIdentifier(b.identityColumn.Name), pq.QuoteIdentifier(b.table.Name), b.batchSize, whereClause) diff --git a/pkg/migrations/errors.go b/pkg/migrations/errors.go index a3cf0113..96d77095 100644 --- a/pkg/migrations/errors.go +++ b/pkg/migrations/errors.go @@ -157,13 +157,12 @@ func (e MultipleAlterColumnChangesError) Error() string { return fmt.Sprintf("alter column operations require exactly one change, found %d", e.Changes) } -type InvalidPrimaryKeyError struct { - Table string - Fields int +type BackfillNotPossibleError struct { + Table string } -func (e InvalidPrimaryKeyError) Error() string { - return fmt.Sprintf("primary key on table %q must be defined on exactly one column, found %d", e.Table, e.Fields) +func (e BackfillNotPossibleError) Error() string { + return fmt.Sprintf("a backfill is required but table %q doesn't have a single column primary key or a UNIQUE, NOT NULL column", e.Table) } type InvalidReplicaIdentityError struct { diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index 21157c2d..df8d9f75 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -164,11 +164,11 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error { } } + // Ensure backfill is possible if o.Up != nil { - // needs backfill, ensure that the table has a primary key defined on exactly one column. - pk := table.GetPrimaryKey() - if len(pk) != 1 { - return InvalidPrimaryKeyError{Table: o.Table, Fields: len(pk)} + err := checkBackfill(table) + if err != nil { + return err } } diff --git a/pkg/migrations/op_add_column_test.go b/pkg/migrations/op_add_column_test.go index 07cd8696..61cfecb5 100644 --- a/pkg/migrations/op_add_column_test.go +++ b/pkg/migrations/op_add_column_test.go @@ -486,6 +486,98 @@ func TestAddColumnWithUpSql(t *testing.T) { triggerFnName := migrations.TriggerFunctionName("products", "description") FunctionMustNotExist(t, db, schema, triggerFnName) + // The trigger has been dropped. + triggerName := migrations.TriggerName("products", "description") + TriggerMustNotExist(t, db, schema, "products", triggerName) + }, + }, + { + name: "add column with up sql and no pk", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "products", + Columns: []migrations.Column{ + { + Name: "id", + Type: "text", + }, + { + Name: "name", + Type: "varchar(255)", + Unique: ptr(true), + Nullable: ptr(false), + }, + }, + }, + // insert some data into the table to test backfill in the next migration + &migrations.OpRawSQL{ + Up: "INSERT INTO products (id, name) VALUES ('c', 'cherries')", + OnComplete: true, + }, + }, + }, + { + Name: "02_add_column", + Operations: migrations.Operations{ + &migrations.OpAddColumn{ + Table: "products", + Up: ptr("UPPER(name)"), + Column: migrations.Column{ + Name: "description", + Type: "varchar(255)", + Nullable: ptr(true), + }, + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + // inserting via both the old and the new views works + MustInsert(t, db, schema, "01_add_table", "products", map[string]string{ + "id": "a", + "name": "apple", + }) + MustInsert(t, db, schema, "02_add_column", "products", map[string]string{ + "id": "b", + "name": "banana", + "description": "a yellow banana", + }) + + res := MustSelect(t, db, schema, "02_add_column", "products") + assert.Equal(t, []map[string]any{ + // the description column has been populated by the backfill process + {"id": "c", "name": "cherries", "description": "CHERRIES"}, + // the description column has been populated for the product inserted into the old view. + {"id": "a", "name": "apple", "description": "APPLE"}, + // the description column for the product inserted into the new view is as inserted. + {"id": "b", "name": "banana", "description": "a yellow banana"}, + }, res) + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + // The trigger function has been dropped. + triggerFnName := migrations.TriggerFunctionName("products", "description") + FunctionMustNotExist(t, db, schema, triggerFnName) + + // The trigger has been dropped. + triggerName := migrations.TriggerName("products", "description") + TriggerMustNotExist(t, db, schema, "products", triggerName) + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + // after rollback + restart + complete, all 'description' values are the backfilled ones. + res := MustSelect(t, db, schema, "02_add_column", "products") + assert.Equal(t, []map[string]any{ + {"id": "a", "name": "apple", "description": "APPLE"}, + {"id": "b", "name": "banana", "description": "BANANA"}, + {"id": "c", "name": "cherries", "description": "CHERRIES"}, + }, res) + + // The trigger function has been dropped. + triggerFnName := migrations.TriggerFunctionName("products", "description") + FunctionMustNotExist(t, db, schema, triggerFnName) + // The trigger has been dropped. triggerName := migrations.TriggerName("products", "description") TriggerMustNotExist(t, db, schema, "products", triggerName) @@ -497,70 +589,72 @@ func TestAddColumnWithUpSql(t *testing.T) { func TestAddNotNullColumnWithNoDefault(t *testing.T) { t.Parallel() - ExecuteTests(t, TestCases{{ - name: "add not null column with no default", - migrations: []migrations.Migration{ - { - Name: "01_add_table", - Operations: migrations.Operations{ - &migrations.OpCreateTable{ - Name: "products", - Columns: []migrations.Column{ - { - Name: "id", - Type: "serial", - Pk: ptr(true), - }, - { - Name: "name", - Type: "varchar(255)", - Unique: ptr(true), + ExecuteTests(t, TestCases{ + { + name: "add not null column with no default", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "products", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: ptr(true), + }, + { + Name: "name", + Type: "varchar(255)", + Unique: ptr(true), + }, }, }, }, }, - }, - { - Name: "02_add_column", - Operations: migrations.Operations{ - &migrations.OpAddColumn{ - Table: "products", - Up: ptr("UPPER(name)"), - Column: migrations.Column{ - Name: "description", - Type: "varchar(255)", - Nullable: ptr(false), + { + Name: "02_add_column", + Operations: migrations.Operations{ + &migrations.OpAddColumn{ + Table: "products", + Up: ptr("UPPER(name)"), + Column: migrations.Column{ + Name: "description", + Type: "varchar(255)", + Nullable: ptr(false), + }, }, }, }, }, + afterStart: func(t *testing.T, db *sql.DB, schema string) { + // Inserting a null description through the old view works (due to `up` sql populating the column). + MustInsert(t, db, schema, "01_add_table", "products", map[string]string{ + "name": "apple", + }) + // Inserting a null description through the new view fails. + MustNotInsert(t, db, schema, "02_add_column", "products", map[string]string{ + "name": "banana", + }, testutils.CheckViolationErrorCode) + }, + afterRollback: func(t *testing.T, db *sql.DB, schema string) { + // the check constraint has been dropped. + constraintName := migrations.NotNullConstraintName("description") + CheckConstraintMustNotExist(t, db, schema, "products", constraintName) + }, + afterComplete: func(t *testing.T, db *sql.DB, schema string) { + // the check constraint has been dropped. + constraintName := migrations.NotNullConstraintName("description") + CheckConstraintMustNotExist(t, db, schema, "products", constraintName) + + // can't insert a null description into the new view; the column now has a NOT NULL constraint. + MustNotInsert(t, db, schema, "02_add_column", "products", map[string]string{ + "name": "orange", + }, testutils.NotNullViolationErrorCode) + }, }, - afterStart: func(t *testing.T, db *sql.DB, schema string) { - // Inserting a null description through the old view works (due to `up` sql populating the column). - MustInsert(t, db, schema, "01_add_table", "products", map[string]string{ - "name": "apple", - }) - // Inserting a null description through the new view fails. - MustNotInsert(t, db, schema, "02_add_column", "products", map[string]string{ - "name": "banana", - }, testutils.CheckViolationErrorCode) - }, - afterRollback: func(t *testing.T, db *sql.DB, schema string) { - // the check constraint has been dropped. - constraintName := migrations.NotNullConstraintName("description") - CheckConstraintMustNotExist(t, db, schema, "products", constraintName) - }, - afterComplete: func(t *testing.T, db *sql.DB, schema string) { - // the check constraint has been dropped. - constraintName := migrations.NotNullConstraintName("description") - CheckConstraintMustNotExist(t, db, schema, "products", constraintName) - - // can't insert a null description into the new view; the column now has a NOT NULL constraint. - MustNotInsert(t, db, schema, "02_add_column", "products", map[string]string{ - "name": "orange", - }, testutils.NotNullViolationErrorCode) - }, - }}) + }) } func TestAddColumnValidation(t *testing.T) { @@ -587,6 +681,48 @@ func TestAddColumnValidation(t *testing.T) { }, } + addTableMigrationNoPKNullable := migrations.Migration{ + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + }, + { + Name: "name", + Type: "varchar(255)", + Unique: ptr(true), + Nullable: ptr(true), + }, + }, + }, + }, + } + + addTableMigrationNoPKNotNull := migrations.Migration{ + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + }, + { + Name: "name", + Type: "varchar(255)", + Unique: ptr(true), + Nullable: ptr(false), + }, + }, + }, + }, + } + ExecuteTests(t, TestCases{ { name: "table must exist", @@ -674,7 +810,7 @@ func TestAddColumnValidation(t *testing.T) { }, }, }, - wantStartErr: migrations.InvalidPrimaryKeyError{Table: "orders", Fields: 2}, + wantStartErr: migrations.BackfillNotPossibleError{Table: "orders"}, }, { name: "table has no restrictions on primary keys if up is not defined", @@ -704,6 +840,48 @@ func TestAddColumnValidation(t *testing.T) { }, wantStartErr: nil, }, + { + name: "table must have a primary key on exactly one column or a unique not null if up is defined", + migrations: []migrations.Migration{ + addTableMigrationNoPKNullable, + { + Name: "02_add_column", + Operations: migrations.Operations{ + &migrations.OpAddColumn{ + Table: "users", + Up: ptr("UPPER(name)"), + Column: migrations.Column{ + Default: ptr("'foo'"), + Name: "description", + Type: "text", + }, + }, + }, + }, + }, + wantStartErr: migrations.BackfillNotPossibleError{Table: "users"}, + }, + { + name: "table with a unique not null column can be backfilled", + migrations: []migrations.Migration{ + addTableMigrationNoPKNotNull, + { + Name: "02_add_column", + Operations: migrations.Operations{ + &migrations.OpAddColumn{ + Table: "users", + Up: ptr("UPPER(name)"), + Column: migrations.Column{ + Default: ptr("'foo'"), + Name: "description", + Type: "text", + }, + }, + }, + }, + }, + wantStartErr: nil, + }, }) } diff --git a/pkg/migrations/op_alter_column.go b/pkg/migrations/op_alter_column.go index 84e22a5f..ca2474fa 100644 --- a/pkg/migrations/op_alter_column.go +++ b/pkg/migrations/op_alter_column.go @@ -45,9 +45,9 @@ func (o *OpAlterColumn) Validate(ctx context.Context, s *schema.Schema) error { } // Ensure that the column has a primary key defined on exactly one column. - pk := table.GetPrimaryKey() - if len(pk) != 1 { - return InvalidPrimaryKeyError{Table: o.Table, Fields: len(pk)} + err := checkBackfill(table) + if err != nil { + return err } // Apply any special validation rules for the inner operation diff --git a/pkg/migrations/op_alter_column_test.go b/pkg/migrations/op_alter_column_test.go index fff77f79..71e0b607 100644 --- a/pkg/migrations/op_alter_column_test.go +++ b/pkg/migrations/op_alter_column_test.go @@ -177,7 +177,7 @@ func TestAlterColumnValidation(t *testing.T) { }, }, }, - wantStartErr: migrations.InvalidPrimaryKeyError{Table: "orders", Fields: 2}, + wantStartErr: migrations.BackfillNotPossibleError{Table: "orders"}, }, }) } diff --git a/pkg/state/state.go b/pkg/state/state.go index 79f3597a..d5db99ec 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -131,7 +131,7 @@ BEGIN SELECT 1 FROM pg_constraint WHERE conrelid = attr.attrelid - AND conkey::int[] @> ARRAY[attr.attnum::int] + AND ARRAY[attr.attnum::int] @> conkey::int[] AND contype = 'u' ) OR EXISTS ( SELECT 1 @@ -139,7 +139,7 @@ BEGIN JOIN pg_class ON pg_class.oid = pg_index.indexrelid WHERE indrelid = attr.attrelid AND indisunique - AND pg_index.indkey::int[] @> ARRAY[attr.attnum::int] + AND ARRAY[attr.attnum::int] @> pg_index.indkey::int[] )) AS unique FROM pg_attribute AS attr diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index 0602b023..3b9a6023 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -430,6 +430,53 @@ func TestReadSchema(t *testing.T) { }, }, }, + { + name: "multicolumn unique constraint", + createStmt: "CREATE TABLE public.table1 (id int PRIMARY KEY, name TEXT, CONSTRAINT name_id_unique UNIQUE(id, name));", + wantSchema: &schema.Schema{ + Name: "public", + Tables: map[string]schema.Table{ + "table1": { + Name: "table1", + Columns: map[string]schema.Column{ + "id": { + Name: "id", + Type: "integer", + Nullable: false, + Unique: true, + }, + "name": { + Name: "name", + Type: "text", + Nullable: true, + Unique: false, + }, + }, + PrimaryKey: []string{"id"}, + Indexes: map[string]schema.Index{ + "table1_pkey": { + Name: "table1_pkey", + Unique: true, + Columns: []string{"id"}, + }, + "name_id_unique": { + Name: "name_id_unique", + Unique: true, + Columns: []string{"id", "name"}, + }, + }, + ForeignKeys: map[string]schema.ForeignKey{}, + CheckConstraints: map[string]schema.CheckConstraint{}, + UniqueConstraints: map[string]schema.UniqueConstraint{ + "name_id_unique": { + Name: "name_id_unique", + Columns: []string{"id", "name"}, + }, + }, + }, + }, + }, + }, { name: "multi-column index", createStmt: "CREATE TABLE public.table1 (a text, b text); CREATE INDEX idx_ab ON public.table1 (a, b);",