From 845a699d9388fbbf2095d19865b0b55c64c84306 Mon Sep 17 00:00:00 2001 From: Qiao Han Date: Thu, 10 Aug 2023 22:03:32 +0800 Subject: [PATCH 1/4] feat: support resetting database to a specific version --- cmd/db.go | 4 ++- internal/db/reset/reset.go | 35 ++++++++++++++++++-------- internal/db/reset/reset_test.go | 16 ++++++------ internal/db/start/start.go | 2 +- internal/migration/apply/apply.go | 4 +-- internal/migration/apply/apply_test.go | 6 ++--- 6 files changed, 41 insertions(+), 26 deletions(-) diff --git a/cmd/db.go b/cmd/db.go index b5dbe1f93c..ed9d1a8902 100644 --- a/cmd/db.go +++ b/cmd/db.go @@ -175,6 +175,7 @@ var ( dbResetCmd = &cobra.Command{ Use: "reset", Short: "Resets the local database to current migrations", + Args: cobra.ExactArgs(0), RunE: func(cmd *cobra.Command, args []string) error { fsys := afero.NewOsFs() if linked || len(dbUrl) > 0 { @@ -183,7 +184,7 @@ var ( } } ctx, _ := signal.NotifyContext(cmd.Context(), os.Interrupt) - return reset.Run(ctx, dbConfig, fsys) + return reset.Run(ctx, version, dbConfig, fsys) }, } @@ -283,6 +284,7 @@ func init() { // Build reset command resetFlags := dbResetCmd.Flags() resetFlags.BoolVar(&linked, "linked", false, "Resets the linked project to current migrations.") + resetFlags.StringVar(&version, "version", "", "Reset up to the specified version.") dbCmd.AddCommand(dbResetCmd) // Build lint command lintFlags := dbLintCmd.Flags() diff --git a/internal/db/reset/reset.go b/internal/db/reset/reset.go index 0a465803d6..b65bb76a18 100644 --- a/internal/db/reset/reset.go +++ b/internal/db/reset/reset.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "os" + "strconv" "strings" "time" @@ -38,10 +39,14 @@ var ( dropObjects string ) -func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { +func Run(ctx context.Context, version string, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { + if len(version) > 0 { + if _, err := strconv.Atoi(version); err != nil { + return repair.ErrInvalidVersion + } + } if len(config.Password) > 0 { - fmt.Fprintln(os.Stderr, "Resetting remote database...") - return resetRemote(ctx, config, fsys, options...) + return resetRemote(ctx, version, config, fsys, options...) } // Sanity checks. @@ -55,7 +60,7 @@ func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...fu } // Reset postgres database because extensions (pg_cron, pg_net) require postgres - if err := resetDatabase(ctx, fsys, options...); err != nil { + if err := resetDatabase(ctx, version, fsys, options...); err != nil { return err } @@ -64,8 +69,8 @@ func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...fu return nil } -func resetDatabase(ctx context.Context, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - fmt.Fprintln(os.Stderr, "Resetting local database...") +func resetDatabase(ctx context.Context, version string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { + fmt.Fprintln(os.Stderr, "Resetting local database"+getMessage(version)) if err := recreateDatabase(ctx, options...); err != nil { return err } @@ -85,7 +90,14 @@ func resetDatabase(ctx context.Context, fsys afero.Fs, options ...func(*pgx.Conn return err } defer conn.Close(context.Background()) - return InitialiseDatabase(ctx, conn, fsys) + return InitialiseDatabase(ctx, version, conn, fsys) +} + +func getMessage(version string) string { + if len(version) > 0 { + return " to version: " + version + } + return "..." } func initDatabase(ctx context.Context, options ...func(*pgx.ConnConfig)) error { @@ -97,8 +109,8 @@ func initDatabase(ctx context.Context, options ...func(*pgx.ConnConfig)) error { return apply.BatchExecDDL(ctx, conn, strings.NewReader(utils.InitialSchemaSql)) } -func InitialiseDatabase(ctx context.Context, conn *pgx.Conn, fsys afero.Fs) error { - if err := apply.MigrateDatabase(ctx, conn, fsys); err != nil { +func InitialiseDatabase(ctx context.Context, version string, conn *pgx.Conn, fsys afero.Fs) error { + if err := apply.MigrateDatabase(ctx, version, conn, fsys); err != nil { return err } return SeedDatabase(ctx, conn, fsys) @@ -225,7 +237,8 @@ func WaitForServiceReady(ctx context.Context, started []string) error { return nil } -func resetRemote(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { +func resetRemote(ctx context.Context, version string, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { + fmt.Fprintln(os.Stderr, "Resetting remote database"+getMessage(version)) conn, err := utils.ConnectRemotePostgres(ctx, config, options...) if err != nil { return err @@ -248,7 +261,7 @@ func resetRemote(ctx context.Context, config pgconn.Config, fsys afero.Fs, optio if err := migration.ExecBatch(ctx, conn); err != nil { return err } - return InitialiseDatabase(ctx, conn, fsys) + return InitialiseDatabase(ctx, version, conn, fsys) } func ListSchemas(ctx context.Context, conn *pgx.Conn, exclude ...string) ([]string, error) { diff --git a/internal/db/reset/reset_test.go b/internal/db/reset/reset_test.go index 069db9b11d..c9439ecad3 100644 --- a/internal/db/reset/reset_test.go +++ b/internal/db/reset/reset_test.go @@ -29,13 +29,13 @@ func TestResetCommand(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() // Run test - err := Run(context.Background(), pgconn.Config{Password: "postgres"}, fsys) + err := Run(context.Background(), "", pgconn.Config{Password: "postgres"}, fsys) // Check error assert.ErrorContains(t, err, "invalid port (outside range)") }) t.Run("throws error on missing config", func(t *testing.T) { - err := Run(context.Background(), pgconn.Config{}, afero.NewMemMapFs()) + err := Run(context.Background(), "", pgconn.Config{}, afero.NewMemMapFs()) assert.ErrorIs(t, err, os.ErrNotExist) }) @@ -50,7 +50,7 @@ func TestResetCommand(t *testing.T) { Get("/v" + utils.Docker.ClientVersion() + "/containers"). Reply(http.StatusServiceUnavailable) // Run test - err := Run(context.Background(), pgconn.Config{}, fsys) + err := Run(context.Background(), "", pgconn.Config{}, fsys) // Check error assert.ErrorIs(t, err, utils.ErrNotRunning) assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -73,7 +73,7 @@ func TestResetCommand(t *testing.T) { conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false;"). ReplyError(pgerrcode.InvalidParameterValue, `cannot disallow connections for current database`) // Run test - err := Run(context.Background(), pgconn.Config{}, fsys, conn.Intercept) + err := Run(context.Background(), "", pgconn.Config{}, fsys, conn.Intercept) // Check error assert.ErrorContains(t, err, "ERROR: cannot disallow connections for current database (SQLSTATE 22023)") assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -371,7 +371,7 @@ func TestResetRemote(t *testing.T) { Query(dropObjects). Reply("INSERT 0") // Run test - err := resetRemote(context.Background(), dbConfig, fsys, conn.Intercept) + err := resetRemote(context.Background(), "", dbConfig, fsys, conn.Intercept) // Check error assert.NoError(t, err) }) @@ -380,7 +380,7 @@ func TestResetRemote(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() // Run test - err := resetRemote(context.Background(), pgconn.Config{}, fsys) + err := resetRemote(context.Background(), "", pgconn.Config{}, fsys) // Check error assert.ErrorContains(t, err, "invalid port (outside range)") }) @@ -394,7 +394,7 @@ func TestResetRemote(t *testing.T) { conn.Query(strings.ReplaceAll(LIST_SCHEMAS, "$1", "'{public,auth,extensions,pgbouncer,realtime,\"\\\\_realtime\",storage,\"\\\\_analytics\",\"supabase\\\\_functions\",\"supabase\\\\_migrations\",\"information\\\\_schema\",\"pg\\\\_%\",cron,graphql,\"graphql\\\\_public\",net,pgsodium,\"pgsodium\\\\_masks\",pgtle,repack,tiger,\"tiger\\\\_data\",\"timescaledb\\\\_%\",\"\\\\_timescaledb\\\\_%\",topology,vault}'")). ReplyError(pgerrcode.InsufficientPrivilege, "permission denied for relation information_schema") // Run test - err := resetRemote(context.Background(), dbConfig, fsys, conn.Intercept) + err := resetRemote(context.Background(), "", dbConfig, fsys, conn.Intercept) // Check error assert.ErrorContains(t, err, "ERROR: permission denied for relation information_schema (SQLSTATE 42501)") }) @@ -411,7 +411,7 @@ func TestResetRemote(t *testing.T) { ReplyError(pgerrcode.InsufficientPrivilege, "permission denied for relation supabase_migrations"). Query(dropObjects) // Run test - err := resetRemote(context.Background(), dbConfig, fsys, conn.Intercept) + err := resetRemote(context.Background(), "", dbConfig, fsys, conn.Intercept) // Check error assert.ErrorContains(t, err, "ERROR: permission denied for relation supabase_migrations (SQLSTATE 42501)") }) diff --git a/internal/db/start/start.go b/internal/db/start/start.go index e44ccd7df9..ed21389e8d 100644 --- a/internal/db/start/start.go +++ b/internal/db/start/start.go @@ -163,7 +163,7 @@ func setupDatabase(ctx context.Context, fsys afero.Fs, w io.Writer, options ...f if err := SetupDatabase(ctx, conn, utils.DbId, w, fsys); err != nil { return err } - return reset.InitialiseDatabase(ctx, conn, fsys) + return reset.InitialiseDatabase(ctx, "", conn, fsys) } func SetupDatabase(ctx context.Context, conn *pgx.Conn, host string, w io.Writer, fsys afero.Fs) error { diff --git a/internal/migration/apply/apply.go b/internal/migration/apply/apply.go index 3d6821fd47..f8cff7289d 100644 --- a/internal/migration/apply/apply.go +++ b/internal/migration/apply/apply.go @@ -14,8 +14,8 @@ import ( "github.com/supabase/cli/internal/utils" ) -func MigrateDatabase(ctx context.Context, conn *pgx.Conn, fsys afero.Fs) error { - migrations, err := list.LoadLocalMigrations(fsys) +func MigrateDatabase(ctx context.Context, version string, conn *pgx.Conn, fsys afero.Fs) error { + migrations, err := list.LoadPartialMigrations(version, fsys) if err != nil { return err } diff --git a/internal/migration/apply/apply_test.go b/internal/migration/apply/apply_test.go index 1e29a6104a..8c574d22ad 100644 --- a/internal/migration/apply/apply_test.go +++ b/internal/migration/apply/apply_test.go @@ -45,20 +45,20 @@ func TestMigrateDatabase(t *testing.T) { require.NoError(t, err) defer mock.Close(ctx) // Run test - err = MigrateDatabase(ctx, mock, fsys) + err = MigrateDatabase(ctx, "", mock, fsys) // Check error assert.NoError(t, err) }) t.Run("ignores empty local directory", func(t *testing.T) { - assert.NoError(t, MigrateDatabase(context.Background(), nil, afero.NewMemMapFs())) + assert.NoError(t, MigrateDatabase(context.Background(), "", nil, afero.NewMemMapFs())) }) t.Run("throws error on write failure", func(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() // Run test - err := MigrateDatabase(context.Background(), nil, afero.NewReadOnlyFs(fsys)) + err := MigrateDatabase(context.Background(), "", nil, afero.NewReadOnlyFs(fsys)) // Check error assert.ErrorIs(t, err, os.ErrPermission) }) From 2882a681516171598e8af5e5427f204d4426b094 Mon Sep 17 00:00:00 2001 From: Qiao Han Date: Fri, 11 Aug 2023 11:24:04 +0800 Subject: [PATCH 2/4] fix: validate migration version exists --- internal/db/reset/reset.go | 3 +++ internal/migration/repair/repair.go | 20 ++++++++++++++------ internal/migration/squash/squash.go | 3 +++ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/internal/db/reset/reset.go b/internal/db/reset/reset.go index b65bb76a18..183502519f 100644 --- a/internal/db/reset/reset.go +++ b/internal/db/reset/reset.go @@ -44,6 +44,9 @@ func Run(ctx context.Context, version string, config pgconn.Config, fsys afero.F if _, err := strconv.Atoi(version); err != nil { return repair.ErrInvalidVersion } + if _, err := repair.GetMigrationFile(version, fsys); err != nil { + return err + } } if len(config.Password) > 0 { return resetRemote(ctx, version, config, fsys, options...) diff --git a/internal/migration/repair/repair.go b/internal/migration/repair/repair.go index cb53ab8ca8..aded29dad3 100644 --- a/internal/migration/repair/repair.go +++ b/internal/migration/repair/repair.go @@ -111,6 +111,18 @@ func DeleteVersionSQL(batch *pgconn.Batch, version string) { ) } +func GetMigrationFile(version string, fsys afero.Fs) (string, error) { + path := filepath.Join(utils.MigrationsDir, version+"_*.sql") + matches, err := afero.Glob(fsys, path) + if err != nil { + return "", err + } + if len(matches) == 0 { + return "", fmt.Errorf("glob %s: %w", path, os.ErrNotExist) + } + return matches[0], nil +} + type MigrationFile struct { Lines []string Version string @@ -118,15 +130,11 @@ type MigrationFile struct { } func NewMigrationFromVersion(version string, fsys afero.Fs) (*MigrationFile, error) { - path := filepath.Join(utils.MigrationsDir, version+"_*.sql") - matches, err := afero.Glob(fsys, path) + name, err := GetMigrationFile(version, fsys) if err != nil { return nil, err } - if len(matches) == 0 { - return nil, fmt.Errorf("glob %s: %w", path, os.ErrNotExist) - } - return NewMigrationFromFile(matches[0], fsys) + return NewMigrationFromFile(name, fsys) } func NewMigrationFromFile(path string, fsys afero.Fs) (*MigrationFile, error) { diff --git a/internal/migration/squash/squash.go b/internal/migration/squash/squash.go index 7c07d9cb57..d095de0772 100644 --- a/internal/migration/squash/squash.go +++ b/internal/migration/squash/squash.go @@ -25,6 +25,9 @@ func Run(ctx context.Context, version string, config pgconn.Config, fsys afero.F if _, err := strconv.Atoi(version); err != nil { return repair.ErrInvalidVersion } + if _, err := repair.GetMigrationFile(version, fsys); err != nil { + return err + } } if err := utils.LoadConfigFS(fsys); err != nil { return err From 968117d81323def79d7fa6d8903e08ae8193c82f Mon Sep 17 00:00:00 2001 From: Qiao Han Date: Fri, 11 Aug 2023 11:41:24 +0800 Subject: [PATCH 3/4] chore: refactor migrate and seed function --- internal/db/push/push.go | 3 +- internal/db/reset/reset.go | 29 ++---------- internal/db/reset/reset_test.go | 58 ----------------------- internal/db/start/start.go | 2 +- internal/migration/apply/apply.go | 20 +++++++- internal/migration/apply/apply_test.go | 64 ++++++++++++++++++++++++-- 6 files changed, 86 insertions(+), 90 deletions(-) diff --git a/internal/db/push/push.go b/internal/db/push/push.go index 879fbec6d1..ef2f6c08b3 100644 --- a/internal/db/push/push.go +++ b/internal/db/push/push.go @@ -10,7 +10,6 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/spf13/afero" - "github.com/supabase/cli/internal/db/reset" "github.com/supabase/cli/internal/migration/apply" "github.com/supabase/cli/internal/migration/up" "github.com/supabase/cli/internal/utils" @@ -51,7 +50,7 @@ func Run(ctx context.Context, dryRun, ignoreVersionMismatch bool, includeRoles, } // Seed database if !dryRun && includeSeed { - if err := reset.SeedDatabase(ctx, conn, fsys); err != nil { + if err := apply.SeedDatabase(ctx, conn, fsys); err != nil { return err } } diff --git a/internal/db/reset/reset.go b/internal/db/reset/reset.go index 183502519f..f88d1222e0 100644 --- a/internal/db/reset/reset.go +++ b/internal/db/reset/reset.go @@ -73,7 +73,7 @@ func Run(ctx context.Context, version string, config pgconn.Config, fsys afero.F } func resetDatabase(ctx context.Context, version string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - fmt.Fprintln(os.Stderr, "Resetting local database"+getMessage(version)) + fmt.Fprintln(os.Stderr, "Resetting local database"+toLogMessage(version)) if err := recreateDatabase(ctx, options...); err != nil { return err } @@ -93,10 +93,10 @@ func resetDatabase(ctx context.Context, version string, fsys afero.Fs, options . return err } defer conn.Close(context.Background()) - return InitialiseDatabase(ctx, version, conn, fsys) + return apply.MigrateAndSeed(ctx, version, conn, fsys) } -func getMessage(version string) string { +func toLogMessage(version string) string { if len(version) > 0 { return " to version: " + version } @@ -112,13 +112,6 @@ func initDatabase(ctx context.Context, options ...func(*pgx.ConnConfig)) error { return apply.BatchExecDDL(ctx, conn, strings.NewReader(utils.InitialSchemaSql)) } -func InitialiseDatabase(ctx context.Context, version string, conn *pgx.Conn, fsys afero.Fs) error { - if err := apply.MigrateDatabase(ctx, version, conn, fsys); err != nil { - return err - } - return SeedDatabase(ctx, conn, fsys) -} - // Recreate postgres database by connecting to template1 func recreateDatabase(ctx context.Context, options ...func(*pgx.ConnConfig)) error { conn, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{User: "supabase_admin", Database: "template1"}, options...) @@ -139,18 +132,6 @@ func recreateDatabase(ctx context.Context, options ...func(*pgx.ConnConfig)) err return sql.ExecBatch(ctx, conn) } -func SeedDatabase(ctx context.Context, conn *pgx.Conn, fsys afero.Fs) error { - seed, err := repair.NewMigrationFromFile(utils.SeedDataPath, fsys) - if errors.Is(err, os.ErrNotExist) { - return nil - } else if err != nil { - return err - } - fmt.Fprintln(os.Stderr, "Seeding data "+utils.Bold(utils.SeedDataPath)+"...") - // Batch seed commands, safe to use statement cache - return seed.ExecBatchWithCache(ctx, conn) -} - func DisconnectClients(ctx context.Context, conn *pgx.Conn) error { // Must be executed separately because running in transaction is unsupported disconn := "ALTER DATABASE postgres ALLOW_CONNECTIONS false;" @@ -241,7 +222,7 @@ func WaitForServiceReady(ctx context.Context, started []string) error { } func resetRemote(ctx context.Context, version string, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - fmt.Fprintln(os.Stderr, "Resetting remote database"+getMessage(version)) + fmt.Fprintln(os.Stderr, "Resetting remote database"+toLogMessage(version)) conn, err := utils.ConnectRemotePostgres(ctx, config, options...) if err != nil { return err @@ -264,7 +245,7 @@ func resetRemote(ctx context.Context, version string, config pgconn.Config, fsys if err := migration.ExecBatch(ctx, conn); err != nil { return err } - return InitialiseDatabase(ctx, version, conn, fsys) + return apply.MigrateAndSeed(ctx, version, conn, fsys) } func ListSchemas(ctx context.Context, conn *pgx.Conn, exclude ...string) ([]string, error) { diff --git a/internal/db/reset/reset_test.go b/internal/db/reset/reset_test.go index c9439ecad3..d16ae53c56 100644 --- a/internal/db/reset/reset_test.go +++ b/internal/db/reset/reset_test.go @@ -18,7 +18,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/supabase/cli/internal/testing/apitest" - "github.com/supabase/cli/internal/testing/fstest" "github.com/supabase/cli/internal/testing/pgtest" "github.com/supabase/cli/internal/utils" "gopkg.in/h2non/gock.v1" @@ -116,63 +115,6 @@ func TestInitDatabase(t *testing.T) { }) } -func TestSeedDatabase(t *testing.T) { - t.Run("seeds from file", func(t *testing.T) { - // Setup in-memory fs - fsys := afero.NewMemMapFs() - // Setup seed file - sql := "INSERT INTO employees(name) VALUES ('Alice')" - require.NoError(t, afero.WriteFile(fsys, utils.SeedDataPath, []byte(sql), 0644)) - // Setup mock postgres - conn := pgtest.NewConn() - defer conn.Close(t) - conn.Query(sql). - Reply("INSERT 0 1") - // Connect to mock - ctx := context.Background() - mock, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{Port: 5432}, conn.Intercept) - require.NoError(t, err) - defer mock.Close(ctx) - // Run test - assert.NoError(t, SeedDatabase(ctx, mock, fsys)) - }) - - t.Run("ignores missing seed", func(t *testing.T) { - assert.NoError(t, SeedDatabase(context.Background(), nil, afero.NewMemMapFs())) - }) - - t.Run("throws error on read failure", func(t *testing.T) { - // Setup in-memory fs - fsys := &fstest.OpenErrorFs{DenyPath: utils.SeedDataPath} - // Run test - err := SeedDatabase(context.Background(), nil, fsys) - // Check error - assert.ErrorIs(t, err, os.ErrPermission) - }) - - t.Run("throws error on insert failure", func(t *testing.T) { - // Setup in-memory fs - fsys := afero.NewMemMapFs() - // Setup seed file - sql := "INSERT INTO employees(name) VALUES ('Alice')" - require.NoError(t, afero.WriteFile(fsys, utils.SeedDataPath, []byte(sql), 0644)) - // Setup mock postgres - conn := pgtest.NewConn() - defer conn.Close(t) - conn.Query(sql). - ReplyError(pgerrcode.NotNullViolation, `null value in column "age" of relation "employees"`) - // Connect to mock - ctx := context.Background() - mock, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{Port: 5432}, conn.Intercept) - require.NoError(t, err) - defer mock.Close(ctx) - // Run test - err = SeedDatabase(ctx, mock, fsys) - // Check error - assert.ErrorContains(t, err, `ERROR: null value in column "age" of relation "employees" (SQLSTATE 23502)`) - }) -} - func TestRecreateDatabase(t *testing.T) { t.Run("resets postgres database", func(t *testing.T) { utils.Config.Db.Port = 54322 diff --git a/internal/db/start/start.go b/internal/db/start/start.go index ed21389e8d..229543cda1 100644 --- a/internal/db/start/start.go +++ b/internal/db/start/start.go @@ -163,7 +163,7 @@ func setupDatabase(ctx context.Context, fsys afero.Fs, w io.Writer, options ...f if err := SetupDatabase(ctx, conn, utils.DbId, w, fsys); err != nil { return err } - return reset.InitialiseDatabase(ctx, "", conn, fsys) + return apply.MigrateAndSeed(ctx, "", conn, fsys) } func SetupDatabase(ctx context.Context, conn *pgx.Conn, host string, w io.Writer, fsys afero.Fs) error { diff --git a/internal/migration/apply/apply.go b/internal/migration/apply/apply.go index f8cff7289d..3854421959 100644 --- a/internal/migration/apply/apply.go +++ b/internal/migration/apply/apply.go @@ -2,6 +2,7 @@ package apply import ( "context" + "errors" "fmt" "io" "os" @@ -14,12 +15,27 @@ import ( "github.com/supabase/cli/internal/utils" ) -func MigrateDatabase(ctx context.Context, version string, conn *pgx.Conn, fsys afero.Fs) error { +func MigrateAndSeed(ctx context.Context, version string, conn *pgx.Conn, fsys afero.Fs) error { migrations, err := list.LoadPartialMigrations(version, fsys) if err != nil { return err } - return MigrateUp(ctx, conn, migrations, fsys) + if err := MigrateUp(ctx, conn, migrations, fsys); err != nil { + return err + } + return SeedDatabase(ctx, conn, fsys) +} + +func SeedDatabase(ctx context.Context, conn *pgx.Conn, fsys afero.Fs) error { + seed, err := repair.NewMigrationFromFile(utils.SeedDataPath, fsys) + if errors.Is(err, os.ErrNotExist) { + return nil + } else if err != nil { + return err + } + fmt.Fprintln(os.Stderr, "Seeding data "+utils.Bold(utils.SeedDataPath)+"...") + // Batch seed commands, safe to use statement cache + return seed.ExecBatchWithCache(ctx, conn) } func MigrateUp(ctx context.Context, conn *pgx.Conn, pending []string, fsys afero.Fs) error { diff --git a/internal/migration/apply/apply_test.go b/internal/migration/apply/apply_test.go index 8c574d22ad..fda922563a 100644 --- a/internal/migration/apply/apply_test.go +++ b/internal/migration/apply/apply_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/supabase/cli/internal/migration/repair" + "github.com/supabase/cli/internal/testing/fstest" "github.com/supabase/cli/internal/testing/pgtest" "github.com/supabase/cli/internal/utils" ) @@ -45,25 +46,82 @@ func TestMigrateDatabase(t *testing.T) { require.NoError(t, err) defer mock.Close(ctx) // Run test - err = MigrateDatabase(ctx, "", mock, fsys) + err = MigrateAndSeed(ctx, "", mock, fsys) // Check error assert.NoError(t, err) }) t.Run("ignores empty local directory", func(t *testing.T) { - assert.NoError(t, MigrateDatabase(context.Background(), "", nil, afero.NewMemMapFs())) + assert.NoError(t, MigrateAndSeed(context.Background(), "", nil, afero.NewMemMapFs())) }) t.Run("throws error on write failure", func(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() // Run test - err := MigrateDatabase(context.Background(), "", nil, afero.NewReadOnlyFs(fsys)) + err := MigrateAndSeed(context.Background(), "", nil, afero.NewReadOnlyFs(fsys)) // Check error assert.ErrorIs(t, err, os.ErrPermission) }) } +func TestSeedDatabase(t *testing.T) { + t.Run("seeds from file", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Setup seed file + sql := "INSERT INTO employees(name) VALUES ('Alice')" + require.NoError(t, afero.WriteFile(fsys, utils.SeedDataPath, []byte(sql), 0644)) + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + conn.Query(sql). + Reply("INSERT 0 1") + // Connect to mock + ctx := context.Background() + mock, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{Port: 5432}, conn.Intercept) + require.NoError(t, err) + defer mock.Close(ctx) + // Run test + assert.NoError(t, SeedDatabase(ctx, mock, fsys)) + }) + + t.Run("ignores missing seed", func(t *testing.T) { + assert.NoError(t, SeedDatabase(context.Background(), nil, afero.NewMemMapFs())) + }) + + t.Run("throws error on read failure", func(t *testing.T) { + // Setup in-memory fs + fsys := &fstest.OpenErrorFs{DenyPath: utils.SeedDataPath} + // Run test + err := SeedDatabase(context.Background(), nil, fsys) + // Check error + assert.ErrorIs(t, err, os.ErrPermission) + }) + + t.Run("throws error on insert failure", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Setup seed file + sql := "INSERT INTO employees(name) VALUES ('Alice')" + require.NoError(t, afero.WriteFile(fsys, utils.SeedDataPath, []byte(sql), 0644)) + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + conn.Query(sql). + ReplyError(pgerrcode.NotNullViolation, `null value in column "age" of relation "employees"`) + // Connect to mock + ctx := context.Background() + mock, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{Port: 5432}, conn.Intercept) + require.NoError(t, err) + defer mock.Close(ctx) + // Run test + err = SeedDatabase(ctx, mock, fsys) + // Check error + assert.ErrorContains(t, err, `ERROR: null value in column "age" of relation "employees" (SQLSTATE 23502)`) + }) +} + func TestMigrateUp(t *testing.T) { t.Run("throws error on exec failure", func(t *testing.T) { // Setup in-memory fs From c94a11c0cbf7b87a2087f18d75274a53784c35c3 Mon Sep 17 00:00:00 2001 From: Qiao Han Date: Fri, 11 Aug 2023 11:44:56 +0800 Subject: [PATCH 4/4] chore: remove exact arg check for backwards compatibility --- cmd/db.go | 1 - 1 file changed, 1 deletion(-) diff --git a/cmd/db.go b/cmd/db.go index ed9d1a8902..6d32bf0e5b 100644 --- a/cmd/db.go +++ b/cmd/db.go @@ -175,7 +175,6 @@ var ( dbResetCmd = &cobra.Command{ Use: "reset", Short: "Resets the local database to current migrations", - Args: cobra.ExactArgs(0), RunE: func(cmd *cobra.Command, args []string) error { fsys := afero.NewOsFs() if linked || len(dbUrl) > 0 {