diff --git a/agent.go b/agent.go index 144f9c21..761c07ca 100644 --- a/agent.go +++ b/agent.go @@ -119,6 +119,10 @@ type Agent struct { // 1:1 D-NAT IP address mapping extIPMapper *externalIPMapper + // Callback that allows user to implement custom behavior + // for STUN Binding Requests + userBindingRequestHandler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool + gatherCandidateCancel func() gatherCandidateDone chan struct{} @@ -213,6 +217,8 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit includeLoopback: config.IncludeLoopback, disableActiveTCP: config.DisableActiveTCP, + + userBindingRequestHandler: config.BindingRequestHandler, } a.connectionStateNotifier = &handlerNotifier{connectionStateFunc: a.onConnectionStateChange} a.candidateNotifier = &handlerNotifier{candidateFunc: a.onCandidate} diff --git a/agent_config.go b/agent_config.go index 5edbea45..4c87fa1a 100644 --- a/agent_config.go +++ b/agent_config.go @@ -193,6 +193,13 @@ type AgentConfig struct { // DisableActiveTCP can be used to disable Active TCP candidates. Otherwise when TCP is enabled // Active TCP candidates will be created when a new passive TCP remote candidate is added. DisableActiveTCP bool + + // BindingRequestHandler allows applications to perform logic on incoming STUN Binding Requests + // This was implemented to allow users to + // * Log incoming Binding Requests for debugging + // * Implement draft-thatcher-ice-renomination + // * Implement custom CandidatePair switching logic + BindingRequestHandler func(m *stun.Message, local, remote Candidate, pair *CandidatePair) bool } // initWithDefaults populates an agent and falls back to defaults if fields are unset diff --git a/selection.go b/selection.go index 55ad20bf..09a0988d 100644 --- a/selection.go +++ b/selection.go @@ -111,6 +111,12 @@ func (s *controllingSelector) HandleBindingRequest(m *stun.Message, local, remot s.nominatePair(p) } } + + if s.agent.userBindingRequestHandler != nil { + if shouldSwitch := s.agent.userBindingRequestHandler(m, local, remote, p); shouldSwitch { + s.agent.setSelectedPair(p) + } + } } func (s *controllingSelector) HandleSuccessResponse(m *stun.Message, local, remote Candidate, remoteAddr net.Addr) { @@ -242,14 +248,12 @@ func (s *controlledSelector) HandleSuccessResponse(m *stun.Message, local, remot } func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote Candidate) { - useCandidate := m.Contains(stun.AttrUseCandidate) - p := s.agent.findPair(local, remote) if p == nil { p = s.agent.addPair(local, remote) } - if useCandidate { + if m.Contains(stun.AttrUseCandidate) { // https://tools.ietf.org/html/rfc8445#section-7.3.1.5 if p.state == CandidatePairStateSucceeded { @@ -257,8 +261,8 @@ func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote // previously sent by this pair produced a successful response and // generated a valid pair (Section 7.2.5.3.2). The agent sets the // nominated flag value of the valid pair to true. - if selectedPair := s.agent.getSelectedPair(); selectedPair == nil || - (selectedPair != p && selectedPair.priority() <= p.priority()) { + selectedPair := s.agent.getSelectedPair() + if selectedPair == nil || (selectedPair != p && selectedPair.priority() <= p.priority()) { s.agent.setSelectedPair(p) } else if selectedPair != p { s.log.Tracef("Ignore nominate new pair %s, already nominated pair %s", p, selectedPair) @@ -278,6 +282,12 @@ func (s *controlledSelector) HandleBindingRequest(m *stun.Message, local, remote s.agent.sendBindingSuccess(m, local, remote) s.PingCandidate(local, remote) + + if s.agent.userBindingRequestHandler != nil { + if shouldSwitch := s.agent.userBindingRequestHandler(m, local, remote, p); shouldSwitch { + s.agent.setSelectedPair(p) + } + } } type liteSelector struct { diff --git a/selection_test.go b/selection_test.go new file mode 100644 index 00000000..a5e17fc1 --- /dev/null +++ b/selection_test.go @@ -0,0 +1,156 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +//go:build !js +// +build !js + +package ice + +import ( + "bytes" + "context" + "errors" + "io" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/pion/stun/v2" + "github.com/pion/transport/v3/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func sendUntilDone(t *testing.T, writingConn, readingConn net.Conn, maxAttempts int) bool { + testMessage := []byte("Hello World") + testBuffer := make([]byte, len(testMessage)) + + readDone, readDoneCancel := context.WithCancel(context.Background()) + go func() { + _, err := readingConn.Read(testBuffer) + if errors.Is(err, io.EOF) { + return + } + + require.NoError(t, err) + require.True(t, bytes.Equal(testMessage, testBuffer)) + + readDoneCancel() + }() + + attempts := 0 + for { + select { + case <-time.After(5 * time.Millisecond): + if attempts > maxAttempts { + return false + } + + _, err := writingConn.Write(testMessage) + require.NoError(t, err) + attempts++ + case <-readDone.Done(): + return true + } + } +} + +func TestBindingRequestHandler(t *testing.T) { + defer test.CheckRoutines(t)() + defer test.TimeOut(time.Second * 30).Stop() + + var switchToNewCandidatePair, controlledLoggingFired atomic.Value + oneHour := time.Hour + keepaliveInterval := time.Millisecond * 20 + + aNotifier, aConnected := onConnected() + bNotifier, bConnected := onConnected() + controllingAgent, err := NewAgent(&AgentConfig{ + NetworkTypes: []NetworkType{NetworkTypeUDP4, NetworkTypeUDP6}, + MulticastDNSMode: MulticastDNSModeDisabled, + KeepaliveInterval: &keepaliveInterval, + CheckInterval: &oneHour, + BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool { + controlledLoggingFired.Store(true) + return false + }, + }) + require.NoError(t, err) + require.NoError(t, controllingAgent.OnConnectionStateChange(aNotifier)) + + controlledAgent, err := NewAgent(&AgentConfig{ + NetworkTypes: []NetworkType{NetworkTypeUDP4}, + MulticastDNSMode: MulticastDNSModeDisabled, + KeepaliveInterval: &keepaliveInterval, + CheckInterval: &oneHour, + BindingRequestHandler: func(_ *stun.Message, _, _ Candidate, _ *CandidatePair) bool { + // Don't switch candidate pair until we are ready + val, ok := switchToNewCandidatePair.Load().(bool) + return ok && val + }, + }) + require.NoError(t, err) + require.NoError(t, controlledAgent.OnConnectionStateChange(bNotifier)) + + controlledConn, controllingConn := connect(controlledAgent, controllingAgent) + <-aConnected + <-bConnected + + // Assert we have connected and can send data + require.True(t, sendUntilDone(t, controlledConn, controllingConn, 100)) + + // Take the lock on the controlling Agent and unset state + assert.NoError(t, controlledAgent.loop.Run(controlledAgent.loop, func(_ context.Context) { + for net, cs := range controlledAgent.remoteCandidates { + for _, c := range cs { + require.NoError(t, c.close()) + } + delete(controlledAgent.remoteCandidates, net) + } + + for _, c := range controlledAgent.localCandidates[NetworkTypeUDP4] { + cast, ok := c.(*CandidateHost) + require.True(t, ok) + cast.remoteCandidateCaches = map[AddrPort]Candidate{} + } + + controlledAgent.setSelectedPair(nil) + controlledAgent.checklist = make([]*CandidatePair, 0) + })) + + // Assert that Selected Candidate pair has only been unset on Controlled side + candidatePair, err := controlledAgent.GetSelectedCandidatePair() + assert.Nil(t, candidatePair) + assert.NoError(t, err) + + candidatePair, err = controllingAgent.GetSelectedCandidatePair() + assert.NotNil(t, candidatePair) + assert.NoError(t, err) + + // Sending will fail, we no longer have a selected candidate pair + require.False(t, sendUntilDone(t, controlledConn, controllingConn, 20)) + + // Send STUN Binding requests until a new Selected Candidate Pair has been set by BindingRequestHandler + switchToNewCandidatePair.Store(true) + for { + controllingAgent.requestConnectivityCheck() + + candidatePair, err = controlledAgent.GetSelectedCandidatePair() + require.NoError(t, err) + if candidatePair != nil { + break + } + + time.Sleep(time.Millisecond * 5) + } + + // We have a new selected candidate pair because of BindingRequestHandler, test that it works + require.True(t, sendUntilDone(t, controllingConn, controlledConn, 100)) + + fired, ok := controlledLoggingFired.Load().(bool) + require.True(t, ok) + require.True(t, fired) + + closePipe(t, controllingConn, controlledConn) +}