From c840b87db13712eeddac3fa1617128651a2ca87c Mon Sep 17 00:00:00 2001 From: Stefan Kaes Date: Wed, 9 Aug 2023 14:38:37 +0200 Subject: [PATCH] set read and write deadlines web sockets so nothing blocks forever --- go/client.go | 5 ++++- go/notification_mailer.go | 2 ++ go/server_state.go | 18 +++++++++++++----- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/go/client.go b/go/client.go index 62239a0..16db12a 100644 --- a/go/client.go +++ b/go/client.go @@ -88,6 +88,7 @@ func (s *ClientState) Connect() (err error) { // Close sends a Close message to the server and closed the connection. func (s *ClientState) Close() { defer s.ws.Close() + s.ws.SetWriteDeadline(time.Now().Add(websocket.DefaultDialer.HandshakeTimeout)) err := s.ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { logError("writing websocket close failed: %s", err) @@ -102,6 +103,7 @@ func (s *ClientState) send(msg MsgBody) error { return err } logDebug("sending message") + s.ws.SetWriteDeadline(time.Now().Add(websocket.DefaultDialer.HandshakeTimeout)) err = s.ws.WriteMessage(websocket.TextMessage, b) if err != nil { logError("could not send message: %s", err) @@ -253,6 +255,7 @@ func (s *ClientState) Reader() { default: } logDebug("reading message") + s.ws.SetReadDeadline(time.Now().Add(websocket.DefaultDialer.HandshakeTimeout)) msgType, bytes, err := s.ws.ReadMessage() atomic.AddInt64(&processed, 1) if err != nil || msgType != websocket.TextMessage { @@ -273,7 +276,7 @@ func (s *ClientState) Reader() { // Writer reads messages from an internal channel and dispatches them. It // periodically sends a HEARTBEAT message to the server. It if receives a config // change message, it replaces the current config with the new one. If the -// config change implies that the server URL has changed it exits, relying on +// config change implies that the server URL has changed, it exits, relying on // the outer loop to restart the client. func (s *ClientState) Writer() { ticker := time.NewTicker(1 * time.Second) diff --git a/go/notification_mailer.go b/go/notification_mailer.go index 763b10c..02d18ec 100644 --- a/go/notification_mailer.go +++ b/go/notification_mailer.go @@ -44,6 +44,7 @@ func (s *MailerState) Connect() (err error) { // Close sends a Close message on the websocket and closes it. func (s *MailerState) Close() { defer s.ws.Close() + s.ws.SetWriteDeadline(time.Now().Add(websocket.DefaultDialer.HandshakeTimeout)) err := s.ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { logError("writing websocket close failed: %s", err) @@ -89,6 +90,7 @@ func (s *MailerState) SendMail(text string) { func (s *MailerState) Reader() { for !interrupted { logDebug("reading message") + s.ws.SetReadDeadline(time.Now().Add(websocket.DefaultDialer.HandshakeTimeout)) msgType, bytes, err := s.ws.ReadMessage() if err != nil || msgType != websocket.TextMessage { logError("error reading from server socket: %s", err) diff --git a/go/server_state.go b/go/server_state.go index 68cc2aa..a0752ea 100644 --- a/go/server_state.go +++ b/go/server_state.go @@ -550,6 +550,7 @@ func (s *ServerState) notificationReader(ws *websocket.Conn) { s.wsChannel <- &WsMsg{body: MsgBody{Name: START_NOTIFY}, channel: dispatcherInput} go s.notificationWriter(ws, dispatcherInput) for !interrupted { + ws.SetReadDeadline(time.Now().Add(websocket.DefaultDialer.HandshakeTimeout)) msgType, bytes, err := ws.ReadMessage() if err != nil || msgType != websocket.TextMessage { logError("notificationReader: could not read msg: %s", err) @@ -570,6 +571,7 @@ func (s *ServerState) notificationWriter(ws *websocket.Conn, inputFromDispatcher logInfo("Terminating notification websocket writer") return } + ws.SetWriteDeadline(time.Now().Add(websocket.DefaultDialer.HandshakeTimeout)) ws.WriteMessage(websocket.TextMessage, []byte(data)) case <-time.After(100 * time.Millisecond): // give the outer loop a chance to detect interrupts (without doing a busy wait) @@ -694,6 +696,7 @@ func (s *ServerState) wsReader(ws *websocket.Conn) { var body MsgBody for !interrupted { + ws.SetReadDeadline(time.Now().Add(websocket.DefaultDialer.HandshakeTimeout)) msgType, bytes, err := ws.ReadMessage() atomic.AddInt64(&processed, 1) if err != nil || msgType != websocket.TextMessage { @@ -720,7 +723,10 @@ func (s *ServerState) wsReader(ws *websocket.Conn) { func (s *ServerState) wsWriter(clientID string, ws *websocket.Conn, inputFromDispatcher chan string) { s.waitGroup.Add(1) defer s.waitGroup.Done() - defer ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(1000, "good bye")) + defer func() { + ws.SetWriteDeadline(time.Now().Add(websocket.DefaultDialer.HandshakeTimeout)) + ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(1000, "good bye")) + }() for !interrupted { select { case data, ok := <-inputFromDispatcher: @@ -728,6 +734,7 @@ func (s *ServerState) wsWriter(clientID string, ws *websocket.Conn, inputFromDis logInfo("Closed channel for %s", clientID) return } + ws.SetWriteDeadline(time.Now().Add(websocket.DefaultDialer.HandshakeTimeout)) ws.WriteMessage(websocket.TextMessage, []byte(data)) case <-time.After(100 * time.Millisecond): // give the outer loop a chance to detect interrupts @@ -738,11 +745,12 @@ func (s *ServerState) wsWriter(clientID string, ws *websocket.Conn, inputFromDis // Initialize completes the state initialization by checking redis connectivity // and loading saved state. func (s *ServerState) Initialize() { - path := s.GetConfig().RedisMasterFile - VerifyMasterFileString(path) + config := s.GetConfig() + websocket.DefaultDialer.HandshakeTimeout = time.Duration(config.DialTimeout) * time.Second + VerifyMasterFileString(config.RedisMasterFile) var masters map[string]string - if MasterFileExists(path) { - masters = RedisMastersFromMasterFile(path) + if MasterFileExists(config.RedisMasterFile) { + masters = RedisMastersFromMasterFile(config.RedisMasterFile) } else if s.opts.ConsulClient != nil { kv, err := s.opts.ConsulClient.GetState() if err != nil {