-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into anton/p2p-priorities
- Loading branch information
Showing
2 changed files
with
515 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,394 @@ | ||
package p2p | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"io" | ||
"sync" | ||
|
||
"github.com/tendermint/tendermint/crypto" | ||
"github.com/tendermint/tendermint/crypto/ed25519" | ||
"github.com/tendermint/tendermint/libs/log" | ||
"github.com/tendermint/tendermint/p2p/conn" | ||
) | ||
|
||
const ( | ||
MemoryProtocol Protocol = "memory" | ||
) | ||
|
||
// MemoryNetwork is an in-memory "network" that uses Go channels to communicate | ||
// between endpoints. Transport endpoints are created with CreateTransport. It | ||
// is primarily used for testing. | ||
type MemoryNetwork struct { | ||
logger log.Logger | ||
|
||
mtx sync.RWMutex | ||
transports map[ID]*MemoryTransport | ||
} | ||
|
||
// NewMemoryNetwork creates a new in-memory network. | ||
func NewMemoryNetwork(logger log.Logger) *MemoryNetwork { | ||
return &MemoryNetwork{ | ||
logger: logger, | ||
transports: map[ID]*MemoryTransport{}, | ||
} | ||
} | ||
|
||
// CreateTransport creates a new memory transport and endpoint for the given | ||
// NodeInfo and private key. Use GenerateTransport() to autogenerate a random | ||
// key and node info. | ||
// | ||
// The transport immediately begins listening on the endpoint "memory:<id>", and | ||
// can be accessed by other transports in the same memory network. | ||
func (n *MemoryNetwork) CreateTransport( | ||
nodeInfo NodeInfo, | ||
privKey crypto.PrivKey, | ||
) (*MemoryTransport, error) { | ||
nodeID := nodeInfo.DefaultNodeID | ||
if nodeID == "" { | ||
return nil, errors.New("no node ID") | ||
} | ||
t := newMemoryTransport(n, nodeInfo, privKey) | ||
|
||
n.mtx.Lock() | ||
defer n.mtx.Unlock() | ||
if _, ok := n.transports[nodeID]; ok { | ||
return nil, fmt.Errorf("transport with node ID %q already exists", nodeID) | ||
} | ||
n.transports[nodeID] = t | ||
return t, nil | ||
} | ||
|
||
// GenerateTransport generates a new transport endpoint by generating a random | ||
// private key and node info. The endpoint address can be obtained via | ||
// Transport.Endpoints(). | ||
func (n *MemoryNetwork) GenerateTransport() *MemoryTransport { | ||
privKey := ed25519.GenPrivKey() | ||
nodeID := PubKeyToID(privKey.PubKey()) | ||
nodeInfo := NodeInfo{ | ||
DefaultNodeID: nodeID, | ||
ListenAddr: fmt.Sprintf("%v:%v", MemoryProtocol, nodeID), | ||
} | ||
t, err := n.CreateTransport(nodeInfo, privKey) | ||
if err != nil { | ||
// GenerateTransport is only used for testing, and the likelihood of | ||
// generating a duplicate node ID is very low, so we'll panic. | ||
panic(err) | ||
} | ||
return t | ||
} | ||
|
||
// GetTransport looks up a transport in the network, returning nil if not found. | ||
func (n *MemoryNetwork) GetTransport(id ID) *MemoryTransport { | ||
n.mtx.RLock() | ||
defer n.mtx.RUnlock() | ||
return n.transports[id] | ||
} | ||
|
||
// RemoveTransport removes a transport from the network and closes it. | ||
func (n *MemoryNetwork) RemoveTransport(id ID) error { | ||
n.mtx.Lock() | ||
t, ok := n.transports[id] | ||
delete(n.transports, id) | ||
n.mtx.Unlock() | ||
|
||
if ok { | ||
// Close may recursively call RemoveTransport() again, but this is safe | ||
// because we've already removed the transport from the map above. | ||
return t.Close() | ||
} | ||
return nil | ||
} | ||
|
||
// MemoryTransport is an in-memory transport that's primarily meant for testing. | ||
// It communicates between endpoints using Go channels. To dial a different | ||
// endpoint, both endpoints/transports must be in the same MemoryNetwork. | ||
type MemoryTransport struct { | ||
network *MemoryNetwork | ||
nodeInfo NodeInfo | ||
privKey crypto.PrivKey | ||
logger log.Logger | ||
|
||
acceptCh chan *MemoryConnection | ||
closeCh chan struct{} | ||
closeOnce sync.Once | ||
} | ||
|
||
// newMemoryTransport creates a new in-memory transport in the given network. | ||
// Callers should use MemoryNetwork.CreateTransport() or GenerateTransport() | ||
// to create transports, this is for internal use by MemoryNetwork. | ||
func newMemoryTransport( | ||
network *MemoryNetwork, | ||
nodeInfo NodeInfo, | ||
privKey crypto.PrivKey, | ||
) *MemoryTransport { | ||
return &MemoryTransport{ | ||
network: network, | ||
nodeInfo: nodeInfo, | ||
privKey: privKey, | ||
logger: network.logger.With("local", | ||
fmt.Sprintf("%v:%v", MemoryProtocol, nodeInfo.DefaultNodeID)), | ||
|
||
acceptCh: make(chan *MemoryConnection), | ||
closeCh: make(chan struct{}), | ||
} | ||
} | ||
|
||
// Accept implements Transport. | ||
func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) { | ||
select { | ||
case conn := <-t.acceptCh: | ||
t.logger.Info("accepted connection from peer", "remote", conn.RemoteEndpoint()) | ||
return conn, nil | ||
case <-t.closeCh: | ||
return nil, ErrTransportClosed{} | ||
case <-ctx.Done(): | ||
return nil, ctx.Err() | ||
} | ||
} | ||
|
||
// Dial implements Transport. | ||
func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) { | ||
if endpoint.Protocol != MemoryProtocol { | ||
return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol) | ||
} | ||
if endpoint.Path == "" { | ||
return nil, errors.New("no path") | ||
} | ||
if endpoint.PeerID == "" { | ||
return nil, errors.New("no peer ID") | ||
} | ||
t.logger.Info("dialing peer", "remote", endpoint) | ||
|
||
peerTransport := t.network.GetTransport(ID(endpoint.Path)) | ||
if peerTransport == nil { | ||
return nil, fmt.Errorf("unknown peer %q", endpoint.Path) | ||
} | ||
inCh := make(chan memoryMessage, 1) | ||
outCh := make(chan memoryMessage, 1) | ||
closeCh := make(chan struct{}) | ||
closeOnce := sync.Once{} | ||
closer := func() bool { | ||
closed := false | ||
closeOnce.Do(func() { | ||
close(closeCh) | ||
closed = true | ||
}) | ||
return closed | ||
} | ||
|
||
outConn := newMemoryConnection(t, peerTransport, inCh, outCh, closeCh, closer) | ||
inConn := newMemoryConnection(peerTransport, t, outCh, inCh, closeCh, closer) | ||
|
||
select { | ||
case peerTransport.acceptCh <- inConn: | ||
return outConn, nil | ||
case <-peerTransport.closeCh: | ||
return nil, ErrTransportClosed{} | ||
case <-ctx.Done(): | ||
return nil, ctx.Err() | ||
} | ||
} | ||
|
||
// DialAccept is a convenience function that dials a peer MemoryTransport and | ||
// returns both ends of the connection (A to B and B to A). | ||
func (t *MemoryTransport) DialAccept( | ||
ctx context.Context, | ||
peer *MemoryTransport, | ||
) (Connection, Connection, error) { | ||
endpoints := peer.Endpoints() | ||
if len(endpoints) == 0 { | ||
return nil, nil, fmt.Errorf("peer %q not listening on any endpoints", peer.nodeInfo.DefaultNodeID) | ||
} | ||
|
||
acceptCh := make(chan Connection, 1) | ||
errCh := make(chan error, 1) | ||
go func() { | ||
conn, err := peer.Accept(ctx) | ||
errCh <- err | ||
acceptCh <- conn | ||
}() | ||
|
||
outConn, err := t.Dial(ctx, endpoints[0]) | ||
if err != nil { | ||
return nil, nil, err | ||
} | ||
if err = <-errCh; err != nil { | ||
return nil, nil, err | ||
} | ||
inConn := <-acceptCh | ||
|
||
return outConn, inConn, nil | ||
} | ||
|
||
// Close implements Transport. | ||
func (t *MemoryTransport) Close() error { | ||
err := t.network.RemoveTransport(t.nodeInfo.DefaultNodeID) | ||
t.closeOnce.Do(func() { | ||
close(t.closeCh) | ||
}) | ||
t.logger.Info("stopped accepting connections") | ||
return err | ||
} | ||
|
||
// Endpoints implements Transport. | ||
func (t *MemoryTransport) Endpoints() []Endpoint { | ||
select { | ||
case <-t.closeCh: | ||
return []Endpoint{} | ||
default: | ||
return []Endpoint{{ | ||
Protocol: MemoryProtocol, | ||
PeerID: t.nodeInfo.DefaultNodeID, | ||
Path: string(t.nodeInfo.DefaultNodeID), | ||
}} | ||
} | ||
} | ||
|
||
// SetChannelDescriptors implements Transport. | ||
func (t *MemoryTransport) SetChannelDescriptors(chDescs []*conn.ChannelDescriptor) { | ||
} | ||
|
||
// MemoryConnection is an in-memory connection between two transports (nodes). | ||
type MemoryConnection struct { | ||
logger log.Logger | ||
local *MemoryTransport | ||
remote *MemoryTransport | ||
|
||
receiveCh <-chan memoryMessage | ||
sendCh chan<- memoryMessage | ||
closeCh <-chan struct{} | ||
close func() bool | ||
} | ||
|
||
// memoryMessage is used to pass messages internally in the connection. | ||
type memoryMessage struct { | ||
channel byte | ||
message []byte | ||
} | ||
|
||
// newMemoryConnection creates a new MemoryConnection. It takes all channels | ||
// (including the closeCh signal channel) on construction, such that they can be | ||
// shared between both ends of the connection. | ||
func newMemoryConnection( | ||
local *MemoryTransport, | ||
remote *MemoryTransport, | ||
receiveCh <-chan memoryMessage, | ||
sendCh chan<- memoryMessage, | ||
closeCh <-chan struct{}, | ||
close func() bool, | ||
) *MemoryConnection { | ||
c := &MemoryConnection{ | ||
local: local, | ||
remote: remote, | ||
receiveCh: receiveCh, | ||
sendCh: sendCh, | ||
closeCh: closeCh, | ||
close: close, | ||
} | ||
c.logger = c.local.logger.With("remote", c.RemoteEndpoint()) | ||
return c | ||
} | ||
|
||
// ReceiveMessage implements Connection. | ||
func (c *MemoryConnection) ReceiveMessage() (chID byte, msg []byte, err error) { | ||
// check close first, since channels are buffered | ||
select { | ||
case <-c.closeCh: | ||
return 0, nil, io.EOF | ||
default: | ||
} | ||
|
||
select { | ||
case msg := <-c.receiveCh: | ||
c.logger.Debug("received message", "channel", msg.channel, "message", msg.message) | ||
return msg.channel, msg.message, nil | ||
case <-c.closeCh: | ||
return 0, nil, io.EOF | ||
} | ||
} | ||
|
||
// SendMessage implements Connection. | ||
func (c *MemoryConnection) SendMessage(chID byte, msg []byte) (bool, error) { | ||
// check close first, since channels are buffered | ||
select { | ||
case <-c.closeCh: | ||
return false, io.EOF | ||
default: | ||
} | ||
|
||
select { | ||
case c.sendCh <- memoryMessage{channel: chID, message: msg}: | ||
c.logger.Debug("sent message", "channel", chID, "message", msg) | ||
return true, nil | ||
case <-c.closeCh: | ||
return false, io.EOF | ||
} | ||
} | ||
|
||
// TrySendMessage implements Connection. | ||
func (c *MemoryConnection) TrySendMessage(chID byte, msg []byte) (bool, error) { | ||
// check close first, since channels are buffered | ||
select { | ||
case <-c.closeCh: | ||
return false, io.EOF | ||
default: | ||
} | ||
|
||
select { | ||
case c.sendCh <- memoryMessage{channel: chID, message: msg}: | ||
c.logger.Debug("sent message", "channel", chID, "message", msg) | ||
return true, nil | ||
case <-c.closeCh: | ||
return false, io.EOF | ||
default: | ||
return false, nil | ||
} | ||
} | ||
|
||
// Close closes the connection. | ||
func (c *MemoryConnection) Close() error { | ||
if c.close() { | ||
c.logger.Info("closed connection") | ||
} | ||
return nil | ||
} | ||
|
||
// FlushClose flushes all pending sends and then closes the connection. | ||
func (c *MemoryConnection) FlushClose() error { | ||
return c.Close() | ||
} | ||
|
||
// LocalEndpoint returns the local endpoint for the connection. | ||
func (c *MemoryConnection) LocalEndpoint() Endpoint { | ||
return Endpoint{ | ||
PeerID: c.local.nodeInfo.DefaultNodeID, | ||
Protocol: MemoryProtocol, | ||
Path: string(c.local.nodeInfo.DefaultNodeID), | ||
} | ||
} | ||
|
||
// RemoteEndpoint returns the remote endpoint for the connection. | ||
func (c *MemoryConnection) RemoteEndpoint() Endpoint { | ||
return Endpoint{ | ||
PeerID: c.remote.nodeInfo.DefaultNodeID, | ||
Protocol: MemoryProtocol, | ||
Path: string(c.remote.nodeInfo.DefaultNodeID), | ||
} | ||
} | ||
|
||
// PubKey returns the remote peer's public key. | ||
func (c *MemoryConnection) PubKey() crypto.PubKey { | ||
return c.remote.privKey.PubKey() | ||
} | ||
|
||
// NodeInfo returns the remote peer's node info. | ||
func (c *MemoryConnection) NodeInfo() NodeInfo { | ||
return c.remote.nodeInfo | ||
} | ||
|
||
// Status returns the current connection status. | ||
func (c *MemoryConnection) Status() conn.ConnectionStatus { | ||
return conn.ConnectionStatus{} | ||
} |
Oops, something went wrong.