From 118ddbe7287f7b0aaa3415793b32d435b7d6d96f Mon Sep 17 00:00:00 2001 From: Xabier Larrakoetxea Date: Fri, 7 Feb 2020 22:14:44 +0100 Subject: [PATCH] Internal http.ResponseWriter interceptor now satisifies http.Hijacker and http.Flusher interfaces Signed-off-by: Xabier Larrakoetxea --- CHANGELOG.md | 5 ++++ middleware/middleware.go | 27 ++++++++++++++++++++ middleware/middleware_test.go | 48 ++++++++++++++++------------------- 3 files changed, 54 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0061fb2..f05319f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## [Unreleased] +### Changed + +- Internal response writer interceptor implements `http.Hijacker` and `http.Flusher` interface. + + ## [0.6.0] - 2019-12-11 ### Breaking changes diff --git a/middleware/middleware.go b/middleware/middleware.go index ff30c46..80f5318 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -5,7 +5,10 @@ package middleware import ( + "bufio" + "errors" "fmt" + "net" "net/http" "strconv" "time" @@ -148,3 +151,27 @@ func (w *responseWriterInterceptor) Write(p []byte) (int, error) { w.bytesWritten += len(p) return w.ResponseWriter.Write(p) } + +func (w *responseWriterInterceptor) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := w.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, errors.New("type assertion failed http.ResponseWriter not a http.Hijacker") + } + return h.Hijack() +} + +func (w *responseWriterInterceptor) Flush() { + f, ok := w.ResponseWriter.(http.Flusher) + if !ok { + return + } + + f.Flush() +} + +// Check interface implementations. +var ( + _ http.ResponseWriter = &responseWriterInterceptor{} + _ http.Hijacker = &responseWriterInterceptor{} + _ http.Flusher = &responseWriterInterceptor{} +) diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 6b36502..7f8f060 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -19,8 +19,7 @@ func getFakeHandler(statusCode int, responseBody string) http.Handler { } func TestMiddlewareHandler(t *testing.T) { - tests := []struct { - name string + tests := map[string]struct { handlerID string body string statusCode int @@ -32,8 +31,7 @@ func TestMiddlewareHandler(t *testing.T) { expSize int64 expStatusCode string }{ - { - name: "A default HTTP middleware should call the recorder to measure.", + "A default HTTP middleware should call the recorder to measure.": { statusCode: http.StatusAccepted, body: "Я бэтмен", req: httptest.NewRequest(http.MethodGet, "/test", nil), @@ -43,8 +41,8 @@ func TestMiddlewareHandler(t *testing.T) { expMethod: http.MethodGet, expStatusCode: "202", }, - { - name: "Using custom ID in the middleware should call the recorder to measure with that ID.", + + "Using custom ID in the middleware should call the recorder to measure with that ID.": { handlerID: "customID", body: "I'm Batman", statusCode: http.StatusTeapot, @@ -55,8 +53,8 @@ func TestMiddlewareHandler(t *testing.T) { expMethod: http.MethodPost, expStatusCode: "418", }, - { - name: "Using grouped status code should group the status code.", + + "Using grouped status code should group the status code.": { config: middleware.Config{GroupedStatus: true}, statusCode: http.StatusGatewayTimeout, req: httptest.NewRequest(http.MethodPatch, "/test", nil), @@ -65,8 +63,8 @@ func TestMiddlewareHandler(t *testing.T) { expMethod: http.MethodPatch, expStatusCode: "5xx", }, - { - name: "Using the service middleware option should set the service on the metrics.", + + "Using the service middleware option should set the service on the metrics.": { config: middleware.Config{Service: "Yoda"}, statusCode: http.StatusContinue, body: "May the force be with you", @@ -79,8 +77,8 @@ func TestMiddlewareHandler(t *testing.T) { }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { + for name, test := range tests { + t.Run(name, func(t *testing.T) { // Mocks. mr := &mmetrics.Recorder{} expHTTPReqProps := metrics.HTTPReqProperties{ @@ -112,45 +110,43 @@ func TestMiddlewareHandler(t *testing.T) { func BenchmarkMiddlewareHandler(b *testing.B) { b.StopTimer() - benchs := []struct { - name string + benchs := map[string]struct { handlerID string cfg middleware.Config }{ - { - name: "benchmark with default settings.", + "benchmark with default settings.": { handlerID: "", cfg: middleware.Config{}, }, - { - name: "benchmark disabling measuring size.", + + "benchmark disabling measuring size.": { handlerID: "", cfg: middleware.Config{ DisableMeasureSize: true, }, }, - { - name: "benchmark disabling inflights.", + + "benchmark disabling inflights.": { handlerID: "", cfg: middleware.Config{ DisableMeasureInflight: true, }, }, - { - name: "benchmark with grouped status code.", + + "benchmark with grouped status code.": { cfg: middleware.Config{ GroupedStatus: true, }, }, - { - name: "benchmark with predefined handler ID", + + "benchmark with predefined handler ID": { handlerID: "benchmark1", cfg: middleware.Config{}, }, } - for _, bench := range benchs { - b.Run(bench.name, func(b *testing.B) { + for name, bench := range benchs { + b.Run(name, func(b *testing.B) { // Prepare. bench.cfg.Recorder = metrics.Dummy m := middleware.New(bench.cfg)