Skip to content

Commit

Permalink
Use new pion/transport Net interface
Browse files Browse the repository at this point in the history
This change adapts pion/ice to use a new interface for most network
related operations. The interface was formerly a simple struct vnet.Net
which was originally intended to facilicate testing. By replacing it
with an interface we have greater flexibility and allow users to hook
into the networking stack by providing their own implementation of
the interface.
  • Loading branch information
stv0g committed Feb 8, 2023
1 parent c8cff3a commit 0194bd6
Show file tree
Hide file tree
Showing 26 changed files with 186 additions and 111 deletions.
50 changes: 27 additions & 23 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package ice

import (
"context"
"fmt"
"net"
"strings"
"sync"
Expand All @@ -13,8 +14,10 @@ import (
"github.com/pion/logging"
"github.com/pion/mdns"
"github.com/pion/stun"
"github.com/pion/transport/packetio"
"github.com/pion/transport/vnet"
"github.com/pion/transport/v2"
"github.com/pion/transport/v2/packetio"
"github.com/pion/transport/v2/stdnet"
"github.com/pion/transport/v2/vnet"
"golang.org/x/net/proxy"
)

Expand Down Expand Up @@ -123,7 +126,7 @@ type Agent struct {
loggerFactory logging.LoggerFactory
log logging.LeveledLogger

net *vnet.Net
net transport.Net
tcpMux TCPMux
udpMux UDPMux
udpMuxSrflx UniversalUDPMux
Expand Down Expand Up @@ -262,21 +265,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
}
log := loggerFactory.NewLogger("ice")

var mDNSConn *mdns.Conn
mDNSConn, mDNSMode, err = createMulticastDNS(mDNSMode, mDNSName, log)
// Opportunistic mDNS: If we can't open the connection, that's ok: we
// can continue without it.
if err != nil {
log.Warnf("Failed to initialize mDNS %s: %v", mDNSName, err)
}
closeMDNSConn := func() {
if mDNSConn != nil {
if mdnsCloseErr := mDNSConn.Close(); mdnsCloseErr != nil {
log.Warnf("Failed to close mDNS: %v", mdnsCloseErr)
}
}
}

startedCtx, startedFn := context.WithCancel(context.Background())

a := &Agent{
Expand Down Expand Up @@ -307,7 +295,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit

mDNSMode: mDNSMode,
mDNSName: mDNSName,
mDNSConn: mDNSConn,

gatherCandidateCancel: func() {},

Expand All @@ -330,11 +317,28 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
a.udpMuxSrflx = config.UDPMuxSrflx

if a.net == nil {
a.net = vnet.NewNet(nil)
} else if a.net.IsVirtual() {
a.log.Warn("vnet is enabled")
a.net, err = stdnet.NewNet()
if err != nil {
return nil, fmt.Errorf("failed to create network: %w", err)
}
} else if _, isVirtual := a.net.(*vnet.Net); isVirtual {
a.log.Warn("virtual network is enabled")
if a.mDNSMode != MulticastDNSModeDisabled {
a.log.Warn("vnet does not support mDNS yet")
a.log.Warn("virtual network does not support mDNS yet")
}
}

a.mDNSConn, mDNSMode, err = createMulticastDNS(a.net, mDNSMode, mDNSName, log)
// Opportunistic mDNS: If we can't open the connection, that's ok: we
// can continue without it.
if err != nil {
log.Warnf("Failed to initialize mDNS %s: %v", mDNSName, err)
}
closeMDNSConn := func() {
if a.mDNSConn != nil {
if mdnsCloseErr := a.mDNSConn.Close(); mdnsCloseErr != nil {
log.Warnf("Failed to close mDNS: %v", mdnsCloseErr)
}
}
}

Expand Down
6 changes: 3 additions & 3 deletions agent_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"time"

"github.com/pion/logging"
"github.com/pion/transport/vnet"
"github.com/pion/transport/v2"
"golang.org/x/net/proxy"
)

Expand Down Expand Up @@ -130,8 +130,8 @@ type AgentConfig struct {
RelayAcceptanceMinWait *time.Duration

// Net is the our abstracted network interface for internal development purpose only
// (see github.com/pion/transport/vnet)
Net *vnet.Net
// (see https://github.com/pion/transport)
Net transport.Net

// InterfaceFilter is a function that you can use in order to whitelist or blacklist
// the interfaces which are used to gather ICE candidates.
Expand Down
26 changes: 15 additions & 11 deletions agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (

"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/transport/test"
"github.com/pion/transport/vnet"
"github.com/pion/transport/v2/test"
"github.com/pion/transport/v2/vnet"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -349,14 +349,16 @@ func TestConnectivityOnStartup(t *testing.T) {
})
assert.NoError(t, err)

net0 := vnet.NewNet(&vnet.NetConfig{
net0, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.1"},
})
assert.NoError(t, err)
assert.NoError(t, wan.AddNet(net0))

net1 := vnet.NewNet(&vnet.NetConfig{
net1, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.2"},
})
assert.NoError(t, err)
assert.NoError(t, wan.AddNet(net1))

assert.NoError(t, wan.Start())
Expand All @@ -366,10 +368,9 @@ func TestConnectivityOnStartup(t *testing.T) {

KeepaliveInterval := time.Hour
cfg0 := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
Net: net0,

NetworkTypes: supportedNetworkTypes(),
MulticastDNSMode: MulticastDNSModeDisabled,
Net: net0,
KeepaliveInterval: &KeepaliveInterval,
CheckInterval: &KeepaliveInterval,
}
Expand Down Expand Up @@ -1707,9 +1708,10 @@ func TestGetSelectedCandidatePair(t *testing.T) {
})
assert.NoError(t, err)

net := vnet.NewNet(&vnet.NetConfig{
net, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.1"},
})
assert.NoError(t, err)
assert.NoError(t, wan.AddNet(net))

assert.NoError(t, wan.Start())
Expand Down Expand Up @@ -1765,14 +1767,16 @@ func TestAcceptAggressiveNomination(t *testing.T) {
})
assert.NoError(t, err)

net0 := vnet.NewNet(&vnet.NetConfig{
net0, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.1"},
})
assert.NoError(t, err)
assert.NoError(t, wan.AddNet(net0))

net1 := vnet.NewNet(&vnet.NetConfig{
net1, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.2", "192.168.0.3", "192.168.0.4"},
})
assert.NoError(t, err)
assert.NoError(t, wan.AddNet(net1))

assert.NoError(t, wan.Start())
Expand Down
2 changes: 1 addition & 1 deletion agent_udpmux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"time"

"github.com/pion/logging"
"github.com/pion/transport/test"
"github.com/pion/transport/v2/test"
"github.com/stretchr/testify/require"
)

Expand Down
2 changes: 1 addition & 1 deletion candidate_relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"testing"
"time"

"github.com/pion/transport/test"
"github.com/pion/transport/v2/test"
"github.com/pion/turn/v2"
"github.com/stretchr/testify/assert"
)
Expand Down
2 changes: 1 addition & 1 deletion candidate_server_reflexive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"testing"
"time"

"github.com/pion/transport/test"
"github.com/pion/transport/v2/test"
"github.com/pion/turn/v2"
"github.com/stretchr/testify/assert"
)
Expand Down
33 changes: 24 additions & 9 deletions connectivity_vnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (

"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/transport/test"
"github.com/pion/transport/vnet"
"github.com/pion/transport/v2/test"
"github.com/pion/transport/v2/vnet"
"github.com/pion/turn/v2"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -54,9 +54,12 @@ func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) {
return nil, err
}

wanNet := vnet.NewNet(&vnet.NetConfig{
wanNet, err := vnet.NewNet(&vnet.NetConfig{
StaticIP: vnetSTUNServerIP, // will be assigned to eth0
})
if err != nil {
return nil, err
}

err = wan.AddNet(wanNet)
if err != nil {
Expand All @@ -83,9 +86,13 @@ func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) {
return nil, err
}

net0 := vnet.NewNet(&vnet.NetConfig{
net0, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{vnetLocalIPA},
})
if err != nil {
return nil, err
}

err = lan0.AddNet(net0)
if err != nil {
return nil, err
Expand Down Expand Up @@ -116,9 +123,13 @@ func buildVNet(natType0, natType1 *vnet.NATType) (*virtualNet, error) {
return nil, err
}

net1 := vnet.NewNet(&vnet.NetConfig{
net1, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{vnetLocalIPB},
})
if err != nil {
return nil, err
}

err = lan1.AddNet(net1)
if err != nil {
return nil, err
Expand Down Expand Up @@ -475,14 +486,16 @@ func TestDisconnectedToConnected(t *testing.T) {
return atomic.LoadUint64(&dropAllData) != 1
})

net0 := vnet.NewNet(&vnet.NetConfig{
net0, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.1"},
})
assert.NoError(t, err)
assert.NoError(t, wan.AddNet(net0))

net1 := vnet.NewNet(&vnet.NetConfig{
net1, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.2"},
})
assert.NoError(t, err)
assert.NoError(t, wan.AddNet(net1))

assert.NoError(t, wan.Start())
Expand Down Expand Up @@ -581,14 +594,16 @@ func TestWriteUseValidPair(t *testing.T) {
return true
})

net0 := vnet.NewNet(&vnet.NetConfig{
net0, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.1"},
})
assert.NoError(t, err)
assert.NoError(t, wan.AddNet(net0))

net1 := vnet.NewNet(&vnet.NetConfig{
net1, err := vnet.NewNet(&vnet.NetConfig{
StaticIPs: []string{"192.168.0.2"},
})
assert.NoError(t, err)
assert.NoError(t, wan.AddNet(net1))

assert.NoError(t, wan.Start())
Expand Down
37 changes: 29 additions & 8 deletions gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,13 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli
locConn = turn.NewSTUNConn(conn)

case url.Proto == ProtoTypeTCP && url.Scheme == SchemeTypeTURN:
tcpAddr, connectErr := net.ResolveTCPAddr(NetworkTypeTCP4.String(), TURNServerAddr)
tcpAddr, connectErr := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), TURNServerAddr)
if connectErr != nil {
a.log.Warnf("Failed to resolve TCP Addr %s: %v", TURNServerAddr, connectErr)
return
}

conn, connectErr := net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
conn, connectErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
if connectErr != nil {
a.log.Warnf("Failed to Dial TCP Addr %s: %v", TURNServerAddr, connectErr)
return
Expand All @@ -596,18 +596,24 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli
relayProtocol = tcp
locConn = turn.NewSTUNConn(conn)
case url.Proto == ProtoTypeUDP && url.Scheme == SchemeTypeTURNS:
udpAddr, connectErr := net.ResolveUDPAddr(network, TURNServerAddr)
udpAddr, connectErr := a.net.ResolveUDPAddr(network, TURNServerAddr)
if connectErr != nil {
a.log.Warnf("Failed to resolve UDP Addr %s: %v", TURNServerAddr, connectErr)
return
}

conn, connectErr := dtls.Dial(network, udpAddr, &dtls.Config{ //nolint:contextcheck
udpConn, dialErr := a.net.DialUDP("udp", nil, udpAddr)
if dialErr != nil {
a.log.Warnf("Failed to dial DTLS Address %s: %v", TURNServerAddr, connectErr)
return
}

conn, connectErr := dtls.ClientWithContext(ctx, udpConn, &dtls.Config{
ServerName: url.Host,
InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec
})
if connectErr != nil {
a.log.Warnf("Failed to Dial DTLS Addr %s: %v", TURNServerAddr, connectErr)
a.log.Warnf("Failed to create DTLS client: %v", TURNServerAddr, connectErr)
return
}

Expand All @@ -616,13 +622,28 @@ func (a *Agent) gatherCandidatesRelay(ctx context.Context, urls []*URL) { //noli
relayProtocol = "dtls"
locConn = &fakePacketConn{conn}
case url.Proto == ProtoTypeTCP && url.Scheme == SchemeTypeTURNS:
conn, connectErr := tls.Dial(NetworkTypeTCP4.String(), TURNServerAddr, &tls.Config{
tcpAddr, err := a.net.ResolveTCPAddr(NetworkTypeTCP4.String(), TURNServerAddr)
if err != nil {
a.log.Warnf("Failed to resolve relay address %s: %v", TURNServerAddr, err)
return
}

tcpConn, dialErr := a.net.DialTCP(NetworkTypeTCP4.String(), nil, tcpAddr)
if dialErr != nil {
a.log.Warnf("Failed to connect to relay: %v", dialErr)
return
}

conn := tls.Client(tcpConn, &tls.Config{
InsecureSkipVerify: a.insecureSkipVerify, //nolint:gosec
})
if connectErr != nil {
a.log.Warnf("Failed to Dial TLS Addr %s: %v", TURNServerAddr, connectErr)

if err := conn.HandshakeContext(ctx); err != nil {
tcpConn.Close()
a.log.Warnf("Failed to connect to relay: %v", dialErr)
return
}

RelAddr = conn.LocalAddr().(*net.TCPAddr).IP.String() //nolint:forcetypeassert
RelPort = conn.LocalAddr().(*net.TCPAddr).Port //nolint:forcetypeassert
relayProtocol = "tls"
Expand Down
2 changes: 1 addition & 1 deletion gather_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
"github.com/pion/logging"
"github.com/pion/stun"
"github.com/pion/transport/test"
"github.com/pion/transport/v2/test"
"github.com/pion/turn/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down

0 comments on commit 0194bd6

Please sign in to comment.