diff --git a/context.go b/context.go index 4dac5c2..b021377 100644 --- a/context.go +++ b/context.go @@ -7,6 +7,7 @@ import ( "crypto/hmac" "crypto/sha1" // #nosec "encoding/binary" + "hash" "github.com/pkg/errors" ) @@ -56,11 +57,13 @@ type Context struct { ssrcStates map[uint32]*ssrcState srtpSessionKey []byte srtpSessionSalt []byte + srtpSessionAuth hash.Hash srtpSessionAuthTag []byte srtpBlock cipher.Block srtcpSessionKey []byte srtcpSessionSalt []byte + srtcpSessionAuth hash.Hash srtcpSessionAuthTag []byte srtcpIndex uint32 srtcpBlock cipher.Block @@ -90,6 +93,8 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile) (c * return nil, err } + c.srtpSessionAuth = hmac.New(sha1.New, c.srtpSessionAuthTag) + if c.srtcpSessionKey, err = c.generateSessionKey(labelSRTCPEncryption); err != nil { return nil, err } else if c.srtcpSessionSalt, err = c.generateSessionSalt(labelSRTCPSalt); err != nil { @@ -100,6 +105,8 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile) (c * return nil, err } + c.srtcpSessionAuth = hmac.New(sha1.New, c.srtcpSessionAuthTag) + return c, nil } @@ -151,6 +158,7 @@ func (c *Context) generateSessionSalt(label byte) ([]byte, error) { block.Encrypt(sessionSalt, sessionSalt) return sessionSalt[0:saltLen], nil } + func (c *Context) generateSessionAuthTag(label byte) ([]byte, error) { // https://tools.ietf.org/html/rfc3711#appendix-B.3 // We now show how the auth key is generated. The input block for AES- @@ -199,7 +207,32 @@ func (c *Context) generateCounter(sequenceNumber uint16, rolloverCounter uint32, return counter } -func (c *Context) generateAuthTag(buf, sessionAuthTag []byte) ([]byte, error) { +func (c *Context) generateSrtpAuthTag(buf []byte) ([]byte, error) { + // https://tools.ietf.org/html/rfc3711#section-4.2 + // In the case of SRTP, M SHALL consist of the Authenticated + // Portion of the packet (as specified in Figure 1) concatenated with + // the ROC, M = Authenticated Portion || ROC; + // + // The pre-defined authentication transform for SRTP is HMAC-SHA1 + // [RFC2104]. With HMAC-SHA1, the SRTP_PREFIX_LENGTH (Figure 3) SHALL + // be 0. For SRTP (respectively SRTCP), the HMAC SHALL be applied to + // the session authentication key and M as specified above, i.e., + // HMAC(k_a, M). The HMAC output SHALL then be truncated to the n_tag + // left-most bits. + // - Authenticated portion of the packet is everything BEFORE MKI + // - k_a is the session message authentication key + // - n_tag is the bit-length of the output authentication tag + // - ROC is already added by caller (to allow RTP + RTCP support) + c.srtpSessionAuth.Reset() + + if _, err := c.srtpSessionAuth.Write(buf); err != nil { + return nil, err + } + + return c.srtpSessionAuth.Sum(nil)[0:10], nil +} + +func (c *Context) generateSrtcpAuthTag(buf []byte) ([]byte, error) { // https://tools.ietf.org/html/rfc3711#section-4.2 // In the case of SRTP, M SHALL consist of the Authenticated // Portion of the packet (as specified in Figure 1) concatenated with @@ -215,17 +248,17 @@ func (c *Context) generateAuthTag(buf, sessionAuthTag []byte) ([]byte, error) { // - k_a is the session message authentication key // - n_tag is the bit-length of the output authentication tag // - ROC is already added by caller (to allow RTP + RTCP support) - mac := hmac.New(sha1.New, sessionAuthTag) + c.srtcpSessionAuth.Reset() - if _, err := mac.Write(buf); err != nil { + if _, err := c.srtcpSessionAuth.Write(buf); err != nil { return nil, err } - return mac.Sum(nil)[0:10], nil + return c.srtcpSessionAuth.Sum(nil)[0:10], nil } -func (c *Context) verifyAuthTag(buf, actualAuthTag []byte) (bool, error) { - expectedAuthTag, err := c.generateAuthTag(buf, c.srtpSessionAuthTag) +func (c *Context) verifySrtpAuthTag(buf, actualAuthTag []byte) (bool, error) { + expectedAuthTag, err := c.generateSrtpAuthTag(buf) if err != nil { return false, err } diff --git a/srtcp.go b/srtcp.go index 425a43e..d6a1ce6 100644 --- a/srtcp.go +++ b/srtcp.go @@ -62,7 +62,7 @@ func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) { binary.BigEndian.PutUint32(out[len(out)-4:], c.srtcpIndex) out[len(out)-4] |= 0x80 - authTag, err := c.generateAuthTag(out, c.srtcpSessionAuthTag) + authTag, err := c.generateSrtcpAuthTag(out) if err != nil { return nil, err } diff --git a/srtp.go b/srtp.go index 2ed2f31..e19067e 100644 --- a/srtp.go +++ b/srtp.go @@ -18,7 +18,7 @@ func (c *Context) decryptRTP(dst, encrypted []byte, header *rtp.Header) ([]byte, binary.BigEndian.PutUint32(pktWithROC[len(pktWithROC)-4:], s.rolloverCounter) actualAuthTag := dst[len(dst)-authTagSize:] - verified, err := c.verifyAuthTag(pktWithROC, actualAuthTag) + verified, err := c.verifySrtpAuthTag(pktWithROC, actualAuthTag) if err != nil { return nil, err } else if !verified { @@ -90,7 +90,7 @@ func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) ( offset += 4 // Generate the auth tag. - authTag, err := c.generateAuthTag(dst[:offset], c.srtpSessionAuthTag) + authTag, err := c.generateSrtpAuthTag(dst[:offset]) if err != nil { return nil, err }