Skip to content
This repository has been archived by the owner on Nov 2, 2023. It is now read-only.

sdk/middleware/sqgin: response's content length monitoring #166

Merged
merged 2 commits into from Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
134 changes: 106 additions & 28 deletions sdk/middleware/sqecho/echo_test.go
Expand Up @@ -426,7 +426,7 @@ func TestMiddleware(t *testing.T) {
})

t.Run("response observation", func(t *testing.T) {
t.Run("direct http header write", func(t *testing.T) {
t.Run("handler response", func(t *testing.T) {
var (
responseStatusCode int
responseContentType string
Expand Down Expand Up @@ -465,43 +465,121 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, expectedStatusCode, rec.Code)
require.Equal(t, expectedStatusCode, responseStatusCode)
require.Equal(t, expectedContentLength, responseContentLength)
require.Equal(t, int64(rec.Body.Len()), responseContentLength)
require.Equal(t, expectedContentType, responseContentType)
require.Equal(t, rec.Header().Get("Content-Type"), responseContentType)
})
})

t.Run("echo handler error", func(t *testing.T) {
var (
responseStatusCode int
)
root := mockups.NewRootHTTPProtectionContextMockup(context.Background(), mock.Anything, mock.Anything)
root.ExpectClose(mock.MatchedBy(func(closed types.ClosedProtectionContextFace) bool {
resp := closed.Response()
responseStatusCode = resp.Status()
return true
}))
defer root.AssertExpectations(t)
t.Run("default response", func(t *testing.T) {
var (
responseStatusCode int
responseContentType string
responseContentLength int64
)
root := mockups.NewRootHTTPProtectionContextMockup(context.Background(), mock.Anything, mock.Anything)
root.ExpectClose(mock.MatchedBy(func(closed types.ClosedProtectionContextFace) bool {
resp := closed.Response()
responseStatusCode = resp.Status()
responseContentLength = resp.ContentLength()
responseContentType = resp.ContentType()
return true
}))
defer root.AssertExpectations(t)

expectedError := echo.ErrNotFound
h := func(c echo.Context) error {
// Do nothing, so that Echo's response fields have their default values
return nil
}

h := func(c echo.Context) error {
return expectedError
}
// Perform the request and record the output
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)

// Perform the request and record the output
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
m := middleware(root)
c := echo.New().NewContext(req, rec)

m := middleware(root)
c := echo.New().NewContext(req, rec)
// Wrap and call the handler
err := m(h)(c)

// Wrap and call the handler
err := m(h)(c)
// Check the result
require.NoError(t, err)
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, http.StatusOK, responseStatusCode)
require.Equal(t, int64(rec.Body.Len()), responseContentLength)
require.Equal(t, rec.Header().Get("Content-Type"), responseContentType)
})

// Check the result
require.Error(t, err)
require.Equal(t, expectedError, err)
require.Equal(t, expectedError.Code, responseStatusCode)
t.Run("not found endpoint", func(t *testing.T) {
var (
responseStatusCode int
responseContentType string
responseContentLength int64
)
root := mockups.NewRootHTTPProtectionContextMockup(context.Background(), mock.Anything, mock.Anything)
root.ExpectClose(mock.MatchedBy(func(closed types.ClosedProtectionContextFace) bool {
resp := closed.Response()
responseStatusCode = resp.Status()
responseContentLength = resp.ContentLength()
responseContentType = resp.ContentType()
return true
}))
defer root.AssertExpectations(t)

// Perform the request and record the output
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)

e := echo.New()
m := middleware(root)
e.Use(m)
e.ServeHTTP(rec, req)

// Check the result
require.Equal(t, http.StatusNotFound, rec.Code)
require.Equal(t, http.StatusNotFound, responseStatusCode)

// Echo bypasses the middleware when handling an error
require.Equal(t, int64(0), responseContentLength)
//require.Equal(t, int64(rec.Body.Len()), responseContentLength)
require.Equal(t, "", responseContentType)
//require.Equal(t, rec.Header().Get("Content-Type"), responseContentType)
})

t.Run("echo handler error", func(t *testing.T) {
var (
responseStatusCode int
)
root := mockups.NewRootHTTPProtectionContextMockup(context.Background(), mock.Anything, mock.Anything)
root.ExpectClose(mock.MatchedBy(func(closed types.ClosedProtectionContextFace) bool {
resp := closed.Response()
responseStatusCode = resp.Status()
return true
}))
defer root.AssertExpectations(t)

expectedError := echo.ErrNotFound

h := func(c echo.Context) error {
return expectedError
}

// Perform the request and record the output
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)

m := middleware(root)
c := echo.New().NewContext(req, rec)

// Wrap and call the handler
err := m(h)(c)

// Check the result
require.Error(t, err)
require.Equal(t, expectedError, err)
require.Equal(t, expectedError.Code, responseStatusCode)
})
})

}

func middleware(p types.RootProtectionContext) echo.MiddlewareFunc {
Expand Down
3 changes: 1 addition & 2 deletions sdk/middleware/sqgin/gin.go
Expand Up @@ -266,11 +266,10 @@ func newObservedResponse(r *responseWriterImpl) *observedResponse {
// less than 0 when not set by default with Gin.
cl := int64(r.c.Writer.Size())
if cl < 0 {
cl = 0
if contentLength := headers.Get("Content-Length"); contentLength != "" {
if l, err := strconv.ParseInt(contentLength, 10, 0); err == nil {
cl = l
} else {
cl = 0
}
}
}
Expand Down
142 changes: 109 additions & 33 deletions sdk/middleware/sqgin/gin_test.go
Expand Up @@ -389,43 +389,119 @@ func TestMiddleware(t *testing.T) {
})

t.Run("response observation", func(t *testing.T) {
var (
responseStatusCode int
responseContentType string
responseContentLength int64
)
root := mockups.NewRootHTTPProtectionContextMockup(context.Background(), mock.Anything, mock.Anything)
root.ExpectClose(mock.MatchedBy(func(closed types.ClosedProtectionContextFace) bool {
resp := closed.Response()
responseStatusCode = resp.Status()
responseContentLength = resp.ContentLength()
responseContentType = resp.ContentType()
return true
}))
defer root.AssertExpectations(t)

expectedStatusCode := 433
expectedContentLength := int64(len(`"hello"`))
expectedContentType := "application/json; charset=utf-8"

// Create a route
router := gin.New()
router.Use(middleware(root))
router.GET("/", func(c *gin.Context) {
c.JSON(expectedStatusCode, "hello")
t.Run("handler response", func(t *testing.T) {
var (
responseStatusCode int
responseContentType string
responseContentLength int64
)
root := mockups.NewRootHTTPProtectionContextMockup(context.Background(), mock.Anything, mock.Anything)
root.ExpectClose(mock.MatchedBy(func(closed types.ClosedProtectionContextFace) bool {
resp := closed.Response()
responseStatusCode = resp.Status()
responseContentLength = resp.ContentLength()
responseContentType = resp.ContentType()
return true
}))
defer root.AssertExpectations(t)

expectedStatusCode := 433
expectedContentLength := int64(len(`"hello"`))
expectedContentType := "application/json; charset=utf-8"

// Create a route
router := gin.New()
router.Use(middleware(root))
router.GET("/", func(c *gin.Context) {
c.JSON(expectedStatusCode, "hello")
})

// Perform the request and record the output
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
router.ServeHTTP(rec, req)

// Check the result
require.Equal(t, expectedStatusCode, rec.Code)
require.Equal(t, expectedStatusCode, responseStatusCode)
require.Equal(t, expectedContentLength, responseContentLength)
require.Equal(t, expectedContentLength, int64(rec.Body.Len()))
require.Equal(t, expectedContentType, responseContentType)
require.Equal(t, expectedContentType, rec.Header().Get("Content-Type"))
})

// Perform the request and record the output
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
router.ServeHTTP(rec, req)
t.Run("default response", func(t *testing.T) {
var (
responseStatusCode int
responseContentType string
responseContentLength int64
)
root := mockups.NewRootHTTPProtectionContextMockup(context.Background(), mock.Anything, mock.Anything)
root.ExpectClose(mock.MatchedBy(func(closed types.ClosedProtectionContextFace) bool {
resp := closed.Response()
responseStatusCode = resp.Status()
responseContentLength = resp.ContentLength()
responseContentType = resp.ContentType()
return true
}))
defer root.AssertExpectations(t)

// Create a route
router := gin.New()
router.Use(middleware(root))
router.GET("/", func(c *gin.Context) {
// Do nothing, so that Gin's response fields have their default values
})

// Perform the request and record the output
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
router.ServeHTTP(rec, req)

// Check the result
require.Equal(t, expectedStatusCode, rec.Code)
require.Equal(t, expectedStatusCode, responseStatusCode)
require.Equal(t, expectedContentLength, responseContentLength)
require.Equal(t, expectedContentType, responseContentType)
// Check the result
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, http.StatusOK, responseStatusCode)
require.Equal(t, int64(rec.Body.Len()), responseContentLength)
require.Equal(t, rec.Header().Get("Content-Type"), responseContentType)
})

t.Run("not found endpoint", func(t *testing.T) {
var (
responseStatusCode int
responseContentType string
responseContentLength int64
)
root := mockups.NewRootHTTPProtectionContextMockup(context.Background(), mock.Anything, mock.Anything)
root.ExpectClose(mock.MatchedBy(func(closed types.ClosedProtectionContextFace) bool {
resp := closed.Response()
responseStatusCode = resp.Status()
responseContentLength = resp.ContentLength()
responseContentType = resp.ContentType()
return true
}))
defer root.AssertExpectations(t)

// Create a route
router := gin.New()
router.Use(middleware(root))

// Perform the request and record the output
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
router.ServeHTTP(rec, req)

// Check the result
require.Equal(t, http.StatusNotFound, rec.Code)
require.Equal(t, http.StatusNotFound, responseStatusCode)
require.Equal(t, int64(0), responseContentLength)
// Gin writes the response after the middleware so we can't observe the
// content type and length
//require.Equal(t, int64(0), int64(rec.Body.Len()))
require.Equal(t, "", responseContentType)
//require.Equal(t, "", rec.Header().Get("Content-Type"))
})
})

}

func middleware(p types.RootProtectionContext) gin.HandlerFunc {
Expand Down