-
Notifications
You must be signed in to change notification settings - Fork 0
/
websocket.go
331 lines (267 loc) · 10.3 KB
/
websocket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
package webbridge
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"sync"
"sync/atomic"
"github.com/lxzan/gws"
"github.com/renbou/grpcbridge/bridgelog"
"github.com/renbou/grpcbridge/grpcadapter"
"github.com/renbou/grpcbridge/internal/rpcutil"
"github.com/renbou/grpcbridge/routing"
"github.com/renbou/grpcbridge/transcoding"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
var (
errExpectedText = status.Errorf(codes.InvalidArgument, "received binary message instead of text")
errExpectedBinary = status.Errorf(codes.InvalidArgument, "received text message instead of binary")
)
// gwsStreamKey is used to store the complete stream state in gws socket sessions,
// which allows reusing a single upgrader for all requests.
const gwsStreamKey = "grpcbridge\x00request"
// TranscodedWebSocketBridgeOpts define all the optional settings which can be set for [TranscodedWebSocketBridge].
type TranscodedWebSocketBridgeOpts struct {
// Logs are discarded by default.
Logger bridgelog.Logger
// If not set, the default [transcoding.StandardTranscoder] is created with default options.
Transcoder transcoding.HTTPTranscoder
// If not set, the default [grpcadapter.ProxyForwarder] is created with default options.
Forwarder grpcadapter.Forwarder
// MetadataParam specifies the name of the query parameter to be parsed as a map containing the metadata to be forwarded.
// This is needed for WebSockets since there's no way to set headers on the WebSocket handshake request through the WebSocket web API.
//
// If not set, _metadata is used. For more info about the format, see [TranscodedWebSocketBridge.ServeHTTP].
MetadataParam string
}
func (o TranscodedWebSocketBridgeOpts) withDefaults() TranscodedWebSocketBridgeOpts {
if o.Logger == nil {
o.Logger = bridgelog.Discard()
}
if o.Transcoder == nil {
o.Transcoder = transcoding.NewStandardTranscoder(transcoding.StandardTranscoderOpts{})
}
if o.Forwarder == nil {
o.Forwarder = grpcadapter.NewProxyForwarder(grpcadapter.ProxyForwarderOpts{})
}
if o.MetadataParam == "" {
o.MetadataParam = defaultMetadataParam
}
return o
}
type TranscodedWebSocketBridge struct {
logger bridgelog.Logger
upgrader *gws.Upgrader
router routing.HTTPRouter
transcoder transcoding.HTTPTranscoder
forwarder grpcadapter.Forwarder
metadataParam string
}
// NewTranscodedWebSocketBridge initializes a new [TranscodedWebSocketBridge] using the specified router and options.
// The router isn't optional, because no routers in grpcbridge can be constructed without some form of required args.
func NewTranscodedWebSocketBridge(router routing.HTTPRouter, opts TranscodedWebSocketBridgeOpts) *TranscodedWebSocketBridge {
opts = opts.withDefaults()
logger := opts.Logger.WithComponent("grpcbridge.web")
upgrader := gws.NewUpgrader(new(gwsHandler), &gws.ServerOption{
ParallelEnabled: false, // No point in parallel message processing, since we're using a single stream per request.
CheckUtf8Enabled: true,
Logger: gwsLogger{logger},
})
return &TranscodedWebSocketBridge{
logger: logger,
upgrader: upgrader,
router: router,
transcoder: opts.Transcoder,
forwarder: opts.Forwarder,
metadataParam: opts.MetadataParam,
}
}
func (b *TranscodedWebSocketBridge) ServeHTTP(unwrappedRW http.ResponseWriter, r *http.Request) {
md := parseMetadataQuery(r, b.metadataParam)
req := routeTranscodedRequest(unwrappedRW, r, b.router, b.transcoder)
if req == nil {
return
}
// TODO(renbou): check that only streaming requests are handled here.
socket, err := b.upgrader.Upgrade(unwrappedRW, r)
if err != nil {
// Upgrade() will write an error if Hijack() was successful, but if it wasn't,
// then no error will be written, so we need to write one ourselves just in case.
// This is a no-op if Upgrade() successfully Hijack()ed the request.
writeError(req.w, req.r, req.resptc, status.Error(codes.Internal, err.Error()))
return
}
// Always clean up the incoming socket to avoid any potential leaks.
defer socket.NetConn().Close()
stream := &gwsStream{
socket: socket,
req: req,
events: make(chan gwsReadEvent),
done: make(chan struct{}),
}
// Store the whole stream state in the session so that it can be retrieved by OnMessage.
// ReadLoop() will exit when the stream exits due to client/server closure, or when the client closes the WebSocket.
socket.Session().Store(gwsStreamKey, stream)
// End of forwarding will notify ReadLoop() to exit via close(done) and WriteClose().
// However, ReadLoop() must also have a way to notify the forwarding proccess to handle cases where the client closes the connection.
// This context allows us to immediately cancel Forward() once we know no client is listening for any more responses.
// NB: r.Context is valid here even after Hijack() in Upgrade()
ctx, cancel := context.WithCancel(r.Context())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
defer cancel()
socket.ReadLoop()
}()
// Even though web clients aren't able to set metadata in headers, it's still useful to support it for other potential clients.
ctx = metadata.NewIncomingContext(ctx, metadata.Join(md, headersToMD(r.Header)))
logger := b.logger.With(
"target", req.route.Target.Name,
"grpc.method", req.route.Method.RPCName,
"http.method", r.Method,
"http.path", r.URL.Path,
"http.params", req.route.PathParams,
)
logger.Debug("began handling WebSocket stream")
defer logger.Debug("ended handling WebSocket stream")
err = b.forwarder.Forward(ctx, grpcadapter.ForwardParams{
Target: req.route.Target,
Service: req.route.Service,
Method: req.route.Method,
Incoming: stream,
Outgoing: req.conn,
})
logger.Debug("WebSocket stream forwarding done", "error", err)
// Close the WebSocket and notify ReadLoop() to stop processing OnMessage, if it hasn't already.
code, reason := websocketError(err)
socket.WriteClose(code, []byte(reason))
close(stream.done) // this allows OnMessage to instantly exit
wg.Wait() // just a safety measure to avoid leaks
}
func websocketError(err error) (code uint16, reason string) {
if err == nil {
return 1000, ""
}
code = 1001
if errors.Is(err, errExpectedBinary) || errors.Is(err, errExpectedText) {
code = 1003
}
if st, ok := status.FromError(err); ok {
// more compact form because ws has ~123 bytes limit on the reason
reason = fmt.Sprintf("code %s: %s", st.Code(), st.Message())
} else {
reason = err.Error()
}
return code, reason
}
type gwsReadEvent struct {
data []byte
err error
}
// TODO(renbou): document the various stream flow states to properly showcase how the various errors and closures are handled.
type gwsStream struct {
socket *gws.Conn
req *transcodedRequest
// done is needed separately to the events channel so that OnMessage can be safely notified
// to stop trying to send any more events, so that the ReadLoop can exit successfully.
done chan struct{}
events chan gwsReadEvent
sendActive atomic.Bool
recvActive atomic.Bool
// used to ignore messages for unary requests
alreadyRead atomic.Bool
}
func (s *gwsStream) Send(ctx context.Context, msg proto.Message) error {
if !s.sendActive.CompareAndSwap(false, true) {
panic("grpcbridge: Send() called concurrently on gwsStream")
}
defer s.sendActive.Store(false)
return withCtx(ctx, func() error { return s.send(msg) })
}
func (s *gwsStream) send(msg proto.Message) error {
b, err := s.req.resptc.Transcode(msg)
if err != nil {
return responseTranscodingError(err)
}
code := gws.OpcodeText
if _, binary := s.req.resptc.ContentType(msg); binary {
code = gws.OpcodeBinary
}
// WriteMessage will be unblocked by socket.NetConn().Close() in ServeHTTP, if not before.
if err := s.socket.WriteMessage(code, b); err != nil {
return status.Errorf(codes.Internal, "failed to write response message: %s", err)
}
return nil
}
func (s *gwsStream) Recv(ctx context.Context, msg proto.Message) error {
if !s.recvActive.CompareAndSwap(false, true) {
panic("grpcbridge: Recv() called concurrently on unaryHTTPStream")
}
defer s.recvActive.Store(false)
var event gwsReadEvent
// Only wait for a message when an event/body is actually required.
if s.req.route.Method.ClientStreaming || s.req.route.Binding.RequestBodyPath != "" {
select {
case ev, ok := <-s.events:
if !ok {
return io.EOF // events channel closed by OnMessage for unary requests
}
event = ev
case <-ctx.Done():
return rpcutil.ContextError(ctx.Err())
}
}
// err written when the incoming side of the socket is closed by us due to other errors,
// which can't really be counted as a proper client-side closure.
if event.err != nil {
return event.err
}
return requestTranscodingError(s.req.reqtc.Transcode(event.data, msg))
}
// WebSockets don't support headers, and they can only be returned during the upgrade, which would be just way too tedious to implement.
func (s *gwsStream) SetHeader(md metadata.MD) {}
// WebSockets don't support trailers AT ALL - there's no way to send them after the upgrade.
func (s *gwsStream) SetTrailer(md metadata.MD) {}
type gwsHandler struct {
gws.BuiltinEventHandler
}
func (b *gwsHandler) OnMessage(socket *gws.Conn, message *gws.Message) {
streamAny, _ := socket.Session().Load(gwsStreamKey)
stream := streamAny.(*gwsStream)
if !(stream.req.route.Method.ClientStreaming || (stream.req.route.Binding.RequestBodyPath != "" && stream.alreadyRead.CompareAndSwap(false, true))) {
return
}
_, expectBinary := stream.req.reqtc.ContentType()
messageBinary := message.Opcode != gws.OpcodeText
var event gwsReadEvent
// send error as event, the stream/socket will be closed when the error is handled.
if messageBinary != expectBinary {
if expectBinary {
event = gwsReadEvent{err: errExpectedBinary}
} else {
event = gwsReadEvent{err: errExpectedText}
}
} else {
event = gwsReadEvent{data: message.Data.Bytes()}
}
select {
case stream.events <- event: // events closed only by OnMessage, so no panic will occur here
case <-stream.done:
}
if !stream.req.route.Method.ClientStreaming {
// only one instance of OnMessage can get to this statement due to the alreadyRead check above.
close(stream.events)
}
}
type gwsLogger struct {
bridgelog.Logger
}
func (l gwsLogger) Error(args ...any) {
l.Logger.Error("gws WebSocket error", "error", fmt.Sprint(args...))
}