Skip to content

Commit

Permalink
Support stateful session resumption
Browse files Browse the repository at this point in the history
This PR only implements the session ID based session
resumption which requires the server to save the session info
https://datatracker.ietf.org/doc/html/rfc5246#appendix-F.1.4

It doesn't implement the session ticket based session
resumption: https://datatracker.ietf.org/doc/html/rfc5077
  • Loading branch information
taoso committed Dec 21, 2021
1 parent f0a790c commit a8ce745
Show file tree
Hide file tree
Showing 21 changed files with 706 additions and 78 deletions.
3 changes: 3 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ type Config struct {
// Use of KeyLogWriter compromises security and should only be
// used for debugging.
KeyLogWriter io.Writer

// SessionStore is the container to store session for resumption.
SessionStore SessionStore
}

func defaultConnectContextMaker() (context.Context, func()) {
Expand Down
14 changes: 14 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
const (
initialTickerInterval = time.Second
cookieLength = 20
sessionLength = 32
defaultNamedCurve = elliptic.X25519
inboundBufferSize = 8192
// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
Expand Down Expand Up @@ -83,6 +84,10 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
return nil, err
}

if isClient && config.SessionStore != nil && config.ServerName == "" {
return nil, errSessionStoreNoServerName
}

if nextConn == nil {
return nil, errNilNextConn
}
Expand Down Expand Up @@ -172,6 +177,7 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
log: logger,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
}

var initialFlight flightVal
Expand Down Expand Up @@ -674,6 +680,14 @@ func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert.Aler
var err error
buf, err = c.state.cipherSuite.Decrypt(buf)
if err != nil {
if len(c.state.SessionID) > 0 {
// According to the RFC, we need to delete the stored session.
// https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
if delErr := c.fsm.cfg.sessionStore.Del(c.state.SessionID); delErr != nil {
return false, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, delErr
}
return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecryptError}, err
}
c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
return false, nil, nil
}
Expand Down
204 changes: 156 additions & 48 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -1584,11 +1585,6 @@ func TestProtocolVersionValidation(t *testing.T) {
var rand [28]byte
random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}

localKeypair, err := elliptic.GenerateKeypair(elliptic.X25519)
if err != nil {
t.Fatal(err)
}

config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
FlightInterval: 100 * time.Millisecond,
Expand Down Expand Up @@ -1740,49 +1736,6 @@ func TestProtocolVersionValidation(t *testing.T) {
},
},
},
{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: 2,
},
Content: &handshake.Handshake{
Header: handshake.Header{
MessageSequence: 2,
},
Message: &handshake.MessageCertificate{},
},
},
{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: 3,
},
Content: &handshake.Handshake{
Header: handshake.Header{
MessageSequence: 3,
},
Message: &handshake.MessageServerKeyExchange{
EllipticCurveType: elliptic.CurveTypeNamedCurve,
NamedCurve: elliptic.X25519,
PublicKey: localKeypair.PublicKey,
HashAlgorithm: hash.SHA256,
SignatureAlgorithm: signature.ECDSA,
Signature: make([]byte, 64),
},
},
},
{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: 4,
},
Content: &handshake.Handshake{
Header: handshake.Header{
MessageSequence: 4,
},
Message: &handshake.MessageServerHelloDone{},
},
},
},
},
}
Expand Down Expand Up @@ -2190,3 +2143,158 @@ func TestSupportedGroupsExtension(t *testing.T) {
}
})
}

func TestSessionResume(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

// Check for leaking routines
report := test.CheckRoutines(t)
defer report()

t.Run("session resumption old", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

type result struct {
c *Conn
err error
}
clientRes := make(chan result, 1)

ss := &memSessStore{}

id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306")
secret, _ := hex.DecodeString("2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7")

s := Session{ID: id, Secret: secret}

_ = ss.Set(id, s)
_ = ss.Set([]byte("example.com"), s)

ca, cb := dpipe.Pipe()
go func() {
config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
ServerName: "example.com",
SessionStore: ss,
MTU: 100,
}
c, err := testClient(ctx, ca, config, false)
clientRes <- result{c, err}
}()

config := &Config{
CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
ServerName: "example.com",
SessionStore: ss,
MTU: 100,
}
server, err := testServer(ctx, cb, config, true)
if err != nil {
t.Fatalf("TestSessionResume: Server failed(%v)", err)
}

actualSessionID := server.ConnectionState().SessionID
actualMasterSecret := server.ConnectionState().masterSecret
if !bytes.Equal(actualSessionID, id) {
t.Errorf("TestSessionResumetion: SessionID Mismatch: expected(%v) actual(%v)", id, actualSessionID)
}
if !bytes.Equal(actualMasterSecret, secret) {
t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", secret, actualMasterSecret)
}

defer func() {
_ = server.Close()
}()

res := <-clientRes
if res.err != nil {
t.Fatal(res.err)
}
_ = res.c.Close()
})

t.Run("session resumption new", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

type result struct {
c *Conn
err error
}
clientRes := make(chan result, 1)

s1 := &memSessStore{}
s2 := &memSessStore{}

ca, cb := dpipe.Pipe()
go func() {
config := &Config{
ServerName: "example.com",
SessionStore: s1,
}
c, err := testClient(ctx, ca, config, false)
clientRes <- result{c, err}
}()

config := &Config{
SessionStore: s2,
}
server, err := testServer(ctx, cb, config, true)
if err != nil {
t.Fatalf("TestSessionResumetion: Server failed(%v)", err)
}

actualSessionID := server.ConnectionState().SessionID
actualMasterSecret := server.ConnectionState().masterSecret
ss, _ := s2.Get(actualSessionID)
if !bytes.Equal(actualMasterSecret, ss.Secret) {
t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret)
}

defer func() {
_ = server.Close()
}()

res := <-clientRes
if res.err != nil {
t.Fatal(res.err)
}
cs, _ := s1.Get([]byte("example.com"))
if !bytes.Equal(actualMasterSecret, cs.Secret) {
t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret)
}
_ = res.c.Close()
})
}

type memSessStore struct {
sync.Map
}

func (ms *memSessStore) Set(key []byte, s Session) error {
k := hex.EncodeToString(key)
ms.Store(k, s)

return nil
}

func (ms *memSessStore) Get(key []byte) (Session, error) {
k := hex.EncodeToString(key)

v, ok := ms.Load(k)
if !ok {
return Session{}, nil
}

return v.(Session), nil
}

func (ms *memSessStore) Del(key []byte) error {
k := hex.EncodeToString(key)
ms.Delete(k)

return nil
}
1 change: 1 addition & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ var (
errServerNoMatchingSRTPProfile = &FatalError{Err: errors.New("client requested SRTP but we have no matching profiles")} //nolint:goerr113
errServerRequiredButNoClientEMS = &FatalError{Err: errors.New("server requires the Extended Master Secret extension, but the client does not support it")} //nolint:goerr113
errVerifyDataMismatch = &FatalError{Err: errors.New("expected and actual verify data does not match")} //nolint:goerr113
errSessionStoreNoServerName = &FatalError{Err: errors.New("SessionStore must be set with ServerName")} //nolint:goerr113

errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} //nolint:goerr113
errKeySignatureGenerateUnimplemented = &InternalError{Err: errors.New("unable to generate key signature, unimplemented")} //nolint:goerr113
Expand Down
20 changes: 18 additions & 2 deletions flight.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ package dtls
of a number of messages, they should be viewed as monolithic for the
purpose of timeout and retransmission.
https://tools.ietf.org/html/rfc4347#section-4.2.4
Note: The flight4b and flight5b will be only used in session resumption.
Client Server
------ ------
Waiting Flight 0
Expand All @@ -22,10 +25,17 @@ package dtls
CertificateRequest* /
<-------- ServerHelloDone /
ServerHello \
[ChangeCipherSpec] Flight 4b
<-------- Finished /
Certificate* \
ClientKeyExchange \
CertificateVerify* Flight 5
[ChangeCipherSpec] /
Finished --------> /
[ChangeCipherSpec] \ Flight 5b
Finished --------> /
[ChangeCipherSpec] \ Flight 6
Expand All @@ -41,7 +51,9 @@ const (
flight2
flight3
flight4
flight4b
flight5
flight5b
flight6
)

Expand All @@ -57,8 +69,12 @@ func (f flightVal) String() string {
return "Flight 3"
case flight4:
return "Flight 4"
case flight4b:
return "Flight 4b"
case flight5:
return "Flight 5"
case flight5b:
return "Flight 5b"
case flight6:
return "Flight 6"
default:
Expand All @@ -67,9 +83,9 @@ func (f flightVal) String() string {
}

func (f flightVal) isLastSendFlight() bool {
return f == flight6
return f == flight6 || f == flight5b
}

func (f flightVal) isLastRecvFlight() bool {
return f == flight5
return f == flight5 || f == flight4b
}
25 changes: 24 additions & 1 deletion flight0handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,30 @@ func flight0Parse(ctx context.Context, c flightConn, state *State, cache *handsh
}
}

return flight2, nil, nil
return handleHelloResume(clientHello.SessionID, state, cfg, flight2)
}

func handleHelloResume(sessionID []byte, state *State, cfg *handshakeConfig, next flightVal) (flightVal, *alert.Alert, error) {
if len(sessionID) > 0 && cfg.sessionStore != nil {
if s, err := cfg.sessionStore.Get(sessionID); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
} else if s.ID != nil {
cfg.log.Tracef("[handshake] resume session: %x", sessionID)

state.SessionID = sessionID
state.masterSecret = s.Secret

if err := state.initCipherSuite(); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}

clientRandom := state.localRandom.MarshalFixed()
cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret)

return flight4b, nil, nil
}
}
return next, nil, nil
}

func flight0Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) {
Expand Down
Loading

0 comments on commit a8ce745

Please sign in to comment.