-
Notifications
You must be signed in to change notification settings - Fork 1
/
broker.go
380 lines (331 loc) · 10.1 KB
/
broker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
package stomp
import (
"errors"
"fmt"
"log"
"net"
"strconv"
"strings"
"sync"
"github.com/cenkalti/backoff"
"github.com/go-co-op/gocron"
"github.com/google/uuid"
)
// Session handles the STOMP client session on connection
type Session struct {
conn net.Conn
sessionID string
loginFunc LoginFunc
wgSessions *sync.WaitGroup
hbSendIntervalMsec int
hbRecvIntervalMsec int
hbJob *gocron.Job
}
// NewSession creates a new session object & maintains the session state internally
func NewSession(conn net.Conn, loginFunc LoginFunc, wg *sync.WaitGroup,
heartbeatSendIntervalMsec, heartbeatReceiveIntervalMsec int,
) *Session {
return &Session{
conn: conn,
loginFunc: loginFunc,
sessionID: uuid.NewString(),
wgSessions: wg,
hbSendIntervalMsec: heartbeatSendIntervalMsec,
hbRecvIntervalMsec: heartbeatReceiveIntervalMsec,
}
}
// LoginFunc represents the user-defined authentication function
type LoginFunc func(login, passcode string) error
// Start begins the STOMP session with the Client
func (sess *Session) Start() {
defer sess.cleanup()
for raw := range frameScanner(sess.conn) {
frame, err := NewFrameFromBytes(raw)
if err != nil {
_ = sess.sendError(err, fmt.Sprint("Frame serialization error:"+frame.String()))
return
}
if err = frame.Validate(ClientFrame); err != nil {
_ = sess.sendError(err, fmt.Sprint("Frame validation error:"+frame.String()))
return
}
if err = sess.stateMachine(frame); err != nil {
log.Println(err)
return
}
}
}
func (sess *Session) cleanup() {
_ = sess.conn.Close()
sess.wgSessions.Done()
if sess.hbJob != nil {
sched.RemoveByReference(sess.hbJob)
}
}
// sendError is the helper function to send the ERROR frames
func (sess *Session) sendError(err error, payload string) error {
return sess.send(CmdError, map[Header]string{
HdrKeyContentType: "text/plain",
HdrKeyContentLength: strconv.Itoa(len(payload)),
HdrKeyMessage: err.Error(),
}, []byte(payload))
}
// stateMachine is the brain of the protocol
func (sess *Session) stateMachine(frame *Frame) error {
switch frame.command {
case CmdConnect, CmdStomp:
if err := sess.handleConnect(frame); err != nil {
return err
}
case CmdSend:
// If the message is part of an ongoing transaction
if txID := frame.getHeader(HdrKeyTransaction); txID != "" {
if err := bufferTxMessage(txID, frame); err != nil {
return err
}
return nil
}
// Not part of transaction
if err := publish(frame, ""); err != nil {
return err
}
case CmdSubscribe:
ack := HdrValAckAuto
if _, ok := frame.headers[HdrKeyAck]; ok {
ack = AckMode(frame.headers[HdrKeyAck])
}
if err := addSubscription(frame.headers[HdrKeyDestination], frame.headers[HdrKeyID], ack, sess); err != nil {
return err
}
case CmdUnsubscribe:
if err := removeSubscription(frame.headers[HdrKeyID]); err != nil {
return err
}
case CmdAck:
// if err := processAck(frame.headers[HdrKeyID]); err != nil {
// return err
// }
case CmdNack:
// if err := processNack(frame.headers[HdrKeyID]); err != nil {
// return err
// }
case CmdBegin:
if err := startTx(frame.headers[HdrKeyTransaction]); err != nil {
return err
}
case CmdCommit:
txID := frame.headers[HdrKeyTransaction]
// Pick each message from TX buffer
if err := foreachTx(txID, func(frameTx *Frame) error {
// Send the message to each subscriber
if err := publish(frameTx, txID); err != nil {
return err
}
return nil
}); err != nil {
return err
}
case CmdAbort:
if err := dropTx(frame.headers[HdrKeyTransaction]); err != nil {
return err
}
case CmdDisconnect:
_ = cleanupSubscriptions(sess.sessionID)
_ = sess.send(CmdReceipt, map[Header]string{HdrKeyReceiptID: frame.headers[HdrKeyReceipt]}, nil)
_ = sess.conn.Close()
}
return nil
}
func (sess *Session) sendMessage(dest, subsID string, ackNum uint32, txID string, headers map[Header]string,
body []byte,
) error {
h := map[Header]string{
HdrKeyDestination: dest,
HdrKeyMessageID: uuid.NewString(),
HdrKeySubscription: subsID,
}
h[HdrKeyAck] = fmtAckNum(dest, subsID, ackNum)
if txID != "" {
h[HdrKeyTransaction] = txID
}
for k, v := range headers {
h[Header(strings.ToLower(string(k)))] = v
}
return sess.send(CmdMessage, h, body)
}
func (sess *Session) sendRaw(body []byte) error {
sendIt := func() error {
if _, err := sess.conn.Write(body); err != nil {
log.Println(err)
return err
}
return nil
}
if err := backoff.Retry(sendIt, backoff.NewExponentialBackOff()); err != nil {
return err
}
return nil
}
func (sess *Session) send(cmd Command, headers map[Header]string, body []byte) error {
f := NewFrame(cmd, headers, body)
// Make this check optional later
if err := f.Validate(ServerFrame); err != nil {
return err
}
sendIt := func() error {
if _, err := sess.conn.Write(f.Serialize()); err != nil {
log.Println(err)
return err
}
return nil
}
// Retry sending on error
if err := backoff.Retry(sendIt, backoff.NewExponentialBackOff()); err != nil {
return err
}
return nil
}
// handleConnect responds to the CONNECT message from client
func (sess *Session) handleConnect(f *Frame) error {
// Authentication
if sess.loginFunc != nil {
login, passcode := f.getHeader(HdrKeyLogin), f.getHeader(HdrKeyPassCode)
if err := sess.loginFunc(login, passcode); err != nil {
_ = sess.sendError(errors.New("login failed"), "Authentication failed:\n"+err.Error())
return errorMsg(errBrokerStateMachine, "Login error: "+err.Error())
}
}
// Version negotiation
ver := ""
for _, v := range strings.Split(f.headers[HdrKeyAcceptVersion], ",") {
if v == "1.2" {
ver = "1.2"
break
}
}
if ver == "" {
// Send version ERROR
return errorMsg(errBrokerStateMachine, "Invalid client version received: "+f.getHeader(HdrKeyVersion))
}
// Heartbeat negotiation
if hbVal := f.getHeader(HdrKeyHeartBeat); hbVal != "" {
if err := sess.negotiateHeartbeats(hbVal); err != nil {
return errorMsg(errBrokerStateMachine, "Heartbeat negotiation: "+err.Error())
}
}
// Respond with CONNECTED
if err := sess.send(CmdConnected, map[Header]string{
HdrKeyVersion: ver,
HdrKeySession: sess.sessionID,
HdrKeyServer: "go-proto-stomp/" + releaseVersion,
HdrKeyHeartBeat: fmt.Sprintf("%d,%d", sess.hbSendIntervalMsec, sess.hbRecvIntervalMsec),
}, nil); err != nil {
return err
}
return nil
}
func (sess *Session) negotiateHeartbeats(hbVal string) error {
intervals := strings.Split(hbVal, ",")
if len(intervals) != 2 {
return errorMsg(errBrokerStateMachine, "Invalid heartbeat header: "+hbVal)
}
// Send-HB negotiation
clientSendInterval, err := strconv.Atoi(intervals[0])
if err != nil {
return errorMsg(errBrokerStateMachine,
"Invalid heartbeat header send interval from client: "+hbVal)
}
if clientSendInterval == 0 || sess.hbRecvIntervalMsec == 0 {
sess.hbRecvIntervalMsec = 0
} else if clientSendInterval > sess.hbRecvIntervalMsec {
sess.hbRecvIntervalMsec = clientSendInterval
}
// Receive-HB negotiation
clientRecvInterval, err := strconv.Atoi(intervals[1])
if err != nil {
return errorMsg(errBrokerStateMachine,
"Invalid heartbeat header receive interval from client: "+hbVal)
}
if clientRecvInterval == 0 || sess.hbSendIntervalMsec == 0 {
sess.hbSendIntervalMsec = 0
} else if clientRecvInterval > sess.hbSendIntervalMsec {
sess.hbSendIntervalMsec = clientRecvInterval
}
// Schedule sending heartbeats by hbSendIntervalMsec
if sess.hbSendIntervalMsec == 0 { // no heartbeats to be sent
return nil
}
sess.hbJob, err = sched.Every(sess.hbSendIntervalMsec).Milliseconds().Tag(sess.sessionID).Do(
func() {
_ = sess.sendRaw([]byte("\n"))
})
if err != nil {
return errorMsg(errBrokerStateMachine, "Heartbeat setup error: "+err.Error())
}
sched.StartAsync()
return nil
}
// Broker lists the methods supported by the STOMP brokers
type Broker interface {
// ListenAndServe is a blocking method that keeps accepting the client connections and handles the STOMP messages.
ListenAndServe()
// Shutdown should be called to bring down the underlying server gracefully.
Shutdown()
}
// BrokerOpts is passed as an argument to StartBroker
type BrokerOpts struct {
// Transport refers to the underlying protocol for STOMP.
// Choices: TransportTCP, TransportWebsocket. Default: TransportTCP
Transport Transport
// Host is the name of the host or IP to bind the server to. Default: localhost
Host string
// Port is the port number for the server to listen on. Default: 61613 (DefaultPort)
Port string
// LoginFunc is a user defined function for authenticating the user. Default: nil
// It is of the form `func(login, passcode string) error`
LoginFunc LoginFunc
// HeartbeatSendIntervalMsec is the interval in milliseconds by which the broker can send heartbeats.
// The broker will negotiate using this value with the client. Default: 0 (no heartbeats)
// It will not send the heartbeats by an interval any smaller than this value.
HeartbeatSendIntervalMsec int
// HeartbeatReceiveIntervalMsec is the interval in milliseconds by which the broker can receive heartbeats.
// The broker will negotiate using this value with the client. Default: 0 (no heartbeats)
// This is to tell the client that the broker cannot receive heartbeats by any shorter interval than this value.
HeartbeatReceiveIntervalMsec int
}
// StartBroker is the entry point for the STOMP broker.
func StartBroker(opts *BrokerOpts) (Broker, error) {
var broker Broker
var err error
// Set default values
if opts.Host == "" {
opts.Host = "localhost"
}
if opts.Port == "" {
opts.Port = DefaultPort
}
if opts.Transport == "" {
opts.Transport = TransportTCP
}
if opts.HeartbeatSendIntervalMsec < 0 {
opts.HeartbeatSendIntervalMsec = 0
}
if opts.HeartbeatReceiveIntervalMsec < 0 {
opts.HeartbeatReceiveIntervalMsec = 0
}
switch opts.Transport {
case TransportTCP:
var tcp *tcpBroker
if tcp, err = startTcpBroker(opts); err != nil {
return nil, err
}
broker = tcp
case TransportWebsocket:
var wss *wssBroker
if wss, err = startWebsocketBroker(opts); err != nil {
return nil, err
}
broker = wss
}
return broker, nil
}