Skip to content

Commit

Permalink
Add Context to clients / sessions.
Browse files Browse the repository at this point in the history
The Context will be closed when the client disconnects / the session is removed,
so any pending requests can be cancelled.
  • Loading branch information
fancycode committed May 14, 2024
1 parent 94a8f0f commit ad54f75
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 75 deletions.
10 changes: 9 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package signaling

import (
"bytes"
"context"
"encoding/json"
"log"
"strconv"
Expand Down Expand Up @@ -93,6 +94,7 @@ type WritableClientMessage interface {
}

type HandlerClient interface {
Context() context.Context
RemoteAddr() string
Country() string
UserAgent() string
Expand Down Expand Up @@ -121,6 +123,7 @@ type ClientGeoIpHandler interface {
}

type Client struct {
ctx context.Context
conn *websocket.Conn
addr string
agent string
Expand All @@ -142,7 +145,7 @@ type Client struct {
messageChan chan *bytes.Buffer
}

func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler ClientHandler) (*Client, error) {
func NewClient(ctx context.Context, conn *websocket.Conn, remoteAddress string, agent string, handler ClientHandler) (*Client, error) {
remoteAddress = strings.TrimSpace(remoteAddress)
if remoteAddress == "" {
remoteAddress = "unknown remote address"
Expand All @@ -153,6 +156,7 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler
}

client := &Client{
ctx: ctx,
agent: agent,
logRTT: true,
}
Expand Down Expand Up @@ -181,6 +185,10 @@ func (c *Client) getHandler() ClientHandler {
return c.handler
}

func (c *Client) Context() context.Context {
return c.ctx
}

func (c *Client) IsConnected() bool {
return c.closed.Load() == 0
}
Expand Down
22 changes: 16 additions & 6 deletions clientsession.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ type ClientSession struct {
privateId string
publicId string
data *SessionIdData
ctx context.Context
closeFunc context.CancelFunc

clientType string
features []string
Expand Down Expand Up @@ -91,12 +93,15 @@ type ClientSession struct {
}

func NewClientSession(hub *Hub, privateId string, publicId string, data *SessionIdData, backend *Backend, hello *HelloClientMessage, auth *BackendClientAuthResponse) (*ClientSession, error) {
ctx, closeFunc := context.WithCancel(context.Background())
s := &ClientSession{
hub: hub,
events: hub.events,
privateId: privateId,
publicId: publicId,
data: data,
ctx: ctx,
closeFunc: closeFunc,

clientType: hello.Auth.Type,
features: hello.Features,
Expand Down Expand Up @@ -140,6 +145,10 @@ func NewClientSession(hub *Hub, privateId string, publicId string, data *Session
return s, nil
}

func (s *ClientSession) Context() context.Context {
return s.ctx
}

func (s *ClientSession) PrivateId() string {
return s.privateId
}
Expand Down Expand Up @@ -337,7 +346,7 @@ func (s *ClientSession) getRoomJoinTime() time.Time {
func (s *ClientSession) releaseMcuObjects() {
if len(s.publishers) > 0 {
go func(publishers map[StreamType]McuPublisher) {
ctx := context.TODO()
ctx := context.Background()
for _, publisher := range publishers {
publisher.Close(ctx)
}
Expand All @@ -346,7 +355,7 @@ func (s *ClientSession) releaseMcuObjects() {
}
if len(s.subscribers) > 0 {
go func(subscribers map[string]McuSubscriber) {
ctx := context.TODO()
ctx := context.Background()
for _, subscriber := range subscribers {
subscriber.Close(ctx)
}
Expand All @@ -360,6 +369,7 @@ func (s *ClientSession) Close() {
}

func (s *ClientSession) closeAndWait(wait bool) {
s.closeFunc()
s.hub.removeSession(s)

s.mu.Lock()
Expand Down Expand Up @@ -885,7 +895,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea
if prev, found := s.publishers[streamType]; found {
// Another thread created the publisher while we were waiting.
go func(pub McuPublisher) {
closeCtx := context.TODO()
closeCtx := context.Background()
pub.Close(closeCtx)
}(publisher)
publisher = prev
Expand Down Expand Up @@ -962,7 +972,7 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id s
if prev, found := s.subscribers[getStreamId(id, streamType)]; found {
// Another thread created the subscriber while we were waiting.
go func(sub McuSubscriber) {
closeCtx := context.TODO()
closeCtx := context.Background()
sub.Close(closeCtx)
}(subscriber)
subscriber = prev
Expand Down Expand Up @@ -1036,7 +1046,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) {
case "sendoffer":
// Process asynchronously to not block other messages received.
go func() {
ctx, cancel := context.WithTimeout(context.Background(), s.hub.mcuTimeout)
ctx, cancel := context.WithTimeout(s.Context(), s.hub.mcuTimeout)
defer cancel()

mc, err := s.GetOrCreateSubscriber(ctx, s.hub.mcu, message.SendOffer.SessionId, StreamType(message.SendOffer.Data.RoomType))
Expand Down Expand Up @@ -1068,7 +1078,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) {
return
}

mc.SendMessage(context.TODO(), nil, message.SendOffer.Data, func(err error, response map[string]interface{}) {
mc.SendMessage(s.Context(), nil, message.SendOffer.Data, func(err error, response map[string]interface{}) {
if err != nil {
log.Printf("Could not send MCU message %+v for session %s to %s: %s", message.SendOffer.Data, message.SendOffer.SessionId, s.PublicId(), err)
if err := s.events.PublishSessionMessage(message.SendOffer.SessionId, s.backend, &AsyncMessage{
Expand Down
2 changes: 1 addition & 1 deletion grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (p *SessionProxy) recvPump() {
for {
msg, err := p.client.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
if errors.Is(err, io.EOF) || status.Code(err) == codes.Canceled {
break
}

Expand Down
4 changes: 4 additions & 0 deletions grpc_remote_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ func (c *remoteGrpcClient) readPump() {
}
}

func (c *remoteGrpcClient) Context() context.Context {
return c.client.Context()
}

func (c *remoteGrpcClient) RemoteAddr() string {
return c.remoteAddr
}
Expand Down

0 comments on commit ad54f75

Please sign in to comment.