Skip to content
This repository has been archived by the owner on Nov 2, 2023. It is now read-only.

Commit

Permalink
sdk/middleware/sqhttp: transparent http response writer wrapper
Browse files Browse the repository at this point in the history
Make the HTTP response writer wrapper transparent by implementing the same
*known* interfaces as the underlying HTTP response writer. The
list of interfaces is currently every optional `net/http` interfaces, and some
from `io` when relevant:

  - `http.Flusher`: to allow flushing any buffered to the client. This enables
    support for streaming handlers.
  - `http.Hijacker`: to allow handlers to takeover the HTTP connection. This
	  should enable the support for websocket servers, which are not officially
	  supported by Sqreen, but is now experimentally allowed.
  - `http.Pusher`: for HTTP2 server push.
  - `http.CloseNotifier`: the deprecated closed connection notifier.
  - `io.ReaderFrom`: for optimized copies (eg. `io.Copy(file, w)`)
  - `io.WriteString`: for optimized string write (which avoids a temporary string copy into a byte slice)

The transparent wrapper implementation has been generated with a tool that will
be released in the Go agent repository in the future.

Fixes #162 and #134
  • Loading branch information
Julio Guerra committed Nov 17, 2020
2 parents f87a508 + 58eeeae commit b4824cc
Show file tree
Hide file tree
Showing 17 changed files with 1,492 additions and 200 deletions.
7 changes: 5 additions & 2 deletions internal/protection/http/bindingaccessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@ func NewRequestBindingAccessorContext(r types.RequestReader) *RequestBindingAcce
//}

func (r *RequestBindingAccessorContext) FilteredParams() RequestParamMap {
queryForm := r.QueryForm()
postForm := r.PostForm()
// Careful: the types need to be changed to avoid types with aliases because
// their conversion to JS will make the `url.Values` method names take
// precedence over the field names
queryForm := map[string][]string(r.QueryForm())
postForm := map[string][]string(r.PostForm())
params := r.RequestReader.Params()

res := make(types.RequestParamMap, 2+len(params))
Expand Down
6 changes: 3 additions & 3 deletions internal/protection/http/bindingaccessors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ func TestRequestBindingAccessors(t *testing.T) {
`#.URL.RequestURI`: "/admin?user=uid&password=pwd",
`#.FilteredParams`: http_protection.RequestParamMap{
"QueryForm": []interface{}{
url.Values{
"user": []string{"uid"},
"password": []string{"pwd"},
map[string][]string{
"user": {"uid"},
"password": {"pwd"},
},
},
"json": types.RequestParamValueSlice{
Expand Down
11 changes: 10 additions & 1 deletion internal/protection/http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"net"
"net/http"
"net/url"
"time"

"github.com/sqreen/go-agent/internal/actor"
Expand Down Expand Up @@ -253,7 +254,15 @@ func (p *ProtectionContext) wrapBody(body io.ReadCloser) io.ReadCloser {
// specify where it was taken from.
func (p *ProtectionContext) AddRequestParam(name string, param interface{}) {
params := p.requestReader.requestParams[name]
p.requestReader.requestParams[name] = append(params, param)
var v interface{}
switch actual := param.(type) {
default:
v = param
case url.Values:
// Bare Go type so that it doesn't have any method (for the JS conversion)
v = map[string][]string(actual)
}
p.requestReader.requestParams[name] = append(params, v)
}

func (p *ProtectionContext) ClientIP() net.IP {
Expand Down
6 changes: 0 additions & 6 deletions internal/protection/http/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package types

import (
"context"
"io"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -82,7 +81,6 @@ func (m *RequestParamMap) Add(key string, value interface{}) {
// ResponseWriter is the response writer interface.
type ResponseWriter interface {
http.ResponseWriter
io.StringWriter
}

// ResponseFace is the interface to the response that was sent by the handler.
Expand All @@ -100,7 +98,3 @@ type ClosedProtectionContextFace interface {
Duration() time.Duration
SqreenTime() time.Duration
}

type WriteAfterCloseError struct{}

func (WriteAfterCloseError) Error() string { return "response write after close" }
4 changes: 3 additions & 1 deletion internal/rule/callback/write-blocking-html-page.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
package callback

import (
"io"

"github.com/sqreen/go-agent/internal/backend/api"
httpprotection "github.com/sqreen/go-agent/internal/protection/http"
"github.com/sqreen/go-agent/internal/sqlib/sqassert"
Expand Down Expand Up @@ -43,7 +45,7 @@ func newWriteBlockingHTMLPagePrologCallback(r RuleContext, statusCode int) httpp
// Write the blocking page. We ignore any return error as this is a best
// effort response attempt: we don't want to penalize the server any
// further with this request - so no logging, no counting, no retry.
_, _ = ctx.ResponseWriter.WriteString(blockedBySqreenPage)
_, _ = io.WriteString(ctx.ResponseWriter, blockedBySqreenPage)
return nil
})
return nil, nil
Expand Down
50 changes: 3 additions & 47 deletions sdk/middleware/sqecho/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package sqecho

import (
"bytes"
"io"
"net"
"net/http"
"net/textproto"
Expand Down Expand Up @@ -87,15 +86,14 @@ func Middleware() echo.MiddlewareFunc {
}

func middlewareHandlerFromRootProtectionContext(ctx types.RootProtectionContext, next echo.HandlerFunc, c echo.Context) (err error) {
w := &responseWriterImpl{c: c}
r := &requestReaderImpl{c: c}
p := http_protection.NewProtectionContext(ctx, w, r)
p := http_protection.NewProtectionContext(ctx, c.Response(), r)
if p == nil {
return next(c)
}

defer func() {
p.Close(w.closeResponseWriter(err))
p.Close(newObservedResponse(c.Response(), err))
}()

return middlewareHandlerFromProtectionContext(p, next, c)
Expand Down Expand Up @@ -210,56 +208,14 @@ func (r *requestReaderImpl) RemoteAddr() string {
return r.c.Request().RemoteAddr
}

type responseWriterImpl struct {
c echo.Context
closed bool
}

func (w *responseWriterImpl) closeResponseWriter(err error) types.ResponseFace {
if !w.closed {
w.closed = true
}
return newObservedResponse(w, err)
}

func (w *responseWriterImpl) Header() http.Header {
return w.c.Response().Header()
}

func (w *responseWriterImpl) Write(b []byte) (int, error) {
if w.closed {
return 0, types.WriteAfterCloseError{}
}
return w.c.Response().Write(b)
}

func (w *responseWriterImpl) WriteString(s string) (int, error) {
if w.closed {
return 0, types.WriteAfterCloseError{}
}
return io.WriteString(w.c.Response(), s)
}

// Static assert that the io.StringWriter is implemented
var _ io.StringWriter = (*responseWriterImpl)(nil)

func (w *responseWriterImpl) WriteHeader(statusCode int) {
if w.closed {
return
}
w.c.Response().WriteHeader(statusCode)
}

// response observed by the response writer
type observedResponse struct {
contentType string
contentLength int64
status int
}

func newObservedResponse(r *responseWriterImpl, err error) *observedResponse {
response := r.c.Response()

func newObservedResponse(response *echo.Response, err error) *observedResponse {
headers := response.Header()

// Content-Type will be not empty only when explicitly set.
Expand Down
45 changes: 45 additions & 0 deletions sdk/middleware/sqecho/echo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,51 @@ func TestMiddleware(t *testing.T) {
require.Equal(t, rec.Header().Get("Content-Type"), responseContentType)
})

t.Run("several handler responses", func(t *testing.T) {
var (
responseStatusCode int
responseContentType string
responseContentLength int64
)
root := mockups.NewRootHTTPProtectionContextMockup(context.Background(), mock.Anything, mock.Anything)
root.ExpectClose(mock.MatchedBy(func(closed types.ClosedProtectionContextFace) bool {
resp := closed.Response()
responseStatusCode = resp.Status()
responseContentLength = resp.ContentLength()
responseContentType = resp.ContentType()
return true
}))
defer root.AssertExpectations(t)

// Both responses are taken into account for now
expectedStatusCode := 42
expectedContentLength := int64(len("\"hello\"\n") + len("bonjour"))
expectedContentType := echo.MIMEApplicationJSONCharsetUTF8

h := func(c echo.Context) error {
if err := c.JSON(expectedStatusCode, "hello"); err != nil {
return err
}
return c.String(42, "bonjour")
}

// Perform the request and record the output
rec := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)

m := middleware(root)
c := echo.New().NewContext(req, rec)

// Wrap and call the handler
err := m(h)(c)

// Check the result
require.NoError(t, err)
require.Equal(t, expectedStatusCode, responseStatusCode)
require.Equal(t, expectedContentLength, responseContentLength)
require.Equal(t, expectedContentType, responseContentType)
})

t.Run("default response", func(t *testing.T) {
var (
responseStatusCode int
Expand Down
50 changes: 3 additions & 47 deletions sdk/middleware/sqecho/v4/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package sqecho

import (
"bytes"
"io"
"net"
"net/http"
"net/textproto"
Expand Down Expand Up @@ -87,15 +86,14 @@ func Middleware() echo.MiddlewareFunc {
}

func middlewareHandlerFromRootProtectionContext(ctx types.RootProtectionContext, next echo.HandlerFunc, c echo.Context) (err error) {
w := &responseWriterImpl{c: c}
r := &requestReaderImpl{c: c}
p := http_protection.NewProtectionContext(ctx, w, r)
p := http_protection.NewProtectionContext(ctx, c.Response(), r)
if p == nil {
return next(c)
}

defer func() {
p.Close(w.closeResponseWriter(err))
p.Close(newObservedResponse(c.Response(), err))
}()

return middlewareHandlerFromProtectionContext(p, next, c)
Expand Down Expand Up @@ -210,56 +208,14 @@ func (r *requestReaderImpl) RemoteAddr() string {
return r.c.Request().RemoteAddr
}

type responseWriterImpl struct {
c echo.Context
closed bool
}

func (w *responseWriterImpl) closeResponseWriter(err error) types.ResponseFace {
if !w.closed {
w.closed = true
}
return newObservedResponse(w, err)
}

func (w *responseWriterImpl) Header() http.Header {
return w.c.Response().Header()
}

func (w *responseWriterImpl) Write(b []byte) (int, error) {
if w.closed {
return 0, types.WriteAfterCloseError{}
}
return w.c.Response().Write(b)
}

func (w *responseWriterImpl) WriteString(s string) (int, error) {
if w.closed {
return 0, types.WriteAfterCloseError{}
}
return io.WriteString(w.c.Response(), s)
}

// Static assert that the io.StringWriter is implemented
var _ io.StringWriter = (*responseWriterImpl)(nil)

func (w *responseWriterImpl) WriteHeader(statusCode int) {
if w.closed {
return
}
w.c.Response().WriteHeader(statusCode)
}

// response observed by the response writer
type observedResponse struct {
contentType string
contentLength int64
status int
}

func newObservedResponse(r *responseWriterImpl, err error) *observedResponse {
response := r.c.Response()

func newObservedResponse(response *echo.Response, err error) *observedResponse {
headers := response.Header()

// Content-Type will be not empty only when explicitly set.
Expand Down

0 comments on commit b4824cc

Please sign in to comment.