diff --git a/cmd/db.go b/cmd/db.go index cc8bca65b..9558dc274 100644 --- a/cmd/db.go +++ b/cmd/db.go @@ -271,7 +271,7 @@ func init() { dumpFlags := dbDumpCmd.Flags() dumpFlags.BoolVar(&dryRun, "dry-run", false, "Prints the pg_dump script that would be executed.") dumpFlags.BoolVar(&dataOnly, "data-only", false, "Dumps only data records.") - dumpFlags.BoolVar(&useCopy, "use-copy", false, "Uses copy statements in place of inserts.") + dumpFlags.BoolVar(&useCopy, "use-copy", false, "Use copy statements in place of inserts.") dumpFlags.StringSliceVarP(&excludeTable, "exclude", "x", []string{}, "List of schema.tables to exclude from data-only dump.") dumpFlags.BoolVar(&roleOnly, "role-only", false, "Dumps only cluster roles.") dbDumpCmd.MarkFlagsMutuallyExclusive("role-only", "data-only") diff --git a/cmd/link.go b/cmd/link.go index 9ed09fd86..f2422ff48 100644 --- a/cmd/link.go +++ b/cmd/link.go @@ -13,6 +13,8 @@ import ( ) var ( + skipPooler bool + linkCmd = &cobra.Command{ GroupID: groupLocalDev, Use: "link", @@ -35,7 +37,7 @@ var ( } // TODO: move this to root cmd cobra.CheckErr(viper.BindPFlag("DB_PASSWORD", cmd.Flags().Lookup("password"))) - return link.Run(ctx, flags.ProjectRef, fsys) + return link.Run(ctx, flags.ProjectRef, skipPooler, fsys) }, } ) @@ -44,6 +46,7 @@ func init() { linkFlags := linkCmd.Flags() linkFlags.StringVar(&flags.ProjectRef, "project-ref", "", "Project ref of the Supabase project.") linkFlags.StringVarP(&dbPassword, "password", "p", "", "Password to your remote Postgres database.") + linkFlags.BoolVar(&skipPooler, "skip-pooler", false, "Use direct connection instead of pooler.") // For some reason, BindPFlag only works for StringVarP instead of StringP rootCmd.AddCommand(linkCmd) } diff --git a/internal/bootstrap/bootstrap.go b/internal/bootstrap/bootstrap.go index b2986656a..3fd02fe65 100644 --- a/internal/bootstrap/bootstrap.go +++ b/internal/bootstrap/bootstrap.go @@ -99,7 +99,7 @@ func Run(ctx context.Context, starter StarterTemplate, fsys afero.Fs, options .. if err := flags.LoadConfig(fsys); err != nil { return err } - link.LinkServices(ctx, flags.ProjectRef, tenant.NewApiKey(keys).Anon, fsys) + link.LinkServices(ctx, flags.ProjectRef, tenant.NewApiKey(keys).Anon, false, fsys) if err := utils.WriteFile(utils.ProjectRefPath, []byte(flags.ProjectRef), fsys); err != nil { return err } diff --git a/internal/link/link.go b/internal/link/link.go index 73635aea8..d3e48fb22 100644 --- a/internal/link/link.go +++ b/internal/link/link.go @@ -22,7 +22,7 @@ import ( "github.com/supabase/cli/pkg/queue" ) -func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { +func Run(ctx context.Context, projectRef string, skipPooler bool, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { majorVersion := utils.Config.Db.MajorVersion if err := checkRemoteProjectStatus(ctx, projectRef, fsys); err != nil { return err @@ -33,7 +33,7 @@ func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func( if err != nil { return err } - LinkServices(ctx, projectRef, keys.ServiceRole, fsys) + LinkServices(ctx, projectRef, keys.ServiceRole, skipPooler, fsys) // 2. Check database connection config := flags.NewDbConfigWithPassword(ctx, projectRef) @@ -58,21 +58,35 @@ major_version = %d return nil } -func LinkServices(ctx context.Context, projectRef, serviceKey string, fsys afero.Fs) { +func LinkServices(ctx context.Context, projectRef, serviceKey string, skipPooler bool, fsys afero.Fs) { jq := queue.NewJobQueue(5) - logger := utils.GetDebugLogger() - fmt.Fprintln(logger, jq.Put(func() error { return linkDatabaseSettings(ctx, projectRef) })) - fmt.Fprintln(logger, jq.Put(func() error { return linkNetworkRestrictions(ctx, projectRef) })) - fmt.Fprintln(logger, jq.Put(func() error { return linkPostgrest(ctx, projectRef) })) - fmt.Fprintln(logger, jq.Put(func() error { return linkGotrue(ctx, projectRef) })) - fmt.Fprintln(logger, jq.Put(func() error { return linkStorage(ctx, projectRef) })) - fmt.Fprintln(logger, jq.Put(func() error { return linkPooler(ctx, projectRef, fsys) })) api := tenant.NewTenantAPI(ctx, projectRef, serviceKey) - fmt.Fprintln(logger, jq.Put(func() error { return linkPostgrestVersion(ctx, api, fsys) })) - fmt.Fprintln(logger, jq.Put(func() error { return linkGotrueVersion(ctx, api, fsys) })) - fmt.Fprintln(logger, jq.Put(func() error { return linkStorageVersion(ctx, api, fsys) })) + jobs := []func() error{ + func() error { return linkDatabaseSettings(ctx, projectRef) }, + func() error { return linkNetworkRestrictions(ctx, projectRef) }, + func() error { return linkPostgrest(ctx, projectRef) }, + func() error { return linkGotrue(ctx, projectRef) }, + func() error { return linkStorage(ctx, projectRef) }, + func() error { + if skipPooler { + return fsys.RemoveAll(utils.PoolerUrlPath) + } + return linkPooler(ctx, projectRef, fsys) + }, + func() error { return linkPostgrestVersion(ctx, api, fsys) }, + func() error { return linkGotrueVersion(ctx, api, fsys) }, + func() error { return linkStorageVersion(ctx, api, fsys) }, + } // Ignore non-fatal errors linking services - fmt.Fprintln(logger, jq.Collect()) + logger := utils.GetDebugLogger() + for _, job := range jobs { + if err := jq.Put(job); err != nil { + fmt.Fprintln(logger, err) + } + } + if err := jq.Collect(); err != nil { + fmt.Fprintln(logger, err) + } } func linkPostgrest(ctx context.Context, projectRef string) error { @@ -129,7 +143,6 @@ func linkStorageVersion(ctx context.Context, api tenant.TenantAPI, fsys afero.Fs if err != nil { return err } - fmt.Fprintln(os.Stderr, version) return utils.WriteFile(utils.StorageVersionPath, []byte(version), fsys) } diff --git a/internal/link/link_test.go b/internal/link/link_test.go index aff34d937..c6a7339e0 100644 --- a/internal/link/link_test.go +++ b/internal/link/link_test.go @@ -119,7 +119,7 @@ func TestLinkCommand(t *testing.T) { Reply(200). BodyString(storage) // Run test - err := Run(context.Background(), project, fsys, conn.Intercept) + err := Run(context.Background(), project, false, fsys, conn.Intercept) // Check error assert.NoError(t, err) assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -192,7 +192,7 @@ func TestLinkCommand(t *testing.T) { Get("/storage/v1/version"). ReplyError(errors.New("network error")) // Run test - err := Run(context.Background(), project, fsys, func(cc *pgx.ConnConfig) { + err := Run(context.Background(), project, false, fsys, func(cc *pgx.ConnConfig) { cc.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { return nil, errors.New("hostname resolving error") } @@ -269,7 +269,7 @@ func TestLinkCommand(t *testing.T) { Get("/v1/projects"). ReplyError(errors.New("network error")) // Run test - err := Run(context.Background(), project, fsys, conn.Intercept) + err := Run(context.Background(), project, false, fsys, conn.Intercept) // Check error assert.ErrorContains(t, err, "operation not permitted") assert.Empty(t, apitest.ListUnmatchedRequests())