Skip to content

Commit

Permalink
Prevent heap allocations from udp muxed conn addr keys
Browse files Browse the repository at this point in the history
  • Loading branch information
paulwe committed Apr 19, 2024
1 parent eb30993 commit 5dc51d2
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 12 deletions.
8 changes: 4 additions & 4 deletions udp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type UDPMuxDefault struct {
connsIPv4, connsIPv6 map[string]*udpMuxedConn

addressMapMu sync.RWMutex
addressMap map[string]*udpMuxedConn
addressMap map[udpMuxedConnAddr]*udpMuxedConn

// Buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
Expand Down Expand Up @@ -105,7 +105,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
}

m := &UDPMuxDefault{
addressMap: map[string]*udpMuxedConn{},
addressMap: map[udpMuxedConnAddr]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
Expand Down Expand Up @@ -246,7 +246,7 @@ func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}

func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr udpMuxedConnAddr) {
if m.IsClosed() {
return
}
Expand Down Expand Up @@ -304,7 +304,7 @@ func (m *UDPMuxDefault) connWorker() {

// If we have already seen this address dispatch to the appropriate destination
m.addressMapMu.Lock()
destinationConn := m.addressMap[addr.String()]
destinationConn := m.addressMap[newUDPMuxedConnAddr(udpAddr)]
m.addressMapMu.Unlock()

// If we haven't seen this address before but is a STUN packet lookup by ufrag
Expand Down
27 changes: 19 additions & 8 deletions udp_muxed_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ import (
"github.com/pion/transport/v2/packetio"
)

type udpMuxedConnAddr struct {
ip [16]byte
port uint16
}

func newUDPMuxedConnAddr(addr *net.UDPAddr) (a udpMuxedConnAddr) {
copy(a.ip[:], addr.IP)
a.port = uint16(addr.Port)
return a
}

type udpMuxedConnParams struct {
Mux *UDPMuxDefault
AddrPool *sync.Pool
Expand All @@ -26,7 +37,7 @@ type udpMuxedConnParams struct {
type udpMuxedConn struct {
params *udpMuxedConnParams
// Remote addresses that we have sent to on this conn
addresses []string
addresses []udpMuxedConnAddr

// Channel holding incoming packets
buf *packetio.Buffer
Expand Down Expand Up @@ -81,7 +92,7 @@ func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
return 0, io.ErrClosedPipe
}
// Each time we write to a new address, we'll register it with the mux
addr := rAddr.String()
addr := newUDPMuxedConnAddr(rAddr.(*net.UDPAddr))
if !c.containsAddress(addr) {
c.addAddress(addr)
}
Expand Down Expand Up @@ -127,15 +138,15 @@ func (c *udpMuxedConn) isClosed() bool {
}
}

func (c *udpMuxedConn) getAddresses() []string {
func (c *udpMuxedConn) getAddresses() []udpMuxedConnAddr {
c.mu.Lock()
defer c.mu.Unlock()
addresses := make([]string, len(c.addresses))
addresses := make([]udpMuxedConnAddr, len(c.addresses))
copy(addresses, c.addresses)
return addresses
}

func (c *udpMuxedConn) addAddress(addr string) {
func (c *udpMuxedConn) addAddress(addr udpMuxedConnAddr) {
c.mu.Lock()
c.addresses = append(c.addresses, addr)
c.mu.Unlock()
Expand All @@ -144,11 +155,11 @@ func (c *udpMuxedConn) addAddress(addr string) {
c.params.Mux.registerConnForAddress(c, addr)
}

func (c *udpMuxedConn) removeAddress(addr string) {
func (c *udpMuxedConn) removeAddress(addr udpMuxedConnAddr) {
c.mu.Lock()
defer c.mu.Unlock()

newAddresses := make([]string, 0, len(c.addresses))
newAddresses := make([]udpMuxedConnAddr, 0, len(c.addresses))
for _, a := range c.addresses {
if a != addr {
newAddresses = append(newAddresses, a)
Expand All @@ -158,7 +169,7 @@ func (c *udpMuxedConn) removeAddress(addr string) {
c.addresses = newAddresses
}

func (c *udpMuxedConn) containsAddress(addr string) bool {
func (c *udpMuxedConn) containsAddress(addr udpMuxedConnAddr) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.addresses {
Expand Down

0 comments on commit 5dc51d2

Please sign in to comment.