Skip to content

Commit

Permalink
Allow user to specify PSK
Browse files Browse the repository at this point in the history
Add configuration option for user to pass PSK. When a user
passes a PSK then we only allow CipherSuites that do PSK.

If user passes a PSK and a certificate return an error.

Relates to #45
  • Loading branch information
Sean-Der committed May 23, 2019
1 parent 1f6822a commit c065462
Show file tree
Hide file tree
Showing 7 changed files with 198 additions and 82 deletions.
44 changes: 44 additions & 0 deletions cipher_suite.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package dtls

import (
"encoding/binary"
"fmt"
"hash"
)

Expand All @@ -24,6 +25,7 @@ type cipherSuite interface {
ID() CipherSuiteID
certificateType() clientCertificateType
hashFunc() func() hash.Hash
isPSK() bool

// Generate the internal encryption state
init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error
Expand Down Expand Up @@ -85,3 +87,45 @@ func encodeCipherSuites(c []cipherSuite) []byte {

return out
}

func parseCipherSuites(userSelectedSuites []CipherSuiteID, psk []byte) ([]cipherSuite, error) {
cipherSuitesForIDs := func(ids []CipherSuiteID) ([]cipherSuite, error) {
cipherSuites := []cipherSuite{}
for _, id := range ids {
c := cipherSuiteForID(id)
if c == nil {
return nil, fmt.Errorf("CipherSuite with id(%d) is not valid", id)
}
cipherSuites = append(cipherSuites, c)
}
return cipherSuites, nil
}

var (
cipherSuites []cipherSuite
err error
i int
)
if len(userSelectedSuites) != 0 {
cipherSuites, err = cipherSuitesForIDs(userSelectedSuites)
if err != nil {
return nil, err
}
} else {
cipherSuites = defaultCipherSuites()
}

for _, c := range cipherSuites {
if (psk != nil && c.isPSK()) || (psk == nil && !c.isPSK()) {
cipherSuites[i] = c
i++
}
}

cipherSuites = cipherSuites[:i]
if len(cipherSuites) == 0 {
return nil, errNoAvailableCipherSuites
}

return cipherSuites, nil
}
4 changes: 4 additions & 0 deletions cipher_suite_tls_ecdhe_ecdsa_with_aes_128_gcm_sha256.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ func (c cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) hashFunc() func() hash.Hash
return sha256.New
}

func (c cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) isPSK() bool {
return false
}

func (c *cipherSuiteTLSEcdheEcdsaWithAes128GcmSha256) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 0
Expand Down
4 changes: 4 additions & 0 deletions cipher_suite_tls_ecdhe_ecdsa_with_aes_256_cbc_sha.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ func (c cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) hashFunc() func() hash.Hash {
return sha256.New
}

func (c cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) isPSK() bool {
return false
}

func (c *cipherSuiteTLSEcdheEcdsaWithAes256CbcSha) init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error {
const (
prfMacLen = 20
Expand Down
11 changes: 7 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ import (
// Config is used to configure a DTLS client or server.
// After a Config is passed to a DTLS function it must not be modified.
type Config struct {
// Certificates contains certificate chain to present to
// the other side of the connection. Server MUST set this,
// client SHOULD sets this so CertificateRequests
// can be handled
// Certificates contains certificate chain to present to the other side of the connection.
// Server MUST set this if PSK is non-nil
// client SHOULD sets this so CertificateRequests can be handled if PSK is non-nil
Certificate *x509.Certificate

// PrivateKey contains matching private key for the certificate
Expand All @@ -38,6 +37,10 @@ type Config struct {
// defaults to time.Second
FlightInterval time.Duration

// PSK sets the pre-shared key used by this DTLS connection
// If PSK is non-nil only PSK CipherSuites will be used
PSK []byte

LoggerFactory logging.LoggerFactory
}

Expand Down
113 changes: 55 additions & 58 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"crypto/ecdsa"
"crypto/rand"
"crypto/x509"
"errors"
"fmt"
"net"
"sync"
Expand All @@ -15,9 +14,12 @@ import (
"github.com/pion/logging"
)

const initialTickerInterval = time.Second
const cookieLength = 20
const defaultNamedCurve = namedCurveX25519
const (
initialTickerInterval = time.Second
cookieLength = 20
defaultNamedCurve = namedCurveX25519
inboundBufferSize = 8192
)

var invalidKeyingLabels = map[string]bool{
"client finished": true,
Expand Down Expand Up @@ -52,6 +54,7 @@ type Conn struct {
localCertificate *x509.Certificate
localPrivateKey crypto.PrivateKey
localKeypair, remoteKeypair *namedCurveKeypair
localPSK []byte
cookie []byte

localCertificateVerify []byte // cache CertificateVerify
Expand All @@ -68,41 +71,36 @@ type Conn struct {
}

func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessageHandler handshakeMessageHandler, config *Config, isClient bool) (*Conn, error) {
if config == nil {
return nil, errors.New("no config provided")
}

loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
switch {
case config == nil:
return nil, errNoConfigProvided
case nextConn == nil:
return nil, errNilNextConn
case config.Certificate != nil && config.PSK != nil:
return nil, errPSKAndCertificate
}

if config.PrivateKey != nil {
if _, ok := config.PrivateKey.(*ecdsa.PrivateKey); !ok {
return nil, errInvalidPrivateKey
}
} else if nextConn == nil {
return nil, errNilNextConn
}

cipherSuites := []cipherSuite{}
if len(config.CipherSuites) != 0 {
for _, id := range config.CipherSuites {
c := cipherSuiteForID(id)
if c == nil {
return nil, fmt.Errorf("CipherSuite with id(%d) is not valid", id)
}
cipherSuites = append(cipherSuites, c)
}
} else {
cipherSuites = defaultCipherSuites()
cipherSuites, err := parseCipherSuites(config.CipherSuites, config.PSK)
if err != nil {
return nil, err
}

workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}

loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
}

c := &Conn{
nextConn: nextConn,
currFlight: newFlight(isClient),
Expand All @@ -115,21 +113,21 @@ func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessage
clientAuth: config.ClientAuth,
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
localCipherSuites: cipherSuites,
localPSK: config.PSK,
namedCurve: defaultNamedCurve,

decrypted: make(chan []byte),
workerTicker: time.NewTicker(workerInterval),
handshakeCompleted: make(chan bool),
log: loggerFactory.NewLogger("dtls"),
}
c.state.isClient = isClient

var zeroEpoch uint16
c.state.localEpoch.Store(zeroEpoch)
c.state.remoteEpoch.Store(zeroEpoch)
c.state.isClient = isClient

err := c.state.localRandom.populate()
if err != nil {
if err = c.state.localRandom.populate(); err != nil {
return nil, err
}
if !isClient {
Expand All @@ -143,27 +141,7 @@ func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessage
c.startHandshakeOutbound()

// Handle inbound
go func() {
defer func() {
close(c.decrypted)
}()

b := make([]byte, 8192)
for {
i, err := c.nextConn.Read(b)
if err != nil {
c.stopWithError(err)
return
} else if c.getConnErr() != nil {
return
}

if err := c.handleIncoming(b[:i]); err != nil {
c.stopWithError(err)
return
}
}
}()
go c.inboundLoop()

<-c.handshakeCompleted
c.log.Trace("Handshake Completed")
Expand All @@ -186,9 +164,12 @@ func Client(conn net.Conn, config *Config) (*Conn, error) {

// Server listens for incoming DTLS connections
func Server(conn net.Conn, config *Config) (*Conn, error) {
if config == nil || config.Certificate == nil {
if config == nil {
return nil, errNoConfigProvided
} else if config.PSK == nil && config.Certificate == nil {
return nil, errServerMustHaveCertificate
}

return createConn(conn, serverFlightHandler, serverHandshakeHandler, config, false)
}

Expand Down Expand Up @@ -320,19 +301,35 @@ func (c *Conn) internalSend(pkt *recordLayer, shouldEncrypt bool) {
}
}

func (c *Conn) handleIncoming(buf []byte) error {
pkts, err := unpackDatagram(buf)
if err != nil {
return err
}
func (c *Conn) inboundLoop() {
defer func() {
close(c.decrypted)
}()

for _, p := range pkts {
err := c.handleIncomingPacket(p)
b := make([]byte, inboundBufferSize)
for {
i, err := c.nextConn.Read(b)
if err != nil {
return err
c.stopWithError(err)
return
} else if c.getConnErr() != nil {
return
}

pkts, err := unpackDatagram(b[:i])
if err != nil {
c.stopWithError(err)
return
}

for _, p := range pkts {
err := c.handleIncomingPacket(p)
if err != nil {
c.stopWithError(err)
return
}
}
}
return nil
}

func (c *Conn) handleIncomingPacket(buf []byte) error {
Expand Down

0 comments on commit c065462

Please sign in to comment.