Skip to content

Commit

Permalink
fix handling of ACK frames serialized after CRYPTO frames
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 5, 2023
1 parent 26c6fcc commit c239066
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 51 deletions.
32 changes: 26 additions & 6 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -718,20 +718,22 @@ func (s *connection) idleTimeoutStartTime() time.Time {
}

func (s *connection) handleHandshakeComplete() error {
s.handshakeComplete = true
defer s.handshakeCtxCancel()
// Once the handshake completes, we have derived 1-RTT keys.
// There's no point in queueing undecryptable packets for later decryption any more.
// There's no point in queueing undecryptable packets for later decryption anymore.
s.undecryptablePackets = nil

s.connIDManager.SetHandshakeComplete()
s.connIDGenerator.SetHandshakeComplete()

// The server applies transport parameters right away, but the client side has to wait for handshake completion.
// During a 0-RTT connection, the client is only allowed to use the new transport parameters for 1-RTT packets.
if s.perspective == protocol.PerspectiveClient {
s.applyTransportParameters()
return nil
}

// All these only apply to the server side.
if err := s.handleHandshakeConfirmed(); err != nil {
return err
}
Expand Down Expand Up @@ -1229,6 +1231,7 @@ func (s *connection) handleFrames(
if log != nil {
frames = make([]logging.Frame, 0, 4)
}
handshakeWasComplete := s.handshakeComplete
var handleErr error
for len(data) > 0 {
l, frame, err := s.frameParser.ParseNext(data, encLevel, s.version)
Expand Down Expand Up @@ -1265,6 +1268,17 @@ func (s *connection) handleFrames(
return false, handleErr
}
}

// Handle completion of the handshake after processing all the frames.
// This ensures that we correctly handle the following case on the server side:
// We receive a Handshake packet that contains the CRYPTO frame that allows us to complete the handshake,
// and an ACK serialized after that CRYPTO frame. In this case, we still want to process the ACK frame.
if !handshakeWasComplete && s.handshakeComplete {
if err := s.handleHandshakeComplete(); err != nil {
return false, err
}
}

return
}

Expand Down Expand Up @@ -1360,7 +1374,9 @@ func (s *connection) handleHandshakeEvents() error {
case handshake.EventNoEvent:
return nil
case handshake.EventHandshakeComplete:
err = s.handleHandshakeComplete()
// Don't call handleHandshakeComplete yet.
// It's advantageous to process ACK frames that might be serialized after the CRYPTO frame first.
s.handshakeComplete = true
case handshake.EventReceivedTransportParameters:
err = s.handleTransportParameters(ev.TransportParameters)
case handshake.EventRestoredTransportParameters:
Expand Down Expand Up @@ -1488,6 +1504,9 @@ func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.Encr
if !acked1RTTPacket {
return nil
}
// On the client side: If the packet acknowledged a 1-RTT packet, this confirms the handshake.
// This is only possible if the ACK was sent in a 1-RTT packet.
// This is an optimization over simply waiting for a HANDSHAKE_DONE frame, see section 4.1.2 of RFC 9001.
if s.perspective == protocol.PerspectiveClient && !s.handshakeConfirmed {
if err := s.handleHandshakeConfirmed(); err != nil {
return err
Expand Down Expand Up @@ -1659,6 +1678,9 @@ func (s *connection) restoreTransportParameters(params *wire.TransportParameters
}

func (s *connection) handleTransportParameters(params *wire.TransportParameters) error {
if s.tracer != nil {
s.tracer.ReceivedTransportParameters(params)
}
if err := s.checkTransportParameters(params); err != nil {
return &qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
Expand Down Expand Up @@ -1693,9 +1715,6 @@ func (s *connection) checkTransportParameters(params *wire.TransportParameters)
if s.logger.Debug() {
s.logger.Debugf("Processed Transport Parameters: %s", params)
}
if s.tracer != nil {
s.tracer.ReceivedTransportParameters(params)
}

// check the initial_source_connection_id
if params.InitialSourceConnectionID != s.handshakeDestConnID {
Expand Down Expand Up @@ -1724,6 +1743,7 @@ func (s *connection) checkTransportParameters(params *wire.TransportParameters)

func (s *connection) applyTransportParameters() {
params := s.peerParams
fmt.Println("apply transport parameters", s.config.MaxIdleTimeout, params.MaxIdleTimeout)
// Our local idle timeout will always be > 0.
s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout)
s.keepAliveInterval = utils.Min(s.config.KeepAlivePeriod, utils.Min(s.idleTimeout/2, protocol.MaxKeepAliveInterval))
Expand Down
57 changes: 12 additions & 45 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1891,7 +1891,6 @@ var _ = Describe("Connection", func() {

It("cancels the HandshakeComplete context when the handshake completes", func() {
packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes()
finishHandshake := make(chan struct{})
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
conn.sentPacketHandler = sph
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake)
Expand All @@ -1901,53 +1900,26 @@ var _ = Describe("Connection", func() {
sph.EXPECT().DropPackets(protocol.EncryptionHandshake)
sph.EXPECT().SetHandshakeConfirmed()
connRunner.EXPECT().Retire(clientDestConnID)
go func() {
defer GinkgoRecover()
<-finishHandshake
cryptoSetup.EXPECT().StartHandshake()
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete})
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket()
conn.run()
}()
cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket()
handshakeCtx := conn.HandshakeComplete()
Consistently(handshakeCtx).ShouldNot(BeClosed())
close(finishHandshake)
Expect(conn.handleHandshakeComplete()).To(Succeed())
Eventually(handshakeCtx).Should(BeClosed())
// make sure the go routine returns
streamManager.EXPECT().CloseWithError(gomock.Any())
expectReplaceWithClosed()
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
Eventually(conn.Context().Done()).Should(BeClosed())
})

It("sends a session ticket when the handshake completes", func() {
const size = protocol.MaxPostHandshakeCryptoFrameSize * 3 / 2
packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), conn.version).AnyTimes()
finishHandshake := make(chan struct{})
connRunner.EXPECT().Retire(clientDestConnID)
conn.sentPacketHandler.DropPackets(protocol.EncryptionInitial)
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionHandshake)
go func() {
defer GinkgoRecover()
<-finishHandshake
cryptoSetup.EXPECT().StartHandshake()
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete})
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil)
conn.run()
}()
cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil)

handshakeCtx := conn.HandshakeComplete()
Consistently(handshakeCtx).ShouldNot(BeClosed())
close(finishHandshake)
Expect(conn.handleHandshakeComplete()).To(Succeed())
var frames []ackhandler.Frame
Eventually(func() []ackhandler.Frame {
frames, _ = conn.framer.AppendControlFrames(nil, protocol.MaxByteCount, protocol.Version1)
Expand All @@ -1963,16 +1935,6 @@ var _ = Describe("Connection", func() {
}
}
Expect(size).To(BeEquivalentTo(s))
// make sure the go routine returns
streamManager.EXPECT().CloseWithError(gomock.Any())
expectReplaceWithClosed()
packer.EXPECT().PackApplicationClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
cryptoSetup.EXPECT().Close()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
conn.shutdown()
Eventually(conn.Context().Done()).Should(BeClosed())
})

It("doesn't cancel the HandshakeComplete context when the handshake fails", func() {
Expand Down Expand Up @@ -2027,6 +1989,7 @@ var _ = Describe("Connection", func() {
cryptoSetup.EXPECT().SetHandshakeConfirmed()
cryptoSetup.EXPECT().GetSessionTicket()
mconn.EXPECT().Write(gomock.Any(), gomock.Any())
Expect(conn.handleHandshakeComplete()).To(Succeed())
conn.run()
}()
Eventually(done).Should(BeClosed())
Expand Down Expand Up @@ -2350,6 +2313,7 @@ var _ = Describe("Connection", func() {
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1)
cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1)
Expect(conn.handleHandshakeComplete()).To(Succeed())
err := conn.run()
nerr, ok := err.(net.Error)
Expect(ok).To(BeTrue())
Expand Down Expand Up @@ -2867,7 +2831,10 @@ var _ = Describe("Client Connection", func() {
TransportParameters: params,
})
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}).MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}).MaxTimes(1)
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}).MaxTimes(1).Do(func() {
defer GinkgoRecover()
Expect(conn.handleHandshakeComplete()).To(Succeed())
})
errChan <- conn.run()
close(errChan)
}()
Expand Down

0 comments on commit c239066

Please sign in to comment.