diff --git a/internal/bans/get/get.go b/internal/bans/get/get.go index a9e52901d..5b747325c 100644 --- a/internal/bans/get/get.go +++ b/internal/bans/get/get.go @@ -4,23 +4,15 @@ import ( "context" "fmt" - "github.com/go-errors/errors" "github.com/spf13/afero" - "github.com/supabase/cli/internal/utils" + "github.com/supabase/cli/internal/utils/flags" ) func Run(ctx context.Context, projectRef string, fsys afero.Fs) error { - // 1. Sanity checks. - // 2. get network bans - { - resp, err := utils.GetSupabase().V1ListAllNetworkBansWithResponse(ctx, projectRef) - if err != nil { - return errors.Errorf("failed to retrieve network bans: %w", err) - } - if resp.JSON201 == nil { - return errors.New("Unexpected error retrieving network bans: " + string(resp.Body)) - } - fmt.Printf("DB banned IPs: %+v\n", resp.JSON201.BannedIpv4Addresses) - return nil + ips, err := flags.ListNetworkBans(ctx, projectRef) + if err != nil { + return err } + fmt.Printf("DB banned IPs: %+v\n", ips) + return nil } diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go index 3fd02fe65..6e93c0acb 100644 --- a/internal/bootstrap/bootstrap.go +++ b/internal/bootstrap/bootstrap.go @@ -112,7 +112,10 @@ func Run(ctx context.Context, starter StarterTemplate, fsys afero.Fs, options .. return err } // 6. Push migrations - config := flags.NewDbConfigWithPassword(ctx, flags.ProjectRef) + config, err := flags.NewDbConfigWithPassword(ctx, flags.ProjectRef) + if err != nil { + fmt.Fprintln(os.Stderr, err) + } if err := writeDotEnv(keys, config, fsys); err != nil { fmt.Fprintln(os.Stderr, "Failed to create .env file:", err) } diff --git a/internal/branches/get/get.go b/internal/branches/get/get.go index a207901cc..1f726ff77 100644 --- a/internal/branches/get/get.go +++ b/internal/branches/get/get.go @@ -27,7 +27,7 @@ func Run(ctx context.Context, branchId string, fsys afero.Fs) error { if err != nil { return err } - pooler, err := getPoolerConfig(ctx, detail.Ref) + pooler, err := utils.GetPoolerConfigPrimary(ctx, detail.Ref) if err != nil { return err } @@ -81,22 +81,6 @@ func getBranchDetail(ctx context.Context, branchId string) (api.BranchDetailResp return *resp.JSON200, nil } -func getPoolerConfig(ctx context.Context, ref string) (api.SupavisorConfigResponse, error) { - var result api.SupavisorConfigResponse - resp, err := utils.GetSupabase().V1GetPoolerConfigWithResponse(ctx, ref) - if err != nil { - return result, errors.Errorf("failed to get pooler: %w", err) - } else if resp.JSON200 == nil { - return result, errors.Errorf("unexpected get pooler status %d: %s", resp.StatusCode(), string(resp.Body)) - } - for _, config := range *resp.JSON200 { - if config.DatabaseType == api.SupavisorConfigResponseDatabaseTypePRIMARY { - return config, nil - } - } - return result, errors.Errorf("primary database not found: %s", ref) -} - func toStandardEnvs(detail api.BranchDetailResponse, pooler api.SupavisorConfigResponse, keys []api.ApiKeyResponse) map[string]string { direct := pgconn.Config{ Host: detail.DbHost, diff --git a/internal/link/link.go b/internal/link/link.go index 98e3b4814..9c7d0ce97 100644 --- a/internal/link/link.go +++ b/internal/link/link.go @@ -18,7 +18,6 @@ import ( "github.com/supabase/cli/pkg/api" "github.com/supabase/cli/pkg/cast" cliConfig "github.com/supabase/cli/pkg/config" - "github.com/supabase/cli/pkg/migration" "github.com/supabase/cli/pkg/queue" ) @@ -36,8 +35,9 @@ func Run(ctx context.Context, projectRef string, skipPooler bool, fsys afero.Fs, LinkServices(ctx, projectRef, keys.ServiceRole, skipPooler, fsys) // 2. Check database connection - config := flags.NewDbConfigWithPassword(ctx, projectRef) - if err := linkDatabase(ctx, config, fsys, options...); err != nil { + if config, err := flags.NewDbConfigWithPassword(ctx, projectRef); err != nil { + fmt.Fprintln(os.Stderr, utils.Yellow("WARN:"), err) + } else if err := linkDatabase(ctx, config, fsys, options...); err != nil { return err } @@ -186,14 +186,7 @@ func linkDatabase(ctx context.Context, config pgconn.Config, fsys afero.Fs, opti } defer conn.Close(context.Background()) updatePostgresConfig(conn) - if err := linkStorageMigration(ctx, conn, fsys); err != nil { - fmt.Fprintln(os.Stderr, err) - } - // If `schema_migrations` doesn't exist on the remote database, create it. - if err := migration.CreateMigrationTable(ctx, conn); err != nil { - return err - } - return migration.CreateSeedTable(ctx, conn) + return linkStorageMigration(ctx, conn, fsys) } func updatePostgresConfig(conn *pgx.Conn) { @@ -207,18 +200,11 @@ func updatePostgresConfig(conn *pgx.Conn) { } func linkPooler(ctx context.Context, projectRef string, fsys afero.Fs) error { - resp, err := utils.GetSupabase().V1GetPoolerConfigWithResponse(ctx, projectRef) + primary, err := utils.GetPoolerConfigPrimary(ctx, projectRef) if err != nil { - return errors.Errorf("failed to get pooler config: %w", err) - } - if resp.JSON200 == nil { - return errors.Errorf("%w: %s", tenant.ErrAuthToken, string(resp.Body)) - } - for _, config := range *resp.JSON200 { - if config.DatabaseType == api.SupavisorConfigResponseDatabaseTypePRIMARY { - updatePoolerConfig(config) - } + return err } + updatePoolerConfig(primary) return utils.WriteFile(utils.PoolerUrlPath, []byte(utils.Config.Db.Pooler.ConnectionString), fsys) } diff --git a/internal/link/link_test.go b/internal/link/link_test.go index c6a7339e0..bf7c85761 100644 --- a/internal/link/link_test.go +++ b/internal/link/link_test.go @@ -15,11 +15,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/supabase/cli/internal/testing/apitest" "github.com/supabase/cli/internal/testing/fstest" - "github.com/supabase/cli/internal/testing/helper" "github.com/supabase/cli/internal/utils" "github.com/supabase/cli/internal/utils/tenant" "github.com/supabase/cli/pkg/api" - "github.com/supabase/cli/pkg/migration" "github.com/supabase/cli/pkg/pgtest" "github.com/zalando/go-keyring" ) @@ -44,15 +42,6 @@ func TestLinkCommand(t *testing.T) { t.Cleanup(fstest.MockStdin(t, "\n")) // Setup in-memory fs fsys := afero.NewMemMapFs() - // Setup mock postgres - conn := pgtest.NewConn() - defer conn.Close(t) - conn.Query(utils.SET_SESSION_ROLE). - Reply("SET ROLE"). - Query(GET_LATEST_STORAGE_MIGRATION). - Reply("SELECT 1", []any{"custom-metadata"}) - helper.MockMigrationHistory(conn) - helper.MockSeedHistory(conn) // Flush pending mocks after test execution defer gock.OffAll() // Mock project status @@ -119,7 +108,7 @@ func TestLinkCommand(t *testing.T) { Reply(200). BodyString(storage) // Run test - err := Run(context.Background(), project, false, fsys, conn.Intercept) + err := Run(context.Background(), project, false, fsys) // Check error assert.NoError(t, err) assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -198,22 +187,13 @@ func TestLinkCommand(t *testing.T) { } }) // Check error - assert.ErrorContains(t, err, "hostname resolving error") + assert.NoError(t, err) assert.Empty(t, apitest.ListUnmatchedRequests()) }) t.Run("throws error on write failure", func(t *testing.T) { // Setup in-memory fs fsys := afero.NewReadOnlyFs(afero.NewMemMapFs()) - // Setup mock postgres - conn := pgtest.NewConn() - defer conn.Close(t) - conn.Query(utils.SET_SESSION_ROLE). - Reply("SET ROLE"). - Query(GET_LATEST_STORAGE_MIGRATION). - Reply("SELECT 1", []any{"custom-metadata"}) - helper.MockMigrationHistory(conn) - helper.MockSeedHistory(conn) // Flush pending mocks after test execution defer gock.OffAll() // Mock project status @@ -269,7 +249,7 @@ func TestLinkCommand(t *testing.T) { Get("/v1/projects"). ReplyError(errors.New("network error")) // Run test - err := Run(context.Background(), project, false, fsys, conn.Intercept) + err := Run(context.Background(), project, false, fsys) // Check error assert.ErrorContains(t, err, "operation not permitted") assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -419,6 +399,23 @@ func TestLinkPostgrest(t *testing.T) { } func TestLinkDatabase(t *testing.T) { + t.Run("syncs storage migration", func(t *testing.T) { + // Setup in-memory fs + fsys := afero.NewMemMapFs() + // Setup mock postgres + conn := pgtest.NewConn() + defer conn.Close(t) + conn.Query(GET_LATEST_STORAGE_MIGRATION). + Reply("SELECT 1", []any{"custom-metadata"}) + // Run test + err := linkDatabase(context.Background(), dbConfig, fsys, conn.Intercept) + // Check error + assert.NoError(t, err) + storage, err := afero.ReadFile(fsys, utils.StorageMigrationPath) + assert.NoError(t, err) + assert.Equal(t, []byte("custom-metadata"), storage) + }) + t.Run("throws error on connect failure", func(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() @@ -438,8 +435,6 @@ func TestLinkDatabase(t *testing.T) { defer conn.Close(t) conn.Query(GET_LATEST_STORAGE_MIGRATION). Reply("SELECT 1", []any{"custom-metadata"}) - helper.MockMigrationHistory(conn) - helper.MockSeedHistory(conn) // Run test err := linkDatabase(context.Background(), dbConfig, fsys, conn.Intercept) // Check error @@ -461,8 +456,6 @@ func TestLinkDatabase(t *testing.T) { defer conn.Close(t) conn.Query(GET_LATEST_STORAGE_MIGRATION). Reply("SELECT 1", []any{"custom-metadata"}) - helper.MockMigrationHistory(conn) - helper.MockSeedHistory(conn) // Run test err := linkDatabase(context.Background(), dbConfig, fsys, conn.Intercept) // Check error @@ -481,18 +474,11 @@ func TestLinkDatabase(t *testing.T) { conn := pgtest.NewConn() defer conn.Close(t) conn.Query(GET_LATEST_STORAGE_MIGRATION). - ReplyError(pgerrcode.InsufficientPrivilege, "permission denied for relation migrations"). - Query(migration.SET_LOCK_TIMEOUT). - Query(migration.CREATE_VERSION_SCHEMA). - Reply("CREATE SCHEMA"). - Query(migration.CREATE_VERSION_TABLE). - ReplyError(pgerrcode.InsufficientPrivilege, "permission denied for relation supabase_migrations"). - Query(migration.ADD_STATEMENTS_COLUMN). - Query(migration.ADD_NAME_COLUMN) + ReplyError(pgerrcode.InsufficientPrivilege, "permission denied for relation migrations") // Run test err := linkDatabase(context.Background(), dbConfig, fsys, conn.Intercept) // Check error - assert.ErrorContains(t, err, "ERROR: permission denied for relation supabase_migrations (SQLSTATE 42501)") + assert.ErrorContains(t, err, "ERROR: permission denied for relation migrations (SQLSTATE 42501)") exists, err := afero.Exists(fsys, utils.StorageMigrationPath) assert.NoError(t, err) assert.False(t, exists) diff --git a/internal/projects/create/create.go b/internal/projects/create/create.go index 6be3fc8ab..9b8e03d1a 100644 --- a/internal/projects/create/create.go +++ b/internal/projects/create/create.go @@ -10,7 +10,6 @@ import ( "github.com/spf13/afero" "github.com/spf13/viper" "github.com/supabase/cli/internal/utils" - "github.com/supabase/cli/internal/utils/credentials" "github.com/supabase/cli/internal/utils/flags" "github.com/supabase/cli/pkg/api" ) @@ -30,9 +29,6 @@ func Run(ctx context.Context, params api.V1CreateProjectBody, fsys afero.Fs) err flags.ProjectRef = resp.JSON201.Id viper.Set("DB_PASSWORD", params.DbPass) - if err := credentials.StoreProvider.Set(flags.ProjectRef, params.DbPass); err != nil { - fmt.Fprintln(os.Stderr, "Failed to save database password:", err) - } projectUrl := fmt.Sprintf("%s/project/%s", utils.GetSupabaseDashboardURL(), resp.JSON201.Id) fmt.Fprintf(os.Stderr, "Created a new project at %s\n", utils.Bold(projectUrl)) diff --git a/internal/utils/connect.go b/internal/utils/connect.go index bd8d199f7..5fc4e3b9b 100644 --- a/internal/utils/connect.go +++ b/internal/utils/connect.go @@ -15,6 +15,7 @@ import ( "github.com/jackc/pgx/v4" "github.com/spf13/viper" "github.com/supabase/cli/internal/debug" + "github.com/supabase/cli/pkg/api" "github.com/supabase/cli/pkg/pgxv5" "golang.org/x/net/publicsuffix" ) @@ -43,6 +44,22 @@ func ToPostgresURL(config pgconn.Config) string { ) } +func GetPoolerConfigPrimary(ctx context.Context, ref string) (api.SupavisorConfigResponse, error) { + var result api.SupavisorConfigResponse + resp, err := GetSupabase().V1GetPoolerConfigWithResponse(ctx, ref) + if err != nil { + return result, errors.Errorf("failed to get pooler: %w", err) + } else if resp.JSON200 == nil { + return result, errors.Errorf("unexpected get pooler status %d: %s", resp.StatusCode(), string(resp.Body)) + } + for _, config := range *resp.JSON200 { + if config.DatabaseType == api.SupavisorConfigResponseDatabaseTypePRIMARY { + return config, nil + } + } + return result, errors.Errorf("primary database not found: %s", ref) +} + func GetPoolerConfig(projectRef string) *pgconn.Config { logger := GetDebugLogger() if len(Config.Db.Pooler.ConnectionString) == 0 { diff --git a/internal/utils/flags/db_url.go b/internal/utils/flags/db_url.go index 7ee2417e5..4dca1d936 100644 --- a/internal/utils/flags/db_url.go +++ b/internal/utils/flags/db_url.go @@ -6,6 +6,7 @@ import ( _ "embed" "fmt" "math/big" + "net" "net/http" "os" "strings" @@ -82,7 +83,10 @@ func ParseDatabaseConfig(ctx context.Context, flagSet *pflag.FlagSet, fsys afero if err := LoadConfig(fsys); err != nil { return err } - DbConfig = NewDbConfigWithPassword(ctx, ProjectRef) + var err error + if DbConfig, err = NewDbConfigWithPassword(ctx, ProjectRef); err != nil { + return err + } case proxy: token, err := utils.LoadAccessTokenFS(fsys) if err != nil { @@ -115,67 +119,96 @@ func RandomString(size int) (string, error) { return string(data), nil } -func NewDbConfigWithPassword(ctx context.Context, projectRef string) pgconn.Config { - config := getDbConfig(projectRef) - config.Password = viper.GetString("DB_PASSWORD") +const suggestEnvVar = "Connect to your database by setting the env var: SUPABASE_DB_PASSWORD" + +func NewDbConfigWithPassword(ctx context.Context, projectRef string) (pgconn.Config, error) { + config := pgconn.Config{ + Host: utils.GetSupabaseDbHost(projectRef), + Port: 5432, + User: "postgres", + Password: viper.GetString("DB_PASSWORD"), + Database: "postgres", + } + logger := utils.GetDebugLogger() + // Use pooler if host is not reachable directly + if _, err := net.DefaultResolver.LookupIPAddr(ctx, config.Host); err != nil { + if poolerConfig := utils.GetPoolerConfig(projectRef); poolerConfig != nil { + if len(config.Password) > 0 { + fmt.Fprintln(logger, "Using database password from env var...") + poolerConfig.Password = config.Password + } else if err := initPoolerLogin(ctx, projectRef, poolerConfig); err != nil { + utils.CmdSuggestion = suggestEnvVar + return *poolerConfig, err + } + return *poolerConfig, nil + } + utils.CmdSuggestion = fmt.Sprintf("Run %s to setup IPv4 connection.", utils.Aqua("supabase link --project-ref "+projectRef)) + return config, errors.Errorf("IPv6 is not supported on your current network: %w", err) + } + // Connect via direct connection if len(config.Password) > 0 { - return config - } - loginRole, err := initLoginRole(ctx, projectRef, config) - if err == nil { - return loginRole - } else if errors.Is(err, context.Canceled) { - return config - } - // Proceed with password prompt - fmt.Fprintln(utils.GetDebugLogger(), err) - if config.Password, err = credentials.StoreProvider.Get(projectRef); err == nil { - return config - } - resetUrl := fmt.Sprintf("%s/project/%s/settings/database", utils.GetSupabaseDashboardURL(), projectRef) - fmt.Fprintln(os.Stderr, "Forgot your password? Reset it from the Dashboard:", utils.Bold(resetUrl)) - fmt.Fprint(os.Stderr, "Enter your database password: ") - config.Password = credentials.PromptMasked(os.Stdin) - return config + fmt.Fprintln(logger, "Using database password from env var...") + } else if err := initLoginRole(ctx, projectRef, &config); err != nil { + // Do not prompt because reading masked input is buggy on windows + utils.CmdSuggestion = suggestEnvVar + return config, err + } + return config, nil } -func initLoginRole(ctx context.Context, projectRef string, config pgconn.Config) (pgconn.Config, error) { +func initLoginRole(ctx context.Context, projectRef string, config *pgconn.Config) error { fmt.Fprintln(os.Stderr, "Initialising login role...") body := api.CreateRoleBody{ReadOnly: false} resp, err := utils.GetSupabase().V1CreateLoginRoleWithResponse(ctx, projectRef, body) if err != nil { - return pgconn.Config{}, errors.Errorf("failed to initialise login role: %w", err) + return errors.Errorf("failed to initialise login role: %w", err) } else if resp.JSON201 == nil { - return pgconn.Config{}, errors.Errorf("unexpected login role status %d: %s", resp.StatusCode(), string(resp.Body)) + return errors.Errorf("unexpected login role status %d: %s", resp.StatusCode(), string(resp.Body)) + } + config.User = resp.JSON201.Role + config.Password = resp.JSON201.Password + return nil +} + +func initPoolerLogin(ctx context.Context, projectRef string, poolerConfig *pgconn.Config) error { + poolerUser := poolerConfig.User + if err := initLoginRole(ctx, projectRef, poolerConfig); err != nil { + return err } - // Direct connection can be tried immediately suffix := "." + projectRef - if !strings.HasSuffix(config.User, suffix) { - config.User = resp.JSON201.Role - config.Password = resp.JSON201.Password - return config, nil + if strings.HasSuffix(poolerUser, suffix) { + poolerConfig.User += suffix } // Wait for pooler to refresh password - config.User = resp.JSON201.Role + suffix - config.Password = resp.JSON201.Password login := func() error { - conn, err := pgconn.ConnectConfig(ctx, &config) + conn, err := pgconn.ConnectConfig(ctx, poolerConfig) if err != nil { return errors.Errorf("failed to connect as temp role: %w", err) } return conn.Close(ctx) } - // Fallback to password prompt on error notify := utils.NewErrorCallback(func(attempt uint) error { - if attempt%3 > 0 { + if attempt < 3 { return nil } - return UnbanIP(ctx, projectRef) + if ips, err := ListNetworkBans(ctx, projectRef); err != nil { + return err + } else if len(ips) > 0 { + return UnbanIP(ctx, projectRef, ips...) + } + return nil }) - if err := backoff.RetryNotify(login, utils.NewBackoffPolicy(ctx), notify); err != nil { - return pgconn.Config{}, err + return backoff.RetryNotify(login, utils.NewBackoffPolicy(ctx), notify) +} + +func ListNetworkBans(ctx context.Context, projectRef string) ([]string, error) { + resp, err := utils.GetSupabase().V1ListAllNetworkBansWithResponse(ctx, projectRef) + if err != nil { + return nil, errors.Errorf("failed to list network bans: %w", err) + } else if resp.JSON201 == nil { + return nil, errors.Errorf("unexpected list bans status %d: %s", resp.StatusCode(), string(resp.Body)) } - return config, nil + return resp.JSON201.BannedIpv4Addresses, nil } func UnbanIP(ctx context.Context, projectRef string, addrs ...string) error { @@ -214,15 +247,3 @@ func PromptPassword(stdin *os.File) string { } return string(password) } - -func getDbConfig(projectRef string) pgconn.Config { - if poolerConfig := utils.GetPoolerConfig(projectRef); poolerConfig != nil { - return *poolerConfig - } - return pgconn.Config{ - Host: utils.GetSupabaseDbHost(projectRef), - Port: 5432, - User: "postgres", - Database: "postgres", - } -} diff --git a/internal/utils/flags/db_url_test.go b/internal/utils/flags/db_url_test.go index 1bb3fcc79..c79c7f936 100644 --- a/internal/utils/flags/db_url_test.go +++ b/internal/utils/flags/db_url_test.go @@ -2,11 +2,14 @@ package flags import ( "context" + "fmt" "os" + "strings" "testing" "github.com/spf13/afero" "github.com/spf13/pflag" + "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/supabase/cli/internal/testing/apitest" @@ -70,10 +73,15 @@ func TestParseDatabaseConfig(t *testing.T) { err = afero.WriteFile(fsys, utils.ProjectRefPath, []byte(project), 0644) require.NoError(t, err) + dbURL := fmt.Sprintf("postgres://postgres:postgres@db.%s.supabase.co:6543/postgres", project) + err = afero.WriteFile(fsys, utils.PoolerUrlPath, []byte(dbURL), 0644) + require.NoError(t, err) + + viper.Set("DB_PASSWORD", "test") err = ParseDatabaseConfig(context.Background(), flagSet, fsys) assert.NoError(t, err) - assert.Equal(t, utils.GetSupabaseDbHost(project), DbConfig.Host) + assert.True(t, strings.HasPrefix(DbConfig.Host, utils.GetSupabaseDbHost(project))) }) }