Skip to content

Commit

Permalink
Move taskloop into dedicated package
Browse files Browse the repository at this point in the history
Reduce size of Agent and simplify code
  • Loading branch information
stv0g authored and Sean-Der committed Mar 23, 2024
1 parent b36d332 commit fdca6c4
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 215 deletions.
181 changes: 55 additions & 126 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (
"sync/atomic"
"time"

atomicx "github.com/pion/ice/v3/internal/atomic"
stunx "github.com/pion/ice/v3/internal/stun"
"github.com/pion/ice/v3/internal/taskloop"
"github.com/pion/logging"
"github.com/pion/mdns/v2"
"github.com/pion/stun/v2"
Expand All @@ -36,13 +36,12 @@ type bindingRequest struct {

// Agent represents the ICE agent
type Agent struct {
chanTask chan task
loop *taskloop.Loop

onConnectionStateChangeHdlr atomic.Value // func(ConnectionState)
onSelectedCandidatePairChangeHdlr atomic.Value // func(Candidate, Candidate)
onCandidateHdlr atomic.Value // func(Candidate)

// State owned by the taskLoop
onConnected chan struct{}
onConnectedOnce sync.Once

Expand Down Expand Up @@ -118,11 +117,6 @@ type Agent struct {
// 1:1 D-NAT IP address mapping
extIPMapper *externalIPMapper

// State for closing
done chan struct{}
taskLoopDone chan struct{}
err atomicx.Error

gatherCandidateCancel func()
gatherCandidateDone chan struct{}

Expand All @@ -147,74 +141,6 @@ type Agent struct {
proxyDialer proxy.Dialer
}

type task struct {
fn func(context.Context, *Agent)
done chan struct{}
}

func (a *Agent) ok() error {
select {
case <-a.done:
return a.getErr()
default:
}
return nil
}

func (a *Agent) getErr() error {
if err := a.err.Load(); err != nil {
return err
}
return ErrClosed
}

// Run task in serial. Blocking tasks must be cancelable by context.
func (a *Agent) run(ctx context.Context, t func(context.Context, *Agent)) error {
if err := a.ok(); err != nil {
return err
}
done := make(chan struct{})
select {
case <-ctx.Done():
return ctx.Err()
case a.chanTask <- task{t, done}:
<-done
return nil
}
}

// taskLoop handles registered tasks and agent close.
func (a *Agent) taskLoop() {
defer func() {
a.deleteAllCandidates()
a.startedFn()

if err := a.buf.Close(); err != nil {
a.log.Warnf("Failed to close buffer: %v", err)
}

a.closeMulticastConn()
a.updateConnectionState(ConnectionStateClosed)

a.gatherCandidateCancel()
if a.gatherCandidateDone != nil {
<-a.gatherCandidateDone
}

close(a.taskLoopDone)
}()

for {
select {
case <-a.done:
return
case t := <-a.chanTask:
t.fn(a.context(), a)
close(t.done)
}
}
}

// NewAgent creates a new Agent
func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
var err error
Expand Down Expand Up @@ -247,7 +173,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
startedCtx, startedFn := context.WithCancel(context.Background())

a := &Agent{
chanTask: make(chan task),
tieBreaker: globalMathRandomGenerator.Uint64(),
lite: config.Lite,
gatheringState: GatheringStateNew,
Expand All @@ -258,8 +183,6 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
networkTypes: config.NetworkTypes,
onConnected: make(chan struct{}),
buf: packetio.NewBuffer(),
done: make(chan struct{}),
taskLoopDone: make(chan struct{}),
startedCh: startedCtx.Done(),
startedFn: startedFn,
portMin: config.PortMin,
Expand Down Expand Up @@ -333,7 +256,23 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
return nil, err
}

go a.taskLoop()
a.loop = taskloop.New(func() {
a.removeUfragFromMux()
a.deleteAllCandidates()
a.startedFn()

if err := a.buf.Close(); err != nil {
a.log.Warnf("Failed to close buffer: %v", err)
}

a.closeMulticastConn()
a.updateConnectionState(ConnectionStateClosed)

a.gatherCandidateCancel()
if a.gatherCandidateDone != nil {
<-a.gatherCandidateDone
}
})

// Restart is also used to initialize the agent for the first time
if err := a.Restart(config.LocalUfrag, config.LocalPwd); err != nil {
Expand All @@ -359,10 +298,10 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP

a.log.Debugf("Started agent: isControlling? %t, remoteUfrag: %q, remotePwd: %q", isControlling, remoteUfrag, remotePwd)

return a.run(a.context(), func(_ context.Context, agent *Agent) {
agent.isControlling = isControlling
agent.remoteUfrag = remoteUfrag
agent.remotePwd = remotePwd
return a.loop.Run(a.loop, func(_ context.Context) {
a.isControlling = isControlling
a.remoteUfrag = remoteUfrag
a.remotePwd = remotePwd

if isControlling {
a.selector = &controllingSelector{agent: a, log: a.log}
Expand All @@ -377,7 +316,7 @@ func (a *Agent) startConnectivityChecks(isControlling bool, remoteUfrag, remoteP
a.selector.Start()
a.startedFn()

agent.updateConnectionState(ConnectionStateChecking)
a.updateConnectionState(ConnectionStateChecking)

a.requestConnectivityCheck()
go a.connectivityChecks() //nolint:contextcheck
Expand All @@ -389,7 +328,7 @@ func (a *Agent) connectivityChecks() {
checkingDuration := time.Time{}

contact := func() {
if err := a.run(a.context(), func(_ context.Context, a *Agent) {
if err := a.loop.Run(a.loop, func(_ context.Context) {
defer func() {
lastConnectionState = a.connectionState
}()
Expand Down Expand Up @@ -446,7 +385,7 @@ func (a *Agent) connectivityChecks() {
contact()
case <-t.C:
contact()
case <-a.done:
case <-a.loop.Done():
t.Stop()
return
}
Expand Down Expand Up @@ -638,9 +577,9 @@ func (a *Agent) AddRemoteCandidate(c Candidate) error {
}

go func() {
if err := a.run(a.context(), func(_ context.Context, agent *Agent) {
if err := a.loop.Run(a.loop, func(_ context.Context) {
// nolint: contextcheck
agent.addRemoteCandidate(c)
a.addRemoteCandidate(c)
}); err != nil {
a.log.Warnf("Failed to add remote candidate %s: %v", c.Address(), err)
return
Expand Down Expand Up @@ -670,9 +609,9 @@ func (a *Agent) resolveAndAddMulticastCandidate(c *CandidateHost) {
return
}

if err = a.run(a.context(), func(_ context.Context, agent *Agent) {
if err = a.loop.Run(a.loop, func(_ context.Context) {
// nolint: contextcheck
agent.addRemoteCandidate(c)
a.addRemoteCandidate(c)
}); err != nil {
a.log.Warnf("Failed to add mDNS candidate %s: %v", c.Address(), err)
return
Expand All @@ -695,7 +634,7 @@ func (a *Agent) addRemotePassiveTCPCandidate(remoteCandidate Candidate) {

for i := range localIPs {
conn := newActiveTCPConn(
a.context(),
a.loop,
net.JoinHostPort(localIPs[i].String(), "0"),
net.JoinHostPort(remoteCandidate.Address(), strconv.Itoa(remoteCandidate.Port())),
a.log,
Expand Down Expand Up @@ -763,7 +702,7 @@ func (a *Agent) addRemoteCandidate(c Candidate) {
}

func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net.PacketConn) error {
return a.run(ctx, func(context.Context, *Agent) {
return a.loop.Run(ctx, func(context.Context) {
set := a.localCandidates[c.NetworkType()]
for _, candidate := range set {
if candidate.Equal(c) {
Expand Down Expand Up @@ -799,9 +738,9 @@ func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net
func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
var res []Candidate

err := a.run(a.context(), func(_ context.Context, agent *Agent) {
err := a.loop.Run(a.loop, func(_ context.Context) {
var candidates []Candidate
for _, set := range agent.remoteCandidates {
for _, set := range a.remoteCandidates {
candidates = append(candidates, set...)
}
res = candidates
Expand All @@ -817,9 +756,9 @@ func (a *Agent) GetRemoteCandidates() ([]Candidate, error) {
func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
var res []Candidate

err := a.run(a.context(), func(_ context.Context, agent *Agent) {
err := a.loop.Run(a.loop, func(_ context.Context) {
var candidates []Candidate
for _, set := range agent.localCandidates {
for _, set := range a.localCandidates {
candidates = append(candidates, set...)
}
res = candidates
Expand All @@ -834,9 +773,9 @@ func (a *Agent) GetLocalCandidates() ([]Candidate, error) {
// GetLocalUserCredentials returns the local user credentials
func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
valSet := make(chan struct{})
err = a.run(a.context(), func(_ context.Context, agent *Agent) {
frag = agent.localUfrag
pwd = agent.localPwd
err = a.loop.Run(a.loop, func(_ context.Context) {
frag = a.localUfrag
pwd = a.localPwd
close(valSet)
})

Expand All @@ -849,9 +788,9 @@ func (a *Agent) GetLocalUserCredentials() (frag string, pwd string, err error) {
// GetRemoteUserCredentials returns the remote user credentials
func (a *Agent) GetRemoteUserCredentials() (frag string, pwd string, err error) {
valSet := make(chan struct{})
err = a.run(a.context(), func(_ context.Context, agent *Agent) {
frag = agent.remoteUfrag
pwd = agent.remotePwd
err = a.loop.Run(a.loop, func(_ context.Context) {
frag = a.remoteUfrag
pwd = a.remotePwd
close(valSet)
})

Expand All @@ -875,17 +814,7 @@ func (a *Agent) removeUfragFromMux() {

// Close cleans up the Agent
func (a *Agent) Close() error {
if err := a.ok(); err != nil {
return err
}

a.err.Store(ErrClosed)

a.removeUfragFromMux()

close(a.done)
<-a.taskLoopDone
return nil
return a.loop.Close()
}

// Remove all candidates. This closes any listening sockets
Expand Down Expand Up @@ -1092,7 +1021,7 @@ func (a *Agent) handleInbound(m *stun.Message, local Candidate, remote net.Addr)
// and returns true if it is an actual remote candidate
func (a *Agent) validateNonSTUNTraffic(local Candidate, remote net.Addr) (Candidate, bool) {
var remoteCandidate Candidate
if err := a.run(local.context(), func(context.Context, *Agent) {
if err := a.loop.Run(local.context(), func(context.Context) {
remoteCandidate = a.findRemoteCandidate(local.NetworkType(), remote)
if remoteCandidate != nil {
remoteCandidate.seen(false)
Expand Down Expand Up @@ -1149,9 +1078,9 @@ func (a *Agent) SetRemoteCredentials(remoteUfrag, remotePwd string) error {
return ErrRemotePwdEmpty
}

return a.run(a.context(), func(_ context.Context, agent *Agent) {
agent.remoteUfrag = remoteUfrag
agent.remotePwd = remotePwd
return a.loop.Run(a.loop, func(_ context.Context) {
a.remoteUfrag = remoteUfrag
a.remotePwd = remotePwd
})
}

Expand Down Expand Up @@ -1186,17 +1115,17 @@ func (a *Agent) Restart(ufrag, pwd string) error {
}

var err error
if runErr := a.run(a.context(), func(_ context.Context, agent *Agent) {
if agent.gatheringState == GatheringStateGathering {
agent.gatherCandidateCancel()
if runErr := a.loop.Run(a.loop, func(_ context.Context) {
if a.gatheringState == GatheringStateGathering {
a.gatherCandidateCancel()
}

// Clear all agent needed to take back to fresh state
a.removeUfragFromMux()
agent.localUfrag = ufrag
agent.localPwd = pwd
agent.remoteUfrag = ""
agent.remotePwd = ""
a.localUfrag = ufrag
a.localPwd = pwd
a.remoteUfrag = ""
a.remotePwd = ""
a.gatheringState = GatheringStateNew
a.checklist = make([]*CandidatePair, 0)
a.pendingBindingRequests = make([]bindingRequest, 0)
Expand All @@ -1219,7 +1148,7 @@ func (a *Agent) Restart(ufrag, pwd string) error {

func (a *Agent) setGatheringState(newState GatheringState) error {
done := make(chan struct{})
if err := a.run(a.context(), func(context.Context, *Agent) {
if err := a.loop.Run(a.loop, func(context.Context) {
if a.gatheringState != newState && newState == GatheringStateComplete {
a.candidateNotifier.EnqueueCandidate(nil)
}
Expand Down
2 changes: 1 addition & 1 deletion agent_on_selected_candidate_pair_change_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestOnSelectedCandidatePairChange(t *testing.T) {
})
require.NoError(t, err)

err = agent.run(context.Background(), func(_ context.Context, agent *Agent) {
err = agent.loop.Run(context.Background(), func(_ context.Context) {
agent.setSelectedPair(candidatePair)
})
require.NoError(t, err)
Expand Down

0 comments on commit fdca6c4

Please sign in to comment.