diff --git a/socketmode/deadman.go b/socketmode/deadman.go deleted file mode 100644 index 7aeea760e..000000000 --- a/socketmode/deadman.go +++ /dev/null @@ -1,31 +0,0 @@ -package socketmode - -import "time" - -type deadmanTimer struct { - timeout time.Duration - timer *time.Timer -} - -func newDeadmanTimer(timeout time.Duration) *deadmanTimer { - return &deadmanTimer{ - timeout: timeout, - timer: time.NewTimer(timeout), - } -} - -func (smc *deadmanTimer) Elapsed() <-chan time.Time { - return smc.timer.C -} - -func (smc *deadmanTimer) Reset() { - // Note that this is the correct way to Reset a non-expired timer - if !smc.timer.Stop() { - select { - case <-smc.timer.C: - default: - } - } - - smc.timer.Reset(smc.timeout) -} diff --git a/socketmode/socket_mode_managed_conn.go b/socketmode/socket_mode_managed_conn.go index 5a98f1ef4..b94456f49 100644 --- a/socketmode/socket_mode_managed_conn.go +++ b/socketmode/socket_mode_managed_conn.go @@ -1,6 +1,7 @@ package socketmode import ( + "bytes" "context" "encoding/json" "errors" @@ -53,13 +54,14 @@ func (smc *Client) RunContext(ctx context.Context) error { } func (smc *Client) run(ctx context.Context, connectionCount int) error { - messages := make(chan json.RawMessage) - defer close(messages) - - deadmanTimer := newDeadmanTimer(smc.maxPingInterval) + messages := make(chan json.RawMessage, 1) + pingChan := make(chan time.Time, 1) pingHandler := func(_ string) error { - deadmanTimer.Reset() + select { + case pingChan <- time.Now(): + default: + } return nil } @@ -81,20 +83,24 @@ func (smc *Client) run(ctx context.Context, connectionCount int) error { ctx, cancel := context.WithCancel(ctx) defer cancel() - smc.Events <- newEvent(EventTypeConnected, &ConnectedEvent{ + smc.sendEvent(ctx, newEvent(EventTypeConnected, &ConnectedEvent{ ConnectionCount: connectionCount, Info: info, - }) + })) smc.Debugf("WebSocket connection succeeded on try %d", connectionCount) // We're now connected so we can set up listeners - var ( - wg sync.WaitGroup - firstErr error - firstErrOnce sync.Once - ) + wg := new(sync.WaitGroup) + // sendErr relies on the buffer of 1 here + errc := make(chan error, 1) + sendErr := func(err error) { + select { + case errc <- err: + default: + } + } wg.Add(1) go func() { @@ -103,9 +109,7 @@ func (smc *Client) run(ctx context.Context, connectionCount int) error { // The response sender sends Socket Mode responses over the WebSocket conn if err := smc.runResponseSender(ctx, conn); err != nil { - firstErrOnce.Do(func() { - firstErr = err - }) + sendErr(err) } }() @@ -116,56 +120,79 @@ func (smc *Client) run(ctx context.Context, connectionCount int) error { // The handler reads Socket Mode requests, and enqueues responses for sending by the response sender if err := smc.runRequestHandler(ctx, messages); err != nil { - firstErrOnce.Do(func() { - firstErr = err - }) + sendErr(err) } }() - // Need to wait for runMessageReceiver to avoid panic described in https://github.com/slack-go/slack/issues/1125 - wg.Add(1) go func() { - defer wg.Done() defer cancel() + // We close messages here as it is the producer for the channel. + defer close(messages) // The receiver reads WebSocket messages, and enqueues parsed Socket Mode requests to be handled by // the request handler if err := smc.runMessageReceiver(ctx, conn, messages); err != nil { - firstErrOnce.Do(func() { - firstErr = err - }) + sendErr(err) } }() wg.Add(1) - go func() { + go func(pingInterval time.Duration) { defer wg.Done() - - select { - case <-ctx.Done(): + defer func() { // Detect when the connection is dead and try close connection. - if err = conn.Close(); err != nil { + if err := conn.Close(); err != nil { smc.Debugf("Failed to close connection: %v", err) } - case <-deadmanTimer.Elapsed(): - firstErrOnce.Do(func() { - firstErr = errors.New("ping timeout: Slack did not send us WebSocket PING for more than Client.maxInterval") - }) + }() + + done := ctx.Done() + var lastPing time.Time + + // More efficient than constantly resetting a timer w/ Stop+Reset + ticker := time.NewTicker(pingInterval) + defer ticker.Stop() - cancel() + for { + select { + case <-done: + return + + case lastPing = <-pingChan: + // This case gets the time of the last ping. + // If this case never fires then the pingHandler was never called + // in which case lastPing is the zero time.Time value, and will 'fail' + // the next tick, causing us to exit. + + case now := <-ticker.C: + // Our last ping is older than our interval + if now.Sub(lastPing) > pingInterval { + sendErr(errors.New("ping timeout: Slack did not send us WebSocket PING for more than Client.maxInterval")) + + cancel() + return + } + } } - }() + }(smc.maxPingInterval) wg.Wait() - if firstErr == context.Canceled { - return firstErr + select { + case err = <-errc: + // Get buffered error + default: + // Or nothing if they all exited nil + } + + if errors.Is(err, context.Canceled) { + return err } // wg.Wait() finishes only after any of the above go routines finishes and cancels the // context, allowing the other threads to shut down gracefully. - // Also, we can expect firstErr to be not nil, as goroutines can finish only on error. - smc.Debugf("Reconnecting due to %v", firstErr) + // Also, we can expect our (first)err to be not nil, as goroutines can finish only on error. + smc.Debugf("Reconnecting due to %v", err) return nil } @@ -193,10 +220,10 @@ func (smc *Client) connect(ctx context.Context, connectionCount int, additionalP ) // send connecting event - smc.Events <- newEvent(EventTypeConnecting, &slack.ConnectingEvent{ + smc.sendEvent(ctx, newEvent(EventTypeConnecting, &slack.ConnectingEvent{ Attempt: boff.Attempts() + 1, ConnectionCount: connectionCount, - }) + })) // attempt to start the connection info, conn, err := smc.openAndDial(ctx, additionalPingHandler) @@ -212,26 +239,32 @@ func (smc *Client) connect(ctx context.Context, connectionCount int, additionalP default: } - switch actual := err.(type) { - case slack.StatusCodeError: - if actual.Code == http.StatusNotFound { - smc.Debugf("invalid auth when connecting with Socket Mode: %s", err) - smc.Events <- newEvent(EventTypeInvalidAuth, &slack.InvalidAuthEvent{}) - return nil, nil, err - } - case *slack.RateLimitedError: - backoff = actual.RetryAfter - default: + var ( + actual slack.StatusCodeError + rlError *slack.RateLimitedError + ) + + if errors.As(err, &actual) && actual.Code == http.StatusNotFound { + smc.Debugf("invalid auth when connecting with Socket Mode: %s", err) + smc.sendEvent(ctx, newEvent(EventTypeInvalidAuth, &slack.InvalidAuthEvent{})) + + return nil, nil, err + } else if errors.As(err, &rlError) { + backoff = rlError.RetryAfter } + // If we check for errors.Is(err, context.Canceled) here and + // return early then we don't send the Event below that some users + // may already rely on; ie a behavior change. + backoff = timex.Max(backoff, boff.Duration()) // any other errors are treated as recoverable and we try again after // sending the event along the Events channel - smc.Events <- newEvent(EventTypeConnectionError, &slack.ConnectionErrorEvent{ + smc.sendEvent(ctx, newEvent(EventTypeConnectionError, &slack.ConnectionErrorEvent{ Attempt: boff.Attempts(), Backoff: backoff, ErrorObj: err, - }) + })) // get time we should wait before attempting to connect again smc.Debugf("reconnection %d failed: %s reconnecting in %v\n", boff.Attempts(), err, backoff) @@ -239,9 +272,11 @@ func (smc *Client) connect(ctx context.Context, connectionCount int, additionalP // wait for one of the following to occur, // backoff duration has elapsed, disconnectCh is signalled, or // the smc finishes disconnecting. + timer := time.NewTimer(backoff) select { - case <-time.After(backoff): // retry after the backoff. + case <-timer.C: // retry after the backoff. case <-ctx.Done(): + timer.Stop() return nil, nil, ctx.Err() } } @@ -276,12 +311,13 @@ func (smc *Client) openAndDial(ctx context.Context, additionalPingHandler func(s smc.Debugf("Failed to dial to the websocket: %s", err) return nil, nil, err } + if additionalPingHandler == nil { + additionalPingHandler = func(_ string) error { return nil } + } conn.SetPingHandler(func(appData string) error { - if additionalPingHandler != nil { - if err := additionalPingHandler(appData); err != nil { - return err - } + if err := additionalPingHandler(appData); err != nil { + return err } smc.handlePing(conn, appData) @@ -312,10 +348,10 @@ func (smc *Client) runResponseSender(ctx context.Context, conn *websocket.Conn) smc.Debugf("Sending Socket Mode response with envelope ID %q: %v", res.EnvelopeID, res) if err := unsafeWriteSocketModeResponse(conn, res); err != nil { - smc.Events <- newEvent(EventTypeErrorWriteFailed, &ErrorWriteFailed{ + smc.sendEvent(ctx, newEvent(EventTypeErrorWriteFailed, &ErrorWriteFailed{ Cause: err, Response: res, - }) + })) } smc.Debugf("Finished sending Socket Mode response with envelope ID %q", res.EnvelopeID) @@ -332,16 +368,22 @@ func (smc *Client) runRequestHandler(ctx context.Context, websocket chan json.Ra select { case <-ctx.Done(): return ctx.Err() - case message := <-websocket: + case message, ok := <-websocket: + if !ok { + // The producer closed the channel because it encountered an error (or panic), + // we need only return. + return nil + } + smc.Debugf("Received WebSocket message: %s", message) // listen for incoming messages that need to be parsed evt, err := smc.parseEvent(message) if err != nil { - smc.Events <- newEvent(EventTypeErrorBadMessage, &ErrorBadMessage{ + smc.sendEvent(ctx, newEvent(EventTypeErrorBadMessage, &ErrorBadMessage{ Cause: err, Message: message, - }) + })) } else if evt != nil { if evt.Type == EventTypeDisconnect { // We treat the `disconnect` request from Slack as an error internally, @@ -349,7 +391,7 @@ func (smc *Client) runRequestHandler(ctx context.Context, websocket chan json.Ra return errorRequestedDisconnect{} } - smc.Events <- *evt + smc.sendEvent(ctx, *evt) } } } @@ -385,11 +427,7 @@ func unsafeWriteSocketModeResponse(conn *websocket.Conn, res *Response) error { // Remove write deadline regardless of WriteJSON succeeds or not defer conn.SetWriteDeadline(time.Time{}) - if err := conn.WriteJSON(res); err != nil { - return err - } - - return nil + return conn.WriteJSON(res) } func newEvent(tpe EventType, data interface{}, req ...*Request) Event { @@ -407,29 +445,54 @@ func newEvent(tpe EventType, data interface{}, req ...*Request) Event { // This tells Slack that the we have received the request denoted by the envelope ID, // by sending back the envelope ID over the WebSocket connection. func (smc *Client) Ack(req Request, payload ...interface{}) { - res := Response{ - EnvelopeID: req.EnvelopeID, - } - + var pld interface{} if len(payload) > 0 { - res.Payload = payload[0] + pld = payload[0] } - smc.Send(res) + smc.AckCtx(context.TODO(), req.EnvelopeID, pld) +} + +// AckCtx acknowledges the Socket Mode request envelope ID with the payload. +// +// This tells Slack that the we have received the request denoted by the request (envelope) ID, +// by sending back the ID over the WebSocket connection. +func (smc *Client) AckCtx(ctx context.Context, reqID string, payload interface{}) error { + return smc.SendCtx(ctx, Response{ + EnvelopeID: reqID, + Payload: payload, + }) } // Send sends the Socket Mode response over a WebSocket connection. // This is usually used for acknowledging requests, but if you need more control over Client.Ack(). // It's normally recommended to use Client.Ack() instead of this. func (smc *Client) Send(res Response) { - js, err := json.Marshal(res) - if err != nil { - panic(err) + smc.SendCtx(context.TODO(), res) +} + +// SendCtx sends the Socket Mode response over a WebSocket connection. +// This is usually used for acknowledging requests, but if you need more control +// it's normally recommended to use Client.AckCtx() instead of this. +func (smc *Client) SendCtx(ctx context.Context, res Response) error { + if smc.debug { + js, err := json.Marshal(res) + + // Log the error so users of `Send` don't see it entirely disappear as that method + // does not return an error and used to panic on failure (with or without debug) + smc.Debugf("Scheduling Socket Mode response (error: %v) for envelope ID %s: %s", err, res.EnvelopeID, js) + if err != nil { + return err + } } - smc.Debugf("Scheduling Socket Mode response for envelope ID %s: %s", res.EnvelopeID, js) + select { + case <-ctx.Done(): + return ctx.Err() + case smc.socketModeResponses <- &res: + } - smc.socketModeResponses <- &res + return nil } // receiveMessagesInto attempts to receive an event from the WebSocket connection for Socket Mode. @@ -439,75 +502,58 @@ func (smc *Client) receiveMessagesInto(ctx context.Context, conn *websocket.Conn smc.Debugf("Starting to receive message") defer smc.Debugf("Finished to receive message") - var err error - event := json.RawMessage{} - - readJsonErr := make(chan error) - go func() { - select { - case readJsonErr <- conn.ReadJSON(&event): - // if conn.ReadJSON method returns after ctx.Done(), no one will read the unbuffered channel - // so, to avoid goroutines leak we need to check for ctx.Done() here too. - // need to say here, that conn.ReadJSON will really return after ctx.Done(), because in that case - // conn is closed - break - case <-ctx.Done(): - // just need to listen ctx.Done, nothing has to be done here - break - } - }() - - select { - case err = <-readJsonErr: - // we have awaited response from conn.ReadJSON, so, we can handle error now - break - case <-ctx.Done(): - // context cancellation signal got, closing connection and returning - cerr := conn.Close() - if cerr != nil { - smc.Debugf("Failed to close connection: %v", cerr) + err := conn.ReadJSON(&event) + if err != nil { + // check if the connection was closed. + // This version of the gorilla/websocket package also does a type assertion + // on the error, rather than unwrapping it, so we'll do the unwrapping then pass + // the unwrapped error + var wsErr *websocket.CloseError + if errors.As(err, &wsErr) && websocket.IsUnexpectedCloseError(wsErr) { + return err } - return ctx.Err() - } - - // after select above, the err will be set to result from conn.ReadJSON, handling - - // check if the connection was closed. - if websocket.IsUnexpectedCloseError(err) { - return err - } + if errors.Is(err, io.ErrUnexpectedEOF) { + // EOF's don't seem to signify a failed connection so instead we ignore + // them here and detect a failed connection upon attempting to send a + // 'PING' message - switch { - case err == io.ErrUnexpectedEOF: - // EOF's don't seem to signify a failed connection so instead we ignore - // them here and detect a failed connection upon attempting to send a - // 'PING' message + // Unlike RTM, we don't ping from the our end as there seem to have no client ping. + // We just continue to the next loop so that we `smc.disconnected` should be received if + // this EOF error was actually due to disconnection. - // Unlike RTM, we don't ping from the our end as there seem to have no client ping. - // We just continue to the next loop so that we `smc.disconnected` should be received if - // this EOF error was actually due to disconnection. + return nil + } - return nil - case err != nil: // All other errors from ReadJSON come from NextReader, and should // kill the read loop and force a reconnect. - smc.Events <- newEvent(EventTypeIncomingError, &slack.IncomingEventError{ + // TODO: Unless it's a JSON unmarshal-type error in which case maybe reconnecting isn't needed... + smc.sendEvent(ctx, newEvent(EventTypeIncomingError, &slack.IncomingEventError{ ErrorObj: err, - }) + })) return err - case len(event) == 0: - smc.Debugln("Received empty event") - default: - select { - case sink <- event: - case <-ctx.Done(): - smc.Debugln("cancelled while attempting to send raw event") + } - return ctx.Err() + if smc.debug { + buf := &bytes.Buffer{} + d := json.NewEncoder(buf) + d.SetIndent("", " ") + if err := d.Encode(event); err != nil { + smc.Debugln("Failed encoding decoded json:", err) } + reencoded := buf.String() + + smc.Debugln("Incoming WebSocket message:", reencoded) + } + + select { + case sink <- event: + case <-ctx.Done(): + smc.Debugln("cancelled while attempting to send raw event") + + return ctx.Err() } return nil @@ -520,7 +566,7 @@ func (smc *Client) parseEvent(wsMsg json.RawMessage) (*Event, error) { req := &Request{} err := json.Unmarshal(wsMsg, req) if err != nil { - return nil, fmt.Errorf("unmarshalling WebSocket message: %v", err) + return nil, fmt.Errorf("unmarshalling WebSocket message: %w", err) } var evt Event @@ -536,7 +582,7 @@ func (smc *Client) parseEvent(wsMsg json.RawMessage) (*Event, error) { eventsAPIEvent, err := slackevents.ParseEvent(payloadEvent, slackevents.OptionNoVerifyToken()) if err != nil { - return nil, fmt.Errorf("parsing Events API event: %v", err) + return nil, fmt.Errorf("parsing Events API event: %w", err) } evt = newEvent(EventTypeEventsAPI, eventsAPIEvent, req) @@ -549,7 +595,7 @@ func (smc *Client) parseEvent(wsMsg json.RawMessage) (*Event, error) { var cmd slack.SlashCommand if err := json.Unmarshal(req.Payload, &cmd); err != nil { - return nil, fmt.Errorf("parsing slash command: %v", err) + return nil, fmt.Errorf("parsing slash command: %w", err) } evt = newEvent(EventTypeSlashCommand, cmd, req) @@ -563,7 +609,7 @@ func (smc *Client) parseEvent(wsMsg json.RawMessage) (*Event, error) { var callback slack.InteractionCallback if err := json.Unmarshal(req.Payload, &callback); err != nil { - return nil, fmt.Errorf("parsing interaction callback: %v", err) + return nil, fmt.Errorf("parsing interaction callback: %w", err) } evt = newEvent(EventTypeInteractive, callback, req) diff --git a/socketmode/socketmode.go b/socketmode/socketmode.go index 1871e6763..6ca8f487c 100644 --- a/socketmode/socketmode.go +++ b/socketmode/socketmode.go @@ -119,3 +119,15 @@ func New(api *slack.Client, options ...Option) *Client { return result } + +// sendEvent safely sends an event into the Clients Events channel +// and blocks until buffer space is had, or the context is canceled. +// This prevents deadlocking in the event that Events buffer is full, +// other goroutines are waiting, and/or timing allows receivers to exit +// before all senders are finished. +func (smc *Client) sendEvent(ctx context.Context, event Event) { + select { + case smc.Events <- event: + case <-ctx.Done(): + } +}