Skip to content

Commit

Permalink
Fix inferred migrations format (#259)
Browse files Browse the repository at this point in the history
Inferred migrations do not follow the same format as all other stored
migrations. From the schema history point of view, all migrations in the
list should apply correctly, including these.

This change ensures we store them following the same format, so
replaying a migration history is possible.
  • Loading branch information
exekias committed Feb 1, 2024
1 parent c5a3dbd commit 7e65cda
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 12 deletions.
18 changes: 16 additions & 2 deletions pkg/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ SECURITY DEFINER
SET search_path = %[1]s, pg_catalog, pg_temp AS $$
DECLARE
schemaname TEXT;
migration_id TEXT;
BEGIN
-- Ignore migrations done by pgroll
IF (pg_catalog.current_setting('pgroll.internal', 'TRUE') <> 'TRUE') THEN
Expand Down Expand Up @@ -290,11 +291,24 @@ BEGIN
END IF;
-- Someone did a schema change without pgroll, include it in the history
SELECT INTO migration_id pg_catalog.format('sql_%%s',pg_catalog.substr(pg_catalog.md5(pg_catalog.random()::text), 0, 15));
INSERT INTO %[1]s.migrations (schema, name, migration, resulting_schema, done, parent, migration_type)
VALUES (
schemaname,
pg_catalog.format('sql_%%s',pg_catalog.substr(pg_catalog.md5(pg_catalog.random()::text), 0, 15)),
pg_catalog.json_build_object('sql', pg_catalog.json_build_object('up', pg_catalog.current_query())),
migration_id,
pg_catalog.json_build_object(
'name', migration_id,
'operations', (
SELECT pg_catalog.json_agg(
pg_catalog.json_build_object(
'sql', pg_catalog.json_build_object(
'up', pg_catalog.current_query()
)
)
)
)
),
%[1]s.read_schema(schemaname),
true,
%[1]s.latest_version(schemaname),
Expand Down
73 changes: 63 additions & 10 deletions pkg/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ package state_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/schema"
Expand All @@ -31,11 +35,6 @@ func TestSchemaOptionIsRespected(t *testing.T) {
t.Fatal(err)
}

// init the state
if err := state.Init(ctx); err != nil {
t.Fatal(err)
}

// check that starting a new migration returns the already existing table
currentSchema, err := state.Start(ctx, "public", &migrations.Migration{
Name: "1_add_column",
Expand All @@ -56,6 +55,65 @@ func TestSchemaOptionIsRespected(t *testing.T) {
})
}

func TestInferredMigration(t *testing.T) {
t.Parallel()

testutils.WithStateAndConnectionToContainer(t, func(state *state.State, db *sql.DB) {
ctx := context.Background()

tests := []struct {
name string
sqlStmt string
wantMigration migrations.Migration
}{
{
name: "create table",
sqlStmt: "CREATE TABLE public.table1 (id int)",
wantMigration: migrations.Migration{
Operations: migrations.Operations{
&migrations.OpRawSQL{
Up: "CREATE TABLE public.table1 (id int)",
},
},
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, err := db.ExecContext(ctx, "DROP SCHEMA public CASCADE; CREATE SCHEMA public"); err != nil {
t.Fatal(err)
}

if _, err := db.ExecContext(ctx, tt.sqlStmt); err != nil {
t.Fatal(err)
}

var migrationStr []byte
err := db.QueryRowContext(ctx,
fmt.Sprintf("SELECT migration FROM %s.migrations WHERE schema=$1", pq.QuoteIdentifier(state.Schema())), "public").
Scan(&migrationStr)
if err != nil {
t.Fatal(err)
}

var gotMigration migrations.Migration
if err := json.Unmarshal(migrationStr, &gotMigration); err != nil {
t.Fatal(err)
}

// test there is a name for the migration, then remove it for the comparison
assert.True(t, strings.HasPrefix(gotMigration.Name, "sql_") && len(gotMigration.Name) > 10)
gotMigration.Name = ""

if diff := cmp.Diff(tt.wantMigration, gotMigration); diff != "" {
t.Errorf("expected schema mismatch (-want +got):\n%s", diff)
}
})
}
})
}

func TestReadSchema(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -314,11 +372,6 @@ func TestReadSchema(t *testing.T) {
},
}

// init the state
if err := state.Init(ctx); err != nil {
t.Fatal(err)
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if _, err := db.ExecContext(ctx, "DROP SCHEMA public CASCADE; CREATE SCHEMA public"); err != nil {
Expand Down
5 changes: 5 additions & 0 deletions pkg/testutils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ func WithStateAndConnectionToContainer(t *testing.T, fn func(*state.State, *sql.
}
})

// init the state
if err := st.Init(ctx); err != nil {
t.Fatal(err)
}

fn(st, db)
}

Expand Down

0 comments on commit 7e65cda

Please sign in to comment.