Skip to content

Commit

Permalink
Make sure clean the stored session
Browse files Browse the repository at this point in the history
We need to delete the stored session when any fatal errors occurs.
This operation should be taken in the Conn.notify function.
  • Loading branch information
taoso authored and daenney committed Jan 11, 2022
1 parent 04c1634 commit fe3a675
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 30 deletions.
28 changes: 20 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -676,14 +676,6 @@ func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert.Aler
var err error
buf, err = c.state.cipherSuite.Decrypt(buf)
if err != nil {
if len(c.state.SessionID) > 0 {
// According to the RFC, we need to delete the stored session.
// https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
if delErr := c.fsm.cfg.sessionStore.Del(c.state.SessionID); delErr != nil {
return false, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, delErr
}
return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecryptError}, err
}
c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
return false, nil, nil
}
Expand Down Expand Up @@ -764,6 +756,16 @@ func (c *Conn) recvHandshake() <-chan chan struct{} {
}

func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
if level == alert.Fatal && len(c.state.SessionID) > 0 {
// According to the RFC, we need to delete the stored session.
// https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
if ss := c.fsm.cfg.sessionStore; ss != nil {
c.log.Tracef("clean invalid session: %s", c.state.SessionID)
if err := ss.Del(c.sessionKey()); err != nil {
return err
}
}
}
return c.writePackets(ctx, []*packet{
{
record: &recordlayer.RecordLayer{
Expand Down Expand Up @@ -960,6 +962,16 @@ func (c *Conn) RemoteAddr() net.Addr {
return c.nextConn.RemoteAddr()
}

func (c *Conn) sessionKey() []byte {
if c.state.isClient {
// As ServerName can be like 0.example.com, it's better to add
// delimiter character which is not allowed to be in
// neither address or domain name.
return []byte(c.nextConn.RemoteAddr().String() + "_" + c.fsm.cfg.serverName)
}
return c.state.SessionID
}

// SetDeadline implements net.Conn.SetDeadline
func (c *Conn) SetDeadline(t time.Time) error {
c.readDeadline.Set(t)
Expand Down
8 changes: 4 additions & 4 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2153,7 +2153,7 @@ func TestSessionResume(t *testing.T) {
report := test.CheckRoutines(t)
defer report()

t.Run("session resumption old", func(t *testing.T) {
t.Run("resumed", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

Expand All @@ -2173,7 +2173,7 @@ func TestSessionResume(t *testing.T) {
ca, cb := dpipe.Pipe()

_ = ss.Set(id, s)
_ = ss.Set([]byte(ca.RemoteAddr().String()+"example.com"), s)
_ = ss.Set([]byte(ca.RemoteAddr().String()+"_example.com"), s)

go func() {
config := &Config{
Expand Down Expand Up @@ -2217,7 +2217,7 @@ func TestSessionResume(t *testing.T) {
_ = res.c.Close()
})

t.Run("session resumption new", func(t *testing.T) {
t.Run("new session", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

Expand Down Expand Up @@ -2263,7 +2263,7 @@ func TestSessionResume(t *testing.T) {
if res.err != nil {
t.Fatal(res.err)
}
cs, _ := s1.Get([]byte(ca.RemoteAddr().String() + "example.com"))
cs, _ := s1.Get([]byte(ca.RemoteAddr().String() + "_example.com"))
if !bytes.Equal(actualMasterSecret, cs.Secret) {
t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret)
}
Expand Down
3 changes: 1 addition & 2 deletions flight1handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ func flight1Generate(c flightConn, state *State, cache *handshakeCache, cfg *han

if cfg.sessionStore != nil {
cfg.log.Tracef("[handshake] try to resume session")
key := []byte(c.RemoteAddr().String() + cfg.serverName)
if s, err := cfg.sessionStore.Get(key); err != nil {
if s, err := cfg.sessionStore.Get(c.sessionKey()); err != nil {
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
} else if s.ID != nil {
cfg.log.Tracef("[handshake] get saved session: %x", s.ID)
Expand Down
5 changes: 0 additions & 5 deletions flight3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,6 @@ func handleResumption(ctx context.Context, c flightConn, state *State, cache *ha
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
cfg.log.Tracef("[handshake] clean invalid session: %s", state.SessionID)
if err := cfg.sessionStore.Del(state.SessionID); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}

return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch
}

Expand Down
5 changes: 0 additions & 5 deletions flight4bhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ func flight4bParse(ctx context.Context, c flightConn, state *State, cache *hands

expectedVerifyData, err := prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc())
if err != nil {
cfg.log.Tracef("[handshake] clean invalid session: %s", state.SessionID)
if delErr := cfg.sessionStore.Del(state.SessionID); delErr != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, delErr
}

return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
if !bytes.Equal(expectedVerifyData, finished.VerifyData) {
Expand Down
3 changes: 1 addition & 2 deletions flight5handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ func flight5Parse(ctx context.Context, c flightConn, state *State, cache *handsh
Secret: state.masterSecret,
}
cfg.log.Tracef("[handshake] save new session: %x", s.ID)
key := []byte(c.RemoteAddr().String() + cfg.serverName)
if err := cfg.sessionStore.Set(key, s); err != nil {
if err := cfg.sessionStore.Set(c.sessionKey(), s); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
}
Expand Down
3 changes: 1 addition & 2 deletions handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"crypto/x509"
"fmt"
"io"
"net"
"sync"
"time"

Expand Down Expand Up @@ -122,7 +121,7 @@ type flightConn interface {
recvHandshake() <-chan chan struct{}
setLocalEpoch(epoch uint16)
handleQueuedPackets(context.Context) error
RemoteAddr() net.Addr
sessionKey() []byte
}

func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) {
Expand Down
3 changes: 1 addition & 2 deletions handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"context"
"crypto/tls"
"net"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -277,6 +276,6 @@ func (c *flightTestConn) handleQueuedPackets(ctx context.Context) error {
return nil
}

func (c *flightTestConn) RemoteAddr() net.Addr {
func (c *flightTestConn) sessionKey() []byte {
return nil
}

0 comments on commit fe3a675

Please sign in to comment.