Skip to content

Commit

Permalink
Parameterize schema in which migration tests run via `PGROLL_TEST_SCH…
Browse files Browse the repository at this point in the history
…EMA` env var (#276)

Parameterize the migration tests so that the schema in which migrations
are applied is taken from the `PGROLL_TEST_SCHEMA` environment variable,
or `public` if unset.

#273 highlights an area in which there is a lack of test coverage:
migration tests apply migrations only in the `public` schema. #273 is an
error that is only reproducible when running migrations in a
non-`public` schema and there are currently also other problems with
migration validation in non-`public` schema.

This PR makes repetitive changes to all tests:
* Update the `afterStart`, `afterComplete` and `afterRollback` hooks to
take a `schema string` parameter.
* Update all hard-coded occurences of `"public"` in those hooks with the
`schema` parameter.

Later PRs will fix the issues with running migrations in non-`public`
schema and run non-`public` migration tests in CI. For now, to run
migration tests in a non-`public` schema run:

```
PGROLL_TEST_SCHEMA=foo go test ./...
```

Part of #273
  • Loading branch information
andrew-farries committed Feb 5, 2024
1 parent a5cf473 commit c95bc78
Show file tree
Hide file tree
Showing 19 changed files with 721 additions and 705 deletions.
156 changes: 78 additions & 78 deletions pkg/migrations/op_add_column_test.go

Large diffs are not rendered by default.

102 changes: 51 additions & 51 deletions pkg/migrations/op_change_type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,91 +60,91 @@ func TestChangeColumnType(t *testing.T) {
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
newVersionSchema := roll.VersionedSchemaName("public", "02_change_type")
afterStart: func(t *testing.T, db *sql.DB, schema string) {
newVersionSchema := roll.VersionedSchemaName(schema, "02_change_type")

// The new (temporary) `rating` column should exist on the underlying table.
ColumnMustExist(t, db, "public", "reviews", migrations.TemporaryName("rating"))
ColumnMustExist(t, db, schema, "reviews", migrations.TemporaryName("rating"))

// The `rating` column in the new view must have the correct type.
ColumnMustHaveType(t, db, newVersionSchema, "reviews", "rating", "integer")

// Inserting into the new `rating` column should work.
MustInsert(t, db, "public", "02_change_type", "reviews", map[string]string{
MustInsert(t, db, schema, "02_change_type", "reviews", map[string]string{
"username": "alice",
"product": "apple",
"rating": "5",
})

// The value inserted into the new `rating` column has been backfilled into
// the old `rating` column.
rows := MustSelect(t, db, "public", "01_add_table", "reviews")
rows := MustSelect(t, db, schema, "01_add_table", "reviews")
assert.Equal(t, []map[string]any{
{"id": 1, "username": "alice", "product": "apple", "rating": "5"},
}, rows)

// Inserting into the old `rating` column should work.
MustInsert(t, db, "public", "01_add_table", "reviews", map[string]string{
MustInsert(t, db, schema, "01_add_table", "reviews", map[string]string{
"username": "bob",
"product": "banana",
"rating": "8",
})

// The value inserted into the old `rating` column has been backfilled into
// the new `rating` column.
rows = MustSelect(t, db, "public", "02_change_type", "reviews")
rows = MustSelect(t, db, schema, "02_change_type", "reviews")
assert.Equal(t, []map[string]any{
{"id": 1, "username": "alice", "product": "apple", "rating": 5},
{"id": 2, "username": "bob", "product": "banana", "rating": 8},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB) {
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
// The new (temporary) `rating` column should not exist on the underlying table.
ColumnMustNotExist(t, db, "public", "reviews", migrations.TemporaryName("rating"))
ColumnMustNotExist(t, db, schema, "reviews", migrations.TemporaryName("rating"))

// The up function no longer exists.
FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", "rating"))
FunctionMustNotExist(t, db, schema, migrations.TriggerFunctionName("reviews", "rating"))
// The down function no longer exists.
FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", migrations.TemporaryName("rating")))
FunctionMustNotExist(t, db, schema, migrations.TriggerFunctionName("reviews", migrations.TemporaryName("rating")))

// The up trigger no longer exists.
TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", "rating"))
TriggerMustNotExist(t, db, schema, "reviews", migrations.TriggerName("reviews", "rating"))
// The down trigger no longer exists.
TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("rating")))
TriggerMustNotExist(t, db, schema, "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("rating")))
},
afterComplete: func(t *testing.T, db *sql.DB) {
newVersionSchema := roll.VersionedSchemaName("public", "02_change_type")
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
newVersionSchema := roll.VersionedSchemaName(schema, "02_change_type")

// The new (temporary) `rating` column should not exist on the underlying table.
ColumnMustNotExist(t, db, "public", "reviews", migrations.TemporaryName("rating"))
ColumnMustNotExist(t, db, schema, "reviews", migrations.TemporaryName("rating"))

// The `rating` column in the new view must have the correct type.
ColumnMustHaveType(t, db, newVersionSchema, "reviews", "rating", "integer")

// Inserting into the new view should work.
MustInsert(t, db, "public", "02_change_type", "reviews", map[string]string{
MustInsert(t, db, schema, "02_change_type", "reviews", map[string]string{
"username": "carl",
"product": "carrot",
"rating": "3",
})

// Selecting from the new view should succeed.
rows := MustSelect(t, db, "public", "02_change_type", "reviews")
rows := MustSelect(t, db, schema, "02_change_type", "reviews")
assert.Equal(t, []map[string]any{
{"id": 1, "username": "alice", "product": "apple", "rating": 5},
{"id": 2, "username": "bob", "product": "banana", "rating": 8},
{"id": 3, "username": "carl", "product": "carrot", "rating": 3},
}, rows)

// The up function no longer exists.
FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", "rating"))
FunctionMustNotExist(t, db, schema, migrations.TriggerFunctionName("reviews", "rating"))
// The down function no longer exists.
FunctionMustNotExist(t, db, "public", migrations.TriggerFunctionName("reviews", migrations.TemporaryName("rating")))
FunctionMustNotExist(t, db, schema, migrations.TriggerFunctionName("reviews", migrations.TemporaryName("rating")))

// The up trigger no longer exists.
TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", "rating"))
TriggerMustNotExist(t, db, schema, "reviews", migrations.TriggerName("reviews", "rating"))
// The down trigger no longer exists.
TriggerMustNotExist(t, db, "public", "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("rating")))
TriggerMustNotExist(t, db, schema, "reviews", migrations.TriggerName("reviews", migrations.TemporaryName("rating")))
},
},
{
Expand Down Expand Up @@ -212,15 +212,15 @@ func TestChangeColumnType(t *testing.T) {
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// A temporary FK constraint has been created on the temporary column
ValidatedForeignKeyMustExist(t, db, "public", "employees", migrations.DuplicationName("fk_employee_department"))
ValidatedForeignKeyMustExist(t, db, schema, "employees", migrations.DuplicationName("fk_employee_department"))
},
afterRollback: func(t *testing.T, db *sql.DB) {
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// The foreign key constraint still exists on the column
ValidatedForeignKeyMustExist(t, db, "public", "employees", "fk_employee_department")
ValidatedForeignKeyMustExist(t, db, schema, "employees", "fk_employee_department")
},
},
{
Expand Down Expand Up @@ -260,28 +260,28 @@ func TestChangeColumnType(t *testing.T) {
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// A row can be inserted into the new version of the table.
MustInsert(t, db, "public", "02_change_type", "users", map[string]string{
MustInsert(t, db, schema, "02_change_type", "users", map[string]string{
"id": "1",
})

// The newly inserted row respects the default value of the column.
rows := MustSelect(t, db, "public", "02_change_type", "users")
rows := MustSelect(t, db, schema, "02_change_type", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "username": "alice"},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB) {
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// A row can be inserted into the new version of the table.
MustInsert(t, db, "public", "02_change_type", "users", map[string]string{
MustInsert(t, db, schema, "02_change_type", "users", map[string]string{
"id": "2",
})

// The newly inserted row respects the default value of the column.
rows := MustSelect(t, db, "public", "02_change_type", "users")
rows := MustSelect(t, db, schema, "02_change_type", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "username": "alice"},
{"id": 2, "username": "alice"},
Expand Down Expand Up @@ -328,18 +328,18 @@ func TestChangeColumnType(t *testing.T) {
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_change_type", "users", map[string]string{
MustNotInsert(t, db, schema, "02_change_type", "users", map[string]string{
"id": "1",
"username": "a",
}, testutils.CheckViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB) {
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Inserting a row that violates the check constraint should fail.
MustNotInsert(t, db, "public", "02_change_type", "users", map[string]string{
MustNotInsert(t, db, schema, "02_change_type", "users", map[string]string{
"id": "2",
"username": "b",
}, testutils.CheckViolationErrorCode)
Expand Down Expand Up @@ -381,17 +381,17 @@ func TestChangeColumnType(t *testing.T) {
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// Inserting a row that violates the NOT NULL constraint fails.
MustNotInsert(t, db, "public", "02_change_type", "users", map[string]string{
MustNotInsert(t, db, schema, "02_change_type", "users", map[string]string{
"id": "1",
}, testutils.NotNullViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB) {
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Inserting a row that violates the NOT NULL constraint fails.
MustNotInsert(t, db, "public", "02_change_type", "users", map[string]string{
MustNotInsert(t, db, schema, "02_change_type", "users", map[string]string{
"id": "2",
}, testutils.NotNullViolationErrorCode)
},
Expand Down Expand Up @@ -432,27 +432,27 @@ func TestChangeColumnType(t *testing.T) {
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// Inserting an initial row succeeds
MustInsert(t, db, "public", "02_change_type", "users", map[string]string{
MustInsert(t, db, schema, "02_change_type", "users", map[string]string{
"username": "alice",
})

// Inserting a row with a duplicate `username` value fails
MustNotInsert(t, db, "public", "02_change_type", "users", map[string]string{
MustNotInsert(t, db, schema, "02_change_type", "users", map[string]string{
"username": "alice",
}, testutils.UniqueViolationErrorCode)
},
afterRollback: func(t *testing.T, db *sql.DB) {
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
},
afterComplete: func(t *testing.T, db *sql.DB) {
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Inserting a row with a duplicate `username` value fails
MustNotInsert(t, db, "public", "02_change_type", "users", map[string]string{
MustNotInsert(t, db, schema, "02_change_type", "users", map[string]string{
"username": "alice",
}, testutils.UniqueViolationErrorCode)

// Inserting a row with a different `username` value succeeds
MustInsert(t, db, "public", "02_change_type", "users", map[string]string{
MustInsert(t, db, schema, "02_change_type", "users", map[string]string{
"username": "bob",
})
},
Expand Down
16 changes: 9 additions & 7 deletions pkg/migrations/op_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ type TestCase struct {
name string
migrations []migrations.Migration
wantStartErr error
afterStart func(t *testing.T, db *sql.DB)
afterComplete func(t *testing.T, db *sql.DB)
afterRollback func(t *testing.T, db *sql.DB)
afterStart func(t *testing.T, db *sql.DB, schema string)
afterComplete func(t *testing.T, db *sql.DB, schema string)
afterRollback func(t *testing.T, db *sql.DB, schema string)
}

type TestCases []TestCase
Expand All @@ -33,9 +33,11 @@ func TestMain(m *testing.M) {
}

func ExecuteTests(t *testing.T, tests TestCases) {
testSchema := testutils.TestSchema()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testutils.WithMigratorAndConnectionToContainer(t, func(mig *roll.Roll, db *sql.DB) {
testutils.WithMigratorInSchemaAndConnectionToContainer(t, testSchema, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()

// run all migrations except the last one
Expand Down Expand Up @@ -63,7 +65,7 @@ func ExecuteTests(t *testing.T, tests TestCases) {

// run the afterStart hook
if tt.afterStart != nil {
tt.afterStart(t, db)
tt.afterStart(t, db, testSchema)
}

// roll back the migration
Expand All @@ -73,7 +75,7 @@ func ExecuteTests(t *testing.T, tests TestCases) {

// run the afterRollback hook
if tt.afterRollback != nil {
tt.afterRollback(t, db)
tt.afterRollback(t, db, testSchema)
}

// re-start the last migration
Expand All @@ -88,7 +90,7 @@ func ExecuteTests(t *testing.T, tests TestCases) {

// run the afterComplete hook
if tt.afterComplete != nil {
tt.afterComplete(t, db)
tt.afterComplete(t, db, testSchema)
}
})
})
Expand Down
20 changes: 10 additions & 10 deletions pkg/migrations/op_create_index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ func TestCreateIndex(t *testing.T) {
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// The index has been created on the underlying table.
IndexMustExist(t, db, "public", "users", "idx_users_name")
IndexMustExist(t, db, schema, "users", "idx_users_name")
},
afterRollback: func(t *testing.T, db *sql.DB) {
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
// The index has been dropped from the the underlying table.
IndexMustNotExist(t, db, "public", "users", "idx_users_name")
IndexMustNotExist(t, db, schema, "users", "idx_users_name")
},
afterComplete: func(t *testing.T, db *sql.DB) {
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Complete is a no-op.
},
}})
Expand Down Expand Up @@ -102,15 +102,15 @@ func TestCreateIndexOnMultipleColumns(t *testing.T) {
},
},
},
afterStart: func(t *testing.T, db *sql.DB) {
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// The index has been created on the underlying table.
IndexMustExist(t, db, "public", "users", "idx_users_name_email")
IndexMustExist(t, db, schema, "users", "idx_users_name_email")
},
afterRollback: func(t *testing.T, db *sql.DB) {
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
// The index has been dropped from the the underlying table.
IndexMustNotExist(t, db, "public", "users", "idx_users_name_email")
IndexMustNotExist(t, db, schema, "users", "idx_users_name_email")
},
afterComplete: func(t *testing.T, db *sql.DB) {
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Complete is a no-op.
},
}})
Expand Down
Loading

0 comments on commit c95bc78

Please sign in to comment.