Skip to content

Commit

Permalink
Support creating foreign key constraints on the create table operation (
Browse files Browse the repository at this point in the history
#79)

Allow creating foreign key columns when doing a **create table**
operation. For example:

```json
{
  "name": "19_create_orders_table",
  "operations": [
    {
      "create_table": {
        "name": "orders",
        "columns": [
          {
            "name": "id",
            "type": "serial",
            "pk": true
          },
          {
            "name": "user_id",
            "type": "integer",
            "references": {
              "table": "users",
              "column": "id"
            }
          },
          {
            "name": "quantity",
            "type": "int"
          }
        ]
      }
    }
  ]
}
```
Here the `user_id` column references the `id` column in the `users`
table.

The constraint is added to the table on `Start` and removed on
`Rollback`.
  • Loading branch information
andrew-farries committed Sep 5, 2023
1 parent 0341c94 commit f94d252
Show file tree
Hide file tree
Showing 5 changed files with 296 additions and 40 deletions.
29 changes: 29 additions & 0 deletions examples/19_create_orders_table.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"name": "19_create_orders_table",
"operations": [
{
"create_table": {
"name": "orders",
"columns": [
{
"name": "id",
"type": "serial",
"pk": true
},
{
"name": "user_id",
"type": "integer",
"references": {
"table": "users",
"column": "id"
}
},
{
"name": "quantity",
"type": "int"
}
]
}
}
]
}
17 changes: 17 additions & 0 deletions pkg/migrations/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,20 @@ type FieldRequiredError struct {
func (e FieldRequiredError) Error() string {
return fmt.Sprintf("field %q is required", e.Name)
}

type ColumnReferenceError struct {
Table string
Column string
Err error
}

func (e ColumnReferenceError) Unwrap() error {
return e.Err
}

func (e ColumnReferenceError) Error() string {
return fmt.Sprintf("column reference to column %q in table %q is invalid: %s",
e.Column,
e.Table,
e.Err.Error())
}
4 changes: 4 additions & 0 deletions pkg/migrations/op_add_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error {
return errors.New("adding primary key columns is not supported")
}

if o.Column.References != nil {
return errors.New("adding foreign key columns is not supported")
}

return nil
}

Expand Down
55 changes: 49 additions & 6 deletions pkg/migrations/op_create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@ type OpCreateTable struct {
}

type Column struct {
Name string `json:"name"`
Type string `json:"type"`
Nullable bool `json:"nullable"`
Unique bool `json:"unique"`
PrimaryKey bool `json:"pk"`
Default *string `json:"default"`
Name string `json:"name"`
Type string `json:"type"`
Nullable bool `json:"nullable"`
Unique bool `json:"unique"`
PrimaryKey bool `json:"pk"`
Default *string `json:"default"`
References *ColumnReference `json:"references"`
}

type ColumnReference struct {
Table string `json:"table"`
Column string `json:"column"`
}

func (o *OpCreateTable) Start(ctx context.Context, conn *sql.DB, stateSchema string, s *schema.Schema) error {
Expand Down Expand Up @@ -70,6 +76,30 @@ func (o *OpCreateTable) Validate(ctx context.Context, s *schema.Schema) error {
if table != nil {
return TableAlreadyExistsError{Name: o.Name}
}

for _, col := range o.Columns {
if col.References != nil {
table := s.GetTable(col.References.Table)
if table == nil {
return ColumnReferenceError{
Table: o.Name,
Column: col.Name,
Err: TableDoesNotExistError{Name: col.References.Table},
}
}
if _, ok := table.Columns[col.References.Column]; !ok {
return ColumnReferenceError{
Table: o.Name,
Column: col.Name,
Err: ColumnDoesNotExistError{
Table: col.References.Table,
Name: col.References.Column,
},
}
}
}
}

return nil
}

Expand Down Expand Up @@ -99,5 +129,18 @@ func ColumnToSQL(col Column) string {
if col.Default != nil {
sql += fmt.Sprintf(" DEFAULT %s", pq.QuoteLiteral(*col.Default))
}
if col.References != nil {
tableRef := col.References.Table
columnRef := col.References.Column

sql += fmt.Sprintf(" CONSTRAINT %s REFERENCES %s(%s)",
pq.QuoteIdentifier(ForeignKeyConstraintName(col.Name, tableRef, columnRef)),
pq.QuoteIdentifier(tableRef),
pq.QuoteIdentifier(columnRef))
}
return sql
}

func ForeignKeyConstraintName(columnName, tableRef, columnRef string) string {
return "_pgroll_fk_" + columnName + "_" + tableRef + "_" + columnRef
}
231 changes: 197 additions & 34 deletions pkg/migrations/op_create_table_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,173 @@ import (
func TestCreateTable(t *testing.T) {
t.Parallel()

ExecuteTests(t, TestCases{
{
name: "create table",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "name",
Type: "varchar(255)",
Unique: true,
},
},
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// The new view exists in the new version schema.
ViewMustExist(t, db, "public", "01_create_table", "users")

// Data can be inserted into the new view.
MustInsert(t, db, "public", "01_create_table", "users", map[string]string{
"name": "Alice",
})

// Data can be retrieved from the new view.
rows := MustSelect(t, db, "public", "01_create_table", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "Alice"},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB) {
// The underlying table has been dropped.
TableMustNotExist(t, db, "public", "users")
},
afterComplete: func(t *testing.T, db *sql.DB) {
// The view still exists
ViewMustExist(t, db, "public", "01_create_table", "users")

// Data can be inserted into the new view.
MustInsert(t, db, "public", "01_create_table", "users", map[string]string{
"name": "Alice",
})

// Data can be retrieved from the new view.
rows := MustSelect(t, db, "public", "01_create_table", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "Alice"},
}, rows)
},
},
{
name: "create table with foreign key",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "name",
Type: "varchar(255)",
Unique: true,
},
},
},
},
},
{
Name: "02_create_table_with_fk",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "orders",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "user_id",
Type: "integer",
References: &migrations.ColumnReference{
Table: "users",
Column: "id",
},
},
{
Name: "quantity",
Type: "integer",
},
},
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// The foreign key constraint exists on the new table.
constraintName := migrations.ForeignKeyConstraintName("user_id", "users", "id")
ConstraintMustExist(t, db, "public", migrations.TemporaryName("orders"), constraintName)

// Inserting a row into the referenced table succeeds.
MustInsert(t, db, "public", "01_create_table", "users", map[string]string{
"name": "alice",
})

// Inserting a row into the referencing table succeeds as the referenced row exists.
MustInsert(t, db, "public", "02_create_table_with_fk", "orders", map[string]string{
"user_id": "1",
"quantity": "100",
})

// Inserting a row into the referencing table fails as the referenced row does not exist.
MustNotInsert(t, db, "public", "02_create_table_with_fk", "orders", map[string]string{
"user_id": "2",
"quantity": "200",
})
},
afterRollback: func(t *testing.T, db *sql.DB) {
// The table has been dropped, so the foreign key constraint is gone.
},
afterComplete: func(t *testing.T, db *sql.DB) {
// The foreign key constraint still exists on the new table.
constraintName := migrations.ForeignKeyConstraintName("user_id", "users", "id")
ConstraintMustExist(t, db, "public", "orders", constraintName)

// Inserting a row into the referenced table succeeds.
MustInsert(t, db, "public", "02_create_table_with_fk", "users", map[string]string{
"name": "bob",
})

// Inserting a row into the referencing table succeeds as the referenced row exists.
MustInsert(t, db, "public", "02_create_table_with_fk", "orders", map[string]string{
"user_id": "2",
"quantity": "200",
})

// Inserting a row into the referencing table fails as the referenced row does not exist.
MustNotInsert(t, db, "public", "02_create_table_with_fk", "orders", map[string]string{
"user_id": "3",
"quantity": "300",
})
},
},
})
}

func TestCreateTableValidation(t *testing.T) {
t.Parallel()

ExecuteTests(t, TestCases{TestCase{
name: "create table",
name: "foreign key validity",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Expand All @@ -35,40 +200,38 @@ func TestCreateTable(t *testing.T) {
},
},
},
{
Name: "02_create_table_with_fk",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "orders",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
PrimaryKey: true,
},
{
Name: "user_id",
Type: "integer",
References: &migrations.ColumnReference{
Table: "users",
Column: "doesntexist",
},
},
{
Name: "quantity",
Type: "integer",
},
},
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
// The new view exists in the new version schema.
ViewMustExist(t, db, "public", "01_create_table", "users")

// Data can be inserted into the new view.
MustInsert(t, db, "public", "01_create_table", "users", map[string]string{
"name": "Alice",
})

// Data can be retrieved from the new view.
rows := MustSelect(t, db, "public", "01_create_table", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "Alice"},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB) {
// The underlying table has been dropped.
TableMustNotExist(t, db, "public", "users")
},
afterComplete: func(t *testing.T, db *sql.DB) {
// The view still exists
ViewMustExist(t, db, "public", "01_create_table", "users")

// Data can be inserted into the new view.
MustInsert(t, db, "public", "01_create_table", "users", map[string]string{
"name": "Alice",
})

// Data can be retrieved from the new view.
rows := MustSelect(t, db, "public", "01_create_table", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "Alice"},
}, rows)
wantStartErr: migrations.ColumnReferenceError{
Table: "orders",
Column: "user_id",
Err: migrations.ColumnDoesNotExistError{Table: "users", Name: "doesntexist"},
},
}})
}

0 comments on commit f94d252

Please sign in to comment.