/
middleware.go
137 lines (114 loc) · 4.3 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package webhttp
import (
"mime"
"net/http"
"strings"
"sync/atomic"
"github.com/go-chi/chi/middleware"
"github.com/pavelmemory/jobtome/internal/logging"
)
// InjectLogger returns a middleware function that injects a logger into request's context.
// It also propagates logger with a request unique sequence number, so all the logs
// for a particular request could be grouped together.
func InjectLogger(logger logging.Logger) func(http.Handler) http.Handler {
var reqSeq = new(int64)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger := logger.WithInt64("req_seq", atomic.AddInt64(reqSeq, 1))
ctx := logging.ToContext(r.Context(), logger)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// LogRequest returns a middleware function that logs each incoming request.
// TODO: make logging level configurable so we could control log severity for each baseHandler
func LogRequest() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger := logging.FromContext(r.Context())
logger.WithString("url", r.URL.String()).
WithString("method", r.Method).
WithString("referer", r.Referer()).
WithString("user_agent", r.UserAgent()).
WithInt64("content_length", r.ContentLength).
Debug("incoming request")
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
next.ServeHTTP(ww, r)
status := ww.Status()
if status == 0 {
// if the status was not set explicitly (0 is default)
// it is considered as OK by net/http
status = http.StatusOK
}
logger.WithInt("status", status).
WithInt("bytes_written", ww.BytesWritten()).
Debug("outgoing response")
})
}
}
// AcceptsJSON verifies request has a 'content-type' header with 'application/json' mime type.
var AcceptsJSON = RequestContentType("application/json; charset=utf-8")
// RequestContentType returns a middleware function that verifies request has
// `content-type` header and its media type is equal to passed in value.
func RequestContentType(contentType string) func(http.Handler) http.Handler {
wantMediaType, wantParams, err := mime.ParseMediaType(contentType)
if err != nil {
// this is fair enough as it will blow up at startup time
panic(err)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
contentType := r.Header.Get("content-type")
mediaType, params, err := mime.ParseMediaType(contentType)
if err != nil {
w.WriteHeader(http.StatusUnsupportedMediaType)
return
}
if !strings.EqualFold(wantMediaType, mediaType) {
w.WriteHeader(http.StatusUnsupportedMediaType)
return
}
for k, v := range params {
if !strings.EqualFold(v, wantParams[k]) {
w.WriteHeader(http.StatusUnsupportedMediaType)
return
}
}
next.ServeHTTP(w, r)
})
}
}
// ProducesJSON sets response header 'content-type' to with 'application/json' mime type.
var ProducesJSON = ResponseContentType("application/json; charset=utf-8")
// ResponseContentType returns a middleware function that sets passed in value
// as a `content-type` header to the HTTP response in case is not yet set.
func ResponseContentType(contentType string) func(http.Handler) http.Handler {
_, _, err := mime.ParseMediaType(contentType)
if err != nil {
// this is fair enough as it will blow up at startup time
panic(err)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ww := &responseContentTypeWrapper{ResponseWriter: w, contentType: contentType}
next.ServeHTTP(ww, r)
})
}
}
type responseContentTypeWrapper struct {
contentType string
http.ResponseWriter
}
func (rw *responseContentTypeWrapper) Write(d []byte) (int, error) {
if rw.Header().Get("content-type") == "" && rw.contentType != "" {
rw.ResponseWriter.Header().Set("content-type", rw.contentType)
}
return rw.ResponseWriter.Write(d)
}
func (rw *responseContentTypeWrapper) WriteHeader(statusCode int) {
// keep plain response content-type for redirects (status: 300-308)
if statusCode >= http.StatusMultipleChoices && statusCode <= http.StatusPermanentRedirect {
rw.contentType = ""
}
rw.ResponseWriter.WriteHeader(statusCode)
}