Skip to content

Commit

Permalink
Update PSK Config to pass PSK Identity hint
Browse files Browse the repository at this point in the history
Also simplify CipherSuite filtering relating to
PSK/non-PSK suites

Relates to #45
  • Loading branch information
Sean-Der committed May 23, 2019
1 parent a2d36ed commit bb439be
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 16 deletions.
9 changes: 5 additions & 4 deletions cipher_suite.go
Expand Up @@ -88,7 +88,7 @@ func encodeCipherSuites(c []cipherSuite) []byte {
return out
}

func parseCipherSuites(userSelectedSuites []CipherSuiteID, psk []byte) ([]cipherSuite, error) {
func parseCipherSuites(userSelectedSuites []CipherSuiteID, excludePSK, excludeNonPSK bool) ([]cipherSuite, error) {
cipherSuitesForIDs := func(ids []CipherSuiteID) ([]cipherSuite, error) {
cipherSuites := []cipherSuite{}
for _, id := range ids {
Expand Down Expand Up @@ -116,10 +116,11 @@ func parseCipherSuites(userSelectedSuites []CipherSuiteID, psk []byte) ([]cipher
}

for _, c := range cipherSuites {
if (psk != nil && c.isPSK()) || (psk == nil && !c.isPSK()) {
cipherSuites[i] = c
i++
if excludePSK && c.isPSK() || excludeNonPSK && !c.isPSK() {
continue
}
cipherSuites[i] = c
i++
}

cipherSuites = cipherSuites[:i]
Expand Down
7 changes: 6 additions & 1 deletion config.go
Expand Up @@ -39,11 +39,16 @@ type Config struct {

// PSK sets the pre-shared key used by this DTLS connection
// If PSK is non-nil only PSK CipherSuites will be used
PSK []byte
PSK PSKCallback
PSKIdentityHint []byte

LoggerFactory logging.LoggerFactory
}

// PSKCallback is called once we have the remote's PSKIdentityHint.
// If the remote provided none it will be nil
type PSKCallback func([]byte) ([]byte, error)

// ClientAuthType declares the policy the server will follow for
// TLS Client Authentication.
type ClientAuthType int
Expand Down
12 changes: 8 additions & 4 deletions conn.go
Expand Up @@ -54,9 +54,11 @@ type Conn struct {
localCertificate *x509.Certificate
localPrivateKey crypto.PrivateKey
localKeypair, remoteKeypair *namedCurveKeypair
localPSK []byte
cookie []byte

localPSKCallback PSKCallback
localPSKIdentityHint []byte

localCertificateVerify []byte // cache CertificateVerify
localVerifyData []byte // cached VerifyData
localKeySignature []byte // cached keySignature
Expand All @@ -76,7 +78,7 @@ func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessage
return nil, errNoConfigProvided
case nextConn == nil:
return nil, errNilNextConn
case config.Certificate != nil && config.PSK != nil:
case config.Certificate != nil && (config.PSK != nil || config.PSKIdentityHint != nil):
return nil, errPSKAndCertificate
}

Expand All @@ -86,7 +88,7 @@ func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessage
}
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.PSK)
cipherSuites, err := parseCipherSuites(config.CipherSuites, config.PSK == nil, config.PSK != nil)
if err != nil {
return nil, err
}
Expand All @@ -113,9 +115,11 @@ func createConn(nextConn net.Conn, flightHandler flightHandler, handshakeMessage
clientAuth: config.ClientAuth,
localSRTPProtectionProfiles: config.SRTPProtectionProfiles,
localCipherSuites: cipherSuites,
localPSK: config.PSK,
namedCurve: defaultNamedCurve,

localPSKCallback: config.PSK,
localPSKIdentityHint: config.PSKIdentityHint,

decrypted: make(chan []byte),
workerTicker: time.NewTicker(workerInterval),
handshakeCompleted: make(chan bool),
Expand Down
12 changes: 6 additions & 6 deletions conn_test.go
Expand Up @@ -371,26 +371,26 @@ func TestPSKConfiguration(t *testing.T) {
Name string
ClientHasCertificate bool
ServerHasCertificate bool
ClientPSK []byte
ServerPSK []byte
ClientPSK PSKCallback
ServerPSK PSKCallback
WantClientError error
WantServerError error
}{
{
Name: "PSK specified",
ClientHasCertificate: false,
ServerHasCertificate: false,
ClientPSK: []byte{0x00, 0x01, 0x02},
ServerPSK: []byte{0x00, 0x01, 0x02},
ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
WantClientError: errNoAvailableCipherSuites, // TODO (should be nil when PSK CipherSuite is added)
WantServerError: errNoAvailableCipherSuites, // TODO (should be nil when PSK CipherSuite is added)
},
{
Name: "PSK and certificate specified",
ClientHasCertificate: true,
ServerHasCertificate: true,
ClientPSK: []byte{0x00, 0x01, 0x02},
ServerPSK: []byte{0x00, 0x01, 0x02},
ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil },
WantClientError: errPSKAndCertificate,
WantServerError: errPSKAndCertificate,
},
Expand Down
2 changes: 1 addition & 1 deletion errors.go
Expand Up @@ -45,6 +45,6 @@ var (
errUnableToMarshalFragmented = errors.New("dtls: unable to marshal fragmented handshakes")
errVerifyDataMismatch = errors.New("dtls: Expected and actual verify data does not match")
errNoConfigProvided = errors.New("dtls: No config provided")
errPSKAndCertificate = errors.New("dtls: Certificate and PSK provided")
errPSKAndCertificate = errors.New("dtls: Certificate and PSK or PSK Identity Hint provided")
errNoAvailableCipherSuites = errors.New("dtls: Connection can not be created, no CipherSuites satisfy this Config")
)

0 comments on commit bb439be

Please sign in to comment.