From 695db46ff92e4b5ee55b078acdf545df78dc516c Mon Sep 17 00:00:00 2001 From: Qiao Han Date: Tue, 16 May 2023 14:25:07 +0800 Subject: [PATCH] feat: support reseting remote database --- cmd/db.go | 12 ++++- internal/db/diff/migra.go | 27 +--------- internal/db/diff/migra_test.go | 3 +- internal/db/reset/reset.go | 69 ++++++++++++++++++++++-- internal/db/reset/reset_test.go | 78 ++++++++++++++++++++++++++-- internal/db/reset/templates/drop.sql | 35 +++++++++++++ 6 files changed, 190 insertions(+), 34 deletions(-) create mode 100644 internal/db/reset/templates/drop.sql diff --git a/cmd/db.go b/cmd/db.go index 057a88eab..0d481358d 100644 --- a/cmd/db.go +++ b/cmd/db.go @@ -166,8 +166,14 @@ var ( Use: "reset", Short: "Resets the local database to current migrations", RunE: func(cmd *cobra.Command, args []string) error { + fsys := afero.NewOsFs() + if linked || len(dbUrl) > 0 { + if err := parseDatabaseConfig(fsys); err != nil { + return err + } + } ctx, _ := signal.NotifyContext(cmd.Context(), os.Interrupt) - return reset.Run(ctx, afero.NewOsFs()) + return reset.Run(ctx, dbConfig, fsys) }, } @@ -187,7 +193,7 @@ var ( } } ctx, _ := signal.NotifyContext(cmd.Context(), os.Interrupt) - return lint.Run(ctx, schema, level.Value, dbConfig, afero.NewOsFs()) + return lint.Run(ctx, schema, level.Value, dbConfig, fsys) }, } @@ -255,6 +261,8 @@ func init() { dbRemoteCmd.AddCommand(dbRemoteCommitCmd) dbCmd.AddCommand(dbRemoteCmd) // Build reset command + resetFlags := dbResetCmd.Flags() + resetFlags.BoolVar(&linked, "linked", false, "Resets the linked project to current migrations.") dbCmd.AddCommand(dbResetCmd) // Build lint command lintFlags := dbLintCmd.Flags() diff --git a/internal/db/diff/migra.go b/internal/db/diff/migra.go index a1663afc9..961067982 100644 --- a/internal/db/diff/migra.go +++ b/internal/db/diff/migra.go @@ -17,13 +17,12 @@ import ( "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" "github.com/spf13/afero" + "github.com/supabase/cli/internal/db/reset" "github.com/supabase/cli/internal/db/start" "github.com/supabase/cli/internal/migration/apply" "github.com/supabase/cli/internal/utils" ) -const LIST_SCHEMAS = "SELECT schema_name FROM information_schema.schemata WHERE NOT schema_name LIKE ANY($1) ORDER BY schema_name" - var ( //go:embed templates/migra.sh diffSchemaScript string @@ -94,29 +93,7 @@ func LoadUserSchemas(ctx context.Context, conn *pgx.Conn, exclude ...string) ([] "supabase_migrations", }, utils.SystemSchemas...) } - exclude = likeEscapeSchema(exclude) - rows, err := conn.Query(ctx, LIST_SCHEMAS, exclude) - if err != nil { - return nil, err - } - schemas := []string{} - for rows.Next() { - var name string - if err := rows.Scan(&name); err != nil { - return nil, err - } - schemas = append(schemas, name) - } - return schemas, nil -} - -func likeEscapeSchema(schemas []string) (result []string) { - // Treat _ as literal, * as any character - replacer := strings.NewReplacer("_", `\_`, "*", "%") - for _, sch := range schemas { - result = append(result, replacer.Replace(sch)) - } - return result + return reset.ListSchemas(ctx, conn, exclude...) } func CreateShadowDatabase(ctx context.Context) (string, error) { diff --git a/internal/db/diff/migra_test.go b/internal/db/diff/migra_test.go index a0d785f29..f41c4546a 100644 --- a/internal/db/diff/migra_test.go +++ b/internal/db/diff/migra_test.go @@ -15,6 +15,7 @@ import ( "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/supabase/cli/internal/db/reset" "github.com/supabase/cli/internal/testing/apitest" "github.com/supabase/cli/internal/testing/pgtest" "github.com/supabase/cli/internal/utils" @@ -267,7 +268,7 @@ func TestUserSchema(t *testing.T) { // Setup mock postgres conn := pgtest.NewConn() defer conn.Close(t) - conn.Query(strings.ReplaceAll(LIST_SCHEMAS, "$1", "'{public}'")). + conn.Query(strings.ReplaceAll(reset.LIST_SCHEMAS, "$1", "'{public}'")). Reply("SELECT 1", []interface{}{"test"}) // Connect to mock ctx := context.Background() diff --git a/internal/db/reset/reset.go b/internal/db/reset/reset.go index 7c711af58..6e3ec2282 100644 --- a/internal/db/reset/reset.go +++ b/internal/db/reset/reset.go @@ -2,6 +2,7 @@ package reset import ( "context" + _ "embed" "errors" "fmt" "io" @@ -19,13 +20,23 @@ import ( "github.com/supabase/cli/internal/utils" ) -const SET_POSTGRES_ROLE = "SET ROLE postgres;" +const ( + SET_POSTGRES_ROLE = "SET ROLE postgres;" + LIST_SCHEMAS = "SELECT schema_name FROM information_schema.schemata WHERE NOT schema_name LIKE ANY($1) ORDER BY schema_name" +) var ( healthTimeout = 5 * time.Second + //go:embed templates/drop.sql + dropObjects string ) -func Run(ctx context.Context, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { +func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { + if len(config.Password) > 0 { + fmt.Fprintln(os.Stderr, "Resetting remote database...") + return resetRemote(ctx, config, fsys, options...) + } + // Sanity checks. { if err := utils.LoadConfigFS(fsys); err != nil { @@ -52,7 +63,7 @@ func Run(ctx context.Context, fsys afero.Fs, options ...func(*pgx.ConnConfig)) e } func resetDatabase(ctx context.Context, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - fmt.Fprintln(os.Stderr, "Resetting database...") + fmt.Fprintln(os.Stderr, "Resetting local database...") if err := RecreateDatabase(ctx, options...); err != nil { return err } @@ -180,3 +191,55 @@ func WaitForHealthyService(ctx context.Context, container string, timeout time.D } return RetryEverySecond(ctx, probe, timeout) } + +func resetRemote(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { + conn, err := utils.ConnectRemotePostgres(ctx, config, options...) + if err != nil { + return err + } + defer conn.Close(context.Background()) + // List user defined schemas + excludes := append([]string{"public"}, utils.InternalSchemas...) + userSchemas, err := ListSchemas(ctx, conn, excludes...) + if err != nil { + return err + } + userSchemas = append(userSchemas, "supabase_migrations") + // Drop user defined objects + migration := apply.MigrationFile{} + for _, schema := range userSchemas { + sql := fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schema) + migration.Lines = append(migration.Lines, sql) + } + migration.Lines = append(migration.Lines, dropObjects) + if err := migration.ExecBatch(ctx, conn); err != nil { + return err + } + return InitialiseDatabase(ctx, conn, fsys) +} + +func ListSchemas(ctx context.Context, conn *pgx.Conn, exclude ...string) ([]string, error) { + exclude = likeEscapeSchema(exclude) + rows, err := conn.Query(ctx, LIST_SCHEMAS, exclude) + if err != nil { + return nil, err + } + schemas := []string{} + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, err + } + schemas = append(schemas, name) + } + return schemas, nil +} + +func likeEscapeSchema(schemas []string) (result []string) { + // Treat _ as literal, * as any character + replacer := strings.NewReplacer("_", `\_`, "*", "%") + for _, sch := range schemas { + result = append(result, replacer.Replace(sch)) + } + return result +} diff --git a/internal/db/reset/reset_test.go b/internal/db/reset/reset_test.go index c83c1044e..b3d0862b3 100644 --- a/internal/db/reset/reset_test.go +++ b/internal/db/reset/reset_test.go @@ -7,11 +7,13 @@ import ( "io" "net/http" "os" + "strings" "testing" "time" "github.com/docker/docker/api/types" "github.com/docker/docker/client" + "github.com/jackc/pgconn" "github.com/jackc/pgerrcode" "github.com/spf13/afero" "github.com/stretchr/testify/assert" @@ -25,7 +27,7 @@ import ( func TestResetCommand(t *testing.T) { t.Run("throws error on missing config", func(t *testing.T) { - err := Run(context.Background(), afero.NewMemMapFs()) + err := Run(context.Background(), pgconn.Config{}, afero.NewMemMapFs()) assert.ErrorIs(t, err, os.ErrNotExist) }) @@ -40,7 +42,7 @@ func TestResetCommand(t *testing.T) { Get("/v" + utils.Docker.ClientVersion() + "/containers"). Reply(http.StatusServiceUnavailable) // Run test - err := Run(context.Background(), fsys) + err := Run(context.Background(), pgconn.Config{}, fsys) // Check error assert.ErrorIs(t, err, utils.ErrNotRunning) assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -63,7 +65,7 @@ func TestResetCommand(t *testing.T) { conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false;"). ReplyError(pgerrcode.InvalidParameterValue, `cannot disallow connections for current database`) // Run test - err := Run(context.Background(), fsys, conn.Intercept) + err := Run(context.Background(), pgconn.Config{}, fsys, conn.Intercept) // Check error assert.ErrorContains(t, err, "ERROR: cannot disallow connections for current database (SQLSTATE 22023)") assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -354,3 +356,73 @@ func TestRestartDatabase(t *testing.T) { assert.Empty(t, apitest.ListUnmatchedRequests()) }) } + +func TestResetRemote(t *testing.T) { + dbConfig := pgconn.Config{ + Host: "localhost", + Port: 5432, + User: "admin", + Password: "password", + Database: "postgres", + } + + t.Run("resets remote database", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + conn.Query(strings.ReplaceAll(LIST_SCHEMAS, "$1", "'{public,auth,extensions,pgbouncer,realtime,\"\\\\_realtime\",storage,\"\\\\_analytics\",\"supabase\\\\_functions\",\"supabase\\\\_migrations\",\"information\\\\_schema\",\"pg\\\\_%\",cron,graphql,\"graphql\\\\_public\",net,pgsodium,\"pgsodium\\\\_masks\",pgtle,repack,tiger,\"tiger\\\\_data\",\"timescaledb\\\\_%\",\"\\\\_timescaledb\\\\_%\",topology,vault}'")). + Reply("SELECT 1", []interface{}{"private"}). + Query("DROP SCHEMA IF EXISTS private CASCADE"). + Reply("DROP SCHEMA"). + Query("DROP SCHEMA IF EXISTS supabase_migrations CASCADE"). + Reply("DROP SCHEMA"). + Query(dropObjects). + Reply("INSERT 0") + // Run test + err := resetRemote(context.Background(), dbConfig, fsys, conn.Intercept) + // Check error + assert.NoError(t, err) + }) + + t.Run("throws error on connect failure", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Run test + err := resetRemote(context.Background(), pgconn.Config{}, fsys) + // Check error + assert.ErrorContains(t, err, "invalid port (outside range)") + }) + + t.Run("throws error on list schema failure", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + conn.Query(strings.ReplaceAll(LIST_SCHEMAS, "$1", "'{public,auth,extensions,pgbouncer,realtime,\"\\\\_realtime\",storage,\"\\\\_analytics\",\"supabase\\\\_functions\",\"supabase\\\\_migrations\",\"information\\\\_schema\",\"pg\\\\_%\",cron,graphql,\"graphql\\\\_public\",net,pgsodium,\"pgsodium\\\\_masks\",pgtle,repack,tiger,\"tiger\\\\_data\",\"timescaledb\\\\_%\",\"\\\\_timescaledb\\\\_%\",topology,vault}'")). + ReplyError(pgerrcode.InsufficientPrivilege, "permission denied for relation information_schema") + // Run test + err := resetRemote(context.Background(), dbConfig, fsys, conn.Intercept) + // Check error + assert.ErrorContains(t, err, "ERROR: permission denied for relation information_schema (SQLSTATE 42501)") + }) + + t.Run("throws error on drop schema failure", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + conn.Query(strings.ReplaceAll(LIST_SCHEMAS, "$1", "'{public,auth,extensions,pgbouncer,realtime,\"\\\\_realtime\",storage,\"\\\\_analytics\",\"supabase\\\\_functions\",\"supabase\\\\_migrations\",\"information\\\\_schema\",\"pg\\\\_%\",cron,graphql,\"graphql\\\\_public\",net,pgsodium,\"pgsodium\\\\_masks\",pgtle,repack,tiger,\"tiger\\\\_data\",\"timescaledb\\\\_%\",\"\\\\_timescaledb\\\\_%\",topology,vault}'")). + Reply("SELECT 0"). + Query("DROP SCHEMA IF EXISTS supabase_migrations CASCADE"). + ReplyError(pgerrcode.InsufficientPrivilege, "permission denied for relation supabase_migrations"). + Query(dropObjects) + // Run test + err := resetRemote(context.Background(), dbConfig, fsys, conn.Intercept) + // Check error + assert.ErrorContains(t, err, "ERROR: permission denied for relation supabase_migrations (SQLSTATE 42501)") + }) +} diff --git a/internal/db/reset/templates/drop.sql b/internal/db/reset/templates/drop.sql new file mode 100644 index 000000000..b94096535 --- /dev/null +++ b/internal/db/reset/templates/drop.sql @@ -0,0 +1,35 @@ +do $$ declare + rec record; +begin + -- functions + for rec in + select * + from pg_proc p + where p.pronamespace::regnamespace::name = 'public' + loop + -- supports aggregate, function, and procedure + execute format('drop routine if exists %I.%I(%s) cascade', rec.pronamespace::regnamespace::name, rec.proname, pg_catalog.pg_get_function_identity_arguments(rec.oid)); + end loop; + + -- in order: tables (cascade to views), sequences + for rec in + select * + from pg_class c + where + c.relnamespace::regnamespace::name = 'public' + and c.relkind not in ('c', 'v', 'm') + order by c.relkind desc + loop + -- supports all kinds of relations, except views and complex types + execute format('drop table if exists %I.%I cascade', rec.relnamespace::regnamespace::name, rec.relname); + end loop; + + -- types + for rec in + select * + from pg_type t + where t.typnamespace::regnamespace::name = 'public' + loop + execute format('drop type if exists %I.%I cascade', rec.typnamespace::regnamespace::name, rec.typname); + end loop; +end $$;