Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func NewDefaultClientFactory() HTTPClientFactory {
return clientFunc
}

func AddDefaultHeaders(req *http.Request, requestId string, orgId string, method string) {
func AddDefaultHeaders(req *http.Request, requestId string, orgId string, method string, needsEncoding bool) {
// if requestId is empty it will be enriched from the Gateway
if len(requestId) > 0 {
req.Header.Set("snyk-request-id", requestId)
Expand All @@ -175,7 +175,7 @@ func AddDefaultHeaders(req *http.Request, requestId string, orgId string, method
// https://www.keycdn.com/blog/http-cache-headers
req.Header.Set("Cache-Control", "private, max-age=0, no-cache")

if mustBeEncoded(method) {
if mustBeEncoded(method, needsEncoding) {
req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("Content-Encoding", "gzip")
} else {
Expand All @@ -185,9 +185,9 @@ func AddDefaultHeaders(req *http.Request, requestId string, orgId string, method

// EncodeIfNeeded returns a byte buffer for the requestBody. Depending on the request method, it may encode the buffer.
// (See http.mustBeEncoded for the list of methods which require encoding the request body.)
func EncodeIfNeeded(method string, requestBody []byte) (*bytes.Buffer, error) {
func EncodeIfNeeded(method string, requestBody []byte, needsEncoding bool) (*bytes.Buffer, error) {
b := new(bytes.Buffer)
if mustBeEncoded(method) {
if mustBeEncoded(method, needsEncoding) {
enc := encoding.NewEncoder(b)
_, err := enc.Write(requestBody)
if err != nil {
Expand All @@ -200,6 +200,6 @@ func EncodeIfNeeded(method string, requestBody []byte) (*bytes.Buffer, error) {
}

// mustBeEncoded returns true if the request method requires the request body to be encoded.
func mustBeEncoded(method string) bool {
return method == http.MethodPost || method == http.MethodPut
func mustBeEncoded(method string, needsEncoding bool) bool {
return needsEncoding && (method == http.MethodPost || method == http.MethodPut)
}
4 changes: 2 additions & 2 deletions internal/analysis/analysis_legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (a *analysisOrchestrator) RunLegacyTest(ctx context.Context, bundleHash str
httpMethod := http.MethodPost

// Encode the request body
bodyBuffer, err := codeClientHTTP.EncodeIfNeeded(http.MethodPost, requestBody)
bodyBuffer, err := codeClientHTTP.EncodeIfNeeded(http.MethodPost, requestBody, true)
if err != nil {
a.logger.Err(err).Str("requestBody", string(requestBody)).Msg("error encoding request body")
return nil, scan.LegacyScanStatus{}, err
Expand All @@ -176,7 +176,7 @@ func (a *analysisOrchestrator) RunLegacyTest(ctx context.Context, bundleHash str
a.logger.Err(err).Str("method", method).Msg("error creating HTTP request")
return nil, scan.LegacyScanStatus{}, err
}
codeClientHTTP.AddDefaultHeaders(req, span.GetTraceId(), a.config.Organization(), httpMethod)
codeClientHTTP.AddDefaultHeaders(req, span.GetTraceId(), a.config.Organization(), httpMethod, true)

// Make HTTP call
resp, err := a.httpClient.Do(req)
Expand Down
4 changes: 2 additions & 2 deletions internal/deepcode/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ func (s *deepcodeClient) Request(
return nil, err
}

bodyBuffer, err := codeClientHTTP.EncodeIfNeeded(method, requestBody)
bodyBuffer, err := codeClientHTTP.EncodeIfNeeded(method, requestBody, true)
if err != nil {
return nil, err
}
Expand All @@ -225,7 +225,7 @@ func (s *deepcodeClient) Request(
return nil, err
}

codeClientHTTP.AddDefaultHeaders(req, codeClientHTTP.NoRequestId, s.config.Organization(), method)
codeClientHTTP.AddDefaultHeaders(req, codeClientHTTP.NoRequestId, s.config.Organization(), method, true)

response, err := s.httpClient.Do(req)
if err != nil {
Expand Down
12 changes: 6 additions & 6 deletions llm/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain
}
}

responseBody, err := d.submitRequest(span.Context(), u, requestBody, "")
responseBody, err := d.submitRequest(span.Context(), u, requestBody, "", false)
if err != nil {
return Explanations{}, err
}
Expand All @@ -63,14 +63,14 @@ func (d *DeepCodeLLMBindingImpl) runExplain(ctx context.Context, options Explain
return explains, nil
}

func (d *DeepCodeLLMBindingImpl) submitRequest(ctx context.Context, url *url.URL, requestBody []byte, orgId string) ([]byte, error) {
func (d *DeepCodeLLMBindingImpl) submitRequest(ctx context.Context, url *url.URL, requestBody []byte, orgId string, needsEncoding bool) ([]byte, error) {
logger := d.logger.With().Str("method", "submitRequest").Logger()
logger.Trace().Str("payload body: %s\n", string(requestBody)).Msg("Marshaled payload")
span := d.instrumentor.StartSpan(ctx, "code.SubmitRequest")
defer span.Finish()

// Encode the request body
bodyBuffer, err := http2.EncodeIfNeeded(http.MethodPost, requestBody)
bodyBuffer, err := http2.EncodeIfNeeded(http.MethodPost, requestBody, needsEncoding)
if err != nil {
logger.Err(err).Str("requestBody", string(requestBody)).Msg("error encoding request body")
return nil, err
Expand All @@ -82,7 +82,7 @@ func (d *DeepCodeLLMBindingImpl) submitRequest(ctx context.Context, url *url.URL
return nil, err
}

http2.AddDefaultHeaders(req, http2.NoRequestId, orgId, http.MethodPost)
http2.AddDefaultHeaders(req, http2.NoRequestId, orgId, http.MethodPost, needsEncoding)

resp, err := d.httpClientFunc().Do(req) //nolint:bodyclose // this seems to be a false positive
if err != nil {
Expand Down Expand Up @@ -152,7 +152,7 @@ func (d *DeepCodeLLMBindingImpl) runAutofix(ctx context.Context, options Autofix
}

logger.Info().Msg("Started obtaining autofix Response")
responseBody, err := d.submitRequest(span.Context(), endpoint, requestBody, options.CodeRequestContext.Org.PublicId)
responseBody, err := d.submitRequest(span.Context(), endpoint, requestBody, options.CodeRequestContext.Org.PublicId, true)
logger.Info().Msg("Finished obtaining autofix Response")

if err != nil {
Expand Down Expand Up @@ -228,7 +228,7 @@ func (d *DeepCodeLLMBindingImpl) submitAutofixFeedback(ctx context.Context, opti
}

logger.Info().Msg("Started obtaining autofix Response")
_, err = d.submitRequest(span.Context(), endpoint, requestBody, options.CodeRequestContext.Org.PublicId)
_, err = d.submitRequest(span.Context(), endpoint, requestBody, options.CodeRequestContext.Org.PublicId, true)
logger.Info().Msg("Finished obtaining autofix Response")

return err
Expand Down
175 changes: 174 additions & 1 deletion llm/api_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,30 @@ func testLogger(t *testing.T) *zerolog.Logger {
func TestAddDefaultHeadersWithExistingHeaders(t *testing.T) {
req := &http.Request{Header: http.Header{"Existing-Header": {"existing-value"}}}

http2.AddDefaultHeaders(req, http2.NoRequestId, "", http.MethodGet)
http2.AddDefaultHeaders(req, http2.NoRequestId, "", http.MethodPost, true)

cacheControl := req.Header.Get("Cache-Control")
contentType := req.Header.Get("Content-Type")
existingHeader := req.Header.Get("Existing-Header")

if cacheControl != "private, max-age=0, no-cache" {
t.Errorf("Expected Cache-Control header to be 'private, max-age=0, no-cache', got %s", cacheControl)
}

if contentType != "application/octet-stream" {
t.Errorf("Expected Content-Type header to be 'application/json', got %s", contentType)
}

if existingHeader != "existing-value" {
t.Errorf("Expected Existing-Header to be 'existing-value', got %s", existingHeader)
}
}

// Test with existing headers
func TestAddDefaultHeadersWithSkipEncodingEnabled(t *testing.T) {
req := &http.Request{Header: http.Header{"Existing-Header": {"existing-value"}}}

http2.AddDefaultHeaders(req, http2.NoRequestId, "", http.MethodPost, false)

cacheControl := req.Header.Get("Cache-Control")
contentType := req.Header.Get("Content-Type")
Expand Down Expand Up @@ -350,3 +373,153 @@ func TestAutofixRequestBody(t *testing.T) {

assert.Equal(t, expectedBody, body)
}

func TestRunExplain_WithHeaderValidation(t *testing.T) {
t.Run("vulnerability explanation with headers", func(t *testing.T) {
ruleKey := "test-rule-key"
derivation := "test-derivation"
ruleMessage := "test-rule-message"

expectedResponse := Explanations{
"explanation1": "This is the first explanation",
"explanation2": "This is the second explanation",
}

response := explainResponse{
Status: completeStatus,
Explanation: expectedResponse,
}

responseBodyBytes, err := json.Marshal(response)
require.NoError(t, err)

// Create a test server that validates headers
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify headers
assert.Equal(t, "private, max-age=0, no-cache", r.Header.Get("Cache-Control"))
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
assert.Equal(t, http.MethodPost, r.Method)

// Verify request body
body, readErr := io.ReadAll(r.Body)
require.NoError(t, readErr)

var requestData explainVulnerabilityRequest
err = json.Unmarshal(body, &requestData)
require.NoError(t, err)

assert.Equal(t, ruleKey, requestData.RuleId)
assert.Equal(t, derivation, requestData.Derivation)
assert.Equal(t, ruleMessage, requestData.RuleMessage)
assert.Equal(t, SHORT, requestData.ExplanationLength)

// Send response
w.WriteHeader(http.StatusOK)
_, _ = w.Write(responseBodyBytes)
}))
defer server.Close()

// Parse server URL
u, err := url.Parse(server.URL)
require.NoError(t, err)

// Create options
options := ExplainOptions{
RuleKey: ruleKey,
Derivation: derivation,
RuleMessage: ruleMessage,
Endpoint: u,
}

// Create DeepCodeLLMBinding
d := NewDeepcodeLLMBinding()

// Run the test
ctx := t.Context()
ctx = observability.GetContextWithTraceId(ctx, "test-trace-id")

result, err := d.runExplain(ctx, options)

// Verify results
require.NoError(t, err)
assert.Equal(t, expectedResponse, result)
})

t.Run("fix explanation with base64 encoded diffs and headers", func(t *testing.T) {
ruleKey := "test-rule-key"
testDiffs := []string{
"--- a/file.txt\n+++ b/file.txt\n@@ -1,1 +1,1 @@\n-old line\n+new line\n",
}

expectedResponse := Explanations{
"explanation1": "This explains the fix",
}

response := explainResponse{
Status: completeStatus,
Explanation: expectedResponse,
}

responseBodyBytes, err := json.Marshal(response)
require.NoError(t, err)

// Create a test server that validates headers and base64 encoded diffs
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify headers
assert.Equal(t, "private, max-age=0, no-cache", r.Header.Get("Cache-Control"))
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
assert.Equal(t, http.MethodPost, r.Method)

// Verify request body
body, readErr := io.ReadAll(r.Body)
require.NoError(t, readErr)

var requestData explainFixRequest
err = json.Unmarshal(body, &requestData)
require.NoError(t, err)

assert.Equal(t, ruleKey, requestData.RuleId)
assert.Equal(t, SHORT, requestData.ExplanationLength)

// Verify diffs are base64 encoded
require.Len(t, requestData.Diffs, 1)

// Decode the base64 diff to verify it was encoded properly
decodedDiff, decodeErr := base64.StdEncoding.DecodeString(requestData.Diffs[0])
require.NoError(t, decodeErr)

// The prepareDiffs function strips --- and +++ headers and adds a newline
expectedDecodedDiff := "@@ -1,1 +1,1 @@\n-old line\n+new line\n\n"
assert.Equal(t, expectedDecodedDiff, string(decodedDiff))

// Send response
w.WriteHeader(http.StatusOK)
_, _ = w.Write(responseBodyBytes)
}))
defer server.Close()

// Parse server URL
u, err := url.Parse(server.URL)
require.NoError(t, err)

// Create options
options := ExplainOptions{
RuleKey: ruleKey,
Diffs: testDiffs,
Endpoint: u,
}

// Create DeepCodeLLMBinding
d := NewDeepcodeLLMBinding()

// Run the test
ctx := t.Context()
ctx = observability.GetContextWithTraceId(ctx, "test-trace-id")

result, err := d.runExplain(ctx, options)

// Verify results
require.NoError(t, err)
assert.Equal(t, expectedResponse, result)
})
}