/
middleware.go
152 lines (125 loc) · 3.49 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package http
import (
"bufio"
"io"
"mime"
"net/http"
"net/textproto"
"strings"
"github.com/go-chi/chi/v5/middleware"
"github.com/phogolabs/log"
)
var (
// HeaderContentType represents a header Content-Type key
HeaderContentType = textproto.CanonicalMIMEHeaderKey("Content-Type")
// HeaderAccept represents a header Accept key
HeaderAccept = textproto.CanonicalMIMEHeaderKey("Accept")
)
const (
// ContentTypeAll is the '*/*' type
ContentTypeAll = "*/*"
// ContentTypeForm is the form url encoded
ContentTypeForm = "application/x-www-form-urlencoded"
// ContentTypeGRPC represents 'application/grpc' content-type
ContentTypeGRPCProto = "application/grpc"
// ContentTypeGRPC represents 'application/grpc' content-type
ContentTypeGRPCJSON = "application/grpc+json"
// ContentTypeJSON represents 'application/json' content-type
ContentTypeJSON = "application/json"
)
// PrepareMediaType prepares a media type header
func PrepareMediaType(name string, r *http.Request) {
const separator = ","
var (
logger = log.GetContext(r.Context())
header = http.Header{}
)
// parse the header
for _, content := range r.Header[name] {
// skip empty entries
if len(content) == 0 {
continue
}
for _, item := range strings.Split(content, separator) {
// skip empty entries
if len(item) == 0 {
continue
}
// parse the media type
value, _, err := mime.ParseMediaType(item)
if err != nil {
logger.WithError(err).Infof("skip unsupported media type '%v'", item)
continue
}
// skip the all header type because we will override it anyway
if strings.EqualFold(value, ContentTypeAll) {
continue
}
header.Add(name, value)
}
}
// delete the header
r.Header.Del(name)
// set the new header
if value, ok := header[name]; ok {
r.Header[name] = value
}
}
// SetMediaType sets the media type
func SetMediaType(name string, r *http.Request) {
value, ok := r.Header[name]
if !ok || len(value) == 0 {
r.Header.Set(name, ContentTypeJSON)
}
}
// Accept prepare the Accept header for underlying requests
func Accept(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
// prepare the header values
PrepareMediaType(HeaderAccept, r)
// set the header media type
SetMediaType(HeaderAccept, r)
// serve the request
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
// ContentType prepare the Content-Type header for underlying requests
func ContentType(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
// set the header media type
SetMediaType(HeaderContentType, r)
// serve the request
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
// Metadata middleware injects some useful headers
func Metadata(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
// set some metadata headers
r.Header.Set("X-Plex-Real-Ip", r.RemoteAddr)
r.Header.Set("X-Plex-Request-Id", middleware.GetReqID(r.Context()))
r.Header.Set("X-Plex-User-Agent", r.UserAgent())
r.Header.Set("X-Plex-Url", r.URL.String())
// serve the request
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
// Match matches the request
func Match(body io.Reader) bool {
r, err := http.ReadRequest(bufio.NewReader(body))
if err != nil {
return false
}
for _, value := range r.Header[HeaderContentType] {
if strings.Contains(value, ContentTypeGRPCJSON) {
return true
}
if strings.Contains(value, ContentTypeGRPCProto) {
return false
}
}
return true
}