/
encoder.go
221 lines (181 loc) · 4.24 KB
/
encoder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
package http
import (
"context"
"fmt"
"io"
net_http "net/http"
"sync"
"time"
kit_http "github.com/go-kit/kit/transport/http"
"github.com/oxtoacart/bpool"
)
// Encoder denotes the Encoder used to write the data on stream
// after reading the interface
type Encoder func(context.Context, net_http.ResponseWriter, interface{}) error
// A pool is an interface for getting and returning temporary
// byte slices for use by io.CopyBuffer.
type pool interface {
Get() []byte
Put([]byte)
}
type flusher interface {
io.Writer
net_http.Flusher
}
type latencyWriter struct {
dst flusher
latency time.Duration
mu sync.Mutex
timer *time.Timer
pending bool
}
func (lw *latencyWriter) Write(p []byte) (n int, err error) {
lw.mu.Lock()
defer lw.mu.Unlock()
n, err = lw.dst.Write(p)
if lw.latency < 0 {
lw.dst.Flush()
return
}
if lw.pending {
return
}
if lw.timer == nil {
lw.timer = time.AfterFunc(lw.latency, lw.delayedFlush)
} else {
lw.timer.Reset(lw.latency)
}
lw.pending = true
return
}
func (lw *latencyWriter) delayedFlush() {
lw.mu.Lock()
defer lw.mu.Unlock()
if !lw.pending {
return
}
lw.dst.Flush()
lw.pending = false
}
func (lw *latencyWriter) stop() {
lw.mu.Lock()
defer lw.mu.Unlock()
lw.pending = false
if lw.timer != nil {
lw.timer.Stop()
}
}
// util methods to copy response from *net_http.Response to net_http.ResponseWriter
// copies response from response.Body to ResponseWriter
func copyResponse(
bp pool,
dst io.Writer,
src io.Reader,
flushdur time.Duration,
) error {
if flushdur != 0 {
if wm, ok := dst.(flusher); ok {
lw := &latencyWriter{
dst: wm,
latency: flushdur,
}
defer lw.stop()
lw.pending = true
lw.timer = time.AfterFunc(flushdur, lw.delayedFlush)
dst = lw
}
}
var buf []byte
if bp != nil {
buf = bp.Get()
defer bp.Put(buf)
}
_, err := copyBuffer(dst, src, buf)
return err
}
// copyBuffer returns any write errors or non-EOF read errors, and the amount
// of bytes written.
func copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
if len(buf) == 0 {
buf = make([]byte, 32*1024)
}
var written int64
for {
nr, rerr := src.Read(buf)
if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
return written, fmt.Errorf("read error during body copy: %v", rerr)
}
if nr > 0 {
nw, werr := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
}
if werr != nil {
return written, werr
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if rerr != nil {
if rerr == io.EOF {
rerr = nil
}
return written, rerr
}
}
}
func flushInterval(res *net_http.Response) time.Duration {
resCT := res.Header.Get("Content-Type")
// For Server-Sent Events responses, flush immediately.
// The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
if resCT == "text/event-stream" {
return -1 // negative means immediately
}
// TODO: more specific cases?
return 10 * time.Millisecond
}
func copyHeader(dst, src net_http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func newDefaultEncoder() Encoder {
bufferPool := bpool.NewBytePool(100, 1000000)
return func(ctx context.Context, rw net_http.ResponseWriter, res interface{}) (err error) {
rr, ok := res.(*net_http.Response)
if !ok {
return ErrNotHTTPResponse
}
if res == nil {
rw.WriteHeader(net_http.StatusNoContent)
return
}
copyHeader(rw.Header(), rr.Header)
switch {
case rr.StatusCode == 0:
rw.WriteHeader(net_http.StatusOK)
case rr.StatusCode > 0:
rw.WriteHeader(rr.StatusCode)
default:
panic("status code should be non-negative")
}
defer func() {
rr.Body.Close()
rr.Close = true
}()
return copyResponse(bufferPool, rw, rr.Body, flushInterval(rr))
}
}
// NewDefaultEncoder returns a default Encoder used by http
func NewDefaultEncoder() Encoder { return newDefaultEncoder() }
// NewDefaultJSONEncoder encodes the response in JSON
func NewDefaultJSONEncoder() Encoder { return Encoder(kit_http.EncodeJSONResponse) }
// NewGoKitEncoderHandlerOption provides option to encode the request
func NewGoKitEncoderHandlerOption(fn kit_http.EncodeResponseFunc) HandlerOption {
return func(h *handler) {
h.encoder = Encoder(fn)
}
}