Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix PLI and FIR handling, wrongly triggering track.OnEnded #420

Merged
merged 4 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 37 additions & 26 deletions track.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

const (
rtpOutboundMTU = 1200
rtcpInboundMTU = 1500
)

var (
Expand Down Expand Up @@ -222,38 +223,48 @@ func (track *baseTrack) bind(ctx webrtc.TrackLocalContext, specializedTrack Trac
keyFrameController, ok := encodedReader.Controller().(codec.KeyFrameController)
if ok {
stopRead = make(chan struct{})
go func() {
reader := ctx.RTCPReader()
for {
select {
case <-stopRead:
return
default:
}
go track.rtcpReadLoop(ctx.RTCPReader(), keyFrameController, stopRead)
}

var readerBuffer []byte
_, _, err := reader.Read(readerBuffer, interceptor.Attributes{})
if err != nil {
track.onError(err)
return
}
return selectedCodec, nil
}

pkts, err := rtcp.Unmarshal(readerBuffer)
func (track *baseTrack) rtcpReadLoop(reader interceptor.RTCPReader, keyFrameController codec.KeyFrameController, stopRead chan struct{}) {
EmrysMyrddin marked this conversation as resolved.
Show resolved Hide resolved
readerBuffer := make([]byte, rtcpInboundMTU)

for _, pkt := range pkts {
switch pkt.(type) {
case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest:
if err := keyFrameController.ForceKeyFrame(); err != nil {
track.onError(err)
return
}
}
readLoop:
for {
select {
case <-stopRead:
return
default:
}

readLength, _, err := reader.Read(readerBuffer, interceptor.Attributes{})
if err != nil {
if errors.Is(err, io.EOF) {
return
}
logger.Warnf("failed to read rtcp packet: %s", err)
continue
}

pkts, err := rtcp.Unmarshal(readerBuffer[:readLength])
if err != nil {
logger.Warnf("failed to unmarshal rtcp packet: %s", err)
continue
}

for _, pkt := range pkts {
switch pkt.(type) {
case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest:
if err := keyFrameController.ForceKeyFrame(); err != nil {
logger.Warnf("failed to force key frame: %s", err)
continue readLoop
}
}
}()
}
}

return selectedCodec, nil
}

func (track *baseTrack) unbind(ctx webrtc.TrackLocalContext) error {
Expand Down
100 changes: 100 additions & 0 deletions track_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package mediadevices

import (
"errors"
"github.com/pion/interceptor"
"io"
"testing"
"time"
)
Expand Down Expand Up @@ -53,3 +55,101 @@ func TestOnEnded(t *testing.T) {
}
})
}

type fakeRTCPReader struct {
mockReturn chan []byte
end chan struct{}
}

func (mock *fakeRTCPReader) Read(buffer []byte, attributes interceptor.Attributes) (int, interceptor.Attributes, error) {
select {
case <-mock.end:
return 0, nil, io.EOF
case mockReturn := <-mock.mockReturn:
if len(buffer) < len(mock.mockReturn) {
return 0, nil, io.ErrShortBuffer
}

return copy(buffer, mockReturn), attributes, nil
}
}

type fakeKeyFrameController struct {
called chan struct{}
}

func (mock *fakeKeyFrameController) ForceKeyFrame() error {
mock.called <- struct{}{}
return nil
}
EmrysMyrddin marked this conversation as resolved.
Show resolved Hide resolved

func TestRtcpHandler(t *testing.T) {

t.Run("ShouldStopReading", func(t *testing.T) {
tr := &baseTrack{}
stop := make(chan struct{}, 1)
stopped := make(chan struct{})
go func() {
tr.rtcpReadLoop(&fakeRTCPReader{end: stop}, &fakeKeyFrameController{}, stop)
stopped <- struct{}{}
}()

stop <- struct{}{}

select {
case <-time.After(100 * time.Millisecond):
t.Error("Timeout")
case <-stopped:
}
})

t.Run("ShouldForceKeyFrame", func(t *testing.T) {
for packetType, packet := range map[string][]byte{
"PLI": {
// v=2, p=0, FMT=1, PSFB, len=1
0x81, 0xce, 0x00, 0x02,
// ssrc=0x0
0x00, 0x00, 0x00, 0x00,
// ssrc=0x4bc4fcb4
0x4b, 0xc4, 0xfc, 0xb4,
},
"FIR": {
// v=2, p=0, FMT=4, PSFB, len=3
0x84, 0xce, 0x00, 0x04,
// ssrc=0x0
0x00, 0x00, 0x00, 0x00,
// ssrc=0x4bc4fcb4
0x4b, 0xc4, 0xfc, 0xb4,
// ssrc=0x12345678
0x12, 0x34, 0x56, 0x78,
// Seqno=0x42
0x42, 0x00, 0x00, 0x00,
},
} {
t.Run(packetType, func(t *testing.T) {
tr := &baseTrack{}
tr.OnEnded(func(err error) {
if err != io.EOF {
t.Error(err)
}
})
stop := make(chan struct{}, 1)
defer func() {
stop <- struct{}{}
}()
mockKeyFrameController := &fakeKeyFrameController{called: make(chan struct{}, 1)}
mockRTCPReader := &fakeRTCPReader{end: stop, mockReturn: make(chan []byte, 1)}

go tr.rtcpReadLoop(mockRTCPReader, mockKeyFrameController, stop)

mockRTCPReader.mockReturn <- packet

select {
case <-time.After(1000 * time.Millisecond):
t.Error("Timeout")
case <-mockKeyFrameController.called:
}
})
}
})
}