diff --git a/pkg/snet/directtp/client.go b/pkg/snet/directtp/client.go index 30764e38af..b95de6d149 100644 --- a/pkg/snet/directtp/client.go +++ b/pkg/snet/directtp/client.go @@ -24,6 +24,7 @@ import ( "github.com/skycoin/skywire/pkg/snet/directtp/tphandshake" "github.com/skycoin/skywire/pkg/snet/directtp/tplistener" "github.com/skycoin/skywire/pkg/snet/directtp/tptypes" + "github.com/skycoin/skywire/pkg/util/netutil" ) const ( @@ -136,7 +137,14 @@ func (c *client) Serve() error { c.log.Errorf("Failed to extract port from addr %v: %v", err) return } - + hasPublic, err := netutil.HasPublicIP() + if err != nil { + c.log.Errorf("Failed to check for public IP: %v", err) + } + if !hasPublic { + c.log.Infof("Not binding STCPR: no public IP address found") + return + } if err := c.conf.AddressResolver.BindSTCPR(context.Background(), port); err != nil { c.log.Errorf("Failed to bind STCPR: %v", err) return @@ -265,7 +273,7 @@ func (c *client) Dial(ctx context.Context, rPK cipher.PubKey, rPort uint16) (*tp c.log.Infof("Resolved PK %v to visor data %v", rPK, visorData) - conn, err := c.dialVisor(visorData) + conn, err := c.dialVisor(ctx, visorData) if err != nil { return nil, err } @@ -313,6 +321,20 @@ func (c *client) dial(addr string) (net.Conn, error) { } } +func (c *client) dialContext(ctx context.Context, addr string) (net.Conn, error) { + dialer := net.Dialer{} + switch c.conf.Type { + case tptypes.STCP, tptypes.STCPR: + return dialer.DialContext(ctx, "tcp", addr) + + case tptypes.SUDPH: + return c.dialUDPWithTimeout(addr) + + default: + return nil, ErrUnknownTransportType + } +} + func (c *client) listen(addr string) (net.Listener, error) { switch c.conf.Type { case tptypes.STCP, tptypes.STCPR: @@ -405,7 +427,7 @@ func (c *client) dialUDPWithTimeout(addr string) (net.Conn, error) { } } -func (c *client) dialVisor(visorData arclient.VisorData) (net.Conn, error) { +func (c *client) dialVisor(ctx context.Context, visorData arclient.VisorData) (net.Conn, error) { if visorData.IsLocal { for _, host := range visorData.Addresses { addr := net.JoinHostPort(host, visorData.Port) @@ -416,7 +438,7 @@ func (c *client) dialVisor(visorData arclient.VisorData) (net.Conn, error) { } } - conn, err := c.dial(addr) + conn, err := c.dialContext(ctx, addr) if err == nil { return conn, nil } @@ -434,7 +456,7 @@ func (c *client) dialVisor(visorData arclient.VisorData) (net.Conn, error) { } } - return c.dial(addr) + return c.dialContext(ctx, addr) } // Listen creates a new listener for sudp. diff --git a/pkg/transport/entry.go b/pkg/transport/entry.go index b7ed47036a..f474e6fb12 100644 --- a/pkg/transport/entry.go +++ b/pkg/transport/entry.go @@ -33,7 +33,8 @@ type Entry struct { // ID is the Transport ID that uniquely identifies the Transport. ID uuid.UUID `json:"t_id"` - // Edges contains the public keys of the Transport's edge nodes (should only have 2 edges and the least-significant edge should come first). + // Edges contains the public keys of the Transport's edge nodes + // (should only have 2 edges and the first edge is transport original initiator). Edges [2]cipher.PubKey `json:"edges"` // Type represents the transport type. @@ -47,20 +48,16 @@ type Entry struct { } // MakeEntry creates a new transport entry -func MakeEntry(pk1, pk2 cipher.PubKey, tpType string, public bool, label Label) Entry { - return Entry{ - ID: MakeTransportID(pk1, pk2, tpType), - Edges: SortEdges(pk1, pk2), +func MakeEntry(initiator, target cipher.PubKey, tpType string, public bool, label Label) Entry { + entry := Entry{ + ID: MakeTransportID(initiator, target, tpType), Type: tpType, Public: public, Label: label, } -} - -// SetEdges sets edges of Entry -func (e *Entry) SetEdges(localPK, remotePK cipher.PubKey) { - e.ID = MakeTransportID(localPK, remotePK, e.Type) - e.Edges = SortEdges(localPK, remotePK) + entry.Edges[0] = initiator + entry.Edges[1] = target + return entry } // RemoteEdge returns the remote edge's public key. @@ -106,8 +103,8 @@ func (e *Entry) String() string { res += fmt.Sprintf("\ttype: %s\n", e.Type) res += fmt.Sprintf("\tid: %s\n", e.ID) res += "\tedges:\n" - res += fmt.Sprintf("\t\tedge 1: %s\n", e.Edges[0]) - res += fmt.Sprintf("\t\tedge 2: %s\n", e.Edges[1]) + res += fmt.Sprintf("\t\tedge 1 (initiator): %s\n", e.Edges[0]) + res += fmt.Sprintf("\t\tedge 2 (target): %s\n", e.Edges[1]) return res } diff --git a/pkg/transport/entry_test.go b/pkg/transport/entry_test.go index 4cddd36805..862cbad6ba 100644 --- a/pkg/transport/entry_test.go +++ b/pkg/transport/entry_test.go @@ -23,21 +23,6 @@ func TestNewEntry(t *testing.T) { assert.NotNil(t, entryBA.ID) } -func TestEntry_SetEdges(t *testing.T) { - pkA, _ := cipher.GenerateKeyPair() - pkB, _ := cipher.GenerateKeyPair() - - entryAB, entryBA := transport.Entry{}, transport.Entry{} - - entryAB.SetEdges(pkA, pkB) - entryBA.SetEdges(pkA, pkB) - - assert.True(t, entryAB.Edges == entryBA.Edges) - assert.True(t, entryAB.ID == entryBA.ID) - assert.NotNil(t, entryAB.ID) - assert.NotNil(t, entryBA.ID) -} - func ExampleSignedEntry_Sign() { pkA, skA := cipher.GenerateKeyPair() pkB, skB := cipher.GenerateKeyPair() diff --git a/pkg/transport/handshake.go b/pkg/transport/handshake.go index 819f9245d1..8994471604 100644 --- a/pkg/transport/handshake.go +++ b/pkg/transport/handshake.go @@ -14,8 +14,21 @@ import ( "github.com/skycoin/skywire/pkg/snet" ) -func makeEntryFromTpConn(conn *snet.Conn) Entry { - return MakeEntry(conn.LocalPK(), conn.RemotePK(), conn.Network(), true, LabelUser) +type hsResponse byte + +const ( + responseFailure hsResponse = iota + responseOK + responseSignatureErr + responseInvalidEntry +) + +func makeEntryFromTpConn(conn *snet.Conn, isInitiator bool) Entry { + initiator, target := conn.LocalPK(), conn.RemotePK() + if !isInitiator { + initiator, target = target, initiator + } + return MakeEntry(initiator, target, conn.Network(), true, LabelUser) } func compareEntries(expected, received *Entry) error { @@ -86,7 +99,7 @@ func (hs SettlementHS) Do(ctx context.Context, dc DiscoveryClient, conn *snet.Co func MakeSettlementHS(init bool) SettlementHS { // initiating logic. initHS := func(ctx context.Context, dc DiscoveryClient, conn *snet.Conn, sk cipher.SecKey) (err error) { - entry := makeEntryFromTpConn(conn) + entry := makeEntryFromTpConn(conn, true) // TODO(evanlinjin): Probably not needed as this is called in mTp already. Need to double check. //defer func() { @@ -110,23 +123,33 @@ func MakeSettlementHS(init bool) SettlementHS { if _, err := io.ReadFull(conn, accepted); err != nil { return fmt.Errorf("failed to read response: %w", err) } - if accepted[0] == 0 { + switch hsResponse(accepted[0]) { + case responseOK: + return nil + case responseFailure: return fmt.Errorf("transport settlement rejected by remote") + case responseInvalidEntry: + return fmt.Errorf("invalid entry") + case responseSignatureErr: + return fmt.Errorf("signature error") + default: + return fmt.Errorf("invalid remote response") } - return nil } // responding logic. respHS := func(ctx context.Context, dc DiscoveryClient, conn *snet.Conn, sk cipher.SecKey) error { - entry := makeEntryFromTpConn(conn) + entry := makeEntryFromTpConn(conn, false) // receive, verify and sign entry. recvSE, err := receiveAndVerifyEntry(conn, &entry, conn.RemotePK()) if err != nil { + writeHsResponse(conn, responseInvalidEntry) //nolint:errcheck, gosec return err } if err := recvSE.Sign(conn.LocalPK(), sk); err != nil { + writeHsResponse(conn, responseSignatureErr) //nolint:errcheck, gosec return fmt.Errorf("failed to sign received entry: %w", err) } @@ -141,12 +164,7 @@ func MakeSettlementHS(init bool) SettlementHS { log.WithError(err).Error("Failed to register transport.") } } - - // inform initiating visor. - if _, err := conn.Write([]byte{1}); err != nil { - return fmt.Errorf("failed to accept transport settlement: write failed: %w", err) - } - return nil + return writeHsResponse(conn, responseOK) } if init { @@ -154,3 +172,10 @@ func MakeSettlementHS(init bool) SettlementHS { } return respHS } + +func writeHsResponse(w io.Writer, response hsResponse) error { + if _, err := w.Write([]byte{byte(response)}); err != nil { + return fmt.Errorf("failed to accept transport settlement: write failed: %w", err) + } + return nil +} diff --git a/pkg/transport/managed_transport.go b/pkg/transport/managed_transport.go index 63dc345e9d..2cbcf4874e 100644 --- a/pkg/transport/managed_transport.go +++ b/pkg/transport/managed_transport.go @@ -36,10 +36,11 @@ var ( // Constants associated with transport redial loop. const ( - tpInitBO = time.Millisecond * 500 - tpMaxBO = time.Minute - tpTries = 0 - tpFactor = 2 + tpInitBO = time.Millisecond * 500 + tpMaxBO = time.Minute + tpTries = 0 + tpFactor = 2 + tpTimeout = time.Second * 3 // timeout for a single try ) // ManagedTransportConfig is a configuration for managed transport. @@ -92,7 +93,11 @@ type ManagedTransport struct { } // NewManagedTransport creates a new ManagedTransport. -func NewManagedTransport(conf ManagedTransportConfig) *ManagedTransport { +func NewManagedTransport(conf ManagedTransportConfig, isInitiator bool) *ManagedTransport { + initiator, target := conf.Net.LocalPK(), conf.RemotePK + if !isInitiator { + initiator, target = target, initiator + } mt := &ManagedTransport{ log: logging.MustGetLogger(fmt.Sprintf("tp:%s", conf.RemotePK.String()[:6])), rPK: conf.RemotePK, @@ -100,7 +105,7 @@ func NewManagedTransport(conf ManagedTransportConfig) *ManagedTransport { n: conf.Net, dc: conf.DC, ls: conf.LS, - Entry: MakeEntry(conf.Net.LocalPK(), conf.RemotePK, conf.NetName, true, conf.TransportLabel), + Entry: MakeEntry(initiator, target, conf.NetName, true, conf.TransportLabel), LogEntry: new(LogEntry), connCh: make(chan struct{}, 1), done: make(chan struct{}), @@ -204,8 +209,8 @@ func (mt *ManagedTransport) Serve(readCh chan<- routing.Packet) { continue } - // Only least significant edge is responsible for redialing. - if !mt.isLeastSignificantEdge() { + // Only initiator is responsible for redialing. + if !mt.isInitiator() { continue } @@ -371,9 +376,11 @@ func (mt *ManagedTransport) redialLoop(ctx context.Context) error { // Only redial when there is no underlying conn. return retry.Do(ctx, func() (err error) { + tryCtx, cancel := context.WithTimeout(ctx, tpTimeout) + defer cancel() mt.connMx.Lock() if mt.conn == nil { - err = mt.redial(ctx) + err = mt.redial(tryCtx) } mt.connMx.Unlock() return err @@ -381,6 +388,11 @@ func (mt *ManagedTransport) redialLoop(ctx context.Context) error { } func (mt *ManagedTransport) isLeastSignificantEdge() bool { + sorted := SortEdges(mt.Entry.Edges[0], mt.Entry.Edges[1]) + return sorted[0] == mt.n.LocalPK() +} + +func (mt *ManagedTransport) isInitiator() bool { return mt.Entry.EdgeIndex(mt.n.LocalPK()) == 0 } diff --git a/pkg/transport/manager.go b/pkg/transport/manager.go index 37f1bc8e56..2486d484b1 100644 --- a/pkg/transport/manager.go +++ b/pkg/transport/manager.go @@ -191,7 +191,8 @@ func (tm *Manager) initTransports(ctx context.Context) { remote = entry.Entry.RemoteEdge(tm.Conf.PubKey) tpID = entry.Entry.ID ) - if _, err := tm.saveTransport(remote, tpType, entry.Entry.Label); err != nil { + isInitiator := tm.n.LocalPK() == entry.Entry.Edges[0] + if _, err := tm.saveTransport(remote, isInitiator, tpType, entry.Entry.Label); err != nil { tm.Logger.Warnf("INIT: failed to init tp: type(%s) remote(%s) tpID(%s)", tpType, remote, tpID) } else { tm.Logger.Debugf("Successfully initialized TP %v", *entry.Entry) @@ -230,7 +231,7 @@ func (tm *Manager) acceptTransport(ctx context.Context, lis *snet.Listener) erro NetName: lis.Network(), AfterClosed: tm.afterTPClosed, TransportLabel: LabelUser, - }) + }, false) go func() { mTp.Serve(tm.readCh) @@ -303,7 +304,7 @@ func (tm *Manager) SaveTransport(ctx context.Context, remote cipher.PubKey, tpTy } for { - mTp, err := tm.saveTransport(remote, tpType, label) + mTp, err := tm.saveTransport(remote, true, tpType, label) if err != nil { return nil, fmt.Errorf("save transport: %w", err) } @@ -346,7 +347,7 @@ func isSTCPTableError(remotePK cipher.PubKey, err error) bool { return err.Error() == fmt.Sprintf("pk table: entry of %s does not exist", remotePK.String()) } -func (tm *Manager) saveTransport(remote cipher.PubKey, netName string, label Label) (*ManagedTransport, error) { +func (tm *Manager) saveTransport(remote cipher.PubKey, initiator bool, netName string, label Label) (*ManagedTransport, error) { tm.mx.Lock() defer tm.mx.Unlock() if !snet.IsKnownNetwork(netName) { @@ -372,7 +373,7 @@ func (tm *Manager) saveTransport(remote cipher.PubKey, netName string, label Lab NetName: netName, AfterClosed: afterTPClosed, TransportLabel: label, - }) + }, initiator) if mTp.netName == tptypes.STCPR { ar := mTp.n.Conf().ARClient diff --git a/pkg/transport/tpdclient/client_test.go b/pkg/transport/tpdclient/client_test.go index c557c25779..d081df4512 100644 --- a/pkg/transport/tpdclient/client_test.go +++ b/pkg/transport/tpdclient/client_test.go @@ -45,7 +45,8 @@ func newTestEntry() *transport.Entry { Type: "dmsg", Public: true, } - entry.SetEdges(pk1, testPubKey) + entry.Edges[0] = pk1 + entry.Edges[1] = testPubKey return entry } diff --git a/pkg/util/netutil/net.go b/pkg/util/netutil/net.go index f01856f33b..e35c88b30c 100644 --- a/pkg/util/netutil/net.go +++ b/pkg/util/netutil/net.go @@ -103,3 +103,18 @@ func DefaultNetworkInterfaceIPs() ([]net.IP, error) { } return localIPs, nil } + +// HasPublicIP returns true if this machine has at least one +// publically available IP address +func HasPublicIP() (bool, error) { + localIPs, err := LocalNetworkInterfaceIPs() + if err != nil { + return false, err + } + for _, IP := range localIPs { + if IsPublicIP(IP) { + return true, nil + } + } + return false, nil +}