Skip to content

Commit

Permalink
Remove global state for ICE TCP
Browse files Browse the repository at this point in the history
This addresses a few points issue of #245:

 - Take a net.Listener instead of having global state
 - Expose a net.TCPMux based API

Also, the unused closeChannel was removed from tcp_mux.go

Closes #253.
  • Loading branch information
jeremija committed Jul 21, 2020
1 parent 1f59642 commit 62e1d37
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 172 deletions.
16 changes: 8 additions & 8 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ type Agent struct {
loggerFactory logging.LoggerFactory
log logging.LeveledLogger

net *vnet.Net
tcp *tcpIPMux
net *vnet.Net
tcpMux *TCPMux

interfaceFilter func(string) bool

Expand Down Expand Up @@ -306,11 +306,7 @@ func NewAgent(config *AgentConfig) (*Agent, error) {
insecureSkipVerify: config.InsecureSkipVerify,
}

a.tcp = newTCPIPMux(tcpIPMuxParams{
ListenPort: config.TCPListenPort,
Logger: log,
ReadBufferSize: 8,
})
a.tcpMux = config.TCPMux

if a.net == nil {
a.net = vnet.NewNet(nil)
Expand Down Expand Up @@ -887,7 +883,11 @@ func (a *Agent) Close() error {

a.gatherCandidateCancel()
a.err.Store(ErrClosed)
a.tcp.RemoveUfrag(a.localUfrag)

if a.tcpMux != nil {
a.tcpMux.RemoveConnByUfrag(a.localUfrag)
}

close(a.done)

<-done
Expand Down
8 changes: 4 additions & 4 deletions agent_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ type AgentConfig struct {
// to TURN servers via TLS or DTLS
InsecureSkipVerify bool

// TCPListenPort will be used to start a TCP listener on all allowed interfaces for
// ICE TCP. Currently only passive candidates are supported. This functionality is
// experimental and this API will likely change in the future.
TCPListenPort int
// TCPMux will be used for multiplexing incoming TCP connections for ICE TCP.
// Currently only passive candidates are supported. This functionality is
// experimental and the API might change in the future.
TCPMux *TCPMux
}

// initWithDefaults populates an agent and falls back to defaults if fields are unset
Expand Down
13 changes: 4 additions & 9 deletions gather.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,28 +161,23 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
var tcpType TCPType
switch network {
case tcp:
if a.tcp == nil {
if a.tcpMux == nil {
continue
}

// below is for passive mode
// TODO active mode
// TODO S-O mode

mux, muxErr := a.tcp.Listen(ip)
if muxErr != nil {
a.log.Warnf("could not listen %s %s\n", network, ip)
continue
}

a.log.Debugf("GetConn by ufrag: %s\n", a.localUfrag)
conn, err = mux.GetConn(a.localUfrag)
conn, err = a.tcpMux.GetConnByUfrag(a.localUfrag)
if err != nil {
a.log.Warnf("error getting tcp conn by ufrag: %s %s\n", network, ip, a.localUfrag)
continue
}
port = conn.LocalAddr().(*net.TCPAddr).Port
tcpType = TCPTypePassive
// TODO is there a way to verify that the listen address is even
// accessible from the current interface.
case udp:
conn, err = listenUDPInPortRange(a.net, a.log, int(a.portmax), int(a.portmin), network, &net.UDPAddr{IP: ip, Port: 0})
if err != nil {
Expand Down
18 changes: 17 additions & 1 deletion gather_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import (

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/crypto/selfsign"
"github.com/pion/logging"
"github.com/pion/transport/test"
"github.com/pion/turn/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestListenUDP(t *testing.T) {
Expand Down Expand Up @@ -116,11 +118,25 @@ func TestSTUNConcurrency(t *testing.T) {
Port: serverPort,
})

listener, err := net.ListenTCP("tcp", &net.TCPAddr{
IP: net.IP{127, 0, 0, 1},
})
require.NoError(t, err)
defer func() {
_ = listener.Close()
}()

a, err := NewAgent(&AgentConfig{
NetworkTypes: supportedNetworkTypes,
Urls: urls,
CandidateTypes: []CandidateType{CandidateTypeHost, CandidateTypeServerReflexive},
TCPListenPort: 9999,
TCPMux: NewTCPMux(
TCPMuxParams{
Listener: listener,
Logger: logging.NewDefaultLoggerFactory().NewLogger("ice"),
ReadBufferSize: 8,
},
),
})
assert.NoError(t, err)

Expand Down
102 changes: 0 additions & 102 deletions tcp_ip_mux.go

This file was deleted.

71 changes: 34 additions & 37 deletions tcp_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,35 @@ import (
"github.com/pion/stun"
)

type tcpMux struct {
params *tcpMuxParams
// TCPMux muxes TCP net.Conns into net.PacketConns and groups them by Ufrag.
type TCPMux struct {
params *TCPMuxParams
closed bool

// conns is a map of all tcpPacketConns indexed by ufrag
conns map[string]*tcpPacketConn

mu sync.Mutex
wg sync.WaitGroup
closedChan chan struct{}
closeOnce sync.Once
mu sync.Mutex
wg sync.WaitGroup
}

type tcpMuxParams struct {
// TCPMuxParams are parameters for TCPMux.
type TCPMuxParams struct {
Listener net.Listener
Logger logging.LeveledLogger
ReadBufferSize int
}

func newTCPMux(params tcpMuxParams) *tcpMux {
m := &tcpMux{
// NewTCPMux creates a new instance of TCPMux.
func NewTCPMux(params TCPMuxParams) *TCPMux {
if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
}

m := &TCPMux{
params: &params,

conns: map[string]*tcpPacketConn{},

closedChan: make(chan struct{}),
}

m.wg.Add(1)
Expand All @@ -47,7 +51,7 @@ func newTCPMux(params tcpMuxParams) *tcpMux {
return m
}

func (m *tcpMux) start() {
func (m *TCPMux) start() {
m.params.Logger.Infof("Listening TCP on %s\n", m.params.Listener.Addr())
for {
conn, err := m.params.Listener.Accept()
Expand All @@ -66,11 +70,13 @@ func (m *tcpMux) start() {
}
}

func (m *tcpMux) LocalAddr() net.Addr {
// LocalAddr returns the listening address of this TCPMux.
func (m *TCPMux) LocalAddr() net.Addr {
return m.params.Listener.Addr()
}

func (m *tcpMux) GetConn(ufrag string) (net.PacketConn, error) {
// GetConnByUfrag retrieves an existing or creates a new net.PacketConn.
func (m *TCPMux) GetConnByUfrag(ufrag string) (net.PacketConn, error) {
m.mu.Lock()
defer m.mu.Unlock()

Expand All @@ -86,7 +92,7 @@ func (m *tcpMux) GetConn(ufrag string) (net.PacketConn, error) {
return conn, nil
}

func (m *tcpMux) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn {
func (m *TCPMux) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn {
conn := newTCPPacketConn(tcpPacketParams{
ReadBuffer: m.params.ReadBufferSize,
LocalAddr: localAddr,
Expand All @@ -98,20 +104,20 @@ func (m *tcpMux) createConn(ufrag string, localAddr net.Addr) *tcpPacketConn {
go func() {
defer m.wg.Done()
<-conn.CloseChannel()
m.RemoveConn(ufrag)
m.RemoveConnByUfrag(ufrag)
}()

return conn
}

func (m *tcpMux) closeAndLogError(closer io.Closer) {
func (m *TCPMux) closeAndLogError(closer io.Closer) {
err := closer.Close()
if err != nil {
m.params.Logger.Warnf("Error closing connection: %s", err)
}
}

func (m *tcpMux) handleConn(conn net.Conn) {
func (m *TCPMux) handleConn(conn net.Conn) {
buf := make([]byte, receiveMTU)

n, err := readStreamingPacket(conn, buf)
Expand Down Expand Up @@ -169,43 +175,34 @@ func (m *tcpMux) handleConn(conn net.Conn) {
}
}

func (m *tcpMux) Close() error {
// Close closes the listener and waits for all goroutines to exit.
func (m *TCPMux) Close() error {
m.mu.Lock()
m.closed = true

m.closeOnce.Do(func() {
close(m.closedChan)
})

for _, conn := range m.conns {
m.closeAndLogError(conn)
}
m.conns = map[string]*tcpPacketConn{}
m.mu.Unlock()

err := m.params.Listener.Close()

m.mu.Unlock()

m.wg.Wait()

return err
}

func (m *tcpMux) CloseChannel() <-chan struct{} {
return m.closedChan
}

func (m *tcpMux) RemoveConn(ufrag string) {
// RemoveConnByUfrag closes and removes a net.PacketConn by Ufrag.
func (m *TCPMux) RemoveConnByUfrag(ufrag string) {
m.mu.Lock()
defer m.mu.Unlock()

if conn, ok := m.conns[ufrag]; ok {
m.closeAndLogError(conn)
delete(m.conns, ufrag)
}

if len(m.conns) == 0 {
m.closeOnce.Do(func() {
close(m.closedChan)
})

m.closeAndLogError(m.params.Listener)
}
}

const streamingPacketHeaderLen = 2
Expand Down
Loading

0 comments on commit 62e1d37

Please sign in to comment.