From 162bd06f9f3074e7dcc898a3ce045e80312714fc Mon Sep 17 00:00:00 2001 From: Andrew Farries Date: Wed, 6 Mar 2024 09:56:52 +0000 Subject: [PATCH] Fix duplicate inferred migrations when dropping columns outside of a migration (#305) Ensure that only one `inferred` migration is created in the `pgroll.migrations` table when a column is dropped outside of a migration. From the Postgres [docs](https://www.postgresql.org/docs/current/event-trigger-definition.html): > The sql_drop event occurs just before the ddl_command_end event trigger for any operation that drops database objects This means that when the `raw_migration` function is run in response to `sql_drop` and `ddl_command_end`, duplicate entries will be created in `pgroll.migrations`; once as the function is run for `sql_drop` and again when it's run for `ddl_command_end`. Change the definition of the `pg_roll_handle_drop` event trigger to only run on those kinds of drops that won't result in duplicates when the `pg_roll_handle_ddl` trigger runs for the same change. `DROP TABLE` and `DROP VIEW` won't result in duplicate migrations because their schema can't be inferred by the `ddl_command_event` trigger because the object has already been dropped when the trigger runs. Update the inferred migration tests with two new testcases covering dropping tables and columns. Fixes #304 --- pkg/state/state.go | 3 +- pkg/state/state_test.go | 186 ++++++++++++++++++++++++++++++++++------ 2 files changed, 163 insertions(+), 26 deletions(-) diff --git a/pkg/state/state.go b/pkg/state/state.go index d5db99ec..2b2d7eec 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -268,7 +268,7 @@ BEGIN RETURN; END IF; - IF tg_event = 'sql_drop' THEN + IF tg_event = 'sql_drop' and tg_tag != 'ALTER TABLE' THEN -- Guess the schema from drop commands SELECT schema_name INTO schemaname FROM pg_catalog.pg_event_trigger_dropped_objects() WHERE schema_name IS NOT NULL; @@ -324,7 +324,6 @@ CREATE EVENT TRIGGER pg_roll_handle_ddl ON ddl_command_end DROP EVENT TRIGGER IF EXISTS pg_roll_handle_drop; CREATE EVENT TRIGGER pg_roll_handle_drop ON sql_drop EXECUTE FUNCTION %[1]s.raw_migration(); - ` type State struct { diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index 3b9a6023..4c567243 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -13,7 +13,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/lib/pq" "github.com/stretchr/testify/assert" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/schema" @@ -63,17 +62,137 @@ func TestInferredMigration(t *testing.T) { ctx := context.Background() tests := []struct { - name string - sqlStmt string - wantMigration migrations.Migration + name string + sqlStmts []string + wantMigrations []migrations.Migration }{ { - name: "create table", - sqlStmt: "CREATE TABLE public.table1 (id int)", - wantMigration: migrations.Migration{ - Operations: migrations.Operations{ - &migrations.OpRawSQL{ - Up: "CREATE TABLE public.table1 (id int)", + name: "create table", + sqlStmts: []string{"CREATE TABLE public.table1 (id int)"}, + wantMigrations: []migrations.Migration{ + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "CREATE TABLE public.table1 (id int)"}, + }, + }, + }, + }, + { + name: "create/drop table", + sqlStmts: []string{ + "CREATE TABLE table1 (id int)", + "DROP TABLE table1", + }, + wantMigrations: []migrations.Migration{ + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "CREATE TABLE table1 (id int)"}, + }, + }, + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "DROP TABLE table1"}, + }, + }, + }, + }, + { + name: "create/drop column", + sqlStmts: []string{ + "CREATE TABLE table1 (id int, b text)", + "ALTER TABLE table1 DROP COLUMN b", + }, + wantMigrations: []migrations.Migration{ + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "CREATE TABLE table1 (id int, b text)"}, + }, + }, + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "ALTER TABLE table1 DROP COLUMN b"}, + }, + }, + }, + }, + { + name: "create/drop check constraint", + sqlStmts: []string{ + "CREATE TABLE table1 (id int, age integer, CONSTRAINT check_age CHECK (age > 0))", + "ALTER TABLE table1 DROP CONSTRAINT check_age", + }, + wantMigrations: []migrations.Migration{ + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "CREATE TABLE table1 (id int, age integer, CONSTRAINT check_age CHECK (age > 0))"}, + }, + }, + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "ALTER TABLE table1 DROP CONSTRAINT check_age"}, + }, + }, + }, + }, + { + name: "create/drop unique constraint", + sqlStmts: []string{ + "CREATE TABLE table1 (id int, b text, CONSTRAINT unique_b UNIQUE(b))", + "ALTER TABLE table1 DROP CONSTRAINT unique_b", + }, + wantMigrations: []migrations.Migration{ + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "CREATE TABLE table1 (id int, b text, CONSTRAINT unique_b UNIQUE(b))"}, + }, + }, + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "ALTER TABLE table1 DROP CONSTRAINT unique_b"}, + }, + }, + }, + }, + { + name: "create/drop index", + sqlStmts: []string{ + "CREATE TABLE table1 (id int, b text)", + "CREATE INDEX idx_b ON table1(b)", + "DROP INDEX idx_b", + }, + wantMigrations: []migrations.Migration{ + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "CREATE TABLE table1 (id int, b text)"}, + }, + }, + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "CREATE INDEX idx_b ON table1(b)"}, + }, + }, + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "DROP INDEX idx_b"}, + }, + }, + }, + }, + { + name: "create/drop function", + sqlStmts: []string{ + "CREATE FUNCTION foo() RETURNS void AS $$ BEGIN END; $$ LANGUAGE plpgsql", + "DROP FUNCTION foo", + }, + wantMigrations: []migrations.Migration{ + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "CREATE FUNCTION foo() RETURNS void AS $$ BEGIN END; $$ LANGUAGE plpgsql"}, + }, + }, + { + Operations: migrations.Operations{ + &migrations.OpRawSQL{Up: "DROP FUNCTION foo"}, }, }, }, @@ -86,29 +205,48 @@ func TestInferredMigration(t *testing.T) { t.Fatal(err) } - if _, err := db.ExecContext(ctx, tt.sqlStmt); err != nil { + if _, err := db.ExecContext(ctx, fmt.Sprintf("TRUNCATE %s.migrations", state.Schema())); err != nil { t.Fatal(err) } - var migrationStr []byte - err := db.QueryRowContext(ctx, - fmt.Sprintf("SELECT migration FROM %s.migrations WHERE schema=$1", pq.QuoteIdentifier(state.Schema())), "public"). - Scan(&migrationStr) - if err != nil { - t.Fatal(err) + for _, stmt := range tt.sqlStmts { + if _, err := db.ExecContext(ctx, stmt); err != nil { + t.Fatal(err) + } } - var gotMigration migrations.Migration - if err := json.Unmarshal(migrationStr, &gotMigration); err != nil { + rows, err := db.QueryContext(ctx, + fmt.Sprintf("SELECT migration FROM %s.migrations ORDER BY created_at ASC", state.Schema())) + if err != nil { t.Fatal(err) } + defer rows.Close() + + var gotMigrations []migrations.Migration + for rows.Next() { + var migrationStr []byte + if err := rows.Scan(&migrationStr); err != nil { + t.Fatal(err) + } + var gotMigration migrations.Migration + if err := json.Unmarshal(migrationStr, &gotMigration); err != nil { + t.Fatal(err) + } + gotMigrations = append(gotMigrations, gotMigration) + } - // test there is a name for the migration, then remove it for the comparison - assert.True(t, strings.HasPrefix(gotMigration.Name, "sql_") && len(gotMigration.Name) > 10) - gotMigration.Name = "" + assert.Equal(t, len(tt.wantMigrations), len(gotMigrations), "unexpected number of migrations") - if diff := cmp.Diff(tt.wantMigration, gotMigration); diff != "" { - t.Errorf("expected schema mismatch (-want +got):\n%s", diff) + for i, wantMigration := range tt.wantMigrations { + gotMigration := gotMigrations[i] + + // test there is a name for the migration, then remove it for the comparison + assert.True(t, strings.HasPrefix(gotMigration.Name, "sql_") && len(gotMigration.Name) > 10) + gotMigration.Name = "" + + if diff := cmp.Diff(wantMigration, gotMigration); diff != "" { + t.Errorf("expected schema mismatch (-want +got):\n%s", diff) + } } }) }