diff --git a/node/handler.go b/node/handler.go index cfd12e2..07a56a6 100644 --- a/node/handler.go +++ b/node/handler.go @@ -18,6 +18,7 @@ package node import ( "encoding/json" + "errors" "fmt" "io" "math/big" @@ -36,6 +37,10 @@ const ( localIPAddress = "127.0.0.1" ) +var ( + errQuestionUnsupport = errors.New("question unsupport") +) + func (n *Node) handleMessages(ws *websocket.Conn, w galaxy.Wave) (common.Hash, error) { wm := w.(*galaxy.WaveMessages) for _, wmsg := range wm.Msgs { @@ -110,104 +115,105 @@ func (n *Node) handleErr(ws *websocket.Conn, w galaxy.Wave) (common.Hash, error) return wm.WaveID, nil } -func (n Node) handleQuestion(ws *websocket.Conn, w galaxy.Wave) (common.Hash, error) { - wm := w.(*galaxy.WaveQuestion) +func (n Node) handleQuestionRoots(ws *websocket.Conn, wq *galaxy.WaveQuestion) (common.Hash, error) { p := peer.Peer{Conn: ws} - //log.Debug("Received question", wm.Cmd) - switch wm.Cmd { - case galaxy.CmdRoots: - user0, user1, err := db.GetRootUsers(n.udb) - if err != nil { - return wm.WaveID, err - } + user0, user1, err := db.GetRootUsers(n.udb) + if err != nil { + return wq.WaveID, err + } + if err = p.SendRoots(wq.WaveID, user0, user1); err != nil { + return wq.WaveID, err + } + return wq.WaveID, nil +} - if err = p.SendRoots(wm.WaveID, user0, user1); err != nil { - return wm.WaveID, err - } - case galaxy.CmdPeers: - if err := p.SendPeers(wm.WaveID, n.peers, n.localPeer()); err != nil { - return wm.WaveID, err - } +func (n Node) handleQuestionPeers(ws *websocket.Conn, wq *galaxy.WaveQuestion) (common.Hash, error) { + p := peer.Peer{Conn: ws} + if err := p.SendPeers(wq.WaveID, n.peers, n.localPeer()); err != nil { + return wq.WaveID, err + } + // add request peer to node.peers + var remotePeer peer.Peer + if err := json.Unmarshal(wq.Args[0], &remotePeer); err != nil { + return wq.WaveID, err + } + // get remote ip address + remoteAddr := strings.Split(ws.Request().RemoteAddr, ":") + remotePeer.IP = remoteAddr[0] + if err := n.AddPeer(&remotePeer); err != nil { + return wq.WaveID, err + } + return wq.WaveID, nil +} - // add request peer to node.peers - var remotePeer peer.Peer - if err := json.Unmarshal(wm.Args[0], &remotePeer); err != nil { - return wm.WaveID, err - } - // get remote ip address - remoteAddr := strings.Split(ws.Request().RemoteAddr, ":") - remotePeer.IP = remoteAddr[0] - if err := n.AddPeer(&remotePeer); err != nil { - return wm.WaveID, err - } +func (n Node) handleQuestionMsg(ws *websocket.Conn, wq *galaxy.WaveQuestion) (common.Hash, error) { + p := peer.Peer{Conn: ws} - case galaxy.CmdMessages: - var order, count *big.Int - var err error - var msgs []*core.Message - msgID := common.Bytes2Hash(wm.Args[0]) + var order, count *big.Int + var err error + var msgs []*core.Message + msgID := common.Bytes2Hash(wq.Args[0]) - if msgID != common.Bytes2Hash([]byte{}) { - order, count, err = db.GetOrderCntByMsg(n.udb, msgID) - if err != nil { - return wm.WaveID, err - } - order = order.Add(order, big.NewInt(1)) - } else { - order = big.NewInt(0) - count, err = db.GetMsgCount(n.udb) - if err != nil { - return wm.WaveID, err - } - } - - if order != nil && count != nil && count.Uint64()-order.Uint64() > peer.MaxMsgCountPerWave { - //log.Debug("Send msg from order", order, "size", peer.MaxMsgCountPerWave) - msgs = db.GetMsgByOrder(n.udb, order, peer.MaxMsgCountPerWave) + if msgID != common.Bytes2Hash([]byte{}) { + order, count, err = db.GetOrderCntByMsg(n.udb, msgID) + if err != nil { + return wq.WaveID, err } - if err = p.SendMsgs(wm.WaveID, msgs); err != nil { - return wm.WaveID, err + order = order.Add(order, big.NewInt(1)) + } else { + order = big.NewInt(0) + count, err = db.GetMsgCount(n.udb) + if err != nil { + return wq.WaveID, err } } - return wm.WaveID, nil + + if order != nil && count != nil && count.Uint64()-order.Uint64() > peer.MaxMsgCountPerWave { + //log.Debug("Send msg from order", order, "size", peer.MaxMsgCountPerWave) + msgs = db.GetMsgByOrder(n.udb, order, peer.MaxMsgCountPerWave) + } + if err = p.SendMsgs(wq.WaveID, msgs); err != nil { + return wq.WaveID, err + } + return wq.WaveID, nil +} +func (n Node) handleQuestion(ws *websocket.Conn, w galaxy.Wave) (waveID common.Hash, err error) { + waveQuestion := w.(*galaxy.WaveQuestion) + switch waveQuestion.Cmd { + case galaxy.CmdRoots: + waveID, err = n.handleQuestionRoots(ws, waveQuestion) + case galaxy.CmdPeers: + waveID, err = n.handleQuestionPeers(ws, waveQuestion) + case galaxy.CmdMessages: + waveID, err = n.handleQuestionMsg(ws, waveQuestion) + default: + waveID, err = waveQuestion.WaveID, errQuestionUnsupport + } + return waveQuestion.WaveID, err } func (n *Node) handleWave(ws *websocket.Conn, w galaxy.Wave, alwaysTrue bool) (waveID common.Hash, err error) { switch w.Command() { case galaxy.CmdMessages: if !alwaysTrue && !n.wsAcceptMsg { - waveID = w.(*galaxy.WaveMessages).WaveID + waveID, err = w.(*galaxy.WaveMessages).WaveID, nil } else { - if waveID, err = n.handleMessages(ws, w); err != nil { - return waveID, err - } + waveID, err = n.handleMessages(ws, w) } case galaxy.CmdQuestion: - if waveID, err = n.handleQuestion(ws, w); err != nil { - return waveID, err - } + waveID, err = n.handleQuestion(ws, w) case galaxy.CmdPing: - if waveID, err = n.handlePing(ws, w); err != nil { - return waveID, err - } + waveID, err = n.handlePing(ws, w) case galaxy.CmdPong: - if waveID, err = n.handlePong(ws, w); err != nil { - return waveID, err - } + waveID, err = n.handlePong(ws, w) case galaxy.CmdRoots: - if waveID, err = n.handleRoots(ws, w); err != nil { - return waveID, err - } + waveID, err = n.handleRoots(ws, w) case galaxy.CmdPeers: - if waveID, err = n.handlePeers(ws, w); err != nil { - return waveID, err - } + waveID, err = n.handlePeers(ws, w) case galaxy.CmdErr: - if waveID, err = n.handleErr(ws, w); err != nil { - return waveID, err - } + waveID, err = n.handleErr(ws, w) default: - return common.Hash{}, fmt.Errorf("unhandled command [%s]", w.Command()) + waveID, err = common.Hash{}, fmt.Errorf("unhandled command [%s]", w.Command()) } return waveID, nil }