diff --git a/bundle/bundle_test.go b/bundle/bundle_test.go index 91687a5c..bb539543 100644 --- a/bundle/bundle_test.go +++ b/bundle/bundle_test.go @@ -19,8 +19,6 @@ package bundle_test import ( "context" "github.com/rs/zerolog" - "github.com/snyk/code-client-go/deepcode" - mocks2 "github.com/snyk/code-client-go/deepcode/mocks" "testing" "github.com/golang/mock/gomock" @@ -28,6 +26,8 @@ import ( "github.com/stretchr/testify/require" "github.com/snyk/code-client-go/bundle" + "github.com/snyk/code-client-go/deepcode" + deepcodeMocks "github.com/snyk/code-client-go/deepcode/mocks" "github.com/snyk/code-client-go/observability/mocks" ) @@ -42,7 +42,7 @@ func Test_UploadBatch(t *testing.T) { t.Run("when no documents - creates nothing", func(t *testing.T) { ctrl := gomock.NewController(t) - mockSnykCodeClient := mocks2.NewMockSnykCodeClient(ctrl) + mockSnykCodeClient := deepcodeMocks.NewMockSnykCodeClient(ctrl) mockSpan := mocks.NewMockSpan(ctrl) mockSpan.EXPECT().Context().AnyTimes() @@ -59,7 +59,7 @@ func Test_UploadBatch(t *testing.T) { t.Run("when no bundles - creates new deepCodeBundle and sets hash", func(t *testing.T) { ctrl := gomock.NewController(t) - mockSnykCodeClient := mocks2.NewMockSnykCodeClient(ctrl) + mockSnykCodeClient := deepcodeMocks.NewMockSnykCodeClient(ctrl) mockSnykCodeClient.EXPECT().ExtendBundle(gomock.Any(), "testBundleHash", map[string]deepcode.BundleFile{ "file": {}, }, []string{}).Return("testBundleHash", []string{}, nil) @@ -78,7 +78,7 @@ func Test_UploadBatch(t *testing.T) { t.Run("when existing bundles - extends deepCodeBundle and updates hash", func(t *testing.T) { ctrl := gomock.NewController(t) - mockSnykCodeClient := mocks2.NewMockSnykCodeClient(ctrl) + mockSnykCodeClient := deepcodeMocks.NewMockSnykCodeClient(ctrl) mockSnykCodeClient.EXPECT().ExtendBundle(gomock.Any(), "testBundleHash", map[string]deepcode.BundleFile{ "another": {}, "file": {}, diff --git a/http/config.go b/config/config.go similarity index 97% rename from http/config.go rename to config/config.go index dc4ca527..ae4bb93e 100644 --- a/http/config.go +++ b/config/config.go @@ -1,4 +1,4 @@ -package http +package config // Config defines the configurable options for the HTTP client. // diff --git a/http/mocks/config.go b/config/mocks/config.go similarity index 100% rename from http/mocks/config.go rename to config/mocks/config.go diff --git a/deepcode/client.go b/deepcode/client.go index 0442764b..07b9e158 100644 --- a/deepcode/client.go +++ b/deepcode/client.go @@ -17,8 +17,16 @@ package deepcode import ( + "bytes" "context" "encoding/json" + "errors" + "github.com/snyk/code-client-go/config" + "github.com/snyk/code-client-go/internal/util/encoding" + "io" + "net/http" + "net/url" + "regexp" "strconv" "github.com/rs/zerolog" @@ -62,17 +70,21 @@ type BundleResponse struct { } type snykCodeClient struct { - httpClient codeClientHTTP.HTTPClient - instrumentor observability.Instrumentor - logger *zerolog.Logger + httpClient codeClientHTTP.HTTPClient + instrumentor observability.Instrumentor + errorReporter observability.ErrorReporter + logger *zerolog.Logger + config config.Config } func NewSnykCodeClient( logger *zerolog.Logger, httpClient codeClientHTTP.HTTPClient, instrumentor observability.Instrumentor, + errorReporter observability.ErrorReporter, + config config.Config, ) *snykCodeClient { - return &snykCodeClient{httpClient, instrumentor, logger} + return &snykCodeClient{httpClient, instrumentor, errorReporter, logger, config} } func (s *snykCodeClient) GetFilters(ctx context.Context) ( @@ -86,7 +98,12 @@ func (s *snykCodeClient) GetFilters(ctx context.Context) ( span := s.instrumentor.StartSpan(ctx, method) defer s.instrumentor.Finish(span) - responseBody, err := s.httpClient.DoCall(span.Context(), "GET", "/filters", nil) + host, err := s.Host() + if err != nil { + return FiltersResponse{ConfigFiles: nil, Extensions: nil}, err + } + + responseBody, err := s.Request(host, http.MethodGet, "/filters", nil) if err != nil { return FiltersResponse{ConfigFiles: nil, Extensions: nil}, err } @@ -110,12 +127,17 @@ func (s *snykCodeClient) CreateBundle( span := s.instrumentor.StartSpan(ctx, method) defer s.instrumentor.Finish(span) + host, err := s.Host() + if err != nil { + return "", nil, err + } + requestBody, err := json.Marshal(filesToFilehashes) if err != nil { return "", nil, err } - responseBody, err := s.httpClient.DoCall(span.Context(), "POST", "/bundle", requestBody) + responseBody, err := s.Request(host, http.MethodPost, "/bundle", requestBody) if err != nil { return "", nil, err } @@ -143,6 +165,11 @@ func (s *snykCodeClient) ExtendBundle( span := s.instrumentor.StartSpan(ctx, method) defer s.instrumentor.Finish(span) + host, err := s.Host() + if err != nil { + return "", nil, err + } + requestBody, err := json.Marshal(ExtendBundleRequest{ Files: files, RemovedFiles: removedFiles, @@ -151,7 +178,7 @@ func (s *snykCodeClient) ExtendBundle( return "", nil, err } - responseBody, err := s.httpClient.DoCall(span.Context(), "PUT", "/bundle/"+bundleHash, requestBody) + responseBody, err := s.Request(host, http.MethodPut, "/bundle/"+bundleHash, requestBody) if err != nil { return "", nil, err } @@ -159,3 +186,103 @@ func (s *snykCodeClient) ExtendBundle( err = json.Unmarshal(responseBody, &bundleResponse) return bundleResponse.BundleHash, bundleResponse.MissingFiles, err } + +// This is only exported for tests. +func (s *snykCodeClient) Host() (string, error) { + var codeApiRegex = regexp.MustCompile(`^(deeproxy\.)?`) + + snykCodeApiUrl := s.config.SnykCodeApi() + if !s.config.IsFedramp() { + return snykCodeApiUrl, nil + } + u, err := url.Parse(snykCodeApiUrl) + if err != nil { + return "", err + } + + u.Host = codeApiRegex.ReplaceAllString(u.Host, "api.") + + organization := s.config.Organization() + if organization == "" { + return "", errors.New("Organization is required in a fedramp environment") + } + + u.Path = "/hidden/orgs/" + organization + "/code" + + return u.String(), nil +} + +func (s *snykCodeClient) Request( + host string, + method string, + path string, + requestBody []byte, +) ([]byte, error) { + log := s.logger.With().Str("method", "deepcode.Request").Logger() + + bodyBuffer, err := s.encodeIfNeeded(method, requestBody) + if err != nil { + return nil, err + } + + req, err := http.NewRequest(method, host+path, bodyBuffer) + if err != nil { + return nil, err + } + + s.addHeaders(method, req) + + response, err := s.httpClient.Do(req) + if err != nil { + return nil, err + } + defer func() { + closeErr := response.Body.Close() + if closeErr != nil { + s.logger.Error().Err(closeErr).Msg("Couldn't close response body in call to Snyk Code") + } + }() + responseBody, err := io.ReadAll(response.Body) + if err != nil { + log.Error().Err(err).Msg("error reading response body") + s.errorReporter.CaptureError(err, observability.ErrorReporterOptions{ErrorDiagnosticPath: req.RequestURI}) + return nil, err + } + + return responseBody, nil +} + +func (s *snykCodeClient) addHeaders(method string, req *http.Request) { + // Setting a chosen org name for the request + org := s.config.Organization() + if org != "" { + req.Header.Set("snyk-org-name", org) + } + // https://www.keycdn.com/blog/http-cache-headers + req.Header.Set("Cache-Control", "private, max-age=0, no-cache") + if s.mustBeEncoded(method) { + req.Header.Set("Content-Type", "application/octet-stream") + req.Header.Set("Content-Encoding", "gzip") + } else { + req.Header.Set("Content-Type", "application/json") + } +} + +func (s *snykCodeClient) encodeIfNeeded(method string, requestBody []byte) (*bytes.Buffer, error) { + b := new(bytes.Buffer) + mustBeEncoded := s.mustBeEncoded(method) + if mustBeEncoded { + enc := encoding.NewEncoder(b) + _, err := enc.Write(requestBody) + if err != nil { + return nil, err + } + } else { + b = bytes.NewBuffer(requestBody) + } + return b, nil +} + +func (s *snykCodeClient) mustBeEncoded(method string) bool { + return method == http.MethodPost || method == http.MethodPut +} diff --git a/deepcode/client_pact_test.go b/deepcode/client_pact_test.go index 805358eb..fe8d8e76 100644 --- a/deepcode/client_pact_test.go +++ b/deepcode/client_pact_test.go @@ -26,9 +26,9 @@ import ( "github.com/pact-foundation/pact-go/dsl" "github.com/stretchr/testify/assert" + confMocks "github.com/snyk/code-client-go/config/mocks" "github.com/snyk/code-client-go/deepcode" codeClientHTTP "github.com/snyk/code-client-go/http" - httpmocks "github.com/snyk/code-client-go/http/mocks" "github.com/snyk/code-client-go/internal/util" "github.com/snyk/code-client-go/internal/util/testutil" ) @@ -214,7 +214,7 @@ func setupPact(t *testing.T) { pact.Setup(true) ctrl := gomock.NewController(t) - config := httpmocks.NewMockConfig(ctrl) + config := confMocks.NewMockConfig(ctrl) config.EXPECT().IsFedramp().AnyTimes().Return(false) config.EXPECT().Organization().AnyTimes().Return(orgUUID) snykCodeApiUrl := fmt.Sprintf("http://localhost:%d", pact.Server.Port) @@ -222,10 +222,10 @@ func setupPact(t *testing.T) { instrumentor := testutil.NewTestInstrumentor() errorReporter := testutil.NewTestErrorReporter() - httpClient := codeClientHTTP.NewHTTPClient(newLogger(t), config, func() *http.Client { + httpClient := codeClientHTTP.NewHTTPClient(newLogger(t), func() *http.Client { return http.DefaultClient }, instrumentor, errorReporter) - client = deepcode.NewSnykCodeClient(newLogger(t), httpClient, instrumentor) + client = deepcode.NewSnykCodeClient(newLogger(t), httpClient, instrumentor, errorReporter, config) } func getPutPostHeaderMatcher() dsl.MapMatcher { diff --git a/deepcode/client_test.go b/deepcode/client_test.go index 054fb09b..7ef54183 100644 --- a/deepcode/client_test.go +++ b/deepcode/client_test.go @@ -16,15 +16,20 @@ package deepcode_test import ( + "bytes" "context" "fmt" + "io" + "net/http" "testing" "time" "github.com/golang/mock/gomock" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + confMocks "github.com/snyk/code-client-go/config/mocks" "github.com/snyk/code-client-go/deepcode" httpmocks "github.com/snyk/code-client-go/http/mocks" "github.com/snyk/code-client-go/internal/util" @@ -59,18 +64,31 @@ func TestSnykCodeBackendService_GetFilters(t *testing.T) { mockSpan := mocks.NewMockSpan(ctrl) mockSpan.EXPECT().GetTraceId().AnyTimes() mockSpan.EXPECT().Context().AnyTimes() - mockConfig := httpmocks.NewMockConfig(ctrl) + mockConfig := confMocks.NewMockConfig(ctrl) mockConfig.EXPECT().Organization().AnyTimes().Return("") mockConfig.EXPECT().IsFedramp().AnyTimes().Return(false) + mockConfig.EXPECT().SnykCodeApi().AnyTimes().Return("http://localhost") + mockHTTPClient := httpmocks.NewMockHTTPClient(ctrl) - mockHTTPClient.EXPECT().Config().AnyTimes().Return(mockConfig) - mockHTTPClient.EXPECT().DoCall(gomock.Any(), "GET", "/filters", gomock.Any()).Return([]byte(`{"configFiles": ["test"], "extensions": ["test"]}`), nil).Times(1) + mockHTTPClient.EXPECT().Do( + mock.MatchedBy(func(i interface{}) bool { + req := i.(*http.Request) + return req.URL.String() == "http://localhost/filters" && + req.Method == "GET" && + req.Header.Get("Cache-Control") == "private, max-age=0, no-cache" && + req.Header.Get("Content-Type") == "application/json" + }), + ).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(`{"configFiles": ["test"], "extensions": ["test"]}`))), + }, nil).Times(1) mockInstrumentor := mocks.NewMockInstrumentor(ctrl) mockInstrumentor.EXPECT().StartSpan(gomock.Any(), gomock.Any()).Return(mockSpan).Times(1) mockInstrumentor.EXPECT().Finish(gomock.Any()).Times(1) + mockErrorReporter := mocks.NewMockErrorReporter(ctrl) - s := deepcode.NewSnykCodeClient(newLogger(t), mockHTTPClient, mockInstrumentor) + s := deepcode.NewSnykCodeClient(newLogger(t), mockHTTPClient, mockInstrumentor, mockErrorReporter, mockConfig) filters, err := s.GetFilters(context.Background()) assert.Nil(t, err) assert.Equal(t, 1, len(filters.ConfigFiles)) @@ -82,17 +100,31 @@ func TestSnykCodeBackendService_CreateBundle(t *testing.T) { mockSpan := mocks.NewMockSpan(ctrl) mockSpan.EXPECT().GetTraceId().AnyTimes() mockSpan.EXPECT().Context().AnyTimes() - mockConfig := httpmocks.NewMockConfig(ctrl) + mockConfig := confMocks.NewMockConfig(ctrl) mockConfig.EXPECT().Organization().AnyTimes().Return("") mockConfig.EXPECT().IsFedramp().AnyTimes().Return(false) + mockConfig.EXPECT().SnykCodeApi().AnyTimes().Return("http://localhost") mockHTTPClient := httpmocks.NewMockHTTPClient(ctrl) - mockHTTPClient.EXPECT().Config().AnyTimes().Return(mockConfig) - mockHTTPClient.EXPECT().DoCall(gomock.Any(), "POST", "/bundle", gomock.Any()).Return([]byte(`{"bundleHash": "bundleHash", "missingFiles": ["test"]}`), nil).Times(1) + mockHTTPClient.EXPECT().Do( + mock.MatchedBy(func(i interface{}) bool { + req := i.(*http.Request) + return req.URL.String() == "http://localhost/bundle" && + req.Method == "POST" && + req.Header.Get("Cache-Control") == "private, max-age=0, no-cache" && + req.Header.Get("Content-Encoding") == "gzip" && + req.Header.Get("Content-Type") == "application/octet-stream" + }), + ).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(`{"bundleHash": "bundleHash", "missingFiles": ["test"]}`))), + }, nil).Times(1) + mockInstrumentor := mocks.NewMockInstrumentor(ctrl) mockInstrumentor.EXPECT().StartSpan(gomock.Any(), gomock.Any()).Return(mockSpan).Times(1) mockInstrumentor.EXPECT().Finish(gomock.Any()).Times(1) + mockErrorReporter := mocks.NewMockErrorReporter(ctrl) - s := deepcode.NewSnykCodeClient(newLogger(t), mockHTTPClient, mockInstrumentor) + s := deepcode.NewSnykCodeClient(newLogger(t), mockHTTPClient, mockInstrumentor, mockErrorReporter, mockConfig) files := map[string]string{} randomAddition := fmt.Sprintf("\n public void random() { System.out.println(\"%d\") }", time.Now().UnixMicro()) files[path1] = util.Hash([]byte(content + randomAddition)) @@ -108,18 +140,43 @@ func TestSnykCodeBackendService_ExtendBundle(t *testing.T) { mockSpan := mocks.NewMockSpan(ctrl) mockSpan.EXPECT().GetTraceId().AnyTimes() mockSpan.EXPECT().Context().AnyTimes() - mockConfig := httpmocks.NewMockConfig(ctrl) + mockConfig := confMocks.NewMockConfig(ctrl) mockConfig.EXPECT().Organization().AnyTimes().Return("") mockConfig.EXPECT().IsFedramp().AnyTimes().Return(false) + mockConfig.EXPECT().SnykCodeApi().AnyTimes().Return("http://localhost") mockHTTPClient := httpmocks.NewMockHTTPClient(ctrl) - mockHTTPClient.EXPECT().Config().AnyTimes().Return(mockConfig) - mockHTTPClient.EXPECT().DoCall(gomock.Any(), "POST", "/bundle", gomock.Any()).Return([]byte(`{"bundleHash": "bundleHash", "missingFiles": []}`), nil).Times(1) - mockHTTPClient.EXPECT().DoCall(gomock.Any(), "PUT", "/bundle/bundleHash", gomock.Any()).Return([]byte(`{"bundleHash": "bundleHash", "missingFiles": []}`), nil).Times(1) + mockHTTPClient.EXPECT().Do( + mock.MatchedBy(func(i interface{}) bool { + req := i.(*http.Request) + return req.URL.String() == "http://localhost/bundle" && + req.Method == "POST" && + req.Header.Get("Cache-Control") == "private, max-age=0, no-cache" && + req.Header.Get("Content-Encoding") == "gzip" && + req.Header.Get("Content-Type") == "application/octet-stream" + }), + ).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(`{"bundleHash": "bundleHash", "missingFiles": []}`))), + }, nil).Times(1) + mockHTTPClient.EXPECT().Do( + mock.MatchedBy(func(i interface{}) bool { + req := i.(*http.Request) + return req.URL.String() == "http://localhost/bundle/bundleHash" && + req.Method == "PUT" && + req.Header.Get("Cache-Control") == "private, max-age=0, no-cache" && + req.Header.Get("Content-Encoding") == "gzip" && + req.Header.Get("Content-Type") == "application/octet-stream" + }), + ).Return(&http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader([]byte(`{"bundleHash": "bundleHash", "missingFiles": []}`))), + }, nil).Times(1) mockInstrumentor := mocks.NewMockInstrumentor(ctrl) mockInstrumentor.EXPECT().StartSpan(gomock.Any(), gomock.Any()).Return(mockSpan).Times(2) mockInstrumentor.EXPECT().Finish(gomock.Any()).Times(2) + mockErrorReporter := mocks.NewMockErrorReporter(ctrl) - s := deepcode.NewSnykCodeClient(newLogger(t), mockHTTPClient, mockInstrumentor) + s := deepcode.NewSnykCodeClient(newLogger(t), mockHTTPClient, mockInstrumentor, mockErrorReporter, mockConfig) var removedFiles []string files := map[string]string{} files[path1] = util.Hash([]byte(content)) @@ -132,6 +189,36 @@ func TestSnykCodeBackendService_ExtendBundle(t *testing.T) { assert.NotEmpty(t, bundleHash) } +func Test_Host(t *testing.T) { + ctrl := gomock.NewController(t) + mockConfig := confMocks.NewMockConfig(ctrl) + mockConfig.EXPECT().SnykCodeApi().AnyTimes().Return("https://snyk.io/api/v1") + mockHTTPClient := httpmocks.NewMockHTTPClient(ctrl) + mockInstrumentor := mocks.NewMockInstrumentor(ctrl) + mockErrorReporter := mocks.NewMockErrorReporter(ctrl) + + t.Run("Changes the URL if FedRAMP", func(t *testing.T) { + mockConfig.EXPECT().Organization().AnyTimes().Return("00000000-0000-0000-0000-000000000023") + mockConfig.EXPECT().IsFedramp().Times(1).Return(true) + + s := deepcode.NewSnykCodeClient(newLogger(t), mockHTTPClient, mockInstrumentor, mockErrorReporter, mockConfig) + + actual, err := s.Host() + assert.Nil(t, err) + assert.Contains(t, actual, "https://api.snyk.io/hidden/orgs/00000000-0000-0000-0000-000000000023/code") + }) + + t.Run("Does not change the URL if it's not FedRAMP", func(t *testing.T) { + mockConfig.EXPECT().Organization().AnyTimes().Return("") + mockConfig.EXPECT().IsFedramp().Times(1).Return(false) + s := deepcode.NewSnykCodeClient(newLogger(t), mockHTTPClient, mockInstrumentor, mockErrorReporter, mockConfig) + + actual, err := s.Host() + assert.Nil(t, err) + assert.Contains(t, actual, "https://snyk.io/api/v1") + }) +} + func createTestExtendMap() map[string]deepcode.BundleFile { filesExtend := map[string]deepcode.BundleFile{} diff --git a/http/http.go b/http/http.go index 21a97dec..5a2486eb 100644 --- a/http/http.go +++ b/http/http.go @@ -18,30 +18,18 @@ package http import ( - "bytes" - "context" "errors" - "io" "net/http" - "net/url" - "regexp" "time" "github.com/rs/zerolog" - "github.com/snyk/code-client-go/internal/util/encoding" "github.com/snyk/code-client-go/observability" ) //go:generate mockgen -destination=mocks/http.go -source=http.go -package mocks type HTTPClient interface { - Config() Config - DoCall(ctx context.Context, - method string, - path string, - requestBody []byte, - ) (responseBody []byte, err error) - FormatCodeApiURL() (string, error) + Do(req *http.Request) (*http.Response, error) } type httpClient struct { @@ -49,17 +37,15 @@ type httpClient struct { instrumentor observability.Instrumentor errorReporter observability.ErrorReporter logger *zerolog.Logger - config Config } func NewHTTPClient( logger *zerolog.Logger, - config Config, clientFactory func() *http.Client, instrumentor observability.Instrumentor, errorReporter observability.ErrorReporter, ) HTTPClient { - return &httpClient{clientFactory, instrumentor, errorReporter, logger, config} + return &httpClient{clientFactory, instrumentor, errorReporter, logger} } var retryErrorCodes = map[int]bool{ @@ -69,41 +55,21 @@ var retryErrorCodes = map[int]bool{ http.StatusInternalServerError: true, } -func (s *httpClient) Config() Config { - return s.config -} - -func (s *httpClient) DoCall(ctx context.Context, - method string, - path string, - requestBody []byte, -) (responseBody []byte, err error) { - span := s.instrumentor.StartSpan(ctx, "http.DoCall") +func (s *httpClient) Do(req *http.Request) (response *http.Response, err error) { + span := s.instrumentor.StartSpan(req.Context(), "http.Do") defer s.instrumentor.Finish(span) const retryCount = 3 for i := 0; i < retryCount; i++ { requestId := span.GetTraceId() + req.Header.Set("snyk-request-id", requestId) - var bodyBuffer *bytes.Buffer - bodyBuffer, err = s.encodeIfNeeded(method, requestBody) - if err != nil { - return nil, err - } + s.logger.Trace().Str("snyk-request-id", requestId).Msg("SEND TO REMOTE") - var req *http.Request - req, err = s.newRequest(method, path, bodyBuffer, requestId) - if err != nil { - return nil, err - } - - s.logger.Trace().Str("requestBody", string(requestBody)).Str("snyk-request-id", requestId).Msg("SEND TO REMOTE") - - var response *http.Response - response, responseBody, err = s.httpCall(req) //nolint:bodyclose // Already closed before in httpCall + response, err = s.httpCall(req) - if response != nil && responseBody != nil { - s.logger.Trace().Str("response.Status", response.Status).Str("responseBody", string(responseBody)).Str("snyk-request-id", requestId).Msg("RECEIVED FROM REMOTE") + if response != nil { + s.logger.Trace().Str("response.Status", response.Status).Str("snyk-request-id", requestId).Msg("RECEIVED FROM REMOTE") } else { s.logger.Trace().Str("snyk-request-id", requestId).Msg("RECEIVED FROM REMOTE") } @@ -115,7 +81,7 @@ func (s *httpClient) DoCall(ctx context.Context, err = s.checkResponseCode(response) if err != nil { if retryErrorCodes[response.StatusCode] { - s.logger.Debug().Err(err).Str("method", method).Int("attempts done", i+1).Msg("retrying") + s.logger.Debug().Err(err).Str("method", req.Method).Int("attempts done", i+1).Msg("retrying") if i < retryCount-1 { time.Sleep(5 * time.Second) continue @@ -128,92 +94,19 @@ func (s *httpClient) DoCall(ctx context.Context, // no error, we can break the retry loop break } - return responseBody, err -} - -func (s *httpClient) newRequest( - method string, - path string, - body *bytes.Buffer, - requestId string, -) (*http.Request, error) { - host, err := s.FormatCodeApiURL() - if err != nil { - return nil, err - } - - req, err := http.NewRequest(method, host+path, body) - if err != nil { - return nil, err - } - - s.addOrganization(req) - s.addDefaultHeaders(req, requestId, method) - return req, nil + return response, err } -func (s *httpClient) httpCall(req *http.Request) (*http.Response, []byte, error) { - log := s.logger.With().Str("method", "code.httpCall").Logger() +func (s *httpClient) httpCall(req *http.Request) (*http.Response, error) { + log := s.logger.With().Str("method", "http.httpCall").Logger() response, err := s.clientFactory().Do(req) if err != nil { log.Error().Err(err).Msg("got http error") s.errorReporter.CaptureError(err, observability.ErrorReporterOptions{ErrorDiagnosticPath: req.RequestURI}) - return nil, nil, err - } - - defer func(Body io.ReadCloser) { - closeErr := Body.Close() - if closeErr != nil { - log.Error().Err(closeErr).Msg("Couldn't close response body in call to Snyk Code") - } - }(response.Body) - responseBody, err := io.ReadAll(response.Body) - - if err != nil { - log.Error().Err(err).Msg("error reading response body") - s.errorReporter.CaptureError(err, observability.ErrorReporterOptions{ErrorDiagnosticPath: req.RequestURI}) - return nil, nil, err + return nil, err } - return response, responseBody, nil -} -func (s *httpClient) addOrganization(req *http.Request) { - // Setting a chosen org name for the request - org := s.config.Organization() - if org != "" { - req.Header.Set("snyk-org-name", org) - } -} - -func (s *httpClient) addDefaultHeaders(req *http.Request, requestId string, method string) { - req.Header.Set("snyk-request-id", requestId) - // https://www.keycdn.com/blog/http-cache-headers - req.Header.Set("Cache-Control", "private, max-age=0, no-cache") - if s.mustBeEncoded(method) { - req.Header.Set("Content-Type", "application/octet-stream") - req.Header.Set("Content-Encoding", "gzip") - } else { - req.Header.Set("Content-Type", "application/json") - } -} - -func (s *httpClient) encodeIfNeeded(method string, requestBody []byte) (*bytes.Buffer, error) { - b := new(bytes.Buffer) - mustBeEncoded := s.mustBeEncoded(method) - if mustBeEncoded { - enc := encoding.NewEncoder(b) - _, err := enc.Write(requestBody) - if err != nil { - return nil, err - } - } else { - b = bytes.NewBuffer(requestBody) - } - return b, nil -} - -func (s *httpClient) mustBeEncoded(method string) bool { - return method == http.MethodPost || method == http.MethodPut + return response, nil } func (s *httpClient) checkResponseCode(r *http.Response) error { @@ -222,28 +115,3 @@ func (s *httpClient) checkResponseCode(r *http.Response) error { } return errors.New("Unexpected response code: " + r.Status) } - -var codeApiRegex = regexp.MustCompile(`^(deeproxy\.)?`) - -// This is only exported for tests. -func (s *httpClient) FormatCodeApiURL() (string, error) { - snykCodeApiUrl := s.config.SnykCodeApi() - if !s.Config().IsFedramp() { - return snykCodeApiUrl, nil - } - u, err := url.Parse(snykCodeApiUrl) - if err != nil { - return "", err - } - - u.Host = codeApiRegex.ReplaceAllString(u.Host, "api.") - - organization := s.Config().Organization() - if organization == "" { - return "", errors.New("Organization is required in a fedramp environment") - } - - u.Path = "/hidden/orgs/" + organization + "/code" - - return u.String(), nil -} diff --git a/http/http_test.go b/http/http_test.go index 243eb9df..d4ef559f 100644 --- a/http/http_test.go +++ b/http/http_test.go @@ -16,16 +16,15 @@ package http_test import ( - "context" "net/http" "testing" "github.com/golang/mock/gomock" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" codeClientHTTP "github.com/snyk/code-client-go/http" - httpmocks "github.com/snyk/code-client-go/http/mocks" "github.com/snyk/code-client-go/observability" "github.com/snyk/code-client-go/observability/mocks" ) @@ -60,13 +59,12 @@ func TestSnykCodeBackendService_DoCall_shouldRetry(t *testing.T) { mockInstrumentor.EXPECT().StartSpan(gomock.Any(), gomock.Any()).Return(mockSpan).Times(1) mockInstrumentor.EXPECT().Finish(gomock.Any()).Times(1) mockErrorReporter := mocks.NewMockErrorReporter(ctrl) - config := httpmocks.NewMockConfig(ctrl) - config.EXPECT().IsFedramp().AnyTimes().Return(false) - config.EXPECT().Organization().AnyTimes().Return("") - config.EXPECT().SnykCodeApi().AnyTimes().Return("") - s := codeClientHTTP.NewHTTPClient(newLogger(t), config, dummyClientFactory, mockInstrumentor, mockErrorReporter) - _, err := s.DoCall(context.Background(), "GET", "https: //httpstat.us/500", nil) + req, err := http.NewRequest(http.MethodGet, "https://httpstat.us/500", nil) + require.NoError(t, err) + + s := codeClientHTTP.NewHTTPClient(newLogger(t), dummyClientFactory, mockInstrumentor, mockErrorReporter) + _, err = s.Do(req) assert.Error(t, err) assert.Equal(t, 3, d.calls) } @@ -84,56 +82,13 @@ func TestSnykCodeBackendService_doCall_rejected(t *testing.T) { mockInstrumentor.EXPECT().Finish(gomock.Any()).Times(1) mockErrorReporter := mocks.NewMockErrorReporter(ctrl) mockErrorReporter.EXPECT().CaptureError(gomock.Any(), observability.ErrorReporterOptions{ErrorDiagnosticPath: ""}) - config := httpmocks.NewMockConfig(ctrl) - config.EXPECT().IsFedramp().AnyTimes().Return(false) - config.EXPECT().Organization().AnyTimes().Return("") - config.EXPECT().SnykCodeApi().AnyTimes().Return("") - s := codeClientHTTP.NewHTTPClient(newLogger(t), config, dummyClientFactory, mockInstrumentor, mockErrorReporter) - _, err := s.DoCall(context.Background(), "GET", "https://127.0.0.1", nil) - assert.Error(t, err) -} + req, err := http.NewRequest(http.MethodGet, "https://127.0.0.1", nil) + require.NoError(t, err) -func Test_FormatCodeApiURL(t *testing.T) { - ctrl := gomock.NewController(t) - mockInstrumentor := mocks.NewMockInstrumentor(ctrl) - mockErrorReporter := mocks.NewMockErrorReporter(ctrl) - logger := newLogger(t) - dummyClientFactory := func() *http.Client { - return &http.Client{} - } - - orgUUID := "00000000-0000-0000-0000-000000000023" - - t.Run("Changes the URL if FedRAMP", func(t *testing.T) { - mockHTTPClient := httpmocks.NewMockHTTPClient(ctrl) - config := httpmocks.NewMockConfig(ctrl) - config.EXPECT().IsFedramp().AnyTimes().Return(true) - config.EXPECT().Organization().AnyTimes().Return(orgUUID) - config.EXPECT().SnykCodeApi().AnyTimes().Return("https://snyk.io/api/v1") - mockHTTPClient.EXPECT().Config().AnyTimes().Return(config) - - s := codeClientHTTP.NewHTTPClient(logger, config, dummyClientFactory, mockInstrumentor, mockErrorReporter) - - actual, err := s.FormatCodeApiURL() - assert.Nil(t, err) - assert.Contains(t, actual, "https://api.snyk.io/hidden/orgs/00000000-0000-0000-0000-000000000023/code") - }) - - t.Run("Does not change the URL if it's not FedRAMP", func(t *testing.T) { - mockHTTPClient := httpmocks.NewMockHTTPClient(ctrl) - config := httpmocks.NewMockConfig(ctrl) - config.EXPECT().IsFedramp().AnyTimes().Return(false) - config.EXPECT().Organization().AnyTimes().Return("") - config.EXPECT().SnykCodeApi().AnyTimes().Return("https://snyk.io/api/v1") - mockHTTPClient.EXPECT().Config().AnyTimes().Return(config) - - s := codeClientHTTP.NewHTTPClient(logger, config, dummyClientFactory, mockInstrumentor, mockErrorReporter) - - actual, err := s.FormatCodeApiURL() - assert.Nil(t, err) - assert.Contains(t, actual, "https://snyk.io/api/v1") - }) + s := codeClientHTTP.NewHTTPClient(newLogger(t), dummyClientFactory, mockInstrumentor, mockErrorReporter) + _, err = s.Do(req) + assert.Error(t, err) } func newLogger(t *testing.T) *zerolog.Logger { diff --git a/http/mocks/http.go b/http/mocks/http.go index 876647e7..b996f9b7 100644 --- a/http/mocks/http.go +++ b/http/mocks/http.go @@ -5,11 +5,10 @@ package mocks import ( - context "context" + http "net/http" reflect "reflect" gomock "github.com/golang/mock/gomock" - http "github.com/snyk/code-client-go/http" ) // MockHTTPClient is a mock of HTTPClient interface. @@ -35,46 +34,17 @@ func (m *MockHTTPClient) EXPECT() *MockHTTPClientMockRecorder { return m.recorder } -// Config mocks base method. -func (m *MockHTTPClient) Config() http.Config { +// Do mocks base method. +func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Config") - ret0, _ := ret[0].(http.Config) - return ret0 -} - -// Config indicates an expected call of Config. -func (mr *MockHTTPClientMockRecorder) Config() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Config", reflect.TypeOf((*MockHTTPClient)(nil).Config)) -} - -// DoCall mocks base method. -func (m *MockHTTPClient) DoCall(ctx context.Context, method, path string, requestBody []byte) ([]byte, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DoCall", ctx, method, path, requestBody) - ret0, _ := ret[0].([]byte) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DoCall indicates an expected call of DoCall. -func (mr *MockHTTPClientMockRecorder) DoCall(ctx, method, path, requestBody interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoCall", reflect.TypeOf((*MockHTTPClient)(nil).DoCall), ctx, method, path, requestBody) -} - -// FormatCodeApiURL mocks base method. -func (m *MockHTTPClient) FormatCodeApiURL() (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FormatCodeApiURL") - ret0, _ := ret[0].(string) + ret := m.ctrl.Call(m, "Do", req) + ret0, _ := ret[0].(*http.Response) ret1, _ := ret[1].(error) return ret0, ret1 } -// FormatCodeApiURL indicates an expected call of FormatCodeApiURL. -func (mr *MockHTTPClientMockRecorder) FormatCodeApiURL() *gomock.Call { +// Do indicates an expected call of Do. +func (mr *MockHTTPClientMockRecorder) Do(req interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FormatCodeApiURL", reflect.TypeOf((*MockHTTPClient)(nil).FormatCodeApiURL)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockHTTPClient)(nil).Do), req) }