-
Notifications
You must be signed in to change notification settings - Fork 12
/
gw.go
245 lines (211 loc) · 5.58 KB
/
gw.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
package gw
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"github.com/gorilla/websocket"
)
// ErrorHandler is used in Settings for handling errors
type ErrorHandler func(error)
// ConnectHandler is used in Settings for handling the initial CONNECT of
// a nats connection
type ConnectHandler func(*NatsConn, *http.Request, *websocket.Conn) error
// NatsServerInfo is the information returned by the INFO nats message
type NatsServerInfo string
// Settings configures a Gateway
type Settings struct {
NatsAddr string
EnableTLS bool
TLSConfig *tls.Config
ConnectHandler ConnectHandler
ErrorHandler ErrorHandler
WSUpgrader *websocket.Upgrader
Trace bool
}
// Gateway is a HTTP handler that acts as a websocket gateway to a NATS server
type Gateway struct {
settings Settings
onError ErrorHandler
handleConnect ConnectHandler
}
var defaultUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
// NatsConn holds a NATS TCP connection
type NatsConn struct {
Conn net.Conn
CmdReader CommandsReader
ServerInfo NatsServerInfo
}
func (gw *Gateway) defaultConnectHandler(natsConn *NatsConn, r *http.Request, wsConn *websocket.Conn) error {
// Default behavior is to let the client on the other side do the CONNECT
// after having forwarded the 'INFO' command
infoCmd := append([]byte("INFO "), []byte(natsConn.ServerInfo)...)
infoCmd = append(infoCmd, byte('\r'), byte('\n'))
if gw.settings.Trace {
fmt.Println("[TRACE] <--", string(infoCmd))
}
if err := wsConn.WriteMessage(websocket.TextMessage, infoCmd); err != nil {
return err
}
return nil
}
func defaultErrorHandler(err error) {
fmt.Println("[ERROR]", err)
}
func copyAndTrace(prefix string, dst io.Writer, src io.Reader, buf []byte) (int64, error) {
read, err := src.Read(buf)
if err != nil {
return 0, err
}
fmt.Println("[TRACE]", prefix, string(buf[:read]))
written, err := dst.Write(buf[:read])
if written != read {
return int64(written), io.ErrShortWrite
}
return int64(written), err
}
// NewGateway instanciates a Gateway
func NewGateway(settings Settings) *Gateway {
gw := Gateway{
settings: settings,
}
gw.setErrorHandler(settings.ErrorHandler)
gw.setConnectHandler(settings.ConnectHandler)
return &gw
}
func (gw *Gateway) setErrorHandler(handler ErrorHandler) {
if handler == nil {
gw.onError = defaultErrorHandler
} else {
gw.onError = handler
}
}
func (gw *Gateway) setConnectHandler(handler ConnectHandler) {
if handler == nil {
gw.handleConnect = gw.defaultConnectHandler
} else {
gw.handleConnect = handler
}
}
func (gw *Gateway) natsToWsWorker(messageType int, ws *websocket.Conn, src CommandsReader, doneCh chan<- bool) {
defer func() {
doneCh <- true
}()
for {
cmd, err := src.nextCommand()
if err != nil {
gw.onError(err)
return
}
if gw.settings.Trace {
fmt.Println("[TRACE] <--", string(cmd))
}
if err := ws.WriteMessage(messageType, cmd); err != nil {
gw.onError(err)
return
}
}
}
func (gw *Gateway) wsToNatsWorker(messageType int, nats net.Conn, ws *websocket.Conn, doneCh chan<- bool) {
defer func() {
doneCh <- true
}()
var buf []byte
if gw.settings.Trace {
buf = make([]byte, 1024*1024)
}
for {
_, src, err := ws.NextReader()
if err != nil {
gw.onError(err)
return
}
if gw.settings.Trace {
_, err = copyAndTrace("-->", nats, src, buf)
} else {
_, err = io.Copy(nats, src)
}
if err != nil {
gw.onError(err)
return
}
}
}
// Handler is a HTTP handler function
func (gw *Gateway) Handler(w http.ResponseWriter, r *http.Request) {
upgrader := defaultUpgrader
if gw.settings.WSUpgrader != nil {
upgrader = *gw.settings.WSUpgrader
}
wsConn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
gw.onError(err)
return
}
natsConn, err := gw.initNatsConnectionForWSConn(r, wsConn)
if err != nil {
gw.onError(err)
return
}
doneCh := make(chan bool)
var mode = websocket.TextMessage
if value, ok := r.URL.Query()["mode"]; ok {
if len(value) == 1 && value[0] == "binary" {
mode = websocket.BinaryMessage
}
}
go gw.natsToWsWorker(mode, wsConn, natsConn.CmdReader, doneCh)
go gw.wsToNatsWorker(mode, natsConn.Conn, wsConn, doneCh)
<-doneCh
wsConn.Close()
natsConn.Conn.Close()
<-doneCh
}
func readInfo(cmd []byte) (NatsServerInfo, error) {
if !bytes.Equal(cmd[:5], []byte("INFO ")) {
return "", fmt.Errorf("Invalid 'INFO' command: %s", string(cmd))
}
return NatsServerInfo(cmd[5 : len(cmd)-2]), nil
}
// initNatsConnectionForRequest open a connection to the nats server, consume the
// INFO message if needed, and finally handle the CONNECT
func (gw *Gateway) initNatsConnectionForWSConn(r *http.Request, wsConn *websocket.Conn) (*NatsConn, error) {
conn, err := net.Dial("tcp", gw.settings.NatsAddr)
if err != nil {
return nil, err
}
natsConn := NatsConn{Conn: conn, CmdReader: NewCommandsReader(conn)}
// read the INFO, keep it
infoCmd, err := natsConn.CmdReader.nextCommand()
if err != nil {
return nil, err
}
info, err := readInfo(infoCmd)
if err != nil {
return nil, err
}
natsConn.ServerInfo = info
// optionnaly initialize the TLS layer
// TODO check if the server requires TLS, which overrides the 'enableTls' setting
if gw.settings.EnableTLS {
tlsConfig := gw.settings.TLSConfig
if tlsConfig == nil {
tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
}
tlsConn := tls.Client(conn, tlsConfig)
tlsConn.Handshake()
natsConn.Conn = tlsConn
natsConn.CmdReader = NewCommandsReader(tlsConn)
}
if err := gw.handleConnect(&natsConn, r, wsConn); err != nil {
return nil, err
}
return &natsConn, nil
}