Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multiple issues found by issue-32 #37

Merged
merged 2 commits into from
Feb 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 37 additions & 19 deletions client_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ func clientHandshakeHandler(c *Conn) error {
}
fallthrough
case flight3:
for _, extension := range h.extensions {
if e, ok := extension.(*extensionUseSRTP); ok {
profile, ok := findMatchingSRTPProfile(e.protectionProfiles, c.localSRTPProtectionProfiles)
if !ok {
return fmt.Errorf("Server responded with SRTP Profile we do not support")
}
c.srtpProtectionProfile = profile
}
}
if len(c.localSRTPProtectionProfiles) > 0 && c.srtpProtectionProfile == 0 {
return fmt.Errorf("SRTP support was requested but server did not respond with use_srtp extension")
}

c.cipherSuite = h.cipherSuite
c.remoteRandom = h.random
}
Expand Down Expand Up @@ -86,7 +99,7 @@ func clientHandshakeHandler(c *Conn) error {
return err
}

expectedHash := valueKeySignature(clientRandom, serverRandom, h.publicKey, c.namedCurve, h.hashAlgorithm)
expectedHash := valueKeySignature(clientRandom, serverRandom, h.publicKey, h.namedCurve, h.hashAlgorithm)
if err := verifyKeySignature(expectedHash, h.signature, c.remoteCertificate); err != nil {
return err
}
Expand Down Expand Up @@ -132,6 +145,28 @@ func clientFlightHandler(c *Conn) (bool, error) {
fallthrough
case flight3:
c.lock.RLock()

extensions := []extension{
&extensionSupportedEllipticCurves{
ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256},
},
&extensionSupportedPointFormats{
pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed},
},
&extensionSupportedSignatureAlgorithms{
signatureHashAlgorithms: []signatureHashAlgorithm{
{HashAlgorithmSHA256, signatureAlgorithmECDSA},
{HashAlgorithmSHA384, signatureAlgorithmECDSA},
{HashAlgorithmSHA512, signatureAlgorithmECDSA},
},
},
}
if len(c.localSRTPProtectionProfiles) > 0 {
extensions = append(extensions, &extensionUseSRTP{
protectionProfiles: c.localSRTPProtectionProfiles,
})
}

c.internalSend(&recordLayer{
recordLayerHeader: recordLayerHeader{
sequenceNumber: c.localSequenceNumber,
Expand All @@ -148,24 +183,7 @@ func clientFlightHandler(c *Conn) (bool, error) {
random: c.localRandom,
cipherSuites: clientCipherSuites(),
compressionMethods: defaultCompressionMethods,
extensions: []extension{
&extensionSupportedEllipticCurves{
ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256},
},
&extensionUseSRTP{
protectionProfiles: []srtpProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
},
&extensionSupportedPointFormats{
pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed},
},
&extensionSupportedSignatureAlgorithms{
signatureHashAlgorithms: []signatureHashAlgorithm{
{HashAlgorithmSHA256, signatureAlgorithmECDSA},
{HashAlgorithmSHA384, signatureAlgorithmECDSA},
{HashAlgorithmSHA512, signatureAlgorithmECDSA},
},
},
},
extensions: extensions,
}},
}, false)
c.lock.RUnlock()
Expand Down
14 changes: 13 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ import (
// Config is used to configure a DTLS client or server.
// After a Config is passed to a DTLS function it must not be modified.
type Config struct {
// Certificates contains certificate chain to present to
// the other side of the connection. Server MUST set this,
// client SHOULD sets this so CertificateRequests
// can be handled
Certificate *x509.Certificate
PrivateKey crypto.PrivateKey

// PrivateKey contains matching private key for the certificate
// only ECDSA is supported
PrivateKey crypto.PrivateKey

// SRTPProtectionProfiles are the supported protection profiles
// Clients will send this via use_srtp and assert that the server properly responds
// Servers will assert that clients send one of these profiles and will respond as needed
SRTPProtectionProfiles []SRTPProtectionProfile
}
36 changes: 26 additions & 10 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ type Conn struct {
localEpoch, remoteEpoch atomic.Value
localSequenceNumber uint64 // uint48

localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support
srtpProtectionProfile SRTPProtectionProfile // Negotiated SRTPProtectionProfile

currFlight *flight
cipherSuite cipherSuite // nil if a cipherSuite hasn't been chosen
namedCurve namedCurve
Expand Down Expand Up @@ -76,16 +79,17 @@ func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessage
}

c := &Conn{
isClient: isClient,
nextConn: nextConn,
currFlight: newFlight(isClient),
fragmentBuffer: newFragmentBuffer(),
handshakeCache: newHandshakeCache(),
handshakeMessageHandler: handshakeMessageHandler,
flightHandler: flightHandler,
localCertificate: config.Certificate,
localPrivateKey: config.PrivateKey,
namedCurve: defaultNamedCurve,
isClient: isClient,
nextConn: nextConn,
currFlight: newFlight(isClient),
fragmentBuffer: newFragmentBuffer(),
handshakeCache: newHandshakeCache(),
handshakeMessageHandler: handshakeMessageHandler,
flightHandler: flightHandler,
localCertificate: config.Certificate,
localPrivateKey: config.PrivateKey,
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
namedCurve: defaultNamedCurve,

decrypted: make(chan []byte),
workerTicker: time.NewTicker(initialTickerInterval),
Expand Down Expand Up @@ -216,6 +220,18 @@ func (c *Conn) RemoteCertificate() *x509.Certificate {
return c.remoteCertificate
}

// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) {
c.lock.RLock()
defer c.lock.RUnlock()

if c.srtpProtectionProfile == 0 {
return 0, false
}

return c.srtpProtectionProfile, true
}

// ExportKeyingMaterial from https://tools.ietf.org/html/rfc5705
// This allows protocols to use DTLS for key establishment, but
// then use some of the keying material for their own purposes
Expand Down
109 changes: 91 additions & 18 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dtls

import (
"bytes"
"fmt"
"net"
"testing"
"time"
Expand Down Expand Up @@ -81,12 +82,12 @@ func pipeMemory() (*Conn, *Conn, error) {

// Setup client
go func() {
client, err := testClient(ca)
client, err := testClient(ca, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}})
c <- result{client, err}
}()

// Setup server
server, err := testServer(cb)
server, err := testServer(cb, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}})
if err != nil {
return nil, nil, err
}
Expand All @@ -100,32 +101,24 @@ func pipeMemory() (*Conn, *Conn, error) {
return res.c, server, nil
}

func testClient(c net.Conn) (*Conn, error) {
func testClient(c net.Conn, cfg *Config) (*Conn, error) {
clientCert, clientKey, err := GenerateSelfSigned()
if err != nil {
return nil, err
}

client, err := Client(c, &Config{clientCert, clientKey})
if err != nil {
return nil, err
}

return client, nil
cfg.PrivateKey = clientKey
cfg.Certificate = clientCert
return Client(c, cfg)
}

func testServer(c net.Conn) (*Conn, error) {
func testServer(c net.Conn, cfg *Config) (*Conn, error) {
serverCert, serverKey, err := GenerateSelfSigned()
if err != nil {
return nil, err
}

server, err := Server(c, &Config{serverCert, serverKey})
if err != nil {
return nil, err
}

return server, nil
cfg.PrivateKey = serverKey
cfg.Certificate = serverCert
return Server(c, cfg)
}

func TestExportKeyingMaterial(t *testing.T) {
Expand Down Expand Up @@ -175,3 +168,83 @@ func TestExportKeyingMaterial(t *testing.T) {
t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedClientKey, keyingMaterial)
}
}

func TestSRTPConfiguration(t *testing.T) {
for _, test := range []struct {
Name string
ClientSRTP []SRTPProtectionProfile
ServerSRTP []SRTPProtectionProfile
ExpectedProfile SRTPProtectionProfile
WantClientError error
WantServerError error
}{
{
Name: "No SRTP in use",
ClientSRTP: nil,
ServerSRTP: nil,
ExpectedProfile: 0,
WantClientError: nil,
WantServerError: nil,
},
{
Name: "SRTP both ends",
ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80,
WantClientError: nil,
WantServerError: nil,
},
{
Name: "SRTP client only",
ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ServerSRTP: nil,
ExpectedProfile: 0,
WantClientError: fmt.Errorf("Client requested SRTP but we have no matching profiles"),
WantServerError: fmt.Errorf("Client requested SRTP but we have no matching profiles"),
},
{
Name: "SRTP server only",
ClientSRTP: nil,
ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
ExpectedProfile: 0,
WantClientError: nil,
WantServerError: nil,
},
} {
ca, cb := net.Pipe()
type result struct {
c *Conn
err error
}
c := make(chan result)

go func() {
client, err := testClient(ca, &Config{SRTPProtectionProfiles: test.ClientSRTP})
c <- result{client, err}
}()

server, err := testServer(cb, &Config{SRTPProtectionProfiles: test.ServerSRTP})
if err != nil || test.WantServerError != nil {
if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) {
t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err)
}
}

res := <-c
if res.err != nil || test.WantClientError != nil {
if !(res.err != nil && test.WantClientError != nil && err.Error() == test.WantClientError.Error()) {
t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, err)
}
}

actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile()
if actualClientSRTP != test.ExpectedProfile {
t.Errorf("TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualClientSRTP)
}

actualServerSRTP, _ := server.SelectedSRTPProtectionProfile()
if actualServerSRTP != test.ExpectedProfile {
t.Errorf("TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualServerSRTP)
}
}
}
4 changes: 2 additions & 2 deletions extension_use_srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const (

// https://tools.ietf.org/html/rfc8422
type extensionUseSRTP struct {
protectionProfiles []srtpProtectionProfile
protectionProfiles []SRTPProtectionProfile
}

func (e extensionUseSRTP) extensionValue() extensionValue {
Expand Down Expand Up @@ -44,7 +44,7 @@ func (e *extensionUseSRTP) Unmarshal(data []byte) error {
}

for i := 0; i < profileCount; i++ {
supportedProfile := srtpProtectionProfile(binary.BigEndian.Uint16(data[(extensionUseSRTPHeaderSize + (i * 2)):]))
supportedProfile := SRTPProtectionProfile(binary.BigEndian.Uint16(data[(extensionUseSRTPHeaderSize + (i * 2)):]))
if _, ok := srtpProtectionProfiles[supportedProfile]; ok {
e.protectionProfiles = append(e.protectionProfiles, supportedProfile)
}
Expand Down
2 changes: 1 addition & 1 deletion extension_use_srtp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
func TestExtensionUseSRTP(t *testing.T) {
rawUseSRTP := []byte{0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x01, 0x00}
parsedUseSRTP := &extensionUseSRTP{
protectionProfiles: []srtpProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
protectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
}

raw, err := parsedUseSRTP.Marshal()
Expand Down
33 changes: 21 additions & 12 deletions server_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ func serverHandshakeHandler(c *Conn) error {
case *extensionSupportedEllipticCurves:
c.namedCurve = e.ellipticCurves[0]
case *extensionUseSRTP:
// TODO expose to API
profile, ok := findMatchingSRTPProfile(e.protectionProfiles, c.localSRTPProtectionProfiles)
if !ok {
return fmt.Errorf("Client requested SRTP but we have no matching profiles")
}
c.srtpProtectionProfile = profile
}
}

Expand Down Expand Up @@ -148,6 +152,21 @@ func serverFlightHandler(c *Conn) (bool, error) {

case flight4:
c.lock.RLock()

extensions := []extension{
&extensionSupportedEllipticCurves{
ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256},
},
&extensionSupportedPointFormats{
pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed},
},
}
if c.srtpProtectionProfile != 0 {
extensions = append(extensions, &extensionUseSRTP{
protectionProfiles: []SRTPProtectionProfile{c.srtpProtectionProfile},
})
}

c.internalSend(&recordLayer{
recordLayerHeader: recordLayerHeader{
sequenceNumber: c.localSequenceNumber,
Expand All @@ -163,17 +182,7 @@ func serverFlightHandler(c *Conn) (bool, error) {
random: c.localRandom,
cipherSuite: c.cipherSuite,
compressionMethod: defaultCompressionMethods[0],
extensions: []extension{
&extensionSupportedEllipticCurves{
ellipticCurves: []namedCurve{namedCurveX25519, namedCurveP256},
},
&extensionUseSRTP{
protectionProfiles: []srtpProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80},
},
&extensionSupportedPointFormats{
pointFormats: []ellipticCurvePointFormat{ellipticCurvePointFormatUncompressed},
},
},
extensions: extensions,
}},
}, false)

Expand Down
Loading