diff --git a/dot/network/host.go b/dot/network/host.go index fe0f73ce1c..2082095a2c 100644 --- a/dot/network/host.go +++ b/dot/network/host.go @@ -368,6 +368,42 @@ func (h *host) peers() []peer.ID { return h.h.Network().Peers() } +// addReservedPeers adds the peers `addrs` to the protected peers list and connects to them +func (h *host) addReservedPeers(addrs ...string) error { + for _, addr := range addrs { + maddr, err := ma.NewMultiaddr(addr) + if err != nil { + return err + } + + addinfo, err := peer.AddrInfoFromP2pAddr(maddr) + if err != nil { + return err + } + + h.h.ConnManager().Protect(addinfo.ID, "") + if err := h.connect(*addinfo); err != nil { + return err + } + } + + return nil +} + +// removeReservedPeers will remove the given peers from the protected peers list +func (h *host) removeReservedPeers(ids ...string) error { + for _, id := range ids { + peerID, err := peer.Decode(id) + if err != nil { + return err + } + + h.h.ConnManager().Unprotect(peerID, "") + } + + return nil +} + // supportsProtocol checks if the protocol is supported by peerID // returns an error if could not get peer protocols func (h *host) supportsProtocol(peerID peer.ID, protocol protocol.ID) (bool, error) { diff --git a/dot/network/host_test.go b/dot/network/host_test.go index 4cb3895008..41c50ca53a 100644 --- a/dot/network/host_test.go +++ b/dot/network/host_test.go @@ -406,3 +406,73 @@ func Test_PeerSupportsProtocol(t *testing.T) { require.Equal(t, test.expect, output) } } + +func Test_AddReservedPeers(t *testing.T) { + basePathA := utils.NewTestBasePath(t, "nodeA") + configA := &Config{ + BasePath: basePathA, + Port: 7001, + NoBootstrap: true, + NoMDNS: true, + } + + nodeA := createTestService(t, configA) + nodeA.noGossip = true + + basePathB := utils.NewTestBasePath(t, "nodeB") + configB := &Config{ + BasePath: basePathB, + Port: 7002, + NoBootstrap: true, + NoMDNS: true, + } + + nodeB := createTestService(t, configB) + nodeB.noGossip = true + + nodeBPeerAddr := nodeB.host.multiaddrs()[0].String() + err := nodeA.host.addReservedPeers(nodeBPeerAddr) + require.NoError(t, err) + + isProtected := nodeA.host.h.ConnManager().IsProtected(nodeB.host.addrInfo().ID, "") + require.True(t, isProtected) +} + +func Test_RemoveReservedPeers(t *testing.T) { + basePathA := utils.NewTestBasePath(t, "nodeA") + configA := &Config{ + BasePath: basePathA, + Port: 7001, + NoBootstrap: true, + NoMDNS: true, + } + + nodeA := createTestService(t, configA) + nodeA.noGossip = true + + basePathB := utils.NewTestBasePath(t, "nodeB") + configB := &Config{ + BasePath: basePathB, + Port: 7002, + NoBootstrap: true, + NoMDNS: true, + } + + nodeB := createTestService(t, configB) + nodeB.noGossip = true + + nodeBPeerAddr := nodeB.host.multiaddrs()[0].String() + err := nodeA.host.addReservedPeers(nodeBPeerAddr) + require.NoError(t, err) + + pID := nodeB.host.addrInfo().ID.String() + + err = nodeA.host.removeReservedPeers(pID) + require.NoError(t, err) + + isProtected := nodeA.host.h.ConnManager().IsProtected(nodeB.host.addrInfo().ID, "") + require.False(t, isProtected) + + err = nodeA.host.removeReservedPeers("failing peer ID") + require.Error(t, err) +} diff --git a/dot/network/service.go b/dot/network/service.go index ec9c130758..e6db0e0bd5 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -705,6 +705,16 @@ func (s *Service) Peers() []common.PeerInfo { return peers } +// AddReservedPeers insert new peers to the peerstore with PermanentAddrTTL +func (s *Service) AddReservedPeers(addrs ...string) error { + return s.host.addReservedPeers(addrs...) +} + +// RemoveReservedPeers closes all connections with the target peers and remove it from the peerstore +func (s *Service) RemoveReservedPeers(addrs ...string) error { + return s.host.removeReservedPeers(addrs...) +} + // NodeRoles Returns the roles the node is running as. func (s *Service) NodeRoles() byte { return s.cfg.Roles diff --git a/dot/rpc/modules/api.go b/dot/rpc/modules/api.go index 48465bb676..023b65995e 100644 --- a/dot/rpc/modules/api.go +++ b/dot/rpc/modules/api.go @@ -53,6 +53,8 @@ type NetworkAPI interface { IsStopped() bool HighestBlock() int64 StartingBlock() int64 + AddReservedPeers(addrs ...string) error + RemoveReservedPeers(addrs ...string) error } // BlockProducerAPI is the interface for BlockProducer methods diff --git a/dot/rpc/modules/mocks/network_api.go b/dot/rpc/modules/mocks/network_api.go index 7802cdf6dd..7d8360dc5f 100644 --- a/dot/rpc/modules/mocks/network_api.go +++ b/dot/rpc/modules/mocks/network_api.go @@ -12,6 +12,26 @@ type MockNetworkAPI struct { mock.Mock } +// AddReservedPeers provides a mock function with given fields: addrs +func (_m *MockNetworkAPI) AddReservedPeers(addrs ...string) error { + _va := make([]interface{}, len(addrs)) + for _i := range addrs { + _va[_i] = addrs[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(...string) error); ok { + r0 = rf(addrs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Health provides a mock function with given fields: func (_m *MockNetworkAPI) Health() common.Health { ret := _m.Called() @@ -98,6 +118,26 @@ func (_m *MockNetworkAPI) Peers() []common.PeerInfo { return r0 } +// RemoveReservedPeers provides a mock function with given fields: addrs +func (_m *MockNetworkAPI) RemoveReservedPeers(addrs ...string) error { + _va := make([]interface{}, len(addrs)) + for _i := range addrs { + _va[_i] = addrs[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(...string) error); ok { + r0 = rf(addrs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Start provides a mock function with given fields: func (_m *MockNetworkAPI) Start() error { ret := _m.Called() diff --git a/dot/rpc/modules/system.go b/dot/rpc/modules/system.go index 80a6a46006..de70e93af9 100644 --- a/dot/rpc/modules/system.go +++ b/dot/rpc/modules/system.go @@ -21,6 +21,7 @@ import ( "errors" "math/big" "net/http" + "strings" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/crypto" @@ -280,3 +281,21 @@ func (sm *SystemModule) LocalPeerId(r *http.Request, req *EmptyRequest, res *str *res = base58.Encode([]byte(netstate.PeerID)) return nil } + +// AddReservedPeer adds a reserved peer. The string parameter should encode a p2p multiaddr. +func (sm *SystemModule) AddReservedPeer(r *http.Request, req *StringRequest, res *[]byte) error { + if strings.TrimSpace(req.String) == "" { + return errors.New("cannot add an empty reserved peer") + } + + return sm.networkAPI.AddReservedPeers(req.String) +} + +// RemoveReservedPeer remove a reserved peer. The string should encode only the PeerId +func (sm *SystemModule) RemoveReservedPeer(r *http.Request, req *StringRequest, res *[]byte) error { + if strings.TrimSpace(req.String) == "" { + return errors.New("cannot remove an empty reserved peer") + } + + return sm.networkAPI.RemoveReservedPeers(req.String) +} diff --git a/dot/rpc/modules/system_test.go b/dot/rpc/modules/system_test.go index 7519925fd7..afd04255f3 100644 --- a/dot/rpc/modules/system_test.go +++ b/dot/rpc/modules/system_test.go @@ -402,7 +402,7 @@ func TestLocalListenAddresses(t *testing.T) { } mockNetAPI := new(mocks.MockNetworkAPI) - mockNetAPI.On("NetworkState").Return(mockedNetState) + mockNetAPI.On("NetworkState").Return(mockedNetState).Once() res := make([]string, 0) @@ -414,6 +414,10 @@ func TestLocalListenAddresses(t *testing.T) { require.Len(t, res, 1) require.Equal(t, res[0], ma.String()) + + mockNetAPI.On("NetworkState").Return(common.NetworkState{Multiaddrs: []multiaddr.Multiaddr{}}).Once() + err = sysmodule.LocalListenAddresses(nil, nil, &res) + require.Error(t, err, "multiaddress list is empty") } func TestLocalPeerId(t *testing.T) { @@ -441,3 +445,57 @@ func TestLocalPeerId(t *testing.T) { err = sysmodules.LocalPeerId(nil, nil, &res) require.Error(t, err) } + +func TestAddReservedPeer(t *testing.T) { + t.Run("Test Add and Remove reserved peers with success", func(t *testing.T) { + networkMock := new(mocks.MockNetworkAPI) + networkMock.On("AddReservedPeers", mock.AnythingOfType("string")).Return(nil).Once() + networkMock.On("RemoveReservedPeers", mock.AnythingOfType("string")).Return(nil).Once() + + multiAddrPeer := "/ip4/198.51.100.19/tcp/30333/p2p/QmSk5HQbn6LhUwDiNMseVUjuRYhEtYj4aUZ6WfWoGURpdV" + sysModule := &SystemModule{ + networkAPI: networkMock, + } + + var b *[]byte + err := sysModule.AddReservedPeer(nil, &StringRequest{String: multiAddrPeer}, b) + require.NoError(t, err) + require.Nil(t, b) + + peerID := "QmSk5HQbn6LhUwDiNMseVUjuRYhEtYj4aUZ6WfWoGURpdV" + err = sysModule.RemoveReservedPeer(nil, &StringRequest{String: peerID}, b) + require.NoError(t, err) + require.Nil(t, b) + }) + + t.Run("Test Add and Remove reserved peers without success", func(t *testing.T) { + networkMock := new(mocks.MockNetworkAPI) + networkMock.On("AddReservedPeers", mock.AnythingOfType("string")).Return(errors.New("some problems")).Once() + networkMock.On("RemoveReservedPeers", mock.AnythingOfType("string")).Return(errors.New("other problems")).Once() + + sysModule := &SystemModule{ + networkAPI: networkMock, + } + + var b *[]byte + err := sysModule.AddReservedPeer(nil, &StringRequest{String: ""}, b) + require.Error(t, err, "cannot add an empty reserved peer") + require.Nil(t, b) + + multiAddrPeer := "/ip4/198.51.100.19/tcp/30333/p2p/QmSk5HQbn6LhUwDiNMseVUjuRYhEtYj4aUZ6WfWoGURpdV" + err = sysModule.AddReservedPeer(nil, &StringRequest{String: multiAddrPeer}, b) + require.Error(t, err, "some problems") + require.Nil(t, b) + + peerID := "QmSk5HQbn6LhUwDiNMseVUjuRYhEtYj4aUZ6WfWoGURpdV" + err = sysModule.RemoveReservedPeer(nil, &StringRequest{String: peerID}, b) + require.Error(t, err, "other problems") + require.Nil(t, b) + }) + + t.Run("Test trying to add or remove peers with empty or white space request", func(t *testing.T) { + sysModule := &SystemModule{} + require.Error(t, sysModule.AddReservedPeer(nil, &StringRequest{String: ""}, nil)) + require.Error(t, sysModule.RemoveReservedPeer(nil, &StringRequest{String: " "}, nil)) + }) +} diff --git a/dot/rpc/service_test.go b/dot/rpc/service_test.go index 1661028388..4f943a7443 100644 --- a/dot/rpc/service_test.go +++ b/dot/rpc/service_test.go @@ -33,7 +33,7 @@ func TestNewService(t *testing.T) { } func TestService_Methods(t *testing.T) { - qtySystemMethods := 13 + qtySystemMethods := 15 qtyRPCMethods := 1 qtyAuthorMethods := 8 diff --git a/dot/rpc/subscription/websocket_test.go b/dot/rpc/subscription/websocket_test.go index 9f5d9d74af..f3475376c2 100644 --- a/dot/rpc/subscription/websocket_test.go +++ b/dot/rpc/subscription/websocket_test.go @@ -2,10 +2,7 @@ package subscription import ( "fmt" - "log" "math/big" - "net/http" - "os" "testing" "time" @@ -20,47 +17,15 @@ import ( "github.com/stretchr/testify/require" ) -var upgrader = websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { return true }, -} - -var wsconn = &WSConn{ - Subscriptions: make(map[uint32]Listener), -} - -func handler(w http.ResponseWriter, r *http.Request) { - c, err := upgrader.Upgrade(w, r, nil) - if err != nil { - log.Print("upgrade:", err) - return - } - defer c.Close() - - wsconn.Wsconn = c - wsconn.HandleComm() -} - -func TestMain(m *testing.M) { - http.HandleFunc("/", handler) - - go func() { - err := http.ListenAndServe("localhost:8546", nil) - if err != nil { - log.Fatal("error", err) - } - }() - time.Sleep(time.Millisecond * 100) +func TestWSConn_HandleComm(t *testing.T) { + wsconn, c, cancel := setupWSConn(t) + wsconn.Subscriptions = make(map[uint32]Listener) + defer cancel() - // Start all tests - os.Exit(m.Run()) -} + go wsconn.HandleComm() + time.Sleep(time.Second * 2) -func TestWSConn_HandleComm(t *testing.T) { - c, _, err := websocket.DefaultDialer.Dial("ws://localhost:8546", nil) //nolint - if err != nil { - log.Fatal("dial:", err) - } - defer c.Close() + fmt.Println("ws defined") // test storageChangeListener res, err := wsconn.initStorageChangeListener(1, nil)