diff --git a/cipher_suite.go b/cipher_suite.go index 53ae3a633..1d736b3aa 100644 --- a/cipher_suite.go +++ b/cipher_suite.go @@ -2,6 +2,7 @@ package dtls import ( "encoding/binary" + "fmt" "hash" ) @@ -24,6 +25,7 @@ type cipherSuite interface { ID() CipherSuiteID certificateType() clientCertificateType hashFunc() func() hash.Hash + isPSK() bool // Generate the internal encryption state init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error @@ -85,3 +87,45 @@ func encodeCipherSuites(c []cipherSuite) []byte { return out } + +func parseCipherSuites(userSelectedSuites []CipherSuiteID, psk []byte) ([]cipherSuite, error) { + cipherSuitesForIDs := func(ids []CipherSuiteID) ([]cipherSuite, error) { + cipherSuites := []cipherSuite{} + for _, id := range ids { + c := cipherSuiteForID(id) + if c == nil { + return nil, fmt.Errorf("CipherSuite with id(%d) is not valid", id) + } + cipherSuites = append(cipherSuites, c) + } + return cipherSuites, nil + } + + var ( + cipherSuites []cipherSuite + err error + i int + ) + if len(userSelectedSuites) != 0 { + cipherSuites, err = cipherSuitesForIDs(userSelectedSuites) + if err != nil { + return nil, err + } + } else { + cipherSuites = defaultCipherSuites() + } + + for _, c := range cipherSuites { + if (psk != nil && c.isPSK()) || (psk == nil && !c.isPSK()) { + cipherSuites[i] = c + i++ + } + } + + cipherSuites = cipherSuites[:i] + if len(cipherSuites) == 0 { + return nil, errNoAvailableCipherSuites + } + + return cipherSuites, nil +} diff --git a/cipher_suite_tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go b/cipher_suite_tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go index 566453d9d..7dd7486c6 100644 --- a/cipher_suite_tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go +++ b/cipher_suite_tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go @@ -26,6 +26,10 @@ func (c cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) hashFunc() func() hash.Hash return sha256.New } +func (c cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) isPSK() bool { + return false +} + func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const ( prfMacLen = 0 diff --git a/cipher_suite_tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go b/cipher_suite_tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go index 028013f75..5d5771f82 100644 --- a/cipher_suite_tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go +++ b/cipher_suite_tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go @@ -26,6 +26,10 @@ func (c cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) hashFunc() func() hash.Hash { return sha256.New } +func (c cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) isPSK() bool { + return false +} + func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error { const ( prfMacLen = 20 diff --git a/config.go b/config.go index aabb7ecfc..b54490bc8 100644 --- a/config.go +++ b/config.go @@ -11,10 +11,9 @@ import ( // Config is used to configure a DTLS client or server. // After a Config is passed to a DTLS function it must not be modified. type Config struct { - // Certificates contains certificate chain to present to - // the other side of the connection. Server MUST set this, - // client SHOULD sets this so CertificateRequests - // can be handled + // Certificates contains certificate chain to present to the other side of the connection. + // Server MUST set this if PSK is non-nil + // client SHOULD sets this so CertificateRequests can be handled if PSK is non-nil Certificate *x509.Certificate // PrivateKey contains matching private key for the certificate @@ -38,6 +37,10 @@ type Config struct { // defaults to time.Second FlightInterval time.Duration + // PSK sets the pre-shared key used by this DTLS connection + // If PSK is non-nil only PSK CipherSuites will be used + PSK []byte + LoggerFactory logging.LoggerFactory } diff --git a/conn.go b/conn.go index 3f85a96d2..5acf136a0 100644 --- a/conn.go +++ b/conn.go @@ -5,7 +5,6 @@ import ( "crypto/ecdsa" "crypto/rand" "crypto/x509" - "errors" "fmt" "net" "sync" @@ -15,9 +14,12 @@ import ( "github.com/pion/logging" ) -const initialTickerInterval = time.Second -const cookieLength = 20 -const defaultNamedCurve = namedCurveX25519 +const ( + initialTickerInterval = time.Second + cookieLength = 20 + defaultNamedCurve = namedCurveX25519 + inboundBufferSize = 8192 +) var invalidKeyingLabels = map[string]bool{ "client finished": true, @@ -52,6 +54,7 @@ type Conn struct { localCertificate *x509.Certificate localPrivateKey crypto.PrivateKey localKeypair, remoteKeypair *namedCurveKeypair + localPSK []byte cookie []byte localCertificateVerify []byte // cache CertificateVerify @@ -68,34 +71,24 @@ type Conn struct { } func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessageHandler handshakeMessageHandler, config *Config, isClient bool) (*Conn, error) { - if config == nil { - return nil, errors.New("no config provided") - } - - loggerFactory := config.LoggerFactory - if loggerFactory == nil { - loggerFactory = logging.NewDefaultLoggerFactory() + switch { + case config == nil: + return nil, errNoConfigProvided + case nextConn == nil: + return nil, errNilNextConn + case config.Certificate != nil && config.PSK != nil: + return nil, errPSKAndCertificate } if config.PrivateKey != nil { if _, ok := config.PrivateKey.(*ecdsa.PrivateKey); !ok { return nil, errInvalidPrivateKey } - } else if nextConn == nil { - return nil, errNilNextConn } - cipherSuites := []cipherSuite{} - if len(config.CipherSuites) != 0 { - for _, id := range config.CipherSuites { - c := cipherSuiteForID(id) - if c == nil { - return nil, fmt.Errorf("CipherSuite with id(%d) is not valid", id) - } - cipherSuites = append(cipherSuites, c) - } - } else { - cipherSuites = defaultCipherSuites() + cipherSuites, err := parseCipherSuites(config.CipherSuites, config.PSK) + if err != nil { + return nil, err } workerInterval := initialTickerInterval @@ -103,6 +96,11 @@ func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessage workerInterval = config.FlightInterval } + loggerFactory := config.LoggerFactory + if loggerFactory == nil { + loggerFactory = logging.NewDefaultLoggerFactory() + } + c := &Conn{ nextConn: nextConn, currFlight: newFlight(isClient), @@ -115,6 +113,7 @@ func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessage clientAuth: config.ClientAuth, localSRTPProtectionProfiles: config.SRTPProtectionProfiles, localCipherSuites: cipherSuites, + localPSK: config.PSK, namedCurve: defaultNamedCurve, decrypted: make(chan []byte), @@ -122,14 +121,13 @@ func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessage handshakeCompleted: make(chan bool), log: loggerFactory.NewLogger("dtls"), } - c.state.isClient = isClient var zeroEpoch uint16 c.state.localEpoch.Store(zeroEpoch) c.state.remoteEpoch.Store(zeroEpoch) + c.state.isClient = isClient - err := c.state.localRandom.populate() - if err != nil { + if err = c.state.localRandom.populate(); err != nil { return nil, err } if !isClient { @@ -143,27 +141,7 @@ func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessage c.startHandshakeOutbound() // Handle inbound - go func() { - defer func() { - close(c.decrypted) - }() - - b := make([]byte, 8192) - for { - i, err := c.nextConn.Read(b) - if err != nil { - c.stopWithError(err) - return - } else if c.getConnErr() != nil { - return - } - - if err := c.handleIncoming(b[:i]); err != nil { - c.stopWithError(err) - return - } - } - }() + go c.inboundLoop() <-c.handshakeCompleted c.log.Trace("Handshake Completed") @@ -186,9 +164,12 @@ func Client(conn net.Conn, config *Config) (*Conn, error) { // Server listens for incoming DTLS connections func Server(conn net.Conn, config *Config) (*Conn, error) { - if config == nil || config.Certificate == nil { + if config == nil { + return nil, errNoConfigProvided + } else if config.PSK == nil && config.Certificate == nil { return nil, errServerMustHaveCertificate } + return createConn(conn, serverFlightHandler, serverHandshakeHandler, config, false) } @@ -320,19 +301,35 @@ func (c *Conn) internalSend(pkt *recordLayer, shouldEncrypt bool) { } } -func (c *Conn) handleIncoming(buf []byte) error { - pkts, err := unpackDatagram(buf) - if err != nil { - return err - } +func (c *Conn) inboundLoop() { + defer func() { + close(c.decrypted) + }() - for _, p := range pkts { - err := c.handleIncomingPacket(p) + b := make([]byte, inboundBufferSize) + for { + i, err := c.nextConn.Read(b) if err != nil { - return err + c.stopWithError(err) + return + } else if c.getConnErr() != nil { + return + } + + pkts, err := unpackDatagram(b[:i]) + if err != nil { + c.stopWithError(err) + return + } + + for _, p := range pkts { + err := c.handleIncomingPacket(p) + if err != nil { + c.stopWithError(err) + return + } } } - return nil } func (c *Conn) handleIncomingPacket(buf []byte) error { diff --git a/conn_test.go b/conn_test.go index 2539aade8..dcfa26b06 100644 --- a/conn_test.go +++ b/conn_test.go @@ -83,12 +83,12 @@ func pipeMemory() (*Conn, *Conn, error) { // Setup client go func() { - client, err := testClient(ca, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}) + client, err := testClient(ca, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) c <- result{client, err} }() // Setup server - server, err := testServer(cb, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}) + server, err := testServer(cb, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) if err != nil { return nil, nil, err } @@ -102,23 +102,27 @@ func pipeMemory() (*Conn, *Conn, error) { return res.c, server, nil } -func testClient(c net.Conn, cfg *Config) (*Conn, error) { - clientCert, clientKey, err := GenerateSelfSigned() - if err != nil { - return nil, err +func testClient(c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) { + if generateCertificate { + clientCert, clientKey, err := GenerateSelfSigned() + if err != nil { + return nil, err + } + cfg.PrivateKey = clientKey + cfg.Certificate = clientCert } - cfg.PrivateKey = clientKey - cfg.Certificate = clientCert return Client(c, cfg) } -func testServer(c net.Conn, cfg *Config) (*Conn, error) { - serverCert, serverKey, err := GenerateSelfSigned() - if err != nil { - return nil, err +func testServer(c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) { + if generateCertificate { + serverCert, serverKey, err := GenerateSelfSigned() + if err != nil { + return nil, err + } + cfg.PrivateKey = serverKey + cfg.Certificate = serverCert } - cfg.PrivateKey = serverKey - cfg.Certificate = serverCert return Server(c, cfg) } @@ -222,11 +226,11 @@ func TestSRTPConfiguration(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ca, &Config{SRTPProtectionProfiles: test.ClientSRTP}) + client, err := testClient(ca, &Config{SRTPProtectionProfiles: test.ClientSRTP}, true) c <- result{client, err} }() - server, err := testServer(cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}) + server, err := testServer(cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}, true) if err != nil || test.WantServerError != nil { if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) { t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) @@ -263,12 +267,12 @@ func TestClientCertificate(t *testing.T) { go func() { conf := &Config{ClientAuth: RequireAnyClientCert} - client, err := testClient(ca, conf) + client, err := testClient(ca, conf, true) c <- result{client, conf, err} }() serverCfg := &Config{ClientAuth: RequireAnyClientCert} - server, err := testServer(cb, serverCfg) + server, err := testServer(cb, serverCfg, true) if err != nil { t.Errorf("TestClientCertificate: Server failed(%v)", err) } @@ -342,11 +346,68 @@ func TestCipherSuiteConfiguration(t *testing.T) { c := make(chan result) go func() { - client, err := testClient(ca, &Config{CipherSuites: test.ClientCipherSuites}) + client, err := testClient(ca, &Config{CipherSuites: test.ClientCipherSuites}, true) + c <- result{client, err} + }() + + _, err := testServer(cb, &Config{CipherSuites: test.ServerCipherSuites}, true) + if err != nil || test.WantServerError != nil { + if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) { + t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) + } + } + + res := <-c + if res.err != nil || test.WantClientError != nil { + if !(res.err != nil && test.WantClientError != nil && err.Error() == test.WantClientError.Error()) { + t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, err) + } + } + } +} + +func TestPSKConfiguration(t *testing.T) { + for _, test := range []struct { + Name string + ClientHasCertificate bool + ServerHasCertificate bool + ClientPSK []byte + ServerPSK []byte + WantClientError error + WantServerError error + }{ + { + Name: "PSK specified", + ClientHasCertificate: false, + ServerHasCertificate: false, + ClientPSK: []byte{0x00, 0x01, 0x02}, + ServerPSK: []byte{0x00, 0x01, 0x02}, + WantClientError: errNoAvailableCipherSuites, // TODO (should be nil when PSK CipherSuite is added) + WantServerError: errNoAvailableCipherSuites, // TODO (should be nil when PSK CipherSuite is added) + }, + { + Name: "PSK and certificate specified", + ClientHasCertificate: true, + ServerHasCertificate: true, + ClientPSK: []byte{0x00, 0x01, 0x02}, + ServerPSK: []byte{0x00, 0x01, 0x02}, + WantClientError: errPSKAndCertificate, + WantServerError: errPSKAndCertificate, + }, + } { + ca, cb := net.Pipe() + type result struct { + c *Conn + err error + } + c := make(chan result) + + go func() { + client, err := testClient(ca, &Config{PSK: test.ClientPSK}, test.ClientHasCertificate) c <- result{client, err} }() - _, err := testServer(cb, &Config{CipherSuites: test.ServerCipherSuites}) + _, err := testServer(cb, &Config{PSK: test.ServerPSK}, test.ServerHasCertificate) if err != nil || test.WantServerError != nil { if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) { t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) diff --git a/errors.go b/errors.go index ace7a8aec..ae361d2a2 100644 --- a/errors.go +++ b/errors.go @@ -44,4 +44,7 @@ var ( errServerMustHaveCertificate = errors.New("dtls: Certificate is mandatory for server") errUnableToMarshalFragmented = errors.New("dtls: unable to marshal fragmented handshakes") errVerifyDataMismatch = errors.New("dtls: Expected and actual verify data does not match") + errNoConfigProvided = errors.New("dtls: No config provided") + errPSKAndCertificate = errors.New("dtls: Certificate and PSK provided") + errNoAvailableCipherSuites = errors.New("dtls: Connection can not be created, no CipherSuites satisfy this Config") )