Skip to content

Commit

Permalink
Make sure clean the stored session
Browse files Browse the repository at this point in the history
We need to delete the stored session when any fatal errors occurs.
This operation should be taken in the Conn.notify function.
  • Loading branch information
taoso committed Dec 23, 2021
1 parent 961026b commit 0068ae0
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 16 deletions.
24 changes: 16 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -676,14 +676,6 @@ func (c *Conn) handleIncomingPacket(buf []byte, enqueue bool) (bool, *alert.Aler
var err error
buf, err = c.state.cipherSuite.Decrypt(buf)
if err != nil {
if len(c.state.SessionID) > 0 {
// According to the RFC, we need to delete the stored session.
// https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
if delErr := c.fsm.cfg.sessionStore.Del(c.state.SessionID); delErr != nil {
return false, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, delErr
}
return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecryptError}, err
}
c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err)
return false, nil, nil
}
Expand Down Expand Up @@ -764,6 +756,15 @@ func (c *Conn) recvHandshake() <-chan chan struct{} {
}

func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error {
if level == alert.Fatal && len(c.state.SessionID) > 0 {
// According to the RFC, we need to delete the stored session.
// https://datatracker.ietf.org/doc/html/rfc5246#section-7.2
if ss := c.fsm.cfg.sessionStore; ss != nil {
if err := ss.Del(c.sessionKey()); err != nil {
return err
}
}
}
return c.writePackets(ctx, []*packet{
{
record: &recordlayer.RecordLayer{
Expand Down Expand Up @@ -960,6 +961,13 @@ func (c *Conn) RemoteAddr() net.Addr {
return c.nextConn.RemoteAddr()
}

func (c *Conn) sessionKey() []byte {
if c.state.isClient {
return []byte(c.nextConn.RemoteAddr().String() + c.fsm.cfg.serverName)
}
return c.state.SessionID
}

// SetDeadline implements net.Conn.SetDeadline
func (c *Conn) SetDeadline(t time.Time) error {
c.readDeadline.Set(t)
Expand Down
3 changes: 1 addition & 2 deletions flight1handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ func flight1Generate(c flightConn, state *State, cache *handshakeCache, cfg *han

if cfg.sessionStore != nil {
cfg.log.Tracef("[handshake] try to resume session")
key := []byte(c.RemoteAddr().String() + cfg.serverName)
if s, err := cfg.sessionStore.Get(key); err != nil {
if s, err := cfg.sessionStore.Get(c.sessionKey()); err != nil {
return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
} else if s.ID != nil {
cfg.log.Tracef("[handshake] get saved session: %x", s.ID)
Expand Down
3 changes: 1 addition & 2 deletions flight5handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ func flight5Parse(ctx context.Context, c flightConn, state *State, cache *handsh
Secret: state.masterSecret,
}
cfg.log.Tracef("[handshake] save new session: %x", s.ID)
key := []byte(c.RemoteAddr().String() + cfg.serverName)
if err := cfg.sessionStore.Set(key, s); err != nil {
if err := cfg.sessionStore.Set(c.sessionKey(), s); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err
}
}
Expand Down
3 changes: 1 addition & 2 deletions handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"crypto/x509"
"fmt"
"io"
"net"
"sync"
"time"

Expand Down Expand Up @@ -122,7 +121,7 @@ type flightConn interface {
recvHandshake() <-chan chan struct{}
setLocalEpoch(epoch uint16)
handleQueuedPackets(context.Context) error
RemoteAddr() net.Addr
sessionKey() []byte
}

func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) {
Expand Down
3 changes: 1 addition & 2 deletions handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"context"
"crypto/tls"
"net"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -277,6 +276,6 @@ func (c *flightTestConn) handleQueuedPackets(ctx context.Context) error {
return nil
}

func (c *flightTestConn) RemoteAddr() net.Addr {
func (c *flightTestConn) sessionKey() []byte {
return nil
}

0 comments on commit 0068ae0

Please sign in to comment.