Skip to content

Commit

Permalink
Preserve foreign key constraints on columns duplicated for backfilling (
Browse files Browse the repository at this point in the history
#230)

When duplicating a column for backfilling to add `NOT NULL` constraint,
ensure that any foreign keys on the duplicated column are preserved.

This fixes the issue where adding `NOT NULL` to an FK column would drop
the FK from the column.

This is part of a larger class of issues where duplicated columns are
not faithfully preserving all properties of the original, tracked in
#227.
  • Loading branch information
andrew-farries committed Jan 15, 2024
1 parent 4b124b0 commit b81c42c
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 7 deletions.
70 changes: 70 additions & 0 deletions pkg/migrations/duplicate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// SPDX-License-Identifier: Apache-2.0

package migrations

import (
"context"
"database/sql"
"fmt"
"slices"
"strings"

"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/schema"
)

type Duplicator struct {
conn *sql.DB
table *schema.Table
column *schema.Column
asName string
}

// NewColumnDuplicator creates a new Duplicator for a column.
func NewColumnDuplicator(conn *sql.DB, table *schema.Table, column *schema.Column) *Duplicator {
return &Duplicator{
conn: conn,
table: table,
column: column,
asName: TemporaryName(column.Name),
}
}

// Duplicate creates a new column with the same type and foreign key
// constraints as the original column.
func (d *Duplicator) Duplicate(ctx context.Context) error {
const (
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)`
)

sql := fmt.Sprintf(cAlterTableSQL,
pq.QuoteIdentifier(d.table.Name),
pq.QuoteIdentifier(d.asName),
d.column.Type)

for _, fk := range d.table.ForeignKeys {
if slices.Contains(fk.Columns, d.column.Name) {
sql += fmt.Sprintf(", "+cAddForeignKeySQL,
pq.QuoteIdentifier(TemporaryName(fk.Name)),
strings.Join(quoteColumnNames(copyAndReplace(fk.Columns, d.column.Name, d.asName)), ", "),
pq.QuoteIdentifier(fk.ReferencedTable),
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "))
}
}

_, err := d.conn.ExecContext(ctx, sql)

return err
}

func copyAndReplace(xs []string, oldValue, newValue string) []string {
ys := slices.Clone(xs)

for i, c := range ys {
if c == oldValue {
ys[i] = newValue
}
}
return ys
}
9 changes: 8 additions & 1 deletion pkg/migrations/op_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"strings"
)

type OpName string
Expand All @@ -33,8 +34,14 @@ const (
OpNameChangeType OpName = "change_type"
)

const temporaryPrefix = "_pgroll_new_"

func TemporaryName(name string) string {
return "_pgroll_new_" + name
return temporaryPrefix + name
}

func StripTemporaryPrefix(name string) string {
return strings.TrimPrefix(name, temporaryPrefix)
}

func ReadMigration(r io.Reader) (*Migration, error) {
Expand Down
11 changes: 5 additions & 6 deletions pkg/migrations/op_set_notnull.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ func (o *OpSetNotNull) Start(ctx context.Context, conn *sql.DB, stateSchema stri
column := table.GetColumn(o.Column)

// Create a copy of the column on the underlying table.
if err := duplicateColumn(ctx, conn, table, *column); err != nil {
d := NewColumnDuplicator(conn, table, column)
if err := d.Duplicate(ctx); err != nil {
return fmt.Errorf("failed to duplicate column: %w", err)
}

Expand Down Expand Up @@ -130,11 +131,9 @@ func (o *OpSetNotNull) Complete(ctx context.Context, conn *sql.DB, s *schema.Sch
}

// Rename the new column to the old column name
_, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(TemporaryName(o.Column)),
pq.QuoteIdentifier(o.Column)))
if err != nil {
table := s.GetTable(o.Table)
column := table.GetColumn(o.Column)
if err := RenameDuplicatedColumn(ctx, conn, table, column); err != nil {
return err
}

Expand Down
77 changes: 77 additions & 0 deletions pkg/migrations/op_set_notnull_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,83 @@ func TestSetNotNull(t *testing.T) {
afterComplete: func(t *testing.T, db *sql.DB) {
},
},
{
name: "setting a foreign key column to not null retains the foreign key constraint",
migrations: []migrations.Migration{
{
Name: "01_add_departments_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "departments",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: true,
},
{
Name: "name",
Type: "text",
Nullable: false,
},
},
},
},
},
{
Name: "02_add_employees_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "employees",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: true,
},
{
Name: "name",
Type: "text",
Nullable: false,
},
{
Name: "department_id",
Type: "integer",
Nullable: true,
References: &migrations.ForeignKeyReference{
Name: "fk_employee_department",
Table: "departments",
Column: "id",
},
},
},
},
},
},
{
Name: "03_set_not_null",
Operations: migrations.Operations{
&migrations.OpAlterColumn{
Table: "employees",
Column: "department_id",
Nullable: ptr(false),
Up: "(SELECT CASE WHEN department_id IS NULL THEN 1 ELSE department_id END)",
Down: "department_id",
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// A temporary FK constraint has been created on the temporary column
ConstraintMustExist(t, db, "public", "employees", migrations.TemporaryName("fk_employee_department"))
},
afterRollback: func(t *testing.T, db *sql.DB) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
// The foreign key constraint still exists on the column
ConstraintMustExist(t, db, "public", "employees", "fk_employee_department")
},
},
})
}

Expand Down
53 changes: 53 additions & 0 deletions pkg/migrations/rename.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// SPDX-License-Identifier: Apache-2.0

package migrations

import (
"context"
"database/sql"
"fmt"
"slices"

"github.com/lib/pq"
"github.com/xataio/pgroll/pkg/schema"
)

// RenameDuplicatedColumn renames a duplicated column to its original name and renames any foreign keys
// on the duplicated column to their original name.
func RenameDuplicatedColumn(ctx context.Context, conn *sql.DB, table *schema.Table, column *schema.Column) error {
const (
cRenameColumnSQL = `ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s`
cRenameConstraintSQL = `ALTER TABLE IF EXISTS %s RENAME CONSTRAINT %s TO %s`
)

// Rename the old column to the new column name
renameColumnSQL := fmt.Sprintf(cRenameColumnSQL,
pq.QuoteIdentifier(table.Name),
pq.QuoteIdentifier(TemporaryName(column.Name)),
pq.QuoteIdentifier(column.Name))

_, err := conn.ExecContext(ctx, renameColumnSQL)
if err != nil {
return fmt.Errorf("failed to rename duplicated column %q: %w", column.Name, err)
}

// Rename any foreign keys on the duplicated column from their temporary name
// to their original name
var renameConstraintSQL string
for _, fk := range table.ForeignKeys {
if slices.Contains(fk.Columns, TemporaryName(column.Name)) {
renameConstraintSQL = fmt.Sprintf(cRenameConstraintSQL,
pq.QuoteIdentifier(table.Name),
pq.QuoteIdentifier(fk.Name),
pq.QuoteIdentifier(StripTemporaryPrefix(fk.Name)),
)

_, err = conn.ExecContext(ctx, renameConstraintSQL)
if err != nil {
return fmt.Errorf("failed to rename column constraint %q: %w", fk.Name, err)
}
}
}

return nil
}

0 comments on commit b81c42c

Please sign in to comment.