diff --git a/client_handlers.go b/client_handlers.go index 4bdbbb6f3..e0a5742c9 100644 --- a/client_handlers.go +++ b/client_handlers.go @@ -45,6 +45,19 @@ func clientHandshakeHandler(c *Conn) error { } fallthrough case flight3: + for _, extension := range h.extensions { + if e, ok := extension.(*extensionUseSRTP); ok { + profile, ok := findMatchingSRTPProfile(e.protectionProfiles, c.localSRTPProtectionProfiles) + if !ok { + return fmt.Errorf("Server responded with SRTP Profile we do not support") + } + c.srtpProtectionProfile = profile + } + } + if len(c.localSRTPProtectionProfiles) > 0 && c.srtpProtectionProfile == 0 { + return fmt.Errorf("SRTP support was requested but server did not respond with use_srtp extension") + } + c.cipherSuite = h.cipherSuite c.remoteRandom = h.random } @@ -86,7 +99,7 @@ func clientHandshakeHandler(c *Conn) error { return err } - expectedHash := valueKeySignature(clientRandom, serverRandom, h.publicKey, c.namedCurve, h.hashAlgorithm) + expectedHash := valueKeySignature(clientRandom, serverRandom, h.publicKey, h.namedCurve, h.hashAlgorithm) if err := verifyKeySignature(expectedHash, h.signature, c.remoteCertificate); err != nil { return err } @@ -132,6 +145,28 @@ func clientFlightHandler(c *Conn) (bool, error) { fallthrough case flight3: c.lock.RLock() + + extensions := []extension{ + &extensionSupportedEllipticCurves{ + ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256}, + }, + &extensionSupportedPointFormats{ + pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed}, + }, + &extensionSupportedSignatureAlgorithms{ + signatureHashAlgorithms: []signatureHashAlgorithm{ + {HashAlgorithmSHA256, signatureAlgorithmECDSA}, + {HashAlgorithmSHA384, signatureAlgorithmECDSA}, + {HashAlgorithmSHA512, signatureAlgorithmECDSA}, + }, + }, + } + if len(c.localSRTPProtectionProfiles) > 0 { + extensions = append(extensions, &extensionUseSRTP{ + protectionProfiles: c.localSRTPProtectionProfiles, + }) + } + c.internalSend(&recordLayer{ recordLayerHeader: recordLayerHeader{ sequenceNumber: c.localSequenceNumber, @@ -148,24 +183,7 @@ func clientFlightHandler(c *Conn) (bool, error) { random: c.localRandom, cipherSuites: clientCipherSuites(), compressionMethods: defaultCompressionMethods, - extensions: []extension{ - &extensionSupportedEllipticCurves{ - ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256}, - }, - &extensionUseSRTP{ - protectionProfiles: []srtpProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, - }, - &extensionSupportedPointFormats{ - pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed}, - }, - &extensionSupportedSignatureAlgorithms{ - signatureHashAlgorithms: []signatureHashAlgorithm{ - {HashAlgorithmSHA256, signatureAlgorithmECDSA}, - {HashAlgorithmSHA384, signatureAlgorithmECDSA}, - {HashAlgorithmSHA512, signatureAlgorithmECDSA}, - }, - }, - }, + extensions: extensions, }}, }, false) c.lock.RUnlock() diff --git a/config.go b/config.go index e18d73429..2da1be3d5 100644 --- a/config.go +++ b/config.go @@ -8,6 +8,18 @@ import ( // Config is used to configure a DTLS client or server. // After a Config is passed to a DTLS function it must not be modified. type Config struct { + // Certificates contains certificate chain to present to + // the other side of the connection. Server MUST set this, + // client SHOULD sets this so CertificateRequests + // can be handled Certificate *x509.Certificate - PrivateKey crypto.PrivateKey + + // PrivateKey contains matching private key for the certificate + // only ECDSA is supported + PrivateKey crypto.PrivateKey + + // SRTPProtectionProfiles are the supported protection profiles + // Clients will send this via use_srtp and assert that the server properly responds + // Servers will assert that clients send one of these profiles and will respond as needed + SRTPProtectionProfiles []SRTPProtectionProfile } diff --git a/conn.go b/conn.go index 00afe0a1e..ed224505d 100644 --- a/conn.go +++ b/conn.go @@ -41,6 +41,9 @@ type Conn struct { localEpoch, remoteEpoch atomic.Value localSequenceNumber uint64 // uint48 + localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support + srtpProtectionProfile SRTPProtectionProfile // Negotiated SRTPProtectionProfile + currFlight *flight cipherSuite cipherSuite // nil if a cipherSuite hasn't been chosen namedCurve namedCurve @@ -76,16 +79,17 @@ func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessage } c := &Conn{ - isClient: isClient, - nextConn: nextConn, - currFlight: newFlight(isClient), - fragmentBuffer: newFragmentBuffer(), - handshakeCache: newHandshakeCache(), - handshakeMessageHandler: handshakeMessageHandler, - flightHandler: flightHandler, - localCertificate: config.Certificate, - localPrivateKey: config.PrivateKey, - namedCurve: defaultNamedCurve, + isClient: isClient, + nextConn: nextConn, + currFlight: newFlight(isClient), + fragmentBuffer: newFragmentBuffer(), + handshakeCache: newHandshakeCache(), + handshakeMessageHandler: handshakeMessageHandler, + flightHandler: flightHandler, + localCertificate: config.Certificate, + localPrivateKey: config.PrivateKey, + localSRTPProtectionProfiles: config.SRTPProtectionProfiles, + namedCurve: defaultNamedCurve, decrypted: make(chan []byte), workerTicker: time.NewTicker(initialTickerInterval), @@ -216,6 +220,18 @@ func (c *Conn) RemoteCertificate() *x509.Certificate { return c.remoteCertificate } +// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile +func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + + if c.srtpProtectionProfile == 0 { + return 0, false + } + + return c.srtpProtectionProfile, true +} + // ExportKeyingMaterial from https://tools.ietf.org/html/rfc5705 // This allows protocols to use DTLS for key establishment, but // then use some of the keying material for their own purposes diff --git a/conn_test.go b/conn_test.go index 111b84865..ec913249b 100644 --- a/conn_test.go +++ b/conn_test.go @@ -2,6 +2,7 @@ package dtls import ( "bytes" + "fmt" "net" "testing" "time" @@ -81,12 +82,12 @@ func pipeMemory() (*Conn, *Conn, error) { // Setup client go func() { - client, err := testClient(ca) + client, err := testClient(ca, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}) c <- result{client, err} }() // Setup server - server, err := testServer(cb) + server, err := testServer(cb, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}) if err != nil { return nil, nil, err } @@ -100,32 +101,24 @@ func pipeMemory() (*Conn, *Conn, error) { return res.c, server, nil } -func testClient(c net.Conn) (*Conn, error) { +func testClient(c net.Conn, cfg *Config) (*Conn, error) { clientCert, clientKey, err := GenerateSelfSigned() if err != nil { return nil, err } - - client, err := Client(c, &Config{clientCert, clientKey}) - if err != nil { - return nil, err - } - - return client, nil + cfg.PrivateKey = clientKey + cfg.Certificate = clientCert + return Client(c, cfg) } -func testServer(c net.Conn) (*Conn, error) { +func testServer(c net.Conn, cfg *Config) (*Conn, error) { serverCert, serverKey, err := GenerateSelfSigned() if err != nil { return nil, err } - - server, err := Server(c, &Config{serverCert, serverKey}) - if err != nil { - return nil, err - } - - return server, nil + cfg.PrivateKey = serverKey + cfg.Certificate = serverCert + return Server(c, cfg) } func TestExportKeyingMaterial(t *testing.T) { @@ -175,3 +168,83 @@ func TestExportKeyingMaterial(t *testing.T) { t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedClientKey, keyingMaterial) } } + +func TestSRTPConfiguration(t *testing.T) { + for _, test := range []struct { + Name string + ClientSRTP []SRTPProtectionProfile + ServerSRTP []SRTPProtectionProfile + ExpectedProfile SRTPProtectionProfile + WantClientError error + WantServerError error + }{ + { + Name: "No SRTP in use", + ClientSRTP: nil, + ServerSRTP: nil, + ExpectedProfile: 0, + WantClientError: nil, + WantServerError: nil, + }, + { + Name: "SRTP both ends", + ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, + WantClientError: nil, + WantServerError: nil, + }, + { + Name: "SRTP client only", + ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + ServerSRTP: nil, + ExpectedProfile: 0, + WantClientError: fmt.Errorf("Client requested SRTP but we have no matching profiles"), + WantServerError: fmt.Errorf("Client requested SRTP but we have no matching profiles"), + }, + { + Name: "SRTP server only", + ClientSRTP: nil, + ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + ExpectedProfile: 0, + WantClientError: nil, + WantServerError: nil, + }, + } { + ca, cb := net.Pipe() + type result struct { + c *Conn + err error + } + c := make(chan result) + + go func() { + client, err := testClient(ca, &Config{SRTPProtectionProfiles: test.ClientSRTP}) + c <- result{client, err} + }() + + server, err := testServer(cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}) + if err != nil || test.WantServerError != nil { + if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) { + t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) + } + } + + res := <-c + if res.err != nil || test.WantClientError != nil { + if !(res.err != nil && test.WantClientError != nil && err.Error() == test.WantClientError.Error()) { + t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, err) + } + } + + actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile() + if actualClientSRTP != test.ExpectedProfile { + t.Errorf("TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualClientSRTP) + } + + actualServerSRTP, _ := server.SelectedSRTPProtectionProfile() + if actualServerSRTP != test.ExpectedProfile { + t.Errorf("TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualServerSRTP) + } + } +} diff --git a/extension_use_srtp.go b/extension_use_srtp.go index 64847c982..f89a671e2 100644 --- a/extension_use_srtp.go +++ b/extension_use_srtp.go @@ -8,7 +8,7 @@ const ( // https://tools.ietf.org/html/rfc8422 type extensionUseSRTP struct { - protectionProfiles []srtpProtectionProfile + protectionProfiles []SRTPProtectionProfile } func (e extensionUseSRTP) extensionValue() extensionValue { @@ -44,7 +44,7 @@ func (e *extensionUseSRTP) Unmarshal(data []byte) error { } for i := 0; i < profileCount; i++ { - supportedProfile := srtpProtectionProfile(binary.BigEndian.Uint16(data[(extensionUseSRTPHeaderSize + (i * 2)):])) + supportedProfile := SRTPProtectionProfile(binary.BigEndian.Uint16(data[(extensionUseSRTPHeaderSize + (i * 2)):])) if _, ok := srtpProtectionProfiles[supportedProfile]; ok { e.protectionProfiles = append(e.protectionProfiles, supportedProfile) } diff --git a/extension_use_srtp_test.go b/extension_use_srtp_test.go index e16a24f2b..4846068b2 100644 --- a/extension_use_srtp_test.go +++ b/extension_use_srtp_test.go @@ -8,7 +8,7 @@ import ( func TestExtensionUseSRTP(t *testing.T) { rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x01, 0x00} parsedUseSRTP := &extensionUseSRTP{ - protectionProfiles: []srtpProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, + protectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, } raw, err := parsedUseSRTP.Marshal() diff --git a/server_handlers.go b/server_handlers.go index c81d96c72..750919cf9 100644 --- a/server_handlers.go +++ b/server_handlers.go @@ -50,7 +50,11 @@ func serverHandshakeHandler(c *Conn) error { case *extensionSupportedEllipticCurves: c.namedCurve = e.ellipticCurves[0] case *extensionUseSRTP: - // TODO expose to API + profile, ok := findMatchingSRTPProfile(e.protectionProfiles, c.localSRTPProtectionProfiles) + if !ok { + return fmt.Errorf("Client requested SRTP but we have no matching profiles") + } + c.srtpProtectionProfile = profile } } @@ -148,6 +152,21 @@ func serverFlightHandler(c *Conn) (bool, error) { case flight4: c.lock.RLock() + + extensions := []extension{ + &extensionSupportedEllipticCurves{ + ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256}, + }, + &extensionSupportedPointFormats{ + pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed}, + }, + } + if c.srtpProtectionProfile != 0 { + extensions = append(extensions, &extensionUseSRTP{ + protectionProfiles: []SRTPProtectionProfile{c.srtpProtectionProfile}, + }) + } + c.internalSend(&recordLayer{ recordLayerHeader: recordLayerHeader{ sequenceNumber: c.localSequenceNumber, @@ -163,17 +182,7 @@ func serverFlightHandler(c *Conn) (bool, error) { random: c.localRandom, cipherSuite: c.cipherSuite, compressionMethod: defaultCompressionMethods[0], - extensions: []extension{ - &extensionSupportedEllipticCurves{ - ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256}, - }, - &extensionUseSRTP{ - protectionProfiles: []srtpProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, - }, - &extensionSupportedPointFormats{ - pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed}, - }, - }, + extensions: extensions, }}, }, false) diff --git a/srtp_protection_profile.go b/srtp_protection_profile.go index 9f269c312..3ae3c28c8 100644 --- a/srtp_protection_profile.go +++ b/srtp_protection_profile.go @@ -1,11 +1,13 @@ package dtls -type srtpProtectionProfile uint16 +// SRTPProtectionProfile defines the parameters and options that are in effect for the SRTP processing +// https://tools.ietf.org/html/rfc5764#section-4.1.2 +type SRTPProtectionProfile uint16 const ( - SRTP_AES128_CM_HMAC_SHA1_80 srtpProtectionProfile = 0x0001 // nolint + SRTP_AES128_CM_HMAC_SHA1_80 SRTPProtectionProfile = 0x0001 // nolint ) -var srtpProtectionProfiles = map[srtpProtectionProfile]bool{ +var srtpProtectionProfiles = map[SRTPProtectionProfile]bool{ SRTP_AES128_CM_HMAC_SHA1_80: true, } diff --git a/util.go b/util.go index f2b8a5755..045a10a8f 100644 --- a/util.go +++ b/util.go @@ -130,3 +130,14 @@ func examinePadding(payload []byte) (toRemove int, good byte) { return toRemove, good } + +func findMatchingSRTPProfile(a, b []SRTPProtectionProfile) (SRTPProtectionProfile, bool) { + for _, aProfile := range a { + for _, bProfile := range b { + if aProfile == bProfile { + return aProfile, true + } + } + } + return 0, false +}