Skip to content

Commit

Permalink
Limit size of encrypted packet queue
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt authored and Sean-Der committed Jun 1, 2024
1 parent fbbdf66 commit e11d429
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 34 deletions.
94 changes: 61 additions & 33 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ const (
inboundBufferSize = 8192
// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
defaultReplayProtectionWindow = 64
// maxAppDataPacketQueueSize is the maximum number of app data packets we will
// enqueue before the handshake is completed
maxAppDataPacketQueueSize = 100
)

func invalidKeyingLabels() map[string]bool {
Expand Down Expand Up @@ -88,7 +91,7 @@ type Conn struct {
replayProtectionWindow uint
}

func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, initialState *State) (*Conn, error) {
func createConn(nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool) (*Conn, error) {
if err := validateConfig(config); err != nil {
return nil, err
}
Expand All @@ -97,21 +100,6 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co
return nil, errNilNextConn
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err
}

signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err
}

workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}

loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
Expand Down Expand Up @@ -162,6 +150,28 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co

c.setRemoteEpoch(0)
c.setLocalEpoch(0)
return c, nil
}

func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
if conn == nil {
return nil, errNilNextConn

Check warning on line 158 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L158

Added line #L158 was not covered by tests
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err

Check warning on line 163 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L163

Added line #L163 was not covered by tests
}

signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err

Check warning on line 168 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L168

Added line #L168 was not covered by tests
}

workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}

serverName := config.ServerName
// Do not allow the use of an IP address literal as an SNI value.
Expand Down Expand Up @@ -193,7 +203,7 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
log: logger,
log: conn.log,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
Expand Down Expand Up @@ -222,30 +232,30 @@ func createConn(ctx context.Context, nextConn net.PacketConn, rAddr net.Addr, co
var initialFSMState handshakeState

if initialState != nil {
if c.state.isClient {
if conn.state.isClient {
initialFlight = flight5
} else {
initialFlight = flight6
}
initialFSMState = handshakeFinished

c.state = *initialState
conn.state = *initialState
} else {
if c.state.isClient {
if conn.state.isClient {
initialFlight = flight1
} else {
initialFlight = flight0
}
initialFSMState = handshakePreparing
}
// Do handshake
if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
return nil, err
}

c.log.Trace("Handshake Completed")
conn.log.Trace("Handshake Completed")

return c, nil
return conn, nil
}

// Dial connects to the given network address and establishes a DTLS connection on top.
Expand Down Expand Up @@ -301,16 +311,24 @@ func ClientWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr,
return nil, errPSKAndIdentityMustBeSetForClient
}

return createConn(ctx, conn, rAddr, config, true, nil)
dconn, err := createConn(conn, rAddr, config, true)
if err != nil {
return nil, err
}

return handshakeConn(ctx, dconn, config, true, nil)
}

// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) {
if config == nil {
return nil, errNoConfigProvided
}

return createConn(ctx, conn, rAddr, config, false, nil)
dconn, err := createConn(conn, rAddr, config, false)
if err != nil {
return nil, err
}
return handshakeConn(ctx, dconn, config, false, nil)
}

// Read reads data from the connection.
Expand Down Expand Up @@ -738,6 +756,14 @@ func (c *Conn) handleQueuedPackets(ctx context.Context) error {
return nil
}

func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool {
if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
c.encryptedPackets = append(c.encryptedPackets, packet)
return true
}
return false
}

func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit
h := &recordlayer.Header{}
// Set connection ID size so that records of content type tls12_cid will
Expand All @@ -751,7 +777,6 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
c.log.Debugf("discarded broken packet: %v", err)
return false, nil, nil
}

// Validate epoch
remoteEpoch := c.state.getRemoteEpoch()
if h.Epoch > remoteEpoch {
Expand All @@ -762,8 +787,9 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
return false, nil, nil
}
if enqueue {
c.log.Debug("received packet of next epoch, queuing packet")
c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf})
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debug("received packet of next epoch, queuing packet")
}
}
return false, nil, nil
}
Expand All @@ -790,8 +816,9 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
if h.Epoch != 0 {
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf})
c.log.Debug("handshake not finished, queuing packet")
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debug("handshake not finished, queuing packet")

Check warning on line 820 in conn.go

View check run for this annotation

Codecov / codecov/patch

conn.go#L819-L820

Added lines #L819 - L820 were not covered by tests
}
}
return false, nil, nil
}
Expand Down Expand Up @@ -883,8 +910,9 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, rAddr net.A
case *protocol.ChangeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, addrPkt{rAddr, buf})
c.log.Debugf("CipherSuite not initialized, queuing packet")
if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok {
c.log.Debugf("CipherSuite not initialized, queuing packet")
}
}
return false, nil, nil
}
Expand Down
84 changes: 84 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3279,3 +3279,87 @@ func (c *connWithCallback) Write(b []byte) (int, error) {
}
return c.Conn.Write(b)
}

func TestApplicationDataQueueLimited(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

// Check for leaking routines
report := test.CheckRoutines(t)
defer report()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

ca, cb := dpipe.Pipe()
defer ca.Close() //nolint:errcheck
defer cb.Close() //nolint:errcheck

done := make(chan struct{})
go func() {
serverCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Error(err)
return
}
cfg := &Config{}
cfg.Certificates = []tls.Certificate{serverCert}

dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false)
if err != nil {
t.Error(err)
return
}
go func() {
for i := 0; i < 5; i++ {
dconn.lock.RLock()
qlen := len(dconn.encryptedPackets)
dconn.lock.RUnlock()
if qlen > maxAppDataPacketQueueSize {
t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets))
}
t.Log(qlen)
time.Sleep(1 * time.Second)
}
}()
if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil {
t.Error("expected handshake to fail")
}
close(done)
}()
extensions := []extension.Extension{}

time.Sleep(50 * time.Millisecond)

err := sendClientHello([]byte{}, ca, 0, extensions)
if err != nil {
t.Fatal(err)
}

time.Sleep(50 * time.Millisecond)

for i := 0; i < 1000; i++ {
// Send an application data packet
packet, err := (&recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: uint64(3),
Epoch: 1, // use an epoch greater than 0
},
Content: &protocol.ApplicationData{
Data: []byte{1, 2, 3, 4},
},
}).Marshal()
if err != nil {
t.Fatal(err)
}
ca.Write(packet) // nolint
if i%100 == 0 {
time.Sleep(10 * time.Millisecond)
}
}
time.Sleep(1 * time.Second)
ca.Close() // nolint
<-done
}
6 changes: 5 additions & 1 deletion resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ func Resume(state *State, conn net.PacketConn, rAddr net.Addr, config *Config) (
if err := state.initCipherSuite(); err != nil {
return nil, err
}
c, err := createConn(context.Background(), conn, rAddr, config, state.isClient, state)
dconn, err := createConn(conn, rAddr, config, state.isClient)
if err != nil {
return nil, err

Check warning on line 18 in resume.go

View check run for this annotation

Codecov / codecov/patch

resume.go#L18

Added line #L18 was not covered by tests
}
c, err := handshakeConn(context.Background(), dconn, config, state.isClient, state)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit e11d429

Please sign in to comment.