Skip to content

Commit

Permalink
Support IPv6 from mDNS
Browse files Browse the repository at this point in the history
  • Loading branch information
edaniels authored and Sean-Der committed Mar 27, 2024
1 parent ae1ba6f commit 39c0392
Show file tree
Hide file tree
Showing 32 changed files with 831 additions and 390 deletions.
14 changes: 10 additions & 4 deletions active_tcp.go
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"io"
"net"
"net/netip"
"sync/atomic"
"time"

Expand All @@ -20,7 +21,7 @@ type activeTCPConn struct {
closed int32
}

func newActiveTCPConn(ctx context.Context, localAddress, remoteAddress string, log logging.LeveledLogger) (a *activeTCPConn) {
func newActiveTCPConn(ctx context.Context, localAddress string, remoteAddress netip.AddrPort, log logging.LeveledLogger) (a *activeTCPConn) {
a = &activeTCPConn{
readBuffer: packetio.NewBuffer(),
writeBuffer: packetio.NewBuffer(),
Expand All @@ -42,12 +43,11 @@ func newActiveTCPConn(ctx context.Context, localAddress, remoteAddress string, l
dialer := &net.Dialer{
LocalAddr: laddr,
}
conn, err := dialer.DialContext(ctx, "tcp", remoteAddress)
conn, err := dialer.DialContext(ctx, "tcp", remoteAddress.String())
if err != nil {
log.Infof("Failed to dial TCP address %s: %v", remoteAddress, err)
return
}

a.remoteAddr.Store(conn.RemoteAddr())

go func() {
Expand Down Expand Up @@ -95,8 +95,9 @@ func (a *activeTCPConn) ReadFrom(buff []byte) (n int, srcAddr net.Addr, err erro
return 0, nil, io.ErrClosedPipe
}

srcAddr = a.RemoteAddr()
n, err = a.readBuffer.Read(buff)
// RemoteAddr is assuredly set *after* we can read from the buffer
srcAddr = a.RemoteAddr()
return
}

Expand All @@ -123,6 +124,11 @@ func (a *activeTCPConn) LocalAddr() net.Addr {
return &net.TCPAddr{}
}

// RemoteAddr returns the remote address of the connection which is only
// set once a background goroutine has successfully dialed. That means
// this may return ":0" for the address prior to that happening. If this
// becomes an issue, we can introduce a synchronization point between Dial
// and these methods.
func (a *activeTCPConn) RemoteAddr() net.Addr {
if v, ok := a.remoteAddr.Load().(*net.TCPAddr); ok {
return v
Expand Down
39 changes: 26 additions & 13 deletions active_tcp_test.go
Expand Up @@ -8,6 +8,7 @@ package ice

import (
"net"
"net/netip"
"testing"
"time"

Expand All @@ -17,21 +18,21 @@ import (
"github.com/stretchr/testify/require"
)

func getLocalIPAddress(t *testing.T, networkType NetworkType) net.IP {
func getLocalIPAddress(t *testing.T, networkType NetworkType) netip.Addr {
net, err := stdnet.NewNet()
require.NoError(t, err)
localIPs, err := localInterfaces(net, nil, nil, []NetworkType{networkType}, false)
_, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{networkType}, false)
require.NoError(t, err)
require.NotEmpty(t, localIPs)
return localIPs[0]
require.NotEmpty(t, localAddrs)
return localAddrs[0]
}

func ipv6Available(t *testing.T) bool {
net, err := stdnet.NewNet()
require.NoError(t, err)
localIPs, err := localInterfaces(net, nil, nil, []NetworkType{NetworkTypeTCP6}, false)
_, localAddrs, err := localInterfaces(net, problematicNetworkInterfaces, nil, []NetworkType{NetworkTypeTCP6}, false)
require.NoError(t, err)
return len(localIPs) > 0
return len(localAddrs) > 0
}

func TestActiveTCP(t *testing.T) {
Expand All @@ -43,8 +44,9 @@ func TestActiveTCP(t *testing.T) {
type testCase struct {
name string
networkTypes []NetworkType
listenIPAddress net.IP
listenIPAddress netip.Addr
selectedPairNetworkType string
useMDNS bool
}

testCases := []testCase{
Expand All @@ -69,12 +71,16 @@ func TestActiveTCP(t *testing.T) {
networkTypes: []NetworkType{NetworkTypeTCP6},
listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP6),
selectedPairNetworkType: tcp,
// if we don't use mDNS, we will very liekly be filtering out location tracked ips.
useMDNS: true,
},
testCase{
name: "UDP is preferred over TCP6", // This fails some time
name: "UDP is preferred over TCP6",
networkTypes: supportedNetworkTypes(),
listenIPAddress: getLocalIPAddress(t, NetworkTypeTCP6),
selectedPairNetworkType: udp,
// if we don't use mDNS, we will very liekly be filtering out location tracked ips.
useMDNS: true,
},
)
}
Expand All @@ -84,8 +90,9 @@ func TestActiveTCP(t *testing.T) {
r := require.New(t)

listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: testCase.listenIPAddress,
IP: testCase.listenIPAddress.AsSlice(),
Port: listenPort,
Zone: testCase.listenIPAddress.Zone(),
})
r.NoError(err)
defer func() {
Expand All @@ -107,14 +114,18 @@ func TestActiveTCP(t *testing.T) {
r.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")

hostAcceptanceMinWait := 100 * time.Millisecond
passiveAgent, err := NewAgent(&AgentConfig{
cfg := &AgentConfig{
TCPMux: tcpMux,
CandidateTypes: []CandidateType{CandidateTypeHost},
NetworkTypes: testCase.networkTypes,
LoggerFactory: loggerFactory,
IncludeLoopback: true,
HostAcceptanceMinWait: &hostAcceptanceMinWait,
})
InterfaceFilter: problematicNetworkInterfaces,
}
if testCase.useMDNS {
cfg.MulticastDNSMode = MulticastDNSModeQueryAndGather
}
passiveAgent, err := NewAgent(cfg)
r.NoError(err)
r.NotNil(passiveAgent)

Expand All @@ -123,6 +134,7 @@ func TestActiveTCP(t *testing.T) {
NetworkTypes: testCase.networkTypes,
LoggerFactory: loggerFactory,
HostAcceptanceMinWait: &hostAcceptanceMinWait,
InterfaceFilter: problematicNetworkInterfaces,
})
r.NoError(err)
r.NotNil(activeAgent)
Expand Down Expand Up @@ -166,7 +178,8 @@ func TestActiveTCP_NonBlocking(t *testing.T) {
defer test.TimeOut(time.Second * 5).Stop()

cfg := &AgentConfig{
NetworkTypes: supportedNetworkTypes(),
NetworkTypes: supportedNetworkTypes(),
InterfaceFilter: problematicNetworkInterfaces,
}

aAgent, err := NewAgent(cfg)
Expand Down
114 changes: 94 additions & 20 deletions addr.go
Expand Up @@ -4,52 +4,126 @@
package ice

import (
"fmt"
"net"
"net/netip"
)

func parseMulticastAnswerAddr(in net.Addr) (net.IP, bool) {
func addrWithOptionalZone(addr netip.Addr, zone string) netip.Addr {
if zone == "" {
return addr
}
if addr.Is6() && (addr.IsLinkLocalUnicast() || addr.IsLinkLocalMulticast()) {
return addr.WithZone(zone)
}
return addr
}

// parseAddrFromIface should only be used when it's known the address belongs to that interface.
// e.g. it's LocalAddress on a listener.
func parseAddrFromIface(in net.Addr, ifcName string) (netip.Addr, int, NetworkType, error) {
addr, port, nt, err := parseAddr(in)
if err != nil {
return netip.Addr{}, 0, 0, err
}
if _, ok := in.(*net.IPNet); ok {
// net.IPNet does not have a Zone but we provide it from the interface
addr = addrWithOptionalZone(addr, ifcName)
}
return addr, port, nt, nil
}

func parseAddr(in net.Addr) (netip.Addr, int, NetworkType, error) {
switch addr := in.(type) {
case *net.IPNet:
ipAddr, err := ipAddrToNetIP(addr.IP, "")
if err != nil {
return netip.Addr{}, 0, 0, err
}
return ipAddr, 0, 0, nil
case *net.IPAddr:
return addr.IP, true
ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone)
if err != nil {
return netip.Addr{}, 0, 0, err
}
return ipAddr, 0, 0, nil
case *net.UDPAddr:
return addr.IP, true
ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone)
if err != nil {
return netip.Addr{}, 0, 0, err
}
var nt NetworkType
if ipAddr.Is4() {
nt = NetworkTypeUDP4
} else {
nt = NetworkTypeUDP6
}
return ipAddr, addr.Port, nt, nil
case *net.TCPAddr:
return addr.IP, true
ipAddr, err := ipAddrToNetIP(addr.IP, addr.Zone)
if err != nil {
return netip.Addr{}, 0, 0, err
}
var nt NetworkType
if ipAddr.Is4() {
nt = NetworkTypeTCP4
} else {
nt = NetworkTypeTCP6
}
return ipAddr, addr.Port, nt, nil
default:
return netip.Addr{}, 0, 0, addrParseError{in}
}
return nil, false
}

func parseAddr(in net.Addr) (net.IP, int, NetworkType, bool) {
switch addr := in.(type) {
case *net.UDPAddr:
return addr.IP, addr.Port, NetworkTypeUDP4, true
case *net.TCPAddr:
return addr.IP, addr.Port, NetworkTypeTCP4, true
type addrParseError struct {
addr net.Addr
}

func (e addrParseError) Error() string {
return fmt.Sprintf("do not know how to parse address type %T", e.addr)
}

type ipConvertError struct {
ip []byte
}

func (e ipConvertError) Error() string {
return fmt.Sprintf("failed to convert IP '%s' to netip.Addr", e.ip)
}

func ipAddrToNetIP(ip []byte, zone string) (netip.Addr, error) {
netIPAddr, ok := netip.AddrFromSlice(ip)
if !ok {
return netip.Addr{}, ipConvertError{ip}
}
return nil, 0, 0, false
// we'd rather have an IPv4-mapped IPv6 become IPv4 so that it is usable.
netIPAddr = netIPAddr.Unmap()
netIPAddr = addrWithOptionalZone(netIPAddr, zone)
return netIPAddr, nil
}

func createAddr(network NetworkType, ip net.IP, port int) net.Addr {
func createAddr(network NetworkType, ip netip.Addr, port int) net.Addr {
switch {
case network.IsTCP():
return &net.TCPAddr{IP: ip, Port: port}
return &net.TCPAddr{IP: ip.AsSlice(), Port: port, Zone: ip.Zone()}
default:
return &net.UDPAddr{IP: ip, Port: port}
return &net.UDPAddr{IP: ip.AsSlice(), Port: port, Zone: ip.Zone()}
}
}

func addrEqual(a, b net.Addr) bool {
aIP, aPort, aType, aOk := parseAddr(a)
if !aOk {
aIP, aPort, aType, aErr := parseAddr(a)
if aErr != nil {
return false
}

bIP, bPort, bType, bOk := parseAddr(b)
if !bOk {
bIP, bPort, bType, bErr := parseAddr(b)
if bErr != nil {
return false
}

return aType == bType && aIP.Equal(bIP) && aPort == bPort
return aType == bType && aIP.Compare(bIP) == 0 && aPort == bPort
}

// AddrPort is an IP and a port number.
Expand Down

0 comments on commit 39c0392

Please sign in to comment.