Skip to content

Commit

Permalink
Improve performance of UDPMux map lookups
Browse files Browse the repository at this point in the history
UDPMux is using a map to lookup addresses of each packets.
Unfortunately the key is based on a string and each time we
want to check the map, a conversion of the UDP address to string
is made (.String()) which is expensive.

This CR replace the string key by a binary key called ipPort. This
structure contains a netip.Addr field and ipPort could be used as
a map key
  • Loading branch information
sebapeti authored and Sean-Der committed Mar 25, 2024
1 parent 52f2075 commit 66051b6
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 22 deletions.
2 changes: 2 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ var (
errWriteSTUNMessage = errors.New("failed to send STUN message")
errWriteSTUNMessageToIceConn = errors.New("failed to write STUN message to ICE connection")
errXORMappedAddrTimeout = errors.New("timeout while waiting for XORMappedAddr")
errFailedToCastUDPAddr = errors.New("failed to cast net.Addr to net.UDPAddr")
errInvalidIPAddress = errors.New("invalid ip address")

// UDPMuxDefault should not listen on unspecified address, but to keep backward compatibility, don't return error now.
// will be used in the future.
Expand Down
52 changes: 40 additions & 12 deletions udp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"io"
"net"
"net/netip"
"os"
"strings"
"sync"
Expand Down Expand Up @@ -36,7 +37,7 @@ type UDPMuxDefault struct {
connsIPv4, connsIPv6 map[string]*udpMuxedConn

addressMapMu sync.RWMutex
addressMap map[string]*udpMuxedConn
addressMap map[ipPort]*udpMuxedConn

// Buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
Expand All @@ -51,8 +52,9 @@ const maxAddrSize = 512

// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
Logger logging.LeveledLogger
UDPConn net.PacketConn
UDPConnString string

// Required for gathering local addresses
// in case a un UDPConn is passed which does not
Expand Down Expand Up @@ -103,9 +105,10 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
}
}
}
params.UDPConnString = params.UDPConn.LocalAddr().String()

m := &UDPMuxDefault{
addressMap: map[string]*udpMuxedConn{},
addressMap: map[ipPort]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
Expand Down Expand Up @@ -142,7 +145,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConnString != addr.String() {
return nil, errInvalidAddress
}

Expand Down Expand Up @@ -246,7 +249,7 @@ func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}

func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr ipPort) {
if m.IsClosed() {
return
}
Expand All @@ -260,7 +263,7 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
}
m.addressMap[addr] = conn

m.params.Logger.Debugf("Registered %s for %s", addr, conn.params.Key)
m.params.Logger.Debugf("Registered %s for %s", addr.addr.String(), conn.params.Key)
}

func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
Expand Down Expand Up @@ -296,15 +299,20 @@ func (m *UDPMuxDefault) connWorker() {
return
}

udpAddr, ok := addr.(*net.UDPAddr)
netUDPAddr, ok := addr.(*net.UDPAddr)
if !ok {
logger.Errorf("Underlying PacketConn did not return a UDPAddr")
return
}
udpAddr, err := newIPPort(netUDPAddr.IP, uint16(netUDPAddr.Port))
if err != nil {
logger.Errorf("Failed to create a new IP/Port host pair")
return
}

// If we have already seen this address dispatch to the appropriate destination
m.addressMapMu.Lock()
destinationConn := m.addressMap[addr.String()]
destinationConn := m.addressMap[udpAddr]
m.addressMapMu.Unlock()

// If we haven't seen this address before but is a STUN packet lookup by ufrag
Expand All @@ -325,19 +333,19 @@ func (m *UDPMuxDefault) connWorker() {
}

ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := udpAddr.IP.To4() == nil
isIPv6 := netUDPAddr.IP.To4() == nil

m.mu.Lock()
destinationConn, _ = m.getConn(ufrag, isIPv6)
m.mu.Unlock()
}

if destinationConn == nil {
m.params.Logger.Tracef("Dropping packet from %s, addr: %s", udpAddr.String(), addr.String())
m.params.Logger.Tracef("Dropping packet from %s, addr: %s", udpAddr.addr.String(), addr.String())
continue
}

if err = destinationConn.writePacket(buf[:n], udpAddr); err != nil {
if err = destinationConn.writePacket(buf[:n], netUDPAddr); err != nil {
m.params.Logger.Errorf("Failed to write packet: %v", err)
}
}
Expand All @@ -361,3 +369,23 @@ func newBufferHolder(size int) *bufferHolder {
buf: make([]byte, size),
}
}

type ipPort struct {
addr netip.Addr
port uint16
}

// newIPPort create a custom type of address based on netip.Addr and
// port. The underlying ip address passed is converted to IPv6 format
// to simplify ip address handling
func newIPPort(ip net.IP, port uint16) (ipPort, error) {
n, ok := netip.AddrFromSlice(ip.To16())
if !ok {
return ipPort{}, errInvalidIPAddress
}

return ipPort{
addr: n,
port: port,
}, nil
}
28 changes: 18 additions & 10 deletions udp_muxed_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type udpMuxedConnParams struct {
type udpMuxedConn struct {
params *udpMuxedConnParams
// Remote addresses that we have sent to on this conn
addresses []string
addresses []ipPort

// Channel holding incoming packets
buf *packetio.Buffer
Expand Down Expand Up @@ -81,9 +81,17 @@ func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
return 0, io.ErrClosedPipe
}
// Each time we write to a new address, we'll register it with the mux
addr := rAddr.String()
if !c.containsAddress(addr) {
c.addAddress(addr)
netUDPAddr, ok := rAddr.(*net.UDPAddr)
if !ok {
return 0, errFailedToCastUDPAddr
}

ipAndPort, err := newIPPort(netUDPAddr.IP, uint16(netUDPAddr.Port))
if err != nil {
return 0, err
}
if !c.containsAddress(ipAndPort) {
c.addAddress(ipAndPort)
}

return c.params.Mux.writeTo(buf, rAddr)
Expand Down Expand Up @@ -127,15 +135,15 @@ func (c *udpMuxedConn) isClosed() bool {
}
}

func (c *udpMuxedConn) getAddresses() []string {
func (c *udpMuxedConn) getAddresses() []ipPort {
c.mu.Lock()
defer c.mu.Unlock()
addresses := make([]string, len(c.addresses))
addresses := make([]ipPort, len(c.addresses))
copy(addresses, c.addresses)
return addresses
}

func (c *udpMuxedConn) addAddress(addr string) {
func (c *udpMuxedConn) addAddress(addr ipPort) {
c.mu.Lock()
c.addresses = append(c.addresses, addr)
c.mu.Unlock()
Expand All @@ -144,11 +152,11 @@ func (c *udpMuxedConn) addAddress(addr string) {
c.params.Mux.registerConnForAddress(c, addr)
}

func (c *udpMuxedConn) removeAddress(addr string) {
func (c *udpMuxedConn) removeAddress(addr ipPort) {
c.mu.Lock()
defer c.mu.Unlock()

newAddresses := make([]string, 0, len(c.addresses))
newAddresses := make([]ipPort, 0, len(c.addresses))
for _, a := range c.addresses {
if a != addr {
newAddresses = append(newAddresses, a)
Expand All @@ -158,7 +166,7 @@ func (c *udpMuxedConn) removeAddress(addr string) {
c.addresses = newAddresses
}

func (c *udpMuxedConn) containsAddress(addr string) bool {
func (c *udpMuxedConn) containsAddress(addr ipPort) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.addresses {
Expand Down

0 comments on commit 66051b6

Please sign in to comment.