Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
ramfox committed Sep 24, 2021
1 parent 0332399 commit e7f976f
Show file tree
Hide file tree
Showing 2 changed files with 368 additions and 24 deletions.
290 changes: 269 additions & 21 deletions lib/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,107 @@ package lib

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"sync"

"github.com/google/uuid"
"github.com/qri-io/qri/auth/key"
"github.com/qri-io/qri/auth/token"
"github.com/qri-io/qri/event"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
)

const qriWebsocketProtocol = "qri-websocket"

var (
errNotFound = fmt.Errorf("not found")
)

// newID returns a new websocket connection ID
func newID() string {
return uuid.New().String()
}

// setIDRand sets the random reader that NewID uses as a source of random bytes
// passing in nil will default to crypto.Rand. This can be used to make ID
// generation deterministic for tests. eg:
// myString := "SomeRandomStringThatIsLong-SoYouCanCallItAsMuchAsNeeded..."
// lib.SetIDRand(strings.NewReader(myString))
// a := NewID()
// lib.SetIDRand(strings.NewReader(myString))
// b := NewID()
func setIDRand(r io.Reader) {
uuid.SetRand(r)
}

// WebsocketHandler defines the handler interface
type WebsocketHandler interface {
WSConnectionHandler(w http.ResponseWriter, r *http.Request)
}

// wsHandler is a concrete implementation of a websocket handler
// and serves to maintain the list of connections
type wsHandler struct {
// Collect all websocket connections
conns []*websocket.Conn
// connections maintains the set of active websocket connections & associated
// connection metadata
type connections struct {
conns map[string]*wsConn
connsLock sync.Mutex
keystore key.Store
// TODO(ramfox): using a `map[string]string` to track connections and
// profile.IDs means that each profile can only have one connection
// which will cause two browser tabs from the same profile to fail.
// we need to support multiple connections for the same profile, which
// will require an array of connection ID strings rather than a single
// string
subscriptions map[string]string
subsLock sync.Mutex
}

type wsConn struct {
profileID string
conn *websocket.Conn
}

var _ WebsocketHandler = (*wsHandler)(nil)
var _ WebsocketHandler = (*connections)(nil)

// NewWebsocketHandler creates a new wsHandler instance that clients
// NewWebsocketHandler creates a new connections instance that clients
// can connect to in order to get realtime events
func NewWebsocketHandler(ctx context.Context, inst *Instance) (WebsocketHandler, error) {
ws := &wsHandler{
conns: []*websocket.Conn{},
ws := &connections{
conns: map[string]*wsConn{},
connsLock: sync.Mutex{},
keystore: inst.keystore,
subscriptions: map[string]string{},
subsLock: sync.Mutex{},
}

inst.bus.SubscribeAll(ws.wsMessageHandler)
return ws, nil
}

// WSConnectionHandler handles websocket upgrade requests and accepts the connection
func (h *wsHandler) WSConnectionHandler(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
func (h *connections) WSConnectionHandler(w http.ResponseWriter, r *http.Request) {
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
Subprotocols: []string{qriWebsocketProtocol},
InsecureSkipVerify: true,
})
if err != nil {
log.Debugf("Websocket accept error: %s", err)
return
}
h.conns = append(h.conns, c)
connID := newID()
wsc := &wsConn{
conn: conn,
}
h.connsLock.Lock()
defer h.connsLock.Unlock()
h.conns[connID] = wsc
go h.read(connID)
}

func (h *wsHandler) wsMessageHandler(_ context.Context, e event.Event) error {
func (h *connections) wsMessageHandler(_ context.Context, e event.Event) error {
ctx := context.Background()
evt := map[string]interface{}{
"type": string(e.Type),
Expand All @@ -58,14 +111,209 @@ func (h *wsHandler) wsMessageHandler(_ context.Context, e event.Event) error {
"data": e.Payload,
}

log.Debugf("sending event %q to %d websocket conns", e.Type, len(h.conns))
for k, c := range h.conns {
go func(k int, c *websocket.Conn) {
err := wsjson.Write(ctx, c, evt)
if err != nil {
log.Errorf("connection %d: wsjson write error: %s", k, err)
}
}(k, c)
profileIDString := e.ProfileID
if profileIDString == "" {
// log.Debugf("Event with SessionID %q has no scope. Not sending event over websocket.", e.SessionID)
return nil
}
connID, err := h.getConnID(profileIDString)
if err != nil {
return fmt.Errorf("connection ID found for profile %q: %w", profileIDString, err)
}

c, err := h.getConn(connID)
if err != nil {
h.unsubscribeConn(profileIDString)
return fmt.Errorf("connection ID %q, profile %q: %w", connID, profileIDString, err)
}
log.Debugf("sending event %q to websocket conns %q", e.Type, profileIDString)
if err := wsjson.Write(ctx, c.conn, evt); err != nil {
log.Errorf("connection %q: wsjson write error: %s", profileIDString, err)
}
return nil
}

// getConn gets a wsConn from the map of connections
func (h *connections) getConn(id string) (*wsConn, error) {
h.connsLock.Lock()
defer h.connsLock.Unlock()
c, ok := h.conns[id]
if !ok {
return nil, errNotFound
}
return c, nil
}

// getConnID returns the connID associated with the given profile.ID string
func (h *connections) getConnID(profileID string) (string, error) {
h.subsLock.Lock()
defer h.subsLock.Unlock()
id, ok := h.subscriptions[profileID]
if !ok {
return "", errNotFound
}
return id, nil
}

// subscribeConn authenticates the given token and adds the connID to the map
// of "subscribed" connections
func (h *connections) subscribeConn(connID, tokenString string) error {
ctx := context.TODO()
tok, err := token.ParseAuthToken(ctx, tokenString, h.keystore)
if err != nil {
return err
}

claims, ok := tok.Claims.(*token.Claims)
if !ok || claims.Subject == "" {
return fmt.Errorf("cannot get profile.ID from token")
}
// TODO(b5): at this point we have a valid signature of a profileID string
// but no proof that this profile is owned by the key that signed the
// token. We either need ProfileID == KeyID, or we need a UCAN. we need to
// check for those, ideally in a method within the profile package that
// abstracts over profile & key agreement

c, err := h.getConn(connID)
if err != nil {
return fmt.Errorf("connection ID %q: %w", connID, err)
}
c.profileID = claims.Subject

h.subsLock.Lock()
defer h.subsLock.Unlock()
h.subscriptions[claims.Subject] = connID
log.Debugw("subscribeConn", "id", connID)
return nil
}

// unsubscribeConn remove the profileID and connID from the map of "subscribed"
// connections
func (h *connections) unsubscribeConn(profileID string) {
connID, err := h.getConnID(profileID)
if err != nil {
return
}

conn, err := h.getConn(connID)
if err != nil {
return
}
conn.profileID = ""

h.subsLock.Lock()
defer h.subsLock.Unlock()
delete(h.subscriptions, profileID)
}

// removeConn removes the conn from the map of connections and subscriptions
// closing the connection if needed
func (h *connections) removeConn(id string) {
c, err := h.getConn(id)
if err != nil {
return
}
defer func() {
c.conn.Close(websocket.StatusNormalClosure, "pruning connection")
}()
if c.profileID != "" {
h.unsubscribeConn(c.profileID)
}
h.connsLock.Lock()
defer h.connsLock.Unlock()
delete(h.conns, id)
}

// read listens to the given connection, handling any messages that come through
// stops listening if it encounters any error
func (h *connections) read(id string) error {
ctx := context.Background()
msg := &message{}
var err error

wsc, err := h.getConn(id)
if err != nil {
return fmt.Errorf("connection ID %q: %w", id, err)
}

for {
err = wsjson.Read(ctx, wsc.conn, msg)
if err != nil {
// all websocket methods that return w/ failure are closed
// we must prune the closed connection
h.removeConn(id)
return err
}
go h.handleMessage(ctx, id, wsc, msg)
}
}

// handleMessage handles each message based on msgType
func (h *connections) handleMessage(ctx context.Context, id string, c *wsConn, msg *message) {
switch msg.Type {
case wsSubscribeRequest:
subMsg := &subscribeMessage{}
if err := json.Unmarshal(msg.Payload, subMsg); err != nil {
log.Debugw("websocket unmarshal", "error", err, "connection id", id, "msg", msg)
h.write(ctx, id, c.conn, &message{Type: wsSubscribeFailure, Error: err})
return
}
if err := h.subscribeConn(id, subMsg.Token); err != nil {
log.Debugw("subscribeConn", "error", err, "connection id", id, "msg", msg)
h.write(ctx, id, c.conn, &message{Type: wsSubscribeFailure, Error: err})
return
}
h.write(ctx, id, c.conn, &message{Type: wsSubscribeSuccess})
case wsUnsubscribeRequest:
h.unsubscribeConn(c.profileID)
default:
log.Debug("unknown message type over websocket %s: %q", id, msg.Type)
}
}

// write sends a json message over the connection
func (h *connections) write(ctx context.Context, id string, conn *websocket.Conn, msg *message) {
log.Debugf("sending message %q to websocket conns %q", msg.Type, id)
if err := wsjson.Write(ctx, conn, msg); err != nil {
log.Errorf("connection %q: wsjson write error: %s", id, err)
// the connection will close if there is any `write` error
// we must remove it from our own stores, so as not to hold
// onto any dead connections
h.removeConn(id)
}
}

// msgType is the type of message that we receive on the
type msgType string

const (
// wsSubscribeRequest indicates the connection is trying to become
// an authenticated connection
// payload is a `subscribeMessage`
wsSubscribeRequest = msgType("subscribe:request")
// wsSubscribeSuccess indicates that the connection successfully
// upgraded to an authenticated connection
// payload is nil
wsSubscribeSuccess = msgType("subscribe:success")
// wsSubscribeFailure indicates that the connection did not
// upgrade to an authenticated connection
// payload is nil
wsSubscribeFailure = msgType("subscribe:failure")
// wsUnsubscribeRequest indicates the connection no longer wants
// to be authenticated
// payload is nil
wsUnsubscribeRequest = msgType("unsubscribe:request")
)

// message is the expected structure of an incoming websocket message
type message struct {
Type msgType `json:"type"`
Payload json.RawMessage `json:"payload"`
Error error `json:"error"`
}

// subscribeMessage is the expected structure of an incoming "subscribe"
// message
type subscribeMessage struct {
Token string `json:"token"`
}
Loading

0 comments on commit e7f976f

Please sign in to comment.