From ef8110ed0f2af183a8f4df5e05ba214f83461dde Mon Sep 17 00:00:00 2001 From: ramfox Date: Thu, 23 Sep 2021 23:33:16 -0400 Subject: [PATCH] feat(websocket): scope websocket connections Websocket connections need to be scoped so we only send the relevant events to the correct connections. To accomplish this, we keep track of the profile.ID of each connections using the new `websocket.conn` struct, as well as keeping a map of profile.IDs to connections. To allow us to authenticate tokens, the `websocket.Handler` must have access to the `key.Store`. We've added an authentication handshake. Once the connection has been established, the client can request to "subscribe" to a connection, and send over a token as part of the message payload. If that token is valid, we upgrade the connection, and send over a success message. If not, we send over a failure message. If the client wants to unsubscribe, they send over an "unsubscribe" request. This will remove the association between that connection and a profile.ID. It will not close the connection. We've also defined a `websocket.message` struct that the websocket expects to receive for communication about the state of the websocket connection (namely the authentication handshake). --- api/api.go | 7 +- lib/lib.go | 8 + lib/websocket.go | 71 ------- lib/websocket/websocket.go | 321 ++++++++++++++++++++++++++++++++ lib/websocket/websocket_test.go | 123 ++++++++++++ lib/websocket_test.go | 48 ----- 6 files changed, 456 insertions(+), 122 deletions(-) delete mode 100644 lib/websocket.go create mode 100644 lib/websocket/websocket.go create mode 100644 lib/websocket/websocket_test.go delete mode 100644 lib/websocket_test.go diff --git a/api/api.go b/api/api.go index b8a881165..fab4a9c68 100644 --- a/api/api.go +++ b/api/api.go @@ -14,6 +14,7 @@ import ( "github.com/qri-io/qri/auth/token" "github.com/qri-io/qri/lib" qhttp "github.com/qri-io/qri/lib/http" + "github.com/qri-io/qri/lib/websocket" "github.com/qri-io/qri/version" ) @@ -39,7 +40,7 @@ func init() { type Server struct { *lib.Instance Mux *mux.Router - websocket lib.WebsocketHandler + websocket websocket.Handler } // New creates a new qri server from a p2p node & configuration @@ -56,7 +57,7 @@ func (s Server) Serve(ctx context.Context) (err error) { node.LocalStreams.Print(fmt.Sprintf("qri version v%s\nconnecting...\n", APIVersion)) - ws, err := lib.NewWebsocketHandler(ctx, s.Instance) + ws, err := websocket.NewHandler(ctx, s.Instance.Bus(), s.Instance.KeyStore()) if err != nil { return err } @@ -133,7 +134,7 @@ func readOnlyResponse(w http.ResponseWriter, endpoint string) { func (s *Server) HomeHandler(w http.ResponseWriter, r *http.Request) { upgrade := r.Header.Get("Upgrade") if upgrade == "websocket" { - s.websocket.WSConnectionHandler(w, r) + s.websocket.ConnectionHandler(w, r) } else { if r.URL.Path == "" || r.URL.Path == "/" { HealthCheckHandler(w, r) diff --git a/lib/lib.go b/lib/lib.go index b8db3684e..e29d407a6 100644 --- a/lib/lib.go +++ b/lib/lib.go @@ -1192,6 +1192,14 @@ func (inst *Instance) TokenProvider() token.Provider { return inst.tokenProvider } +// KeyStore exposes the instance key.Store +func (inst *Instance) KeyStore() key.Store { + if inst == nil { + return nil + } + return inst.keystore +} + // activeProfile tries to extract the current user from values embedded in the // passed-in context, falling back to the repo owner as a default active profile func (inst *Instance) activeProfile(ctx context.Context) (pro *profile.Profile, err error) { diff --git a/lib/websocket.go b/lib/websocket.go deleted file mode 100644 index a91ce8b6b..000000000 --- a/lib/websocket.go +++ /dev/null @@ -1,71 +0,0 @@ -package lib - -import ( - "context" - "net/http" - - "github.com/qri-io/qri/event" - "nhooyr.io/websocket" - "nhooyr.io/websocket/wsjson" -) - -const qriWebsocketProtocol = "qri-websocket" - -// WebsocketHandler defines the handler interface -type WebsocketHandler interface { - WSConnectionHandler(w http.ResponseWriter, r *http.Request) -} - -// wsHandler is a concrete implementation of a websocket handler -// and serves to maintain the list of connections -type wsHandler struct { - // Collect all websocket connections - conns []*websocket.Conn -} - -var _ WebsocketHandler = (*wsHandler)(nil) - -// NewWebsocketHandler creates a new wsHandler instance that clients -// can connect to in order to get realtime events -func NewWebsocketHandler(ctx context.Context, inst *Instance) (WebsocketHandler, error) { - ws := &wsHandler{ - conns: []*websocket.Conn{}, - } - - inst.bus.SubscribeAll(ws.wsMessageHandler) - return ws, nil -} - -// WSConnectionHandler handles websocket upgrade requests and accepts the connection -func (h *wsHandler) WSConnectionHandler(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{qriWebsocketProtocol}, - InsecureSkipVerify: true, - }) - if err != nil { - log.Debugf("Websocket accept error: %s", err) - return - } - h.conns = append(h.conns, c) -} - -func (h *wsHandler) wsMessageHandler(_ context.Context, e event.Event) error { - ctx := context.Background() - evt := map[string]interface{}{ - "type": string(e.Type), - "ts": e.Timestamp, - "sessionID": e.SessionID, - "data": e.Payload, - } - - log.Debugf("sending event %q to %d websocket conns", e.Type, len(h.conns)) - for k, c := range h.conns { - go func(k int, c *websocket.Conn) { - err := wsjson.Write(ctx, c, evt) - if err != nil { - log.Errorf("connection %d: wsjson write error: %s", k, err) - } - }(k, c) - } - return nil -} diff --git a/lib/websocket/websocket.go b/lib/websocket/websocket.go new file mode 100644 index 000000000..17252a79e --- /dev/null +++ b/lib/websocket/websocket.go @@ -0,0 +1,321 @@ +package websocket + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + + "github.com/google/uuid" + golog "github.com/ipfs/go-log" + "github.com/qri-io/qri/auth/key" + "github.com/qri-io/qri/auth/token" + "github.com/qri-io/qri/event" + "nhooyr.io/websocket" + "nhooyr.io/websocket/wsjson" +) + +const qriWebsocketProtocol = "qri-websocket" + +var ( + errNotFound = fmt.Errorf("connection not found") + + log = golog.Logger("websocket") +) + +// newID returns a new websocket connection ID +func newID() string { + return uuid.New().String() +} + +// setIDRand sets the random reader that NewID uses as a source of random bytes +// passing in nil will default to crypto.Rand. This can be used to make ID +// generation deterministic for tests. eg: +// myString := "SomeRandomStringThatIsLong-SoYouCanCallItAsMuchAsNeeded..." +// lib.SetIDRand(strings.NewReader(myString)) +// a := NewID() +// lib.SetIDRand(strings.NewReader(myString)) +// b := NewID() +func setIDRand(r io.Reader) { + uuid.SetRand(r) +} + +// Handler defines the handler interface +type Handler interface { + ConnectionHandler(w http.ResponseWriter, r *http.Request) +} + +// connections maintains the set of active websocket connections & associated +// connection metadata +type connections struct { + conns map[string]*conn + connsLock sync.Mutex + keystore key.Store + // TODO(ramfox): using a `map[string]string` to track connections and + // profile.IDs means that each profile can only have one connection + // which will cause two browser tabs from the same profile to fail. + // we need to support multiple connections for the same profile, which + // will require an array of connection ID strings rather than a single + // string + subscriptions map[string]string + subsLock sync.Mutex +} + +type conn struct { + id string + profileID string + conn *websocket.Conn +} + +var _ Handler = (*connections)(nil) + +// NewHandler creates a new connections instance that clients +// can connect to in order to get realtime events +func NewHandler(ctx context.Context, bus event.Bus, keystore key.Store) (Handler, error) { + ws := &connections{ + conns: map[string]*conn{}, + connsLock: sync.Mutex{}, + keystore: keystore, + subscriptions: map[string]string{}, + subsLock: sync.Mutex{}, + } + + bus.SubscribeAll(ws.messageHandler) + return ws, nil +} + +// ConnectionHandler handles websocket upgrade requests and accepts the connection +func (h *connections) ConnectionHandler(w http.ResponseWriter, r *http.Request) { + wsc, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{qriWebsocketProtocol}, + InsecureSkipVerify: true, + }) + if err != nil { + log.Debugf("Websocket accept error: %s", err) + return + } + id := newID() + c := &conn{ + id: id, + conn: wsc, + } + h.connsLock.Lock() + defer h.connsLock.Unlock() + h.conns[id] = c + go h.read(r.Context(), id) +} + +func (h *connections) messageHandler(_ context.Context, e event.Event) error { + ctx := context.Background() + evt := map[string]interface{}{ + "type": string(e.Type), + "ts": e.Timestamp, + "sessionID": e.SessionID, + "data": e.Payload, + } + + profileIDString := e.ProfileID + if profileIDString == "" { + return nil + } + connID, err := h.getConnID(profileIDString) + if err != nil { + return fmt.Errorf("profile %q: %w", profileIDString, err) + } + + c, err := h.getConn(connID) + if err != nil { + h.unsubscribeConn(profileIDString) + return fmt.Errorf("connection %q, profile %q: %w", connID, profileIDString, err) + } + log.Debugf("sending event %q to websocket conns %q", e.Type, profileIDString) + if err := wsjson.Write(ctx, c.conn, evt); err != nil { + log.Errorf("connection %q: wsjson write error: %s", profileIDString, err) + } + return nil +} + +// getConn gets a *conn from the map of connections +func (h *connections) getConn(id string) (*conn, error) { + h.connsLock.Lock() + defer h.connsLock.Unlock() + c, ok := h.conns[id] + if !ok { + return nil, errNotFound + } + return c, nil +} + +// getConnID returns the connection ID associated with the given profile.ID string +func (h *connections) getConnID(profileID string) (string, error) { + h.subsLock.Lock() + defer h.subsLock.Unlock() + id, ok := h.subscriptions[profileID] + if !ok { + return "", errNotFound + } + return id, nil +} + +// subscribeConn authenticates the given token and adds the connID to the map +// of "subscribed" connections +func (h *connections) subscribeConn(connID, tokenString string) error { + ctx := context.TODO() + tok, err := token.ParseAuthToken(ctx, tokenString, h.keystore) + if err != nil { + return err + } + + claims, ok := tok.Claims.(*token.Claims) + if !ok || claims.Subject == "" { + return fmt.Errorf("cannot get profile.ID from token") + } + // TODO(b5): at this point we have a valid signature of a profileID string + // but no proof that this profile is owned by the key that signed the + // token. We either need ProfileID == KeyID, or we need a UCAN. we need to + // check for those, ideally in a method within the profile package that + // abstracts over profile & key agreement + + c, err := h.getConn(connID) + if err != nil { + return fmt.Errorf("connection %q: %w", connID, err) + } + c.profileID = claims.Subject + + h.subsLock.Lock() + defer h.subsLock.Unlock() + h.subscriptions[claims.Subject] = connID + log.Debugw("subscribeConn", "id", connID) + return nil +} + +// unsubscribeConn remove the profileID and connID from the map of "subscribed" +// connections +func (h *connections) unsubscribeConn(profileID string) { + connID, err := h.getConnID(profileID) + if err != nil { + return + } + + c, err := h.getConn(connID) + if err != nil { + return + } + c.profileID = "" + + h.subsLock.Lock() + defer h.subsLock.Unlock() + delete(h.subscriptions, profileID) +} + +// removeConn removes the conn from the map of connections and subscriptions +// closing the connection if needed +func (h *connections) removeConn(id string) { + c, err := h.getConn(id) + if err != nil { + return + } + defer func() { + c.conn.Close(websocket.StatusNormalClosure, "pruning connection") + }() + if c.profileID != "" { + h.unsubscribeConn(c.profileID) + } + h.connsLock.Lock() + defer h.connsLock.Unlock() + delete(h.conns, id) +} + +// read listens to the given connection, handling any messages that come through +// stops listening if it encounters any error +func (h *connections) read(ctx context.Context, id string) error { + msg := &message{} + + c, err := h.getConn(id) + if err != nil { + return fmt.Errorf("connection %q: %w", id, err) + } + + for { + err = wsjson.Read(ctx, c.conn, msg) + if err != nil { + // all websocket methods that return w/ failure are closed + // we must prune the closed connection + h.removeConn(id) + return err + } + h.handleMessage(ctx, c, msg) + } +} + +// handleMessage handles each message based on msgType +func (h *connections) handleMessage(ctx context.Context, c *conn, msg *message) { + switch msg.Type { + case subscribeRequest: + subMsg := &subscribeMessage{} + if err := json.Unmarshal(msg.Payload, subMsg); err != nil { + log.Debugw("websocket unmarshal", "error", err, "connection id", c.id, "msg", msg) + h.write(ctx, c, &message{Type: subscribeFailure, Error: err}) + return + } + if err := h.subscribeConn(c.id, subMsg.Token); err != nil { + log.Debugw("subscribeConn", "error", err, "connection id", c.id, "msg", msg) + h.write(ctx, c, &message{Type: subscribeFailure, Error: err}) + return + } + h.write(ctx, c, &message{Type: subscribeSuccess}) + case unsubscribeRequest: + h.unsubscribeConn(c.profileID) + default: + log.Debug("unknown message type over websocket %s: %q", c.id, msg.Type) + } +} + +// write sends a json message over the connection +func (h *connections) write(ctx context.Context, c *conn, msg *message) { + log.Debugf("sending message %q to websocket conns %q", msg.Type, c.id) + if err := wsjson.Write(ctx, c.conn, msg); err != nil { + log.Errorf("connection %q: wsjson write error: %s", c.id, err) + // the connection will close if there is any `write` error + // we must remove it from our own stores, so as not to hold + // onto any dead connections + h.removeConn(c.id) + } +} + +// msgType is the type of message that we receive on the +type msgType string + +const ( + // subscribeRequest indicates the connection is trying to become + // an authenticated connection + // payload is a `subscribeMessage` + subscribeRequest = msgType("subscribe:request") + // subscribeSuccess indicates that the connection successfully + // upgraded to an authenticated connection + // payload is nil + subscribeSuccess = msgType("subscribe:success") + // subscribeFailure indicates that the connection did not + // upgrade to an authenticated connection + // payload is nil + subscribeFailure = msgType("subscribe:failure") + // unsubscribeRequest indicates the connection no longer wants + // to be authenticated + // payload is nil + unsubscribeRequest = msgType("unsubscribe:request") +) + +// message is the expected structure of an incoming websocket message +type message struct { + Type msgType `json:"type"` + Payload json.RawMessage `json:"payload"` + Error error `json:"error"` +} + +// subscribeMessage is the expected structure of an incoming "subscribe" +// message +type subscribeMessage struct { + Token string `json:"token"` +} diff --git a/lib/websocket/websocket_test.go b/lib/websocket/websocket_test.go new file mode 100644 index 000000000..72492be32 --- /dev/null +++ b/lib/websocket/websocket_test.go @@ -0,0 +1,123 @@ +package websocket + +import ( + "bufio" + "bytes" + "context" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/qri-io/qri/auth/key" + testkeys "github.com/qri-io/qri/auth/key/test" + "github.com/qri-io/qri/auth/token" + "github.com/qri-io/qri/event" +) + +func TestWebsocket(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // create key store & add test key + kd := testkeys.GetKeyData(0) + ks, err := key.NewMemStore() + if err != nil { + t.Fatal(err) + } + if err := ks.AddPubKey(context.Background(), kd.KeyID, kd.PrivKey.GetPublic()); err != nil { + t.Fatal(err) + } + + // create bus + bus := event.NewBus(ctx) + + subsCount := bus.NumSubscribers() + + // create Handler + websocketHandler, err := NewHandler(ctx, bus, ks) + if err != nil { + t.Fatal(err) + } + wsh := websocketHandler.(*connections) + + // websockets should subscribe the message handler + if bus.NumSubscribers() != subsCount+1 { + t.Fatalf("failed to subscribe websocket handlers") + } + + // add connection + randIDStr := "test_connection_id_str" + setIDRand(strings.NewReader(randIDStr)) + connID := newID() + setIDRand(strings.NewReader(randIDStr)) + + wsh.ConnectionHandler(mockWriterAndRequest()) + if _, err := wsh.getConn(connID); err != nil { + t.Fatal("ConnectionHandler did not create a connection") + } + + // create a token from a private key + kd = testkeys.GetKeyData(0) + tokenStr, err := token.NewPrivKeyAuthToken(kd.PrivKey, kd.KeyID.String(), 0) + if err != nil { + t.Fatal(err) + } + // upgrade connection w/ valid token + wsh.subscribeConn(connID, tokenStr) + proID := kd.KeyID.String() + gotConnID, err := wsh.getConnID(proID) + if err != nil { + t.Fatal("connections.subscribeConn did not add profileID or conn to subscriptions map") + } + if gotConnID != connID { + t.Fatalf("connections.subscribeConn added incorrect connID to subscriptions map, expected %q, got %q", connID, gotConnID) + } + + // unsubscribe connection via profileID + wsh.unsubscribeConn(proID) + if _, err := wsh.getConnID(proID); err == nil { + t.Fatal("connections.unsubscribeConn did not remove the profileID from the subscription map") + } + wsc, err := wsh.getConn(connID) + if err != nil { + t.Fatalf("connection %s not found", connID) + } + if wsc.profileID != "" { + t.Error("connections.unsubscribeConn did not remove the profileID from the conn") + } + + // remove the connection + wsh.removeConn(connID) + if _, err := wsh.getConn(connID); err == nil { + t.Fatal("connections.removeConn did not remove the connection from the map of conns") + } +} + +func mockWriterAndRequest() (http.ResponseWriter, *http.Request) { + w := mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + } + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "keep-alive, Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "test_key") + return w, r +} + +type mockHijacker struct { + http.ResponseWriter +} + +var _ http.Hijacker = mockHijacker{} + +func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + c, _ := net.Pipe() + r := bufio.NewReader(strings.NewReader("test_reader")) + w := bufio.NewWriter(&bytes.Buffer{}) + rw := bufio.NewReadWriter(r, w) + return c, rw, nil +} diff --git a/lib/websocket_test.go b/lib/websocket_test.go deleted file mode 100644 index b60a8b239..000000000 --- a/lib/websocket_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package lib - -import ( - "context" - "testing" - - "github.com/qri-io/qfs" - testcfg "github.com/qri-io/qri/config/test" - repotest "github.com/qri-io/qri/repo/test" -) - -func TestWebsocket(t *testing.T) { - tr, err := repotest.NewTempRepo("foo", "websocket_test", repotest.NewTestCrypto()) - if err != nil { - t.Fatal(err) - } - defer tr.Delete() - - cfg := testcfg.DefaultConfigForTesting() - cfg.Filesystems = []qfs.Config{ - {Type: "mem"}, - {Type: "local"}, - } - cfg.Repo.Type = "mem" - - instCtx, instCancel := context.WithCancel(context.Background()) - defer instCancel() - - inst, err := NewInstance(instCtx, tr.QriPath, OptConfig(cfg)) - if err != nil { - t.Fatal(err) - } - - subsCount := inst.bus.NumSubscribers() - - wsCtx, wsCancel := context.WithCancel(context.Background()) - _, err = NewWebsocketHandler(wsCtx, inst) - if err != nil { - t.Fatal(err) - } - - // websockets should subscribe the WS message handler - if inst.bus.NumSubscribers() != subsCount+1 { - t.Fatalf("failed to subscribe websocket handlers") - } - - wsCancel() -}