Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stcpr transport establishment issues #813

Merged
merged 10 commits into from
Jun 21, 2021
32 changes: 27 additions & 5 deletions pkg/snet/directtp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/skycoin/skywire/pkg/snet/directtp/tphandshake"
"github.com/skycoin/skywire/pkg/snet/directtp/tplistener"
"github.com/skycoin/skywire/pkg/snet/directtp/tptypes"
"github.com/skycoin/skywire/pkg/util/netutil"
)

const (
Expand Down Expand Up @@ -136,7 +137,14 @@ func (c *client) Serve() error {
c.log.Errorf("Failed to extract port from addr %v: %v", err)
return
}

hasPublic, err := netutil.HasPublicIP()
if err != nil {
c.log.Errorf("Failed to check for public IP: %v", err)
}
if !hasPublic {
c.log.Infof("Not binding STCPR: no public IP address found")
return
}
if err := c.conf.AddressResolver.BindSTCPR(context.Background(), port); err != nil {
c.log.Errorf("Failed to bind STCPR: %v", err)
return
Expand Down Expand Up @@ -265,7 +273,7 @@ func (c *client) Dial(ctx context.Context, rPK cipher.PubKey, rPort uint16) (*tp

c.log.Infof("Resolved PK %v to visor data %v", rPK, visorData)

conn, err := c.dialVisor(visorData)
conn, err := c.dialVisor(ctx, visorData)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -313,6 +321,20 @@ func (c *client) dial(addr string) (net.Conn, error) {
}
}

func (c *client) dialContext(ctx context.Context, addr string) (net.Conn, error) {
dialer := net.Dialer{}
switch c.conf.Type {
case tptypes.STCP, tptypes.STCPR:
return dialer.DialContext(ctx, "tcp", addr)

case tptypes.SUDPH:
return c.dialUDPWithTimeout(addr)

default:
return nil, ErrUnknownTransportType
}
}

func (c *client) listen(addr string) (net.Listener, error) {
switch c.conf.Type {
case tptypes.STCP, tptypes.STCPR:
Expand Down Expand Up @@ -405,7 +427,7 @@ func (c *client) dialUDPWithTimeout(addr string) (net.Conn, error) {
}
}

func (c *client) dialVisor(visorData arclient.VisorData) (net.Conn, error) {
func (c *client) dialVisor(ctx context.Context, visorData arclient.VisorData) (net.Conn, error) {
if visorData.IsLocal {
for _, host := range visorData.Addresses {
addr := net.JoinHostPort(host, visorData.Port)
Expand All @@ -416,7 +438,7 @@ func (c *client) dialVisor(visorData arclient.VisorData) (net.Conn, error) {
}
}

conn, err := c.dial(addr)
conn, err := c.dialContext(ctx, addr)
if err == nil {
return conn, nil
}
Expand All @@ -434,7 +456,7 @@ func (c *client) dialVisor(visorData arclient.VisorData) (net.Conn, error) {
}
}

return c.dial(addr)
return c.dialContext(ctx, addr)
}

// Listen creates a new listener for sudp.
Expand Down
23 changes: 10 additions & 13 deletions pkg/transport/entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ type Entry struct {
// ID is the Transport ID that uniquely identifies the Transport.
ID uuid.UUID `json:"t_id"`

// Edges contains the public keys of the Transport's edge nodes (should only have 2 edges and the least-significant edge should come first).
// Edges contains the public keys of the Transport's edge nodes
// (should only have 2 edges and the first edge is transport original initiator).
Edges [2]cipher.PubKey `json:"edges"`

// Type represents the transport type.
Expand All @@ -47,20 +48,16 @@ type Entry struct {
}

// MakeEntry creates a new transport entry
func MakeEntry(pk1, pk2 cipher.PubKey, tpType string, public bool, label Label) Entry {
return Entry{
ID: MakeTransportID(pk1, pk2, tpType),
Edges: SortEdges(pk1, pk2),
func MakeEntry(initiator, target cipher.PubKey, tpType string, public bool, label Label) Entry {
entry := Entry{
ID: MakeTransportID(initiator, target, tpType),
jdknives marked this conversation as resolved.
Show resolved Hide resolved
Type: tpType,
Public: public,
Label: label,
}
}

// SetEdges sets edges of Entry
func (e *Entry) SetEdges(localPK, remotePK cipher.PubKey) {
e.ID = MakeTransportID(localPK, remotePK, e.Type)
e.Edges = SortEdges(localPK, remotePK)
entry.Edges[0] = initiator
entry.Edges[1] = target
return entry
}

// RemoteEdge returns the remote edge's public key.
Expand Down Expand Up @@ -106,8 +103,8 @@ func (e *Entry) String() string {
res += fmt.Sprintf("\ttype: %s\n", e.Type)
res += fmt.Sprintf("\tid: %s\n", e.ID)
res += "\tedges:\n"
res += fmt.Sprintf("\t\tedge 1: %s\n", e.Edges[0])
res += fmt.Sprintf("\t\tedge 2: %s\n", e.Edges[1])
res += fmt.Sprintf("\t\tedge 1 (initiator): %s\n", e.Edges[0])
res += fmt.Sprintf("\t\tedge 2 (target): %s\n", e.Edges[1])
return res
}

Expand Down
15 changes: 0 additions & 15 deletions pkg/transport/entry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,6 @@ func TestNewEntry(t *testing.T) {
assert.NotNil(t, entryBA.ID)
}

func TestEntry_SetEdges(t *testing.T) {
pkA, _ := cipher.GenerateKeyPair()
pkB, _ := cipher.GenerateKeyPair()

entryAB, entryBA := transport.Entry{}, transport.Entry{}

entryAB.SetEdges(pkA, pkB)
entryBA.SetEdges(pkA, pkB)

assert.True(t, entryAB.Edges == entryBA.Edges)
assert.True(t, entryAB.ID == entryBA.ID)
assert.NotNil(t, entryAB.ID)
assert.NotNil(t, entryBA.ID)
}

func ExampleSignedEntry_Sign() {
pkA, skA := cipher.GenerateKeyPair()
pkB, skB := cipher.GenerateKeyPair()
Expand Down
49 changes: 37 additions & 12 deletions pkg/transport/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,21 @@ import (
"github.com/skycoin/skywire/pkg/snet"
)

func makeEntryFromTpConn(conn *snet.Conn) Entry {
return MakeEntry(conn.LocalPK(), conn.RemotePK(), conn.Network(), true, LabelUser)
type hsResponse byte

const (
responseFailure hsResponse = iota
responseOK
responseSignatureErr
responseInvalidEntry
)

func makeEntryFromTpConn(conn *snet.Conn, isInitiator bool) Entry {
initiator, target := conn.LocalPK(), conn.RemotePK()
if !isInitiator {
initiator, target = target, initiator
}
return MakeEntry(initiator, target, conn.Network(), true, LabelUser)
}

func compareEntries(expected, received *Entry) error {
Expand Down Expand Up @@ -86,7 +99,7 @@ func (hs SettlementHS) Do(ctx context.Context, dc DiscoveryClient, conn *snet.Co
func MakeSettlementHS(init bool) SettlementHS {
// initiating logic.
initHS := func(ctx context.Context, dc DiscoveryClient, conn *snet.Conn, sk cipher.SecKey) (err error) {
entry := makeEntryFromTpConn(conn)
entry := makeEntryFromTpConn(conn, true)

// TODO(evanlinjin): Probably not needed as this is called in mTp already. Need to double check.
//defer func() {
Expand All @@ -110,23 +123,33 @@ func MakeSettlementHS(init bool) SettlementHS {
if _, err := io.ReadFull(conn, accepted); err != nil {
return fmt.Errorf("failed to read response: %w", err)
}
if accepted[0] == 0 {
switch hsResponse(accepted[0]) {
jdknives marked this conversation as resolved.
Show resolved Hide resolved
case responseOK:
return nil
case responseFailure:
return fmt.Errorf("transport settlement rejected by remote")
case responseInvalidEntry:
return fmt.Errorf("invalid entry")
case responseSignatureErr:
return fmt.Errorf("signature error")
default:
return fmt.Errorf("invalid remote response")
}
return nil
}

// responding logic.
respHS := func(ctx context.Context, dc DiscoveryClient, conn *snet.Conn, sk cipher.SecKey) error {
entry := makeEntryFromTpConn(conn)
entry := makeEntryFromTpConn(conn, false)

// receive, verify and sign entry.
recvSE, err := receiveAndVerifyEntry(conn, &entry, conn.RemotePK())
if err != nil {
writeHsResponse(conn, responseInvalidEntry) //nolint:errcheck, gosec
return err
}

if err := recvSE.Sign(conn.LocalPK(), sk); err != nil {
writeHsResponse(conn, responseSignatureErr) //nolint:errcheck, gosec
return fmt.Errorf("failed to sign received entry: %w", err)
}

Expand All @@ -141,16 +164,18 @@ func MakeSettlementHS(init bool) SettlementHS {
log.WithError(err).Error("Failed to register transport.")
}
}

// inform initiating visor.
if _, err := conn.Write([]byte{1}); err != nil {
return fmt.Errorf("failed to accept transport settlement: write failed: %w", err)
}
return nil
return writeHsResponse(conn, responseOK)
}

if init {
return initHS
}
return respHS
}

func writeHsResponse(w io.Writer, response hsResponse) error {
if _, err := w.Write([]byte{byte(response)}); err != nil {
return fmt.Errorf("failed to accept transport settlement: write failed: %w", err)
}
return nil
}
30 changes: 21 additions & 9 deletions pkg/transport/managed_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ var (

// Constants associated with transport redial loop.
const (
tpInitBO = time.Millisecond * 500
tpMaxBO = time.Minute
tpTries = 0
tpFactor = 2
tpInitBO = time.Millisecond * 500
tpMaxBO = time.Minute
tpTries = 0
tpFactor = 2
tpTimeout = time.Second * 3 // timeout for a single try
)

// ManagedTransportConfig is a configuration for managed transport.
Expand Down Expand Up @@ -92,15 +93,19 @@ type ManagedTransport struct {
}

// NewManagedTransport creates a new ManagedTransport.
func NewManagedTransport(conf ManagedTransportConfig) *ManagedTransport {
func NewManagedTransport(conf ManagedTransportConfig, isInitiator bool) *ManagedTransport {
initiator, target := conf.Net.LocalPK(), conf.RemotePK
if !isInitiator {
initiator, target = target, initiator
}
mt := &ManagedTransport{
log: logging.MustGetLogger(fmt.Sprintf("tp:%s", conf.RemotePK.String()[:6])),
rPK: conf.RemotePK,
netName: conf.NetName,
n: conf.Net,
dc: conf.DC,
ls: conf.LS,
Entry: MakeEntry(conf.Net.LocalPK(), conf.RemotePK, conf.NetName, true, conf.TransportLabel),
Entry: MakeEntry(initiator, target, conf.NetName, true, conf.TransportLabel),
LogEntry: new(LogEntry),
connCh: make(chan struct{}, 1),
done: make(chan struct{}),
Expand Down Expand Up @@ -204,8 +209,8 @@ func (mt *ManagedTransport) Serve(readCh chan<- routing.Packet) {
continue
}

// Only least significant edge is responsible for redialing.
if !mt.isLeastSignificantEdge() {
// Only initiator is responsible for redialing.
if !mt.isInitiator() {
continue
}

Expand Down Expand Up @@ -371,16 +376,23 @@ func (mt *ManagedTransport) redialLoop(ctx context.Context) error {

// Only redial when there is no underlying conn.
return retry.Do(ctx, func() (err error) {
tryCtx, cancel := context.WithTimeout(ctx, tpTimeout)
defer cancel()
mt.connMx.Lock()
if mt.conn == nil {
err = mt.redial(ctx)
err = mt.redial(tryCtx)
}
mt.connMx.Unlock()
return err
})
}

func (mt *ManagedTransport) isLeastSignificantEdge() bool {
sorted := SortEdges(mt.Entry.Edges[0], mt.Entry.Edges[1])
return sorted[0] == mt.n.LocalPK()
}

func (mt *ManagedTransport) isInitiator() bool {
return mt.Entry.EdgeIndex(mt.n.LocalPK()) == 0
}

Expand Down
11 changes: 6 additions & 5 deletions pkg/transport/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ func (tm *Manager) initTransports(ctx context.Context) {
remote = entry.Entry.RemoteEdge(tm.Conf.PubKey)
tpID = entry.Entry.ID
)
if _, err := tm.saveTransport(remote, tpType, entry.Entry.Label); err != nil {
isInitiator := tm.n.LocalPK() == entry.Entry.Edges[0]
if _, err := tm.saveTransport(remote, isInitiator, tpType, entry.Entry.Label); err != nil {
tm.Logger.Warnf("INIT: failed to init tp: type(%s) remote(%s) tpID(%s)", tpType, remote, tpID)
} else {
tm.Logger.Debugf("Successfully initialized TP %v", *entry.Entry)
Expand Down Expand Up @@ -230,7 +231,7 @@ func (tm *Manager) acceptTransport(ctx context.Context, lis *snet.Listener) erro
NetName: lis.Network(),
AfterClosed: tm.afterTPClosed,
TransportLabel: LabelUser,
})
}, false)

go func() {
mTp.Serve(tm.readCh)
Expand Down Expand Up @@ -303,7 +304,7 @@ func (tm *Manager) SaveTransport(ctx context.Context, remote cipher.PubKey, tpTy
}

for {
mTp, err := tm.saveTransport(remote, tpType, label)
mTp, err := tm.saveTransport(remote, true, tpType, label)
if err != nil {
return nil, fmt.Errorf("save transport: %w", err)
}
Expand Down Expand Up @@ -346,7 +347,7 @@ func isSTCPTableError(remotePK cipher.PubKey, err error) bool {
return err.Error() == fmt.Sprintf("pk table: entry of %s does not exist", remotePK.String())
}

func (tm *Manager) saveTransport(remote cipher.PubKey, netName string, label Label) (*ManagedTransport, error) {
func (tm *Manager) saveTransport(remote cipher.PubKey, initiator bool, netName string, label Label) (*ManagedTransport, error) {
tm.mx.Lock()
defer tm.mx.Unlock()
if !snet.IsKnownNetwork(netName) {
Expand All @@ -372,7 +373,7 @@ func (tm *Manager) saveTransport(remote cipher.PubKey, netName string, label Lab
NetName: netName,
AfterClosed: afterTPClosed,
TransportLabel: label,
})
}, initiator)

if mTp.netName == tptypes.STCPR {
ar := mTp.n.Conf().ARClient
Expand Down
3 changes: 2 additions & 1 deletion pkg/transport/tpdclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func newTestEntry() *transport.Entry {
Type: "dmsg",
Public: true,
}
entry.SetEdges(pk1, testPubKey)
entry.Edges[0] = pk1
entry.Edges[1] = testPubKey

return entry
}
Expand Down
15 changes: 15 additions & 0 deletions pkg/util/netutil/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,18 @@ func DefaultNetworkInterfaceIPs() ([]net.IP, error) {
}
return localIPs, nil
}

// HasPublicIP returns true if this machine has at least one
// publically available IP address
func HasPublicIP() (bool, error) {
localIPs, err := LocalNetworkInterfaceIPs()
if err != nil {
return false, err
}
for _, IP := range localIPs {
if IsPublicIP(IP) {
return true, nil
}
}
return false, nil
}