diff --git a/cmd/restrictions.go b/cmd/restrictions.go index 4358d6132..6bc1973d0 100644 --- a/cmd/restrictions.go +++ b/cmd/restrictions.go @@ -16,12 +16,13 @@ var ( dbCidrsToAllow []string bypassCidrChecks bool + appendMode bool restrictionsUpdateCmd = &cobra.Command{ Use: "update", Short: "Update network restrictions", RunE: func(cmd *cobra.Command, args []string) error { - return update.Run(cmd.Context(), flags.ProjectRef, dbCidrsToAllow, bypassCidrChecks) + return update.Run(cmd.Context(), flags.ProjectRef, dbCidrsToAllow, bypassCidrChecks, appendMode) }, } @@ -38,6 +39,7 @@ func init() { restrictionsCmd.PersistentFlags().StringVar(&flags.ProjectRef, "project-ref", "", "Project ref of the Supabase project.") restrictionsUpdateCmd.Flags().StringSliceVar(&dbCidrsToAllow, "db-allow-cidr", []string{}, "CIDR to allow DB connections from.") restrictionsUpdateCmd.Flags().BoolVar(&bypassCidrChecks, "bypass-cidr-checks", false, "Bypass some of the CIDR validation checks.") + restrictionsUpdateCmd.Flags().BoolVar(&appendMode, "append", false, "Append to existing restrictions instead of replacing them.") restrictionsCmd.AddCommand(restrictionsGetCmd) restrictionsCmd.AddCommand(restrictionsUpdateCmd) rootCmd.AddCommand(restrictionsCmd) diff --git a/internal/restrictions/get/get.go b/internal/restrictions/get/get.go index f8e27342d..fa5e54d95 100644 --- a/internal/restrictions/get/get.go +++ b/internal/restrictions/get/get.go @@ -6,6 +6,7 @@ import ( "github.com/go-errors/errors" "github.com/supabase/cli/internal/utils" + "github.com/supabase/cli/pkg/api" ) func Run(ctx context.Context, projectRef string) error { @@ -19,6 +20,6 @@ func Run(ctx context.Context, projectRef string) error { fmt.Printf("DB Allowed IPv4 CIDRs: %+v\n", resp.JSON200.Config.DbAllowedCidrs) fmt.Printf("DB Allowed IPv6 CIDRs: %+v\n", resp.JSON200.Config.DbAllowedCidrsV6) - fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON200.Status == "applied") + fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON200.Status == api.NetworkRestrictionsResponseStatusApplied) return nil } diff --git a/internal/restrictions/update/update.go b/internal/restrictions/update/update.go index 98a70da98..eb9e2f555 100644 --- a/internal/restrictions/update/update.go +++ b/internal/restrictions/update/update.go @@ -10,7 +10,8 @@ import ( "github.com/supabase/cli/pkg/api" ) -func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypassCidrChecks bool) error { +// Run updates the network restriction lists using the provided CIDRs. +func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypassCidrChecks bool, appendMode bool) error { // 1. separate CIDR to v4 and v6 body := api.V1UpdateNetworkRestrictionsJSONRequestBody{ DbAllowedCidrs: &[]string{}, @@ -31,6 +32,10 @@ func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypass } } + if appendMode { + return ApplyPatch(ctx, projectRef, body) + } + // 2. update restrictions resp, err := utils.GetSupabase().V1UpdateNetworkRestrictionsWithResponse(ctx, projectRef, body) if err != nil { @@ -42,6 +47,44 @@ func Run(ctx context.Context, projectRef string, dbCidrsToAllow []string, bypass fmt.Printf("DB Allowed IPv4 CIDRs: %+v\n", resp.JSON201.Config.DbAllowedCidrs) fmt.Printf("DB Allowed IPv6 CIDRs: %+v\n", resp.JSON201.Config.DbAllowedCidrsV6) - fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON201.Status == "applied") + fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON201.Status == api.NetworkRestrictionsResponseStatusApplied) + return nil +} + +// ApplyPatch submits a network restriction payload using PATCH (add/remove mode). +func ApplyPatch(ctx context.Context, projectRef string, body api.V1UpdateNetworkRestrictionsJSONRequestBody) error { + patchBody := api.V1PatchNetworkRestrictionsJSONRequestBody{ + Add: &struct { + DbAllowedCidrs *[]string `json:"dbAllowedCidrs,omitempty"` + DbAllowedCidrsV6 *[]string `json:"dbAllowedCidrsV6,omitempty"` + }{ + DbAllowedCidrs: body.DbAllowedCidrs, + DbAllowedCidrsV6: body.DbAllowedCidrsV6, + }, + } + + resp, err := utils.GetSupabase().V1PatchNetworkRestrictionsWithResponse(ctx, projectRef, patchBody) + if err != nil { + return errors.Errorf("failed to apply network restrictions: %w", err) + } + if resp.JSON200 == nil { + return errors.New("failed to apply network restrictions: " + string(resp.Body)) + } + + var allowedIPv4, allowedIPv6 []string + if allowed := resp.JSON200.Config.DbAllowedCidrs; allowed != nil { + for _, cidr := range *allowed { + switch cidr.Type { + case api.NetworkRestrictionsV2ResponseConfigDbAllowedCidrsTypeV4: + allowedIPv4 = append(allowedIPv4, cidr.Address) + case api.NetworkRestrictionsV2ResponseConfigDbAllowedCidrsTypeV6: + allowedIPv6 = append(allowedIPv6, cidr.Address) + } + } + } + + fmt.Printf("DB Allowed IPv4 CIDRs: %+v\n", &allowedIPv4) + fmt.Printf("DB Allowed IPv6 CIDRs: %+v\n", &allowedIPv6) + fmt.Printf("Restrictions applied successfully: %+v\n", resp.JSON200.Status == api.NetworkRestrictionsV2ResponseStatusApplied) return nil } diff --git a/internal/restrictions/update/update_test.go b/internal/restrictions/update/update_test.go index 5a6583ea8..a927e2711 100644 --- a/internal/restrictions/update/update_test.go +++ b/internal/restrictions/update/update_test.go @@ -19,22 +19,52 @@ func TestUpdateRestrictionsCommand(t *testing.T) { token := apitest.RandomAccessToken(t) t.Setenv("SUPABASE_ACCESS_TOKEN", string(token)) - t.Run("updates v4 and v6 CIDR", func(t *testing.T) { + t.Run("replaces v4 and v6 CIDR", func(t *testing.T) { // Setup mock api defer gock.OffAll() + expectedV4 := []string{"12.3.4.5/32", "1.2.3.1/24"} + expectedV6 := []string{"2001:db8:abcd:0012::0/64"} gock.New(utils.DefaultApiHost). Post("/v1/projects/" + projectRef + "/network-restrictions/apply"). MatchType("json"). JSON(api.NetworkRestrictionsRequest{ - DbAllowedCidrs: &[]string{"12.3.4.5/32", "1.2.3.1/24"}, - DbAllowedCidrsV6: &[]string{"2001:db8:abcd:0012::0/64"}, + DbAllowedCidrs: &expectedV4, + DbAllowedCidrsV6: &expectedV6, }). Reply(http.StatusCreated). JSON(api.NetworkRestrictionsResponse{ Status: api.NetworkRestrictionsResponseStatus("applied"), }) // Run test - err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "2001:db8:abcd:0012::0/64", "1.2.3.1/24"}, false) + err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "2001:db8:abcd:0012::0/64", "1.2.3.1/24"}, false, false) + // Check error + assert.NoError(t, err) + assert.Empty(t, apitest.ListUnmatchedRequests()) + }) + + t.Run("appends v4 and v6 CIDR", func(t *testing.T) { + // Setup mock api + defer gock.OffAll() + addV4 := []string{"12.3.4.5/32", "1.2.3.1/24"} + addV6 := []string{"2001:db8:abcd:0012::0/64"} + gock.New(utils.DefaultApiHost). + Patch("/v1/projects/" + projectRef + "/network-restrictions"). + MatchType("json"). + JSON(api.NetworkRestrictionsPatchRequest{ + Add: &struct { + DbAllowedCidrs *[]string `json:"dbAllowedCidrs,omitempty"` + DbAllowedCidrsV6 *[]string `json:"dbAllowedCidrsV6,omitempty"` + }{ + DbAllowedCidrs: &addV4, + DbAllowedCidrsV6: &addV6, + }, + }). + Reply(http.StatusOK). + JSON(api.NetworkRestrictionsV2Response{ + Status: api.NetworkRestrictionsV2ResponseStatus("applied"), + }) + // Run test + err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "1.2.3.1/24", "2001:db8:abcd:0012::0/64"}, false, true) // Check error assert.NoError(t, err) assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -53,7 +83,7 @@ func TestUpdateRestrictionsCommand(t *testing.T) { }). ReplyError(errNetwork) // Run test - err := Run(context.Background(), projectRef, []string{}, true) + err := Run(context.Background(), projectRef, []string{}, true, false) // Check error assert.ErrorIs(t, err, errNetwork) assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -71,7 +101,7 @@ func TestUpdateRestrictionsCommand(t *testing.T) { }). Reply(http.StatusServiceUnavailable) // Run test - err := Run(context.Background(), projectRef, []string{}, true) + err := Run(context.Background(), projectRef, []string{}, true, false) // Check error assert.ErrorContains(t, err, "failed to apply network restrictions:") assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -99,7 +129,7 @@ func TestValidateCIDR(t *testing.T) { Status: api.NetworkRestrictionsResponseStatus("applied"), }) // Run test - err := Run(context.Background(), projectRef, []string{"10.0.0.0/8"}, true) + err := Run(context.Background(), projectRef, []string{"10.0.0.0/8"}, true, false) // Check error assert.NoError(t, err) assert.Empty(t, apitest.ListUnmatchedRequests()) @@ -107,14 +137,14 @@ func TestValidateCIDR(t *testing.T) { t.Run("throws error on private subnet", func(t *testing.T) { // Run test - err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "10.0.0.0/8", "1.2.3.1/24"}, false) + err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "10.0.0.0/8", "1.2.3.1/24"}, false, false) // Check error assert.ErrorContains(t, err, "private IP provided: 10.0.0.0/8") }) t.Run("throws error on invalid subnet", func(t *testing.T) { // Run test - err := Run(context.Background(), projectRef, []string{"12.3.4.5", "10.0.0.0/8", "1.2.3.1/24"}, false) + err := Run(context.Background(), projectRef, []string{"12.3.4.5", "10.0.0.0/8", "1.2.3.1/24"}, false, false) // Check error assert.ErrorContains(t, err, "failed to parse IP: 12.3.4.5") })