diff --git a/log_middleware.go b/log_middleware.go index 4ba9a17..6b5089c 100644 --- a/log_middleware.go +++ b/log_middleware.go @@ -67,8 +67,11 @@ func buildSkipCache(cfg LogMiddlewareConfig) ([]*regexp.Regexp, error) { func (mw *LogMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { start := time.Now() + lrw := &logResponseWriter{ResponseWriter: w} + defer func() { if rec := recover(); rec != nil { + lrw.WriteHeader(http.StatusInternalServerError) // See: https://pkg.go.dev/net/http#ErrAbortHandler if recErr, ok := rec.(error); ok && errors.Is(recErr, http.ErrAbortHandler) { AccessAborted(r, start) @@ -78,7 +81,6 @@ func (mw *LogMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { } }() - lrw := &logResponseWriter{ResponseWriter: w} mw.Next.ServeHTTP(lrw, r) level := logrus.InfoLevel diff --git a/log_middleware_test.go b/log_middleware_test.go index 280fdb4..5f137ab 100644 --- a/log_middleware_test.go +++ b/log_middleware_test.go @@ -79,8 +79,11 @@ func Test_LogMiddleware_Panic(t *testing.T) { })) r, _ := http.NewRequest(http.MethodGet, "http://www.example.org/foo", nil) + resp := httptest.NewRecorder() - lm.ServeHTTP(httptest.NewRecorder(), r) + lm.ServeHTTP(resp, r) + + assert.Equal(t, http.StatusInternalServerError, resp.Result().StatusCode) data := logRecordFromBuffer(b) @@ -101,8 +104,11 @@ func Test_LogMiddleware_Panic_ErrAbortHandler(t *testing.T) { })) r, _ := http.NewRequest(http.MethodGet, "http://www.example.org/foo", nil) + resp := httptest.NewRecorder() - lm.ServeHTTP(httptest.NewRecorder(), r) + lm.ServeHTTP(resp, r) + + assert.Equal(t, http.StatusInternalServerError, resp.Result().StatusCode) data := logRecordFromBuffer(b)