From b0d56c67f39fdd05390335b90142abb926b9da28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antoine=20Bach=C3=A9?= Date: Sat, 26 Jun 2021 14:04:12 +0200 Subject: [PATCH] Implement DTLS restart Fixes #1636 --- dtlstransport.go | 27 ++++- peerconnection.go | 34 +++++++ peerconnection_media_test.go | 185 +++++++++++++++++++++++++++++++++++ 3 files changed, 245 insertions(+), 1 deletion(-) diff --git a/dtlstransport.go b/dtlstransport.go index 9b69b2d700..a13f2e4db7 100644 --- a/dtlstransport.go +++ b/dtlstransport.go @@ -213,6 +213,31 @@ func (t *DTLSTransport) startSRTP() error { return fmt.Errorf("%w: %v", errDtlsKeyExtractionFailed, err) } + isAlreadyRunning := func() bool { + select { + case <-t.srtpReady: + return true + default: + return false + } + }() + + if isAlreadyRunning { + if sess, ok := t.srtpSession.Load().(*srtp.SessionSRTP); ok { + if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil { + return updateErr + } + } + + if sess, ok := t.srtcpSession.Load().(*srtp.SessionSRTCP); ok { + if updateErr := sess.UpdateContext(srtpConfig); updateErr != nil { + return updateErr + } + } + + return nil + } + srtpSession, err := srtp.NewSessionSRTP(t.srtpEndpoint, srtpConfig) if err != nil { return fmt.Errorf("%w: %v", errFailedToStartSRTP, err) @@ -283,7 +308,7 @@ func (t *DTLSTransport) Start(remoteParameters DTLSParameters) error { return DTLSRole(0), nil, err } - if t.state != DTLSTransportStateNew { + if t.state != DTLSTransportStateNew && t.state != DTLSTransportStateClosed { return DTLSRole(0), nil, &rtcerr.InvalidStateError{Err: fmt.Errorf("%w: %s", errInvalidDTLSStart, t.state)} } diff --git a/peerconnection.go b/peerconnection.go index fbc334e9cc..809b4f82c8 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -1108,7 +1108,41 @@ func (pc *PeerConnection) SetRemoteDescription(desc SessionDescription) error { pc.ops.Enqueue(func() { pc.startRTP(true, &desc, currentTransceivers) }) + } else if pc.dtlsTransport.State() != DTLSTransportStateNew { + fingerprint, fingerprintHash, fErr := extractFingerprint(desc.parsed) + if fErr != nil { + return fErr + } + + fingerPrintDidChange := true + + for _, fp := range pc.dtlsTransport.remoteParameters.Fingerprints { + if fingerprint == fp.Value && fingerprintHash == fp.Algorithm { + fingerPrintDidChange = false + break + } + } + + if fingerPrintDidChange { + pc.ops.Enqueue(func() { + if dErr := pc.dtlsTransport.Stop(); dErr != nil { + pc.log.Warnf("Failed to stop DTLS: %s", dErr) + } + + // Restart the dtls transport with updated fingerprints + err = pc.dtlsTransport.Start(DTLSParameters{ + Role: dtlsRoleFromRemoteSDP(desc.parsed), + Fingerprints: []DTLSFingerprint{{Algorithm: fingerprintHash, Value: fingerprint}}, + }) + pc.updateConnectionState(pc.ICEConnectionState(), pc.dtlsTransport.State()) + if err != nil { + pc.log.Warnf("Failed to restart DTLS: %s", err) + return + } + }) + } } + return nil } diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index 0ec4eeb059..9853049245 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -14,10 +14,12 @@ import ( "testing" "time" + "github.com/pion/logging" "github.com/pion/randutil" "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/pion/transport/test" + "github.com/pion/transport/vnet" "github.com/pion/webrtc/v3/pkg/media" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1052,3 +1054,186 @@ func TestPeerConnection_RaceReplaceTrack(t *testing.T) { assert.NoError(t, pc.Close()) } + +// Issue #1636 +func TestPeerConnection_DTLS_Restart(t *testing.T) { + lim := test.TimeOut(time.Second * 30) + defer lim.Stop() + + // First prepare network configuration + + router, err := vnet.NewRouter(&vnet.RouterConfig{ + CIDR: "0.0.0.0/0", + LoggerFactory: logging.NewDefaultLoggerFactory(), + }) + assert.NoError(t, err) + + networkA1 := vnet.NewNet(&vnet.NetConfig{ + NetworkConditioner: vnet.NewNetworkConditioner(vnet.NetworkConditionerPresetNone), + }) + + networkA2 := vnet.NewNet(&vnet.NetConfig{ + NetworkConditioner: vnet.NewNetworkConditioner(vnet.NetworkConditionerPresetNone), + }) + + networkB := vnet.NewNet(&vnet.NetConfig{ + NetworkConditioner: vnet.NewNetworkConditioner(vnet.NetworkConditionerPresetNone), + }) + + assert.NoError(t, router.AddNet(networkA1)) + assert.NoError(t, router.AddNet(networkA2)) + assert.NoError(t, router.AddNet(networkB)) + + assert.NoError(t, router.Start()) + defer func() { _ = router.Stop() }() + + // ... then the clients + + makeClient := func(network *vnet.Net) (*PeerConnection, *TrackLocalStaticSample) { + m := &MediaEngine{} + assert.NoError(t, m.RegisterDefaultCodecs()) + + s := SettingEngine{} + s.SetVNet(network) + s.SetICETimeouts(2*time.Second, 5*time.Second, 1*time.Second) + + api := NewAPI(WithSettingEngine(s), WithMediaEngine(m)) + pc, cliErr := api.NewPeerConnection(Configuration{}) + assert.NoError(t, cliErr) + + track, cliErr := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: MimeTypeOpus}, "audio", "test-client") + assert.NoError(t, cliErr) + + _, cliErr = pc.AddTrack(track) + assert.NoError(t, cliErr) + + return pc, track + } + + clientA1, _ := makeClient(networkA1) + defer func() { _ = clientA1.Close() }() + + clientB, localClientBTrack := makeClient(networkB) + defer func() { _ = clientB.Close() }() + + // ... clientB starts publishing media + publishClientBCtx, publishCancel := context.WithCancel(context.Background()) + go func() { + ticker := time.NewTicker(20 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-publishClientBCtx.Done(): + return + case <-ticker.C: + _ = localClientBTrack.WriteSample(media.Sample{ + Data: []byte{0xbb}, + Timestamp: time.Now(), + Duration: 20 * time.Millisecond, + }) + } + } + }() + defer publishCancel() + + clientA1Tracks := make(chan *TrackRemote, 1) + clientA1.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) { + clientA1Tracks <- remote + }) + + // ClientA1 connects to ClientB + + gatherCompletePromiseA1 := GatheringCompletePromise(clientA1) + offerA1, err := clientA1.CreateOffer(nil) + assert.NoError(t, err) + assert.NoError(t, clientA1.SetLocalDescription(offerA1)) + <-gatherCompletePromiseA1 + + assert.NoError(t, clientB.SetRemoteDescription(*clientA1.LocalDescription())) + + gatherCompletePromiseB := GatheringCompletePromise(clientB) + answerB, err := clientB.CreateAnswer(nil) + assert.NoError(t, err) + assert.NoError(t, clientB.SetLocalDescription(answerB)) + <-gatherCompletePromiseB + + clientA1Connected := make(chan struct{}, 1) + clientA1Disconnected := make(chan struct{}, 1) + clientA1.OnICEConnectionStateChange(func(s ICEConnectionState) { + if s == ICEConnectionStateConnected { + clientA1Connected <- struct{}{} + } else if s == ICEConnectionStateDisconnected { + clientA1Disconnected <- struct{}{} + } + }) + + assert.NoError(t, clientA1.SetRemoteDescription(answerB)) + + // Wait for connection + <-clientA1Connected + + // At this point, clientA1 should have received a track, and some media + clientA1RemoteTrack := <-clientA1Tracks + pkt, _, err := clientA1RemoteTrack.ReadRTP() + assert.NotNil(t, pkt) + assert.NoError(t, err) + + networkA1.SetNetworkConditioner(vnet.NewNetworkConditioner(vnet.NetworkConditionerPresetFullLoss)) + + <-clientA1Disconnected + + // ClientA1 has been disconnected – in a mobile app context, this could be a switch to the background + // or a killed app. + // + // In these scenarios, the client will reconnect with a different PeerConnection – here ClientA2. + + clientA2, _ := makeClient(networkA2) + defer func() { _ = clientA2.Close() }() + + clientA2Connected := make(chan struct{}, 1) + clientA2.OnICEConnectionStateChange(func(s ICEConnectionState) { + if s == ICEConnectionStateConnected { + clientA2Connected <- struct{}{} + } else if s == ICEConnectionStateFailed { + assert.FailNow(t, "should not fail") + } + }) + + clientA2Tracks := make(chan *TrackRemote, 1) + clientA2.OnTrack(func(remote *TrackRemote, receiver *RTPReceiver) { + clientA2Tracks <- remote + }) + + // ClientA2 connects to ClientB + + gatherCompletePromiseA2 := GatheringCompletePromise(clientA2) + // We can't do an ICE Restart here, since it's a different PeerConnection + offerA2, err := clientA2.CreateOffer(nil) + assert.NoError(t, err) + assert.NoError(t, clientA2.SetLocalDescription(offerA2)) + <-gatherCompletePromiseA2 + + assert.NoError(t, clientB.SetRemoteDescription(*clientA2.LocalDescription())) + + gatherCompletePromiseB = GatheringCompletePromise(clientB) + answerB, err = clientB.CreateAnswer(nil) + assert.NoError(t, err) + assert.NoError(t, clientB.SetLocalDescription(answerB)) + <-gatherCompletePromiseB + + assert.NoError(t, clientA2.SetRemoteDescription(answerB)) + + // Wait for connection + <-clientA2Connected + + // At this point, clientA2 should have received a track, and some media + clientA2RemoteTrack := <-clientA2Tracks + + // Read a bunch of RTPs + for ndx := 0; ndx < 10; ndx++ { + pkt, _, err = clientA2RemoteTrack.ReadRTP() + assert.NotNil(t, pkt) + assert.NoError(t, err) + } +}