/
event_websocket_server.go
110 lines (98 loc) · 2.68 KB
/
event_websocket_server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package cltest
import (
"net/http"
"net/http/httptest"
"net/url"
"sync"
"testing"
"github.com/gorilla/websocket"
)
type EventWebSocketServer struct {
*httptest.Server
mutex *sync.RWMutex // shared mutex for safe access to arrays/maps.
t *testing.T
connections []*websocket.Conn
Connected chan struct{}
Received chan string
URL *url.URL
}
func NewEventWebSocketServer(t *testing.T) (*EventWebSocketServer, func()) {
server := &EventWebSocketServer{
mutex: &sync.RWMutex{},
t: t,
Connected: make(chan struct{}, 1), // have buffer of one for easier assertions after the event
Received: make(chan string, 100),
}
server.Server = httptest.NewServer(http.HandlerFunc(server.handler))
u, err := url.Parse(server.Server.URL)
if err != nil {
t.Fatal("EventWebSocketServer: ", err)
}
u.Scheme = "ws"
server.URL = u
return server, func() {
server.Close()
}
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
func (wss *EventWebSocketServer) handler(w http.ResponseWriter, r *http.Request) {
var err error
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
wss.t.Fatal("EventWebSocketServer Upgrade: ", err)
}
wss.addConnection(conn)
closeCodes := []int{websocket.CloseNormalClosure, websocket.CloseAbnormalClosure}
for {
_, payload, err := conn.ReadMessage() // we only read
if websocket.IsCloseError(err, closeCodes...) {
wss.removeConnection(conn)
return
}
if err != nil {
wss.t.Fatal("EventWebSocketServer ReadMessage: ", err)
}
select {
case wss.Received <- string(payload):
default:
}
}
}
func (wss *EventWebSocketServer) addConnection(conn *websocket.Conn) {
wss.mutex.Lock()
wss.connections = append(wss.connections, conn)
wss.mutex.Unlock()
select { // broadcast connected event
case wss.Connected <- struct{}{}:
default:
}
}
func (wss *EventWebSocketServer) removeConnection(conn *websocket.Conn) {
newc := []*websocket.Conn{}
wss.mutex.Lock()
for _, connection := range wss.connections {
if connection != conn {
newc = append(newc, connection)
}
}
wss.connections = newc
wss.mutex.Unlock()
}
// WriteCloseMessage tells connected clients to disconnect.
// Useful to emulate that the websocket server is shutting down without
// actually shutting down.
// This overcomes httptest.Server's inability to restart on the same URL:port.
func (wss *EventWebSocketServer) WriteCloseMessage() {
wss.mutex.RLock()
for _, connection := range wss.connections {
err := connection.WriteMessage(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
if err != nil {
wss.t.Error(err)
}
}
wss.mutex.RUnlock()
}