diff --git a/internal/link/link.go b/internal/link/link.go index 0c2aa1de4..8a75a8d3d 100644 --- a/internal/link/link.go +++ b/internal/link/link.go @@ -1,7 +1,6 @@ package link import ( - "bytes" "context" "fmt" "os" @@ -9,7 +8,6 @@ import ( "strings" "sync" - "github.com/BurntSushi/toml" "github.com/go-errors/errors" "github.com/jackc/pgconn" "github.com/jackc/pgx/v4" @@ -20,15 +18,20 @@ import ( "github.com/supabase/cli/internal/utils/flags" "github.com/supabase/cli/internal/utils/tenant" "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" cliConfig "github.com/supabase/cli/pkg/config" + "github.com/supabase/cli/pkg/diff" "github.com/supabase/cli/pkg/migration" ) func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { - original := toTomlBytes(map[string]interface{}{ + original, err := cliConfig.ToTomlBytes(map[string]interface{}{ "api": utils.Config.Api, "db": utils.Config.Db, }) + if err != nil { + fmt.Fprintln(utils.GetDebugLogger(), err) + } if err := checkRemoteProjectStatus(ctx, projectRef); err != nil { return err @@ -60,28 +63,21 @@ func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func( fmt.Fprintln(os.Stdout, "Finished "+utils.Aqua("supabase link")+".") // 4. Suggest config update - updated := toTomlBytes(map[string]interface{}{ + updated, err := cliConfig.ToTomlBytes(map[string]interface{}{ "api": utils.Config.Api, "db": utils.Config.Db, }) - // if lineDiff := cmp.Diff(original, updated); len(lineDiff) > 0 { - if lineDiff := Diff(utils.ConfigPath, original, projectRef, updated); len(lineDiff) > 0 { + if err != nil { + fmt.Fprintln(utils.GetDebugLogger(), err) + } + + if lineDiff := diff.Diff(utils.ConfigPath, original, projectRef, updated); len(lineDiff) > 0 { fmt.Fprintln(os.Stderr, utils.Yellow("WARNING:"), "Local config differs from linked project. Try updating", utils.Bold(utils.ConfigPath)) fmt.Println(string(lineDiff)) } return nil } -func toTomlBytes(config any) []byte { - var buf bytes.Buffer - enc := toml.NewEncoder(&buf) - enc.Indent = "" - if err := enc.Encode(config); err != nil { - fmt.Fprintln(utils.GetDebugLogger(), "failed to marshal toml config:", err) - } - return buf.Bytes() -} - func LinkServices(ctx context.Context, projectRef, anonKey string, fsys afero.Fs) { // Ignore non-fatal errors linking services var wg sync.WaitGroup @@ -147,7 +143,7 @@ func linkPostgrestVersion(ctx context.Context, api tenant.TenantAPI, fsys afero. } func updateApiConfig(config api.PostgrestConfigWithJWTSecretResponse) { - utils.Config.Api.MaxRows = uint(config.MaxRows) + utils.Config.Api.MaxRows = cast.IntToUint(config.MaxRows) utils.Config.Api.ExtraSearchPath = readCsv(config.DbExtraSearchPath) utils.Config.Api.Schemas = readCsv(config.DbSchema) } diff --git a/pkg/cast/cast.go b/pkg/cast/cast.go new file mode 100644 index 000000000..b72cadbbf --- /dev/null +++ b/pkg/cast/cast.go @@ -0,0 +1,25 @@ +package cast + +import "math" + +// UintToInt converts a uint to an int, handling potential overflow +func UintToInt(value uint) int { + if value <= math.MaxInt { + result := int(value) + return result + } + maxInt := math.MaxInt + return maxInt +} + +// IntToUint converts an int to a uint, handling negative values +func IntToUint(value int) uint { + if value < 0 { + return 0 + } + return uint(value) +} + +func Ptr[T any](v T) *T { + return &v +} diff --git a/pkg/config/api.go b/pkg/config/api.go new file mode 100644 index 000000000..403e290e5 --- /dev/null +++ b/pkg/config/api.go @@ -0,0 +1,101 @@ +package config + +import ( + "strings" + + v1API "github.com/supabase/cli/pkg/api" + "github.com/supabase/cli/pkg/cast" + "github.com/supabase/cli/pkg/diff" +) + +type ( + api struct { + Enabled bool `toml:"enabled"` + Schemas []string `toml:"schemas"` + ExtraSearchPath []string `toml:"extra_search_path"` + MaxRows uint `toml:"max_rows"` + // Local only config + Image string `toml:"-"` + KongImage string `toml:"-"` + Port uint16 `toml:"port"` + Tls tlsKong `toml:"tls"` + // TODO: replace [auth|studio].api_url + ExternalUrl string `toml:"external_url"` + } + + tlsKong struct { + Enabled bool `toml:"enabled"` + } +) + +func (a *api) ToUpdatePostgrestConfigBody() v1API.UpdatePostgrestConfigBody { + body := v1API.UpdatePostgrestConfigBody{} + + // When the api is disabled, remote side it just set the dbSchema to an empty value + if !a.Enabled { + body.DbSchema = cast.Ptr("") + return body + } + + // Convert Schemas to a comma-separated string + if len(a.Schemas) > 0 { + schemas := strings.Join(a.Schemas, ",") + body.DbSchema = &schemas + } + + // Convert ExtraSearchPath to a comma-separated string + if len(a.ExtraSearchPath) > 0 { + extraSearchPath := strings.Join(a.ExtraSearchPath, ",") + body.DbExtraSearchPath = &extraSearchPath + } + + // Convert MaxRows to int pointer + if a.MaxRows > 0 { + body.MaxRows = cast.Ptr(cast.UintToInt(a.MaxRows)) + } + + // Note: DbPool is not present in the Api struct, so it's not set here + return body +} + +func (a *api) fromRemoteApiConfig(remoteConfig v1API.PostgrestConfigWithJWTSecretResponse) api { + result := *a + if remoteConfig.DbSchema == "" { + result.Enabled = false + return result + } + + result.Enabled = true + // Update Schemas if present in remoteConfig + schemas := strings.Split(remoteConfig.DbSchema, ",") + result.Schemas = make([]string, len(schemas)) + // TODO: use slices.Map when upgrade go version + for i, schema := range schemas { + result.Schemas[i] = strings.TrimSpace(schema) + } + + // Update ExtraSearchPath if present in remoteConfig + extraSearchPath := strings.Split(remoteConfig.DbExtraSearchPath, ",") + result.ExtraSearchPath = make([]string, len(extraSearchPath)) + for i, path := range extraSearchPath { + result.ExtraSearchPath[i] = strings.TrimSpace(path) + } + + // Update MaxRows if present in remoteConfig + result.MaxRows = cast.IntToUint(remoteConfig.MaxRows) + + return result +} + +func (a *api) DiffWithRemote(remoteConfig v1API.PostgrestConfigWithJWTSecretResponse) ([]byte, error) { + // Convert the config values into easily comparable remoteConfig values + currentValue, err := ToTomlBytes(a) + if err != nil { + return nil, err + } + remoteCompare, err := ToTomlBytes(a.fromRemoteApiConfig(remoteConfig)) + if err != nil { + return nil, err + } + return diff.Diff("remote[api]", remoteCompare, "local[api]", currentValue), nil +} diff --git a/pkg/config/api_test.go b/pkg/config/api_test.go new file mode 100644 index 000000000..d48001244 --- /dev/null +++ b/pkg/config/api_test.go @@ -0,0 +1,143 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" + v1API "github.com/supabase/cli/pkg/api" +) + +func TestApiToUpdatePostgrestConfigBody(t *testing.T) { + t.Run("converts all fields correctly", func(t *testing.T) { + api := &api{ + Enabled: true, + Schemas: []string{"public", "private"}, + ExtraSearchPath: []string{"extensions", "public"}, + MaxRows: 1000, + } + + body := api.ToUpdatePostgrestConfigBody() + + assert.Equal(t, "public,private", *body.DbSchema) + assert.Equal(t, "extensions,public", *body.DbExtraSearchPath) + assert.Equal(t, 1000, *body.MaxRows) + }) + + t.Run("handles empty fields", func(t *testing.T) { + api := &api{} + + body := api.ToUpdatePostgrestConfigBody() + + // remote api will be false by default, leading to an empty schema on api side + assert.Equal(t, "", *body.DbSchema) + }) +} + +func TestApiDiffWithRemote(t *testing.T) { + t.Run("detects differences", func(t *testing.T) { + api := &api{ + Enabled: true, + Schemas: []string{"public", "private"}, + ExtraSearchPath: []string{"extensions", "public"}, + MaxRows: 1000, + } + + remoteConfig := v1API.PostgrestConfigWithJWTSecretResponse{ + DbSchema: "public", + DbExtraSearchPath: "public", + MaxRows: 500, + } + + diff, err := api.DiffWithRemote(remoteConfig) + assert.NoError(t, err, string(diff)) + + assert.Contains(t, string(diff), "-schemas = [\"public\"]") + assert.Contains(t, string(diff), "+schemas = [\"public\", \"private\"]") + assert.Contains(t, string(diff), "-extra_search_path = [\"public\"]") + assert.Contains(t, string(diff), "+extra_search_path = [\"extensions\", \"public\"]") + assert.Contains(t, string(diff), "-max_rows = 500") + assert.Contains(t, string(diff), "+max_rows = 1000") + }) + + t.Run("handles no differences", func(t *testing.T) { + api := &api{ + Enabled: true, + Schemas: []string{"public"}, + ExtraSearchPath: []string{"public"}, + MaxRows: 500, + } + + remoteConfig := v1API.PostgrestConfigWithJWTSecretResponse{ + DbSchema: "public", + DbExtraSearchPath: "public", + MaxRows: 500, + } + + diff, err := api.DiffWithRemote(remoteConfig) + assert.NoError(t, err) + + assert.Empty(t, diff) + }) + + t.Run("handles multiple schemas and search paths with spaces", func(t *testing.T) { + api := &api{ + Enabled: true, + Schemas: []string{"public", "private"}, + ExtraSearchPath: []string{"extensions", "public"}, + MaxRows: 500, + } + + remoteConfig := v1API.PostgrestConfigWithJWTSecretResponse{ + DbSchema: "public, private", + DbExtraSearchPath: "extensions, public", + MaxRows: 500, + } + + diff, err := api.DiffWithRemote(remoteConfig) + assert.NoError(t, err) + + assert.Empty(t, diff) + }) + + t.Run("handles api disabled on remote side", func(t *testing.T) { + api := &api{ + Enabled: true, + Schemas: []string{"public", "private"}, + ExtraSearchPath: []string{"extensions", "public"}, + MaxRows: 500, + } + + remoteConfig := v1API.PostgrestConfigWithJWTSecretResponse{ + DbSchema: "", + DbExtraSearchPath: "", + MaxRows: 0, + } + + diff, err := api.DiffWithRemote(remoteConfig) + assert.NoError(t, err, string(diff)) + + assert.Contains(t, string(diff), "-enabled = false") + assert.Contains(t, string(diff), "+enabled = true") + }) + + t.Run("handles api disabled on local side", func(t *testing.T) { + api := &api{ + Enabled: false, + Schemas: []string{"public"}, + ExtraSearchPath: []string{"public"}, + MaxRows: 500, + } + + remoteConfig := v1API.PostgrestConfigWithJWTSecretResponse{ + DbSchema: "public", + DbExtraSearchPath: "public", + MaxRows: 500, + } + + diff, err := api.DiffWithRemote(remoteConfig) + assert.NoError(t, err, string(diff)) + + assert.Contains(t, string(diff), "-enabled = true") + assert.Contains(t, string(diff), "+enabled = false") + }) +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 82b9d0a46..3bf566bb1 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -145,23 +145,6 @@ type ( Remotes map[string]baseConfig `toml:"-"` } - api struct { - Enabled bool `toml:"enabled"` - Image string `toml:"-"` - KongImage string `toml:"-"` - Port uint16 `toml:"port"` - Schemas []string `toml:"schemas"` - ExtraSearchPath []string `toml:"extra_search_path"` - MaxRows uint `toml:"max_rows"` - Tls tlsKong `toml:"tls"` - // TODO: replace [auth|studio].api_url - ExternalUrl string `toml:"external_url"` - } - - tlsKong struct { - Enabled bool `toml:"enabled"` - } - db struct { Image string `toml:"-"` Port uint16 `toml:"port"` @@ -723,6 +706,7 @@ func (c *config) Load(path string, fsys fs.FS) error { if err := c.baseConfig.Validate(fsys); err != nil { return err } + idToName := map[string]string{} c.Remotes = make(map[string]baseConfig, len(c.Overrides)) for name, remote := range c.Overrides { base := c.baseConfig.Clone() @@ -735,10 +719,18 @@ func (c *config) Load(path string, fsys fs.FS) error { if metadata, err := toml.NewDecoder(&buf).Decode(&base); err != nil { return errors.Errorf("failed to decode remote config: %w", err) } else if undecoded := metadata.Undecoded(); len(undecoded) > 0 { - fmt.Fprintf(os.Stderr, "Unknown config fields: %+v\n", undecoded) + fmt.Fprintf(os.Stderr, "WARN: unknown config fields: %+v\n", undecoded) + } + // Cross validate remote project id + if base.ProjectId == c.baseConfig.ProjectId { + fmt.Fprintf(os.Stderr, "WARN: project_id is missing for [remotes.%s]\n", name) + } else if other, exists := idToName[base.ProjectId]; exists { + return errors.Errorf("duplicate project_id for [remotes.%s] and [remotes.%s]", other, name) + } else { + idToName[base.ProjectId] = name } if err := base.Validate(fsys); err != nil { - return err + return errors.Errorf("invalid config for [remotes.%s]: %w", name, err) } c.Remotes[name] = base } @@ -1332,3 +1324,35 @@ func (a *auth) ResolveJWKS(ctx context.Context) (string, error) { return string(jwksEncoded), nil } + +// Retrieve the final base config to use taking into account the remotes override +func (c *config) GetRemoteByProjectRef(projectRef string) (baseConfig, error) { + var result []string + // Iterate over all the config.Remotes + for name, remoteConfig := range c.Remotes { + // Check if there is one matching project_id + if remoteConfig.ProjectId == projectRef { + // Check for duplicate project IDs across remotes + result = append(result, name) + } + } + // If no matching remote config is found, return the base config + if len(result) == 0 { + return c.baseConfig, errors.Errorf("no remote found for project_id: %s", projectRef) + } + remote := c.Remotes[result[0]] + if len(result) > 1 { + return remote, errors.Errorf("multiple remotes %v have the same project_id: %s", result, projectRef) + } + return remote, nil +} + +func ToTomlBytes(config any) ([]byte, error) { + var buf bytes.Buffer + enc := toml.NewEncoder(&buf) + enc.Indent = "" + if err := enc.Encode(config); err != nil { + return nil, errors.Errorf("failed to marshal toml config: %w", err) + } + return buf.Bytes(), nil +} diff --git a/pkg/config/updater.go b/pkg/config/updater.go new file mode 100644 index 000000000..467b7bb63 --- /dev/null +++ b/pkg/config/updater.go @@ -0,0 +1,49 @@ +package config + +import ( + "context" + "fmt" + "os" + + "github.com/go-errors/errors" + v1API "github.com/supabase/cli/pkg/api" +) + +type ConfigUpdater struct { + client v1API.ClientWithResponses +} + +func NewConfigUpdater(client v1API.ClientWithResponses) ConfigUpdater { + return ConfigUpdater{client: client} +} + +func (u *ConfigUpdater) UpdateRemoteConfig(ctx context.Context, remote baseConfig) error { + if err := u.UpdateApiConfig(ctx, remote.ProjectId, remote.Api); err != nil { + return err + } + // TODO: implement other service configs, ie. auth + return nil +} + +func (u *ConfigUpdater) UpdateApiConfig(ctx context.Context, projectRef string, c api) error { + apiConfig, err := u.client.V1GetPostgrestServiceConfigWithResponse(ctx, projectRef) + if err != nil { + return errors.Errorf("failed to read API config: %w", err) + } else if apiConfig.JSON200 == nil { + return errors.Errorf("unexpected status %d: %s", apiConfig.StatusCode(), string(apiConfig.Body)) + } + apiDiff, err := c.DiffWithRemote(*apiConfig.JSON200) + if err != nil { + return err + } else if len(apiDiff) == 0 { + fmt.Fprintln(os.Stderr, "Remote API config is up to date.") + return nil + } + fmt.Fprintln(os.Stderr, "Updating API service with config:", string(apiDiff)) + if resp, err := u.client.V1UpdatePostgrestServiceConfigWithResponse(ctx, projectRef, c.ToUpdatePostgrestConfigBody()); err != nil { + return errors.Errorf("failed to update API config: %w", err) + } else if resp.JSON200 == nil { + return errors.Errorf("unexpected status %d: %s", resp.StatusCode(), string(resp.Body)) + } + return nil +} diff --git a/pkg/config/updater_test.go b/pkg/config/updater_test.go new file mode 100644 index 000000000..241d612e9 --- /dev/null +++ b/pkg/config/updater_test.go @@ -0,0 +1,65 @@ +package config + +import ( + "context" + "net/http" + "testing" + + "github.com/h2non/gock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + v1API "github.com/supabase/cli/pkg/api" +) + +func TestUpdateApi(t *testing.T) { + server := "http://localhost" + client, err := v1API.NewClientWithResponses(server) + require.NoError(t, err) + + t.Run("updates remote config", func(t *testing.T) { + updater := NewConfigUpdater(*client) + // Setup mock server + defer gock.Off() + gock.New(server). + Get("/v1/projects/test-project/postgrest"). + Reply(http.StatusOK). + JSON(v1API.PostgrestConfigWithJWTSecretResponse{}) + gock.New(server). + Patch("/v1/projects/test-project/postgrest"). + Reply(http.StatusOK). + JSON(v1API.PostgrestConfigWithJWTSecretResponse{ + DbSchema: "public,graphql_public", + DbExtraSearchPath: "public,extensions", + MaxRows: 1000, + }) + // Run test + err := updater.UpdateApiConfig(context.Background(), "test-project", api{ + Enabled: true, + Schemas: []string{"public", "graphql_public"}, + ExtraSearchPath: []string{"public", "extensions"}, + MaxRows: 1000, + }) + // Check result + assert.NoError(t, err) + assert.True(t, gock.IsDone()) + }) + + t.Run("skips update if no diff", func(t *testing.T) { + updater := NewConfigUpdater(*client) + // Setup mock server + defer gock.Off() + gock.New(server). + Get("/v1/projects/test-project/postgrest"). + Reply(http.StatusOK). + JSON(v1API.PostgrestConfigWithJWTSecretResponse{ + DbSchema: "", + DbExtraSearchPath: "public,extensions", + MaxRows: 1000, + }) + // Run test + err := updater.UpdateApiConfig(context.Background(), "test-project", api{}) + // Check result + assert.NoError(t, err) + assert.True(t, gock.IsDone()) + }) +} diff --git a/internal/link/diff.go b/pkg/diff/diff.go similarity index 99% rename from internal/link/diff.go rename to pkg/diff/diff.go index 84f2e3494..6a40b23fc 100644 --- a/internal/link/diff.go +++ b/pkg/diff/diff.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package link +package diff import ( "bytes" diff --git a/pkg/parser/token.go b/pkg/parser/token.go index 018a6f5d4..db0084342 100644 --- a/pkg/parser/token.go +++ b/pkg/parser/token.go @@ -8,6 +8,7 @@ import ( "github.com/go-errors/errors" "github.com/spf13/viper" + "github.com/supabase/cli/pkg/cast" ) // Equal to `startBufSize` from `bufio/scan.go` @@ -83,7 +84,7 @@ func Split(sql io.Reader, transform ...func(string) string) (stats []string, err // Increase scanner capacity to support very long lines containing e.g. geodata buf := make([]byte, startBufSize) - maxbuf := int(viper.GetSizeInBytes("SCANNER_BUFFER_SIZE")) + maxbuf := cast.UintToInt(viper.GetSizeInBytes("SCANNER_BUFFER_SIZE")) if maxbuf == 0 { maxbuf = MaxScannerCapacity }