Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cli/cmd/vm_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd
import (
"fmt"
"os"
"strings"
"time"

"github.com/moby/moby/pkg/namesgenerator"
Expand Down Expand Up @@ -147,6 +148,9 @@ func (r *runners) createVM(cmd *cobra.Command, args []string) error {
func (r *runners) createAndWaitForVM(opts kotsclient.CreateVMOpts) ([]*types.VM, error) {
vms, ve, err := r.kotsAPI.CreateVM(opts)
if errors.Cause(err) == platformclient.ErrForbidden {
if isRBACDeniedError(err) {
return nil, errors.New(err.Error())
}
return nil, ErrCompatibilityMatrixTermsNotAccepted
} else if err != nil {
return nil, errors.Wrap(err, "create vm")
Expand Down Expand Up @@ -175,6 +179,11 @@ func (r *runners) createAndWaitForVM(opts kotsclient.CreateVMOpts) ([]*types.VM,
return vms, nil
}

func isRBACDeniedError(err error) bool {
message := strings.TrimSpace(strings.ToLower(err.Error()))
return strings.HasPrefix(message, "access to ") && strings.HasSuffix(message, " is denied")
}

Comment on lines +182 to +186
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there no better way to decide if its an RBAC error than by string matching?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RBAC errors are unfortunately just a 403 and a string
it's not json or anything

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The literal API response is access to "kots/vm/create" is denied lol

func waitForVMs(kotsRestClient *kotsclient.VendorV3Client, vms []*types.VM, duration time.Duration) ([]*types.VM, error) {
start := time.Now()
runningVMs := map[string]*types.VM{}
Expand Down
63 changes: 63 additions & 0 deletions cli/cmd/vm_create_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package cmd

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/replicatedhq/replicated/pkg/kotsclient"
"github.com/replicatedhq/replicated/pkg/platformclient"
"github.com/stretchr/testify/require"
)

func TestCreateAndWaitForVM_ForbiddenErrors(t *testing.T) {
tests := []struct {
name string
body string
contentType string
expectedError string
}{
{
name: "rbac denial returns server message",
body: `access to "kots/vm/create" is denied`,
contentType: "text/plain",
expectedError: `access to "kots/vm/create" is denied`,
},
{
name: "non-rbac forbidden returns compatibility matrix message",
body: `{"error":{"message":"You must read and accept the Compatibility Matrix Terms of Service"}}`,
contentType: "application/json",
expectedError: ErrCompatibilityMatrixTermsNotAccepted.Error(),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v3/vm" || r.Method != http.MethodPost {
http.NotFound(w, r)
return
}

w.Header().Set("Content-Type", tt.contentType)
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(tt.body))
}))
defer server.Close()

httpClient := platformclient.NewHTTPClient(server.URL, "fake-api-key")
runner := &runners{
kotsAPI: &kotsclient.VendorV3Client{HTTPClient: *httpClient},
}

_, err := runner.createAndWaitForVM(kotsclient.CreateVMOpts{
Name: "test-vm",
Distribution: "ubuntu",
Version: "22.04",
Count: 1,
})
require.Error(t, err)
require.Equal(t, tt.expectedError, err.Error())
})
}
}
61 changes: 47 additions & 14 deletions pkg/platformclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,26 @@ func (e APIError) Error() string {
return fmt.Sprintf("%s %s %d: %s", e.Method, e.Endpoint, e.StatusCode, e.Message)
}

type ForbiddenError struct {
Message string
Body []byte
}

func (e ForbiddenError) Error() string {
if strings.TrimSpace(e.Message) != "" {
return e.Message
}
return ErrForbidden.Error()
}

func (e ForbiddenError) Cause() error {
return ErrForbidden
}

func (e ForbiddenError) Unwrap() error {
return ErrForbidden
}

type AppOptions struct {
Name string
}
Expand Down Expand Up @@ -185,25 +205,12 @@ func (c *HTTPClient) DoJSON(ctx context.Context, method string, path string, suc
}
if resp.StatusCode != successStatus {
if resp.StatusCode == http.StatusForbidden {
// look for a response message in the body
body, err := io.ReadAll(resp.Body)
if err != nil {
return ErrForbidden
}

// some of the methods in the api have a standardized response for 403
type forbiddenResponse struct {
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
var fr forbiddenResponse
if err := json.Unmarshal(body, &fr); err == nil {
return errors.New(fr.Error.Message)
}

return ErrForbidden
return parseForbiddenError(body)
}
body, _ := io.ReadAll(resp.Body)
return APIError{
Expand All @@ -228,6 +235,32 @@ func (c *HTTPClient) DoJSON(ctx context.Context, method string, path string, suc
return nil
}

func parseForbiddenError(body []byte) error {
type forbiddenResponse struct {
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}

var fr forbiddenResponse
if err := json.Unmarshal(body, &fr); err == nil && strings.TrimSpace(fr.Error.Message) != "" {
return ForbiddenError{
Message: fr.Error.Message,
Body: body,
}
}

if message := strings.TrimSpace(string(body)); message != "" {
return ForbiddenError{
Message: message,
Body: body,
}
}

return ErrForbidden
}

func addGitHubActionsHeaders(req *http.Request) error {
// anyone can set this to false to disable this behavior
if os.Getenv("CI") != "true" {
Expand Down
Loading