diff --git a/dot/core/interface.go b/dot/core/interface.go index a27db00227..c9760ae0ce 100644 --- a/dot/core/interface.go +++ b/dot/core/interface.go @@ -41,7 +41,6 @@ type BlockState interface { GetSlotForBlock(common.Hash) (uint64, error) GetFinalisedHeader(uint64, uint64) (*types.Header, error) GetFinalisedHash(uint64, uint64) (common.Hash, error) - SetFinalisedHash(common.Hash, uint64, uint64) error RegisterImportedChannel(ch chan<- *types.Block) (byte, error) UnregisterImportedChannel(id byte) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) diff --git a/dot/network/mock_block_state.go b/dot/network/mock_block_state.go index f93de890dc..54a9af2a6f 100644 --- a/dot/network/mock_block_state.go +++ b/dot/network/mock_block_state.go @@ -11,13 +11,13 @@ import ( types "github.com/ChainSafe/gossamer/dot/types" ) -// MockBlockState is an autogenerated mock type for the BlockState type -type MockBlockState struct { +// mockBlockState is an autogenerated mock type for the BlockState type +type mockBlockState struct { mock.Mock } // BestBlockHeader provides a mock function with given fields: -func (_m *MockBlockState) BestBlockHeader() (*types.Header, error) { +func (_m *mockBlockState) BestBlockHeader() (*types.Header, error) { ret := _m.Called() var r0 *types.Header @@ -40,7 +40,7 @@ func (_m *MockBlockState) BestBlockHeader() (*types.Header, error) { } // BestBlockNumber provides a mock function with given fields: -func (_m *MockBlockState) BestBlockNumber() (*big.Int, error) { +func (_m *mockBlockState) BestBlockNumber() (*big.Int, error) { ret := _m.Called() var r0 *big.Int @@ -63,7 +63,7 @@ func (_m *MockBlockState) BestBlockNumber() (*big.Int, error) { } // GenesisHash provides a mock function with given fields: -func (_m *MockBlockState) GenesisHash() common.Hash { +func (_m *mockBlockState) GenesisHash() common.Hash { ret := _m.Called() var r0 common.Hash @@ -78,22 +78,22 @@ func (_m *MockBlockState) GenesisHash() common.Hash { return r0 } -// GetFinalisedHeader provides a mock function with given fields: round, setID -func (_m *MockBlockState) GetFinalisedHeader(round uint64, setID uint64) (*types.Header, error) { - ret := _m.Called(round, setID) +// GetHashByNumber provides a mock function with given fields: num +func (_m *mockBlockState) GetHashByNumber(num *big.Int) (common.Hash, error) { + ret := _m.Called(num) - var r0 *types.Header - if rf, ok := ret.Get(0).(func(uint64, uint64) *types.Header); ok { - r0 = rf(round, setID) + var r0 common.Hash + if rf, ok := ret.Get(0).(func(*big.Int) common.Hash); ok { + r0 = rf(num) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*types.Header) + r0 = ret.Get(0).(common.Hash) } } var r1 error - if rf, ok := ret.Get(1).(func(uint64, uint64) error); ok { - r1 = rf(round, setID) + if rf, ok := ret.Get(1).(func(*big.Int) error); ok { + r1 = rf(num) } else { r1 = ret.Error(1) } @@ -101,22 +101,22 @@ func (_m *MockBlockState) GetFinalisedHeader(round uint64, setID uint64) (*types return r0, r1 } -// GetHashByNumber provides a mock function with given fields: num -func (_m *MockBlockState) GetHashByNumber(num *big.Int) (common.Hash, error) { - ret := _m.Called(num) +// GetHighestFinalisedHeader provides a mock function with given fields: +func (_m *mockBlockState) GetHighestFinalisedHeader() (*types.Header, error) { + ret := _m.Called() - var r0 common.Hash - if rf, ok := ret.Get(0).(func(*big.Int) common.Hash); ok { - r0 = rf(num) + var r0 *types.Header + if rf, ok := ret.Get(0).(func() *types.Header); ok { + r0 = rf() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(common.Hash) + r0 = ret.Get(0).(*types.Header) } } var r1 error - if rf, ok := ret.Get(1).(func(*big.Int) error); ok { - r1 = rf(num) + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() } else { r1 = ret.Error(1) } @@ -125,7 +125,7 @@ func (_m *MockBlockState) GetHashByNumber(num *big.Int) (common.Hash, error) { } // HasBlockBody provides a mock function with given fields: _a0 -func (_m *MockBlockState) HasBlockBody(_a0 common.Hash) (bool, error) { +func (_m *mockBlockState) HasBlockBody(_a0 common.Hash) (bool, error) { ret := _m.Called(_a0) var r0 bool diff --git a/dot/network/service.go b/dot/network/service.go index fa89abf502..c85a682bff 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -343,7 +343,7 @@ func (s *Service) sentBlockIntervalTelemetry() { } bestHash := best.Hash() - finalized, err := s.blockState.GetFinalisedHeader(0, 0) //nolint + finalized, err := s.blockState.GetHighestFinalisedHeader() //nolint if err != nil { continue } @@ -519,16 +519,21 @@ func (s *Service) GossipMessage(msg NotificationsMessage) { func (s *Service) SendMessage(to peer.ID, msg NotificationsMessage) error { s.notificationsMu.Lock() defer s.notificationsMu.Unlock() + for msgID, prtl := range s.notificationsProtocols { - if msg.Type() != msgID || prtl == nil { + if msg.Type() != msgID { continue } + hs, err := prtl.getHandshake() if err != nil { return err } + s.sendData(to, hs, prtl, msg) + return nil } + return errors.New("message not supported by any notifications protocol") } diff --git a/dot/network/state.go b/dot/network/state.go index ea471ac354..e71fd1f229 100644 --- a/dot/network/state.go +++ b/dot/network/state.go @@ -29,7 +29,7 @@ type BlockState interface { BestBlockNumber() (*big.Int, error) GenesisHash() common.Hash HasBlockBody(common.Hash) (bool, error) - GetFinalisedHeader(round, setID uint64) (*types.Header, error) + GetHighestFinalisedHeader() (*types.Header, error) GetHashByNumber(num *big.Int) (common.Hash, error) } diff --git a/dot/network/sync.go b/dot/network/sync.go index 79fcaa90fd..9e93311952 100644 --- a/dot/network/sync.go +++ b/dot/network/sync.go @@ -351,7 +351,7 @@ func (q *syncQueue) benchmark() { goal := atomic.LoadInt64(&q.goal) if before.Number.Int64() >= goal { - finalised, err := q.s.blockState.GetFinalisedHeader(0, 0) //nolint + finalised, err := q.s.blockState.GetHighestFinalisedHeader() //nolint if err != nil { continue } @@ -767,7 +767,7 @@ func (q *syncQueue) handleBlockJustification(data []*types.BlockData) { } func (q *syncQueue) handleBlockData(data []*types.BlockData) { - finalised, err := q.s.blockState.GetFinalisedHeader(0, 0) + finalised, err := q.s.blockState.GetHighestFinalisedHeader() if err != nil { panic(err) // this should never happen } @@ -815,7 +815,7 @@ func (q *syncQueue) handleBlockDataFailure(idx int, err error, data []*types.Blo logger.Warn("failed to handle block data", "failed on block", q.currStart+int64(idx), "error", err) if errors.Is(err, chaindb.ErrKeyNotFound) || errors.Is(err, blocktree.ErrParentNotFound) { - finalised, err := q.s.blockState.GetFinalisedHeader(0, 0) + finalised, err := q.s.blockState.GetHighestFinalisedHeader() if err != nil { panic(err) } diff --git a/dot/network/test_helpers.go b/dot/network/test_helpers.go index 78e1d513bc..f1912e25ff 100644 --- a/dot/network/test_helpers.go +++ b/dot/network/test_helpers.go @@ -15,7 +15,7 @@ import ( ) // NewMockBlockState create and return a network BlockState interface mock -func NewMockBlockState(n *big.Int) *MockBlockState { +func NewMockBlockState(n *big.Int) *mockBlockState { parentHash, _ := common.HexToHash("0x4545454545454545454545454545454545454545454545454545454545454545") stateRoot, _ := common.HexToHash("0xb3266de137d20a5d0ff3a6401eb57127525fd9b2693701f0bf5a8a853fa3ebe0") extrinsicsRoot, _ := common.HexToHash("0x03170a2e7597b7b7e3d84c05391d139a62b157e78786d8c082f29dcf4c111314") @@ -31,13 +31,12 @@ func NewMockBlockState(n *big.Int) *MockBlockState { Digest: types.Digest{}, } - m := new(MockBlockState) + m := new(mockBlockState) m.On("BestBlockHeader").Return(header, nil) - + m.On("GetHighestFinalisedHeader").Return(header, nil) m.On("GenesisHash").Return(common.NewHash([]byte{})) m.On("BestBlockNumber").Return(big.NewInt(1), nil) m.On("HasBlockBody", mock.AnythingOfType("common.Hash")).Return(false, nil) - m.On("GetFinalisedHeader", mock.AnythingOfType("uint64"), mock.AnythingOfType("uint64")).Return(header, nil) m.On("GetHashByNumber", mock.AnythingOfType("*big.Int")).Return(common.Hash{}, nil) return m diff --git a/dot/rpc/modules/api.go b/dot/rpc/modules/api.go index 5dc1e3328c..833ad231cc 100644 --- a/dot/rpc/modules/api.go +++ b/dot/rpc/modules/api.go @@ -29,8 +29,9 @@ type BlockAPI interface { GetHeader(hash common.Hash) (*types.Header, error) BestBlockHash() common.Hash GetBlockByHash(hash common.Hash) (*types.Block, error) - GetBlockHash(blockNumber *big.Int) (*common.Hash, error) + GetBlockHash(blockNumber *big.Int) (common.Hash, error) GetFinalisedHash(uint64, uint64) (common.Hash, error) + GetHighestFinalisedHash() (common.Hash, error) HasJustification(hash common.Hash) (bool, error) GetJustification(hash common.Hash) ([]byte, error) RegisterImportedChannel(ch chan<- *types.Block) (byte, error) diff --git a/dot/rpc/modules/api_mocks.go b/dot/rpc/modules/api_mocks.go index 46bbfdf0db..b7ebcb21df 100644 --- a/dot/rpc/modules/api_mocks.go +++ b/dot/rpc/modules/api_mocks.go @@ -20,13 +20,14 @@ func NewMockStorageAPI() *modulesmocks.MockStorageAPI { } // NewMockBlockAPI creates and return an rpc BlockAPI interface mock -func NewMockBlockAPI() *modulesmocks.MockBlockAPI { - m := new(modulesmocks.MockBlockAPI) +func NewMockBlockAPI() *modulesmocks.BlockAPI { + m := new(modulesmocks.BlockAPI) m.On("GetHeader", mock.AnythingOfType("common.Hash")).Return(nil, nil) m.On("BestBlockHash").Return(common.Hash{}) m.On("GetBlockByHash", mock.AnythingOfType("common.Hash")).Return(nil, nil) m.On("GetBlockHash", mock.AnythingOfType("*big.Int")).Return(nil, nil) m.On("GetFinalisedHash", mock.AnythingOfType("uint64"), mock.AnythingOfType("uint64")).Return(common.Hash{}, nil) + m.On("GetHighestFinalisedHash").Return(common.Hash{}, nil) m.On("RegisterImportedChannel", mock.AnythingOfType("chan<- *types.Block")).Return(byte(0), nil) m.On("UnregisterImportedChannel", mock.AnythingOfType("uint8")) m.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")).Return(byte(0), nil) diff --git a/dot/rpc/modules/chain.go b/dot/rpc/modules/chain.go index b36713bb84..17c12ec2e6 100644 --- a/dot/rpc/modules/chain.go +++ b/dot/rpc/modules/chain.go @@ -135,7 +135,7 @@ func (cm *ChainModule) GetHead(r *http.Request, req *ChainBlockNumberRequest, re // GetFinalizedHead returns the most recently finalised block hash func (cm *ChainModule) GetFinalizedHead(r *http.Request, req *EmptyRequest, res *ChainHashResponse) error { - h, err := cm.blockAPI.GetFinalisedHash(0, 0) + h, err := cm.blockAPI.GetHighestFinalisedHash() if err != nil { return err } diff --git a/dot/rpc/modules/mocks/block_api.go b/dot/rpc/modules/mocks/BlockAPI.go similarity index 72% rename from dot/rpc/modules/mocks/block_api.go rename to dot/rpc/modules/mocks/BlockAPI.go index 1bb565bfa0..982c91ccc7 100644 --- a/dot/rpc/modules/mocks/block_api.go +++ b/dot/rpc/modules/mocks/BlockAPI.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.8.0. DO NOT EDIT. +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. package mocks @@ -11,13 +11,13 @@ import ( types "github.com/ChainSafe/gossamer/dot/types" ) -// MockBlockAPI is an autogenerated mock type for the BlockAPI type -type MockBlockAPI struct { +// BlockAPI is an autogenerated mock type for the BlockAPI type +type BlockAPI struct { mock.Mock } // BestBlockHash provides a mock function with given fields: -func (_m *MockBlockAPI) BestBlockHash() common.Hash { +func (_m *BlockAPI) BestBlockHash() common.Hash { ret := _m.Called() var r0 common.Hash @@ -33,7 +33,7 @@ func (_m *MockBlockAPI) BestBlockHash() common.Hash { } // GetBlockByHash provides a mock function with given fields: hash -func (_m *MockBlockAPI) GetBlockByHash(hash common.Hash) (*types.Block, error) { +func (_m *BlockAPI) GetBlockByHash(hash common.Hash) (*types.Block, error) { ret := _m.Called(hash) var r0 *types.Block @@ -56,15 +56,15 @@ func (_m *MockBlockAPI) GetBlockByHash(hash common.Hash) (*types.Block, error) { } // GetBlockHash provides a mock function with given fields: blockNumber -func (_m *MockBlockAPI) GetBlockHash(blockNumber *big.Int) (*common.Hash, error) { +func (_m *BlockAPI) GetBlockHash(blockNumber *big.Int) (common.Hash, error) { ret := _m.Called(blockNumber) - var r0 *common.Hash - if rf, ok := ret.Get(0).(func(*big.Int) *common.Hash); ok { + var r0 common.Hash + if rf, ok := ret.Get(0).(func(*big.Int) common.Hash); ok { r0 = rf(blockNumber) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*common.Hash) + r0 = ret.Get(0).(common.Hash) } } @@ -79,7 +79,7 @@ func (_m *MockBlockAPI) GetBlockHash(blockNumber *big.Int) (*common.Hash, error) } // GetFinalisedHash provides a mock function with given fields: _a0, _a1 -func (_m *MockBlockAPI) GetFinalisedHash(_a0 uint64, _a1 uint64) (common.Hash, error) { +func (_m *BlockAPI) GetFinalisedHash(_a0 uint64, _a1 uint64) (common.Hash, error) { ret := _m.Called(_a0, _a1) var r0 common.Hash @@ -102,7 +102,7 @@ func (_m *MockBlockAPI) GetFinalisedHash(_a0 uint64, _a1 uint64) (common.Hash, e } // GetHeader provides a mock function with given fields: hash -func (_m *MockBlockAPI) GetHeader(hash common.Hash) (*types.Header, error) { +func (_m *BlockAPI) GetHeader(hash common.Hash) (*types.Header, error) { ret := _m.Called(hash) var r0 *types.Header @@ -124,8 +124,31 @@ func (_m *MockBlockAPI) GetHeader(hash common.Hash) (*types.Header, error) { return r0, r1 } +// GetHighestFinalisedHash provides a mock function with given fields: +func (_m *BlockAPI) GetHighestFinalisedHash() (common.Hash, error) { + ret := _m.Called() + + var r0 common.Hash + if rf, ok := ret.Get(0).(func() common.Hash); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(common.Hash) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetJustification provides a mock function with given fields: hash -func (_m *MockBlockAPI) GetJustification(hash common.Hash) ([]byte, error) { +func (_m *BlockAPI) GetJustification(hash common.Hash) ([]byte, error) { ret := _m.Called(hash) var r0 []byte @@ -148,7 +171,7 @@ func (_m *MockBlockAPI) GetJustification(hash common.Hash) ([]byte, error) { } // HasJustification provides a mock function with given fields: hash -func (_m *MockBlockAPI) HasJustification(hash common.Hash) (bool, error) { +func (_m *BlockAPI) HasJustification(hash common.Hash) (bool, error) { ret := _m.Called(hash) var r0 bool @@ -169,7 +192,7 @@ func (_m *MockBlockAPI) HasJustification(hash common.Hash) (bool, error) { } // RegisterFinalizedChannel provides a mock function with given fields: ch -func (_m *MockBlockAPI) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) { +func (_m *BlockAPI) RegisterFinalizedChannel(ch chan<- *types.FinalisationInfo) (byte, error) { ret := _m.Called(ch) var r0 byte @@ -190,7 +213,7 @@ func (_m *MockBlockAPI) RegisterFinalizedChannel(ch chan<- *types.FinalisationIn } // RegisterImportedChannel provides a mock function with given fields: ch -func (_m *MockBlockAPI) RegisterImportedChannel(ch chan<- *types.Block) (byte, error) { +func (_m *BlockAPI) RegisterImportedChannel(ch chan<- *types.Block) (byte, error) { ret := _m.Called(ch) var r0 byte @@ -211,7 +234,7 @@ func (_m *MockBlockAPI) RegisterImportedChannel(ch chan<- *types.Block) (byte, e } // SubChain provides a mock function with given fields: start, end -func (_m *MockBlockAPI) SubChain(start common.Hash, end common.Hash) ([]common.Hash, error) { +func (_m *BlockAPI) SubChain(start common.Hash, end common.Hash) ([]common.Hash, error) { ret := _m.Called(start, end) var r0 []common.Hash @@ -234,11 +257,11 @@ func (_m *MockBlockAPI) SubChain(start common.Hash, end common.Hash) ([]common.H } // UnregisterFinalisedChannel provides a mock function with given fields: id -func (_m *MockBlockAPI) UnregisterFinalisedChannel(id byte) { +func (_m *BlockAPI) UnregisterFinalisedChannel(id byte) { _m.Called(id) } // UnregisterImportedChannel provides a mock function with given fields: id -func (_m *MockBlockAPI) UnregisterImportedChannel(id byte) { +func (_m *BlockAPI) UnregisterImportedChannel(id byte) { _m.Called(id) } diff --git a/dot/rpc/modules/state_test.go b/dot/rpc/modules/state_test.go index 39fd5ede14..ca8661e250 100644 --- a/dot/rpc/modules/state_test.go +++ b/dot/rpc/modules/state_test.go @@ -453,5 +453,5 @@ func setupStateModule(t *testing.T) (*StateModule, *common.Hash, *common.Hash) { hash, _ := chain.Block.GetBlockHash(big.NewInt(2)) core := newCoreService(t, chain) - return NewStateModule(net, chain.Storage, core), hash, &sr1 + return NewStateModule(net, chain.Storage, core), &hash, &sr1 } diff --git a/dot/rpc/modules/system_test.go b/dot/rpc/modules/system_test.go index d7819b0261..7519925fd7 100644 --- a/dot/rpc/modules/system_test.go +++ b/dot/rpc/modules/system_test.go @@ -363,7 +363,7 @@ func TestSyncState(t *testing.T) { Number: big.NewInt(int64(49)), } - blockapiMock := new(mocks.MockBlockAPI) + blockapiMock := new(mocks.BlockAPI) blockapiMock.On("BestBlockHash").Return(fakeCommonHash) blockapiMock.On("GetHeader", fakeCommonHash).Return(fakeHeader, nil).Once() diff --git a/dot/rpc/subscription/listeners_test.go b/dot/rpc/subscription/listeners_test.go index 1fb1400267..7be7580550 100644 --- a/dot/rpc/subscription/listeners_test.go +++ b/dot/rpc/subscription/listeners_test.go @@ -83,10 +83,10 @@ func TestBlockListener_Listen(t *testing.T) { wsconn, ws, cancel := setupWSConn(t) defer cancel() - mockBlockAPI := new(mocks.MockBlockAPI) - mockBlockAPI.On("UnregisterImportedChannel", mock.AnythingOfType("uint8")) + BlockAPI := new(mocks.BlockAPI) + BlockAPI.On("UnregisterImportedChannel", mock.AnythingOfType("uint8")) - wsconn.BlockAPI = mockBlockAPI + wsconn.BlockAPI = BlockAPI notifyChan := make(chan *types.Block) bl := BlockListener{ @@ -104,7 +104,7 @@ func TestBlockListener_Listen(t *testing.T) { defer func() { require.NoError(t, bl.Stop()) time.Sleep(time.Millisecond * 10) - mockBlockAPI.AssertCalled(t, "UnregisterImportedChannel", mock.AnythingOfType("uint8")) + BlockAPI.AssertCalled(t, "UnregisterImportedChannel", mock.AnythingOfType("uint8")) }() notifyChan <- block @@ -130,10 +130,10 @@ func TestBlockFinalizedListener_Listen(t *testing.T) { wsconn, ws, cancel := setupWSConn(t) defer cancel() - mockBlockAPI := new(mocks.MockBlockAPI) - mockBlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + BlockAPI := new(mocks.BlockAPI) + BlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) - wsconn.BlockAPI = mockBlockAPI + wsconn.BlockAPI = BlockAPI notifyChan := make(chan *types.FinalisationInfo) bfl := BlockFinalizedListener{ @@ -150,7 +150,7 @@ func TestBlockFinalizedListener_Listen(t *testing.T) { defer func() { require.NoError(t, bfl.Stop()) time.Sleep(time.Millisecond * 10) - mockBlockAPI.AssertCalled(t, "UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + BlockAPI.AssertCalled(t, "UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) }() notifyChan <- &types.FinalisationInfo{ @@ -181,11 +181,11 @@ func TestExtrinsicSubmitListener_Listen(t *testing.T) { notifyImportedChan := make(chan *types.Block, 100) notifyFinalizedChan := make(chan *types.FinalisationInfo, 100) - mockBlockAPI := new(mocks.MockBlockAPI) - mockBlockAPI.On("UnregisterImportedChannel", mock.AnythingOfType("uint8")) - mockBlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + BlockAPI := new(mocks.BlockAPI) + BlockAPI.On("UnregisterImportedChannel", mock.AnythingOfType("uint8")) + BlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) - wsconn.BlockAPI = mockBlockAPI + wsconn.BlockAPI = BlockAPI esl := ExtrinsicSubmitListener{ importedChan: notifyImportedChan, @@ -212,8 +212,8 @@ func TestExtrinsicSubmitListener_Listen(t *testing.T) { require.NoError(t, esl.Stop()) time.Sleep(time.Millisecond * 10) - mockBlockAPI.AssertCalled(t, "UnregisterImportedChannel", mock.AnythingOfType("uint8")) - mockBlockAPI.AssertCalled(t, "UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + BlockAPI.AssertCalled(t, "UnregisterImportedChannel", mock.AnythingOfType("uint8")) + BlockAPI.AssertCalled(t, "UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) }() notifyImportedChan <- block @@ -256,7 +256,7 @@ func TestGrandpaJustification_Listen(t *testing.T) { mockedJustBytes, err := mockedJust.Encode() require.NoError(t, err) - blockStateMock := new(mocks.MockBlockAPI) + blockStateMock := new(mocks.BlockAPI) blockStateMock.On("GetJustification", mock.AnythingOfType("common.Hash")).Return(mockedJustBytes, nil) blockStateMock.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) wsconn.BlockAPI = blockStateMock diff --git a/dot/rpc/subscription/websocket_test.go b/dot/rpc/subscription/websocket_test.go index 3e72d83808..9f5d9d74af 100644 --- a/dot/rpc/subscription/websocket_test.go +++ b/dot/rpc/subscription/websocket_test.go @@ -266,18 +266,18 @@ func TestWSConn_HandleComm(t *testing.T) { mockedJustBytes, err := mockedJust.Encode() require.NoError(t, err) - mockBlockAPI := new(modulesmocks.MockBlockAPI) - mockBlockAPI.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")). + BlockAPI := new(modulesmocks.BlockAPI) + BlockAPI.On("RegisterFinalizedChannel", mock.AnythingOfType("chan<- *types.FinalisationInfo")). Run(func(args mock.Arguments) { ch := args.Get(0).(chan<- *types.FinalisationInfo) fCh = ch }). Return(uint8(4), nil) - mockBlockAPI.On("GetJustification", mock.AnythingOfType("common.Hash")).Return(mockedJustBytes, nil) - mockBlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) + BlockAPI.On("GetJustification", mock.AnythingOfType("common.Hash")).Return(mockedJustBytes, nil) + BlockAPI.On("UnregisterFinalisedChannel", mock.AnythingOfType("uint8")) - wsconn.BlockAPI = mockBlockAPI + wsconn.BlockAPI = BlockAPI listener, err := wsconn.initGrandpaJustificationListener(0, nil) require.NoError(t, err) require.NotNil(t, listener) diff --git a/dot/state/block.go b/dot/state/block.go index 1d53e88f8a..1d0aa41bb5 100644 --- a/dot/state/block.go +++ b/dot/state/block.go @@ -88,9 +88,7 @@ func NewBlockState(db chaindb.Database, bt *blocktree.BlockTree) (*BlockState, e } bs.importedBytePool = common.NewBytePool256() - bs.finalisedBytePool = common.NewBytePool256() - return bs, nil } @@ -123,15 +121,17 @@ func NewBlockStateFromGenesis(db chaindb.Database, header *types.Header) (*Block bs.genesisHash = header.Hash() + if err := bs.db.Put(highestRoundAndSetIDKey, roundAndSetIDToBytes(0, 0)); err != nil { + return nil, err + } + // set the latest finalised head to the genesis header if err := bs.SetFinalisedHash(bs.genesisHash, 0, 0); err != nil { return nil, err } bs.importedBytePool = common.NewBytePool256() - bs.finalisedBytePool = common.NewBytePool256() - return bs, nil } @@ -320,14 +320,13 @@ func (bs *BlockState) GetBlockByNumber(num *big.Int) (*types.Block, error) { } // GetBlockHash returns block hash for a given blockNumber -func (bs *BlockState) GetBlockHash(blockNumber *big.Int) (*common.Hash, error) { - // First retrieve the block hash in a byte array based on the block number from the database +func (bs *BlockState) GetBlockHash(blockNumber *big.Int) (common.Hash, error) { byteHash, err := bs.db.Get(headerHashKey(blockNumber.Uint64())) if err != nil { - return nil, fmt.Errorf("cannot get block %d: %w", blockNumber, err) + return common.Hash{}, fmt.Errorf("cannot get block %d: %w", blockNumber, err) } - hash := common.NewHash(byteHash) - return &hash, nil + + return common.NewHash(byteHash), nil } // SetHeader will set the header into DB diff --git a/dot/state/block_finalisation.go b/dot/state/block_finalisation.go index 9dab50d68a..46c751703f 100644 --- a/dot/state/block_finalisation.go +++ b/dot/state/block_finalisation.go @@ -17,6 +17,7 @@ package state import ( + "encoding/binary" "fmt" "math/big" @@ -24,9 +25,11 @@ import ( "github.com/ChainSafe/gossamer/lib/common" ) +var highestRoundAndSetIDKey = []byte("hrs") + // finalisedHashKey = FinalizedBlockHashKey + round + setID (LE encoded) func finalisedHashKey(round, setID uint64) []byte { - return append(common.FinalizedBlockHashKey, roundSetIDKey(round, setID)...) + return append(common.FinalizedBlockHashKey, roundAndSetIDToBytes(round, setID)...) } // HasFinalisedBlock returns true if there is a finalised block for a given round and setID, false otherwise @@ -46,6 +49,9 @@ func (bs *BlockState) NumberIsFinalised(num *big.Int) (bool, error) { // GetFinalisedHeader returns the finalised block header by round and setID func (bs *BlockState) GetFinalisedHeader(round, setID uint64) (*types.Header, error) { + bs.Lock() + defer bs.Unlock() + h, err := bs.GetFinalisedHash(round, setID) if err != nil { return nil, err @@ -69,8 +75,58 @@ func (bs *BlockState) GetFinalisedHash(round, setID uint64) (common.Hash, error) return common.NewHash(h), nil } +func (bs *BlockState) setHighestRoundAndSetID(round, setID uint64) error { + currRound, currSetID, err := bs.GetHighestRoundAndSetID() + if err != nil { + return err + } + + // higher setID takes precedence over round + if setID < currSetID || setID == currSetID && round <= currRound { + return nil + } + + return bs.db.Put(highestRoundAndSetIDKey, roundAndSetIDToBytes(round, setID)) +} + +// GetHighestRoundAndSetID gets the highest round and setID that have been finalised +func (bs *BlockState) GetHighestRoundAndSetID() (uint64, uint64, error) { + b, err := bs.db.Get(highestRoundAndSetIDKey) + if err != nil { + return 0, 0, err + } + + round := binary.LittleEndian.Uint64(b[:8]) + setID := binary.LittleEndian.Uint64(b[8:16]) + return round, setID, nil +} + +// GetHighestFinalisedHash returns the highest finalised block hash +func (bs *BlockState) GetHighestFinalisedHash() (common.Hash, error) { + round, setID, err := bs.GetHighestRoundAndSetID() + if err != nil { + return common.Hash{}, err + } + + return bs.GetFinalisedHash(round, setID) +} + +// GetHighestFinalisedHeader returns the highest finalised block header +func (bs *BlockState) GetHighestFinalisedHeader() (*types.Header, error) { + h, err := bs.GetHighestFinalisedHash() + if err != nil { + return nil, err + } + + header, err := bs.GetHeader(h) + if err != nil { + return nil, err + } + + return header, nil +} + // SetFinalisedHash sets the latest finalised block header -// Note that using round=0 and setID=0 would refer to the latest finalised hash func (bs *BlockState) SetFinalisedHash(hash common.Hash, round, setID uint64) error { bs.Lock() defer bs.Unlock() @@ -94,23 +150,32 @@ func (bs *BlockState) SetFinalisedHash(hash common.Hash, round, setID uint64) er } pruned := bs.bt.Prune(hash) - for _, rem := range pruned { - header, err := bs.GetHeader(rem) + for _, hash := range pruned { + header, err := bs.GetHeader(hash) if err != nil { - return err + logger.Debug("failed to get pruned header", "hash", hash, "error", err) + continue } - err = bs.DeleteBlock(rem) + err = bs.DeleteBlock(hash) if err != nil { - return err + logger.Debug("failed to delete block", "hash", hash, "error", err) + continue } - logger.Trace("pruned block", "hash", rem, "number", header.Number) - bs.pruneKeyCh <- header + logger.Trace("pruned block", "hash", hash, "number", header.Number) + go func(header *types.Header) { + bs.pruneKeyCh <- header + }(header) } bs.lastFinalised = hash - return bs.db.Put(finalisedHashKey(round, setID), hash[:]) + + if err := bs.db.Put(finalisedHashKey(round, setID), hash[:]); err != nil { + return err + } + + return bs.setHighestRoundAndSetID(round, setID) } func (bs *BlockState) setFirstSlotOnFinalisation() error { diff --git a/dot/state/block_finalisation_test.go b/dot/state/block_finalisation_test.go new file mode 100644 index 0000000000..9d048354d7 --- /dev/null +++ b/dot/state/block_finalisation_test.go @@ -0,0 +1,71 @@ +// Copyright 2019 ChainSafe Systems (ON) Corp. +// This file is part of gossamer. +// +// The gossamer library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The gossamer library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the gossamer library. If not, see . + +package state + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHighestRoundAndSetID(t *testing.T) { + bs := newTestBlockState(t, testGenesisHeader) + round, setID, err := bs.GetHighestRoundAndSetID() + require.NoError(t, err) + require.Equal(t, uint64(0), round) + require.Equal(t, uint64(0), setID) + + err = bs.setHighestRoundAndSetID(1, 0) + require.NoError(t, err) + + round, setID, err = bs.GetHighestRoundAndSetID() + require.NoError(t, err) + require.Equal(t, uint64(1), round) + require.Equal(t, uint64(0), setID) + + err = bs.setHighestRoundAndSetID(10, 0) + require.NoError(t, err) + + round, setID, err = bs.GetHighestRoundAndSetID() + require.NoError(t, err) + require.Equal(t, uint64(10), round) + require.Equal(t, uint64(0), setID) + + err = bs.setHighestRoundAndSetID(9, 0) + require.NoError(t, err) + + round, setID, err = bs.GetHighestRoundAndSetID() + require.NoError(t, err) + require.Equal(t, uint64(10), round) + require.Equal(t, uint64(0), setID) + + err = bs.setHighestRoundAndSetID(0, 1) + require.NoError(t, err) + + round, setID, err = bs.GetHighestRoundAndSetID() + require.NoError(t, err) + require.Equal(t, uint64(0), round) + require.Equal(t, uint64(1), setID) + + err = bs.setHighestRoundAndSetID(100000, 0) + require.NoError(t, err) + + round, setID, err = bs.GetHighestRoundAndSetID() + require.NoError(t, err) + require.Equal(t, uint64(0), round) + require.Equal(t, uint64(1), setID) +} diff --git a/dot/state/grandpa.go b/dot/state/grandpa.go index 9f66064842..af1585f538 100644 --- a/dot/state/grandpa.go +++ b/dot/state/grandpa.go @@ -287,17 +287,17 @@ func (s *GrandpaState) GetNextResume() (*big.Int, error) { func prevotesKey(round, setID uint64) []byte { prevotesPrefix := []byte("pv") - k := roundSetIDKey(round, setID) + k := roundAndSetIDToBytes(round, setID) return append(prevotesPrefix, k...) } func precommitsKey(round, setID uint64) []byte { precommitsPrefix := []byte("pc") - k := roundSetIDKey(round, setID) + k := roundAndSetIDToBytes(round, setID) return append(precommitsPrefix, k...) } -func roundSetIDKey(round, setID uint64) []byte { +func roundAndSetIDToBytes(round, setID uint64) []byte { buf := make([]byte, 8) binary.LittleEndian.PutUint64(buf, round) buf2 := make([]byte, 8) diff --git a/dot/sync/interface.go b/dot/sync/interface.go index 987dbc8648..f99fd4d86f 100644 --- a/dot/sync/interface.go +++ b/dot/sync/interface.go @@ -78,7 +78,7 @@ type Verifier interface { // FinalityGadget implements justification verification functionality type FinalityGadget interface { - VerifyBlockJustification([]byte) error + VerifyBlockJustification(common.Hash, []byte) error } // BlockImportHandler is the interface for the handler of newly imported blocks diff --git a/dot/sync/mocks/FinalityGadget.go b/dot/sync/mocks/FinalityGadget.go new file mode 100644 index 0000000000..6c79912932 --- /dev/null +++ b/dot/sync/mocks/FinalityGadget.go @@ -0,0 +1,27 @@ +// Code generated by mockery v2.8.0. DO NOT EDIT. + +package sync + +import ( + common "github.com/ChainSafe/gossamer/lib/common" + mock "github.com/stretchr/testify/mock" +) + +// FinalityGadget is an autogenerated mock type for the FinalityGadget type +type FinalityGadget struct { + mock.Mock +} + +// VerifyBlockJustification provides a mock function with given fields: _a0, _a1 +func (_m *FinalityGadget) VerifyBlockJustification(_a0 common.Hash, _a1 []byte) error { + ret := _m.Called(_a0, _a1) + + var r0 error + if rf, ok := ret.Get(0).(func(common.Hash, []byte) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/dot/sync/mocks/finality_gadget.go b/dot/sync/mocks/finality_gadget.go deleted file mode 100644 index 5386205968..0000000000 --- a/dot/sync/mocks/finality_gadget.go +++ /dev/null @@ -1,24 +0,0 @@ -// Code generated by mockery v2.8.0. DO NOT EDIT. - -package sync - -import mock "github.com/stretchr/testify/mock" - -// MockFinalityGadget is an autogenerated mock type for the FinalityGadget type -type MockFinalityGadget struct { - mock.Mock -} - -// VerifyBlockJustification provides a mock function with given fields: _a0 -func (_m *MockFinalityGadget) VerifyBlockJustification(_a0 []byte) error { - ret := _m.Called(_a0) - - var r0 error - if rf, ok := ret.Get(0).(func([]byte) error); ok { - r0 = rf(_a0) - } else { - r0 = ret.Error(0) - } - - return r0 -} diff --git a/dot/sync/syncer.go b/dot/sync/syncer.go index 7089099df3..5c4de5135a 100644 --- a/dot/sync/syncer.go +++ b/dot/sync/syncer.go @@ -366,18 +366,12 @@ func (s *Service) handleJustification(header *types.Header, justification []byte return } - err := s.finalityGadget.VerifyBlockJustification(justification) + err := s.finalityGadget.VerifyBlockJustification(header.Hash(), justification) if err != nil { logger.Warn("failed to verify block justification", "hash", header.Hash(), "number", header.Number, "error", err) return } - err = s.blockState.SetFinalisedHash(header.Hash(), 0, 0) - if err != nil { - logger.Error("failed to set finalised hash", "error", err) - return - } - err = s.blockState.SetJustification(header.Hash(), justification) if err != nil { logger.Error("failed tostore justification", "error", err) diff --git a/dot/sync/test_helpers.go b/dot/sync/test_helpers.go index 67dc43f47c..ae8eacf7fc 100644 --- a/dot/sync/test_helpers.go +++ b/dot/sync/test_helpers.go @@ -42,10 +42,10 @@ import ( ) // NewMockFinalityGadget create and return sync FinalityGadget interface mock -func NewMockFinalityGadget() *syncmocks.MockFinalityGadget { - m := new(syncmocks.MockFinalityGadget) +func NewMockFinalityGadget() *syncmocks.FinalityGadget { + m := new(syncmocks.FinalityGadget) // using []uint8 instead of []byte: https://github.com/stretchr/testify/pull/969 - m.On("VerifyBlockJustification", mock.AnythingOfType("[]uint8")).Return(nil) + m.On("VerifyBlockJustification", mock.AnythingOfType("common.Hash"), mock.AnythingOfType("[]uint8")).Return(nil) return m } diff --git a/lib/grandpa/errors.go b/lib/grandpa/errors.go index 3b7cab645c..23a6f4ae53 100644 --- a/lib/grandpa/errors.go +++ b/lib/grandpa/errors.go @@ -18,10 +18,16 @@ package grandpa import ( "errors" + "fmt" "github.com/ChainSafe/gossamer/lib/blocktree" ) +// errRoundMismatch is returned when trying to validate a vote message that isn't for the current round +func errRoundMismatch(got, want uint64) error { + return fmt.Errorf("rounds do not match: got %d, want %d", got, want) +} + //nolint var ( ErrNilBlockState = errors.New("cannot have nil BlockState") @@ -39,9 +45,6 @@ var ( // ErrSetIDMismatch is returned when trying to validate a vote message with an invalid voter set ID, or when receiving a catch up message with a different set ID ErrSetIDMismatch = errors.New("set IDs do not match") - // ErrRoundMismatch is returned when trying to validate a vote message that isn't for the current round - ErrRoundMismatch = errors.New("rounds do not match") - // ErrEquivocation is returned when trying to validate a vote for that is equivocatory ErrEquivocation = errors.New("vote is equivocatory") @@ -92,11 +95,8 @@ var ( // ErrPrecommitSignatureMismatch is returned when the number of precommits and signatures in a CommitMessage do not match ErrPrecommitSignatureMismatch = errors.New("number of precommits does not match number of signatures") - // ErrJustificationHashMismatch is returned when a precommit hash within a justification does not match the justification hash - ErrJustificationHashMismatch = errors.New("precommit hash does not match justification hash") - - // ErrJustificationNumberMismatch is returned when a precommit number within a justification does not match the justification number - ErrJustificationNumberMismatch = errors.New("precommit number does not match justification number") + // ErrPrecommitBlockMismatch is returned when a precommit hash within a justification is not a descendant of the committed block + ErrPrecommitBlockMismatch = errors.New("precommit block is not descendant of committed block") // ErrAuthorityNotInSet is returned when a precommit within a justification is signed by a key not in the authority set ErrAuthorityNotInSet = errors.New("authority is not in set") diff --git a/lib/grandpa/grandpa.go b/lib/grandpa/grandpa.go index c826d5ab00..15c40799fc 100644 --- a/lib/grandpa/grandpa.go +++ b/lib/grandpa/grandpa.go @@ -166,7 +166,7 @@ func NewService(cfg *Config) (*Service, error) { preVotedBlock: make(map[uint64]*Vote), bestFinalCandidate: make(map[uint64]*Vote), head: head, - in: make(chan *networkVoteMessage, 128), + in: make(chan *networkVoteMessage, 1024), resumed: make(chan struct{}), network: cfg.Network, finalisedCh: finalisedCh, @@ -278,9 +278,29 @@ func (s *Service) initiateRound() error { return err } + round, setID, err := s.blockState.GetHighestRoundAndSetID() + if err != nil { + return err + } + + if round > s.state.round && setID == s.state.setID { + logger.Debug("found block finalised in higher round, updating our round...", "new round", round) + s.state.round = round + err = s.grandpaState.SetLatestRound(round) + if err != nil { + return err + } + } + + if setID > s.state.setID { + logger.Debug("found block finalised in higher setID, updating our setID...", "new setID", setID) + s.state.setID = setID + s.state.round = round + } + s.head, err = s.blockState.GetFinalisedHeader(s.state.round, s.state.setID) if err != nil { - logger.Crit("failed to get finalised header", "error", err) + logger.Crit("failed to get finalised header", "round", s.state.round, "error", err) return err } @@ -345,7 +365,8 @@ func (s *Service) initiate() error { } if err != nil { - return err + logger.Warn("failed to play grandpa round", "error", err) + continue } if s.ctx.Err() != nil { @@ -479,6 +500,7 @@ func (s *Service) playGrandpaRound() error { logger.Debug("sending pre-vote message...", "vote", pv) roundComplete := make(chan struct{}) + defer close(roundComplete) // continue to send prevote messages until round is done go s.sendVoteMessage(prevote, vm, roundComplete) @@ -513,7 +535,6 @@ func (s *Service) playGrandpaRound() error { return err } - close(roundComplete) return nil } @@ -561,6 +582,17 @@ func (s *Service) attemptToFinalize() error { return nil // a block was finalised, seems like we missed some messages } + highestRound, highestSetID, _ := s.blockState.GetHighestRoundAndSetID() + if highestRound > s.state.round { + logger.Debug("block was finalised!", "round", highestRound, "setID", highestSetID) + return nil // a block was finalised, seems like we missed some messages + } + + if highestSetID > s.state.setID { + logger.Debug("block was finalised!", "round", highestRound, "setID", highestSetID) + return nil // a block was finalised, seems like we missed some messages + } + bfc, err := s.getBestFinalCandidate() if err != nil { return err @@ -799,12 +831,7 @@ func (s *Service) finalise() error { return err } - if err = s.grandpaState.SetLatestRound(s.state.round); err != nil { - return err - } - - // set latest finalised head in db - return s.blockState.SetFinalisedHash(bfc.Hash, 0, 0) + return s.grandpaState.SetLatestRound(s.state.round) } // createJustification collects the signed precommits received for this round and turns them into @@ -1012,8 +1039,10 @@ func (s *Service) getPreVotedBlock() (Vote, error) { func (s *Service) getGrandpaGHOST() (Vote, error) { threshold := s.state.threshold() - var blocks map[common.Hash]uint32 - var err error + var ( + blocks map[common.Hash]uint32 + err error + ) for { blocks, err = s.getPossibleSelectedBlocks(prevote, threshold) diff --git a/lib/grandpa/grandpa_test.go b/lib/grandpa/grandpa_test.go index 645faeda0e..4dc0213882 100644 --- a/lib/grandpa/grandpa_test.go +++ b/lib/grandpa/grandpa_test.go @@ -71,7 +71,7 @@ func newTestState(t *testing.T) *state.Service { t.Cleanup(func() { db.Close() }) gen, genTrie, _ := genesis.NewTestGenesisWithTrieAndHeader(t) - block, err := state.NewBlockStateFromGenesis(db, testHeader) + block, err := state.NewBlockStateFromGenesis(db, testGenesisHeader) require.NoError(t, err) rtCfg := &wasmer.Config{} @@ -313,7 +313,7 @@ func TestGetPossibleSelectedAncestors_SameAncestor(t *testing.T) { require.NoError(t, err) } - expected, err := common.HexToHash("0x4c897e75b7bf836ed5508bb0f1d04b396ae0bba3a1f902a1ac4195728bec35d9") + expected, err := st.Block.GetBlockHash(big.NewInt(6)) require.NoError(t, err) // this should return the highest common ancestor of (a, b, c) with >=2/3 votes, @@ -369,10 +369,10 @@ func TestGetPossibleSelectedAncestors_VaryingAncestor(t *testing.T) { require.NoError(t, err) } - expectedAt6, err := common.HexToHash("0x4c897e75b7bf836ed5508bb0f1d04b396ae0bba3a1f902a1ac4195728bec35d9") + expectedAt6, err := st.Block.GetBlockHash(big.NewInt(6)) require.NoError(t, err) - expectedAt7, err := common.HexToHash("0x3820aa85743cd534dd1cdb309fe8543f3a6fb5818119b7b857ffafa2ae18ab1b") + expectedAt7, err := st.Block.GetBlockHash(big.NewInt(7)) require.NoError(t, err) // this should return the highest common ancestor of (a, b) and (b, c) with >=2/3 votes, @@ -437,10 +437,10 @@ func TestGetPossibleSelectedAncestors_VaryingAncestor_MoreBranches(t *testing.T) require.NoError(t, err) } - expectedAt6, err := common.HexToHash("0x4c897e75b7bf836ed5508bb0f1d04b396ae0bba3a1f902a1ac4195728bec35d9") + expectedAt6, err := st.Block.GetBlockHash(big.NewInt(6)) require.NoError(t, err) - expectedAt7, err := common.HexToHash("0x3820aa85743cd534dd1cdb309fe8543f3a6fb5818119b7b857ffafa2ae18ab1b") + expectedAt7, err := st.Block.GetBlockHash(big.NewInt(7)) require.NoError(t, err) // this should return the highest common ancestor of (a, b) and (b, c) with >=2/3 votes, @@ -523,7 +523,7 @@ func TestGetPossibleSelectedBlocks_EqualVotes_SameAncestor(t *testing.T) { blocks, err := gs.getPossibleSelectedBlocks(prevote, gs.state.threshold()) require.NoError(t, err) - expected, err := common.HexToHash("0x4c897e75b7bf836ed5508bb0f1d04b396ae0bba3a1f902a1ac4195728bec35d9") + expected, err := st.Block.GetBlockHash(big.NewInt(6)) require.NoError(t, err) // this should return the highest common ancestor of (a, b, c) @@ -572,10 +572,10 @@ func TestGetPossibleSelectedBlocks_EqualVotes_VaryingAncestor(t *testing.T) { blocks, err := gs.getPossibleSelectedBlocks(prevote, gs.state.threshold()) require.NoError(t, err) - expectedAt6, err := common.HexToHash("0x4c897e75b7bf836ed5508bb0f1d04b396ae0bba3a1f902a1ac4195728bec35d9") + expectedAt6, err := st.Block.GetBlockHash(big.NewInt(6)) require.NoError(t, err) - expectedAt7, err := common.HexToHash("0x3820aa85743cd534dd1cdb309fe8543f3a6fb5818119b7b857ffafa2ae18ab1b") + expectedAt7, err := st.Block.GetBlockHash(big.NewInt(7)) require.NoError(t, err) // this should return the highest common ancestor of (a, b) and (b, c) with >=2/3 votes, @@ -743,7 +743,7 @@ func TestGetPreVotedBlock_MultipleCandidates(t *testing.T) { } // expected block is that with the highest number ie. at depth 7 - expected, err := common.HexToHash("0x3820aa85743cd534dd1cdb309fe8543f3a6fb5818119b7b857ffafa2ae18ab1b") + expected, err := st.Block.GetBlockHash(big.NewInt(7)) require.NoError(t, err) block, err := gs.getPreVotedBlock() @@ -817,14 +817,14 @@ func TestGetPreVotedBlock_EvenMoreCandidates(t *testing.T) { t.Log(st.Block.BlocktreeAsString()) - // expected block is at depth 4 - expected, err := common.HexToHash("0x951e6e1a529692b1e6cbfbf00ae6bb39386e7b883c42d92a4672780e769f8a51") + // expected block is at depth 5 + expected, err := st.Block.GetBlockHash(big.NewInt(5)) require.NoError(t, err) block, err := gs.getPreVotedBlock() require.NoError(t, err) require.Equal(t, expected, block.Hash) - require.Equal(t, uint32(4), block.Number) + require.Equal(t, uint32(5), block.Number) } func TestIsCompletable(t *testing.T) { @@ -1259,7 +1259,7 @@ func TestGetGrandpaGHOST_MultipleCandidates(t *testing.T) { t.Log(st.Block.BlocktreeAsString()) // expected block is that with the most votes ie. block 3 - expected, err := common.HexToHash("0x7b8d506f0977136fcb9ba630bc179d30d698d1247dd64f08df976205ad2cc04d") + expected, err := st.Block.GetBlockHash(big.NewInt(3)) require.NoError(t, err) block, err := gs.getGrandpaGHOST() diff --git a/lib/grandpa/message_handler.go b/lib/grandpa/message_handler.go index 3df956c873..4b5a918325 100644 --- a/lib/grandpa/message_handler.go +++ b/lib/grandpa/message_handler.go @@ -64,7 +64,7 @@ func (h *MessageHandler) handleMessage(from peer.ID, m GrandpaMessage) (network. return nil, nil case commitType: if fm, ok := m.(*CommitMessage); ok { - return h.handleCommitMessage(fm) + return nil, h.handleCommitMessage(fm) } case neighbourType: nm, ok := m.(*NeighbourMessage) @@ -116,50 +116,34 @@ func (h *MessageHandler) handleNeighbourMessage(from peer.ID, msg *NeighbourMess return nil } -func (h *MessageHandler) handleCommitMessage(msg *CommitMessage) (*ConsensusMessage, error) { - logger.Debug("received finalisation message", "msg", msg) +func (h *MessageHandler) handleCommitMessage(msg *CommitMessage) error { + logger.Debug("received commit message", "msg", msg) if has, _ := h.blockState.HasFinalisedBlock(msg.Round, h.grandpa.state.setID); has { - return nil, nil + return nil } // check justification here if err := h.verifyCommitMessageJustification(msg); err != nil { - return nil, err + return err } // set finalised head for round in db if err := h.blockState.SetFinalisedHash(msg.Vote.Hash, msg.Round, h.grandpa.state.setID); err != nil { - return nil, err + return err } pcs, err := compactToJustification(msg.Precommits, msg.AuthData) if err != nil { - return nil, err + return err } if err = h.grandpa.grandpaState.SetPrecommits(msg.Round, msg.SetID, pcs); err != nil { - return nil, err - } - - if msg.Round >= h.grandpa.state.round { - // set latest finalised head in db - err = h.blockState.SetFinalisedHash(msg.Vote.Hash, 0, 0) - if err != nil { - return nil, err - } - } - - // check if msg has same setID but is 2 or more rounds ahead of us, if so, return catch-up request to send - if msg.Round > h.grandpa.state.round+1 && !h.grandpa.paused.Load().(bool) { // TODO: CommitMessage does not have setID, confirm this is correct - h.grandpa.paused.Store(true) - h.grandpa.state.round = msg.Round + 1 - req := newCatchUpRequest(msg.Round, h.grandpa.state.setID) - logger.Debug("sending catch-up request; paused service", "round", msg.Round) - return req.ToConsensusMessage() + return err } - return nil, nil + // TODO: re-add catch-up logic + return nil } func (h *MessageHandler) handleCatchUpRequest(msg *catchUpRequest) (*ConsensusMessage, error) { @@ -193,8 +177,13 @@ func (h *MessageHandler) handleCatchUpResponse(msg *catchUpResponse) error { logger.Debug("received catch up response", "round", msg.Round, "setID", msg.SetID, "hash", msg.Hash) + // TODO: re-add catch-up logic + if true { + return nil + } + // if we aren't currently expecting a catch up response, return - if !h.grandpa.paused.Load().(bool) { + if !h.grandpa.paused.Load().(bool) { //nolint logger.Debug("not currently paused, ignoring catch up response") return nil } @@ -249,7 +238,7 @@ func (h *MessageHandler) handleCatchUpResponse(msg *catchUpResponse) error { } // verifyCatchUpResponseCompletability verifies that the pre-commit block is a descendant of, or is, the pre-voted block -func (h *MessageHandler) verifyCatchUpResponseCompletability(prevote, precommit common.Hash) error { +func (h *MessageHandler) verifyCatchUpResponseCompletability(prevote, precommit common.Hash) error { //nolint if prevote == precommit { return nil } @@ -288,6 +277,7 @@ func (h *MessageHandler) verifyCommitMessageJustification(fm *CommitMessage) err isDescendant, err := h.blockState.IsDescendantOf(fm.Vote.Hash, just.Vote.Hash) if err != nil { logger.Warn("verifyCommitMessageJustification", "error", err) + continue } if isDescendant { @@ -400,7 +390,7 @@ func (h *MessageHandler) verifyJustification(just *SignedVote, round, setID uint } // VerifyBlockJustification verifies the finality justification for a block -func (s *Service) VerifyBlockJustification(justification []byte) error { +func (s *Service) VerifyBlockJustification(hash common.Hash, justification []byte) error { r := &bytes.Buffer{} _, _ = r.Write(justification) fj := new(Justification) @@ -414,6 +404,15 @@ func (s *Service) VerifyBlockJustification(justification []byte) error { return fmt.Errorf("cannot get set ID from block number: %w", err) } + has, err := s.blockState.HasFinalisedBlock(fj.Round, setID) + if err != nil { + return err + } + + if has { + return fmt.Errorf("already have finalised block with setID=%d and round=%d", setID, fj.Round) + } + auths, err := s.grandpaState.GetAuthorities(setID) if err != nil { return fmt.Errorf("cannot get authorities for set ID: %w", err) @@ -432,12 +431,14 @@ func (s *Service) VerifyBlockJustification(justification []byte) error { } for _, just := range fj.Commit.Precommits { - if just.Vote.Hash != fj.Commit.Hash { - return ErrJustificationHashMismatch + // check if vote was for descendant of committed block + isDescendant, err := s.blockState.IsDescendantOf(hash, just.Vote.Hash) //nolint + if err != nil { + return err } - if just.Vote.Number != fj.Commit.Number { - return ErrJustificationNumberMismatch + if !isDescendant { + return ErrPrecommitBlockMismatch } pk, err := ed25519.NewPublicKey(just.AuthorityID[:]) @@ -471,6 +472,12 @@ func (s *Service) VerifyBlockJustification(justification []byte) error { } } + err = s.blockState.SetFinalisedHash(hash, fj.Round, setID) + if err != nil { + return err + } + + logger.Debug("set finalised block", "hash", hash, "round", fj.Round, "setID", setID) return nil } diff --git a/lib/grandpa/message_handler_test.go b/lib/grandpa/message_handler_test.go index d2999a8ec0..a3c595cbd4 100644 --- a/lib/grandpa/message_handler_test.go +++ b/lib/grandpa/message_handler_test.go @@ -252,7 +252,7 @@ func TestMessageHandler_CommitMessage_NoCatchUpRequest_ValidSig(t *testing.T) { block := &types.Block{ Header: &types.Header{ - ParentHash: testHeader.Hash(), + ParentHash: testGenesisHeader.Hash(), Number: big.NewInt(1), Digest: types.Digest{ types.NewBabeSecondaryPlainPreDigest(0, 1).ToPreRuntimeDigest(), @@ -269,11 +269,7 @@ func TestMessageHandler_CommitMessage_NoCatchUpRequest_ValidSig(t *testing.T) { require.NoError(t, err) require.Nil(t, out) - hash, err := st.Block.GetFinalisedHash(0, 0) - require.NoError(t, err) - require.Equal(t, fm.Vote.Hash, hash) - - hash, err = st.Block.GetFinalisedHash(fm.Round, gs.state.setID) + hash, err := st.Block.GetFinalisedHash(fm.Round, gs.state.setID) require.NoError(t, err) require.Equal(t, fm.Vote.Hash, hash) } @@ -316,14 +312,8 @@ func TestMessageHandler_CommitMessage_WithCatchUpRequest(t *testing.T) { gs.state.voters = gs.state.voters[:1] h := NewMessageHandler(gs, st.Block) - out, err := h.handleMessage("", fm) - require.NoError(t, err) - require.NotNil(t, out) - - req := newCatchUpRequest(77, gs.state.setID) - expected, err := req.ToConsensusMessage() + _, err = h.handleMessage("", fm) require.NoError(t, err) - require.Equal(t, expected, out) } func TestMessageHandler_CatchUpRequest_InvalidRound(t *testing.T) { @@ -354,7 +344,7 @@ func TestMessageHandler_CatchUpRequest_WithResponse(t *testing.T) { block := &types.Block{ Header: &types.Header{ - ParentHash: testHeader.Hash(), + ParentHash: testGenesisHeader.Hash(), Number: big.NewInt(2), Digest: types.Digest{ types.NewBabeSecondaryPlainPreDigest(0, 1).ToPreRuntimeDigest(), @@ -366,7 +356,7 @@ func TestMessageHandler_CatchUpRequest_WithResponse(t *testing.T) { err := st.Block.AddBlock(block) require.NoError(t, err) - err = gs.blockState.SetFinalisedHash(testHeader.Hash(), round, setID) + err = gs.blockState.SetFinalisedHash(testGenesisHeader.Hash(), round, setID) require.NoError(t, err) err = gs.blockState.(*state.BlockState).SetHeader(block.Header) require.NoError(t, err) @@ -535,6 +525,14 @@ func TestMessageHandler_VerifyBlockJustification(t *testing.T) { err := st.Grandpa.SetNextChange(auths, big.NewInt(1)) require.NoError(t, err) + block := &types.Block{ + Header: testHeader, + Body: &types.Body{0}, + } + + err = st.Block.AddBlock(block) + require.NoError(t, err) + err = st.Grandpa.IncrementSetID() require.NoError(t, err) @@ -542,52 +540,49 @@ func TestMessageHandler_VerifyBlockJustification(t *testing.T) { require.NoError(t, err) require.Equal(t, uint64(1), setID) + genhash := st.Block.GenesisHash() + round := uint64(2) number := uint32(2) precommits := buildTestJustification(t, 2, round, setID, kr, precommit) just := newJustification(round, testHash, number, precommits) data, err := just.Encode() require.NoError(t, err) - err = gs.VerifyBlockJustification(data) + err = gs.VerifyBlockJustification(testHash, data) require.NoError(t, err) // use wrong hash, shouldn't verify - just = newJustification(round, common.Hash{}, number, precommits) - data, err = just.Encode() - require.NoError(t, err) - err = gs.VerifyBlockJustification(data) - require.NotNil(t, err) - require.Equal(t, ErrJustificationHashMismatch, err) - - // use wrong number, shouldn't verify - just = newJustification(round, testHash, number+1, precommits) + precommits = buildTestJustification(t, 2, round+1, setID, kr, precommit) + just = newJustification(round+1, testHash, number, precommits) + just.Commit.Precommits[0].Vote.Hash = genhash data, err = just.Encode() require.NoError(t, err) - err = gs.VerifyBlockJustification(data) + err = gs.VerifyBlockJustification(testHash, data) require.NotNil(t, err) - require.Equal(t, ErrJustificationNumberMismatch, err) + require.Equal(t, ErrPrecommitBlockMismatch, err) // use wrong round, shouldn't verify - just = newJustification(round+1, testHash, number, precommits) + precommits = buildTestJustification(t, 2, round+1, setID, kr, precommit) + just = newJustification(round+2, testHash, number, precommits) data, err = just.Encode() require.NoError(t, err) - err = gs.VerifyBlockJustification(data) + err = gs.VerifyBlockJustification(testHash, data) require.NotNil(t, err) require.Equal(t, ErrInvalidSignature, err) // add authority not in set, shouldn't verify - precommits = buildTestJustification(t, len(auths)+1, round, setID, kr, precommit) - just = newJustification(round, testHash, number, precommits) + precommits = buildTestJustification(t, len(auths)+1, round+1, setID, kr, precommit) + just = newJustification(round+1, testHash, number, precommits) data, err = just.Encode() require.NoError(t, err) - err = gs.VerifyBlockJustification(data) + err = gs.VerifyBlockJustification(testHash, data) require.Equal(t, ErrAuthorityNotInSet, err) // not enough signatures, shouldn't verify - precommits = buildTestJustification(t, 1, round, setID, kr, precommit) - just = newJustification(round, testHash, number, precommits) + precommits = buildTestJustification(t, 1, round+1, setID, kr, precommit) + just = newJustification(round+1, testHash, number, precommits) data, err = just.Encode() require.NoError(t, err) - err = gs.VerifyBlockJustification(data) + err = gs.VerifyBlockJustification(testHash, data) require.Equal(t, ErrMinVotesNotMet, err) } diff --git a/lib/grandpa/message_test.go b/lib/grandpa/message_test.go index cf5c27aa6c..6dc280fd96 100644 --- a/lib/grandpa/message_test.go +++ b/lib/grandpa/message_test.go @@ -106,7 +106,7 @@ func TestNewCatchUpResponse(t *testing.T) { block := &types.Block{ Header: &types.Header{ - ParentHash: testHeader.Hash(), + ParentHash: testGenesisHeader.Hash(), Number: big.NewInt(1), Digest: types.Digest{ types.NewBabeSecondaryPlainPreDigest(0, 1).ToPreRuntimeDigest(), diff --git a/lib/grandpa/message_tracker.go b/lib/grandpa/message_tracker.go index 9101859a3c..c0006a218f 100644 --- a/lib/grandpa/message_tracker.go +++ b/lib/grandpa/message_tracker.go @@ -69,6 +69,7 @@ func (t *tracker) add(v *networkVoteMessage) { } t.mapLock.Lock() + // TODO: change to map of maps, this allows duplicates t.messages[v.msg.Message.Hash] = append(t.messages[v.msg.Message.Hash], v) t.mapLock.Unlock() } diff --git a/lib/grandpa/network_test.go b/lib/grandpa/network_test.go index efd5fa9282..abe523fc33 100644 --- a/lib/grandpa/network_test.go +++ b/lib/grandpa/network_test.go @@ -73,12 +73,6 @@ func TestHandleNetworkMessage(t *testing.T) { require.NoError(t, err) require.True(t, propagate) - select { - case <-gs.network.(*testNetwork).out: - case <-time.After(testTimeout): - t.Fatal("expected to send message") - } - neighbourMsg := &NeighbourMessage{} cm, err = neighbourMsg.ToConsensusMessage() require.NoError(t, err) diff --git a/lib/grandpa/state.go b/lib/grandpa/state.go index 423378cd0c..14bcbbab44 100644 --- a/lib/grandpa/state.go +++ b/lib/grandpa/state.go @@ -50,6 +50,7 @@ type BlockState interface { GetJustification(hash common.Hash) ([]byte, error) GetHashByNumber(num *big.Int) (common.Hash, error) BestBlockNumber() (*big.Int, error) + GetHighestRoundAndSetID() (uint64, uint64, error) } // GrandpaState is the interface required by grandpa into the grandpa state diff --git a/lib/grandpa/vote_message.go b/lib/grandpa/vote_message.go index 4c56c604cd..45bc37b037 100644 --- a/lib/grandpa/vote_message.go +++ b/lib/grandpa/vote_message.go @@ -170,12 +170,12 @@ func (s *Service) validateMessage(from peer.ID, m *VoteMessage) (*Vote, error) { } if err = s.network.SendMessage(from, msg); err != nil { - return nil, err + logger.Warn("failed to send CommitMessage", "error", err) } } // TODO: get justification if your round is lower, or just do catch-up? - return nil, ErrRoundMismatch + return nil, errRoundMismatch(m.Round, s.state.round) } // check for equivocation ie. multiple votes within one subround @@ -193,7 +193,7 @@ func (s *Service) validateMessage(from peer.ID, m *VoteMessage) (*Vote, error) { } err = s.validateVote(vote) - if errors.Is(err, ErrBlockDoesNotExist) || errors.Is(err, blocktree.ErrEndNodeNotFound) { + if errors.Is(err, ErrBlockDoesNotExist) || errors.Is(err, blocktree.ErrDescendantNotFound) || errors.Is(err, blocktree.ErrEndNodeNotFound) || errors.Is(err, blocktree.ErrStartNodeNotFound) { // TODO: cancel if block is imported; if we refactor the syncing this will likely become cleaner // as we can have an API to synchronously sync and import a block go s.network.SendBlockReqestByHash(vote.Hash) diff --git a/lib/grandpa/vote_message_test.go b/lib/grandpa/vote_message_test.go index 182c4f533e..f70239f985 100644 --- a/lib/grandpa/vote_message_test.go +++ b/lib/grandpa/vote_message_test.go @@ -371,6 +371,8 @@ func TestValidateMessage_IsNotDescendant(t *testing.T) { gs, err := NewService(cfg) require.NoError(t, err) + gs.tracker, err = newTracker(gs.blockState, gs.in) + require.NoError(t, err) var branches []*types.Header for {