diff --git a/CHANGELOG.md b/CHANGELOG.md
index f6ad9156..2272559a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,18 @@
+# v0.16.1 - 30 September 2020
+
+## Fixes
+
+- (#158) PII: make the PII scrubbing of In-App WAF attack events
+ case-insensitive in order to correctly scrub transformed request parameters.
+
+- (#159) Monitoring: fix the content type and length monitoring of HTTP
+ responses.
+
+- (#157) Gin middleware: use the request Go context instead of Gin's so that the
+ agent can properly manage the request execution context, but also to correctly
+ propagate values stored in the Go context before the middleware function.
+
+
# v0.16.0 - 22 September 2020
## New Feature
diff --git a/README.md b/README.md
index da8142df..0e638ffe 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
-
+
# [Sqreen](https://www.sqreen.com/)'s Application Security Management for Go
@@ -9,7 +9,9 @@ security component into your app. Sqreen’s microagent automatically monitors
sensitive app’s routines, blocks attacks and reports actionable infos to your
dashboard.
-![Dashboard](https://sqreen-assets.s3-eu-west-1.amazonaws.com/miscellaneous/dashboard.gif)
+
+
+
Sqreen provides automatic defense against attacks:
@@ -90,4 +92,4 @@ Congratulations, your Go web application is now protected by Sqreen!
Optionally, use the SDK to perform [user monitoring](https://docs.sqreen.com/go/user-monitoring/)
or [custom security events](https://docs.sqreen.com/go/custom-events/) you would
-like to track and possibly block.
\ No newline at end of file
+like to track and possibly block.
diff --git a/internal/backend/api/api.go b/internal/backend/api/api.go
index ba138ae4..36d55f31 100644
--- a/internal/backend/api/api.go
+++ b/internal/backend/api/api.go
@@ -6,9 +6,10 @@ package api
import (
"encoding/json"
- "strings"
+ "regexp"
"time"
+ "github.com/sqreen/go-agent/internal/sqlib/sqerrors"
"github.com/sqreen/go-agent/internal/sqlib/sqsanitize"
)
@@ -451,16 +452,22 @@ func (i *WAFAttackInfo) Scrub(scrubber *sqsanitize.Scrubber, info sqsanitize.Inf
// were scrubbed. The caller must have stored into info the values scrubbed
// from the request.
redactedString := scrubber.RedactedValueMask()
- for e := range wafInfo {
+ for sanitized := range info {
+ re, err := regexp.Compile(`(?i)`+regexp.QuoteMeta(sanitized))
+ if err != nil {
+ return false, sqerrors.Wrapf(err, "could not ")
+ }
+
+ for e := range wafInfo {
for f := range wafInfo[e].Filter {
- for v := range info {
resolvedValue := wafInfo[e].Filter[f].ResolvedValue
- newStr := strings.ReplaceAll(resolvedValue, v, redactedString)
+
+ newStr := re.ReplaceAllString(resolvedValue, redactedString)
if newStr != resolvedValue {
// The string was changed
wafInfo[e].Filter[f].ResolvedValue = newStr
if wafInfo[e].Filter[f].MatchStatus != "" {
- wafInfo[e].Filter[f].MatchStatus = strings.ReplaceAll(wafInfo[e].Filter[f].MatchStatus, v, redactedString)
+ wafInfo[e].Filter[f].MatchStatus = re.ReplaceAllString(wafInfo[e].Filter[f].MatchStatus, redactedString)
}
scrubbed = true
}
diff --git a/internal/backend/api/api_test.go b/internal/backend/api/api_test.go
new file mode 100644
index 00000000..14c1ff7e
--- /dev/null
+++ b/internal/backend/api/api_test.go
@@ -0,0 +1,111 @@
+package api_test
+
+import (
+ "encoding/json"
+ "regexp"
+ "strings"
+ "testing"
+
+ "github.com/sqreen/go-agent/internal/backend/api"
+ "github.com/sqreen/go-agent/internal/sqlib/sqsanitize"
+ "github.com/stretchr/testify/require"
+)
+
+func TestWAFAttackInfo_Scrub(t *testing.T) {
+ t.Run("AGO-137", func(t *testing.T) {
+ // Test case covering issue https://sqreen.atlassian.net/browse/AGO-137
+
+ // The idea of the following mess is to mock a WAF attack with a lowercase
+ // transformation of the request parameters. The fix should correctly scrub
+ // the request parameters in the attack.
+
+ // Set of PII values used in this test. They are uppercase here while they
+ // are lowercased in the attack information.
+ pii := map[string]string{
+ "access_token": "PASSWORD_1",
+ "api_key": "PASSWORD_2",
+ "apikey": "PASSWORD_3",
+ "authorization": "PASSWORD_4",
+ }
+
+ // Create test data including the resolved values of the WAF using lowercase
+ // parameters and the request parameters including a separate attack
+ // parameter.
+ resolvedValues := make(map[string]string, len(pii))
+ params := make(map[string][]interface{}, len(pii))
+ for k, v := range pii {
+ resolvedValues[k] = strings.ToLower(v)
+ params[k] = []interface{}{v}
+ }
+ resolvedValues["attack"] = "java.lang.processbuilder"
+
+ // Prepare the WAF data json string
+ resolvedValuesJSON, err := json.Marshal(resolvedValues)
+ require.NoError(t, err)
+ resolvedValuesJSONStr, err := json.Marshal(string(resolvedValuesJSON))
+ require.NoError(t, err)
+ wafData := []byte(`[
+ {
+ "ret_code": 1,
+ "flow": "shell_injection-monitoring",
+ "step": "start",
+ "rule": "rule_944100",
+ "filter": [
+ {
+ "operator": "@rx",
+ "operator_value": "java\\.lang\\.(?:runtime|processbuilder)",
+ "binding_accessor": "#.Request.Body.String",
+ "resolved_value": ` + string(resolvedValuesJSONStr) + `,
+ "match_status": "java.lang.processbuilder"
+ }
+ ]
+ }
+ ]`)
+
+ // Create a fake request record with the interesting parts for this test
+ // only
+ record := api.RequestRecord{
+ Request: api.RequestRecord_Request{
+ Parameters: api.RequestRecord_Request_Parameters{
+ Params: params,
+ },
+ },
+ Observed: api.RequestRecord_Observed{
+ Attacks: []*api.RequestRecord_Observed_Attack{
+ {Info: api.WAFAttackInfo{WAFData: wafData}},
+ },
+ },
+ }
+
+ // Create a scrubber of the PII values
+ keyRE := regexp.MustCompile(`(?i)(passw(((or)?d))|(phrase))|(secret)|(authorization)|(api_?key)|((access_?)?token)`)
+ valueRE := regexp.MustCompile(`(?:\d[ -]*?){13,16}`)
+ redactionString := "Redacted by Test"
+ scrubber := sqsanitize.NewScrubber(keyRE, valueRE, redactionString)
+
+ // Scrub the request record
+ info := sqsanitize.Info{}
+ scrubbed, err := record.Scrub(scrubber, info)
+
+ // It shouldn't fail and it should have scrubbed
+ require.NoError(t, err)
+ require.True(t, scrubbed)
+
+
+ scrubbedWAFData := string(record.Observed.Attacks[0].Info.(api.WAFAttackInfo).WAFData)
+
+ // Check that the count of redactions in the WAF info string is correct:
+ // one per PII value.
+ require.Equal(t, len(pii), strings.Count(scrubbedWAFData, redactionString))
+
+ // For each PII value
+ for k, v := range pii {
+ // Check that the returned scrubbed values contain the PII value
+ require.Contains(t, info, v)
+ // Check that the WAF data has been scrubbed
+ require.NotContains(t, scrubbedWAFData, v)
+ // Check that the request parameter has been scrubbed
+ require.Equal(t, redactionString, record.Request.Parameters.Params[k][0].(string))
+ }
+ })
+}
diff --git a/internal/sqlib/sqsanitize/sanitize.go b/internal/sqlib/sqsanitize/sanitize.go
index 619cb946..977bdadb 100644
--- a/internal/sqlib/sqsanitize/sanitize.go
+++ b/internal/sqlib/sqsanitize/sanitize.go
@@ -100,9 +100,7 @@ walk:
v = v.Elem()
goto walk
- case reflect.Array:
- fallthrough
- case reflect.Slice:
+ case reflect.Array, reflect.Slice:
return s.scrubSlice(v, info)
case reflect.Map:
@@ -212,6 +210,8 @@ func (s *Scrubber) scrubSlice(v reflect.Value, info Info) (scrubbed bool) {
}
func (s *Scrubber) scrubMap(v reflect.Value, info Info) (scrubbed bool) {
+ var scrubEverything *Scrubber
+
vt := v.Type().Elem()
hasInterfaceValueType := vt.Kind() == reflect.Interface
hasStringKeyType := v.Type().Key().Kind() == reflect.String
@@ -223,9 +223,12 @@ func (s *Scrubber) scrubMap(v reflect.Value, info Info) (scrubbed bool) {
// value regular expression.
key := iter.Key()
if hasStringKeyType && !s.scrubEveryString && matchString(s.keyRegexp, key.String()) {
- scrubber = new(Scrubber)
- *scrubber = *s
- scrubber.scrubEveryString = true
+ if scrubEverything == nil {
+ scrubEverything = new(Scrubber)
+ *scrubEverything = *scrubber
+ scrubEverything.scrubEveryString = true
+ }
+ scrubber = scrubEverything
}
// Map entries cannot be set. We therefore create a new value in order
@@ -247,9 +250,10 @@ func (s *Scrubber) scrubMap(v reflect.Value, info Info) (scrubbed bool) {
// the scrubber.
newVal := reflect.New(valT).Elem()
newVal.Set(val)
+
// Scrub it
- if scrubbedElement := scrubber.scrubValue(newVal, info); scrubbedElement {
- // Set it
+ if scrubber.scrubValue(newVal, info) {
+ // Replace it
v.SetMapIndex(key, newVal)
scrubbed = true
}
@@ -258,6 +262,8 @@ func (s *Scrubber) scrubMap(v reflect.Value, info Info) (scrubbed bool) {
}
func (s *Scrubber) scrubStruct(v reflect.Value, info Info) (scrubbed bool) {
+ var scrubEverything *Scrubber
+
l := v.NumField()
vt := v.Type()
for i := 0; i < l; i++ {
@@ -269,9 +275,12 @@ func (s *Scrubber) scrubStruct(v reflect.Value, info Info) (scrubbed bool) {
scrubber := s
if !s.scrubEveryString && matchString(s.keyRegexp, ft.Name) {
- scrubber = new(Scrubber)
- *scrubber = *s
- scrubber.scrubEveryString = true
+ if scrubEverything == nil {
+ scrubEverything = new(Scrubber)
+ *scrubEverything = *scrubber
+ scrubEverything.scrubEveryString = true
+ }
+ scrubber = scrubEverything
}
f := v.Field(i)
diff --git a/internal/sqlib/sqsanitize/sanitize_test.go b/internal/sqlib/sqsanitize/sanitize_test.go
index fe5541a9..9f3a296f 100644
--- a/internal/sqlib/sqsanitize/sanitize_test.go
+++ b/internal/sqlib/sqsanitize/sanitize_test.go
@@ -1247,8 +1247,16 @@ func TestScrubber(t *testing.T) {
fuzzer.Fuzz(&multipartForm)
// Insert some values forbidden by the regular expression
- postForm.Add("password", "1234")
- postForm.Add("password", "5678")
+ postForm.Add("password", "password10")
+ postForm.Add("password", "password11")
+ postForm.Add("password", "password12")
+ postForm.Add("passwd", "password1")
+ postForm.Add("api_key", "password2")
+ postForm.Add("apikey", "password3")
+ postForm.Add("authorization", "password4")
+ postForm.Add("access_token", "password5")
+ postForm.Add("secret", "password6")
+
messageFormat := "here is my credit card number %s."
stringWithCreditCardNb := fmt.Sprintf(messageFormat, "4533-3432-3234-3334")
form.Add("message", stringWithCreditCardNb)
@@ -1276,12 +1284,23 @@ func TestScrubber(t *testing.T) {
require.True(t, scrubbed)
// Check values were scrubbed
- require.Equal(t, []string{expectedMask, expectedMask}, req.PostForm["password"])
+ require.Equal(t, []string{expectedMask, expectedMask, expectedMask}, req.PostForm["password"])
+ require.Equal(t, []string{expectedMask}, req.PostForm["passwd"])
+ require.Equal(t, []string{expectedMask}, req.PostForm["api_key"])
+ require.Equal(t, []string{expectedMask}, req.PostForm["apikey"])
+ require.Equal(t, []string{expectedMask}, req.PostForm["authorization"])
+ require.Equal(t, []string{expectedMask}, req.PostForm["access_token"])
+ require.Equal(t, []string{expectedMask}, req.PostForm["secret"])
require.Equal(t, []string{fmt.Sprintf(messageFormat, expectedMask)}, req.Form["message"])
require.Contains(t, info, stringWithCreditCardNb)
- require.Contains(t, info, "1234")
- require.Contains(t, info, "5678")
+ require.Contains(t, info, "password10")
+ require.Contains(t, info, "password11")
+ require.Contains(t, info, "password2")
+ require.Contains(t, info, "password3")
+ require.Contains(t, info, "password4")
+ require.Contains(t, info, "password5")
+ require.Contains(t, info, "password6")
})
})
})
diff --git a/internal/version/version.go b/internal/version/version.go
index be419b56..9d9eb8a4 100644
--- a/internal/version/version.go
+++ b/internal/version/version.go
@@ -4,6 +4,6 @@
package version
-const version = "0.16.0"
+const version = "0.16.1"
func Version() string { return version }
diff --git a/sdk/middleware/sqecho/echo.go b/sdk/middleware/sqecho/echo.go
index a1ace8d1..dc07447f 100644
--- a/sdk/middleware/sqecho/echo.go
+++ b/sdk/middleware/sqecho/echo.go
@@ -255,17 +255,22 @@ type observedResponse struct {
}
func newObservedResponse(r *responseWriterImpl) *observedResponse {
+ response := r.c.Response()
+
+ headers := response.Header()
+
// Content-Type will be not empty only when explicitly set.
// It could be guessed as net/http does. Not implemented for now.
- ct := r.Header().Get("Content-Type")
-
- response := r.c.Response()
+ ct := headers.Get("Content-Type")
- // Content-Length is either explicitly set or the amount of written data.
+ // Content-Length is either explicitly set or the amount of written data. It's
+ // 0 by default with Echo.
cl := response.Size
- if contentLength := r.Header().Get("Content-Length"); contentLength != "" {
- if l, err := strconv.ParseInt(contentLength, 10, 0); err == nil {
- cl = l
+ if cl == 0 {
+ if contentLength := headers.Get("Content-Length"); contentLength != "" {
+ if l, err := strconv.ParseInt(contentLength, 10, 0); err == nil {
+ cl = l
+ }
}
}
diff --git a/sdk/middleware/sqecho/echo_test.go b/sdk/middleware/sqecho/echo_test.go
index b0c2d626..8647c9b3 100644
--- a/sdk/middleware/sqecho/echo_test.go
+++ b/sdk/middleware/sqecho/echo_test.go
@@ -5,6 +5,7 @@
package sqecho
import (
+ "context"
"errors"
"io/ioutil"
"net/http"
@@ -100,7 +101,7 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, body, rec.Body.String())
})
- t.Run("control flow", func(t *testing.T) {
+ t.Run("data and control flow", func(t *testing.T) {
middlewareResponseBody := testlib.RandUTF8String(4096)
middlewareResponseStatus := 433
handlerResponseBody := testlib.RandUTF8String(4096)
@@ -133,6 +134,11 @@ func TestMiddleware(t *testing.T) {
handler func(echo.Context) error
test func(t *testing.T, rec *httptest.ResponseRecorder, err error)
}{
+ //
+ // Control flow tests
+ // When an handlers, including middlewares, block.
+ //
+
{
name: "sqreen first/the middleware aborts before the handler",
middlewares: []echo.MiddlewareFunc{
@@ -341,6 +347,54 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, middlewareResponseBody+handlerResponseBody+middlewareResponseBody, rec.Body.String())
},
},
+
+ //
+ // Context data flow tests
+ //
+ {
+ name: "middleware1, sqreen, middleware2, handler",
+ middlewares: []echo.MiddlewareFunc{
+ func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ c.Set("m10", "v10")
+ c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), "m11", "v11")))
+ return next(c)
+ }
+ },
+ middleware(tc.agent),
+ func(next echo.HandlerFunc) echo.HandlerFunc {
+ return func(c echo.Context) error {
+ c.Set("m20", "v20")
+ c.SetRequest(c.Request().WithContext(context.WithValue(c.Request().Context(), "m21", "v21")))
+ return next(c)
+ }
+ },
+ },
+ handler: func(c echo.Context) error {
+ // From Gin's context
+ if v, ok := c.Get("m10").(string); !ok || v != "v10" {
+ panic("couldn't get the context value m10")
+ }
+ if v, ok := c.Get("m20").(string); !ok || v != "v20" {
+ panic("couldn't get the context value m20")
+ }
+
+ // From the request context
+ reqCtx := c.Request().Context()
+ if v, ok := reqCtx.Value("m11").(string); !ok || v != "v11" {
+ panic("couldn't get the context value m11")
+ }
+ if v, ok := reqCtx.Value("m21").(string); !ok || v != "v21" {
+ panic("couldn't get the context value m21")
+ }
+
+ return c.NoContent(http.StatusOK)
+ },
+ test: func(t *testing.T, rec *httptest.ResponseRecorder, err error) {
+ require.NoError(t, err)
+ require.Equal(t, http.StatusOK, rec.Code)
+ },
+ },
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
@@ -370,15 +424,23 @@ func TestMiddleware(t *testing.T) {
t.Run("response observation", func(t *testing.T) {
expectedStatusCode := 433
+ expectedContentLength := int64(len("\"hello\"\n"))
+ expectedContentType := echo.MIMEApplicationJSONCharsetUTF8
agent := &mockups.AgentMockup{}
agent.ExpectConfig().Return(&mockups.AgentConfigMockup{}).Once()
agent.ExpectIsIPAllowed(mock.Anything).Return(false).Once()
agent.ExpectIsPathAllowed(mock.Anything).Return(false).Once()
- var responseStatusCode int
+ var (
+ responseStatusCode int
+ responseContentType string
+ responseContentLength int64
+ )
agent.ExpectSendClosedRequestContext(mock.MatchedBy(func(recorded types.ClosedRequestContextFace) bool {
resp := recorded.Response()
responseStatusCode = resp.Status()
+ responseContentLength = resp.ContentLength()
+ responseContentType = resp.ContentType()
return true
})).Return(nil)
defer agent.AssertExpectations(t)
@@ -387,7 +449,7 @@ func TestMiddleware(t *testing.T) {
router := echo.New()
router.Use(middleware(agent))
router.GET("/", func(c echo.Context) error {
- return c.NoContent(expectedStatusCode)
+ return c.JSON(expectedStatusCode, "hello")
})
// Perform the request and record the output
@@ -401,8 +463,10 @@ func TestMiddleware(t *testing.T) {
// Check the result
require.NoError(t, err)
- require.Equal(t, expectedStatusCode, responseStatusCode)
require.Equal(t, expectedStatusCode, rec.Code)
+ require.Equal(t, expectedStatusCode, responseStatusCode)
+ require.Equal(t, expectedContentLength, responseContentLength)
+ require.Equal(t, expectedContentType, responseContentType)
})
}
diff --git a/sdk/middleware/sqecho/v4/README.md b/sdk/middleware/sqecho/v4/README.md
new file mode 100644
index 00000000..f4a827ef
--- /dev/null
+++ b/sdk/middleware/sqecho/v4/README.md
@@ -0,0 +1,36 @@
+
+
+
+
+# [Sqreen](https://www.sqreen.com/)'s Application Security Management for Go
+
+After performance monitoring (APM), error and log monitoring it’s time to add a
+security component into your app. Sqreen’s microagent automatically monitors
+sensitive app’s routines, blocks attacks and reports actionable infos to your
+dashboard.
+
+
+
+
+
+# Echo middleware function
+
+This package provides Sqreen's middleware function for Echo to monitor and
+protect requests Echo receives. Simply setup the middleware function to have
+your requests monitored and protected by Sqreen.
+
+Usage:
+
+```go
+e := echo.New()
+// Setup Sqreen's middleware
+e.Use(sqecho.Middleware())
+
+// Every router endpoint is now automatically monitored and protected by Sqreen
+e.GET("/", func(c echo.Context) error {
+ // ...
+}
+```
+
+Find more details on how to setup Sqreen for Go at
+
\ No newline at end of file
diff --git a/sdk/middleware/sqgin/README.md b/sdk/middleware/sqgin/README.md
new file mode 100644
index 00000000..cc703696
--- /dev/null
+++ b/sdk/middleware/sqgin/README.md
@@ -0,0 +1,36 @@
+
+
+
+
+# [Sqreen](https://www.sqreen.com/)'s Application Security Management for Go
+
+After performance monitoring (APM), error and log monitoring it’s time to add a
+security component into your app. Sqreen’s microagent automatically monitors
+sensitive app’s routines, blocks attacks and reports actionable infos to your
+dashboard.
+
+
+
+
+
+# Gin middleware function
+
+This package provides Sqreen's middleware function for Gin to monitor and
+protect requests Gin receives. Simply setup the middleware function to have your
+requests monitored and protected by Sqreen.
+
+Usage:
+
+```go
+router := gin.Default()
+// Setup Sqreen's middleware
+router.Use(sqgin.Middleware())
+
+// Every router endpoint is now automatically monitored and protected by Sqreen
+router.GET("/", func(c *gin.Context) {
+ c.Status(http.StatusOK)
+}
+```
+
+Find more details on how to setup Sqreen for Go at
+
\ No newline at end of file
diff --git a/sdk/middleware/sqgin/gin.go b/sdk/middleware/sqgin/gin.go
index e0dad9f4..38a5d22f 100644
--- a/sdk/middleware/sqgin/gin.go
+++ b/sdk/middleware/sqgin/gin.go
@@ -76,7 +76,7 @@ func middlewareHandler(agent protection_context.AgentFace, c *gingonic.Context)
requestReader := &requestReaderImpl{c: c}
responseWriter := &responseWriterImpl{c: c}
- ctx, reqCtx, cancelHandlerContext := http_protection.NewRequestContext(c, agent, responseWriter, requestReader)
+ ctx, reqCtx, cancelHandlerContext := http_protection.NewRequestContext(c.Request.Context(), agent, responseWriter, requestReader)
if ctx == nil {
c.Next()
return
@@ -248,15 +248,22 @@ type observedResponse struct {
}
func newObservedResponse(r *responseWriterImpl) *observedResponse {
+ headers := r.c.Writer.Header()
+
// Content-Type will be not empty only when explicitly set.
// It could be guessed as net/http does. Not implemented for now.
- ct := r.Header().Get("Content-Type")
+ ct := headers.Get("Content-Type")
- // Content-Length is either explicitly set or the amount of written data.
+ // Content-Length is either explicitly set or the amount of written data. It's
+ // less than 0 when not set by default with Gin.
cl := int64(r.c.Writer.Size())
- if contentLength := r.Header().Get("Content-Length"); contentLength != "" {
- if l, err := strconv.ParseInt(contentLength, 10, 0); err == nil {
- cl = l
+ if cl < 0 {
+ if contentLength := headers.Get("Content-Length"); contentLength != "" {
+ if l, err := strconv.ParseInt(contentLength, 10, 0); err == nil {
+ cl = l
+ } else {
+ cl = 0
+ }
}
}
diff --git a/sdk/middleware/sqgin/gin_test.go b/sdk/middleware/sqgin/gin_test.go
index 80a02b59..9eec7b94 100644
--- a/sdk/middleware/sqgin/gin_test.go
+++ b/sdk/middleware/sqgin/gin_test.go
@@ -5,6 +5,7 @@
package sqgin
import (
+ "context"
"net/http"
"net/http/httptest"
"testing"
@@ -91,7 +92,7 @@ func TestMiddleware(t *testing.T) {
})
// Test how the control flows between middleware and handler functions
- t.Run("control flow", func(t *testing.T) {
+ t.Run("data and control flow", func(t *testing.T) {
middlewareResponseBody := testlib.RandUTF8String(4096)
middlewareResponseStatus := 433
handlerResponseBody := testlib.RandUTF8String(4096)
@@ -124,6 +125,11 @@ func TestMiddleware(t *testing.T) {
handler func(*gin.Context)
test func(t *testing.T, rec *httptest.ResponseRecorder)
}{
+ //
+ // Control flow tests
+ // When an handlers, including middlewares, block.
+ //
+
{
name: "sqreen first/next middleware aborts before the handler",
middlewares: []gin.HandlerFunc{
@@ -292,6 +298,47 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, middlewareResponseBody, rec.Body.String())
},
},
+
+ //
+ // Context data flow tests
+ //
+ {
+ name: "middleware1, sqreen, middleware2, handler",
+ middlewares: []gin.HandlerFunc{
+ func(c *gin.Context) {
+ c.Set("m10", "v10")
+ c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), "m11", "v11"))
+ },
+ middleware(tc.agent),
+ func(c *gin.Context) {
+ c.Set("m20", "v20")
+ c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), "m21", "v21"))
+ },
+ },
+ handler: func(c *gin.Context) {
+ // From Gin's context
+ if v, ok := c.Value("m10").(string); !ok || v != "v10" {
+ panic("couldn't get the context value m10")
+ }
+ if v, ok := c.Value("m20").(string); !ok || v != "v20" {
+ panic("couldn't get the context value m20")
+ }
+
+ // From the request context
+ reqCtx := c.Request.Context()
+ if v, ok := reqCtx.Value("m11").(string); !ok || v != "v11" {
+ panic("couldn't get the context value m11")
+ }
+ if v, ok := reqCtx.Value("m21").(string); !ok || v != "v21" {
+ panic("couldn't get the context value m21")
+ }
+
+ c.Status(http.StatusOK)
+ },
+ test: func(t *testing.T, rec *httptest.ResponseRecorder) {
+ require.Equal(t, http.StatusOK, rec.Code)
+ },
+ },
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
@@ -317,15 +364,23 @@ func TestMiddleware(t *testing.T) {
t.Run("response observation", func(t *testing.T) {
expectedStatusCode := 433
+ expectedContentLength := int64(len(`"hello"`))
+ expectedContentType := "application/json; charset=utf-8"
agent := &mockups.AgentMockup{}
agent.ExpectConfig().Return(&mockups.AgentConfigMockup{}).Once()
agent.ExpectIsIPAllowed(mock.Anything).Return(false).Once()
agent.ExpectIsPathAllowed(mock.Anything).Return(false).Once()
- var responseStatusCode int
+ var (
+ responseStatusCode int
+ responseContentType string
+ responseContentLength int64
+ )
agent.ExpectSendClosedRequestContext(mock.MatchedBy(func(recorded types.ClosedRequestContextFace) bool {
resp := recorded.Response()
responseStatusCode = resp.Status()
+ responseContentType = resp.ContentType()
+ responseContentLength = resp.ContentLength()
return true
})).Return(nil)
defer agent.AssertExpectations(t)
@@ -334,7 +389,7 @@ func TestMiddleware(t *testing.T) {
router := gin.New()
router.Use(middleware(agent))
router.GET("/", func(c *gin.Context) {
- c.Status(expectedStatusCode)
+ c.JSON(expectedStatusCode, "hello")
})
// Perform the request and record the output
@@ -343,8 +398,10 @@ func TestMiddleware(t *testing.T) {
router.ServeHTTP(rec, req)
// Check the result
+ require.Equal(t, expectedStatusCode, rec.Code)
require.Equal(t, expectedStatusCode, responseStatusCode)
- require.Equal(t, expectedStatusCode, responseStatusCode)
+ require.Equal(t, expectedContentLength, responseContentLength)
+ require.Equal(t, expectedContentType, responseContentType)
})
}
diff --git a/sdk/middleware/sqhttp/http_test.go b/sdk/middleware/sqhttp/http_test.go
index 40b707a0..1e4f7544 100644
--- a/sdk/middleware/sqhttp/http_test.go
+++ b/sdk/middleware/sqhttp/http_test.go
@@ -5,6 +5,7 @@
package sqhttp
import (
+ "context"
"io"
"net/http"
"net/http/httptest"
@@ -70,7 +71,7 @@ func TestMiddleware(t *testing.T) {
})
// Test how the control flows between middleware and handler functions
- t.Run("control flow", func(t *testing.T) {
+ t.Run("data and control flow", func(t *testing.T) {
middlewareResponseBody := testlib.RandUTF8String(4096)
middlewareResponseStatus := 433
handlerResponseBody := testlib.RandUTF8String(4096)
@@ -102,6 +103,11 @@ func TestMiddleware(t *testing.T) {
handlers http.Handler
test func(t *testing.T, rec *httptest.ResponseRecorder)
}{
+ //
+ // Control flow tests
+ // When an handlers, including middlewares, block.
+ //
+
{
name: "sqreen first/handler writes the response",
handlers: middleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
@@ -163,6 +169,29 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, middlewareResponseBody, rec.Body.String())
},
},
+
+ //
+ // Context data flow tests
+ //
+ {
+ name: "middleware, sqreen, handler",
+ handlers: func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ r = r.WithContext(context.WithValue(r.Context(), "m", "v"))
+ next.ServeHTTP(w, r)
+ })
+ }(middleware(func(w http.ResponseWriter, r *http.Request) {
+ ctx := r.Context()
+ if v, ok := ctx.Value("m").(string); !ok || v != "v" {
+ panic("couldn't get the context value m")
+ }
+
+ w.WriteHeader(http.StatusOK)
+ }, tc.agent)),
+ test: func(t *testing.T, rec *httptest.ResponseRecorder) {
+ require.Equal(t, http.StatusOK, rec.Code)
+ },
+ },
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
@@ -181,15 +210,23 @@ func TestMiddleware(t *testing.T) {
t.Run("response observation", func(t *testing.T) {
expectedStatusCode := 433
+ expectedContentLength := int64(len(`"hello"`))
+ expectedContentType := "application/json"
agent := &mockups.AgentMockup{}
agent.ExpectConfig().Return(&mockups.AgentConfigMockup{}).Once()
agent.ExpectIsIPAllowed(mock.Anything).Return(false).Once()
agent.ExpectIsPathAllowed(mock.Anything).Return(false).Once()
- var responseStatusCode int
+ var (
+ responseStatusCode int
+ responseContentType string
+ responseContentLength int64
+ )
agent.ExpectSendClosedRequestContext(mock.MatchedBy(func(recorded types.ClosedRequestContextFace) bool {
resp := recorded.Response()
responseStatusCode = resp.Status()
+ responseContentLength = resp.ContentLength()
+ responseContentType = resp.ContentType()
return true
})).Return(nil)
defer agent.AssertExpectations(t)
@@ -198,15 +235,19 @@ func TestMiddleware(t *testing.T) {
// Create a router
router := http.NewServeMux()
router.Handle("/", middleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
w.WriteHeader(expectedStatusCode)
+ w.Write([]byte(`"hello"`))
}), agent))
// Perform the request and record the output
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
// Check the request was performed as expected
- require.Equal(t, expectedStatusCode, responseStatusCode)
require.Equal(t, expectedStatusCode, rec.Code)
+ require.Equal(t, expectedStatusCode, responseStatusCode)
+ require.Equal(t, expectedContentLength, responseContentLength)
+ require.Equal(t, expectedContentType, responseContentType)
})
}