Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] return early in websocket upgrade handler #1315

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion internal/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
package api

import (
"time"

"github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api/client/accounts"
"github.com/superseriousbusiness/gotosocial/internal/api/client/admin"
Expand Down Expand Up @@ -122,7 +124,7 @@ func NewClient(db db.DB, p processing.Processor) *Client {
notifications: notifications.New(p),
search: search.New(p),
statuses: statuses.New(p),
streaming: streaming.New(p),
streaming: streaming.New(p, time.Second*30, 4096),
timelines: timelines.New(p),
user: user.New(p),
}
Expand Down
152 changes: 94 additions & 58 deletions internal/api/client/streaming/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
package streaming

import (
"context"
"errors"
"fmt"
"net/http"
"time"

"codeberg.org/gruf/go-kv"
Expand All @@ -32,16 +33,6 @@ import (
"github.com/gorilla/websocket"
)

var (
wsUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// we expect cors requests (via eg., pinafore.social) so be lenient
CheckOrigin: func(r *http.Request) bool { return true },
}
errNoToken = fmt.Errorf("no access token provided under query key %s or under header %s", AccessTokenQueryKey, AccessTokenHeader)
)

// StreamGETHandler swagger:operation GET /api/v1/streaming streamGet
//
// Initiate a websocket connection for live streaming of statuses and notifications.
Expand Down Expand Up @@ -150,21 +141,20 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
return
}

var accessToken string
if t := c.Query(AccessTokenQueryKey); t != "" {
// try query param first
accessToken = t
} else if t := c.GetHeader(AccessTokenHeader); t != "" {
// fall back to Sec-Websocket-Protocol
accessToken = t
} else {
// no token
err := errNoToken
apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return
var token string

// First we check for a query param provided access token
if token = c.Query(AccessTokenQueryKey); token == "" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i learn so much from reading your code kimbe, this little bit is just an example of that: it hadn't occurred to me to nest the checks like this, and also didn't think of using const to hint to the compiler like this

// Else we check the HTTP header provided token
if token = c.GetHeader(AccessTokenHeader); token == "" {
const errStr = "no access token provided"
err := gtserror.NewErrorUnauthorized(errors.New(errStr), errStr)
apiutil.ErrorHandler(c, err, m.processor.InstanceGet)
return
}
}

account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), accessToken)
account, errWithCode := m.processor.AuthorizeStreamingRequest(c.Request.Context(), token)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGet)
return
Expand All @@ -178,51 +168,97 @@ func (m *Module) StreamGETHandler(c *gin.Context) {

l := log.WithFields(kv.Fields{
{"account", account.Username},
{"path", BasePath},
{"streamID", stream.ID},
{"streamType", streamType},
}...)

wsConn, err := wsUpgrader.Upgrade(c.Writer, c.Request, nil)
// Upgrade the incoming HTTP request, which hijacks the underlying
// connection and reuses it for the websocket (non-http) protocol.
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
if err != nil {
// If the upgrade fails, then Upgrade replies to the client with an HTTP error response.
// Because websocket issues are a pretty common source of headaches, we should also log
// this at Error to make this plenty visible and help admins out a bit.
l.Errorf("error upgrading websocket connection: %s", err)
l.Errorf("error upgrading websocket connection: %v", err)
close(stream.Hangup)
return
}

defer func() {
// cleanup
wsConn.Close()
close(stream.Hangup)
}()
go func() {
// We perform the main websocket send loop in a separate
// goroutine in order to let the upgrade handler return.
// This prevents the upgrade handler from holding open any
// throttle / rate-limit request tokens which could become
// problematic on instances with multiple users.
l.Info("opened websocket connection")
defer l.Info("closed websocket connection")

// Create new context for lifetime of the connection
ctx, cncl := context.WithCancel(context.Background())

// Create ticker to send alive pings
pinger := time.NewTicker(m.dTicker)

defer func() {
// Signal done
cncl()

streamTicker := time.NewTicker(m.tickDuration)
defer streamTicker.Stop()

// We want to stay in the loop as long as possible while the client is connected.
// The only thing that should break the loop is if the client leaves or the connection becomes unhealthy.
//
// If the loop does break, we expect the client to reattempt connection, so it's cheap to leave + try again
wsLoop:
for {
select {
case m := <-stream.Messages:
l.Trace("received message from stream")
if err := wsConn.WriteJSON(m); err != nil {
l.Debugf("error writing json to websocket connection; breaking off: %s", err)
break wsLoop
// Close websocket conn
_ = wsConn.Close()

// Close processor stream
close(stream.Hangup)

// Stop ping ticker
pinger.Stop()
}()

go func() {
// Signal done
defer cncl()

for {
// We have to listen for received websocket messages in
// order to trigger the underlying wsConn.PingHandler().
//
// So we wait on received messages but only act on errors.
_, _, err := wsConn.ReadMessage()
if err != nil {
if ctx.Err() == nil {
// Only log error if the connection was not closed
// by us. Uncanceled context indicates this is the case.
l.Errorf("error reading from websocket: %v", err)
}
return
}
}
l.Trace("wrote message into websocket connection")
case <-streamTicker.C:
l.Trace("received TICK from ticker")
if err := wsConn.WriteMessage(websocket.PingMessage, []byte(": ping")); err != nil {
l.Debugf("error writing ping to websocket connection; breaking off: %s", err)
break wsLoop
}()

for {
select {
// Connection closed
case <-ctx.Done():
return

// Received next stream message
case msg := <-stream.Messages:
l.Tracef("sending message to websocket: %+v", msg)
if err := wsConn.WriteJSON(msg); err != nil {
l.Errorf("error writing json to websocket: %v", err)
return
}

// Reset on each successful send.
pinger.Reset(m.dTicker)

// Send keep-alive "ping"
case <-pinger.C:
l.Trace("pinging websocket ...")
if err := wsConn.WriteMessage(
websocket.PingMessage,
[]byte{},
); err != nil {
l.Errorf("error writing ping to websocket: %v", err)
return
}
}
l.Trace("wrote ping message into websocket connection")
}
}
}()
}
24 changes: 13 additions & 11 deletions internal/api/client/streaming/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"time"

"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/superseriousbusiness/gotosocial/internal/processing"
)

Expand All @@ -41,21 +42,22 @@ const (
)

type Module struct {
processor processing.Processor
tickDuration time.Duration
processor processing.Processor
dTicker time.Duration
wsUpgrade websocket.Upgrader
}

func New(processor processing.Processor) *Module {
func New(processor processing.Processor, dTicker time.Duration, wsBuf int) *Module {
return &Module{
processor: processor,
tickDuration: 30 * time.Second,
}
}
processor: processor,
dTicker: dTicker,
wsUpgrade: websocket.Upgrader{
ReadBufferSize: wsBuf, // we don't expect reads
WriteBufferSize: wsBuf,

func NewWithTickDuration(processor processing.Processor, tickDuration time.Duration) *Module {
return &Module{
processor: processor,
tickDuration: tickDuration,
// we expect cors requests (via eg., pinafore.social) so be lenient
CheckOrigin: func(r *http.Request) bool { return true },
},
}
}

Expand Down
2 changes: 1 addition & 1 deletion internal/api/client/streaming/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (suite *StreamingTestSuite) SetupTest() {
suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker)
suite.streamingModule = streaming.NewWithTickDuration(suite.processor, 1)
suite.streamingModule = streaming.New(suite.processor, 1, 4096)
suite.NoError(suite.processor.Start())
}

Expand Down