diff --git a/api/api.go b/api/api.go index b8a881165..fab4a9c68 100644 --- a/api/api.go +++ b/api/api.go @@ -14,6 +14,7 @@ import ( "github.com/qri-io/qri/auth/token" "github.com/qri-io/qri/lib" qhttp "github.com/qri-io/qri/lib/http" + "github.com/qri-io/qri/lib/websocket" "github.com/qri-io/qri/version" ) @@ -39,7 +40,7 @@ func init() { type Server struct { *lib.Instance Mux *mux.Router - websocket lib.WebsocketHandler + websocket websocket.Handler } // New creates a new qri server from a p2p node & configuration @@ -56,7 +57,7 @@ func (s Server) Serve(ctx context.Context) (err error) { node.LocalStreams.Print(fmt.Sprintf("qri version v%s\nconnecting...\n", APIVersion)) - ws, err := lib.NewWebsocketHandler(ctx, s.Instance) + ws, err := websocket.NewHandler(ctx, s.Instance.Bus(), s.Instance.KeyStore()) if err != nil { return err } @@ -133,7 +134,7 @@ func readOnlyResponse(w http.ResponseWriter, endpoint string) { func (s *Server) HomeHandler(w http.ResponseWriter, r *http.Request) { upgrade := r.Header.Get("Upgrade") if upgrade == "websocket" { - s.websocket.WSConnectionHandler(w, r) + s.websocket.ConnectionHandler(w, r) } else { if r.URL.Path == "" || r.URL.Path == "/" { HealthCheckHandler(w, r) diff --git a/lib/lib.go b/lib/lib.go index b8db3684e..e29d407a6 100644 --- a/lib/lib.go +++ b/lib/lib.go @@ -1192,6 +1192,14 @@ func (inst *Instance) TokenProvider() token.Provider { return inst.tokenProvider } +// KeyStore exposes the instance key.Store +func (inst *Instance) KeyStore() key.Store { + if inst == nil { + return nil + } + return inst.keystore +} + // activeProfile tries to extract the current user from values embedded in the // passed-in context, falling back to the repo owner as a default active profile func (inst *Instance) activeProfile(ctx context.Context) (pro *profile.Profile, err error) { diff --git a/lib/websocket.go b/lib/websocket.go deleted file mode 100644 index a91ce8b6b..000000000 --- a/lib/websocket.go +++ /dev/null @@ -1,71 +0,0 @@ -package lib - -import ( - "context" - "net/http" - - "github.com/qri-io/qri/event" - "nhooyr.io/websocket" - "nhooyr.io/websocket/wsjson" -) - -const qriWebsocketProtocol = "qri-websocket" - -// WebsocketHandler defines the handler interface -type WebsocketHandler interface { - WSConnectionHandler(w http.ResponseWriter, r *http.Request) -} - -// wsHandler is a concrete implementation of a websocket handler -// and serves to maintain the list of connections -type wsHandler struct { - // Collect all websocket connections - conns []*websocket.Conn -} - -var _ WebsocketHandler = (*wsHandler)(nil) - -// NewWebsocketHandler creates a new wsHandler instance that clients -// can connect to in order to get realtime events -func NewWebsocketHandler(ctx context.Context, inst *Instance) (WebsocketHandler, error) { - ws := &wsHandler{ - conns: []*websocket.Conn{}, - } - - inst.bus.SubscribeAll(ws.wsMessageHandler) - return ws, nil -} - -// WSConnectionHandler handles websocket upgrade requests and accepts the connection -func (h *wsHandler) WSConnectionHandler(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{qriWebsocketProtocol}, - InsecureSkipVerify: true, - }) - if err != nil { - log.Debugf("Websocket accept error: %s", err) - return - } - h.conns = append(h.conns, c) -} - -func (h *wsHandler) wsMessageHandler(_ context.Context, e event.Event) error { - ctx := context.Background() - evt := map[string]interface{}{ - "type": string(e.Type), - "ts": e.Timestamp, - "sessionID": e.SessionID, - "data": e.Payload, - } - - log.Debugf("sending event %q to %d websocket conns", e.Type, len(h.conns)) - for k, c := range h.conns { - go func(k int, c *websocket.Conn) { - err := wsjson.Write(ctx, c, evt) - if err != nil { - log.Errorf("connection %d: wsjson write error: %s", k, err) - } - }(k, c) - } - return nil -} diff --git a/lib/websocket/websocket.go b/lib/websocket/websocket.go new file mode 100644 index 000000000..17252a79e --- /dev/null +++ b/lib/websocket/websocket.go @@ -0,0 +1,321 @@ +package websocket + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sync" + + "github.com/google/uuid" + golog "github.com/ipfs/go-log" + "github.com/qri-io/qri/auth/key" + "github.com/qri-io/qri/auth/token" + "github.com/qri-io/qri/event" + "nhooyr.io/websocket" + "nhooyr.io/websocket/wsjson" +) + +const qriWebsocketProtocol = "qri-websocket" + +var ( + errNotFound = fmt.Errorf("connection not found") + + log = golog.Logger("websocket") +) + +// newID returns a new websocket connection ID +func newID() string { + return uuid.New().String() +} + +// setIDRand sets the random reader that NewID uses as a source of random bytes +// passing in nil will default to crypto.Rand. This can be used to make ID +// generation deterministic for tests. eg: +// myString := "SomeRandomStringThatIsLong-SoYouCanCallItAsMuchAsNeeded..." +// lib.SetIDRand(strings.NewReader(myString)) +// a := NewID() +// lib.SetIDRand(strings.NewReader(myString)) +// b := NewID() +func setIDRand(r io.Reader) { + uuid.SetRand(r) +} + +// Handler defines the handler interface +type Handler interface { + ConnectionHandler(w http.ResponseWriter, r *http.Request) +} + +// connections maintains the set of active websocket connections & associated +// connection metadata +type connections struct { + conns map[string]*conn + connsLock sync.Mutex + keystore key.Store + // TODO(ramfox): using a `map[string]string` to track connections and + // profile.IDs means that each profile can only have one connection + // which will cause two browser tabs from the same profile to fail. + // we need to support multiple connections for the same profile, which + // will require an array of connection ID strings rather than a single + // string + subscriptions map[string]string + subsLock sync.Mutex +} + +type conn struct { + id string + profileID string + conn *websocket.Conn +} + +var _ Handler = (*connections)(nil) + +// NewHandler creates a new connections instance that clients +// can connect to in order to get realtime events +func NewHandler(ctx context.Context, bus event.Bus, keystore key.Store) (Handler, error) { + ws := &connections{ + conns: map[string]*conn{}, + connsLock: sync.Mutex{}, + keystore: keystore, + subscriptions: map[string]string{}, + subsLock: sync.Mutex{}, + } + + bus.SubscribeAll(ws.messageHandler) + return ws, nil +} + +// ConnectionHandler handles websocket upgrade requests and accepts the connection +func (h *connections) ConnectionHandler(w http.ResponseWriter, r *http.Request) { + wsc, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{qriWebsocketProtocol}, + InsecureSkipVerify: true, + }) + if err != nil { + log.Debugf("Websocket accept error: %s", err) + return + } + id := newID() + c := &conn{ + id: id, + conn: wsc, + } + h.connsLock.Lock() + defer h.connsLock.Unlock() + h.conns[id] = c + go h.read(r.Context(), id) +} + +func (h *connections) messageHandler(_ context.Context, e event.Event) error { + ctx := context.Background() + evt := map[string]interface{}{ + "type": string(e.Type), + "ts": e.Timestamp, + "sessionID": e.SessionID, + "data": e.Payload, + } + + profileIDString := e.ProfileID + if profileIDString == "" { + return nil + } + connID, err := h.getConnID(profileIDString) + if err != nil { + return fmt.Errorf("profile %q: %w", profileIDString, err) + } + + c, err := h.getConn(connID) + if err != nil { + h.unsubscribeConn(profileIDString) + return fmt.Errorf("connection %q, profile %q: %w", connID, profileIDString, err) + } + log.Debugf("sending event %q to websocket conns %q", e.Type, profileIDString) + if err := wsjson.Write(ctx, c.conn, evt); err != nil { + log.Errorf("connection %q: wsjson write error: %s", profileIDString, err) + } + return nil +} + +// getConn gets a *conn from the map of connections +func (h *connections) getConn(id string) (*conn, error) { + h.connsLock.Lock() + defer h.connsLock.Unlock() + c, ok := h.conns[id] + if !ok { + return nil, errNotFound + } + return c, nil +} + +// getConnID returns the connection ID associated with the given profile.ID string +func (h *connections) getConnID(profileID string) (string, error) { + h.subsLock.Lock() + defer h.subsLock.Unlock() + id, ok := h.subscriptions[profileID] + if !ok { + return "", errNotFound + } + return id, nil +} + +// subscribeConn authenticates the given token and adds the connID to the map +// of "subscribed" connections +func (h *connections) subscribeConn(connID, tokenString string) error { + ctx := context.TODO() + tok, err := token.ParseAuthToken(ctx, tokenString, h.keystore) + if err != nil { + return err + } + + claims, ok := tok.Claims.(*token.Claims) + if !ok || claims.Subject == "" { + return fmt.Errorf("cannot get profile.ID from token") + } + // TODO(b5): at this point we have a valid signature of a profileID string + // but no proof that this profile is owned by the key that signed the + // token. We either need ProfileID == KeyID, or we need a UCAN. we need to + // check for those, ideally in a method within the profile package that + // abstracts over profile & key agreement + + c, err := h.getConn(connID) + if err != nil { + return fmt.Errorf("connection %q: %w", connID, err) + } + c.profileID = claims.Subject + + h.subsLock.Lock() + defer h.subsLock.Unlock() + h.subscriptions[claims.Subject] = connID + log.Debugw("subscribeConn", "id", connID) + return nil +} + +// unsubscribeConn remove the profileID and connID from the map of "subscribed" +// connections +func (h *connections) unsubscribeConn(profileID string) { + connID, err := h.getConnID(profileID) + if err != nil { + return + } + + c, err := h.getConn(connID) + if err != nil { + return + } + c.profileID = "" + + h.subsLock.Lock() + defer h.subsLock.Unlock() + delete(h.subscriptions, profileID) +} + +// removeConn removes the conn from the map of connections and subscriptions +// closing the connection if needed +func (h *connections) removeConn(id string) { + c, err := h.getConn(id) + if err != nil { + return + } + defer func() { + c.conn.Close(websocket.StatusNormalClosure, "pruning connection") + }() + if c.profileID != "" { + h.unsubscribeConn(c.profileID) + } + h.connsLock.Lock() + defer h.connsLock.Unlock() + delete(h.conns, id) +} + +// read listens to the given connection, handling any messages that come through +// stops listening if it encounters any error +func (h *connections) read(ctx context.Context, id string) error { + msg := &message{} + + c, err := h.getConn(id) + if err != nil { + return fmt.Errorf("connection %q: %w", id, err) + } + + for { + err = wsjson.Read(ctx, c.conn, msg) + if err != nil { + // all websocket methods that return w/ failure are closed + // we must prune the closed connection + h.removeConn(id) + return err + } + h.handleMessage(ctx, c, msg) + } +} + +// handleMessage handles each message based on msgType +func (h *connections) handleMessage(ctx context.Context, c *conn, msg *message) { + switch msg.Type { + case subscribeRequest: + subMsg := &subscribeMessage{} + if err := json.Unmarshal(msg.Payload, subMsg); err != nil { + log.Debugw("websocket unmarshal", "error", err, "connection id", c.id, "msg", msg) + h.write(ctx, c, &message{Type: subscribeFailure, Error: err}) + return + } + if err := h.subscribeConn(c.id, subMsg.Token); err != nil { + log.Debugw("subscribeConn", "error", err, "connection id", c.id, "msg", msg) + h.write(ctx, c, &message{Type: subscribeFailure, Error: err}) + return + } + h.write(ctx, c, &message{Type: subscribeSuccess}) + case unsubscribeRequest: + h.unsubscribeConn(c.profileID) + default: + log.Debug("unknown message type over websocket %s: %q", c.id, msg.Type) + } +} + +// write sends a json message over the connection +func (h *connections) write(ctx context.Context, c *conn, msg *message) { + log.Debugf("sending message %q to websocket conns %q", msg.Type, c.id) + if err := wsjson.Write(ctx, c.conn, msg); err != nil { + log.Errorf("connection %q: wsjson write error: %s", c.id, err) + // the connection will close if there is any `write` error + // we must remove it from our own stores, so as not to hold + // onto any dead connections + h.removeConn(c.id) + } +} + +// msgType is the type of message that we receive on the +type msgType string + +const ( + // subscribeRequest indicates the connection is trying to become + // an authenticated connection + // payload is a `subscribeMessage` + subscribeRequest = msgType("subscribe:request") + // subscribeSuccess indicates that the connection successfully + // upgraded to an authenticated connection + // payload is nil + subscribeSuccess = msgType("subscribe:success") + // subscribeFailure indicates that the connection did not + // upgrade to an authenticated connection + // payload is nil + subscribeFailure = msgType("subscribe:failure") + // unsubscribeRequest indicates the connection no longer wants + // to be authenticated + // payload is nil + unsubscribeRequest = msgType("unsubscribe:request") +) + +// message is the expected structure of an incoming websocket message +type message struct { + Type msgType `json:"type"` + Payload json.RawMessage `json:"payload"` + Error error `json:"error"` +} + +// subscribeMessage is the expected structure of an incoming "subscribe" +// message +type subscribeMessage struct { + Token string `json:"token"` +} diff --git a/lib/websocket/websocket_test.go b/lib/websocket/websocket_test.go new file mode 100644 index 000000000..72492be32 --- /dev/null +++ b/lib/websocket/websocket_test.go @@ -0,0 +1,123 @@ +package websocket + +import ( + "bufio" + "bytes" + "context" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/qri-io/qri/auth/key" + testkeys "github.com/qri-io/qri/auth/key/test" + "github.com/qri-io/qri/auth/token" + "github.com/qri-io/qri/event" +) + +func TestWebsocket(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // create key store & add test key + kd := testkeys.GetKeyData(0) + ks, err := key.NewMemStore() + if err != nil { + t.Fatal(err) + } + if err := ks.AddPubKey(context.Background(), kd.KeyID, kd.PrivKey.GetPublic()); err != nil { + t.Fatal(err) + } + + // create bus + bus := event.NewBus(ctx) + + subsCount := bus.NumSubscribers() + + // create Handler + websocketHandler, err := NewHandler(ctx, bus, ks) + if err != nil { + t.Fatal(err) + } + wsh := websocketHandler.(*connections) + + // websockets should subscribe the message handler + if bus.NumSubscribers() != subsCount+1 { + t.Fatalf("failed to subscribe websocket handlers") + } + + // add connection + randIDStr := "test_connection_id_str" + setIDRand(strings.NewReader(randIDStr)) + connID := newID() + setIDRand(strings.NewReader(randIDStr)) + + wsh.ConnectionHandler(mockWriterAndRequest()) + if _, err := wsh.getConn(connID); err != nil { + t.Fatal("ConnectionHandler did not create a connection") + } + + // create a token from a private key + kd = testkeys.GetKeyData(0) + tokenStr, err := token.NewPrivKeyAuthToken(kd.PrivKey, kd.KeyID.String(), 0) + if err != nil { + t.Fatal(err) + } + // upgrade connection w/ valid token + wsh.subscribeConn(connID, tokenStr) + proID := kd.KeyID.String() + gotConnID, err := wsh.getConnID(proID) + if err != nil { + t.Fatal("connections.subscribeConn did not add profileID or conn to subscriptions map") + } + if gotConnID != connID { + t.Fatalf("connections.subscribeConn added incorrect connID to subscriptions map, expected %q, got %q", connID, gotConnID) + } + + // unsubscribe connection via profileID + wsh.unsubscribeConn(proID) + if _, err := wsh.getConnID(proID); err == nil { + t.Fatal("connections.unsubscribeConn did not remove the profileID from the subscription map") + } + wsc, err := wsh.getConn(connID) + if err != nil { + t.Fatalf("connection %s not found", connID) + } + if wsc.profileID != "" { + t.Error("connections.unsubscribeConn did not remove the profileID from the conn") + } + + // remove the connection + wsh.removeConn(connID) + if _, err := wsh.getConn(connID); err == nil { + t.Fatal("connections.removeConn did not remove the connection from the map of conns") + } +} + +func mockWriterAndRequest() (http.ResponseWriter, *http.Request) { + w := mockHijacker{ + ResponseWriter: httptest.NewRecorder(), + } + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "keep-alive, Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "test_key") + return w, r +} + +type mockHijacker struct { + http.ResponseWriter +} + +var _ http.Hijacker = mockHijacker{} + +func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + c, _ := net.Pipe() + r := bufio.NewReader(strings.NewReader("test_reader")) + w := bufio.NewWriter(&bytes.Buffer{}) + rw := bufio.NewReadWriter(r, w) + return c, rw, nil +} diff --git a/lib/websocket_test.go b/lib/websocket_test.go deleted file mode 100644 index b60a8b239..000000000 --- a/lib/websocket_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package lib - -import ( - "context" - "testing" - - "github.com/qri-io/qfs" - testcfg "github.com/qri-io/qri/config/test" - repotest "github.com/qri-io/qri/repo/test" -) - -func TestWebsocket(t *testing.T) { - tr, err := repotest.NewTempRepo("foo", "websocket_test", repotest.NewTestCrypto()) - if err != nil { - t.Fatal(err) - } - defer tr.Delete() - - cfg := testcfg.DefaultConfigForTesting() - cfg.Filesystems = []qfs.Config{ - {Type: "mem"}, - {Type: "local"}, - } - cfg.Repo.Type = "mem" - - instCtx, instCancel := context.WithCancel(context.Background()) - defer instCancel() - - inst, err := NewInstance(instCtx, tr.QriPath, OptConfig(cfg)) - if err != nil { - t.Fatal(err) - } - - subsCount := inst.bus.NumSubscribers() - - wsCtx, wsCancel := context.WithCancel(context.Background()) - _, err = NewWebsocketHandler(wsCtx, inst) - if err != nil { - t.Fatal(err) - } - - // websockets should subscribe the WS message handler - if inst.bus.NumSubscribers() != subsCount+1 { - t.Fatalf("failed to subscribe websocket handlers") - } - - wsCancel() -}