/
responses.go
146 lines (124 loc) · 3.73 KB
/
responses.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package httpx
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest/internal/errcode"
"github.com/zeromicro/go-zero/rest/internal/header"
)
var (
errorHandler func(error) (int, interface{})
errorHandlerCtx func(context.Context, error) (int, interface{})
lock sync.RWMutex
)
// Error writes err into w.
func Error(w http.ResponseWriter, err error, fns ...func(w http.ResponseWriter, err error)) {
lock.RLock()
handler := errorHandler
lock.RUnlock()
doHandleError(w, err, handler, WriteJson, fns...)
}
// ErrorCtx writes err into w.
func ErrorCtx(ctx context.Context, w http.ResponseWriter, err error,
fns ...func(w http.ResponseWriter, err error)) {
lock.RLock()
handlerCtx := errorHandlerCtx
lock.RUnlock()
var handler func(error) (int, interface{})
if handlerCtx != nil {
handler = func(err error) (int, interface{}) {
return handlerCtx(ctx, err)
}
}
writeJson := func(w http.ResponseWriter, code int, v interface{}) {
WriteJsonCtx(ctx, w, code, v)
}
doHandleError(w, err, handler, writeJson, fns...)
}
// Ok writes HTTP 200 OK into w.
func Ok(w http.ResponseWriter) {
w.WriteHeader(http.StatusOK)
}
// OkJson writes v into w with 200 OK.
func OkJson(w http.ResponseWriter, v interface{}) {
WriteJson(w, http.StatusOK, v)
}
// OkJsonCtx writes v into w with 200 OK.
func OkJsonCtx(ctx context.Context, w http.ResponseWriter, v interface{}) {
WriteJsonCtx(ctx, w, http.StatusOK, v)
}
// SetErrorHandler sets the error handler, which is called on calling Error.
func SetErrorHandler(handler func(error) (int, interface{})) {
lock.Lock()
defer lock.Unlock()
errorHandler = handler
}
// SetErrorHandlerCtx sets the error handler, which is called on calling Error.
func SetErrorHandlerCtx(handlerCtx func(context.Context, error) (int, interface{})) {
lock.Lock()
defer lock.Unlock()
errorHandlerCtx = handlerCtx
}
// WriteJson writes v as json string into w with code.
func WriteJson(w http.ResponseWriter, code int, v interface{}) {
if err := doWriteJson(w, code, v); err != nil {
logx.Error(err)
}
}
// WriteJsonCtx writes v as json string into w with code.
func WriteJsonCtx(ctx context.Context, w http.ResponseWriter, code int, v interface{}) {
if err := doWriteJson(w, code, v); err != nil {
logx.WithContext(ctx).Error(err)
}
}
func doHandleError(w http.ResponseWriter, err error, handler func(error) (int, interface{}),
writeJson func(w http.ResponseWriter, code int, v interface{}),
fns ...func(w http.ResponseWriter, err error)) {
if handler == nil {
if len(fns) > 0 {
for _, fn := range fns {
fn(w, err)
}
} else if errcode.IsGrpcError(err) {
// don't unwrap error and get status.Message(),
// it hides the rpc error headers.
http.Error(w, err.Error(), errcode.CodeFromGrpcError(err))
} else {
http.Error(w, err.Error(), http.StatusBadRequest)
}
return
}
code, body := handler(err)
if body == nil {
w.WriteHeader(code)
return
}
e, ok := body.(error)
if ok {
http.Error(w, e.Error(), code)
} else {
writeJson(w, code, body)
}
}
func doWriteJson(w http.ResponseWriter, code int, v interface{}) error {
bs, err := json.Marshal(v)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return fmt.Errorf("marshal json failed, error: %w", err)
}
w.Header().Set(ContentType, header.JsonContentType)
w.WriteHeader(code)
if n, err := w.Write(bs); err != nil {
// http.ErrHandlerTimeout has been handled by http.TimeoutHandler,
// so it's ignored here.
if err != http.ErrHandlerTimeout {
return fmt.Errorf("write response failed, error: %w", err)
}
} else if n < len(bs) {
return fmt.Errorf("actual bytes: %d, written bytes: %d", len(bs), n)
}
return nil
}