Skip to content

Commit

Permalink
Add test for WithKickstartReplication option
Browse files Browse the repository at this point in the history
  • Loading branch information
andrew-farries committed Feb 28, 2024
1 parent 869add8 commit aceeb19
Showing 1 changed file with 53 additions and 15 deletions.
68 changes: 53 additions & 15 deletions pkg/roll/execute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,6 @@ func TestDisabledSchemaManagement(t *testing.T) {
})
}

func schemaExists(t *testing.T, db *sql.DB, schema string) bool {
t.Helper()
var exists bool
err := db.QueryRow(`
SELECT EXISTS(
SELECT 1
FROM pg_catalog.pg_namespace
WHERE nspname = $1
)`, schema).Scan(&exists)
if err != nil {
t.Fatal(err)
}
return exists
}

func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -547,6 +532,27 @@ func TestWithSettingsOnMigrationStartIsRespected(t *testing.T) {
})
}

func TestWithKickstartReplicationCleansUp(t *testing.T) {
t.Parallel()

options := []roll.Option{roll.WithKickstartReplication()}

testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", options, func(mig *roll.Roll, db *sql.DB) {
ctx := context.Background()

// Start a create table migration
if err := mig.Start(ctx, &migrations.Migration{Name: "01_create_table", Operations: migrations.Operations{createTableOp("table1")}}); err != nil {
t.Fatalf("Failed to start migration: %v", err)
}

// Ensure that there is only one table in the schema - the no-op table
// created by the WithKickstartReplication option has been removed
if count := tableCount(t, db, "public"); count != 1 {
t.Errorf("Expected 1 table in schema %q, got %d", "public", count)
}
})
}

func createTableOp(tableName string) *migrations.OpCreateTable {
return &migrations.OpCreateTable{
Name: tableName,
Expand Down Expand Up @@ -619,6 +625,38 @@ func MustSelect(t *testing.T, db *sql.DB, schema, version, table string) []map[s
return res
}

func schemaExists(t *testing.T, db *sql.DB, schema string) bool {
t.Helper()
var exists bool
err := db.QueryRow(`
SELECT EXISTS(
SELECT 1
FROM pg_catalog.pg_namespace
WHERE nspname = $1
)`, schema).Scan(&exists)
if err != nil {
t.Fatal(err)
}
return exists
}

func tableCount(t *testing.T, db *sql.DB, schema string) int {
t.Helper()

var count int
err := db.QueryRow(`
SELECT COUNT(*)
FROM pg_catalog.pg_tables
WHERE schemaname = $1
`,
schema).Scan(&count)
if err != nil {
t.Fatal(err)
}

return count
}

func ptr[T any](v T) *T {
return &v
}

0 comments on commit aceeb19

Please sign in to comment.