/
context.go
205 lines (192 loc) · 5.31 KB
/
context.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
package requestContext
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"github.com/pelletier/go-toml"
"github.com/runar-rkmedia/go-common/logger"
"github.com/runar-rkmedia/skiver/types"
"gopkg.in/yaml.v2"
)
type Context struct {
L logger.AppLogger
DB types.Storage
StructValidater func(interface{}) error
}
// Deprecated.
// The useful methods here should be returned into structs
type ReqContext struct {
Context *Context
Req *http.Request
L logger.AppLogger
Rw http.ResponseWriter
ContentKind OutputKind
Accept OutputKind
RemoteIP string
}
func NewReqContext(context *Context, req *http.Request, rw http.ResponseWriter) ReqContext {
// TODO: parse this value into a ip. For now, we do not actually need it.
// (we only use the ip for reducing session-duplications if there are lots of logins.)
remoteIP := req.Header.Get("Forwarded")
if remoteIP == "" {
remoteIP = req.Header.Get("X-Forwarded-For")
}
if remoteIP == "" {
remoteIP = req.Header.Get("X-Originating-IP")
}
if remoteIP == "" {
remoteIP = req.RemoteAddr
}
h := make(http.Header)
for k, v := range req.Header {
switch strings.ToLower(k) {
case "cookie", "authorization":
continue
}
for i := 0; i < len(v); i++ {
h.Add(k, v[i])
}
}
return ReqContext{
Context: context,
L: logger.With(context.L.With().Str("method", req.Method).Str("path", req.URL.Path).Interface("headers", h).Logger()),
Req: req,
Rw: rw,
ContentKind: contentType(req.Header.Get("Content-Type")),
Accept: WantedOutputFormat(req),
RemoteIP: remoteIP,
}
}
func (c *Context) NewReqContext(rw http.ResponseWriter, r *http.Request) ReqContext {
return NewReqContext(c, r, rw)
}
func (rc ReqContext) WriteAuto(output interface{}, error error, errCode ErrorCodes) {
err := WriteAuto(output, error, errCode, rc.Req, rc.Rw)
if err != nil {
l := rc.L.Error().
Err(err).
Str("path", rc.Req.URL.String()).
Str("method", rc.Req.Method)
if error != nil {
l = l.
Str("for-error-code", string(errCode)).
Str("for-error", error.Error())
}
l.Msg("Failure during WriteAuto")
}
}
func (rc ReqContext) WriteError(msg string, errCode ErrorCodes, details ...interface{}) {
// TODO: get error-code from above
_err := WriteError(msg, 0, errCode, rc.Req, rc.Rw, details...)
if _err != nil {
rc.L.Error().Err(_err).Msg("Failure in WriteErr")
}
}
func (rc ReqContext) WriteNotFound(errCode ErrorCodes) {
_err := WriteErr(errors.New("Not found"), http.StatusNotFound, errCode, rc.Req, rc.Rw)
if _err != nil {
rc.L.Error().Err(_err).Msg("Failure in WriteErr")
}
}
func (rc ReqContext) WriteErr(err error, errCode ErrorCodes) {
if apiErr, ok := err.(APIError); ok {
code := apiErr.Err.Code
if errCode != "" {
code = errCode + ": " + code
}
_err := WriteError(apiErr.Err.Message, apiErr.StatusCode, apiErr.Err.Code, rc.Req, rc.Rw, apiErr.Details)
if _err != nil {
rc.L.Error().Err(_err).Msg("Failure in WriteErr")
}
return
}
_err := WriteErr(err, 0, errCode, rc.Req, rc.Rw)
if _err != nil {
rc.L.Error().Err(_err).Msg("Failure in WriteErr")
}
}
func (rc ReqContext) WriteOutput(output interface{}, statusCode int) {
_err := WriteOutput(false, statusCode, output, rc.Req, rc.Rw)
if _err != nil {
rc.L.Error().Err(_err).Msg("Failure in WriteErr")
}
}
func (rc ReqContext) ValidateStruct(input interface{}) error {
return rc.Context.StructValidater(input)
}
func (rc ReqContext) Unmarshal(body []byte, j interface{}) error {
if body == nil {
if rc.L.HasDebug() {
rc.L.Debug().Msg("Body was nil")
}
return fmt.Errorf("Body was nil")
}
err := UnmarshalWithKind(rc.ContentKind, body, j)
if err != nil && rc.L.HasDebug() {
rc.L.Debug().
Bytes("body", body).
Err(err).
Msg("unmarshalling failed with input")
}
return err
}
type decoder interface {
Decode(v interface{}) error
}
func (rc ReqContext) GetDecoder() decoder {
switch rc.ContentKind {
case OutputJson:
return json.NewDecoder(rc.Req.Body)
case OutputToml:
return toml.NewDecoder(rc.Req.Body)
case OutputYaml:
return yaml.NewDecoder(rc.Req.Body)
}
return json.NewDecoder(rc.Req.Body)
}
// Reads the requests body, and validates it.
// with writeErr = true, upon validation error it will write the error to the body. In this case, the caller should simply return
func (rc ReqContext) ValidateBody(j interface{}, writeErr bool) error {
if rc.Req.ContentLength == 0 {
err := ErrEmptyBody
if writeErr {
rc.WriteErr(err, CodeErrMissingBody)
}
return err
}
decoder := rc.GetDecoder()
err := decoder.Decode(j)
if err != nil {
if writeErr {
rc.WriteErr(err, CodeErrMarshal)
}
return err
}
err = rc.ValidateStruct(j)
if err != nil {
if writeErr {
rc.WriteErr(err, CodeErrInputValidation)
}
return err
}
return err
}
// Will perform validation and write errors to responsewriter if validation failed.
// If err is non-nill, the caller should simply return
func (rc ReqContext) ValidateBytes(body []byte, j interface{}) error {
err := rc.Unmarshal(body, j)
if err != nil {
rc.WriteErr(err, CodeErrMarshal)
return err
}
err = rc.ValidateStruct(j)
if err != nil {
// rw.Header().Set("Content-Type", "application/json")
// rw.WriteHeader(http.StatusBadRequest)
rc.WriteErr(err, CodeErrInputValidation)
return err
}
return err
}