Skip to content

Commit

Permalink
Add options to set custom replay detector
Browse files Browse the repository at this point in the history
Support using custom implementation of replay detector.
  • Loading branch information
at-wat committed Jul 26, 2023
1 parent def59cc commit d8652d4
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 0 deletions.
16 changes: 16 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ func SRTCPNoReplayProtection() ContextOption {
}
}

// SRTPReplayDetectorFactory sets custom SRTP replay detector.
func SRTPReplayDetectorFactory(fn func() replaydetector.ReplayDetector) ContextOption { // nolint:revive
return func(c *Context) error {
c.newSRTPReplayDetector = fn
return nil
}
}

// SRTCPReplayDetectorFactory sets custom SRTCP replay detector.
func SRTCPReplayDetectorFactory(fn func() replaydetector.ReplayDetector) ContextOption {
return func(c *Context) error {
c.newSRTCPReplayDetector = fn
return nil
}
}

type nopReplayDetector struct{}

func (s *nopReplayDetector) Check(uint64) (func(), bool) {
Expand Down
24 changes: 24 additions & 0 deletions srtcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"testing"

"github.com/pion/rtcp"
"github.com/pion/transport/v2/replaydetector"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -570,3 +571,26 @@ func TestRTCPMaxPackets(t *testing.T) {
})
}
}

func TestRTCPReplayDetectorFactory(t *testing.T) {
assert := assert.New(t)
testCase := rtcpTestCases()["AEAD_AES_128_GCM"]
data := testCase.packets[0]

var cntFactory int
decryptContext, err := CreateContext(
testCase.masterKey, testCase.masterSalt, testCase.algo,
SRTCPReplayDetectorFactory(func() replaydetector.ReplayDetector {
cntFactory++
return &nopReplayDetector{}
}),
)
if err != nil {
t.Fatal(err)
}

if _, err := decryptContext.DecryptRTCP(nil, data.encrypted, nil); err != nil {
t.Fatal(err)
}
assert.Equal(1, cntFactory)
}
32 changes: 32 additions & 0 deletions srtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"testing"

"github.com/pion/rtp"
"github.com/pion/transport/v2/replaydetector"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -462,6 +463,37 @@ func TestRTPReplayProtection(t *testing.T) {
t.Run("GCM", func(t *testing.T) { testRTPReplayProtection(t, profileGCM) })
}

func TestRTPReplayDetectorFactory(t *testing.T) {
assert := assert.New(t)
profile := profileCTR
data := rtpTestCases()[0]

var cntFactory int
decryptContext, err := buildTestContext(
profile, SRTPReplayDetectorFactory(func() replaydetector.ReplayDetector {
cntFactory++
return &nopReplayDetector{}
}),
)
if err != nil {
t.Fatal(err)
}

pkt := &rtp.Packet{
Payload: data.encrypted(profile),
Header: rtp.Header{SequenceNumber: data.sequenceNumber},
}
in, err := pkt.Marshal()
if err != nil {
t.Fatal(err)
}

if _, err := decryptContext.DecryptRTP(nil, in, nil); err != nil {
t.Fatal(err)
}
assert.Equal(1, cntFactory)
}

func benchmarkEncryptRTP(b *testing.B, profile ProtectionProfile, size int) {
encryptContext, err := buildTestContext(profile)
if err != nil {
Expand Down

0 comments on commit d8652d4

Please sign in to comment.