From aceeb19ea10bfa1261a56abe0c74a029bb29aae6 Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Wed, 28 Feb 2024 11:42:56 +0000 Subject: [PATCH] Add test for `WithKickstartReplication` option --- pkg/roll/execute_test.go | 68 +++++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/pkg/roll/execute_test.go b/pkg/roll/execute_test.go index b9b5f350..ab974c71 100644 --- a/pkg/roll/execute_test.go +++ b/pkg/roll/execute_test.go @@ -82,21 +82,6 @@ func TestDisabledSchemaManagement(t *testing.T) { }) } -func schemaExists(t *testing.T, db *sql.DB, schema string) bool { - t.Helper() - var exists bool - err := db.QueryRow(` - SELECT EXISTS( - SELECT 1 - FROM pg_catalog.pg_namespace - WHERE nspname = $1 - )`, schema).Scan(&exists) - if err != nil { - t.Fatal(err) - } - return exists -} - func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { t.Parallel() @@ -547,6 +532,27 @@ func TestWithSettingsOnMigrationStartIsRespected(t *testing.T) { }) } +func TestWithKickstartReplicationCleansUp(t *testing.T) { + t.Parallel() + + options := []roll.Option{roll.WithKickstartReplication()} + + testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", options, func(mig *roll.Roll, db *sql.DB) { + ctx := context.Background() + + // Start a create table migration + if err := mig.Start(ctx, &migrations.Migration{Name: "01_create_table", Operations: migrations.Operations{createTableOp("table1")}}); err != nil { + t.Fatalf("Failed to start migration: %v", err) + } + + // Ensure that there is only one table in the schema - the no-op table + // created by the WithKickstartReplication option has been removed + if count := tableCount(t, db, "public"); count != 1 { + t.Errorf("Expected 1 table in schema %q, got %d", "public", count) + } + }) +} + func createTableOp(tableName string) *migrations.OpCreateTable { return &migrations.OpCreateTable{ Name: tableName, @@ -619,6 +625,38 @@ func MustSelect(t *testing.T, db *sql.DB, schema, version, table string) []map[s return res } +func schemaExists(t *testing.T, db *sql.DB, schema string) bool { + t.Helper() + var exists bool + err := db.QueryRow(` + SELECT EXISTS( + SELECT 1 + FROM pg_catalog.pg_namespace + WHERE nspname = $1 + )`, schema).Scan(&exists) + if err != nil { + t.Fatal(err) + } + return exists +} + +func tableCount(t *testing.T, db *sql.DB, schema string) int { + t.Helper() + + var count int + err := db.QueryRow(` + SELECT COUNT(*) + FROM pg_catalog.pg_tables + WHERE schemaname = $1 + `, + schema).Scan(&count) + if err != nil { + t.Fatal(err) + } + + return count +} + func ptr[T any](v T) *T { return &v }