Skip to content

Commit

Permalink
Fix stats interceptor deadlock
Browse files Browse the repository at this point in the history
When closing a PeerConnection that is using the stats recorder as an
interceptor, the process may block. I've dug a little bit and found
out the calls to GetStats(), QueueIncomingRTP(), QueueIncomingRTCP(),
QueueOutgoingRTP() and QueueOutgoingRTCP() block when the recorder
is not yet started or stopped. I've added both TestGetStatsNotBlocking
and TestQueueNotBlocking tests to test this.
  • Loading branch information
asticode authored and Sean-Der committed Apr 25, 2023
1 parent 7355501 commit 1075999
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 54 deletions.
2 changes: 1 addition & 1 deletion pkg/gcc/adaptive_threshold.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (a *adaptiveThreshold) update(estimate time.Duration) {
timeDelta := time.Duration(minInt(int(now.Sub(a.lastUpdate).Milliseconds()), int(maxTimeDelta.Milliseconds()))) * time.Millisecond
d := absEstimate - a.thresh
add := k * float64(d.Milliseconds()) * float64(timeDelta.Milliseconds())
a.thresh += time.Duration(add * 1000) * time.Microsecond
a.thresh += time.Duration(add*1000) * time.Microsecond
a.thresh = clampDuration(a.thresh, a.min, a.max)
a.lastUpdate = now
}
96 changes: 45 additions & 51 deletions pkg/stats/stats_recorder.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package stats

import (
"sync"
"sync/atomic"
"time"

"github.com/pion/interceptor"
Expand Down Expand Up @@ -78,12 +80,9 @@ type recorder struct {
maxLastSenderReports int
maxLastReceiverReferenceTimes int

incomingRTPChan chan *incomingRTP
incomingRTCPChan chan *incomingRTCP
outgoingRTPChan chan *outgoingRTP
outgoingRTCPChan chan *outgoingRTCP
getStatsChan chan Stats
done chan struct{}
latestStats *internalStats
ms *sync.Mutex // Locks latestStats
running uint32
}

func newRecorder(ssrc uint32, clockRate float64) *recorder {
Expand All @@ -93,21 +92,24 @@ func newRecorder(ssrc uint32, clockRate float64) *recorder {
clockRate: clockRate,
maxLastSenderReports: 5,
maxLastReceiverReferenceTimes: 5,
incomingRTPChan: make(chan *incomingRTP),
incomingRTCPChan: make(chan *incomingRTCP),
outgoingRTPChan: make(chan *outgoingRTP),
outgoingRTCPChan: make(chan *outgoingRTCP),
getStatsChan: make(chan Stats),
done: make(chan struct{}),
latestStats: &internalStats{},
ms: &sync.Mutex{},
}
}

func (r *recorder) Stop() {
close(r.done)
atomic.StoreUint32(&r.running, 0)
}

func (r *recorder) GetStats() Stats {
return <-r.getStatsChan
r.ms.Lock()
defer r.ms.Unlock()
return Stats{
InboundRTPStreamStats: r.latestStats.InboundRTPStreamStats,
OutboundRTPStreamStats: r.latestStats.OutboundRTPStreamStats,
RemoteInboundRTPStreamStats: r.latestStats.RemoteInboundRTPStreamStats,
RemoteOutboundRTPStreamStats: r.latestStats.RemoteOutboundRTPStreamStats,
}
}

func (r *recorder) recordIncomingRTP(latestStats internalStats, v *incomingRTP) internalStats {
Expand Down Expand Up @@ -261,38 +263,13 @@ func (r *recorder) recordIncomingRTCP(latestStats internalStats, v *incomingRTCP
}

func (r *recorder) Start() {
latestStats := &internalStats{}
for {
select {
case <-r.done:
return
case v := <-r.incomingRTPChan:
s := r.recordIncomingRTP(*latestStats, v)
latestStats = &s

case v := <-r.outgoingRTCPChan:
s := r.recordOutgoingRTCP(*latestStats, v)
latestStats = &s

case v := <-r.outgoingRTPChan:
s := r.recordOutgoingRTP(*latestStats, v)
latestStats = &s

case v := <-r.incomingRTCPChan:
s := r.recordIncomingRTCP(*latestStats, v)
latestStats = &s

case r.getStatsChan <- Stats{
InboundRTPStreamStats: latestStats.InboundRTPStreamStats,
OutboundRTPStreamStats: latestStats.OutboundRTPStreamStats,
RemoteInboundRTPStreamStats: latestStats.RemoteInboundRTPStreamStats,
RemoteOutboundRTPStreamStats: latestStats.RemoteOutboundRTPStreamStats,
}:
}
}
atomic.StoreUint32(&r.running, 1)
}

func (r *recorder) QueueIncomingRTP(ts time.Time, buf []byte, attr interceptor.Attributes) {
if atomic.LoadUint32(&r.running) == 0 {
return
}
if attr == nil {
attr = make(interceptor.Attributes)
}
Expand All @@ -302,15 +279,20 @@ func (r *recorder) QueueIncomingRTP(ts time.Time, buf []byte, attr interceptor.A
return
}
hdr := header.Clone()
r.incomingRTPChan <- &incomingRTP{
r.ms.Lock()
*r.latestStats = r.recordIncomingRTP(*r.latestStats, &incomingRTP{
ts: ts,
header: hdr,
payloadLen: len(buf) - hdr.MarshalSize(),
attr: attr,
}
})
r.ms.Unlock()
}

func (r *recorder) QueueIncomingRTCP(ts time.Time, buf []byte, attr interceptor.Attributes) {
if atomic.LoadUint32(&r.running) == 0 {
return
}
if attr == nil {
attr = make(interceptor.Attributes)
}
Expand All @@ -319,29 +301,41 @@ func (r *recorder) QueueIncomingRTCP(ts time.Time, buf []byte, attr interceptor.
r.logger.Warnf("failed to get RTCP packets, skipping incoming RTCP packet in stats calculation: %v", err)
return
}
r.incomingRTCPChan <- &incomingRTCP{
r.ms.Lock()
*r.latestStats = r.recordIncomingRTCP(*r.latestStats, &incomingRTCP{
ts: ts,
pkts: pkts,
attr: attr,
}
})
r.ms.Unlock()
}

func (r *recorder) QueueOutgoingRTP(ts time.Time, header *rtp.Header, payload []byte, attr interceptor.Attributes) {
if atomic.LoadUint32(&r.running) == 0 {
return
}
hdr := header.Clone()
r.outgoingRTPChan <- &outgoingRTP{
r.ms.Lock()
*r.latestStats = r.recordOutgoingRTP(*r.latestStats, &outgoingRTP{
ts: ts,
header: hdr,
payloadLen: len(payload),
attr: attr,
}
})
r.ms.Unlock()
}

func (r *recorder) QueueOutgoingRTCP(ts time.Time, pkts []rtcp.Packet, attr interceptor.Attributes) {
r.outgoingRTCPChan <- &outgoingRTCP{
if atomic.LoadUint32(&r.running) == 0 {
return
}
r.ms.Lock()
*r.latestStats = r.recordOutgoingRTCP(*r.latestStats, &outgoingRTCP{
ts: ts,
pkts: pkts,
attr: attr,
}
})
r.ms.Unlock()
}

func min(a, b int) int {
Expand Down
80 changes: 78 additions & 2 deletions pkg/stats/stats_recorder_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package stats

import (
"context"
"errors"
"fmt"
"testing"
"time"

"github.com/pion/interceptor"
"github.com/pion/interceptor/internal/ntp"
"github.com/pion/rtcp"
"github.com/pion/rtp"
Expand Down Expand Up @@ -260,8 +263,7 @@ func TestStatsRecorder(t *testing.T) {
t.Run(fmt.Sprintf("%v:%v", i, cc.name), func(t *testing.T) {
r := newRecorder(0, 90_000)

go r.Start()
defer r.Stop()
r.Start()

for _, record := range cc.records {
switch v := record.content.(type) {
Expand All @@ -282,6 +284,8 @@ func TestStatsRecorder(t *testing.T) {

s := r.GetStats()

r.Stop()

assert.Equal(t, cc.expectedInboundRTPStreamStats, s.InboundRTPStreamStats)
assert.Equal(t, cc.expectedOutboundRTPStreamStats, s.OutboundRTPStreamStats)
assert.Equal(t, cc.expectedRemoteInboundRTPStreamStats, s.RemoteInboundRTPStreamStats)
Expand Down Expand Up @@ -313,3 +317,75 @@ func TestStatsRecorder_DLRR_Precision(t *testing.T) {

assert.Equal(t, int64(s.RemoteOutboundRTPStreamStats.RoundTripTime), int64(-9223372036854775808))
}

func TestGetStatsNotBlocking(t *testing.T) {
r := newRecorder(0, 90_000)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

go func() {
defer cancel()
r.Start()
r.GetStats()
}()
go r.Stop()

<-ctx.Done()

if err := ctx.Err(); err != nil && errors.Is(err, context.DeadlineExceeded) {
t.Error("it shouldn't block")
}
}

func TestQueueNotBlocking(t *testing.T) {
for _, i := range []struct {
f func(r *recorder)
name string
}{
{
f: func(r *recorder) {
r.QueueIncomingRTP(time.Now(), mustMarshalRTP(t, rtp.Packet{}), interceptor.Attributes{})
},
name: "QueueIncomingRTP",
},
{
f: func(r *recorder) {
r.QueueOutgoingRTP(time.Now(), &rtp.Header{}, mustMarshalRTP(t, rtp.Packet{}), interceptor.Attributes{})
},
name: "QueueOutgoingRTP",
},
{
f: func(r *recorder) {
r.QueueIncomingRTCP(time.Now(), mustMarshalRTCPs(t, &rtcp.CCFeedbackReport{}), interceptor.Attributes{})
},
name: "QueueIncomingRTCP",
},
{
f: func(r *recorder) {
r.QueueOutgoingRTCP(time.Now(), []rtcp.Packet{}, interceptor.Attributes{})
},
name: "QueueOutgoingRTCP",
},
} {
t.Run(i.name+"NotBlocking", func(t *testing.T) {
r := newRecorder(0, 90_000)

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

go func() {
defer cancel()
r.Start()
i.f(r)
}()
go r.Stop()

<-ctx.Done()

if err := ctx.Err(); err != nil && errors.Is(err, context.DeadlineExceeded) {
t.Error("it shouldn't block")
}
})
}
}

0 comments on commit 1075999

Please sign in to comment.