Skip to content

Commit

Permalink
feat(dot/network): add propagate return bool to messageHandler func t…
Browse files Browse the repository at this point in the history
…ype to determine whether to propagate message or not (ChainSafe#1555)
  • Loading branch information
noot committed May 5, 2021
1 parent d7b87f1 commit 0d6f488
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 66 deletions.
8 changes: 4 additions & 4 deletions dot/network/block_announce.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ func (s *Service) validateBlockAnnounceHandshake(peer peer.ID, hs Handshake) err
// handleBlockAnnounceMessage handles BlockAnnounce messages
// if some more blocks are required to sync the announced block, the node will open a sync stream
// with its peer and send a BlockRequest message
func (s *Service) handleBlockAnnounceMessage(peer peer.ID, msg NotificationsMessage) error {
func (s *Service) handleBlockAnnounceMessage(peer peer.ID, msg NotificationsMessage) (propagate bool, err error) {
if an, ok := msg.(*BlockAnnounceMessage); ok {
s.syncQueue.handleBlockAnnounce(an, peer)
err := s.syncer.HandleBlockAnnounce(an)
err = s.syncer.HandleBlockAnnounce(an)
if err != nil {
return err
return false, err
}
}

return nil
return true, nil
}
3 changes: 2 additions & 1 deletion dot/network/block_announce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ func TestHandleBlockAnnounceMessage(t *testing.T) {
Number: big.NewInt(10),
}

err := s.handleBlockAnnounceMessage(peerID, msg)
propagate, err := s.handleBlockAnnounceMessage(peerID, msg)
require.NoError(t, err)
require.True(t, propagate)
}

func TestValidateBlockAnnounceHandshake(t *testing.T) {
Expand Down
11 changes: 3 additions & 8 deletions dot/network/notifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type (
MessageDecoder = func([]byte) (NotificationsMessage, error)

// NotificationsMessageHandler is called when a (non-handshake) message is received over a notifications stream.
NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) error
NotificationsMessageHandler = func(peer peer.ID, msg NotificationsMessage) (propagate bool, err error)
)

type notificationsProtocol struct {
Expand Down Expand Up @@ -180,17 +180,12 @@ func (s *Service) createNotificationsMessageHandler(info *notificationsProtocol,
"peer", stream.Conn().RemotePeer(),
)

err := messageHandler(peer, msg)
propagate, err := messageHandler(peer, msg)
if err != nil {
return err
}

if s.noGossip {
return nil
}

// TODO: we don't want to rebroadcast neighbour messages, so ignore all consensus messages for now
if _, isConsensus := msg.(*ConsensusMessage); isConsensus {
if !propagate || s.noGossip {
return nil
}

Expand Down
6 changes: 3 additions & 3 deletions dot/network/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,11 @@ func decodeTransactionMessage(in []byte) (NotificationsMessage, error) {
return msg, err
}

func (s *Service) handleTransactionMessage(_ peer.ID, msg NotificationsMessage) error {
func (s *Service) handleTransactionMessage(_ peer.ID, msg NotificationsMessage) (bool, error) {
txMsg, ok := msg.(*TransactionMessage)
if !ok {
return errors.New("invalid transaction type")
return false, errors.New("invalid transaction type")
}

return s.transactionHandler.HandleTransactionMessage(txMsg)
return true, s.transactionHandler.HandleTransactionMessage(txMsg)
}
12 changes: 1 addition & 11 deletions lib/grandpa/message_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,7 @@ func NewMessageHandler(grandpa *Service, blockState BlockState) *MessageHandler
// HandleMessage handles a GRANDPA consensus message
// if it is a CommitMessage, it updates the BlockState
// if it is a VoteMessage, it sends it to the GRANDPA service
func (h *MessageHandler) handleMessage(from peer.ID, msg *ConsensusMessage) (network.NotificationsMessage, error) {
if msg == nil || len(msg.Data) == 0 {
logger.Trace("received nil message or message with nil data")
return nil, nil
}

m, err := decodeMessage(msg)
if err != nil {
return nil, err
}

func (h *MessageHandler) handleMessage(from peer.ID, m GrandpaMessage) (network.NotificationsMessage, error) {
logger.Trace("handling grandpa message", "msg", m)

switch m.Type() {
Expand Down
43 changes: 10 additions & 33 deletions lib/grandpa/message_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,8 @@ func TestMessageHandler_VoteMessage(t *testing.T) {
vm, err := gs.createVoteMessage(v, precommit, gs.keypair)
require.NoError(t, err)

cm, err := vm.ToConsensusMessage()
require.NoError(t, err)

h := NewMessageHandler(gs, st.Block)
out, err := h.handleMessage("", cm)
out, err := h.handleMessage("", vm)
require.NoError(t, err)
require.Nil(t, out)

Expand All @@ -194,10 +191,7 @@ func TestMessageHandler_NeighbourMessage(t *testing.T) {
Number: 1,
}

cm, err := msg.ToConsensusMessage()
require.NoError(t, err)

_, err = h.handleMessage("", cm)
_, err := h.handleMessage("", msg)
require.NoError(t, err)

block := &types.Block{
Expand All @@ -211,7 +205,7 @@ func TestMessageHandler_NeighbourMessage(t *testing.T) {
err = st.Block.AddBlock(block)
require.NoError(t, err)

out, err := h.handleMessage("", cm)
out, err := h.handleMessage("", msg)
require.NoError(t, err)
require.Nil(t, out)

Expand Down Expand Up @@ -244,11 +238,9 @@ func TestMessageHandler_CommitMessage_NoCatchUpRequest_ValidSig(t *testing.T) {

fm := gs.newCommitMessage(gs.head, round)
fm.Vote = NewVote(testHash, uint32(round))
cm, err := fm.ToConsensusMessage()
require.NoError(t, err)

h := NewMessageHandler(gs, st.Block)
out, err := h.handleMessage("", cm)
out, err := h.handleMessage("", fm)
require.NoError(t, err)
require.Nil(t, out)

Expand All @@ -270,11 +262,9 @@ func TestMessageHandler_CommitMessage_NoCatchUpRequest_MinVoteError(t *testing.T
gs.justification[round] = buildTestJustification(t, int(gs.state.threshold()), round, gs.state.setID, kr, precommit)

fm := gs.newCommitMessage(gs.head, round)
cm, err := fm.ToConsensusMessage()
require.NoError(t, err)

h := NewMessageHandler(gs, st.Block)
out, err := h.handleMessage("", cm)
out, err := h.handleMessage("", fm)
require.EqualError(t, err, ErrMinVotesNotMet.Error())
require.Nil(t, out)
}
Expand All @@ -291,12 +281,10 @@ func TestMessageHandler_CommitMessage_WithCatchUpRequest(t *testing.T) {
}

fm := gs.newCommitMessage(gs.head, 77)
cm, err := fm.ToConsensusMessage()
require.NoError(t, err)
gs.state.voters = gs.state.voters[:1]

h := NewMessageHandler(gs, st.Block)
out, err := h.handleMessage("", cm)
out, err := h.handleMessage("", fm)
require.NoError(t, err)
require.NotNil(t, out)

Expand All @@ -308,25 +296,19 @@ func TestMessageHandler_CommitMessage_WithCatchUpRequest(t *testing.T) {

func TestMessageHandler_CatchUpRequest_InvalidRound(t *testing.T) {
gs, st := newTestService(t)

req := newCatchUpRequest(77, 0)
cm, err := req.ToConsensusMessage()
require.NoError(t, err)

h := NewMessageHandler(gs, st.Block)
_, err = h.handleMessage("", cm)
_, err := h.handleMessage("", req)
require.Equal(t, ErrInvalidCatchUpRound, err)
}

func TestMessageHandler_CatchUpRequest_InvalidSetID(t *testing.T) {
gs, st := newTestService(t)

req := newCatchUpRequest(1, 77)
cm, err := req.ToConsensusMessage()
require.NoError(t, err)

h := NewMessageHandler(gs, st.Block)
_, err = h.handleMessage("", cm)
_, err := h.handleMessage("", req)
require.Equal(t, ErrSetIDMismatch, err)
}

Expand Down Expand Up @@ -381,11 +363,9 @@ func TestMessageHandler_CatchUpRequest_WithResponse(t *testing.T) {

// create and handle request
req := newCatchUpRequest(round, setID)
cm, err := req.ToConsensusMessage()
require.NoError(t, err)

h := NewMessageHandler(gs, st.Block)
out, err := h.handleMessage("", cm)
out, err := h.handleMessage("", req)
require.NoError(t, err)
require.Equal(t, expected, out)
}
Expand Down Expand Up @@ -495,10 +475,7 @@ func TestMessageHandler_HandleCatchUpResponse(t *testing.T) {
Number: uint32(round),
}

cm, err := msg.ToConsensusMessage()
require.NoError(t, err)

out, err := h.handleMessage("", cm)
out, err := h.handleMessage("", msg)
require.NoError(t, err)
require.Nil(t, out)
require.Equal(t, round+1, gs.state.round)
Expand Down
29 changes: 24 additions & 5 deletions lib/grandpa/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,39 @@ func (s *Service) decodeMessage(in []byte) (NotificationsMessage, error) {
return msg, err
}

func (s *Service) handleNetworkMessage(from peer.ID, msg NotificationsMessage) error {
func (s *Service) handleNetworkMessage(from peer.ID, msg NotificationsMessage) (bool, error) {
if msg == nil {
logger.Trace("received nil message, ignoring")
return false, nil
}

cm, ok := msg.(*network.ConsensusMessage)
if !ok {
return ErrInvalidMessageType
return false, ErrInvalidMessageType
}

if len(cm.Data) == 0 {
logger.Trace("received message with nil data, ignoring")
return false, nil
}

resp, err := s.messageHandler.handleMessage(from, cm)
m, err := decodeMessage(cm)
if err != nil {
return err
return false, err
}

resp, err := s.messageHandler.handleMessage(from, m)
if err != nil {
return false, err
}

if resp != nil {
s.network.SendMessage(resp)
}

return nil
if m.Type() == neighbourType || m.Type() == catchUpResponseType {
return false, nil
}

return true, nil
}
11 changes: 10 additions & 1 deletion lib/grandpa/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,21 @@ func TestHandleNetworkMessage(t *testing.T) {
h := NewMessageHandler(gs, st.Block)
gs.messageHandler = h

err = gs.handleNetworkMessage(peer.ID(""), cm)
propagate, err := gs.handleNetworkMessage(peer.ID(""), cm)
require.NoError(t, err)
require.True(t, propagate)

select {
case <-gs.network.(*testNetwork).out:
case <-time.After(testTimeout):
t.Fatal("expected to send message")
}

neighbourMsg := &NeighbourMessage{}
cm, err = neighbourMsg.ToConsensusMessage()
require.NoError(t, err)

propagate, err = gs.handleNetworkMessage(peer.ID(""), cm)
require.NoError(t, err)
require.False(t, propagate)
}

0 comments on commit 0d6f488

Please sign in to comment.