diff --git a/CHANGELOG.md b/CHANGELOG.md index f6ad9156..2272559a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +# v0.16.1 - 30 September 2020 + +## Fixes + +- (#158) PII: make the PII scrubbing of In-App WAF attack events + case-insensitive in order to correctly scrub transformed request parameters. + +- (#159) Monitoring: fix the content type and length monitoring of HTTP + responses. + +- (#157) Gin middleware: use the request Go context instead of Gin's so that the + agent can properly manage the request execution context, but also to correctly + propagate values stored in the Go context before the middleware function. + + # v0.16.0 - 22 September 2020 ## New Feature diff --git a/README.md b/README.md index da8142df..0e638ffe 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@

-Sqreen for Go +Sqreen for Go

# [Sqreen](https://www.sqreen.com/)'s Application Security Management for Go @@ -9,7 +9,9 @@ security component into your app. Sqreen’s microagent automatically monitors sensitive app’s routines, blocks attacks and reports actionable infos to your dashboard. -![Dashboard](https://sqreen-assets.s3-eu-west-1.amazonaws.com/miscellaneous/dashboard.gif) +

+Sqreen for Go +

Sqreen provides automatic defense against attacks: @@ -90,4 +92,4 @@ Congratulations, your Go web application is now protected by Sqreen! Optionally, use the SDK to perform [user monitoring](https://docs.sqreen.com/go/user-monitoring/) or [custom security events](https://docs.sqreen.com/go/custom-events/) you would -like to track and possibly block. \ No newline at end of file +like to track and possibly block. diff --git a/internal/backend/api/api.go b/internal/backend/api/api.go index ba138ae4..36d55f31 100644 --- a/internal/backend/api/api.go +++ b/internal/backend/api/api.go @@ -6,9 +6,10 @@ package api import ( "encoding/json" - "strings" + "regexp" "time" + "github.com/sqreen/go-agent/internal/sqlib/sqerrors" "github.com/sqreen/go-agent/internal/sqlib/sqsanitize" ) @@ -451,16 +452,22 @@ func (i *WAFAttackInfo) Scrub(scrubber *sqsanitize.Scrubber, info sqsanitize.Inf // were scrubbed. The caller must have stored into info the values scrubbed // from the request. redactedString := scrubber.RedactedValueMask() - for e := range wafInfo { + for sanitized := range info { + re, err := regexp.Compile(`(?i)`+regexp.QuoteMeta(sanitized)) + if err != nil { + return false, sqerrors.Wrapf(err, "could not ") + } + + for e := range wafInfo { for f := range wafInfo[e].Filter { - for v := range info { resolvedValue := wafInfo[e].Filter[f].ResolvedValue - newStr := strings.ReplaceAll(resolvedValue, v, redactedString) + + newStr := re.ReplaceAllString(resolvedValue, redactedString) if newStr != resolvedValue { // The string was changed wafInfo[e].Filter[f].ResolvedValue = newStr if wafInfo[e].Filter[f].MatchStatus != "" { - wafInfo[e].Filter[f].MatchStatus = strings.ReplaceAll(wafInfo[e].Filter[f].MatchStatus, v, redactedString) + wafInfo[e].Filter[f].MatchStatus = re.ReplaceAllString(wafInfo[e].Filter[f].MatchStatus, redactedString) } scrubbed = true } diff --git a/internal/backend/api/api_test.go b/internal/backend/api/api_test.go new file mode 100644 index 00000000..14c1ff7e --- /dev/null +++ b/internal/backend/api/api_test.go @@ -0,0 +1,111 @@ +package api_test + +import ( + "encoding/json" + "regexp" + "strings" + "testing" + + "github.com/sqreen/go-agent/internal/backend/api" + "github.com/sqreen/go-agent/internal/sqlib/sqsanitize" + "github.com/stretchr/testify/require" +) + +func TestWAFAttackInfo_Scrub(t *testing.T) { + t.Run("AGO-137", func(t *testing.T) { + // Test case covering issue https://sqreen.atlassian.net/browse/AGO-137 + + // The idea of the following mess is to mock a WAF attack with a lowercase + // transformation of the request parameters. The fix should correctly scrub + // the request parameters in the attack. + + // Set of PII values used in this test. They are uppercase here while they + // are lowercased in the attack information. + pii := map[string]string{ + "access_token": "PASSWORD_1", + "api_key": "PASSWORD_2", + "apikey": "PASSWORD_3", + "authorization": "PASSWORD_4", + } + + // Create test data including the resolved values of the WAF using lowercase + // parameters and the request parameters including a separate attack + // parameter. + resolvedValues := make(map[string]string, len(pii)) + params := make(map[string][]interface{}, len(pii)) + for k, v := range pii { + resolvedValues[k] = strings.ToLower(v) + params[k] = []interface{}{v} + } + resolvedValues["attack"] = "java.lang.processbuilder" + + // Prepare the WAF data json string + resolvedValuesJSON, err := json.Marshal(resolvedValues) + require.NoError(t, err) + resolvedValuesJSONStr, err := json.Marshal(string(resolvedValuesJSON)) + require.NoError(t, err) + wafData := []byte(`[ + { + "ret_code": 1, + "flow": "shell_injection-monitoring", + "step": "start", + "rule": "rule_944100", + "filter": [ + { + "operator": "@rx", + "operator_value": "java\\.lang\\.(?:runtime|processbuilder)", + "binding_accessor": "#.Request.Body.String", + "resolved_value": ` + string(resolvedValuesJSONStr) + `, + "match_status": "java.lang.processbuilder" + } + ] + } + ]`) + + // Create a fake request record with the interesting parts for this test + // only + record := api.RequestRecord{ + Request: api.RequestRecord_Request{ + Parameters: api.RequestRecord_Request_Parameters{ + Params: params, + }, + }, + Observed: api.RequestRecord_Observed{ + Attacks: []*api.RequestRecord_Observed_Attack{ + {Info: api.WAFAttackInfo{WAFData: wafData}}, + }, + }, + } + + // Create a scrubber of the PII values + keyRE := regexp.MustCompile(`(?i)(passw(((or)?d))|(phrase))|(secret)|(authorization)|(api_?key)|((access_?)?token)`) + valueRE := regexp.MustCompile(`(?:\d[ -]*?){13,16}`) + redactionString := "Redacted by Test" + scrubber := sqsanitize.NewScrubber(keyRE, valueRE, redactionString) + + // Scrub the request record + info := sqsanitize.Info{} + scrubbed, err := record.Scrub(scrubber, info) + + // It shouldn't fail and it should have scrubbed + require.NoError(t, err) + require.True(t, scrubbed) + + + scrubbedWAFData := string(record.Observed.Attacks[0].Info.(api.WAFAttackInfo).WAFData) + + // Check that the count of redactions in the WAF info string is correct: + // one per PII value. + require.Equal(t, len(pii), strings.Count(scrubbedWAFData, redactionString)) + + // For each PII value + for k, v := range pii { + // Check that the returned scrubbed values contain the PII value + require.Contains(t, info, v) + // Check that the WAF data has been scrubbed + require.NotContains(t, scrubbedWAFData, v) + // Check that the request parameter has been scrubbed + require.Equal(t, redactionString, record.Request.Parameters.Params[k][0].(string)) + } + }) +} diff --git a/internal/sqlib/sqsanitize/sanitize.go b/internal/sqlib/sqsanitize/sanitize.go index 619cb946..977bdadb 100644 --- a/internal/sqlib/sqsanitize/sanitize.go +++ b/internal/sqlib/sqsanitize/sanitize.go @@ -100,9 +100,7 @@ walk: v = v.Elem() goto walk - case reflect.Array: - fallthrough - case reflect.Slice: + case reflect.Array, reflect.Slice: return s.scrubSlice(v, info) case reflect.Map: @@ -212,6 +210,8 @@ func (s *Scrubber) scrubSlice(v reflect.Value, info Info) (scrubbed bool) { } func (s *Scrubber) scrubMap(v reflect.Value, info Info) (scrubbed bool) { + var scrubEverything *Scrubber + vt := v.Type().Elem() hasInterfaceValueType := vt.Kind() == reflect.Interface hasStringKeyType := v.Type().Key().Kind() == reflect.String @@ -223,9 +223,12 @@ func (s *Scrubber) scrubMap(v reflect.Value, info Info) (scrubbed bool) { // value regular expression. key := iter.Key() if hasStringKeyType && !s.scrubEveryString && matchString(s.keyRegexp, key.String()) { - scrubber = new(Scrubber) - *scrubber = *s - scrubber.scrubEveryString = true + if scrubEverything == nil { + scrubEverything = new(Scrubber) + *scrubEverything = *scrubber + scrubEverything.scrubEveryString = true + } + scrubber = scrubEverything } // Map entries cannot be set. We therefore create a new value in order @@ -247,9 +250,10 @@ func (s *Scrubber) scrubMap(v reflect.Value, info Info) (scrubbed bool) { // the scrubber. newVal := reflect.New(valT).Elem() newVal.Set(val) + // Scrub it - if scrubbedElement := scrubber.scrubValue(newVal, info); scrubbedElement { - // Set it + if scrubber.scrubValue(newVal, info) { + // Replace it v.SetMapIndex(key, newVal) scrubbed = true } @@ -258,6 +262,8 @@ func (s *Scrubber) scrubMap(v reflect.Value, info Info) (scrubbed bool) { } func (s *Scrubber) scrubStruct(v reflect.Value, info Info) (scrubbed bool) { + var scrubEverything *Scrubber + l := v.NumField() vt := v.Type() for i := 0; i < l; i++ { @@ -269,9 +275,12 @@ func (s *Scrubber) scrubStruct(v reflect.Value, info Info) (scrubbed bool) { scrubber := s if !s.scrubEveryString && matchString(s.keyRegexp, ft.Name) { - scrubber = new(Scrubber) - *scrubber = *s - scrubber.scrubEveryString = true + if scrubEverything == nil { + scrubEverything = new(Scrubber) + *scrubEverything = *scrubber + scrubEverything.scrubEveryString = true + } + scrubber = scrubEverything } f := v.Field(i) diff --git a/internal/sqlib/sqsanitize/sanitize_test.go b/internal/sqlib/sqsanitize/sanitize_test.go index fe5541a9..9f3a296f 100644 --- a/internal/sqlib/sqsanitize/sanitize_test.go +++ b/internal/sqlib/sqsanitize/sanitize_test.go @@ -1247,8 +1247,16 @@ func TestScrubber(t *testing.T) { fuzzer.Fuzz(&multipartForm) // Insert some values forbidden by the regular expression - postForm.Add("password", "1234") - postForm.Add("password", "5678") + postForm.Add("password", "password10") + postForm.Add("password", "password11") + postForm.Add("password", "password12") + postForm.Add("passwd", "password1") + postForm.Add("api_key", "password2") + postForm.Add("apikey", "password3") + postForm.Add("authorization", "password4") + postForm.Add("access_token", "password5") + postForm.Add("secret", "password6") + messageFormat := "here is my credit card number %s." stringWithCreditCardNb := fmt.Sprintf(messageFormat, "4533-3432-3234-3334") form.Add("message", stringWithCreditCardNb) @@ -1276,12 +1284,23 @@ func TestScrubber(t *testing.T) { require.True(t, scrubbed) // Check values were scrubbed - require.Equal(t, []string{expectedMask, expectedMask}, req.PostForm["password"]) + require.Equal(t, []string{expectedMask, expectedMask, expectedMask}, req.PostForm["password"]) + require.Equal(t, []string{expectedMask}, req.PostForm["passwd"]) + require.Equal(t, []string{expectedMask}, req.PostForm["api_key"]) + require.Equal(t, []string{expectedMask}, req.PostForm["apikey"]) + require.Equal(t, []string{expectedMask}, req.PostForm["authorization"]) + require.Equal(t, []string{expectedMask}, req.PostForm["access_token"]) + require.Equal(t, []string{expectedMask}, req.PostForm["secret"]) require.Equal(t, []string{fmt.Sprintf(messageFormat, expectedMask)}, req.Form["message"]) require.Contains(t, info, stringWithCreditCardNb) - require.Contains(t, info, "1234") - require.Contains(t, info, "5678") + require.Contains(t, info, "password10") + require.Contains(t, info, "password11") + require.Contains(t, info, "password2") + require.Contains(t, info, "password3") + require.Contains(t, info, "password4") + require.Contains(t, info, "password5") + require.Contains(t, info, "password6") }) }) }) diff --git a/internal/version/version.go b/internal/version/version.go index be419b56..9d9eb8a4 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -4,6 +4,6 @@ package version -const version = "0.16.0" +const version = "0.16.1" func Version() string { return version } diff --git a/sdk/middleware/sqecho/echo.go b/sdk/middleware/sqecho/echo.go index a1ace8d1..dc07447f 100644 --- a/sdk/middleware/sqecho/echo.go +++ b/sdk/middleware/sqecho/echo.go @@ -255,17 +255,22 @@ type observedResponse struct { } func newObservedResponse(r *responseWriterImpl) *observedResponse { + response := r.c.Response() + + headers := response.Header() + // Content-Type will be not empty only when explicitly set. // It could be guessed as net/http does. Not implemented for now. - ct := r.Header().Get("Content-Type") - - response := r.c.Response() + ct := headers.Get("Content-Type") - // Content-Length is either explicitly set or the amount of written data. + // Content-Length is either explicitly set or the amount of written data. It's + // 0 by default with Echo. cl := response.Size - if contentLength := r.Header().Get("Content-Length"); contentLength != "" { - if l, err := strconv.ParseInt(contentLength, 10, 0); err == nil { - cl = l + if cl == 0 { + if contentLength := headers.Get("Content-Length"); contentLength != "" { + if l, err := strconv.ParseInt(contentLength, 10, 0); err == nil { + cl = l + } } } diff --git a/sdk/middleware/sqecho/echo_test.go b/sdk/middleware/sqecho/echo_test.go index b0c2d626..8647c9b3 100644 --- a/sdk/middleware/sqecho/echo_test.go +++ b/sdk/middleware/sqecho/echo_test.go @@ -5,6 +5,7 @@ package sqecho import ( + "context" "errors" "io/ioutil" "net/http" @@ -100,7 +101,7 @@ func TestMiddleware(t *testing.T) { require.Equal(t, body, rec.Body.String()) }) - t.Run("control flow", func(t *testing.T) { + t.Run("data and control flow", func(t *testing.T) { middlewareResponseBody := testlib.RandUTF8String(4096) middlewareResponseStatus := 433 handlerResponseBody := testlib.RandUTF8String(4096) @@ -133,6 +134,11 @@ func TestMiddleware(t *testing.T) { handler func(echo.Context) error test func(t *testing.T, rec *httptest.ResponseRecorder, err error) }{ + // + // Control flow tests + // When an handlers, including middlewares, block. + // + { name: "sqreen first/the middleware aborts before the handler", middlewares: []echo.MiddlewareFunc{ @@ -341,6 +347,54 @@ func TestMiddleware(t *testing.T) { require.Equal(t, middlewareResponseBody+handlerResponseBody+middlewareResponseBody, rec.Body.String()) }, }, + + // + // Context data flow tests + // + { + name: "middleware1, sqreen, middleware2, handler", + middlewares: []echo.MiddlewareFunc{ + func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("m10", "v10") + c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), "m11", "v11"))) + return next(c) + } + }, + middleware(tc.agent), + func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + c.Set("m20", "v20") + c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), "m21", "v21"))) + return next(c) + } + }, + }, + handler: func(c echo.Context) error { + // From Gin's context + if v, ok := c.Get("m10").(string); !ok || v != "v10" { + panic("couldn't get the context value m10") + } + if v, ok := c.Get("m20").(string); !ok || v != "v20" { + panic("couldn't get the context value m20") + } + + // From the request context + reqCtx := c.Request().Context() + if v, ok := reqCtx.Value("m11").(string); !ok || v != "v11" { + panic("couldn't get the context value m11") + } + if v, ok := reqCtx.Value("m21").(string); !ok || v != "v21" { + panic("couldn't get the context value m21") + } + + return c.NoContent(http.StatusOK) + }, + test: func(t *testing.T, rec *httptest.ResponseRecorder, err error) { + require.NoError(t, err) + require.Equal(t, http.StatusOK, rec.Code) + }, + }, } { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -370,15 +424,23 @@ func TestMiddleware(t *testing.T) { t.Run("response observation", func(t *testing.T) { expectedStatusCode := 433 + expectedContentLength := int64(len("\"hello\"\n")) + expectedContentType := echo.MIMEApplicationJSONCharsetUTF8 agent := &mockups.AgentMockup{} agent.ExpectConfig().Return(&mockups.AgentConfigMockup{}).Once() agent.ExpectIsIPAllowed(mock.Anything).Return(false).Once() agent.ExpectIsPathAllowed(mock.Anything).Return(false).Once() - var responseStatusCode int + var ( + responseStatusCode int + responseContentType string + responseContentLength int64 + ) agent.ExpectSendClosedRequestContext(mock.MatchedBy(func(recorded types.ClosedRequestContextFace) bool { resp := recorded.Response() responseStatusCode = resp.Status() + responseContentLength = resp.ContentLength() + responseContentType = resp.ContentType() return true })).Return(nil) defer agent.AssertExpectations(t) @@ -387,7 +449,7 @@ func TestMiddleware(t *testing.T) { router := echo.New() router.Use(middleware(agent)) router.GET("/", func(c echo.Context) error { - return c.NoContent(expectedStatusCode) + return c.JSON(expectedStatusCode, "hello") }) // Perform the request and record the output @@ -401,8 +463,10 @@ func TestMiddleware(t *testing.T) { // Check the result require.NoError(t, err) - require.Equal(t, expectedStatusCode, responseStatusCode) require.Equal(t, expectedStatusCode, rec.Code) + require.Equal(t, expectedStatusCode, responseStatusCode) + require.Equal(t, expectedContentLength, responseContentLength) + require.Equal(t, expectedContentType, responseContentType) }) } diff --git a/sdk/middleware/sqecho/v4/README.md b/sdk/middleware/sqecho/v4/README.md new file mode 100644 index 00000000..f4a827ef --- /dev/null +++ b/sdk/middleware/sqecho/v4/README.md @@ -0,0 +1,36 @@ +

+Sqreen for Go +

+ +# [Sqreen](https://www.sqreen.com/)'s Application Security Management for Go + +After performance monitoring (APM), error and log monitoring it’s time to add a +security component into your app. Sqreen’s microagent automatically monitors +sensitive app’s routines, blocks attacks and reports actionable infos to your +dashboard. + +

+Sqreen for Go +

+ +# Echo middleware function + +This package provides Sqreen's middleware function for Echo to monitor and +protect requests Echo receives. Simply setup the middleware function to have +your requests monitored and protected by Sqreen. + +Usage: + +```go +e := echo.New() +// Setup Sqreen's middleware +e.Use(sqecho.Middleware()) + +// Every router endpoint is now automatically monitored and protected by Sqreen +e.GET("/", func(c echo.Context) error { + // ... +} +``` + +Find more details on how to setup Sqreen for Go at + \ No newline at end of file diff --git a/sdk/middleware/sqgin/README.md b/sdk/middleware/sqgin/README.md new file mode 100644 index 00000000..cc703696 --- /dev/null +++ b/sdk/middleware/sqgin/README.md @@ -0,0 +1,36 @@ +

+Sqreen for Go +

+ +# [Sqreen](https://www.sqreen.com/)'s Application Security Management for Go + +After performance monitoring (APM), error and log monitoring it’s time to add a +security component into your app. Sqreen’s microagent automatically monitors +sensitive app’s routines, blocks attacks and reports actionable infos to your +dashboard. + +

+Sqreen for Go +

+ +# Gin middleware function + +This package provides Sqreen's middleware function for Gin to monitor and +protect requests Gin receives. Simply setup the middleware function to have your +requests monitored and protected by Sqreen. + +Usage: + +```go +router := gin.Default() +// Setup Sqreen's middleware +router.Use(sqgin.Middleware()) + +// Every router endpoint is now automatically monitored and protected by Sqreen +router.GET("/", func(c *gin.Context) { + c.Status(http.StatusOK) +} +``` + +Find more details on how to setup Sqreen for Go at + \ No newline at end of file diff --git a/sdk/middleware/sqgin/gin.go b/sdk/middleware/sqgin/gin.go index e0dad9f4..38a5d22f 100644 --- a/sdk/middleware/sqgin/gin.go +++ b/sdk/middleware/sqgin/gin.go @@ -76,7 +76,7 @@ func middlewareHandler(agent protection_context.AgentFace, c *gingonic.Context) requestReader := &requestReaderImpl{c: c} responseWriter := &responseWriterImpl{c: c} - ctx, reqCtx, cancelHandlerContext := http_protection.NewRequestContext(c, agent, responseWriter, requestReader) + ctx, reqCtx, cancelHandlerContext := http_protection.NewRequestContext(c.Request.Context(), agent, responseWriter, requestReader) if ctx == nil { c.Next() return @@ -248,15 +248,22 @@ type observedResponse struct { } func newObservedResponse(r *responseWriterImpl) *observedResponse { + headers := r.c.Writer.Header() + // Content-Type will be not empty only when explicitly set. // It could be guessed as net/http does. Not implemented for now. - ct := r.Header().Get("Content-Type") + ct := headers.Get("Content-Type") - // Content-Length is either explicitly set or the amount of written data. + // Content-Length is either explicitly set or the amount of written data. It's + // less than 0 when not set by default with Gin. cl := int64(r.c.Writer.Size()) - if contentLength := r.Header().Get("Content-Length"); contentLength != "" { - if l, err := strconv.ParseInt(contentLength, 10, 0); err == nil { - cl = l + if cl < 0 { + if contentLength := headers.Get("Content-Length"); contentLength != "" { + if l, err := strconv.ParseInt(contentLength, 10, 0); err == nil { + cl = l + } else { + cl = 0 + } } } diff --git a/sdk/middleware/sqgin/gin_test.go b/sdk/middleware/sqgin/gin_test.go index 80a02b59..9eec7b94 100644 --- a/sdk/middleware/sqgin/gin_test.go +++ b/sdk/middleware/sqgin/gin_test.go @@ -5,6 +5,7 @@ package sqgin import ( + "context" "net/http" "net/http/httptest" "testing" @@ -91,7 +92,7 @@ func TestMiddleware(t *testing.T) { }) // Test how the control flows between middleware and handler functions - t.Run("control flow", func(t *testing.T) { + t.Run("data and control flow", func(t *testing.T) { middlewareResponseBody := testlib.RandUTF8String(4096) middlewareResponseStatus := 433 handlerResponseBody := testlib.RandUTF8String(4096) @@ -124,6 +125,11 @@ func TestMiddleware(t *testing.T) { handler func(*gin.Context) test func(t *testing.T, rec *httptest.ResponseRecorder) }{ + // + // Control flow tests + // When an handlers, including middlewares, block. + // + { name: "sqreen first/next middleware aborts before the handler", middlewares: []gin.HandlerFunc{ @@ -292,6 +298,47 @@ func TestMiddleware(t *testing.T) { require.Equal(t, middlewareResponseBody, rec.Body.String()) }, }, + + // + // Context data flow tests + // + { + name: "middleware1, sqreen, middleware2, handler", + middlewares: []gin.HandlerFunc{ + func(c *gin.Context) { + c.Set("m10", "v10") + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), "m11", "v11")) + }, + middleware(tc.agent), + func(c *gin.Context) { + c.Set("m20", "v20") + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), "m21", "v21")) + }, + }, + handler: func(c *gin.Context) { + // From Gin's context + if v, ok := c.Value("m10").(string); !ok || v != "v10" { + panic("couldn't get the context value m10") + } + if v, ok := c.Value("m20").(string); !ok || v != "v20" { + panic("couldn't get the context value m20") + } + + // From the request context + reqCtx := c.Request.Context() + if v, ok := reqCtx.Value("m11").(string); !ok || v != "v11" { + panic("couldn't get the context value m11") + } + if v, ok := reqCtx.Value("m21").(string); !ok || v != "v21" { + panic("couldn't get the context value m21") + } + + c.Status(http.StatusOK) + }, + test: func(t *testing.T, rec *httptest.ResponseRecorder) { + require.Equal(t, http.StatusOK, rec.Code) + }, + }, } { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -317,15 +364,23 @@ func TestMiddleware(t *testing.T) { t.Run("response observation", func(t *testing.T) { expectedStatusCode := 433 + expectedContentLength := int64(len(`"hello"`)) + expectedContentType := "application/json; charset=utf-8" agent := &mockups.AgentMockup{} agent.ExpectConfig().Return(&mockups.AgentConfigMockup{}).Once() agent.ExpectIsIPAllowed(mock.Anything).Return(false).Once() agent.ExpectIsPathAllowed(mock.Anything).Return(false).Once() - var responseStatusCode int + var ( + responseStatusCode int + responseContentType string + responseContentLength int64 + ) agent.ExpectSendClosedRequestContext(mock.MatchedBy(func(recorded types.ClosedRequestContextFace) bool { resp := recorded.Response() responseStatusCode = resp.Status() + responseContentType = resp.ContentType() + responseContentLength = resp.ContentLength() return true })).Return(nil) defer agent.AssertExpectations(t) @@ -334,7 +389,7 @@ func TestMiddleware(t *testing.T) { router := gin.New() router.Use(middleware(agent)) router.GET("/", func(c *gin.Context) { - c.Status(expectedStatusCode) + c.JSON(expectedStatusCode, "hello") }) // Perform the request and record the output @@ -343,8 +398,10 @@ func TestMiddleware(t *testing.T) { router.ServeHTTP(rec, req) // Check the result + require.Equal(t, expectedStatusCode, rec.Code) require.Equal(t, expectedStatusCode, responseStatusCode) - require.Equal(t, expectedStatusCode, responseStatusCode) + require.Equal(t, expectedContentLength, responseContentLength) + require.Equal(t, expectedContentType, responseContentType) }) } diff --git a/sdk/middleware/sqhttp/http_test.go b/sdk/middleware/sqhttp/http_test.go index 40b707a0..1e4f7544 100644 --- a/sdk/middleware/sqhttp/http_test.go +++ b/sdk/middleware/sqhttp/http_test.go @@ -5,6 +5,7 @@ package sqhttp import ( + "context" "io" "net/http" "net/http/httptest" @@ -70,7 +71,7 @@ func TestMiddleware(t *testing.T) { }) // Test how the control flows between middleware and handler functions - t.Run("control flow", func(t *testing.T) { + t.Run("data and control flow", func(t *testing.T) { middlewareResponseBody := testlib.RandUTF8String(4096) middlewareResponseStatus := 433 handlerResponseBody := testlib.RandUTF8String(4096) @@ -102,6 +103,11 @@ func TestMiddleware(t *testing.T) { handlers http.Handler test func(t *testing.T, rec *httptest.ResponseRecorder) }{ + // + // Control flow tests + // When an handlers, including middlewares, block. + // + { name: "sqreen first/handler writes the response", handlers: middleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -163,6 +169,29 @@ func TestMiddleware(t *testing.T) { require.Equal(t, middlewareResponseBody, rec.Body.String()) }, }, + + // + // Context data flow tests + // + { + name: "middleware, sqreen, handler", + handlers: func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = r.WithContext(context.WithValue(r.Context(), "m", "v")) + next.ServeHTTP(w, r) + }) + }(middleware(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if v, ok := ctx.Value("m").(string); !ok || v != "v" { + panic("couldn't get the context value m") + } + + w.WriteHeader(http.StatusOK) + }, tc.agent)), + test: func(t *testing.T, rec *httptest.ResponseRecorder) { + require.Equal(t, http.StatusOK, rec.Code) + }, + }, } { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -181,15 +210,23 @@ func TestMiddleware(t *testing.T) { t.Run("response observation", func(t *testing.T) { expectedStatusCode := 433 + expectedContentLength := int64(len(`"hello"`)) + expectedContentType := "application/json" agent := &mockups.AgentMockup{} agent.ExpectConfig().Return(&mockups.AgentConfigMockup{}).Once() agent.ExpectIsIPAllowed(mock.Anything).Return(false).Once() agent.ExpectIsPathAllowed(mock.Anything).Return(false).Once() - var responseStatusCode int + var ( + responseStatusCode int + responseContentType string + responseContentLength int64 + ) agent.ExpectSendClosedRequestContext(mock.MatchedBy(func(recorded types.ClosedRequestContextFace) bool { resp := recorded.Response() responseStatusCode = resp.Status() + responseContentLength = resp.ContentLength() + responseContentType = resp.ContentType() return true })).Return(nil) defer agent.AssertExpectations(t) @@ -198,15 +235,19 @@ func TestMiddleware(t *testing.T) { // Create a router router := http.NewServeMux() router.Handle("/", middleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") w.WriteHeader(expectedStatusCode) + w.Write([]byte(`"hello"`)) }), agent)) // Perform the request and record the output rec := httptest.NewRecorder() router.ServeHTTP(rec, req) // Check the request was performed as expected - require.Equal(t, expectedStatusCode, responseStatusCode) require.Equal(t, expectedStatusCode, rec.Code) + require.Equal(t, expectedStatusCode, responseStatusCode) + require.Equal(t, expectedContentLength, responseContentLength) + require.Equal(t, expectedContentType, responseContentType) }) }