/
requests.go
133 lines (109 loc) · 2.91 KB
/
requests.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package httpx
import (
"io"
"net/http"
"net/textproto"
"strings"
"github.com/tal-tech/go-zero/core/mapping"
"github.com/tal-tech/go-zero/rest/pathvar"
)
const (
formKey = "form"
pathKey = "path"
headerKey = "header"
emptyJson = "{}"
maxMemory = 32 << 20 // 32MB
maxBodyLen = 8 << 20 // 8MB
separator = ";"
tokensInAttribute = 2
)
var (
formUnmarshaler = mapping.NewUnmarshaler(formKey, mapping.WithStringValues())
pathUnmarshaler = mapping.NewUnmarshaler(pathKey, mapping.WithStringValues())
headerUnmarshaler = mapping.NewUnmarshaler(headerKey, mapping.WithStringValues(),
mapping.WithCanonicalKeyFunc(textproto.CanonicalMIMEHeaderKey))
)
// Parse parses the request.
func Parse(r *http.Request, v interface{}) error {
if err := ParsePath(r, v); err != nil {
return err
}
if err := ParseForm(r, v); err != nil {
return err
}
if err := ParseHeaders(r, v); err != nil {
return err
}
return ParseJsonBody(r, v)
}
// ParseHeaders parses the headers request.
func ParseHeaders(r *http.Request, v interface{}) error {
m := map[string]interface{}{}
for k, v := range r.Header {
if len(v) == 1 {
m[k] = v[0]
} else {
m[k] = v
}
}
return headerUnmarshaler.Unmarshal(m, v)
}
// ParseForm parses the form request.
func ParseForm(r *http.Request, v interface{}) error {
if err := r.ParseForm(); err != nil {
return err
}
if err := r.ParseMultipartForm(maxMemory); err != nil {
if err != http.ErrNotMultipart {
return err
}
}
params := make(map[string]interface{}, len(r.Form))
for name := range r.Form {
formValue := r.Form.Get(name)
if len(formValue) > 0 {
params[name] = formValue
}
}
return formUnmarshaler.Unmarshal(params, v)
}
// ParseHeader parses the request header and returns a map.
func ParseHeader(headerValue string) map[string]string {
ret := make(map[string]string)
fields := strings.Split(headerValue, separator)
for _, field := range fields {
field = strings.TrimSpace(field)
if len(field) == 0 {
continue
}
kv := strings.SplitN(field, "=", tokensInAttribute)
if len(kv) != tokensInAttribute {
continue
}
ret[kv[0]] = kv[1]
}
return ret
}
// ParseJsonBody parses the post request which contains json in body.
func ParseJsonBody(r *http.Request, v interface{}) error {
var reader io.Reader
if withJsonBody(r) {
reader = io.LimitReader(r.Body, maxBodyLen)
} else {
reader = strings.NewReader(emptyJson)
}
return mapping.UnmarshalJsonReader(reader, v)
}
// ParsePath parses the symbols reside in url path.
// Like http://localhost/bag/:name
func ParsePath(r *http.Request, v interface{}) error {
vars := pathvar.Vars(r)
m := make(map[string]interface{}, len(vars))
for k, v := range vars {
m[k] = v
}
return pathUnmarshaler.Unmarshal(m, v)
}
func withJsonBody(r *http.Request) bool {
return r.ContentLength > 0 && strings.Contains(r.Header.Get(ContentType), ApplicationJson)
}