Skip to content

Commit

Permalink
Limit size of encrypted packet queue
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt authored and MarcoPolo committed May 2, 2024
1 parent ebdb8bd commit cfeb9ca
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 34 deletions.
105 changes: 72 additions & 33 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ const (
inboundBufferSize = 8192
// Default replay protection window is specified by RFC 6347 Section 4.1.2.6
defaultReplayProtectionWindow = 64
// maxAppDataPacketQueueSize is the maximum number of app data packets we will
// enqueue before the handshake is completed
maxAppDataPacketQueueSize = 100
)

func invalidKeyingLabels() map[string]bool {
Expand Down Expand Up @@ -81,7 +84,7 @@ type Conn struct {
replayProtectionWindow uint
}

func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
func createConn(nextConn net.Conn, config *Config, isClient bool) (*Conn, error) {
err := validateConfig(config)
if err != nil {
return nil, err
Expand All @@ -91,21 +94,6 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
return nil, errNilNextConn
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err
}

signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err
}

workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}

loggerFactory := config.LoggerFactory
if loggerFactory == nil {
loggerFactory = logging.NewDefaultLoggerFactory()
Expand Down Expand Up @@ -149,6 +137,38 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient

c.setRemoteEpoch(0)
c.setLocalEpoch(0)
return c, nil
}

func handshakeConn(ctx context.Context, conn *Conn, config *Config, isClient bool, initialState *State) (*Conn, error) {
if conn == nil {
return nil, errNilNextConn
}

cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil)
if err != nil {
return nil, err
}

signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes)
if err != nil {
return nil, err
}

workerInterval := initialTickerInterval
if config.FlightInterval != 0 {
workerInterval = config.FlightInterval
}

mtu := config.MTU
if mtu <= 0 {
mtu = defaultMTU
}

replayProtectionWindow := config.ReplayProtectionWindow
if replayProtectionWindow <= 0 {
replayProtectionWindow = defaultReplayProtectionWindow
}

serverName := config.ServerName
// Do not allow the use of an IP address literal as an SNI value.
Expand Down Expand Up @@ -180,7 +200,7 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
clientCAs: config.ClientCAs,
customCipherSuites: config.CustomCipherSuites,
retransmitInterval: workerInterval,
log: logger,
log: conn.log,
initialEpoch: 0,
keyLogWriter: config.KeyLogWriter,
sessionStore: config.SessionStore,
Expand All @@ -205,30 +225,30 @@ func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient
var initialFSMState handshakeState

if initialState != nil {
if c.state.isClient {
if conn.state.isClient {
initialFlight = flight5
} else {
initialFlight = flight6
}
initialFSMState = handshakeFinished

c.state = *initialState
conn.state = *initialState
} else {
if c.state.isClient {
if conn.state.isClient {
initialFlight = flight1
} else {
initialFlight = flight0
}
initialFSMState = handshakePreparing
}
// Do handshake
if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
if err := conn.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil {
return nil, err
}

c.log.Trace("Handshake Completed")
conn.log.Trace("Handshake Completed")

return c, nil
return conn, nil
}

// Dial connects to the given network address and establishes a DTLS connection on top.
Expand Down Expand Up @@ -279,16 +299,24 @@ func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Con
return nil, errPSKAndIdentityMustBeSetForClient
}

return createConn(ctx, conn, config, true, nil)
dconn, err := createConn(conn, config, true)
if err != nil {
return nil, err
}

return handshakeConn(ctx, dconn, config, true, nil)
}

// ServerWithContext listens for incoming DTLS connections.
func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) {
if config == nil {
return nil, errNoConfigProvided
}

return createConn(ctx, conn, config, false, nil)
dconn, err := createConn(conn, config, false)
if err != nil {
return nil, err
}
return handshakeConn(ctx, dconn, config, false, nil)
}

// Read reads data from the connection.
Expand Down Expand Up @@ -662,7 +690,6 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
c.log.Debugf("discarded broken packet: %v", err)
return false, nil, nil
}

// Validate epoch
remoteEpoch := c.state.getRemoteEpoch()
if h.Epoch > remoteEpoch {
Expand All @@ -673,8 +700,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
return false, nil, nil
}
if enqueue {
c.log.Debug("received packet of next epoch, queuing packet")
c.encryptedPackets = append(c.encryptedPackets, buf)
if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
c.log.Debug("received packet of next epoch, queuing packet")
c.encryptedPackets = append(c.encryptedPackets, buf)
} else {
c.log.Debug("app data packet queue full, dropping packet")
}
}
return false, nil, nil
}
Expand All @@ -697,8 +728,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
if h.Epoch != 0 {
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debug("handshake not finished, queuing packet")
if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debug("handshake not finished, queuing packet")
} else {
c.log.Debug("app data packet queue full, dropping packet")
}
}
return false, nil, nil
}
Expand Down Expand Up @@ -749,8 +784,12 @@ func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue boo
case *protocol.ChangeCipherSpec:
if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() {
if enqueue {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debugf("CipherSuite not initialized, queuing packet")
if len(c.encryptedPackets) < maxAppDataPacketQueueSize {
c.encryptedPackets = append(c.encryptedPackets, buf)
c.log.Debugf("CipherSuite not initialized, queuing packet")
} else {
c.log.Debug("app data packet queue full. dropping packet")
}
}
return false, nil, nil
}
Expand Down
85 changes: 85 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3050,3 +3050,88 @@ func (c *connWithCallback) Write(b []byte) (int, error) {
}
return c.Conn.Write(b)
}

func TestApplicationDataQueueLimited(t *testing.T) {
// Limit runtime in case of deadlocks
lim := test.TimeOut(time.Second * 20)
defer lim.Stop()

// Check for leaking routines
report := test.CheckRoutines(t)
defer report()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

ca, cb := dpipe.Pipe()
defer ca.Close()
defer cb.Close()

done := make(chan struct{})
go func() {
serverCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Error(err)
return
}
cfg := &Config{}
cfg.Certificates = []tls.Certificate{serverCert}

dconn, err := createConn(cb, cfg, false)
if err != nil {
t.Error(err)
return
}
go func() {
for i := 0; i < 5; i++ {
dconn.lock.RLock()
qlen := len(dconn.encryptedPackets)
dconn.lock.RUnlock()
if qlen > maxAppDataPacketQueueSize {
t.Error("too many encrypted packets enqueued", len(dconn.encryptedPackets))
}
t.Log(qlen)
time.Sleep(1 * time.Second)
}

}()
if _, err := handshakeConn(ctx, dconn, cfg, false, nil); err == nil {
t.Error("expected handshake to fail")
}
close(done)
}()
extensions := []extension.Extension{}

time.Sleep(50 * time.Millisecond)

err := sendClientHello([]byte{}, ca, 0, extensions)
if err != nil {
t.Fatal(err)
}

time.Sleep(50 * time.Millisecond)

for i := 0; i < 1000; i++ {
// Send an application data packet
packet, err := (&recordlayer.RecordLayer{
Header: recordlayer.Header{
Version: protocol.Version1_2,
SequenceNumber: uint64(3),
Epoch: 1, // use an epoch greater than 0
},
Content: &protocol.ApplicationData{
Data: []byte{1, 2, 3, 4},
},
}).Marshal()
if err != nil {
t.Fatal(err)
}
ca.Write(packet)
if i%100 == 0 {
time.Sleep(10 * time.Millisecond)
}
}
time.Sleep(1 * time.Second)
ca.Close()
<-done
}
6 changes: 5 additions & 1 deletion resume.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ func Resume(state *State, conn net.Conn, config *Config) (*Conn, error) {
if err := state.initCipherSuite(); err != nil {
return nil, err
}
c, err := createConn(context.Background(), conn, config, state.isClient, state)
dconn, err := createConn(conn, config, state.isClient)
if err != nil {
return nil, err
}
c, err := handshakeConn(context.Background(), dconn, config, state.isClient, state)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit cfeb9ca

Please sign in to comment.