Skip to content

Commit

Permalink
Merge pull request #219 from peer-calls/jeremija/issue-217
Browse files Browse the repository at this point in the history
Close transport when clientID already exists
  • Loading branch information
jeremija committed May 17, 2021
2 parents 6ea4433 + 7efc849 commit 90a37a7
Show file tree
Hide file tree
Showing 19 changed files with 443 additions and 570 deletions.
3 changes: 2 additions & 1 deletion server/cli/play.go
Expand Up @@ -232,7 +232,7 @@ func (h *playHandler) handleMessages(ctx context.Context, wsClient *server.Clien
signaller *server.Signaller
}

messagesChan := wsClient.Subscribe(ctx)
messagesChan := wsClient.Messages()

err := wsClient.Write(message.NewReady(h.roomID, message.Ready{
Nickname: h.args.nickname,
Expand All @@ -252,6 +252,7 @@ func (h *playHandler) handleMessages(ctx context.Context, wsClient *server.Clien
for {
select {
case <-ctx.Done():
wsClient.Close(websocket.StatusNormalClosure, "")
return
case msg, ok := <-messagesChan:
if !ok {
Expand Down
4 changes: 2 additions & 2 deletions server/logformatter/log_formatter.go
Expand Up @@ -70,15 +70,15 @@ func (f *LogFormatter) Format(message logger.Message) ([]byte, error) {
message.Level,
namespace,
clientID,
message.Body,
strings.TrimRight(message.Body, "\n"),
b.String(),
)
} else {
ret = fmt.Sprintf("%s %5s [%20s] %s%s\n",
message.Timestamp.Format(timeLayout),
message.Level,
namespace,
message.Body,
strings.TrimRight(message.Body, "\n"),
b.String(),
)
}
Expand Down
2 changes: 1 addition & 1 deletion server/logger/formatter.go
Expand Up @@ -78,7 +78,7 @@ func (f *StringFormatter) Format(message Message) ([]byte, error) {
message.Timestamp.Format(f.params.DateLayout),
message.Level,
message.Namespace,
message.Body,
strings.TrimRight(message.Body, "\n"),
b.String(),
)

Expand Down
18 changes: 15 additions & 3 deletions server/memoryadapter.go
Expand Up @@ -24,18 +24,30 @@ func NewMemoryAdapter(room identifiers.RoomID) *MemoryAdapter {
}
}

// Add a client to the room
// Add a client to the room. Will return an error on duplicate client ID.
func (m *MemoryAdapter) Add(client ClientWriter) (err error) {
m.clientsMu.Lock()

clientID := client.ID()
m.clients[clientID] = client

if _, ok := m.clients[clientID]; ok {
err = errors.Annotatef(ErrDuplicateClientID, "%s", clientID)
} else {
m.clients[clientID] = client
}

m.clientsMu.Unlock()

if err != nil {
return errors.Trace(err)
}

err = m.broadcast(
message.NewRoomJoin(m.room, message.RoomJoin{
ClientID: clientID,
Metadata: client.Metadata(),
}),
)
m.clientsMu.Unlock()
return errors.Annotatef(err, "add client: %s", clientID)
}

Expand Down
35 changes: 26 additions & 9 deletions server/memoryadapter_test.go
Expand Up @@ -11,28 +11,38 @@ import (
"github.com/peer-calls/peer-calls/server/message"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"nhooyr.io/websocket"
)

func TestMemoryAdapter_add_remove_clients(t *testing.T) {
goleak.VerifyNone(t)
adapter := server.NewMemoryAdapter(room)
mockWriter := NewMockWriter()
client := server.NewClient(mockWriter)

defer client.Close(websocket.StatusNormalClosure, "")

client.SetMetadata("a")
clientID := client.ID()

err := adapter.Add(client)
assert.Nil(t, err)

clientIDs, err := adapter.Clients()
assert.Nil(t, err)
assert.Equal(t, map[identifiers.ClientID]string{clientID: "a"}, clientIDs)

size, err := adapter.Size()
assert.Nil(t, err)
assert.Equal(t, 1, size)

err = adapter.Remove(clientID)
assert.Nil(t, err)
clientIDs, err = adapter.Clients()

assert.Nil(t, err)
assert.Equal(t, map[identifiers.ClientID]string{}, clientIDs)

size, err = adapter.Size()
assert.Nil(t, err)
assert.Equal(t, 0, size)
Expand All @@ -45,11 +55,10 @@ func TestMemoryAdapter_emitFound(t *testing.T) {
defer close(mockWriter.out)
client := server.NewClient(mockWriter)
adapter.Add(client)
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
msgChan := client.Subscribe(ctx)
msgChan := client.Messages()
for range msgChan {
}
err := client.Err()
Expand All @@ -70,7 +79,9 @@ func TestMemoryAdapter_emitFound(t *testing.T) {

assert.Equal(t, joinMessage, msg1)
msg2 := <-mockWriter.out
cancel()

client.Close(websocket.StatusNormalClosure, "")

assert.Equal(t, serialize(t, msg), msg2)
wg.Wait()
}
Expand All @@ -93,24 +104,24 @@ func TestMemoryAdapter_Broadcast(t *testing.T) {
client1 := server.NewClient(mockWriter1)
mockWriter2 := NewMockWriter()
client2 := server.NewClient(mockWriter2)

defer close(mockWriter1.out)
defer close(mockWriter2.out)

assert.Nil(t, adapter.Add(client1))
assert.Nil(t, adapter.Add(client2))
ctx, cancel := context.WithCancel(context.Background())

var wg sync.WaitGroup
wg.Add(2)
go func() {
msgChan := client1.Subscribe(ctx)
for range msgChan {
for range client1.Messages() {
}
err := client1.Err()
assert.True(t, errIs(errors.Cause(err), context.Canceled), "expected context.Canceled, but got: %s", err)
wg.Done()
}()
go func() {
msgChan := client2.Subscribe(ctx)
for range msgChan {
for range client2.Messages() {
}
err := client2.Err()
assert.True(t, errIs(errors.Cause(err), context.Canceled), "expected context.Canceled, but got: %s", err)
Expand All @@ -129,6 +140,12 @@ func TestMemoryAdapter_Broadcast(t *testing.T) {
serializedMsg := serialize(t, msg)
assert.Equal(t, serializedMsg, <-mockWriter1.out)
assert.Equal(t, serializedMsg, <-mockWriter2.out)
cancel()

err := client1.Close(websocket.StatusNormalClosure, "")
assert.NoError(t, err, "closing websocket client1")

err = client2.Close(websocket.StatusNormalClosure, "")
assert.NoError(t, err, "closing websocket client2")

wg.Wait()
}
30 changes: 18 additions & 12 deletions server/mesh.go
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/peer-calls/peer-calls/server/identifiers"
"github.com/peer-calls/peer-calls/server/logger"
"github.com/peer-calls/peer-calls/server/message"
"nhooyr.io/websocket"
)

type ReadyMessage struct {
Expand All @@ -19,24 +20,29 @@ func NewMeshHandler(log logger.Logger, wss *WSS) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
log = log.WithNamespaceAppended("mesh")

sub, err := wss.Subscribe(w, r)
websocketCtx, err := wss.NewWebsocketContext(w, r)
if err != nil {
log.Error("Subscribe to websocket", errors.Trace(err), nil)
log.Error("Create websocket context", errors.Trace(err), nil)
return
}

for msg := range sub.Messages {
adapter := sub.Adapter
room := sub.Room
clientID := sub.ClientID
roomID := websocketCtx.RoomID()
clientID := websocketCtx.ClientID()

// Just in case. I'm actually not sure if this is necessary since if the
// reading stops, it most likely means the connection has already been
// closed.
defer websocketCtx.Close(websocket.StatusNormalClosure, "")

for msg := range websocketCtx.Messages() {
adapter := websocketCtx.Adapter()

log = log.WithCtx(logger.Ctx{
"client_id": clientID,
"room_id": room,
"room_id": roomID,
})

var (
err error
)
var err error

switch msg.Type {
case message.TypeHangUp:
Expand All @@ -54,7 +60,7 @@ func NewMeshHandler(log logger.Logger, wss *WSS) http.Handler {
log.Info(fmt.Sprintf("Got clients: %s", clients), nil)

err = adapter.Broadcast(
message.NewUsers(room, message.Users{
message.NewUsers(roomID, message.Users{
Initiator: clientID,
PeerIDs: clientsToPeerIDs(clients),
Nicknames: clients,
Expand All @@ -69,7 +75,7 @@ func NewMeshHandler(log logger.Logger, wss *WSS) http.Handler {
log.Info("Send signal to", logger.Ctx{
"target_client_id": targetClientID,
})
err = adapter.Emit(targetClientID, message.NewSignal(room, message.UserSignal{
err = adapter.Emit(targetClientID, message.NewSignal(roomID, message.UserSignal{
Signal: signal.Signal,
PeerID: clientID,
}))
Expand Down
1 change: 1 addition & 0 deletions server/nodemanager.go
Expand Up @@ -134,6 +134,7 @@ func (nm *NodeManager) handleTransport(transport *udptransport2.Transport) error

ch, err := nm.params.TracksManager.Add(streamID, transport)
if err != nil {
transport.Close()
return errors.Annotatef(err, "add transport: %s", streamID)
}

Expand Down
2 changes: 2 additions & 0 deletions server/pubsub/events.go
Expand Up @@ -67,11 +67,13 @@ func (s *events) start(in <-chan PubTrackEvent) {
}

case req := <-s.subRequestsChan:
// Unsubscribe existing subscription.
if out, ok := subs[req.clientID]; ok {
delete(subs, req.clientID)
close(out)
}

// Subscribe if necessary.
if req.typ == subRequestTypeSubscribe {
sub := make(chan PubTrackEvent, s.bufferSize)
subs[req.clientID] = sub
Expand Down
18 changes: 16 additions & 2 deletions server/redisadapter.go
Expand Up @@ -109,9 +109,23 @@ func (a *RedisAdapter) Add(client ClientWriter) (err error) {
})

a.clientsMu.Lock()
a.clients[clientID] = client

// TODO what if a client with the same ID has joined another
// node? We need to be smarter about this.
//
// Perhaps an easy solution is to add a node ID as a key prefix.
if _, ok := a.clients[clientID]; ok {
err = errors.Annotatef(ErrDuplicateClientID, "%s", clientID)
} else {
a.clients[clientID] = client
}

a.clientsMu.Unlock()

if err != nil {
return errors.Trace(err)
}

join := message.RoomJoin{
ClientID: clientID,
Metadata: client.Metadata(),
Expand Down Expand Up @@ -187,7 +201,7 @@ func (a *RedisAdapter) SetMetadata(clientID identifiers.ClientID, metadata strin
_, err := a.pubRedis.HSet(a.keys.roomClients, clientID.String(), metadata).Result()
if err != nil {
// FIXME return error
a.log.Error("Setmetadata", errors.Trace(err), logCtx)
a.log.Error("SetMetadata", errors.Trace(err), logCtx)
}

return err == nil
Expand Down
15 changes: 11 additions & 4 deletions server/redisadapter_test.go
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/peer-calls/peer-calls/server/test"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"nhooyr.io/websocket"
)

func errIs(err error, target error) bool {
Expand Down Expand Up @@ -50,21 +51,21 @@ func TestRedisAdapter_add_remove_client(t *testing.T) {
adapter1 := server.NewRedisAdapter(test.NewLogger(), pub, sub, "peercalls", room)
mockWriter1 := NewMockWriter()
defer close(mockWriter1.out)

client1 := server.NewClient(mockWriter1)
client1.SetMetadata("a")
mockWriter2 := NewMockWriter()
defer close(mockWriter2.out)

client2 := server.NewClient(mockWriter2)
client2.SetMetadata("b")
ctx, cancel := context.WithCancel(context.Background())

var wg sync.WaitGroup
wg.Add(2)

for _, client := range []*server.Client{client1, client2} {
go func(client *server.Client) {
msgChan := client.Subscribe(ctx)
for range msgChan {
for range client.Messages() {
}
err := client.Err()
assert.True(t, errIs(errors.Cause(err), context.Canceled), "expected error to be context.Canceled, but was: %s", err)
Expand Down Expand Up @@ -132,6 +133,12 @@ func TestRedisAdapter_add_remove_client(t *testing.T) {
err := stop()
assert.Equal(t, nil, err)
}
cancel()

err = client1.Close(websocket.StatusNormalClosure, "")
assert.NoError(t, err, "closing websocket client1")

err = client2.Close(websocket.StatusNormalClosure, "")
assert.NoError(t, err, "closing websocket client2")

wg.Wait()
}

0 comments on commit 90a37a7

Please sign in to comment.