Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make OnConnectionStateChange wait for handler to finish #2702

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 40 additions & 27 deletions peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ type PeerConnection struct {
pendingRemoteDescription *SessionDescription
signalingState SignalingState
iceConnectionState atomic.Value // ICEConnectionState
connectionState atomic.Value // PeerConnectionState

connStateMx sync.RWMutex
connectionState PeerConnectionState // PeerConnectionState
// connectionStateNotifications runs OnConnectionStateUpdate callbacks sequentially
// guaranteeing the ordering of the updates.
connectionStateNotifications *operations

idpLoginURL *string

Expand Down Expand Up @@ -125,20 +130,21 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection,
Certificates: []Certificate{},
ICECandidatePoolSize: 0,
},
ops: newOperations(),
isClosed: &atomicBool{},
isNegotiationNeeded: &atomicBool{},
negotiationNeededState: negotiationNeededStateEmpty,
lastOffer: "",
lastAnswer: "",
greaterMid: -1,
signalingState: SignalingStateStable,
ops: newOperations(),
isClosed: &atomicBool{},
isNegotiationNeeded: &atomicBool{},
negotiationNeededState: negotiationNeededStateEmpty,
lastOffer: "",
lastAnswer: "",
greaterMid: -1,
signalingState: SignalingStateStable,
connectionStateNotifications: newOperations(),
connectionState: PeerConnectionStateNew,

api: api,
log: api.settingEngine.LoggerFactory.NewLogger("pc"),
}
pc.iceConnectionState.Store(ICEConnectionStateNew)
pc.connectionState.Store(PeerConnectionStateNew)

i, err := api.interceptorRegistry.Build("")
if err != nil {
Expand Down Expand Up @@ -507,10 +513,10 @@ func (pc *PeerConnection) OnConnectionStateChange(f func(PeerConnectionState)) {
}

func (pc *PeerConnection) onConnectionStateChange(cs PeerConnectionState) {
pc.connectionState.Store(cs)
pc.log.Infof("peer connection state changed: %s", cs)
if handler, ok := pc.onConnectionStateChangeHandler.Load().(func(PeerConnectionState)); ok && handler != nil {
go handler(cs)
pc.connectionStateNotifications.Enqueue(func() {
handler(cs)
})
}
}

Expand Down Expand Up @@ -746,7 +752,20 @@ func (pc *PeerConnection) createICEGatherer() (*ICEGatherer, error) {

// Update the PeerConnectionState given the state of relevant transports
// https://www.w3.org/TR/webrtc/#rtcpeerconnectionstate-enum
func (pc *PeerConnection) updateConnectionState(iceConnectionState ICEConnectionState, dtlsTransportState DTLSTransportState) {
func (pc *PeerConnection) updateConnectionState() {
pc.connStateMx.Lock()
defer pc.connStateMx.Unlock()

cs := pc.getConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State())
if cs == pc.connectionState {
return
}
pc.connectionState = cs
pc.log.Infof("peer connection state changed: %s", cs)
pc.onConnectionStateChange(cs)
}

func (pc *PeerConnection) getConnectionState(iceConnectionState ICEConnectionState, dtlsTransportState DTLSTransportState) PeerConnectionState {
connectionState := PeerConnectionStateNew
switch {
// The RTCPeerConnection object's [[IsClosed]] slot is true.
Expand Down Expand Up @@ -780,12 +799,7 @@ func (pc *PeerConnection) updateConnectionState(iceConnectionState ICEConnection
(dtlsTransportState == DTLSTransportStateConnected || dtlsTransportState == DTLSTransportStateClosed):
connectionState = PeerConnectionStateConnected
}

if pc.connectionState.Load() == connectionState {
return
}

pc.onConnectionStateChange(connectionState)
return connectionState
}

func (pc *PeerConnection) createICETransport() *ICETransport {
Expand All @@ -812,7 +826,7 @@ func (pc *PeerConnection) createICETransport() *ICETransport {
return
}
pc.onICEConnectionStateChange(cs)
pc.updateConnectionState(cs, pc.dtlsTransport.State())
pc.updateConnectionState()
})

return t
Expand Down Expand Up @@ -2113,7 +2127,7 @@ func (pc *PeerConnection) Close() error {
}

// https://www.w3.org/TR/webrtc/#dom-rtcpeerconnection-close (step #11)
pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State())
pc.updateConnectionState()

return util.FlattenErrs(closeErrs)
}
Expand Down Expand Up @@ -2201,10 +2215,9 @@ func (pc *PeerConnection) ICEGatheringState() ICEGatheringState {
// ConnectionState attribute returns the connection state of the
// PeerConnection instance.
func (pc *PeerConnection) ConnectionState() PeerConnectionState {
if state, ok := pc.connectionState.Load().(PeerConnectionState); ok {
return state
}
return PeerConnectionState(0)
pc.connStateMx.RLock()
defer pc.connStateMx.RUnlock()
return pc.connectionState
}

// GetStats return data providing statistics about the overall connection
Expand Down Expand Up @@ -2300,7 +2313,7 @@ func (pc *PeerConnection) startTransports(iceRole ICERole, dtlsRole DTLSRole, re
Role: dtlsRole,
Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}},
})
pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State())
pc.updateConnectionState()
if err != nil {
pc.log.Warnf("Failed to start manager: %s", err)
return
Expand Down
33 changes: 13 additions & 20 deletions peerconnection_go_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,17 +388,17 @@
currentRemoteDescription: &SessionDescription{},
pendingRemoteDescription: &SessionDescription{},
signalingState: SignalingStateHaveLocalOffer,
connectionState: PeerConnectionStateConnecting,
}
pc.iceConnectionState.Store(ICEConnectionStateChecking)
pc.connectionState.Store(PeerConnectionStateConnecting)

assert.Equal(t, pc.currentLocalDescription, pc.CurrentLocalDescription(), "should match")
assert.Equal(t, pc.pendingLocalDescription, pc.PendingLocalDescription(), "should match")
assert.Equal(t, pc.currentRemoteDescription, pc.CurrentRemoteDescription(), "should match")
assert.Equal(t, pc.pendingRemoteDescription, pc.PendingRemoteDescription(), "should match")
assert.Equal(t, pc.signalingState, pc.SignalingState(), "should match")
assert.Equal(t, pc.iceConnectionState.Load(), pc.ICEConnectionState(), "should match")
assert.Equal(t, pc.connectionState.Load(), pc.ConnectionState(), "should match")
assert.Equal(t, pc.connectionState, pc.ConnectionState(), "should match")
}

func TestPeerConnection_AnswerWithoutOffer(t *testing.T) {
Expand Down Expand Up @@ -1576,33 +1576,26 @@
assert.NoError(t, err)
assert.Equal(t, PeerConnectionStateNew, pc.ConnectionState())

pc.updateConnectionState(ICEConnectionStateChecking, DTLSTransportStateNew)
assert.Equal(t, PeerConnectionStateConnecting, pc.ConnectionState())
assert.Equal(t, PeerConnectionStateConnecting, pc.getConnectionState(ICEConnectionStateChecking, DTLSTransportStateNew))

pc.updateConnectionState(ICEConnectionStateConnected, DTLSTransportStateNew)
assert.Equal(t, PeerConnectionStateConnecting, pc.ConnectionState())
assert.Equal(t, PeerConnectionStateConnecting, pc.getConnectionState(ICEConnectionStateConnected, DTLSTransportStateNew))

pc.updateConnectionState(ICEConnectionStateConnected, DTLSTransportStateConnecting)
assert.Equal(t, PeerConnectionStateConnecting, pc.ConnectionState())
assert.Equal(t, PeerConnectionStateConnecting, pc.getConnectionState(ICEConnectionStateConnected, DTLSTransportStateConnecting))

pc.updateConnectionState(ICEConnectionStateConnected, DTLSTransportStateConnected)
assert.Equal(t, PeerConnectionStateConnected, pc.ConnectionState())
assert.Equal(t, PeerConnectionStateConnected, pc.getConnectionState(ICEConnectionStateConnected, DTLSTransportStateConnected))

pc.updateConnectionState(ICEConnectionStateCompleted, DTLSTransportStateConnected)
assert.Equal(t, PeerConnectionStateConnected, pc.ConnectionState())
assert.Equal(t, PeerConnectionStateConnected, pc.getConnectionState(ICEConnectionStateCompleted, DTLSTransportStateConnected))

pc.updateConnectionState(ICEConnectionStateConnected, DTLSTransportStateClosed)
assert.Equal(t, PeerConnectionStateConnected, pc.ConnectionState())
assert.Equal(t, PeerConnectionStateConnected, pc.getConnectionState(ICEConnectionStateConnected, DTLSTransportStateClosed))

pc.updateConnectionState(ICEConnectionStateDisconnected, DTLSTransportStateConnected)
assert.Equal(t, PeerConnectionStateDisconnected, pc.ConnectionState())
assert.Equal(t, PeerConnectionStateDisconnected, pc.getConnectionState(ICEConnectionStateDisconnected, DTLSTransportStateConnected))

pc.updateConnectionState(ICEConnectionStateFailed, DTLSTransportStateConnected)
assert.Equal(t, PeerConnectionStateFailed, pc.ConnectionState())
assert.Equal(t, PeerConnectionStateFailed, pc.getConnectionState(ICEConnectionStateFailed, DTLSTransportStateConnected))

pc.updateConnectionState(ICEConnectionStateConnected, DTLSTransportStateFailed)
assert.Equal(t, PeerConnectionStateFailed, pc.ConnectionState())
assert.Equal(t, PeerConnectionStateFailed, pc.getConnectionState(ICEConnectionStateConnected, DTLSTransportStateFailed))

assert.NoError(t, pc.Close())
assert.Equal(t, PeerConnectionStateClosed, pc.ConnectionState())
}

Check failure on line 1600 in peerconnection_go_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

File is not `gci`-ed with --skip-generated -s standard -s default (gci)

Check failure on line 1601 in peerconnection_go_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

File is not `gofumpt`-ed (gofumpt)
22 changes: 22 additions & 0 deletions peerconnection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"github.com/pion/transport/v3/test"
"github.com/pion/webrtc/v4/pkg/rtcerr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// newPair creates two new peer connections (an offerer and an answerer)
Expand Down Expand Up @@ -792,3 +793,24 @@
peerConnectionsConnected.Wait()
assert.NoError(t, pcOffer.Close())
}

func TestPeerConnectionStateIceAgentUpdate(t *testing.T) {
offer, answer, err := newPair()
require.NoError(t, err)
defer answer.Close()

Check failure on line 800 in peerconnection_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

Error return value of `answer.Close` is not checked (errcheck)

closed := make(chan struct{})
offer.OnConnectionStateChange(func(cs PeerConnectionState) {
if cs == PeerConnectionStateConnecting || cs == PeerConnectionStateConnected {
offer.Close()

Check failure on line 805 in peerconnection_test.go

View workflow job for this annotation

GitHub Actions / lint / Go

Error return value of `offer.Close` is not checked (errcheck)
close(closed)
}
})
require.NoError(t, signalPair(offer, answer))
select {
case <-closed:
case <-time.After(2 * time.Second):
t.Fatal("didn't receive close connection update")
}
require.Equal(t, PeerConnectionStateClosed, offer.ConnectionState())
}
17 changes: 4 additions & 13 deletions sctptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ ACCEPT:
return
}

<-r.onDataChannel(rtcDC)
r.onDataChannel(rtcDC)
rtcDC.handleOpen(dc, true, dc.Config.Negotiated)

r.lock.Lock()
Expand Down Expand Up @@ -283,27 +283,18 @@ func (r *SCTPTransport) OnDataChannelOpened(f func(*DataChannel)) {
r.onDataChannelOpenedHandler = f
}

func (r *SCTPTransport) onDataChannel(dc *DataChannel) (done chan struct{}) {
func (r *SCTPTransport) onDataChannel(dc *DataChannel) {
r.lock.Lock()
r.dataChannels = append(r.dataChannels, dc)
r.dataChannelsAccepted++
handler := r.onDataChannelHandler
r.lock.Unlock()

done = make(chan struct{})
if handler == nil || dc == nil {
close(done)
return
}

// Run this synchronously to allow setup done in onDataChannelFn()
// to complete before datachannel event handlers might be called.
go func() {
if handler != nil {
handler(dc)
close(done)
}()

return
}
}

func (r *SCTPTransport) updateMessageSize() {
Expand Down
Loading