diff --git a/dot/network/block_announce.go b/dot/network/block_announce.go index 63d02e4479..bf8e655195 100644 --- a/dot/network/block_announce.go +++ b/dot/network/block_announce.go @@ -251,14 +251,14 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err // handleBlockAnnounceMessage handles BlockAnnounce messages // if some more blocks are required to sync the announced block, the node will open a sync stream // with its peer and send a BlockRequest message -func (s *Service) handleBlockAnnounceMessage(peer peer.ID, msg NotificationsMessage) error { +func (s *Service) handleBlockAnnounceMessage(peer peer.ID, msg NotificationsMessage) (propagate bool, err error) { if an, ok := msg.(*BlockAnnounceMessage); ok { s.syncQueue.handleBlockAnnounce(an, peer) - err := s.syncer.HandleBlockAnnounce(an) + err = s.syncer.HandleBlockAnnounce(an) if err != nil { - return err + return false, err } } - return nil + return true, nil } diff --git a/dot/network/block_announce_test.go b/dot/network/block_announce_test.go index 0eac5903cf..4010e002b3 100644 --- a/dot/network/block_announce_test.go +++ b/dot/network/block_announce_test.go @@ -101,8 +101,9 @@ func TestHandleBlockAnnounceMessage(t *testing.T) { Number: big.NewInt(10), } - err := s.handleBlockAnnounceMessage(peerID, msg) + propagate, err := s.handleBlockAnnounceMessage(peerID, msg) require.NoError(t, err) + require.True(t, propagate) } func TestValidateBlockAnnounceHandshake(t *testing.T) { diff --git a/dot/network/notifications.go b/dot/network/notifications.go index a6c2fcd2b7..47792d22a7 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -50,7 +50,7 @@ type ( MessageDecoder = func([]byte) (NotificationsMessage, error) // NotificationsMessageHandler is called when a (non-handshake) message is received over a notifications stream. - NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) error + NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) (propagate bool, err error) ) type notificationsProtocol struct { @@ -180,17 +180,12 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol, "peer", stream.Conn().RemotePeer(), ) - err := messageHandler(peer, msg) + propagate, err := messageHandler(peer, msg) if err != nil { return err } - if s.noGossip { - return nil - } - - // TODO: we don't want to rebroadcast neighbour messages, so ignore all consensus messages for now - if _, isConsensus := msg.(*ConsensusMessage); isConsensus { + if !propagate || s.noGossip { return nil } diff --git a/dot/network/transaction.go b/dot/network/transaction.go index 5131bc23ef..2a2c13f9a2 100644 --- a/dot/network/transaction.go +++ b/dot/network/transaction.go @@ -156,11 +156,11 @@ func decodeTransactionMessage(in []byte) (NotificationsMessage, error) { return msg, err } -func (s *Service) handleTransactionMessage(_ peer.ID, msg NotificationsMessage) error { +func (s *Service) handleTransactionMessage(_ peer.ID, msg NotificationsMessage) (bool, error) { txMsg, ok := msg.(*TransactionMessage) if !ok { - return errors.New("invalid transaction type") + return false, errors.New("invalid transaction type") } - return s.transactionHandler.HandleTransactionMessage(txMsg) + return true, s.transactionHandler.HandleTransactionMessage(txMsg) } diff --git a/lib/grandpa/message_handler.go b/lib/grandpa/message_handler.go index 95d0778b68..31654edc27 100644 --- a/lib/grandpa/message_handler.go +++ b/lib/grandpa/message_handler.go @@ -48,17 +48,7 @@ func NewMessageHandler(grandpa *Service, blockState BlockState) *MessageHandler // HandleMessage handles a GRANDPA consensus message // if it is a CommitMessage, it updates the BlockState // if it is a VoteMessage, it sends it to the GRANDPA service -func (h *MessageHandler) handleMessage(from peer.ID, msg *ConsensusMessage) (network.NotificationsMessage, error) { - if msg == nil || len(msg.Data) == 0 { - logger.Trace("received nil message or message with nil data") - return nil, nil - } - - m, err := decodeMessage(msg) - if err != nil { - return nil, err - } - +func (h *MessageHandler) handleMessage(from peer.ID, m GrandpaMessage) (network.NotificationsMessage, error) { logger.Trace("handling grandpa message", "msg", m) switch m.Type() { diff --git a/lib/grandpa/message_handler_test.go b/lib/grandpa/message_handler_test.go index 05a60547d1..04e7f30dca 100644 --- a/lib/grandpa/message_handler_test.go +++ b/lib/grandpa/message_handler_test.go @@ -167,11 +167,8 @@ func TestMessageHandler_VoteMessage(t *testing.T) { vm, err := gs.createVoteMessage(v, precommit, gs.keypair) require.NoError(t, err) - cm, err := vm.ToConsensusMessage() - require.NoError(t, err) - h := NewMessageHandler(gs, st.Block) - out, err := h.handleMessage("", cm) + out, err := h.handleMessage("", vm) require.NoError(t, err) require.Nil(t, out) @@ -194,10 +191,7 @@ func TestMessageHandler_NeighbourMessage(t *testing.T) { Number: 1, } - cm, err := msg.ToConsensusMessage() - require.NoError(t, err) - - _, err = h.handleMessage("", cm) + _, err := h.handleMessage("", msg) require.NoError(t, err) block := &types.Block{ @@ -211,7 +205,7 @@ func TestMessageHandler_NeighbourMessage(t *testing.T) { err = st.Block.AddBlock(block) require.NoError(t, err) - out, err := h.handleMessage("", cm) + out, err := h.handleMessage("", msg) require.NoError(t, err) require.Nil(t, out) @@ -244,11 +238,9 @@ func TestMessageHandler_CommitMessage_NoCatchUpRequest_ValidSig(t *testing.T) { fm := gs.newCommitMessage(gs.head, round) fm.Vote = NewVote(testHash, uint32(round)) - cm, err := fm.ToConsensusMessage() - require.NoError(t, err) h := NewMessageHandler(gs, st.Block) - out, err := h.handleMessage("", cm) + out, err := h.handleMessage("", fm) require.NoError(t, err) require.Nil(t, out) @@ -270,11 +262,9 @@ func TestMessageHandler_CommitMessage_NoCatchUpRequest_MinVoteError(t *testing.T gs.justification[round] = buildTestJustification(t, int(gs.state.threshold()), round, gs.state.setID, kr, precommit) fm := gs.newCommitMessage(gs.head, round) - cm, err := fm.ToConsensusMessage() - require.NoError(t, err) h := NewMessageHandler(gs, st.Block) - out, err := h.handleMessage("", cm) + out, err := h.handleMessage("", fm) require.EqualError(t, err, ErrMinVotesNotMet.Error()) require.Nil(t, out) } @@ -291,12 +281,10 @@ func TestMessageHandler_CommitMessage_WithCatchUpRequest(t *testing.T) { } fm := gs.newCommitMessage(gs.head, 77) - cm, err := fm.ToConsensusMessage() - require.NoError(t, err) gs.state.voters = gs.state.voters[:1] h := NewMessageHandler(gs, st.Block) - out, err := h.handleMessage("", cm) + out, err := h.handleMessage("", fm) require.NoError(t, err) require.NotNil(t, out) @@ -308,25 +296,19 @@ func TestMessageHandler_CommitMessage_WithCatchUpRequest(t *testing.T) { func TestMessageHandler_CatchUpRequest_InvalidRound(t *testing.T) { gs, st := newTestService(t) - req := newCatchUpRequest(77, 0) - cm, err := req.ToConsensusMessage() - require.NoError(t, err) h := NewMessageHandler(gs, st.Block) - _, err = h.handleMessage("", cm) + _, err := h.handleMessage("", req) require.Equal(t, ErrInvalidCatchUpRound, err) } func TestMessageHandler_CatchUpRequest_InvalidSetID(t *testing.T) { gs, st := newTestService(t) - req := newCatchUpRequest(1, 77) - cm, err := req.ToConsensusMessage() - require.NoError(t, err) h := NewMessageHandler(gs, st.Block) - _, err = h.handleMessage("", cm) + _, err := h.handleMessage("", req) require.Equal(t, ErrSetIDMismatch, err) } @@ -381,11 +363,9 @@ func TestMessageHandler_CatchUpRequest_WithResponse(t *testing.T) { // create and handle request req := newCatchUpRequest(round, setID) - cm, err := req.ToConsensusMessage() - require.NoError(t, err) h := NewMessageHandler(gs, st.Block) - out, err := h.handleMessage("", cm) + out, err := h.handleMessage("", req) require.NoError(t, err) require.Equal(t, expected, out) } @@ -495,10 +475,7 @@ func TestMessageHandler_HandleCatchUpResponse(t *testing.T) { Number: uint32(round), } - cm, err := msg.ToConsensusMessage() - require.NoError(t, err) - - out, err := h.handleMessage("", cm) + out, err := h.handleMessage("", msg) require.NoError(t, err) require.Nil(t, out) require.Equal(t, round+1, gs.state.round) diff --git a/lib/grandpa/network.go b/lib/grandpa/network.go index 5fd7e31b1c..d52677db77 100644 --- a/lib/grandpa/network.go +++ b/lib/grandpa/network.go @@ -124,20 +124,39 @@ func (s *Service) decodeMessage(in []byte) (NotificationsMessage, error) { return msg, err } -func (s *Service) handleNetworkMessage(from peer.ID, msg NotificationsMessage) error { +func (s *Service) handleNetworkMessage(from peer.ID, msg NotificationsMessage) (bool, error) { + if msg == nil { + logger.Trace("received nil message, ignoring") + return false, nil + } + cm, ok := msg.(*network.ConsensusMessage) if !ok { - return ErrInvalidMessageType + return false, ErrInvalidMessageType + } + + if len(cm.Data) == 0 { + logger.Trace("received message with nil data, ignoring") + return false, nil } - resp, err := s.messageHandler.handleMessage(from, cm) + m, err := decodeMessage(cm) if err != nil { - return err + return false, err + } + + resp, err := s.messageHandler.handleMessage(from, m) + if err != nil { + return false, err } if resp != nil { s.network.SendMessage(resp) } - return nil + if m.Type() == neighbourType || m.Type() == catchUpResponseType { + return false, nil + } + + return true, nil } diff --git a/lib/grandpa/network_test.go b/lib/grandpa/network_test.go index 1bcd81ed43..466e92f953 100644 --- a/lib/grandpa/network_test.go +++ b/lib/grandpa/network_test.go @@ -62,12 +62,21 @@ func TestHandleNetworkMessage(t *testing.T) { h := NewMessageHandler(gs, st.Block) gs.messageHandler = h - err = gs.handleNetworkMessage(peer.ID(""), cm) + propagate, err := gs.handleNetworkMessage(peer.ID(""), cm) require.NoError(t, err) + require.True(t, propagate) select { case <-gs.network.(*testNetwork).out: case <-time.After(testTimeout): t.Fatal("expected to send message") } + + neighbourMsg := &NeighbourMessage{} + cm, err = neighbourMsg.ToConsensusMessage() + require.NoError(t, err) + + propagate, err = gs.handleNetworkMessage(peer.ID(""), cm) + require.NoError(t, err) + require.False(t, propagate) }