Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent heap allocations in udp muxed conn addr #683

Merged
merged 1 commit into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions udp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type UDPMuxDefault struct {
connsIPv4, connsIPv6 map[string]*udpMuxedConn

addressMapMu sync.RWMutex
addressMap map[string]*udpMuxedConn
addressMap map[udpMuxedConnAddr]*udpMuxedConn

// Buffer pool to recycle buffers for net.UDPAddr encodes/decodes
pool *sync.Pool
Expand Down Expand Up @@ -105,7 +105,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
}

m := &UDPMuxDefault{
addressMap: map[string]*udpMuxedConn{},
addressMap: map[udpMuxedConnAddr]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
Expand Down Expand Up @@ -246,7 +246,7 @@ func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}

func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr udpMuxedConnAddr) {
if m.IsClosed() {
return
}
Expand Down Expand Up @@ -304,7 +304,7 @@ func (m *UDPMuxDefault) connWorker() {

// If we have already seen this address dispatch to the appropriate destination
m.addressMapMu.Lock()
destinationConn := m.addressMap[addr.String()]
destinationConn := m.addressMap[newUDPMuxedConnAddr(udpAddr)]
m.addressMapMu.Unlock()

// If we haven't seen this address before but is a STUN packet lookup by ufrag
Expand Down
27 changes: 19 additions & 8 deletions udp_muxed_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ import (
"github.com/pion/transport/v2/packetio"
)

type udpMuxedConnAddr struct {
ip [16]byte
port uint16
}

func newUDPMuxedConnAddr(addr *net.UDPAddr) (a udpMuxedConnAddr) {
copy(a.ip[:], addr.IP.To16())
a.port = uint16(addr.Port)
return a
}

type udpMuxedConnParams struct {
Mux *UDPMuxDefault
AddrPool *sync.Pool
Expand All @@ -26,7 +37,7 @@ type udpMuxedConnParams struct {
type udpMuxedConn struct {
params *udpMuxedConnParams
// Remote addresses that we have sent to on this conn
addresses []string
addresses []udpMuxedConnAddr

// Channel holding incoming packets
buf *packetio.Buffer
Expand Down Expand Up @@ -81,7 +92,7 @@ func (c *udpMuxedConn) WriteTo(buf []byte, rAddr net.Addr) (n int, err error) {
return 0, io.ErrClosedPipe
}
// Each time we write to a new address, we'll register it with the mux
addr := rAddr.String()
addr := newUDPMuxedConnAddr(rAddr.(*net.UDPAddr))
if !c.containsAddress(addr) {
c.addAddress(addr)
}
Expand Down Expand Up @@ -127,15 +138,15 @@ func (c *udpMuxedConn) isClosed() bool {
}
}

func (c *udpMuxedConn) getAddresses() []string {
func (c *udpMuxedConn) getAddresses() []udpMuxedConnAddr {
c.mu.Lock()
defer c.mu.Unlock()
addresses := make([]string, len(c.addresses))
addresses := make([]udpMuxedConnAddr, len(c.addresses))
copy(addresses, c.addresses)
return addresses
}

func (c *udpMuxedConn) addAddress(addr string) {
func (c *udpMuxedConn) addAddress(addr udpMuxedConnAddr) {
c.mu.Lock()
c.addresses = append(c.addresses, addr)
c.mu.Unlock()
Expand All @@ -144,11 +155,11 @@ func (c *udpMuxedConn) addAddress(addr string) {
c.params.Mux.registerConnForAddress(c, addr)
}

func (c *udpMuxedConn) removeAddress(addr string) {
func (c *udpMuxedConn) removeAddress(addr udpMuxedConnAddr) {
c.mu.Lock()
defer c.mu.Unlock()

newAddresses := make([]string, 0, len(c.addresses))
newAddresses := make([]udpMuxedConnAddr, 0, len(c.addresses))
for _, a := range c.addresses {
if a != addr {
newAddresses = append(newAddresses, a)
Expand All @@ -158,7 +169,7 @@ func (c *udpMuxedConn) removeAddress(addr string) {
c.addresses = newAddresses
}

func (c *udpMuxedConn) containsAddress(addr string) bool {
func (c *udpMuxedConn) containsAddress(addr udpMuxedConnAddr) bool {
c.mu.Lock()
defer c.mu.Unlock()
for _, a := range c.addresses {
Expand Down