Skip to content

Commit

Permalink
wip: add Subscribe & Unsubscribe process to websocket connections
Browse files Browse the repository at this point in the history
  • Loading branch information
ramfox committed Sep 23, 2021
1 parent 0332399 commit a0dd49a
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 15 deletions.
208 changes: 196 additions & 12 deletions lib/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,39 @@ package lib

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"

"github.com/google/uuid"
"github.com/qri-io/qri/auth/key"
"github.com/qri-io/qri/auth/token"
"github.com/qri-io/qri/event"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
)

const qriWebsocketProtocol = "qri-websocket"

// newID returns a new websocket connection ID
func newID() string {
return uuid.New().String()
}

// SetIDRand sets the random reader that NewID uses as a source of random bytes
// passing in nil will default to crypto.Rand. This can be used to make ID
// generation deterministic for tests. eg:
// myString := "SomeRandomStringThatIsLong-SoYouCanCallItAsMuchAsNeeded..."
// workflow.SetIDRand(strings.NewReader(myString))
// a := NewID()
// workflow.SetIDRand(strings.NewReader(myString))
// b := NewID()
func SetIDRand(r io.Reader) {
uuid.SetRand(r)
}

// WebsocketHandler defines the handler interface
type WebsocketHandler interface {
WSConnectionHandler(w http.ResponseWriter, r *http.Request)
Expand All @@ -20,7 +44,16 @@ type WebsocketHandler interface {
// and serves to maintain the list of connections
type wsHandler struct {
// Collect all websocket connections
conns []*websocket.Conn
conns map[string]*wsConn
connsLock sync.Mutex
keystore key.Store
subscriptions map[string]string
subLock sync.Mutex
}

type wsConn struct {
profileID string
conn *websocket.Conn
}

var _ WebsocketHandler = (*wsHandler)(nil)
Expand All @@ -29,7 +62,11 @@ var _ WebsocketHandler = (*wsHandler)(nil)
// can connect to in order to get realtime events
func NewWebsocketHandler(ctx context.Context, inst *Instance) (WebsocketHandler, error) {
ws := &wsHandler{
conns: []*websocket.Conn{},
conns: map[string]*wsConn{},
connsLock: sync.Mutex{},
keystore: inst.keystore,
subscriptions: map[string]string{},
subLock: sync.Mutex{},
}

inst.bus.SubscribeAll(ws.wsMessageHandler)
Expand All @@ -38,15 +75,22 @@ func NewWebsocketHandler(ctx context.Context, inst *Instance) (WebsocketHandler,

// WSConnectionHandler handles websocket upgrade requests and accepts the connection
func (h *wsHandler) WSConnectionHandler(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
Subprotocols: []string{qriWebsocketProtocol},
InsecureSkipVerify: true,
})
if err != nil {
log.Debugf("Websocket accept error: %s", err)
return
}
h.conns = append(h.conns, c)
connID := newID()
wsc := &wsConn{
conn: conn,
}
h.connsLock.Lock()
defer h.connsLock.Unlock()
h.conns[connID] = wsc
go h.read(connID)
}

func (h *wsHandler) wsMessageHandler(_ context.Context, e event.Event) error {
Expand All @@ -58,14 +102,154 @@ func (h *wsHandler) wsMessageHandler(_ context.Context, e event.Event) error {
"data": e.Payload,
}

log.Debugf("sending event %q to %d websocket conns", e.Type, len(h.conns))
for k, c := range h.conns {
go func(k int, c *websocket.Conn) {
err := wsjson.Write(ctx, c, evt)
if err != nil {
log.Errorf("connection %d: wsjson write error: %s", k, err)
}
}(k, c)
profileIDString := e.ProfileID
if profileIDString == "" {
log.Debugf("Event with SessionID %q has no scope. Not sending event over websocket.", e.SessionID)
return nil
}
connID, ok := h.subscriptions[profileIDString]
if !ok {
return fmt.Errorf("no websocket connection ID found for profile %q", profileIDString)
}
c, ok := h.conns[connID]
if !ok {
h.unsubscribeConn(profileIDString)
return fmt.Errorf("no websocket connection found for connection ID %q, profile %q", connID, profileIDString)
}
log.Debugf("sending event %q to websocket conns %q", e.Type, profileIDString)
err := wsjson.Write(ctx, c.conn, evt)
if err != nil {
log.Errorf("connection %q: wsjson write error: %s", profileIDString, err)
}
return nil
}

// subscribeConn authenticates the given token and adds the connID to the map
// of "subscribed" connections
func (h *wsHandler) subscribeConn(connID, tokenString string) error {
ctx := context.TODO()
tok, err := token.ParseAuthToken(ctx, tokenString, h.keystore)
if err != nil {
return err
}

claims, ok := tok.Claims.(*token.Claims)
if !ok || claims.Subject == "" {
return fmt.Errorf("cannot get profile.ID from token")
}
// TODO(b5): at this point we have a valid signature of a profileID string
// but no proof that this profile is owned by the key that signed the
// token. We either need ProfileID == KeyID, or we need a UCAN. we need to
// check for those, ideally in a method within the profile package that
// abstracts over profile & key agreement

h.connsLock.Lock()
c, ok := h.conns[connID]
if !ok {
return fmt.Errorf("no connection for connection ID %q found", connID)
}
c.profileID = claims.Subject
h.connsLock.Unlock()

h.subLock.Lock()
defer h.subLock.Unlock()
h.subscriptions[claims.Subject] = connID
return nil
}

// unsubscribeConn remove the profileID and connID from the map of "subscribed"
// connections
func (h *wsHandler) unsubscribeConn(profileID string) {
h.subLock.Lock()
defer h.subLock.Unlock()
delete(h.subscriptions, profileID)
}

// removeConn removes the conn from the map of connections and subscriptions
// closing the connection if needed
func (h *wsHandler) removeConn(connID string) {
c, ok := h.conns[connID]
if !ok {
return
}
defer func() {
c.conn.Close(websocket.StatusNormalClosure, "pruning connection")
}()
if c.profileID != "" {
h.unsubscribeConn(c.profileID)
}
h.connsLock.Lock()
defer h.connsLock.Unlock()
delete(h.conns, connID)
}

// read listens to the given connection, handling any messages that come through
// stops listening if it encounters any error
func (h *wsHandler) read(id string) error {
ctx := context.Background()
msg := &message{}
var err error
wsc, ok := h.conns[id]
if !ok {
return fmt.Errorf("connection for connection ID %q not found", id)
}

for {
err = wsjson.Read(ctx, wsc.conn, msg)
if err != nil {
// all websocket methods that return w/ failure
// close the connection
h.removeConn(id)
return err
}
go h.handleMessage(id, msg)
}
}

// handleMessage handles each message based on msgType
func (h *wsHandler) handleMessage(id string, msg *message) {
switch msg.Type {
case wsSubscribe:
subMsg := &subscribeMessage{}
err := json.Unmarshal(msg.Payload, subMsg)
if err != nil {
log.Debugf("connection %q - error unmarshaling payload for subscribe message: %s", id, err)
}
h.subscribeConn(id, subMsg.Token)
case wsUnsubscribe:
c, ok := h.conns[id]
if !ok {
log.Errorf("conn not found %q", id)
return
}
h.unsubscribeConn(c.profileID)
default:
log.Debug("unknown message type over websocket %s: %q", id, msg.Type)
}
}

// msgType is the type of message that we receive on the
type msgType string

const (
// wsSubscribe indicates the connection is trying to become
// an authenticated connection
// payload is a `subscribeMessage`
wsSubscribe = msgType("subscribe")
// wsUnsubscribe indicates the connection no longer wants
// to be authenticated
// payload is nil
wsUnsubscribe = msgType("unsubscribe")
)

// message is the expected structure of an incoming websocket message
type message struct {
Type msgType `json:"type"`
Payload json.RawMessage `json:"payload"`
}

// subscribeMessage is the expected structure of an incoming "subscribe"
// message
type subscribeMessage struct {
Token string `json:"token"`
}
98 changes: 95 additions & 3 deletions lib/websocket_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
package lib

import (
"bufio"
"bytes"
"context"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/qri-io/qfs"
"github.com/qri-io/qri/auth/key"
testkeys "github.com/qri-io/qri/auth/key/test"
"github.com/qri-io/qri/auth/token"
testcfg "github.com/qri-io/qri/config/test"
repotest "github.com/qri-io/qri/repo/test"
)
Expand All @@ -26,23 +35,106 @@ func TestWebsocket(t *testing.T) {
instCtx, instCancel := context.WithCancel(context.Background())
defer instCancel()

inst, err := NewInstance(instCtx, tr.QriPath, OptConfig(cfg))
// create key store & add test key
kd := testkeys.GetKeyData(0)
ks, err := key.NewMemStore()
if err != nil {
t.Fatal(err)
}
if err := ks.AddPubKey(context.Background(), kd.KeyID, kd.PrivKey.GetPublic()); err != nil {
t.Fatal(err)
}

// create instance
inst, err := NewInstance(instCtx, tr.QriPath, OptConfig(cfg), OptKeyStore(ks))
if err != nil {
t.Fatal(err)
}

subsCount := inst.bus.NumSubscribers()

wsCtx, wsCancel := context.WithCancel(context.Background())
_, err = NewWebsocketHandler(wsCtx, inst)
defer wsCancel()

// create WebsocketHandler
websocketHandler, err := NewWebsocketHandler(wsCtx, inst)
if err != nil {
t.Fatal(err)
}
wsh := websocketHandler.(*wsHandler)

// websockets should subscribe the WS message handler
if inst.bus.NumSubscribers() != subsCount+1 {
t.Fatalf("failed to subscribe websocket handlers")
}

wsCancel()
// add connection
randIDStr := "test_connection_id_str"
SetIDRand(strings.NewReader(randIDStr))
connID := newID()
SetIDRand(strings.NewReader(randIDStr))

wsh.WSConnectionHandler(mockWebsocketWriterAndRequest())
_, ok := wsh.conns[connID]
if !ok {
t.Fatal("WSConnectionHandler did not create a connection")
}

// create a token from a private key
kd = testkeys.GetKeyData(0)
tokenStr, err := token.NewPrivKeyAuthToken(kd.PrivKey, kd.KeyID.String(), 0)
if err != nil {
t.Fatal(err)
}
// upgrade connection w/ valid token
wsh.subscribeConn(connID, tokenStr)
proID := kd.KeyID.String()
gotConnID, ok := wsh.subscriptions[proID]
if !ok {
t.Fatal("wsHandler.SubscribeConn did not add profileID or conn to subscriptions map")
}
if gotConnID != connID {
t.Fatalf("wsHandler.SubscribeConn added incorrect connID to subscriptions map, expected %q, got %q", connID, gotConnID)
}

// unsubscribe connection via profileID
wsh.unsubscribeConn(proID)
_, ok = wsh.subscriptions[proID]
if ok {
t.Fatal("wsHandler.UnsubscribeConn did not remove the profileID from the subscription map")
}

// remove the connection
wsh.removeConn(connID)
_, ok = wsh.conns[connID]
if ok {
t.Fatal("wsHandler.Removeconn did not remove the connection from the map of conns")
}
}

func mockWebsocketWriterAndRequest() (http.ResponseWriter, *http.Request) {
w := mockHijacker{
ResponseWriter: httptest.NewRecorder(),
}

r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "keep-alive, Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "test_key")
return w, r
}

type mockHijacker struct {
http.ResponseWriter
}

var _ http.Hijacker = mockHijacker{}

func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
c, _ := net.Pipe()
r := bufio.NewReader(strings.NewReader("test_reader"))
w := bufio.NewWriter(&bytes.Buffer{})
rw := bufio.NewReadWriter(r, w)
return c, rw, nil
}

0 comments on commit a0dd49a

Please sign in to comment.