Skip to content

Commit

Permalink
Handle Simulcast RepairStream
Browse files Browse the repository at this point in the history
Read + Discard packets from the Simulcast repair stream. When a
Simulcast stream is enabled the remote will send packets via the repair
stream for probing. We can't ignore these packets anymore because it
will cause gaps in the feedback reports

Resolves #1957
  • Loading branch information
Sean-Der committed Sep 15, 2021
1 parent f8fa792 commit 11b8873
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 70 deletions.
2 changes: 2 additions & 0 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ const (
incomingUnhandledRTPSsrc = "Incoming unhandled RTP ssrc(%d), OnTrack will not be fired. %v"

generatedCertificateOrigin = "WebRTC"

sdesRepairRTPStreamIDURI = "urn:ietf:params:rtp-hdrext:sdes:repaired-rtp-stream-id"
)

func defaultSrtpProtectionProfiles() []dtls.SRTPProtectionProfile {
Expand Down
35 changes: 35 additions & 0 deletions dtlstransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

"github.com/pion/dtls/v2"
"github.com/pion/dtls/v2/pkg/crypto/fingerprint"
"github.com/pion/interceptor"
"github.com/pion/logging"
"github.com/pion/rtcp"
"github.com/pion/srtp/v2"
Expand Down Expand Up @@ -459,3 +460,37 @@ func (t *DTLSTransport) storeSimulcastStream(s *srtp.ReadStreamSRTP) {

t.simulcastStreams = append(t.simulcastStreams, s)
}

func (t *DTLSTransport) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) {
srtpSession, err := t.getSRTPSession()
if err != nil {
return nil, nil, nil, nil, err
}

rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc))
if err != nil {
return nil, nil, nil, nil, err
}

rtpInterceptor := t.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = rtpReadStream.Read(in)
return n, a, err
}))

srtcpSession, err := t.getSRTCPSession()
if err != nil {
return nil, nil, nil, nil, err
}

rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc))
if err != nil {
return nil, nil, nil, nil, err
}

rtcpInterceptor := t.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = rtcpReadStream.Read(in)
return n, a, err
}))

return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil
}
1 change: 0 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ var (
errRTPReceiverDTLSTransportNil = errors.New("DTLSTransport must not be nil")
errRTPReceiverReceiveAlreadyCalled = errors.New("Receive has already been called")
errRTPReceiverWithSSRCTrackStreamNotFound = errors.New("unable to find stream for Track with SSRC")
errRTPReceiverForSSRCTrackStreamNotFound = errors.New("no trackStreams found for SSRC")
errRTPReceiverForRIDTrackStreamNotFound = errors.New("no trackStreams found for RID")

errRTPSenderTrackNil = errors.New("Track must not be nil")
Expand Down
4 changes: 2 additions & 2 deletions interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (i *interceptorToTrackLocalWriter) Write(b []byte) (int, error) {
return i.WriteRTP(&packet.Header, packet.Payload)
}

func createStreamInfo(id string, ssrc SSRC, payloadType PayloadType, codec RTPCodecCapability, webrtcHeaderExtensions []RTPHeaderExtensionParameter) interceptor.StreamInfo {
func createStreamInfo(id string, ssrc SSRC, payloadType PayloadType, codec RTPCodecCapability, webrtcHeaderExtensions []RTPHeaderExtensionParameter) *interceptor.StreamInfo {
headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(webrtcHeaderExtensions))
for _, h := range webrtcHeaderExtensions {
headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI})
Expand All @@ -128,7 +128,7 @@ func createStreamInfo(id string, ssrc SSRC, payloadType PayloadType, codec RTPCo
feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter})
}

return interceptor.StreamInfo{
return &interceptor.StreamInfo{
ID: id,
Attributes: interceptor.Attributes{},
SSRC: uint32(ssrc),
Expand Down
68 changes: 45 additions & 23 deletions peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -1397,33 +1397,44 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
return errPeerConnSimulcastStreamIDRTPExtensionRequired
}

repairStreamIDExtensionID, _, _ := pc.api.mediaEngine.getHeaderExtensionID(RTPHeaderExtensionCapability{sdesRepairRTPStreamIDURI})

b := make([]byte, pc.api.settingEngine.getReceiveMTU())
var mid, rid string
for readCount := 0; readCount <= simulcastProbeCount; readCount++ {
i, err := rtpStream.Read(b)
if err != nil {
return err
}

maybeMid, maybeRid, payloadType, err := handleUnknownRTPPacket(b[:i], uint8(midExtensionID), uint8(streamIDExtensionID))
if err != nil {
return err
}
i, err := rtpStream.Read(b)
if err != nil {
return err
}

if maybeMid != "" {
mid = maybeMid
}
if maybeRid != "" {
rid = maybeRid
}
var mid, rid, rsid string
payloadType, err := handleUnknownRTPPacket(b[:i], uint8(midExtensionID), uint8(streamIDExtensionID), uint8(repairStreamIDExtensionID), &mid, &rid, &rsid)
if err != nil {
return err
}

if mid == "" || rid == "" {
continue
}
params, err := pc.api.mediaEngine.getRTPParametersByPayloadType(payloadType)
if err != nil {
return err
}

params, err := pc.api.mediaEngine.getRTPParametersByPayloadType(payloadType)
if err != nil {
return err
streamInfo := createStreamInfo("", ssrc, params.Codecs[0].PayloadType, params.Codecs[0].RTPCodecCapability, params.HeaderExtensions)
readStream, interceptor, rtcpReadStream, rtcpInterceptor, err := pc.dtlsTransport.streamsForSSRC(ssrc, *streamInfo)
if err != nil {
return err
}

for readCount := 0; readCount <= simulcastProbeCount; readCount++ {
if mid == "" || (rid == "" && rsid == "") {
i, _, err := interceptor.Read(b, nil)
if err != nil {
return err
}

if _, err = handleUnknownRTPPacket(b[:i], uint8(midExtensionID), uint8(streamIDExtensionID), uint8(repairStreamIDExtensionID), &mid, &rid, &rsid); err != nil {
return err
}

continue
}

for _, t := range pc.GetTransceivers() {
Expand All @@ -1432,7 +1443,11 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
continue
}

track, err := receiver.receiveForRid(rid, params, ssrc)
if rsid != "" {
return receiver.receiveForRsid(rsid, streamInfo, readStream, interceptor, rtcpReadStream, rtcpInterceptor)
}

track, err := receiver.receiveForRid(rid, params, streamInfo, readStream, interceptor, rtcpReadStream, rtcpInterceptor)
if err != nil {
return err
}
Expand All @@ -1441,6 +1456,13 @@ func (pc *PeerConnection) handleIncomingSSRC(rtpStream io.Reader, ssrc SSRC) err
}
}

if readStream != nil {
_ = readStream.Close()
}
if rtcpReadStream != nil {
_ = rtcpReadStream.Close()
}
pc.api.interceptor.UnbindRemoteStream(streamInfo)
return errPeerConnSimulcastIncomingSSRCFailed
}

Expand Down
9 changes: 8 additions & 1 deletion peerconnection_media_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1175,7 +1175,14 @@ func TestPeerConnection_Simulcast(t *testing.T) {
PayloadType: 96,
}
assert.NoError(t, header.SetExtension(1, []byte("0")))
assert.NoError(t, header.SetExtension(2, []byte(rid)))

// Send RSID for first 10 packets
if sequenceNumber >= 10 {
assert.NoError(t, header.SetExtension(2, []byte(rid)))
} else {
assert.NoError(t, header.SetExtension(3, []byte(rid)))
header.SSRC += 10
}

_, err := vp8Writer.bindings[0].writeStream.WriteRTP(header, []byte{0x00})
assert.NoError(t, err)
Expand Down
98 changes: 59 additions & 39 deletions rtpreceiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,19 @@ import (
type trackStreams struct {
track *TrackRemote

streamInfo interceptor.StreamInfo
streamInfo, repairStreamInfo *interceptor.StreamInfo

rtpReadStream *srtp.ReadStreamSRTP
rtpInterceptor interceptor.RTPReader

rtcpReadStream *srtp.ReadStreamSRTCP
rtcpInterceptor interceptor.RTCPReader

repairReadStream *srtp.ReadStreamSRTP
repairInterceptor interceptor.RTPReader

repairRtcpReadStream *srtp.ReadStreamSRTCP
repairRtcpInterceptor interceptor.RTCPReader
}

// RTPReceiver allows an application to inspect the receipt of a TrackRemote
Expand Down Expand Up @@ -146,7 +152,7 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error {
if parameters.Encodings[i].SSRC != 0 {
t.streamInfo = createStreamInfo("", parameters.Encodings[i].SSRC, 0, codec, globalParams.HeaderExtensions)
var err error
if t.rtpReadStream, t.rtpInterceptor, t.rtcpReadStream, t.rtcpInterceptor, err = r.streamsForSSRC(parameters.Encodings[i].SSRC, t.streamInfo); err != nil {
if t.rtpReadStream, t.rtpInterceptor, t.rtcpReadStream, t.rtcpInterceptor, err = r.transport.streamsForSSRC(parameters.Encodings[i].SSRC, *t.streamInfo); err != nil {
return err
}
}
Expand Down Expand Up @@ -245,8 +251,23 @@ func (r *RTPReceiver) Stop() error {
errs = append(errs, r.tracks[i].rtpReadStream.Close())
}

if r.tracks[i].repairReadStream != nil {
errs = append(errs, r.tracks[i].repairReadStream.Close())
}

if r.tracks[i].repairRtcpReadStream != nil {
errs = append(errs, r.tracks[i].repairRtcpReadStream.Close())
}

if r.tracks[i].streamInfo != nil {
r.api.interceptor.UnbindRemoteStream(r.tracks[i].streamInfo)
}

if r.tracks[i].repairStreamInfo != nil {
r.api.interceptor.UnbindRemoteStream(r.tracks[i].repairStreamInfo)
}

err = util.FlattenErrs(errs)
r.api.interceptor.UnbindRemoteStream(&r.tracks[i].streamInfo)
}
default:
}
Expand Down Expand Up @@ -276,7 +297,7 @@ func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, a intercept

// receiveForRid is the sibling of Receive expect for RIDs instead of SSRCs
// It populates all the internal state for the given RID
func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, ssrc SSRC) (*TrackRemote, error) {
func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, streamInfo *interceptor.StreamInfo, rtpReadStream *srtp.ReadStreamSRTP, rtpInterceptor interceptor.RTPReader, rtcpReadStream *srtp.ReadStreamSRTCP, rtcpInterceptor interceptor.RTCPReader) (*TrackRemote, error) {
r.mu.Lock()
defer r.mu.Unlock()

Expand All @@ -286,54 +307,53 @@ func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, ssrc SSRC)
r.tracks[i].track.kind = r.kind
r.tracks[i].track.codec = params.Codecs[0]
r.tracks[i].track.params = params
r.tracks[i].track.ssrc = ssrc
r.tracks[i].streamInfo = createStreamInfo("", ssrc, params.Codecs[0].PayloadType, params.Codecs[0].RTPCodecCapability, params.HeaderExtensions)
r.tracks[i].track.ssrc = SSRC(streamInfo.SSRC)
r.tracks[i].track.mu.Unlock()

var err error
if r.tracks[i].rtpReadStream, r.tracks[i].rtpInterceptor, r.tracks[i].rtcpReadStream, r.tracks[i].rtcpInterceptor, err = r.streamsForSSRC(ssrc, r.tracks[i].streamInfo); err != nil {
return nil, err
}
r.tracks[i].streamInfo = streamInfo
r.tracks[i].rtpReadStream = rtpReadStream
r.tracks[i].rtpInterceptor = rtpInterceptor
r.tracks[i].rtcpReadStream = rtcpReadStream
r.tracks[i].rtcpInterceptor = rtcpInterceptor

return r.tracks[i].track, nil
}
}

return nil, fmt.Errorf("%w: %d", errRTPReceiverForSSRCTrackStreamNotFound, ssrc)
return nil, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid)
}

func (r *RTPReceiver) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) {
srtpSession, err := r.transport.getSRTPSession()
if err != nil {
return nil, nil, nil, nil, err
}

rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc))
if err != nil {
return nil, nil, nil, nil, err
}

rtpInterceptor := r.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = rtpReadStream.Read(in)
return n, a, err
}))
// receiveForRsid starts a routine that processes the repair stream for a RID
// These packets aren't exposed to the user yet, but we need to process them for
// TWCC
func (r *RTPReceiver) receiveForRsid(rsid string, streamInfo *interceptor.StreamInfo, rtpReadStream *srtp.ReadStreamSRTP, rtpInterceptor interceptor.RTPReader, rtcpReadStream *srtp.ReadStreamSRTCP, rtcpInterceptor interceptor.RTCPReader) error {
r.mu.Lock()
defer r.mu.Unlock()

srtcpSession, err := r.transport.getSRTCPSession()
if err != nil {
return nil, nil, nil, nil, err
}
for i := range r.tracks {
if r.tracks[i].track.RID() == rsid {
var err error

rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc))
if err != nil {
return nil, nil, nil, nil, err
r.tracks[i].repairStreamInfo = streamInfo
r.tracks[i].repairReadStream = rtpReadStream
r.tracks[i].repairInterceptor = rtpInterceptor
r.tracks[i].repairRtcpReadStream = rtcpReadStream
r.tracks[i].repairRtcpInterceptor = rtcpInterceptor

go func() {
b := make([]byte, r.api.settingEngine.getReceiveMTU())
for {
if _, _, readErr := r.tracks[i].repairInterceptor.Read(b, nil); readErr != nil {
return
}
}
}()

return err
}
}

rtcpInterceptor := r.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) {
n, err = rtcpReadStream.Read(in)
return n, a, err
}))

return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil
return fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rsid)
}

// SetReadDeadline sets the max amount of time the RTCP stream will block before returning. 0 is forever.
Expand Down
2 changes: 1 addition & 1 deletion rtpsender.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error {
}
r.context.params.Codecs = []RTPCodecParameters{codec}

r.streamInfo = createStreamInfo(r.id, parameters.Encodings[0].SSRC, codec.PayloadType, codec.RTPCodecCapability, parameters.HeaderExtensions)
r.streamInfo = *createStreamInfo(r.id, parameters.Encodings[0].SSRC, codec.PayloadType, codec.RTPCodecCapability, parameters.HeaderExtensions)
rtpInterceptor := r.api.interceptor.BindLocalStream(&r.streamInfo, interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) {
return r.srtpStream.WriteRTP(header, payload)
}))
Expand Down
10 changes: 7 additions & 3 deletions rtptransceiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func satisfyTypeAndDirection(remoteKind RTPCodecType, remoteDirection RTPTransce

// handleUnknownRTPPacket consumes a single RTP Packet and returns information that is helpful
// for demuxing and handling an unknown SSRC (usually for Simulcast)
func handleUnknownRTPPacket(buf []byte, midExtensionID, streamIDExtensionID uint8) (mid, rid string, payloadType PayloadType, err error) {
func handleUnknownRTPPacket(buf []byte, midExtensionID, streamIDExtensionID, repairStreamIDExtensionID uint8, mid, rid, rsid *string) (payloadType PayloadType, err error) {
rp := &rtp.Packet{}
if err = rp.Unmarshal(buf); err != nil {
return
Expand All @@ -259,11 +259,15 @@ func handleUnknownRTPPacket(buf []byte, midExtensionID, streamIDExtensionID uint

payloadType = PayloadType(rp.PayloadType)
if payload := rp.GetExtension(midExtensionID); payload != nil {
mid = string(payload)
*mid = string(payload)
}

if payload := rp.GetExtension(streamIDExtensionID); payload != nil {
rid = string(payload)
*rid = string(payload)
}

if payload := rp.GetExtension(repairStreamIDExtensionID); payload != nil {
*rsid = string(payload)
}

return
Expand Down

0 comments on commit 11b8873

Please sign in to comment.