Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions cmd/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,14 @@ var (
Use: "reset",
Short: "Resets the local database to current migrations",
RunE: func(cmd *cobra.Command, args []string) error {
fsys := afero.NewOsFs()
if linked || len(dbUrl) > 0 {
if err := parseDatabaseConfig(fsys); err != nil {
return err
}
}
ctx, _ := signal.NotifyContext(cmd.Context(), os.Interrupt)
return reset.Run(ctx, afero.NewOsFs())
return reset.Run(ctx, dbConfig, fsys)
},
}

Expand All @@ -187,7 +193,7 @@ var (
}
}
ctx, _ := signal.NotifyContext(cmd.Context(), os.Interrupt)
return lint.Run(ctx, schema, level.Value, dbConfig, afero.NewOsFs())
return lint.Run(ctx, schema, level.Value, dbConfig, fsys)
},
}

Expand Down Expand Up @@ -255,6 +261,8 @@ func init() {
dbRemoteCmd.AddCommand(dbRemoteCommitCmd)
dbCmd.AddCommand(dbRemoteCmd)
// Build reset command
resetFlags := dbResetCmd.Flags()
resetFlags.BoolVar(&linked, "linked", false, "Resets the linked project to current migrations.")
dbCmd.AddCommand(dbResetCmd)
// Build lint command
lintFlags := dbLintCmd.Flags()
Expand Down
27 changes: 2 additions & 25 deletions internal/db/diff/migra.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@ import (
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/spf13/afero"
"github.com/supabase/cli/internal/db/reset"
"github.com/supabase/cli/internal/db/start"
"github.com/supabase/cli/internal/migration/apply"
"github.com/supabase/cli/internal/utils"
)

const LIST_SCHEMAS = "SELECT schema_name FROM information_schema.schemata WHERE NOT schema_name LIKE ANY($1) ORDER BY schema_name"

var (
//go:embed templates/migra.sh
diffSchemaScript string
Expand Down Expand Up @@ -94,29 +93,7 @@ func LoadUserSchemas(ctx context.Context, conn *pgx.Conn, exclude ...string) ([]
"supabase_migrations",
}, utils.SystemSchemas...)
}
exclude = likeEscapeSchema(exclude)
rows, err := conn.Query(ctx, LIST_SCHEMAS, exclude)
if err != nil {
return nil, err
}
schemas := []string{}
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
schemas = append(schemas, name)
}
return schemas, nil
}

func likeEscapeSchema(schemas []string) (result []string) {
// Treat _ as literal, * as any character
replacer := strings.NewReplacer("_", `\_`, "*", "%")
for _, sch := range schemas {
result = append(result, replacer.Replace(sch))
}
return result
return reset.ListSchemas(ctx, conn, exclude...)
}

func CreateShadowDatabase(ctx context.Context) (string, error) {
Expand Down
3 changes: 2 additions & 1 deletion internal/db/diff/migra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/supabase/cli/internal/db/reset"
"github.com/supabase/cli/internal/testing/apitest"
"github.com/supabase/cli/internal/testing/pgtest"
"github.com/supabase/cli/internal/utils"
Expand Down Expand Up @@ -267,7 +268,7 @@ func TestUserSchema(t *testing.T) {
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query(strings.ReplaceAll(LIST_SCHEMAS, "$1", "'{public}'")).
conn.Query(strings.ReplaceAll(reset.LIST_SCHEMAS, "$1", "'{public}'")).
Reply("SELECT 1", []interface{}{"test"})
// Connect to mock
ctx := context.Background()
Expand Down
69 changes: 66 additions & 3 deletions internal/db/reset/reset.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package reset

import (
"context"
_ "embed"
"errors"
"fmt"
"io"
Expand All @@ -19,13 +20,23 @@ import (
"github.com/supabase/cli/internal/utils"
)

const SET_POSTGRES_ROLE = "SET ROLE postgres;"
const (
SET_POSTGRES_ROLE = "SET ROLE postgres;"
LIST_SCHEMAS = "SELECT schema_name FROM information_schema.schemata WHERE NOT schema_name LIKE ANY($1) ORDER BY schema_name"
)

var (
healthTimeout = 5 * time.Second
//go:embed templates/drop.sql
dropObjects string
)

func Run(ctx context.Context, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
func Run(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
if len(config.Password) > 0 {
fmt.Fprintln(os.Stderr, "Resetting remote database...")
return resetRemote(ctx, config, fsys, options...)
}

// Sanity checks.
{
if err := utils.LoadConfigFS(fsys); err != nil {
Expand All @@ -52,7 +63,7 @@ func Run(ctx context.Context, fsys afero.Fs, options ...func(*pgx.ConnConfig)) e
}

func resetDatabase(ctx context.Context, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
fmt.Fprintln(os.Stderr, "Resetting database...")
fmt.Fprintln(os.Stderr, "Resetting local database...")
if err := RecreateDatabase(ctx, options...); err != nil {
return err
}
Expand Down Expand Up @@ -180,3 +191,55 @@ func WaitForHealthyService(ctx context.Context, container string, timeout time.D
}
return RetryEverySecond(ctx, probe, timeout)
}

func resetRemote(ctx context.Context, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
conn, err := utils.ConnectRemotePostgres(ctx, config, options...)
if err != nil {
return err
}
defer conn.Close(context.Background())
// List user defined schemas
excludes := append([]string{"public"}, utils.InternalSchemas...)
userSchemas, err := ListSchemas(ctx, conn, excludes...)
if err != nil {
return err
}
userSchemas = append(userSchemas, "supabase_migrations")
// Drop user defined objects
migration := apply.MigrationFile{}
for _, schema := range userSchemas {
sql := fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", schema)
migration.Lines = append(migration.Lines, sql)
}
migration.Lines = append(migration.Lines, dropObjects)
if err := migration.ExecBatch(ctx, conn); err != nil {
return err
}
return InitialiseDatabase(ctx, conn, fsys)
}

func ListSchemas(ctx context.Context, conn *pgx.Conn, exclude ...string) ([]string, error) {
exclude = likeEscapeSchema(exclude)
rows, err := conn.Query(ctx, LIST_SCHEMAS, exclude)
if err != nil {
return nil, err
}
schemas := []string{}
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
schemas = append(schemas, name)
}
return schemas, nil
}

func likeEscapeSchema(schemas []string) (result []string) {
// Treat _ as literal, * as any character
replacer := strings.NewReplacer("_", `\_`, "*", "%")
for _, sch := range schemas {
result = append(result, replacer.Replace(sch))
}
return result
}
78 changes: 75 additions & 3 deletions internal/db/reset/reset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import (
"io"
"net/http"
"os"
"strings"
"testing"
"time"

"github.com/docker/docker/api/types"
"github.com/docker/docker/client"
"github.com/jackc/pgconn"
"github.com/jackc/pgerrcode"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
Expand All @@ -25,7 +27,7 @@ import (

func TestResetCommand(t *testing.T) {
t.Run("throws error on missing config", func(t *testing.T) {
err := Run(context.Background(), afero.NewMemMapFs())
err := Run(context.Background(), pgconn.Config{}, afero.NewMemMapFs())
assert.ErrorIs(t, err, os.ErrNotExist)
})

Expand All @@ -40,7 +42,7 @@ func TestResetCommand(t *testing.T) {
Get("/v" + utils.Docker.ClientVersion() + "/containers").
Reply(http.StatusServiceUnavailable)
// Run test
err := Run(context.Background(), fsys)
err := Run(context.Background(), pgconn.Config{}, fsys)
// Check error
assert.ErrorIs(t, err, utils.ErrNotRunning)
assert.Empty(t, apitest.ListUnmatchedRequests())
Expand All @@ -63,7 +65,7 @@ func TestResetCommand(t *testing.T) {
conn.Query("ALTER DATABASE postgres ALLOW_CONNECTIONS false;").
ReplyError(pgerrcode.InvalidParameterValue, `cannot disallow connections for current database`)
// Run test
err := Run(context.Background(), fsys, conn.Intercept)
err := Run(context.Background(), pgconn.Config{}, fsys, conn.Intercept)
// Check error
assert.ErrorContains(t, err, "ERROR: cannot disallow connections for current database (SQLSTATE 22023)")
assert.Empty(t, apitest.ListUnmatchedRequests())
Expand Down Expand Up @@ -354,3 +356,73 @@ func TestRestartDatabase(t *testing.T) {
assert.Empty(t, apitest.ListUnmatchedRequests())
})
}

func TestResetRemote(t *testing.T) {
dbConfig := pgconn.Config{
Host: "localhost",
Port: 5432,
User: "admin",
Password: "password",
Database: "postgres",
}

t.Run("resets remote database", func(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query(strings.ReplaceAll(LIST_SCHEMAS, "$1", "'{public,auth,extensions,pgbouncer,realtime,\"\\\\_realtime\",storage,\"\\\\_analytics\",\"supabase\\\\_functions\",\"supabase\\\\_migrations\",\"information\\\\_schema\",\"pg\\\\_%\",cron,graphql,\"graphql\\\\_public\",net,pgsodium,\"pgsodium\\\\_masks\",pgtle,repack,tiger,\"tiger\\\\_data\",\"timescaledb\\\\_%\",\"\\\\_timescaledb\\\\_%\",topology,vault}'")).
Reply("SELECT 1", []interface{}{"private"}).
Query("DROP SCHEMA IF EXISTS private CASCADE").
Reply("DROP SCHEMA").
Query("DROP SCHEMA IF EXISTS supabase_migrations CASCADE").
Reply("DROP SCHEMA").
Query(dropObjects).
Reply("INSERT 0")
// Run test
err := resetRemote(context.Background(), dbConfig, fsys, conn.Intercept)
// Check error
assert.NoError(t, err)
})

t.Run("throws error on connect failure", func(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
// Run test
err := resetRemote(context.Background(), pgconn.Config{}, fsys)
// Check error
assert.ErrorContains(t, err, "invalid port (outside range)")
})

t.Run("throws error on list schema failure", func(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query(strings.ReplaceAll(LIST_SCHEMAS, "$1", "'{public,auth,extensions,pgbouncer,realtime,\"\\\\_realtime\",storage,\"\\\\_analytics\",\"supabase\\\\_functions\",\"supabase\\\\_migrations\",\"information\\\\_schema\",\"pg\\\\_%\",cron,graphql,\"graphql\\\\_public\",net,pgsodium,\"pgsodium\\\\_masks\",pgtle,repack,tiger,\"tiger\\\\_data\",\"timescaledb\\\\_%\",\"\\\\_timescaledb\\\\_%\",topology,vault}'")).
ReplyError(pgerrcode.InsufficientPrivilege, "permission denied for relation information_schema")
// Run test
err := resetRemote(context.Background(), dbConfig, fsys, conn.Intercept)
// Check error
assert.ErrorContains(t, err, "ERROR: permission denied for relation information_schema (SQLSTATE 42501)")
})

t.Run("throws error on drop schema failure", func(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
conn.Query(strings.ReplaceAll(LIST_SCHEMAS, "$1", "'{public,auth,extensions,pgbouncer,realtime,\"\\\\_realtime\",storage,\"\\\\_analytics\",\"supabase\\\\_functions\",\"supabase\\\\_migrations\",\"information\\\\_schema\",\"pg\\\\_%\",cron,graphql,\"graphql\\\\_public\",net,pgsodium,\"pgsodium\\\\_masks\",pgtle,repack,tiger,\"tiger\\\\_data\",\"timescaledb\\\\_%\",\"\\\\_timescaledb\\\\_%\",topology,vault}'")).
Reply("SELECT 0").
Query("DROP SCHEMA IF EXISTS supabase_migrations CASCADE").
ReplyError(pgerrcode.InsufficientPrivilege, "permission denied for relation supabase_migrations").
Query(dropObjects)
// Run test
err := resetRemote(context.Background(), dbConfig, fsys, conn.Intercept)
// Check error
assert.ErrorContains(t, err, "ERROR: permission denied for relation supabase_migrations (SQLSTATE 42501)")
})
}
35 changes: 35 additions & 0 deletions internal/db/reset/templates/drop.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
do $$ declare
rec record;
begin
-- functions
for rec in
select *
from pg_proc p
where p.pronamespace::regnamespace::name = 'public'
loop
-- supports aggregate, function, and procedure
execute format('drop routine if exists %I.%I(%s) cascade', rec.pronamespace::regnamespace::name, rec.proname, pg_catalog.pg_get_function_identity_arguments(rec.oid));
end loop;

-- in order: tables (cascade to views), sequences
for rec in
select *
from pg_class c
where
c.relnamespace::regnamespace::name = 'public'
and c.relkind not in ('c', 'v', 'm')
order by c.relkind desc
loop
-- supports all kinds of relations, except views and complex types
execute format('drop table if exists %I.%I cascade', rec.relnamespace::regnamespace::name, rec.relname);
end loop;

-- types
for rec in
select *
from pg_type t
where t.typnamespace::regnamespace::name = 'public'
loop
execute format('drop type if exists %I.%I cascade', rec.typnamespace::regnamespace::name, rec.typname);
end loop;
end $$;