diff --git a/.circleci/config.yml b/.circleci/config.yml index 2ea17cc7c68..e39fa12f527 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,11 +1,5 @@ version: 2.1 executors: - test-go119: - docker: - - image: "cimg/go:1.19" - environment: - runrace: true - TIMESCALE_FACTOR: 3 test-go120: docker: - image: "cimg/go:1.20" @@ -15,7 +9,7 @@ executors: jobs: "test": &test - executor: test-go119 + executor: test-go120 steps: - checkout - run: @@ -39,14 +33,10 @@ jobs: - run: name: "Run version negotiation tests with qlog" command: go run github.com/onsi/ginkgo/v2/ginkgo -v -randomize-all -trace integrationtests/versionnegotiation -- -qlog - go119: - <<: *test go120: <<: *test - executor: test-go120 workflows: workflow: jobs: - - go119 - go120 diff --git a/.github/workflows/cross-compile.yml b/.github/workflows/cross-compile.yml index e9f9211fb89..e08558210b5 100644 --- a/.github/workflows/cross-compile.yml +++ b/.github/workflows/cross-compile.yml @@ -4,7 +4,7 @@ jobs: strategy: fail-fast: false matrix: - go: [ "1.19.x", "1.20.x" ] + go: [ "1.20.x", "1.21.0-rc.2" ] runs-on: ${{ fromJSON(vars['CROSS_COMPILE_RUNNER_UBUNTU'] || '"ubuntu-latest"') }} name: "Cross Compilation (Go ${{matrix.go}})" steps: diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index b973ef9ccc9..e73883ddcf1 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -5,7 +5,7 @@ jobs: strategy: fail-fast: false matrix: - go: [ "1.19.x", "1.20.x" ] + go: [ "1.20.x", "1.21.0-rc.2" ] runs-on: ${{ fromJSON(vars['INTEGRATION_RUNNER_UBUNTU'] || '"ubuntu-latest"') }} env: DEBUG: false # set this to true to export qlogs and save them as artifacts diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index eb4ce725655..f436d74834e 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -7,7 +7,7 @@ jobs: fail-fast: false matrix: os: [ "ubuntu", "windows", "macos" ] - go: [ "1.19.x", "1.20.x" ] + go: [ "1.20.x", "1.21.0-rc.2" ] runs-on: ${{ fromJSON(vars[format('UNIT_RUNNER_{0}', matrix.os)] || format('"{0}-latest"', matrix.os)) }} name: Unit tests (${{ matrix.os}}, Go ${{ matrix.go }}) steps: diff --git a/.golangci.yml b/.golangci.yml index 7820be8c975..1315759bc1f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,4 +1,6 @@ run: + skip-files: + - internal/handshake/cipher_suite.go linters-settings: depguard: type: blacklist diff --git a/README.md b/README.md index ad29d8650f9..3cc6b5a9c43 100644 --- a/README.md +++ b/README.md @@ -220,7 +220,8 @@ quic-go always aims to support the latest two Go releases. ### Dependency on forked crypto/tls -Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20) and [qtls for Go 1.19](https://github.com/quic-go/qtls-go1-19). This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward. +Since the standard library didn't provide any QUIC APIs before the Go 1.21 release, we had to fork crypto/tls to add the required APIs ourselves: [qtls for Go 1.20](https://github.com/quic-go/qtls-go1-20). +This had led to a lot of pain in the Go ecosystem, and we're happy that we can rely on Go 1.21 going forward. ## Contributing diff --git a/codecov.yml b/codecov.yml index 074d9832523..69435150974 100644 --- a/codecov.yml +++ b/codecov.yml @@ -8,6 +8,7 @@ coverage: - http3/gzip_reader.go - interop/ - internal/ackhandler/packet_linkedlist.go + - internal/handshake/cipher_suite.go - internal/utils/byteinterval_linkedlist.go - internal/utils/newconnectionid_linkedlist.go - internal/utils/packetinterval_linkedlist.go diff --git a/connection.go b/connection.go index c544b591af9..81863b34e83 100644 --- a/connection.go +++ b/connection.go @@ -52,7 +52,7 @@ type streamManager interface { } type cryptoStreamHandler interface { - RunHandshake() + StartHandshake() error ChangeConnectionID(protocol.ConnectionID) SetLargest1RTTAcked(protocol.PacketNumber) error SetHandshakeConfirmed() @@ -98,15 +98,15 @@ type connRunner interface { type handshakeRunner struct { onReceivedParams func(*wire.TransportParameters) - onError func(error) + onReceivedReadKeys func() dropKeys func(protocol.EncryptionLevel) onHandshakeComplete func() } func (r *handshakeRunner) OnReceivedParams(tp *wire.TransportParameters) { r.onReceivedParams(tp) } -func (r *handshakeRunner) OnError(e error) { r.onError(e) } func (r *handshakeRunner) DropKeys(el protocol.EncryptionLevel) { r.dropKeys(el) } func (r *handshakeRunner) OnHandshakeComplete() { r.onHandshakeComplete() } +func (r *handshakeRunner) OnReceivedReadKeys() { r.onReceivedReadKeys() } type closeError struct { err error @@ -329,14 +329,13 @@ var newConnection = func( cs := handshake.NewCryptoSetupServer( initialStream, handshakeStream, + s.oneRTTStream, clientDestConnID, - conn.LocalAddr(), - conn.RemoteAddr(), params, &handshakeRunner{ - onReceivedParams: s.handleTransportParameters, - onError: s.closeLocal, - dropKeys: s.dropEncryptionLevel, + onReceivedParams: s.handleTransportParameters, + dropKeys: s.dropEncryptionLevel, + onReceivedReadKeys: s.receivedReadKeys, onHandshakeComplete: func() { runner.Retire(clientDestConnID) close(s.handshakeCompleteChan) @@ -418,6 +417,7 @@ var newClientConnection = func( s.mtuDiscoverer = newMTUDiscoverer(s.rttStats, getMaxPacketSize(s.conn.RemoteAddr()), s.sentPacketHandler.SetMaxDatagramSize) initialStream := newCryptoStream() handshakeStream := newCryptoStream() + oneRTTStream := newCryptoStream() params := &wire.TransportParameters{ InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow), InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow), @@ -448,14 +448,13 @@ var newClientConnection = func( cs, clientHelloWritten := handshake.NewCryptoSetupClient( initialStream, handshakeStream, + oneRTTStream, destConnID, - conn.LocalAddr(), - conn.RemoteAddr(), params, &handshakeRunner{ onReceivedParams: s.handleTransportParameters, - onError: s.closeLocal, dropKeys: s.dropEncryptionLevel, + onReceivedReadKeys: s.receivedReadKeys, onHandshakeComplete: func() { close(s.handshakeCompleteChan) }, }, tlsConf, @@ -467,7 +466,7 @@ var newClientConnection = func( ) s.clientHelloWritten = clientHelloWritten s.cryptoStreamHandler = cs - s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, newCryptoStream()) + s.cryptoStreamManager = newCryptoStreamManager(cs, initialStream, handshakeStream, oneRTTStream) s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen) s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, initialStream, handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective) if len(tlsConf.ServerName) > 0 { @@ -530,11 +529,9 @@ func (s *connection) run() error { s.timer = *newTimer() - handshaking := make(chan struct{}) - go func() { - defer close(handshaking) - s.cryptoStreamHandler.RunHandshake() - }() + if err := s.cryptoStreamHandler.StartHandshake(); err != nil { + return err + } go func() { if err := s.sendQueue.Run(); err != nil { s.destroyImpl(err) @@ -686,7 +683,6 @@ runLoop: } s.cryptoStreamHandler.Close() - <-handshaking s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE s.handleCloseError(&closeErr) if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) && s.tracer != nil { @@ -717,7 +713,9 @@ func (s *connection) supportsDatagrams() bool { func (s *connection) ConnectionState() ConnectionState { s.connStateMutex.Lock() defer s.connStateMutex.Unlock() - s.connState.TLS = s.cryptoStreamHandler.ConnectionState() + cs := s.cryptoStreamHandler.ConnectionState() + s.connState.TLS = cs.ConnectionState + s.connState.Used0RTT = cs.Used0RTT return s.connState } @@ -786,7 +784,7 @@ func (s *connection) handleHandshakeComplete() { if err != nil { s.closeLocal(err) } - if ticket != nil { + if ticket != nil { // may be nil if session tickets are disabled via tls.Config.SessionTicketsDisabled s.oneRTTStream.Write(ticket) for s.oneRTTStream.HasData() { s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize)) @@ -1378,16 +1376,13 @@ func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame } func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { - encLevelChanged, err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) - if err != nil { - return err - } - if encLevelChanged { - // Queue all packets for decryption that have been undecryptable so far. - s.undecryptablePacketsToProcess = s.undecryptablePackets - s.undecryptablePackets = nil - } - return nil + return s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel) +} + +func (s *connection) receivedReadKeys() { + // Queue all packets for decryption that have been undecryptable so far. + s.undecryptablePacketsToProcess = s.undecryptablePackets + s.undecryptablePackets = nil } func (s *connection) handleStreamFrame(frame *wire.StreamFrame) error { @@ -1629,11 +1624,15 @@ func (s *connection) handleCloseError(closeErr *closeError) { } func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel) { - s.sentPacketHandler.DropPackets(encLevel) - s.receivedPacketHandler.DropPackets(encLevel) if s.tracer != nil { s.tracer.DroppedEncryptionLevel(encLevel) } + s.sentPacketHandler.DropPackets(encLevel) + s.receivedPacketHandler.DropPackets(encLevel) + if err := s.cryptoStreamManager.Drop(encLevel); err != nil { + s.closeLocal(err) + return + } if encLevel == protocol.Encryption0RTT { s.streamsMap.ResetFor0RTT() if err := s.connFlowController.Reset(); err != nil { @@ -1817,6 +1816,9 @@ func (s *connection) sendPackets(now time.Time) error { s.framer.QueueControlFrame(&wire.DataBlockedFrame{MaximumData: offset}) } s.windowUpdateQueue.QueueAll() + if cf := s.cryptoStreamManager.GetPostHandshakeData(protocol.MaxPostHandshakeCryptoFrameSize); cf != nil { + s.queueControlFrame(cf) + } if !s.handshakeConfirmed { packet, err := s.packer.PackCoalescedPacket(false, s.mtuDiscoverer.CurrentSize(), s.version) diff --git a/connection_test.go b/connection_test.go index 177aa44143f..c8e7edb155b 100644 --- a/connection_test.go +++ b/connection_test.go @@ -119,7 +119,7 @@ var _ = Describe("Connection", func() { &protocol.DefaultConnectionIDGenerator{}, protocol.StatelessResetToken{}, populateServerConfig(&Config{DisablePathMTUDiscovery: true}), - nil, // tls.Config + &tls.Config{}, tokenGenerator, false, tracer, @@ -357,7 +357,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) Expect(conn.run()).To(MatchError(expectedErr)) }() Expect(conn.handleFrame(&wire.ConnectionCloseFrame{ @@ -385,7 +385,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) Expect(conn.run()).To(MatchError(testErr)) }() ccf := &wire.ConnectionCloseFrame{ @@ -432,7 +432,7 @@ var _ = Describe("Connection", func() { runConn := func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) runErr <- conn.run() }() Eventually(areConnsRunning).Should(BeTrue()) @@ -811,7 +811,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), conn.version).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() expectReplaceWithClosed() @@ -853,7 +853,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) @@ -888,7 +888,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) @@ -913,7 +913,7 @@ var _ = Describe("Connection", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) err := conn.run() Expect(err).To(HaveOccurred()) Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) @@ -937,7 +937,7 @@ var _ = Describe("Connection", func() { runErr := make(chan error) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) runErr <- conn.run() }() expectReplaceWithClosed() @@ -961,7 +961,7 @@ var _ = Describe("Connection", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) err := conn.run() Expect(err).To(HaveOccurred()) Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) @@ -1197,7 +1197,7 @@ var _ = Describe("Connection", func() { runConn := func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() close(connDone) }() @@ -1415,7 +1415,7 @@ var _ = Describe("Connection", func() { }) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1439,7 +1439,7 @@ var _ = Describe("Connection", func() { }) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1463,7 +1463,7 @@ var _ = Describe("Connection", func() { }) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1479,7 +1479,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1496,7 +1496,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1514,7 +1514,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any(), gomock.Any()) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1540,7 +1540,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(2) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1562,7 +1562,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }).Times(3) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1580,7 +1580,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Available().Return(available) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1602,7 +1602,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().WouldBlock().AnyTimes() go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() @@ -1633,7 +1633,7 @@ var _ = Describe("Connection", func() { sender.EXPECT().Send(gomock.Any(), gomock.Any()).DoAndReturn(func(*packetBuffer, protocol.ByteCount) { written <- struct{}{} }) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() available := make(chan struct{}, 1) @@ -1664,7 +1664,7 @@ var _ = Describe("Connection", func() { // don't EXPECT any calls to mconn.Write() go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() // no packet will get sent @@ -1687,7 +1687,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().PackMTUProbePacket(ping, protocol.ByteCount(1234), conn.version).Return(shortHeaderPacket{PacketNumber: 1}, getPacketBuffer(), nil) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() conn.scheduleSending() @@ -1734,7 +1734,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() // don't EXPECT any calls to mconn.Write() @@ -1768,7 +1768,7 @@ var _ = Describe("Connection", func() { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() Eventually(written).Should(BeClosed()) @@ -1832,7 +1832,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() @@ -1864,7 +1864,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() <-finishHandshake - cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().StartHandshake() cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() close(conn.handshakeCompleteChan) @@ -1894,7 +1894,7 @@ var _ = Describe("Connection", func() { go func() { defer GinkgoRecover() <-finishHandshake - cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().StartHandshake() cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket().Return(make([]byte, size), nil) close(conn.handshakeCompleteChan) @@ -1941,7 +1941,7 @@ var _ = Describe("Connection", func() { tracer.EXPECT().Close() go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().StartHandshake() conn.run() }() handshakeCtx := conn.HandshakeComplete() @@ -1974,7 +1974,7 @@ var _ = Describe("Connection", func() { packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), conn.version).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes() go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake() + cryptoSetup.EXPECT().StartHandshake() cryptoSetup.EXPECT().SetHandshakeConfirmed() cryptoSetup.EXPECT().GetSessionTicket() mconn.EXPECT().Write(gomock.Any(), gomock.Any()) @@ -1997,7 +1997,7 @@ var _ = Describe("Connection", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) Expect(conn.run()).To(Succeed()) close(done) }() @@ -2017,7 +2017,7 @@ var _ = Describe("Connection", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) err := conn.run() Expect(err).To(MatchError(&qerr.ApplicationError{ ErrorCode: 0x1337, @@ -2069,7 +2069,7 @@ var _ = Describe("Connection", func() { runConn := func() { go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() } @@ -2171,7 +2171,7 @@ var _ = Describe("Connection", func() { ) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) err := conn.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) @@ -2196,7 +2196,7 @@ var _ = Describe("Connection", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) err := conn.run() nerr, ok := err.(net.Error) Expect(ok).To(BeTrue()) @@ -2229,7 +2229,7 @@ var _ = Describe("Connection", func() { // and not on the last network activity go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) @@ -2256,7 +2256,7 @@ var _ = Describe("Connection", func() { conn.handshakeComplete = false go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) err := conn.run() nerr, ok := err.(net.Error) @@ -2285,7 +2285,7 @@ var _ = Describe("Connection", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1) close(conn.handshakeCompleteChan) @@ -2305,7 +2305,7 @@ var _ = Describe("Connection", func() { conn.idleTimeout = 30 * time.Second go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() Consistently(conn.Context().Done()).ShouldNot(BeClosed()) @@ -2336,7 +2336,7 @@ var _ = Describe("Connection", func() { pto := conn.rttStats.PTO(true) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) cryptoSetup.EXPECT().GetSessionTicket().MaxTimes(1) cryptoSetup.EXPECT().SetHandshakeConfirmed().MaxTimes(1) close(conn.handshakeCompleteChan) @@ -2508,7 +2508,7 @@ var _ = Describe("Client Connection", func() { conn.unpacker = unpacker go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) conn.run() }() newConnID := protocol.ParseConnectionID([]byte{1, 3, 3, 7, 1, 3, 3, 7}) @@ -2588,7 +2588,7 @@ var _ = Describe("Client Connection", func() { tracer.EXPECT().ClosedConnection(gomock.Any()) tracer.EXPECT().Close() running := make(chan struct{}) - cryptoSetup.EXPECT().RunHandshake().Do(func() { + cryptoSetup.EXPECT().StartHandshake().Do(func() { close(running) conn.closeLocal(errors.New("early error")) }) @@ -2641,7 +2641,7 @@ var _ = Describe("Client Connection", func() { errChan := make(chan error, 1) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) errChan <- conn.run() }() connRunner.EXPECT().Remove(srcConnID) @@ -2666,7 +2666,7 @@ var _ = Describe("Client Connection", func() { errChan := make(chan error, 1) go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) errChan <- conn.run() }() connRunner.EXPECT().Remove(srcConnID).MaxTimes(1) @@ -2774,7 +2774,7 @@ var _ = Describe("Client Connection", func() { closed = false go func() { defer GinkgoRecover() - cryptoSetup.EXPECT().RunHandshake().MaxTimes(1) + cryptoSetup.EXPECT().StartHandshake().MaxTimes(1) errChan <- conn.run() close(errChan) }() diff --git a/crypto_stream.go b/crypto_stream.go index f10e91202fa..4be2a07ae1a 100644 --- a/crypto_stream.go +++ b/crypto_stream.go @@ -71,17 +71,9 @@ func (s *cryptoStreamImpl) HandleCryptoFrame(f *wire.CryptoFrame) error { // GetCryptoData retrieves data that was received in CRYPTO frames func (s *cryptoStreamImpl) GetCryptoData() []byte { - if len(s.msgBuf) < 4 { - return nil - } - msgLen := 4 + int(s.msgBuf[1])<<16 + int(s.msgBuf[2])<<8 + int(s.msgBuf[3]) - if len(s.msgBuf) < msgLen { - return nil - } - msg := make([]byte, msgLen) - copy(msg, s.msgBuf[:msgLen]) - s.msgBuf = s.msgBuf[msgLen:] - return msg + b := s.msgBuf + s.msgBuf = nil + return b } func (s *cryptoStreamImpl) Finish() error { diff --git a/crypto_stream_manager.go b/crypto_stream_manager.go index 91946acfa52..8961965d072 100644 --- a/crypto_stream_manager.go +++ b/crypto_stream_manager.go @@ -8,7 +8,7 @@ import ( ) type cryptoDataHandler interface { - HandleMessage([]byte, protocol.EncryptionLevel) bool + HandleMessage([]byte, protocol.EncryptionLevel) error } type cryptoStreamManager struct { @@ -33,7 +33,7 @@ func newCryptoStreamManager( } } -func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) (bool /* encryption level changed */, error) { +func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error { var str cryptoStream //nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets. switch encLevel { @@ -44,18 +44,39 @@ func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLeve case protocol.Encryption1RTT: str = m.oneRTTStream default: - return false, fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) + return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel) } if err := str.HandleCryptoFrame(frame); err != nil { - return false, err + return err } for { data := str.GetCryptoData() if data == nil { - return false, nil + return nil } - if encLevelFinished := m.cryptoHandler.HandleMessage(data, encLevel); encLevelFinished { - return true, str.Finish() + if err := m.cryptoHandler.HandleMessage(data, encLevel); err != nil { + return err } } } + +func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame { + if !m.oneRTTStream.HasData() { + return nil + } + return m.oneRTTStream.PopCryptoFrame(maxSize) +} + +func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error { + //nolint:exhaustive // 1-RTT keys should never get dropped. + switch encLevel { + case protocol.EncryptionInitial: + return m.initialStream.Finish() + case protocol.EncryptionHandshake: + return m.handshakeStream.Finish() + case protocol.Encryption0RTT: + return nil + default: + panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel)) + } +} diff --git a/crypto_stream_manager_test.go b/crypto_stream_manager_test.go index e2f46c8cc97..c5b59b92780 100644 --- a/crypto_stream_manager_test.go +++ b/crypto_stream_manager_test.go @@ -1,12 +1,9 @@ package quic import ( - "errors" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/wire" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) @@ -35,9 +32,7 @@ var _ = Describe("Crypto Stream Manager", func() { initialStream.EXPECT().GetCryptoData().Return([]byte("foobar")) initialStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionInitial) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionInitial) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionInitial)).To(Succeed()) }) It("passes messages to the handshake stream", func() { @@ -46,9 +41,7 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")) handshakeStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) }) It("passes messages to the 1-RTT stream", func() { @@ -57,9 +50,7 @@ var _ = Describe("Crypto Stream Manager", func() { oneRTTStream.EXPECT().GetCryptoData().Return([]byte("foobar")) oneRTTStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foobar"), protocol.Encryption1RTT) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.Encryption1RTT) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) + Expect(csm.HandleCryptoFrame(cf, protocol.Encryption1RTT)).To(Succeed()) }) It("doesn't call the message handler, if there's no message", func() { @@ -67,9 +58,7 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().HandleCryptoFrame(cf) handshakeStream.EXPECT().GetCryptoData() // don't return any data to handle // don't EXPECT any calls to HandleMessage() - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) }) It("processes all messages", func() { @@ -80,39 +69,11 @@ var _ = Describe("Crypto Stream Manager", func() { handshakeStream.EXPECT().GetCryptoData() cs.EXPECT().HandleMessage([]byte("foo"), protocol.EncryptionHandshake) cs.EXPECT().HandleMessage([]byte("bar"), protocol.EncryptionHandshake) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeFalse()) - }) - - It("finishes the crypto stream, when the crypto setup is done with this encryption level", func() { - cf := &wire.CryptoFrame{Data: []byte("foobar")} - gomock.InOrder( - handshakeStream.EXPECT().HandleCryptoFrame(cf), - handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), - cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), - handshakeStream.EXPECT().Finish(), - ) - encLevelChanged, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).ToNot(HaveOccurred()) - Expect(encLevelChanged).To(BeTrue()) - }) - - It("returns errors that occur when finishing a stream", func() { - testErr := errors.New("test error") - cf := &wire.CryptoFrame{Data: []byte("foobar")} - gomock.InOrder( - handshakeStream.EXPECT().HandleCryptoFrame(cf), - handshakeStream.EXPECT().GetCryptoData().Return([]byte("foobar")), - cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake).Return(true), - handshakeStream.EXPECT().Finish().Return(testErr), - ) - _, err := csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake) - Expect(err).To(MatchError(err)) + Expect(csm.HandleCryptoFrame(cf, protocol.EncryptionHandshake)).To(Succeed()) }) It("errors for unknown encryption levels", func() { - _, err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, 42) + err := csm.HandleCryptoFrame(&wire.CryptoFrame{}, 42) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("received CRYPTO frame with unexpected encryption level")) }) diff --git a/crypto_stream_test.go b/crypto_stream_test.go index 100498ebe94..9a4a2ee57f9 100644 --- a/crypto_stream_test.go +++ b/crypto_stream_test.go @@ -1,7 +1,6 @@ package quic import ( - "crypto/rand" "fmt" "github.com/quic-go/quic-go/internal/protocol" @@ -12,16 +11,6 @@ import ( . "github.com/onsi/gomega" ) -func createHandshakeMessage(len int) []byte { - msg := make([]byte, 4+len) - rand.Read(msg[:1]) // random message type - msg[1] = uint8(len >> 16) - msg[2] = uint8(len >> 8) - msg[3] = uint8(len) - rand.Read(msg[4:]) - return msg -} - var _ = Describe("Crypto Stream", func() { var str cryptoStream @@ -31,21 +20,11 @@ var _ = Describe("Crypto Stream", func() { Context("handling incoming data", func() { It("handles in-order CRYPTO frames", func() { - msg := createHandshakeMessage(6) - err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg}) - Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")})).To(Succeed()) + Expect(str.GetCryptoData()).To(Equal([]byte("foo"))) Expect(str.GetCryptoData()).To(BeNil()) - }) - - It("handles multiple messages in one CRYPTO frame", func() { - msg1 := createHandshakeMessage(6) - msg2 := createHandshakeMessage(10) - msg := append(append([]byte{}, msg1...), msg2...) - err := str.HandleCryptoFrame(&wire.CryptoFrame{Data: msg}) - Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(Equal(msg1)) - Expect(str.GetCryptoData()).To(Equal(msg2)) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3})).To(Succeed()) + Expect(str.GetCryptoData()).To(Equal([]byte("bar"))) Expect(str.GetCryptoData()).To(BeNil()) }) @@ -59,42 +38,17 @@ var _ = Describe("Crypto Stream", func() { })) }) - It("handles messages split over multiple CRYPTO frames", func() { - msg := createHandshakeMessage(6) - err := str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: msg[:4], - }) - Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(BeNil()) - err = str.HandleCryptoFrame(&wire.CryptoFrame{ - Offset: 4, - Data: msg[4:], - }) - Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(Equal(msg)) - Expect(str.GetCryptoData()).To(BeNil()) - }) - It("handles out-of-order CRYPTO frames", func() { - msg := createHandshakeMessage(6) - err := str.HandleCryptoFrame(&wire.CryptoFrame{ - Offset: 4, - Data: msg[4:], - }) - Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(BeNil()) - err = str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: msg[:4], - }) - Expect(err).ToNot(HaveOccurred()) - Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{Offset: 3, Data: []byte("bar")})).To(Succeed()) + Expect(str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")})).To(Succeed()) + Expect(str.GetCryptoData()).To(Equal([]byte("foobar"))) Expect(str.GetCryptoData()).To(BeNil()) }) Context("finishing", func() { It("errors if there's still data to read after finishing", func() { Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: createHandshakeMessage(5), + Data: []byte("foobar"), Offset: 10, })).To(Succeed()) Expect(str.Finish()).To(MatchError(&qerr.TransportError{ @@ -120,7 +74,7 @@ var _ = Describe("Crypto Stream", func() { It("rejects new crypto data after finishing", func() { Expect(str.Finish()).To(Succeed()) Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: createHandshakeMessage(5), + Data: []byte("foo"), })).To(MatchError(&qerr.TransportError{ ErrorCode: qerr.ProtocolViolation, ErrorMessage: "received crypto data after change of encryption level", @@ -128,15 +82,14 @@ var _ = Describe("Crypto Stream", func() { }) It("ignores crypto data below the maximum offset received before finishing", func() { - msg := createHandshakeMessage(15) Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ - Data: msg, + Data: []byte("foobar"), })).To(Succeed()) - Expect(str.GetCryptoData()).To(Equal(msg)) + Expect(str.GetCryptoData()).To(Equal([]byte("foobar"))) Expect(str.Finish()).To(Succeed()) Expect(str.HandleCryptoFrame(&wire.CryptoFrame{ - Offset: protocol.ByteCount(len(msg) - 6), - Data: []byte("foobar"), + Offset: 2, + Data: []byte("foo"), })).To(Succeed()) }) }) diff --git a/fuzzing/handshake/cmd/corpus.go b/fuzzing/handshake/cmd/corpus.go index 1142bea617c..3963fc1d418 100644 --- a/fuzzing/handshake/cmd/corpus.go +++ b/fuzzing/handshake/cmd/corpus.go @@ -43,26 +43,24 @@ func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */) type handshakeRunner interface { OnReceivedParams(*wire.TransportParameters) OnHandshakeComplete() - OnError(error) + OnReceivedReadKeys() DropKeys(protocol.EncryptionLevel) } type runner struct { - client, server *handshake.CryptoSetup + handshakeComplete chan<- struct{} } var _ handshakeRunner = &runner{} -func newRunner(client, server *handshake.CryptoSetup) *runner { - return &runner{client: client, server: server} +func newRunner(handshakeComplete chan<- struct{}) *runner { + return &runner{handshakeComplete: handshakeComplete} } func (r *runner) OnReceivedParams(*wire.TransportParameters) {} -func (r *runner) OnHandshakeComplete() {} -func (r *runner) OnError(err error) { - (*r.client).Close() - (*r.server).Close() - log.Fatal("runner error:", err) +func (r *runner) OnReceivedReadKeys() {} +func (r *runner) OnHandshakeComplete() { + close(r.handshakeComplete) } func (r *runner) DropKeys(protocol.EncryptionLevel) {} @@ -71,16 +69,16 @@ const alpn = "fuzz" func main() { cChunkChan, cInitialStream, cHandshakeStream := initStreams() var client, server handshake.CryptoSetup - runner := newRunner(&client, &server) + clientHandshakeCompleted := make(chan struct{}) client, _ = handshake.NewCryptoSetupClient( cInitialStream, cHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, - runner, + newRunner(clientHandshakeCompleted), &tls.Config{ + MinVersion: tls.VersionTLS13, ServerName: "localhost", NextProtos: []string{alpn}, RootCAs: testdata.GetRootCA(), @@ -96,14 +94,14 @@ func main() { sChunkChan, sInitialStream, sHandshakeStream := initStreams() config := testdata.GetTLSConfig() config.NextProtos = []string{alpn} + serverHandshakeCompleted := make(chan struct{}) server = handshake.NewCryptoSetupServer( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, - runner, + newRunner(serverHandshakeCompleted), config, false, utils.NewRTTStats(), @@ -112,17 +110,13 @@ func main() { protocol.Version1, ) - serverHandshakeCompleted := make(chan struct{}) - go func() { - defer close(serverHandshakeCompleted) - server.RunHandshake() - }() + if err := client.StartHandshake(); err != nil { + log.Fatal(err) + } - clientHandshakeCompleted := make(chan struct{}) - go func() { - defer close(clientHandshakeCompleted) - client.RunHandshake() - }() + if err := server.StartHandshake(); err != nil { + log.Fatal(err) + } done := make(chan struct{}) go func() { @@ -137,10 +131,14 @@ messageLoop: select { case c := <-cChunkChan: messages = append(messages, c.data) - server.HandleMessage(c.data, c.encLevel) + if err := server.HandleMessage(c.data, c.encLevel); err != nil { + log.Fatal(err) + } case c := <-sChunkChan: messages = append(messages, c.data) - client.HandleMessage(c.data, c.encLevel) + if err := client.HandleMessage(c.data, c.encLevel); err != nil { + log.Fatal(err) + } case <-done: break messageLoop } diff --git a/fuzzing/handshake/fuzz.go b/fuzzing/handshake/fuzz.go index 2d73e6056bf..5c8fcb1b5ad 100644 --- a/fuzzing/handshake/fuzz.go +++ b/fuzzing/handshake/fuzz.go @@ -11,7 +11,6 @@ import ( "log" "math" mrand "math/rand" - "sync" "time" "github.com/quic-go/quic-go/fuzzing/internal/helper" @@ -157,39 +156,24 @@ func initStreams() (chan chunk, *stream /* initial */, *stream /* handshake */) type handshakeRunner interface { OnReceivedParams(*wire.TransportParameters) OnHandshakeComplete() - OnError(error) + OnReceivedReadKeys() DropKeys(protocol.EncryptionLevel) } type runner struct { - sync.Mutex - errored bool - client, server *handshake.CryptoSetup + handshakeComplete chan<- struct{} } var _ handshakeRunner = &runner{} -func newRunner(client, server *handshake.CryptoSetup) *runner { - return &runner{client: client, server: server} +func newRunner(handshakeComplete chan<- struct{}) *runner { + return &runner{handshakeComplete: handshakeComplete} } func (r *runner) OnReceivedParams(*wire.TransportParameters) {} -func (r *runner) OnHandshakeComplete() {} -func (r *runner) OnError(err error) { - r.Lock() - defer r.Unlock() - if r.errored { - return - } - r.errored = true - (*r.client).Close() - (*r.server).Close() -} - -func (r *runner) Errored() bool { - r.Lock() - defer r.Unlock() - return r.errored +func (r *runner) OnReceivedReadKeys() {} +func (r *runner) OnHandshakeComplete() { + close(r.handshakeComplete) } func (r *runner) DropKeys(protocol.EncryptionLevel) {} @@ -270,6 +254,7 @@ func Fuzz(data []byte) int { } clientConf := &tls.Config{ + MinVersion: tls.VersionTLS13, ServerName: "localhost", NextProtos: []string{alpn}, RootCAs: certPool, @@ -287,6 +272,7 @@ func Fuzz(data []byte) int { func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls.Config, data []byte) int { serverConf := &tls.Config{ + MinVersion: tls.VersionTLS13, Certificates: []tls.Certificate{*cert}, NextProtos: []string{alpn}, SessionTicketKey: sessionTicketKey, @@ -373,15 +359,14 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. cChunkChan, cInitialStream, cHandshakeStream := initStreams() var client, server handshake.CryptoSetup - runner := newRunner(&client, &server) + clientHandshakeCompleted := make(chan struct{}) client, _ = handshake.NewCryptoSetupClient( cInitialStream, cHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, clientTP, - runner, + newRunner(clientHandshakeCompleted), clientConf, enable0RTTClient, utils.NewRTTStats(), @@ -390,15 +375,15 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. protocol.Version1, ) + serverHandshakeCompleted := make(chan struct{}) sChunkChan, sInitialStream, sHandshakeStream := initStreams() server = handshake.NewCryptoSetupServer( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, serverTP, - runner, + newRunner(serverHandshakeCompleted), serverConf, enable0RTTServer, utils.NewRTTStats(), @@ -411,17 +396,13 @@ func runHandshake(runConfig [confLen]byte, messageConfig uint8, clientConf *tls. return -1 } - serverHandshakeCompleted := make(chan struct{}) - go func() { - defer close(serverHandshakeCompleted) - server.RunHandshake() - }() + if err := client.StartHandshake(); err != nil { + log.Fatal(err) + } - clientHandshakeCompleted := make(chan struct{}) - go func() { - defer close(clientHandshakeCompleted) - client.RunHandshake() - }() + if err := server.StartHandshake(); err != nil { + log.Fatal(err) + } done := make(chan struct{}) go func() { @@ -441,7 +422,9 @@ messageLoop: b = data encLevel = maxEncLevel(server, messageToReplaceEncLevel) } - server.HandleMessage(b, encLevel) + if err := server.HandleMessage(b, encLevel); err != nil { + break messageLoop + } case c := <-sChunkChan: b := c.data encLevel := c.encLevel @@ -450,21 +433,17 @@ messageLoop: b = data encLevel = maxEncLevel(client, messageToReplaceEncLevel) } - client.HandleMessage(b, encLevel) + if err := client.HandleMessage(b, encLevel); err != nil { + break messageLoop + } case <-done: // test done break messageLoop } - if runner.Errored() { - break messageLoop - } } <-done _ = client.ConnectionState() _ = server.ConnectionState() - if runner.Errored() { - return 0 - } sealer, err := client.Get1RTTSealer() if err != nil { diff --git a/go.mod b/go.mod index 9cdfd0d6f65..88caed06928 100644 --- a/go.mod +++ b/go.mod @@ -8,8 +8,7 @@ require ( github.com/onsi/ginkgo/v2 v2.9.5 github.com/onsi/gomega v1.27.6 github.com/quic-go/qpack v0.4.0 - github.com/quic-go/qtls-go1-19 v0.3.2 - github.com/quic-go/qtls-go1-20 v0.2.2 + github.com/quic-go/qtls-go1-20 v0.0.0-20230530152335-bf7bdeef8902 golang.org/x/crypto v0.4.0 golang.org/x/exp v0.0.0-20221205204356-47842c84f3db golang.org/x/net v0.10.0 diff --git a/go.sum b/go.sum index 3a777327959..88e0a6e06af 100644 --- a/go.sum +++ b/go.sum @@ -90,10 +90,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= -github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= -github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= -github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= +github.com/quic-go/qtls-go1-20 v0.0.0-20230530152335-bf7bdeef8902 h1:+ir8isKnADr1GYr/DmIg1NJ/ncNu2arZScXPDuRSC48= +github.com/quic-go/qtls-go1-20 v0.0.0-20230530152335-bf7bdeef8902/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= diff --git a/http3/client.go b/http3/client.go index c54de5ea959..b0e6546cc48 100644 --- a/http3/client.go +++ b/http3/client.go @@ -15,7 +15,6 @@ import ( "github.com/quic-go/quic-go" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/quicvarint" @@ -402,7 +401,7 @@ func (c *client) doRequest(req *http.Request, conn quic.EarlyConnection, str qui return nil, newConnError(ErrCodeGeneralProtocolError, err) } - connState := qtls.ToTLSConnectionState(conn.ConnectionState().TLS) + connState := conn.ConnectionState().TLS res := &http.Response{ Proto: "HTTP/3.0", ProtoMajor: 3, diff --git a/http3/server.go b/http3/server.go index d77501824eb..f2027429d66 100644 --- a/http3/server.go +++ b/http3/server.go @@ -577,7 +577,7 @@ func (s *Server) handleRequest(conn quic.Connection, str quic.Stream, decoder *q return newStreamError(ErrCodeGeneralProtocolError, err) } - connState := conn.ConnectionState().TLS.ConnectionState + connState := conn.ConnectionState().TLS req.TLS = &connState req.RemoteAddr = conn.RemoteAddr().String() body := newRequestBody(newStream(str, onFrameError)) diff --git a/http3/server_test.go b/http3/server_test.go index 91e1df8dab9..9a362c6ce64 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -926,7 +926,7 @@ var _ = Describe("Server", func() { c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") - Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) + Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) }) It("sets the GetConfigForClient callback if no tls.Config is given", func() { @@ -954,7 +954,7 @@ var _ = Describe("Server", func() { c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") - Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) + Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) }) It("works if GetConfigForClient returns a nil tls.Config", func() { @@ -967,7 +967,7 @@ var _ = Describe("Server", func() { c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") - Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) + Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) }) It("sets the ALPN for tls.Configs returned by the tls.GetConfigForClient, if it returns a static tls.Config", func() { @@ -985,7 +985,7 @@ var _ = Describe("Server", func() { c, err := quic.DialAddr(context.Background(), ln.Addr().String(), &tls.Config{InsecureSkipVerify: true, NextProtos: []string{NextProtoH3}}, nil) Expect(err).ToNot(HaveOccurred()) defer c.CloseWithError(0, "") - Expect(c.ConnectionState().TLS.ConnectionState.NegotiatedProtocol).To(Equal(NextProtoH3)) + Expect(c.ConnectionState().TLS.NegotiatedProtocol).To(Equal(NextProtoH3)) // check that the original config was not modified Expect(tlsClientConf.NextProtos).To(Equal([]string{"foo", "bar"})) }) diff --git a/integrationtests/gomodvendor/go.sum b/integrationtests/gomodvendor/go.sum index c7b644680d3..ae823e0fee2 100644 --- a/integrationtests/gomodvendor/go.sum +++ b/integrationtests/gomodvendor/go.sum @@ -136,10 +136,8 @@ github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7q github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0/go.mod h1:UZVnYIfi5GRk+zI9UMaCPsmZ2xKJP7XBUvVyT1Knj9A= -github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= -github.com/quic-go/qtls-go1-19 v0.3.2/go.mod h1:ySOI96ew8lnoKPtSqx2BlI5wCpUVPT05RMAlajtnyOI= -github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= -github.com/quic-go/qtls-go1-20 v0.2.2/go.mod h1:JKtK6mjbAVcUTN/9jZpvLbGxvdWIKS8uT7EiStoU1SM= +github.com/quic-go/qtls-go1-20 v0.0.0-20230530152335-bf7bdeef8902 h1:+ir8isKnADr1GYr/DmIg1NJ/ncNu2arZScXPDuRSC48= +github.com/quic-go/qtls-go1-20 v0.0.0-20230530152335-bf7bdeef8902/go.mod h1:X9Nh97ZL80Z+bX/gUXMbipO6OxdiDi58b/fMC9mAL+k= github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY= @@ -185,7 +183,6 @@ golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190313024323-a1f597ede03a/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= diff --git a/integrationtests/self/handshake_test.go b/integrationtests/self/handshake_test.go index 0f1c3678608..2ef4dd20768 100644 --- a/integrationtests/self/handshake_test.go +++ b/integrationtests/self/handshake_test.go @@ -198,7 +198,10 @@ var _ = Describe("Handshake tests", func() { var transportErr *quic.TransportError Expect(errors.As(err, &transportErr)).To(BeTrue()) Expect(transportErr.ErrorCode.IsCryptoError()).To(BeTrue()) - Expect(transportErr.Error()).To(ContainSubstring("tls: bad certificate")) + Expect(transportErr.Error()).To(Or( + ContainSubstring("tls: certificate required"), + ContainSubstring("tls: bad certificate"), + )) }) It("uses the ServerName in the tls.Config", func() { diff --git a/integrationtests/self/resumption_test.go b/integrationtests/self/resumption_test.go index f2c6ec891c5..264f832ebf5 100644 --- a/integrationtests/self/resumption_test.go +++ b/integrationtests/self/resumption_test.go @@ -5,7 +5,7 @@ import ( "crypto/tls" "fmt" "net" - "sync" + "time" "github.com/quic-go/quic-go" @@ -14,16 +14,15 @@ import ( ) type clientSessionCache struct { - mutex sync.Mutex - cache map[string]*tls.ClientSessionState + cache tls.ClientSessionCache gets chan<- string puts chan<- string } -func newClientSessionCache(gets, puts chan<- string) *clientSessionCache { +func newClientSessionCache(cache tls.ClientSessionCache, gets, puts chan<- string) *clientSessionCache { return &clientSessionCache{ - cache: make(map[string]*tls.ClientSessionState), + cache: cache, gets: gets, puts: puts, } @@ -32,29 +31,25 @@ func newClientSessionCache(gets, puts chan<- string) *clientSessionCache { var _ tls.ClientSessionCache = &clientSessionCache{} func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) { + session, ok := c.cache.Get(sessionKey) c.gets <- sessionKey - c.mutex.Lock() - session, ok := c.cache[sessionKey] - c.mutex.Unlock() return session, ok } func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) { + c.cache.Put(sessionKey, cs) c.puts <- sessionKey - c.mutex.Lock() - c.cache[sessionKey] = cs - c.mutex.Unlock() } var _ = Describe("TLS session resumption", func() { It("uses session resumption", func() { - server, err := quic.ListenAddr("localhost:0", getTLSConfig(), nil) + server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil)) Expect(err).ToNot(HaveOccurred()) defer server.Close() gets := make(chan string, 100) puts := make(chan string, 100) - cache := newClientSessionCache(gets, puts) + cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) tlsConf := getTLSClientConfig() tlsConf.ClientSessionCache = cache conn, err := quic.DialAddr( @@ -96,7 +91,7 @@ var _ = Describe("TLS session resumption", func() { gets := make(chan string, 100) puts := make(chan string, 100) - cache := newClientSessionCache(gets, puts) + cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts) tlsConf := getTLSClientConfig() tlsConf.ClientSessionCache = cache conn, err := quic.DialAddr( @@ -109,7 +104,9 @@ var _ = Describe("TLS session resumption", func() { Consistently(puts).ShouldNot(Receive()) Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse()) - serverConn, err := server.Accept(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + serverConn, err := server.Accept(ctx) Expect(err).ToNot(HaveOccurred()) Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse()) diff --git a/integrationtests/self/zero_rtt_oldgo_test.go b/integrationtests/self/zero_rtt_oldgo_test.go new file mode 100644 index 00000000000..beaf351e249 --- /dev/null +++ b/integrationtests/self/zero_rtt_oldgo_test.go @@ -0,0 +1,804 @@ +//go:build !go1.21 + +package self_test + +import ( + "context" + "crypto/tls" + "fmt" + "io" + mrand "math/rand" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/quic-go/quic-go" + quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/quic-go/internal/wire" + "github.com/quic-go/quic-go/logging" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("0-RTT", func() { + rtt := scaleDuration(5 * time.Millisecond) + + runCountingProxy := func(serverPort int) (*quicproxy.QuicProxy, *uint32) { + var num0RTTPackets uint32 // to be used as an atomic + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", serverPort), + DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { + for len(data) > 0 { + if !wire.IsLongHeaderPacket(data[0]) { + break + } + hdr, _, rest, err := wire.ParsePacket(data) + Expect(err).ToNot(HaveOccurred()) + if hdr.Type == protocol.PacketType0RTT { + atomic.AddUint32(&num0RTTPackets, 1) + break + } + data = rest + } + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + + return proxy, &num0RTTPackets + } + + dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) { + tlsConf := getTLSConfig() + if serverConf == nil { + serverConf = getQuicConfig(nil) + } + serverConf.Allow0RTT = true + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + serverConf, + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { return rtt / 2 }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + // dial the first connection in order to receive a session ticket + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + <-conn.Context().Done() + }() + + clientConf := getTLSClientConfig() + gets := make(chan string, 100) + puts := make(chan string, 100) + clientConf.ClientSessionCache = newClientSessionCache(tls.NewLRUClientSessionCache(100), gets, puts) + conn, err := quic.DialAddr( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + Eventually(puts).Should(Receive()) + // received the session ticket. We're done here. + Expect(conn.CloseWithError(0, "")).To(Succeed()) + Eventually(done).Should(BeClosed()) + return tlsConf, clientConf + } + + transfer0RTTData := func( + ln *quic.EarlyListener, + proxyPort int, + connIDLen int, + clientTLSConf *tls.Config, + clientConf *quic.Config, + testdata []byte, // data to transfer + ) { + // accept the second connection, and receive the data sent in 0-RTT + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(testdata)) + Expect(str.Close()).To(Succeed()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + <-conn.Context().Done() + close(done) + }() + + if clientConf == nil { + clientConf = getQuicConfig(nil) + } + var conn quic.EarlyConnection + if connIDLen == 0 { + var err error + conn, err = quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxyPort), + clientTLSConf, + clientConf, + ) + Expect(err).ToNot(HaveOccurred()) + } else { + addr, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + udpConn, err := net.ListenUDP("udp", addr) + Expect(err).ToNot(HaveOccurred()) + defer udpConn.Close() + tr := &quic.Transport{ + Conn: udpConn, + ConnectionIDLength: connIDLen, + } + defer tr.Close() + conn, err = tr.DialEarly( + context.Background(), + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: proxyPort}, + clientTLSConf, + clientConf, + ) + Expect(err).ToNot(HaveOccurred()) + } + defer conn.CloseWithError(0, "") + str, err := conn.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(testdata) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + <-conn.HandshakeComplete() + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn + conn.CloseWithError(0, "") + Eventually(done).Should(BeClosed()) + Eventually(conn.Context().Done()).Should(BeClosed()) + } + + check0RTTRejected := func( + ln *quic.EarlyListener, + proxyPort int, + clientConf *tls.Config, + ) { + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxyPort), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(make([]byte, 3000)) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) + + // make sure the server doesn't process the data + ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) + defer cancel() + serverConn, err := ln.Accept(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse()) + _, err = serverConn.AcceptUniStream(ctx) + Expect(err).To(Equal(context.DeadlineExceeded)) + Expect(serverConn.CloseWithError(0, "")).To(Succeed()) + Eventually(conn.Context().Done()).Should(BeClosed()) + } + + // can be used to extract 0-RTT from a packetTracer + get0RTTPackets := func(packets []packet) []protocol.PacketNumber { + var zeroRTTPackets []protocol.PacketNumber + for _, p := range packets { + if p.hdr.Type == protocol.PacketType0RTT { + zeroRTTPackets = append(zeroRTTPackets, p.hdr.PacketNumber) + } + } + return zeroRTTPackets + } + + for _, l := range []int{0, 15} { + connIDLen := l + + It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() { + tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + transfer0RTTData( + ln, + proxy.LocalPort(), + connIDLen, + clientTLSConf, + getQuicConfig(nil), + PRData, + ) + + var numNewConnIDs int + for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, f := range p.frames { + if _, ok := f.(*logging.NewConnectionIDFrame); ok { + numNewConnIDs++ + } + } + } + if connIDLen == 0 { + Expect(numNewConnIDs).To(BeZero()) + } else { + Expect(numNewConnIDs).ToNot(BeZero()) + } + + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) + Expect(zeroRTTPackets).To(ContainElement(protocol.PacketNumber(0))) + }) + } + + // Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets. + It("waits for a connection until the handshake is done", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + zeroRTTData := GeneratePRData(5 << 10) + oneRTTData := PRData + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + // now accept the second connection, and receive the 0-RTT data + go func() { + defer GinkgoRecover() + conn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(zeroRTTData)) + str, err = conn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err = io.ReadAll(str) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal(oneRTTData)) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }() + + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + firstStr, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = firstStr.Write(zeroRTTData) + Expect(err).ToNot(HaveOccurred()) + Expect(firstStr.Close()).To(Succeed()) + + // wait for the handshake to complete + Eventually(conn.HandshakeComplete()).Should(BeClosed()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(PRData) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + <-conn.Context().Done() + + // check that 0-RTT packets only contain STREAM frames for the first stream + var num0RTT int + for _, p := range tracer.getRcvdLongHeaderPackets() { + if p.hdr.Header.Type != protocol.PacketType0RTT { + continue + } + for _, f := range p.frames { + sf, ok := f.(*logging.StreamFrame) + if !ok { + continue + } + num0RTT++ + Expect(sf.StreamID).To(Equal(firstStr.StreamID())) + } + } + fmt.Fprintf(GinkgoWriter, "received %d STREAM frames in 0-RTT packets\n", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + }) + + It("transfers 0-RTT data, when 0-RTT packets are lost", func() { + var ( + num0RTTPackets uint32 // to be used as an atomic + num0RTTDropped uint32 + ) + + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + DelayPacket: func(_ quicproxy.Direction, data []byte) time.Duration { + if wire.IsLongHeaderPacket(data[0]) { + hdr, _, _, err := wire.ParsePacket(data) + Expect(err).ToNot(HaveOccurred()) + if hdr.Type == protocol.PacketType0RTT { + atomic.AddUint32(&num0RTTPackets, 1) + } + } + return rtt / 2 + }, + DropPacket: func(_ quicproxy.Direction, data []byte) bool { + if !wire.IsLongHeaderPacket(data[0]) { + return false + } + hdr, _, _, err := wire.ParsePacket(data) + Expect(err).ToNot(HaveOccurred()) + if hdr.Type == protocol.PacketType0RTT { + // drop 25% of the 0-RTT packets + drop := mrand.Intn(4) == 0 + if drop { + atomic.AddUint32(&num0RTTDropped, 1) + } + return drop + } + return false + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) + + num0RTT := atomic.LoadUint32(&num0RTTPackets) + numDropped := atomic.LoadUint32(&num0RTTDropped) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets. Dropped %d of those.", num0RTT, numDropped) + Expect(numDropped).ToNot(BeZero()) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).ToNot(BeEmpty()) + }) + + It("retransmits all 0-RTT data when the server performs a Retry", func() { + var mutex sync.Mutex + var firstConnID, secondConnID *protocol.ConnectionID + var firstCounter, secondCounter protocol.ByteCount + + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) { + for len(data) > 0 { + hdr, _, rest, err := wire.ParsePacket(data) + if err != nil { + return + } + data = rest + if hdr.Type == protocol.PacketType0RTT { + n += hdr.Length - 16 /* AEAD tag */ + } + } + return + } + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + RequireAddressValidation: func(net.Addr) bool { return true }, + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: fmt.Sprintf("localhost:%d", ln.Addr().(*net.UDPAddr).Port), + DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { + connID, err := wire.ParseConnectionID(data, 0) + Expect(err).ToNot(HaveOccurred()) + + mutex.Lock() + defer mutex.Unlock() + + if zeroRTTBytes := countZeroRTTBytes(data); zeroRTTBytes > 0 { + if firstConnID == nil { + firstConnID = &connID + firstCounter += zeroRTTBytes + } else if firstConnID != nil && *firstConnID == connID { + Expect(secondConnID).To(BeNil()) + firstCounter += zeroRTTBytes + } else if secondConnID == nil { + secondConnID = &connID + secondCounter += zeroRTTBytes + } else if secondConnID != nil && *secondConnID == connID { + secondCounter += zeroRTTBytes + } else { + Fail("received 3 connection IDs on 0-RTT packets") + } + } + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, GeneratePRData(5000)) // ~5 packets + + mutex.Lock() + defer mutex.Unlock() + Expect(firstCounter).To(BeNumerically("~", 5000+100 /* framing overhead */, 100)) // the FIN bit might be sent extra + Expect(secondCounter).To(BeNumerically("~", firstCounter, 20)) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(len(zeroRTTPackets)).To(BeNumerically(">=", 5)) + Expect(zeroRTTPackets[0]).To(BeNumerically(">=", protocol.PacketNumber(5))) + }) + + It("doesn't reject 0-RTT when the server's transport stream limit increased", func() { + const maxStreams = 1 + tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + MaxIncomingUniStreams: maxStreams, + })) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + MaxIncomingUniStreams: maxStreams + 1, + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + // The client remembers the old limit and refuses to open a new stream. + _, err = conn.OpenUniStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("too many open streams")) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err = conn.OpenUniStreamSync(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + }) + + It("rejects 0-RTT when the server's stream limit decreased", func() { + const maxStreams = 42 + tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + MaxIncomingStreams: maxStreams, + })) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + MaxIncomingStreams: maxStreams - 1, + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + check0RTTRejected(ln, proxy.LocalPort(), clientConf) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + }) + + It("rejects 0-RTT when the ALPN changed", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + // now close the listener and dial new connection with a different ALPN + clientConf.NextProtos = []string{"new-alpn"} + tlsConf.NextProtos = []string{"new-alpn"} + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + check0RTTRejected(ln, proxy.LocalPort(), clientConf) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + }) + + It("rejects 0-RTT when the application doesn't allow it", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + // now close the listener and dial new connection with a different ALPN + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: false, // application rejects 0-RTT + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + check0RTTRejected(ln, proxy.LocalPort(), clientConf) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + }) + + DescribeTable("flow control limits", + func(addFlowControlLimit func(*quic.Config, uint64)) { + tracer := newPacketTracer() + firstConf := getQuicConfig(&quic.Config{Allow0RTT: true}) + addFlowControlLimit(firstConf, 3) + tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) + + secondConf := getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }) + addFlowControlLimit(secondConf, 100) + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + secondConf, + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, _ := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + str, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + written := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(written) + _, err := str.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + }() + + Eventually(written).Should(BeClosed()) + + serverConn, err := ln.Accept(context.Background()) + Expect(err).ToNot(HaveOccurred()) + rstr, err := serverConn.AcceptUniStream(context.Background()) + Expect(err).ToNot(HaveOccurred()) + data, err := io.ReadAll(rstr) + Expect(err).ToNot(HaveOccurred()) + Expect(data).To(Equal([]byte("foobar"))) + Expect(serverConn.ConnectionState().Used0RTT).To(BeTrue()) + Expect(serverConn.CloseWithError(0, "")).To(Succeed()) + Eventually(conn.Context().Done()).Should(BeClosed()) + + var processedFirst bool + for _, p := range tracer.getRcvdLongHeaderPackets() { + for _, f := range p.frames { + if sf, ok := f.(*logging.StreamFrame); ok { + if !processedFirst { + // The first STREAM should have been sent in a 0-RTT packet. + // Due to the flow control limit, the STREAM frame was limit to the first 3 bytes. + Expect(p.hdr.Type).To(Equal(protocol.PacketType0RTT)) + Expect(sf.Length).To(BeEquivalentTo(3)) + processedFirst = true + } else { + Fail("STREAM was shouldn't have been sent in 0-RTT") + } + } + } + } + }, + Entry("doesn't reject 0-RTT when the server's transport stream flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialStreamReceiveWindow = limit }), + Entry("doesn't reject 0-RTT when the server's transport connection flow control limit increased", func(c *quic.Config, limit uint64) { c.InitialConnectionReceiveWindow = limit }), + ) + + for _, l := range []int{0, 15} { + connIDLen := l + + It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + // now dial new connection with different transport parameters + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + MaxIncomingUniStreams: 1, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) + defer proxy.Close() + + conn, err := quic.DialAddrEarly( + context.Background(), + fmt.Sprintf("localhost:%d", proxy.LocalPort()), + clientConf, + getQuicConfig(nil), + ) + Expect(err).ToNot(HaveOccurred()) + // The client remembers that it was allowed to open 2 uni-directional streams. + firstStr, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + written := make(chan struct{}, 2) + go func() { + defer GinkgoRecover() + defer func() { written <- struct{}{} }() + _, err := firstStr.Write([]byte("first flight")) + Expect(err).ToNot(HaveOccurred()) + }() + secondStr, err := conn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + go func() { + defer GinkgoRecover() + defer func() { written <- struct{}{} }() + _, err := secondStr.Write([]byte("first flight")) + Expect(err).ToNot(HaveOccurred()) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + _, err = conn.AcceptStream(ctx) + Expect(err).To(MatchError(quic.Err0RTTRejected)) + Eventually(written).Should(Receive()) + Eventually(written).Should(Receive()) + _, err = firstStr.Write([]byte("foobar")) + Expect(err).To(MatchError(quic.Err0RTTRejected)) + _, err = conn.OpenUniStream() + Expect(err).To(MatchError(quic.Err0RTTRejected)) + + _, err = conn.AcceptStream(ctx) + Expect(err).To(Equal(quic.Err0RTTRejected)) + + newConn := conn.NextConnection() + str, err := newConn.OpenUniStream() + Expect(err).ToNot(HaveOccurred()) + _, err = newConn.OpenUniStream() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("too many open streams")) + _, err = str.Write([]byte("second flight")) + Expect(err).ToNot(HaveOccurred()) + Expect(str.Close()).To(Succeed()) + Expect(conn.CloseWithError(0, "")).To(Succeed()) + + // The client should send 0-RTT packets, but the server doesn't process them. + num0RTT := atomic.LoadUint32(num0RTTPackets) + fmt.Fprintf(GinkgoWriter, "Sent %d 0-RTT packets.", num0RTT) + Expect(num0RTT).ToNot(BeZero()) + Expect(get0RTTPackets(tracer.getRcvdLongHeaderPackets())).To(BeEmpty()) + }) + } + + It("queues 0-RTT packets, if the Initial is delayed", func() { + tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + proxy, err := quicproxy.NewQuicProxy("localhost:0", &quicproxy.Opts{ + RemoteAddr: ln.Addr().String(), + DelayPacket: func(dir quicproxy.Direction, data []byte) time.Duration { + if dir == quicproxy.DirectionIncoming && wire.IsLongHeaderPacket(data[0]) && data[0]&0x30>>4 == 0 { // Initial packet from client + return rtt/2 + rtt + } + return rtt / 2 + }, + }) + Expect(err).ToNot(HaveOccurred()) + defer proxy.Close() + + transfer0RTTData(ln, proxy.LocalPort(), protocol.DefaultConnectionIDLength, clientConf, nil, PRData) + + Expect(tracer.getRcvdLongHeaderPackets()[0].hdr.Type).To(Equal(protocol.PacketTypeInitial)) + zeroRTTPackets := get0RTTPackets(tracer.getRcvdLongHeaderPackets()) + Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) + Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) + }) +}) diff --git a/integrationtests/self/zero_rtt_test.go b/integrationtests/self/zero_rtt_test.go index 001ea5da252..4e8d1d052be 100644 --- a/integrationtests/self/zero_rtt_test.go +++ b/integrationtests/self/zero_rtt_test.go @@ -1,3 +1,5 @@ +//go:build go1.21 + package self_test import ( @@ -21,6 +23,36 @@ import ( . "github.com/onsi/gomega" ) +type metadataClientSessionCache struct { + toAdd []byte + restored func([]byte) + + cache tls.ClientSessionCache +} + +func (m metadataClientSessionCache) Get(key string) (*tls.ClientSessionState, bool) { + session, ok := m.cache.Get(key) + if !ok || session == nil { + return session, ok + } + ticket, state, err := session.ResumptionState() + Expect(err).ToNot(HaveOccurred()) + Expect(state.Extra).To(HaveLen(2)) // ours, and the quic-go's + m.restored(state.Extra[1]) + session, err = tls.NewResumptionState(ticket, state) + Expect(err).ToNot(HaveOccurred()) + return session, true +} + +func (m metadataClientSessionCache) Put(key string, session *tls.ClientSessionState) { + ticket, state, err := session.ResumptionState() + Expect(err).ToNot(HaveOccurred()) + state.Extra = append(state.Extra, m.toAdd) + session, err = tls.NewResumptionState(ticket, state) + Expect(err).ToNot(HaveOccurred()) + m.cache.Put(key, session) +} + var _ = Describe("0-RTT", func() { rtt := scaleDuration(5 * time.Millisecond) @@ -49,15 +81,14 @@ var _ = Describe("0-RTT", func() { return proxy, &num0RTTPackets } - dialAndReceiveSessionTicket := func(serverConf *quic.Config) (*tls.Config, *tls.Config) { - tlsConf := getTLSConfig() + dialAndReceiveSessionTicket := func(serverTLSConf *tls.Config, serverConf *quic.Config, clientTLSConf *tls.Config) { if serverConf == nil { serverConf = getQuicConfig(nil) } serverConf.Allow0RTT = true ln, err := quic.ListenAddrEarly( "localhost:0", - tlsConf, + serverTLSConf, serverConf, ) Expect(err).ToNot(HaveOccurred()) @@ -80,14 +111,16 @@ var _ = Describe("0-RTT", func() { <-conn.Context().Done() }() - clientConf := getTLSClientConfig() - gets := make(chan string, 100) puts := make(chan string, 100) - clientConf.ClientSessionCache = newClientSessionCache(gets, puts) + cache := clientTLSConf.ClientSessionCache + if cache == nil { + cache = tls.NewLRUClientSessionCache(100) + } + clientTLSConf.ClientSessionCache = newClientSessionCache(cache, make(chan string, 100), puts) conn, err := quic.DialAddr( context.Background(), fmt.Sprintf("localhost:%d", proxy.LocalPort()), - clientConf, + clientTLSConf, getQuicConfig(nil), ) Expect(err).ToNot(HaveOccurred()) @@ -95,7 +128,6 @@ var _ = Describe("0-RTT", func() { // received the session ticket. We're done here. Expect(conn.CloseWithError(0, "")).To(Succeed()) Eventually(done).Should(BeClosed()) - return tlsConf, clientConf } transfer0RTTData := func( @@ -118,7 +150,7 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal(testdata)) Expect(str.Close()).To(Succeed()) - Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) <-conn.Context().Done() close(done) }() @@ -162,7 +194,7 @@ var _ = Describe("0-RTT", func() { Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) <-conn.HandshakeComplete() - Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) io.ReadAll(str) // wait for the EOF from the server to arrive before closing the conn conn.CloseWithError(0, "") Eventually(done).Should(BeClosed()) @@ -186,14 +218,14 @@ var _ = Describe("0-RTT", func() { _, err = str.Write(make([]byte, 3000)) Expect(err).ToNot(HaveOccurred()) Expect(str.Close()).To(Succeed()) - Expect(conn.ConnectionState().TLS.Used0RTT).To(BeFalse()) + Expect(conn.ConnectionState().Used0RTT).To(BeFalse()) // make sure the server doesn't process the data ctx, cancel := context.WithTimeout(context.Background(), scaleDuration(50*time.Millisecond)) defer cancel() serverConn, err := ln.Accept(ctx) Expect(err).ToNot(HaveOccurred()) - Expect(serverConn.ConnectionState().TLS.Used0RTT).To(BeFalse()) + Expect(serverConn.ConnectionState().Used0RTT).To(BeFalse()) _, err = serverConn.AcceptUniStream(ctx) Expect(err).To(Equal(context.DeadlineExceeded)) Expect(serverConn.CloseWithError(0, "")).To(Succeed()) @@ -215,7 +247,9 @@ var _ = Describe("0-RTT", func() { connIDLen := l It(fmt.Sprintf("transfers 0-RTT data, with %d byte connection IDs", connIDLen), func() { - tlsConf, clientTLSConf := dialAndReceiveSessionTicket(nil) + tlsConf := getTLSConfig() + clientTLSConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf) tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( @@ -266,7 +300,9 @@ var _ = Describe("0-RTT", func() { // Test that data intended to be sent with 1-RTT protection is not sent in 0-RTT packets. It("waits for a connection until the handshake is done", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + tlsConf := getTLSConfig() + clientConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, nil, clientConf) zeroRTTData := GeneratePRData(5 << 10) oneRTTData := PRData @@ -351,7 +387,9 @@ var _ = Describe("0-RTT", func() { num0RTTDropped uint32 ) - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + tlsConf := getTLSConfig() + clientConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, nil, clientConf) tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( @@ -412,7 +450,9 @@ var _ = Describe("0-RTT", func() { var firstConnID, secondConnID *protocol.ConnectionID var firstCounter, secondCounter protocol.ByteCount - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + tlsConf := getTLSConfig() + clientConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, nil, clientConf) countZeroRTTBytes := func(data []byte) (n protocol.ByteCount) { for len(data) > 0 { @@ -485,9 +525,11 @@ var _ = Describe("0-RTT", func() { It("doesn't reject 0-RTT when the server's transport stream limit increased", func() { const maxStreams = 1 - tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + tlsConf := getTLSConfig() + clientConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{ MaxIncomingUniStreams: maxStreams, - })) + }), clientConf) tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( @@ -524,15 +566,17 @@ var _ = Describe("0-RTT", func() { defer cancel() _, err = conn.OpenUniStreamSync(ctx) Expect(err).ToNot(HaveOccurred()) - Expect(conn.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(conn.ConnectionState().Used0RTT).To(BeTrue()) Expect(conn.CloseWithError(0, "")).To(Succeed()) }) It("rejects 0-RTT when the server's stream limit decreased", func() { const maxStreams = 42 - tlsConf, clientConf := dialAndReceiveSessionTicket(getQuicConfig(&quic.Config{ + tlsConf := getTLSConfig() + clientConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, getQuicConfig(&quic.Config{ MaxIncomingStreams: maxStreams, - })) + }), clientConf) tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( @@ -548,6 +592,7 @@ var _ = Describe("0-RTT", func() { defer ln.Close() proxy, num0RTTPackets := runCountingProxy(ln.Addr().(*net.UDPAddr).Port) defer proxy.Close() + check0RTTRejected(ln, proxy.LocalPort(), clientConf) // The client should send 0-RTT packets, but the server doesn't process them. @@ -558,11 +603,15 @@ var _ = Describe("0-RTT", func() { }) It("rejects 0-RTT when the ALPN changed", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + tlsConf := getTLSConfig() + clientConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, nil, clientConf) - // now close the listener and dial new connection with a different ALPN - clientConf.NextProtos = []string{"new-alpn"} + // switch to different ALPN on the server side tlsConf.NextProtos = []string{"new-alpn"} + // Append to the client's ALPN. + // crypto/tls will attempt to resume with the ALPN from the original connection + clientConf.NextProtos = append(clientConf.NextProtos, "new-alpn") tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( "localhost:0", @@ -587,7 +636,9 @@ var _ = Describe("0-RTT", func() { }) It("rejects 0-RTT when the application doesn't allow it", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + tlsConf := getTLSConfig() + clientConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, nil, clientConf) // now close the listener and dial new connection with a different ALPN tracer := newPacketTracer() @@ -618,7 +669,9 @@ var _ = Describe("0-RTT", func() { tracer := newPacketTracer() firstConf := getQuicConfig(&quic.Config{Allow0RTT: true}) addFlowControlLimit(firstConf, 3) - tlsConf, clientConf := dialAndReceiveSessionTicket(firstConf) + tlsConf := getTLSConfig() + clientConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, firstConf, clientConf) secondConf := getQuicConfig(&quic.Config{ Allow0RTT: true, @@ -662,7 +715,7 @@ var _ = Describe("0-RTT", func() { data, err := io.ReadAll(rstr) Expect(err).ToNot(HaveOccurred()) Expect(data).To(Equal([]byte("foobar"))) - Expect(serverConn.ConnectionState().TLS.Used0RTT).To(BeTrue()) + Expect(serverConn.ConnectionState().Used0RTT).To(BeTrue()) Expect(serverConn.CloseWithError(0, "")).To(Succeed()) Eventually(conn.Context().Done()).Should(BeClosed()) @@ -691,7 +744,9 @@ var _ = Describe("0-RTT", func() { connIDLen := l It(fmt.Sprintf("correctly deals with 0-RTT rejections, for %d byte connection IDs", connIDLen), func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + tlsConf := getTLSConfig() + clientConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, nil, clientConf) // now dial new connection with different transport parameters tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( @@ -767,7 +822,9 @@ var _ = Describe("0-RTT", func() { } It("queues 0-RTT packets, if the Initial is delayed", func() { - tlsConf, clientConf := dialAndReceiveSessionTicket(nil) + tlsConf := getTLSConfig() + clientConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, nil, clientConf) tracer := newPacketTracer() ln, err := quic.ListenAddrEarly( @@ -799,4 +856,87 @@ var _ = Describe("0-RTT", func() { Expect(len(zeroRTTPackets)).To(BeNumerically(">", 10)) Expect(zeroRTTPackets[0]).To(Equal(protocol.PacketNumber(0))) }) + + It("allows the application to attach data to the session ticket, for the server", func() { + tlsConf := getTLSConfig() + tlsConf.WrapSession = func(cs tls.ConnectionState, ss *tls.SessionState) ([]byte, error) { + ss.Extra = append(ss.Extra, []byte("foobar")) + return tlsConf.EncryptTicket(cs, ss) + } + var unwrapped bool + tlsConf.UnwrapSession = func(identity []byte, cs tls.ConnectionState) (*tls.SessionState, error) { + defer GinkgoRecover() + state, err := tlsConf.DecryptTicket(identity, cs) + if err != nil { + return nil, err + } + Expect(len(state.Extra)).To(BeNumerically(">=", 6)) + Expect(state.Extra[len(state.Extra)-6:]).To(Equal([]byte("foobar"))) + state.Extra = state.Extra[:len(state.Extra)-6] // remove the foobar + unwrapped = true + return state, nil + } + clientTLSConf := getTLSClientConfig() + dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + transfer0RTTData( + ln, + ln.Addr().(*net.UDPAddr).Port, + 10, + clientTLSConf, + getQuicConfig(nil), + PRData, + ) + Expect(unwrapped).To(BeTrue()) + }) + + It("allows the application to attach data to the session ticket, for the client", func() { + tlsConf := getTLSConfig() + clientTLSConf := getTLSClientConfig() + var restored bool + clientTLSConf.ClientSessionCache = &metadataClientSessionCache{ + toAdd: []byte("foobar"), + restored: func(b []byte) { + defer GinkgoRecover() + Expect(b).To(Equal([]byte("foobar"))) + restored = true + }, + cache: tls.NewLRUClientSessionCache(100), + } + dialAndReceiveSessionTicket(tlsConf, nil, clientTLSConf) + + tracer := newPacketTracer() + ln, err := quic.ListenAddrEarly( + "localhost:0", + tlsConf, + getQuicConfig(&quic.Config{ + Allow0RTT: true, + Tracer: newTracer(tracer), + }), + ) + Expect(err).ToNot(HaveOccurred()) + defer ln.Close() + + transfer0RTTData( + ln, + ln.Addr().(*net.UDPAddr).Port, + 10, + clientTLSConf, + getQuicConfig(nil), + PRData, + ) + Expect(restored).To(BeTrue()) + }) }) diff --git a/interface.go b/interface.go index 8e6213bfcbf..7c6f892d97e 100644 --- a/interface.go +++ b/interface.go @@ -2,6 +2,7 @@ package quic import ( "context" + "crypto/tls" "errors" "io" "net" @@ -336,12 +337,14 @@ type ClientHelloInfo struct { // ConnectionState records basic details about a QUIC connection type ConnectionState struct { // TLS contains information about the TLS connection state, incl. the tls.ConnectionState. - TLS handshake.ConnectionState + TLS tls.ConnectionState // SupportsDatagrams says if support for QUIC datagrams (RFC 9221) was negotiated. // This requires both nodes to support and enable the datagram extensions (via Config.EnableDatagrams). // If datagram support was negotiated, datagrams can be sent and received using the // SendMessage and ReceiveMessage methods on the Connection. SupportsDatagrams bool + // Used0RTT says if 0-RTT resumption was used. + Used0RTT bool // Version is the QUIC version of the QUIC connection. Version VersionNumber } diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index 410745f1a81..ccda43ca5eb 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -5,11 +5,10 @@ import ( "encoding/binary" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/internal/utils" ) -func createAEAD(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD { +func createAEAD(suite *cipherSuite, trafficSecret []byte, v protocol.VersionNumber) cipher.AEAD { keyLabel := hkdfLabelKeyV1 ivLabel := hkdfLabelIVV1 if v == protocol.Version2 { diff --git a/internal/handshake/cipher_suite.go b/internal/handshake/cipher_suite.go new file mode 100644 index 00000000000..608d5ea00e6 --- /dev/null +++ b/internal/handshake/cipher_suite.go @@ -0,0 +1,104 @@ +package handshake + +import ( + "crypto" + "crypto/aes" + "crypto/cipher" + "crypto/tls" + "fmt" + + "golang.org/x/crypto/chacha20poly1305" +) + +// These cipher suite implementations are copied from the standard library crypto/tls package. + +const aeadNonceLength = 12 + +type cipherSuite struct { + ID uint16 + Hash crypto.Hash + KeyLen int + AEAD func(key, nonceMask []byte) cipher.AEAD +} + +func (s cipherSuite) IVLen() int { return aeadNonceLength } + +func getCipherSuite(id uint16) *cipherSuite { + switch id { + case tls.TLS_AES_128_GCM_SHA256: + return &cipherSuite{ID: tls.TLS_AES_128_GCM_SHA256, Hash: crypto.SHA256, KeyLen: 16, AEAD: aeadAESGCMTLS13} + case tls.TLS_CHACHA20_POLY1305_SHA256: + return &cipherSuite{ID: tls.TLS_CHACHA20_POLY1305_SHA256, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadChaCha20Poly1305} + case tls.TLS_AES_256_GCM_SHA384: + return &cipherSuite{ID: tls.TLS_AES_256_GCM_SHA384, Hash: crypto.SHA256, KeyLen: 32, AEAD: aeadAESGCMTLS13} + default: + panic(fmt.Sprintf("unknown cypher suite: %d", id)) + } +} + +func aeadAESGCMTLS13(key, nonceMask []byte) cipher.AEAD { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +func aeadChaCha20Poly1305(key, nonceMask []byte) cipher.AEAD { + if len(nonceMask) != aeadNonceLength { + panic("tls: internal error: wrong nonce length") + } + aead, err := chacha20poly1305.New(key) + if err != nil { + panic(err) + } + + ret := &xorNonceAEAD{aead: aead} + copy(ret.nonceMask[:], nonceMask) + return ret +} + +// xorNonceAEAD wraps an AEAD by XORing in a fixed pattern to the nonce +// before each call. +type xorNonceAEAD struct { + nonceMask [aeadNonceLength]byte + aead cipher.AEAD +} + +func (f *xorNonceAEAD) NonceSize() int { return 8 } // 64-bit sequence number +func (f *xorNonceAEAD) Overhead() int { return f.aead.Overhead() } +func (f *xorNonceAEAD) explicitNonceLen() int { return 0 } + +func (f *xorNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result := f.aead.Seal(out, f.nonceMask[:], plaintext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result +} + +func (f *xorNonceAEAD) Open(out, nonce, ciphertext, additionalData []byte) ([]byte, error) { + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + result, err := f.aead.Open(out, f.nonceMask[:], ciphertext, additionalData) + for i, b := range nonce { + f.nonceMask[4+i] ^= b + } + + return result, err +} diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index a11c0d2366c..29e66f57f4c 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -7,9 +7,8 @@ import ( "errors" "fmt" "io" - "math" - "net" "sync" + "sync/atomic" "time" "github.com/quic-go/quic-go/internal/protocol" @@ -25,98 +24,22 @@ type quicVersionContextKey struct{} var QUICVersionContextKey = &quicVersionContextKey{} -// TLS unexpected_message alert -const alertUnexpectedMessage uint8 = 10 - -type messageType uint8 - -// TLS handshake message types. -const ( - typeClientHello messageType = 1 - typeServerHello messageType = 2 - typeNewSessionTicket messageType = 4 - typeEncryptedExtensions messageType = 8 - typeCertificate messageType = 11 - typeCertificateRequest messageType = 13 - typeCertificateVerify messageType = 15 - typeFinished messageType = 20 -) - -func (m messageType) String() string { - switch m { - case typeClientHello: - return "ClientHello" - case typeServerHello: - return "ServerHello" - case typeNewSessionTicket: - return "NewSessionTicket" - case typeEncryptedExtensions: - return "EncryptedExtensions" - case typeCertificate: - return "Certificate" - case typeCertificateRequest: - return "CertificateRequest" - case typeCertificateVerify: - return "CertificateVerify" - case typeFinished: - return "Finished" - default: - return fmt.Sprintf("unknown message type: %d", m) - } -} - const clientSessionStateRevision = 3 -type conn struct { - localAddr, remoteAddr net.Addr -} - -var _ net.Conn = &conn{} - -func newConn(local, remote net.Addr) net.Conn { - return &conn{ - localAddr: local, - remoteAddr: remote, - } -} - -func (c *conn) Read([]byte) (int, error) { return 0, nil } -func (c *conn) Write([]byte) (int, error) { return 0, nil } -func (c *conn) Close() error { return nil } -func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } -func (c *conn) LocalAddr() net.Addr { return c.localAddr } -func (c *conn) SetReadDeadline(time.Time) error { return nil } -func (c *conn) SetWriteDeadline(time.Time) error { return nil } -func (c *conn) SetDeadline(time.Time) error { return nil } - type cryptoSetup struct { - tlsConf *tls.Config - extraConf *qtls.ExtraConfig - conn *qtls.Conn + tlsConf *tls.Config + conn *qtls.QUICConn version protocol.VersionNumber - messageChan chan []byte - isReadingHandshakeMessage chan struct{} - readFirstHandshakeMessage bool - ourParams *wire.TransportParameters peerParams *wire.TransportParameters - paramsChan <-chan []byte runner handshakeRunner - alertChan chan uint8 - // handshakeDone is closed as soon as the go routine running qtls.Handshake() returns - handshakeDone chan struct{} - // is closed when Close() is called - closeChan chan struct{} - - zeroRTTParameters *wire.TransportParameters - clientHelloWritten bool - clientHelloWrittenChan chan struct{} // is closed as soon as the ClientHello is written - zeroRTTParametersChan chan<- *wire.TransportParameters - allow0RTT bool + zeroRTTParameters *wire.TransportParameters + zeroRTTParametersChan chan<- *wire.TransportParameters + allow0RTT bool rttStats *utils.RTTStats @@ -129,9 +52,6 @@ type cryptoSetup struct { handshakeCompleteTime time.Time - readEncLevel protocol.EncryptionLevel - writeEncLevel protocol.EncryptionLevel - zeroRTTOpener LongHeaderOpener // only set for the server zeroRTTSealer LongHeaderSealer // only set for the client @@ -143,23 +63,20 @@ type cryptoSetup struct { handshakeOpener LongHeaderOpener handshakeSealer LongHeaderSealer + used0RTT atomic.Bool + + oneRTTStream io.Writer aead *updatableAEAD has1RTTSealer bool has1RTTOpener bool } -var ( - _ qtls.RecordLayer = &cryptoSetup{} - _ CryptoSetup = &cryptoSetup{} -) +var _ CryptoSetup = &cryptoSetup{} // NewCryptoSetupClient creates a new crypto setup for the client func NewCryptoSetupClient( - initialStream io.Writer, - handshakeStream io.Writer, + initialStream, handshakeStream, oneRTTStream io.Writer, connID protocol.ConnectionID, - localAddr net.Addr, - remoteAddr net.Addr, tp *wire.TransportParameters, runner handshakeRunner, tlsConf *tls.Config, @@ -172,28 +89,33 @@ func NewCryptoSetupClient( cs, clientHelloWritten := newCryptoSetup( initialStream, handshakeStream, + oneRTTStream, connID, tp, runner, - tlsConf, - enable0RTT, rttStats, tracer, logger, protocol.PerspectiveClient, version, ) - cs.conn = qtls.Client(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) + + tlsConf = tlsConf.Clone() + tlsConf.MinVersion = tls.VersionTLS13 + quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} + qtls.SetupConfigForClient(quicConf, cs.marshalDataForSessionState, cs.handleDataFromSessionState) + cs.tlsConf = tlsConf + + cs.conn = qtls.QUICClient(quicConf) + cs.conn.SetTransportParameters(cs.ourParams.Marshal(protocol.PerspectiveClient)) + return cs, clientHelloWritten } // NewCryptoSetupServer creates a new crypto setup for the server func NewCryptoSetupServer( - initialStream io.Writer, - handshakeStream io.Writer, + initialStream, handshakeStream, oneRTTStream io.Writer, connID protocol.ConnectionID, - localAddr net.Addr, - remoteAddr net.Addr, tp *wire.TransportParameters, runner handshakeRunner, tlsConf *tls.Config, @@ -206,29 +128,32 @@ func NewCryptoSetupServer( cs, _ := newCryptoSetup( initialStream, handshakeStream, + oneRTTStream, connID, tp, runner, - tlsConf, - allow0RTT, rttStats, tracer, logger, protocol.PerspectiveServer, version, ) - cs.conn = qtls.Server(newConn(localAddr, remoteAddr), cs.tlsConf, cs.extraConf) + cs.allow0RTT = allow0RTT + + quicConf := &qtls.QUICConfig{TLSConfig: tlsConf} + qtls.SetupConfigForServer(quicConf, cs.allow0RTT, cs.getDataForSessionTicket, cs.accept0RTT) + + cs.tlsConf = quicConf.TLSConfig + cs.conn = qtls.QUICServer(quicConf) + return cs } func newCryptoSetup( - initialStream io.Writer, - handshakeStream io.Writer, + initialStream, handshakeStream, oneRTTStream io.Writer, connID protocol.ConnectionID, tp *wire.TransportParameters, runner handshakeRunner, - tlsConf *tls.Config, - enable0RTT bool, rttStats *utils.RTTStats, tracer logging.ConnectionTracer, logger utils.Logger, @@ -240,51 +165,23 @@ func newCryptoSetup( tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveClient) tracer.UpdatedKeyFromTLS(protocol.EncryptionInitial, protocol.PerspectiveServer) } - extHandler := newExtensionHandler(tp.Marshal(perspective), perspective) zeroRTTParametersChan := make(chan *wire.TransportParameters, 1) - cs := &cryptoSetup{ - tlsConf: tlsConf, - initialStream: initialStream, - initialSealer: initialSealer, - initialOpener: initialOpener, - handshakeStream: handshakeStream, - aead: newUpdatableAEAD(rttStats, tracer, logger, version), - readEncLevel: protocol.EncryptionInitial, - writeEncLevel: protocol.EncryptionInitial, - runner: runner, - allow0RTT: enable0RTT, - ourParams: tp, - paramsChan: extHandler.TransportParameters(), - rttStats: rttStats, - tracer: tracer, - logger: logger, - perspective: perspective, - handshakeDone: make(chan struct{}), - alertChan: make(chan uint8), - clientHelloWrittenChan: make(chan struct{}), - zeroRTTParametersChan: zeroRTTParametersChan, - messageChan: make(chan []byte, 1), - isReadingHandshakeMessage: make(chan struct{}), - closeChan: make(chan struct{}), - version: version, - } - var maxEarlyData uint32 - if enable0RTT { - maxEarlyData = math.MaxUint32 - } - cs.extraConf = &qtls.ExtraConfig{ - GetExtensions: extHandler.GetExtensions, - ReceivedExtensions: extHandler.ReceivedExtensions, - AlternativeRecordLayer: cs, - EnforceNextProtoSelection: true, - MaxEarlyData: maxEarlyData, - Accept0RTT: cs.accept0RTT, - Rejected0RTT: cs.rejected0RTT, - Enable0RTT: enable0RTT, - GetAppDataForSessionState: cs.marshalDataForSessionState, - SetAppDataFromSessionState: cs.handleDataFromSessionState, - } - return cs, zeroRTTParametersChan + return &cryptoSetup{ + initialStream: initialStream, + initialSealer: initialSealer, + initialOpener: initialOpener, + handshakeStream: handshakeStream, + oneRTTStream: oneRTTStream, + aead: newUpdatableAEAD(rttStats, tracer, logger, version), + runner: runner, + ourParams: tp, + rttStats: rttStats, + tracer: tracer, + logger: logger, + perspective: perspective, + zeroRTTParametersChan: zeroRTTParametersChan, + version: version, + }, zeroRTTParametersChan } func (h *cryptoSetup) ChangeConnectionID(id protocol.ConnectionID) { @@ -301,142 +198,100 @@ func (h *cryptoSetup) SetLargest1RTTAcked(pn protocol.PacketNumber) error { return h.aead.SetLargestAcked(pn) } -func (h *cryptoSetup) RunHandshake() { - // Handle errors that might occur when HandleData() is called. - handshakeComplete := make(chan struct{}) - handshakeErrChan := make(chan error, 1) - go func() { - defer close(h.handshakeDone) - if err := h.conn.HandshakeContext(context.WithValue(context.Background(), QUICVersionContextKey, h.version)); err != nil { - handshakeErrChan <- err - return +func (h *cryptoSetup) StartHandshake() error { + err := h.conn.Start(context.WithValue(context.Background(), QUICVersionContextKey, h.version)) + if err != nil { + return wrapError(err) + } + for { + ev := h.conn.NextEvent() + done, err := h.handleEvent(ev) + if err != nil { + return wrapError(err) } - close(handshakeComplete) - }() - - if h.perspective == protocol.PerspectiveClient { - select { - case err := <-handshakeErrChan: - h.onError(0, err.Error()) - return - case <-h.clientHelloWrittenChan: + if done { + break } } - - select { - case <-handshakeComplete: // return when the handshake is done - h.mutex.Lock() - h.handshakeCompleteTime = time.Now() - h.mutex.Unlock() - h.runner.OnHandshakeComplete() - case <-h.closeChan: - // wait until the Handshake() go routine has returned - <-h.handshakeDone - case alert := <-h.alertChan: - handshakeErr := <-handshakeErrChan - h.onError(alert, handshakeErr.Error()) - } -} - -func (h *cryptoSetup) onError(alert uint8, message string) { - var err error - if alert == 0 { - err = &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: message} - } else { - err = qerr.NewLocalCryptoError(alert, message) + if h.perspective == protocol.PerspectiveClient { + if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil { + h.logger.Debugf("Doing 0-RTT.") + h.zeroRTTParametersChan <- h.zeroRTTParameters + } else { + h.logger.Debugf("Not doing 0-RTT. Has sealer: %t, has params: %t", h.zeroRTTSealer != nil, h.zeroRTTParameters != nil) + h.zeroRTTParametersChan <- nil + } } - h.runner.OnError(err) + return nil } // Close closes the crypto setup. // It aborts the handshake, if it is still running. -// It must only be called once. func (h *cryptoSetup) Close() error { - close(h.closeChan) - // wait until qtls.Handshake() actually returned - <-h.handshakeDone - return nil + return h.conn.Close() } -// handleMessage handles a TLS handshake message. +// HandleMessage handles a TLS handshake message. // It is called by the crypto streams when a new message is available. -// It returns if it is done with messages on the same encryption level. -func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) bool /* stream finished */ { - msgType := messageType(data[0]) - h.logger.Debugf("Received %s message (%d bytes, encryption level: %s)", msgType, len(data), encLevel) - if err := h.checkEncryptionLevel(msgType, encLevel); err != nil { - h.onError(alertUnexpectedMessage, err.Error()) - return false +func (h *cryptoSetup) HandleMessage(data []byte, encLevel protocol.EncryptionLevel) error { + if err := h.handleMessage(data, encLevel); err != nil { + return wrapError(err) } - if encLevel != protocol.Encryption1RTT { - select { - case h.messageChan <- data: - case <-h.handshakeDone: // handshake errored, nobody is going to consume this message - return false - } - } - if encLevel == protocol.Encryption1RTT { - h.messageChan <- data - h.handlePostHandshakeMessage() - return false + return nil +} + +func (h *cryptoSetup) handleMessage(data []byte, encLevel protocol.EncryptionLevel) error { + if err := h.conn.HandleData(qtls.ToTLSEncryptionLevel(encLevel), data); err != nil { + return err } -readLoop: for { - select { - case data := <-h.paramsChan: - if data == nil { - h.onError(0x6d, "missing quic_transport_parameters extension") - } else { - h.handleTransportParameters(data) - } - case <-h.isReadingHandshakeMessage: - break readLoop - case <-h.handshakeDone: - break readLoop - case <-h.closeChan: - break readLoop + ev := h.conn.NextEvent() + done, err := h.handleEvent(ev) + if err != nil { + return err + } + if done { + return nil } } - // We're done with the Initial encryption level after processing a ClientHello / ServerHello, - // but only if a handshake opener and sealer was created. - // Otherwise, a HelloRetryRequest was performed. - // We're done with the Handshake encryption level after processing the Finished message. - return ((msgType == typeClientHello || msgType == typeServerHello) && h.handshakeOpener != nil && h.handshakeSealer != nil) || - msgType == typeFinished -} - -func (h *cryptoSetup) checkEncryptionLevel(msgType messageType, encLevel protocol.EncryptionLevel) error { - var expected protocol.EncryptionLevel - switch msgType { - case typeClientHello, typeServerHello: - expected = protocol.EncryptionInitial - case typeEncryptedExtensions, - typeCertificate, - typeCertificateRequest, - typeCertificateVerify, - typeFinished: - expected = protocol.EncryptionHandshake - case typeNewSessionTicket: - expected = protocol.Encryption1RTT +} + +func (h *cryptoSetup) handleEvent(ev qtls.QUICEvent) (done bool, err error) { + switch ev.Kind { + case qtls.QUICNoEvent: + return true, nil + case qtls.QUICSetReadSecret: + h.SetReadKey(ev.Level, ev.Suite, ev.Data) + return false, nil + case qtls.QUICSetWriteSecret: + h.SetWriteKey(ev.Level, ev.Suite, ev.Data) + return false, nil + case qtls.QUICTransportParameters: + return false, h.handleTransportParameters(ev.Data) + case qtls.QUICTransportParametersRequired: + h.conn.SetTransportParameters(h.ourParams.Marshal(h.perspective)) + return false, nil + case qtls.QUICRejectedEarlyData: + h.rejected0RTT() + return false, nil + case qtls.QUICWriteData: + return false, h.WriteRecord(ev.Level, ev.Data) + case qtls.QUICHandshakeDone: + h.handshakeComplete() + return false, nil default: - return fmt.Errorf("unexpected handshake message: %d", msgType) - } - if encLevel != expected { - return fmt.Errorf("expected handshake message %s to have encryption level %s, has %s", msgType, expected, encLevel) + return false, fmt.Errorf("unexpected event: %d", ev.Kind) } - return nil } -func (h *cryptoSetup) handleTransportParameters(data []byte) { +func (h *cryptoSetup) handleTransportParameters(data []byte) error { var tp wire.TransportParameters if err := tp.Unmarshal(data, h.perspective.Opposite()); err != nil { - h.runner.OnError(&qerr.TransportError{ - ErrorCode: qerr.TransportParameterError, - ErrorMessage: err.Error(), - }) + return err } h.peerParams = &tp h.runner.OnReceivedParams(h.peerParams) + return nil } // must be called after receiving the transport parameters @@ -477,17 +332,32 @@ func (h *cryptoSetup) handleDataFromSessionStateImpl(data []byte) (*wire.Transpo return &tp, nil } -// only valid for the server +func (h *cryptoSetup) getDataForSessionTicket() []byte { + return (&sessionTicket{ + Parameters: h.ourParams, + RTT: h.rttStats.SmoothedRTT(), + }).Marshal() +} + +// GetSessionTicket generates a new session ticket. +// Due to limitations in crypto/tls, it's only possible to generate a single session ticket per connection. +// It is only valid for the server. func (h *cryptoSetup) GetSessionTicket() ([]byte, error) { - var appData []byte - // Save transport parameters to the session ticket if we're allowing 0-RTT. - if h.extraConf.MaxEarlyData > 0 { - appData = (&sessionTicket{ - Parameters: h.ourParams, - RTT: h.rttStats.SmoothedRTT(), - }).Marshal() + if h.tlsConf.SessionTicketsDisabled { + return nil, nil + } + if err := h.conn.SendSessionTicket(h.allow0RTT); err != nil { + return nil, err } - return h.conn.GetSessionTicket(appData) + ev := h.conn.NextEvent() + if ev.Kind != qtls.QUICWriteData || ev.Level != qtls.QUICEncryptionLevelApplication { + panic("crypto/tls bug: where's my session ticket?") + } + ticket := ev.Data + if ev := h.conn.NextEvent(); ev.Kind != qtls.QUICNoEvent { + panic("crypto/tls bug: why more than one ticket?") + } + return ticket, nil } // accept0RTT is called for the server when receiving the client's session ticket. @@ -526,60 +396,12 @@ func (h *cryptoSetup) rejected0RTT() { } } -func (h *cryptoSetup) handlePostHandshakeMessage() { - // make sure the handshake has already completed - <-h.handshakeDone - - done := make(chan struct{}) - defer close(done) - - // h.alertChan is an unbuffered channel. - // If an error occurs during conn.HandlePostHandshakeMessage, - // it will be sent on this channel. - // Read it from a go-routine so that HandlePostHandshakeMessage doesn't deadlock. - alertChan := make(chan uint8, 1) - go func() { - <-h.isReadingHandshakeMessage - select { - case alert := <-h.alertChan: - alertChan <- alert - case <-done: - } - }() - - if err := h.conn.HandlePostHandshakeMessage(); err != nil { - select { - case <-h.closeChan: - case alert := <-alertChan: - h.onError(alert, err.Error()) - } - } -} - -// ReadHandshakeMessage is called by TLS. -// It blocks until a new handshake message is available. -func (h *cryptoSetup) ReadHandshakeMessage() ([]byte, error) { - if !h.readFirstHandshakeMessage { - h.readFirstHandshakeMessage = true - } else { - select { - case h.isReadingHandshakeMessage <- struct{}{}: - case <-h.closeChan: - return nil, errors.New("error while handling the handshake message") - } - } - select { - case msg := <-h.messageChan: - return msg, nil - case <-h.closeChan: - return nil, errors.New("error while handling the handshake message") - } -} - -func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { +func (h *cryptoSetup) SetReadKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { + suite := getCipherSuite(suiteID) h.mutex.Lock() - switch encLevel { - case qtls.Encryption0RTT: + //nolint:exhaustive // The TLS stack doesn't export Initial keys. + switch el { + case qtls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveClient { panic("Received 0-RTT read key for the client") } @@ -587,16 +409,11 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), ) - h.mutex.Unlock() + h.used0RTT.Store(true) if h.logger.Debug() { h.logger.Debugf("Installed 0-RTT Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } - if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective.Opposite()) - } - return - case qtls.EncryptionHandshake: - h.readEncLevel = protocol.EncryptionHandshake + case qtls.QUICEncryptionLevelHandshake: h.handshakeOpener = newHandshakeOpener( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), @@ -606,8 +423,7 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph if h.logger.Debug() { h.logger.Debugf("Installed Handshake Read keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.EncryptionApplication: - h.readEncLevel = protocol.Encryption1RTT + case qtls.QUICEncryptionLevelApplication: h.aead.SetReadKey(suite, trafficSecret) h.has1RTTOpener = true if h.logger.Debug() { @@ -617,15 +433,18 @@ func (h *cryptoSetup) SetReadKey(encLevel qtls.EncryptionLevel, suite *qtls.Ciph panic("unexpected read encryption level") } h.mutex.Unlock() + h.runner.OnReceivedReadKeys() if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(h.readEncLevel, h.perspective.Opposite()) + h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective.Opposite()) } } -func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { +func (h *cryptoSetup) SetWriteKey(el qtls.QUICEncryptionLevel, suiteID uint16, trafficSecret []byte) { + suite := getCipherSuite(suiteID) h.mutex.Lock() - switch encLevel { - case qtls.Encryption0RTT: + //nolint:exhaustive // The TLS stack doesn't export Initial keys. + switch el { + case qtls.QUICEncryptionLevelEarly: if h.perspective == protocol.PerspectiveServer { panic("Received 0-RTT write key for the server") } @@ -640,9 +459,9 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip if h.tracer != nil { h.tracer.UpdatedKeyFromTLS(protocol.Encryption0RTT, h.perspective) } + // don't set used0RTT here. 0-RTT might still get rejected. return - case qtls.EncryptionHandshake: - h.writeEncLevel = protocol.EncryptionHandshake + case qtls.QUICEncryptionLevelHandshake: h.handshakeSealer = newHandshakeSealer( createAEAD(suite, trafficSecret, h.version), newHeaderProtector(suite, trafficSecret, true, h.version), @@ -652,14 +471,15 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip if h.logger.Debug() { h.logger.Debugf("Installed Handshake Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } - case qtls.EncryptionApplication: - h.writeEncLevel = protocol.Encryption1RTT + case qtls.QUICEncryptionLevelApplication: h.aead.SetWriteKey(suite, trafficSecret) h.has1RTTSealer = true if h.logger.Debug() { h.logger.Debugf("Installed 1-RTT Write keys (using %s)", tls.CipherSuiteName(suite.ID)) } if h.zeroRTTSealer != nil { + // Once we receive handshake keys, we know that 0-RTT was not rejected. + h.used0RTT.Store(true) h.zeroRTTSealer = nil h.logger.Debugf("Dropping 0-RTT keys.") if h.tracer != nil { @@ -671,45 +491,30 @@ func (h *cryptoSetup) SetWriteKey(encLevel qtls.EncryptionLevel, suite *qtls.Cip } h.mutex.Unlock() if h.tracer != nil { - h.tracer.UpdatedKeyFromTLS(h.writeEncLevel, h.perspective) + h.tracer.UpdatedKeyFromTLS(qtls.FromTLSEncryptionLevel(el), h.perspective) } } // WriteRecord is called when TLS writes data -func (h *cryptoSetup) WriteRecord(p []byte) (int, error) { +func (h *cryptoSetup) WriteRecord(encLevel qtls.QUICEncryptionLevel, p []byte) error { h.mutex.Lock() defer h.mutex.Unlock() - //nolint:exhaustive // LS records can only be written for Initial and Handshake. - switch h.writeEncLevel { - case protocol.EncryptionInitial: + var str io.Writer + //nolint:exhaustive // handshake records can only be written for Initial and Handshake. + switch encLevel { + case qtls.QUICEncryptionLevelInitial: // assume that the first WriteRecord call contains the ClientHello - n, err := h.initialStream.Write(p) - if !h.clientHelloWritten && h.perspective == protocol.PerspectiveClient { - h.clientHelloWritten = true - close(h.clientHelloWrittenChan) - if h.zeroRTTSealer != nil && h.zeroRTTParameters != nil { - h.logger.Debugf("Doing 0-RTT.") - h.zeroRTTParametersChan <- h.zeroRTTParameters - } else { - h.logger.Debugf("Not doing 0-RTT.") - h.zeroRTTParametersChan <- nil - } - } - return n, err - case protocol.EncryptionHandshake: - return h.handshakeStream.Write(p) + str = h.initialStream + case qtls.QUICEncryptionLevelHandshake: + str = h.handshakeStream + case qtls.QUICEncryptionLevelApplication: + str = h.oneRTTStream default: - panic(fmt.Sprintf("unexpected write encryption level: %s", h.writeEncLevel)) - } -} - -func (h *cryptoSetup) SendAlert(alert uint8) { - select { - case h.alertChan <- alert: - case <-h.closeChan: - // no need to send an alert when we've already closed + panic(fmt.Sprintf("unexpected write encryption level: %s", encLevel)) } + _, err := str.Write(p) + return err } // used a callback in the handshakeSealer and handshakeOpener @@ -722,6 +527,11 @@ func (h *cryptoSetup) dropInitialKeys() { h.logger.Debugf("Dropping Initial keys.") } +func (h *cryptoSetup) handshakeComplete() { + h.handshakeCompleteTime = time.Now() + h.runner.OnHandshakeComplete() +} + func (h *cryptoSetup) SetHandshakeConfirmed() { h.aead.SetHandshakeConfirmed() // drop Handshake keys @@ -839,5 +649,15 @@ func (h *cryptoSetup) Get1RTTOpener() (ShortHeaderOpener, error) { } func (h *cryptoSetup) ConnectionState() ConnectionState { - return qtls.GetConnectionState(h.conn) + return ConnectionState{ + ConnectionState: h.conn.ConnectionState(), + Used0RTT: h.used0RTT.Load(), + } +} + +func wrapError(err error) error { + if alertErr := qtls.AlertError(0); errors.As(err, &alertErr) && alertErr != 80 { + return qerr.NewLocalCryptoError(uint8(alertErr), err.Error()) + } + return &qerr.TransportError{ErrorCode: qerr.InternalError, ErrorMessage: err.Error()} } diff --git a/internal/handshake/crypto_setup_test.go b/internal/handshake/crypto_setup_test.go index 51d1980a01e..f6d8ae3f220 100644 --- a/internal/handshake/crypto_setup_test.go +++ b/internal/handshake/crypto_setup_test.go @@ -1,7 +1,6 @@ package handshake import ( - "bytes" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -23,12 +22,10 @@ import ( . "github.com/onsi/gomega" ) -var helloRetryRequestRandom = []byte{ // See RFC 8446, Section 4.1.3. - 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, - 0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91, - 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E, - 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C, -} +const ( + typeClientHello = 1 + typeNewSessionTicket = 4 +) type chunk struct { data []byte @@ -80,54 +77,7 @@ var _ = Describe("Crypto Setup TLS", func() { } }) - It("returns Handshake() when an error occurs in qtls", func() { - sErrChan := make(chan error, 1) - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - _, sInitialStream, sHandshakeStream := initStreams() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - runner, - testdata.GetTLSConfig(), - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.Version1, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ - ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), - ErrorMessage: "local error: tls: unexpected message", - }))) - close(done) - }() - - fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) - handledMessage := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.HandleMessage(fakeCH, protocol.EncryptionInitial) - close(handledMessage) - }() - Eventually(handledMessage).Should(BeClosed()) - Eventually(done).Should(BeClosed()) - }) - It("handles qtls errors occurring before during ClientHello generation", func() { - sErrChan := make(chan error, 1) - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) _, sInitialStream, sHandshakeStream := initStreams() tlsConf := testdata.GetTLSConfig() tlsConf.InsecureSkipVerify = true @@ -135,11 +85,10 @@ var _ = Describe("Crypto Setup TLS", func() { cl, _ := NewCryptoSetupClient( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{}, - runner, + NewMockHandshakeRunner(mockCtrl), tlsConf, false, &utils.RTTStats{}, @@ -148,76 +97,24 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.Version1, ) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - cl.RunHandshake() - close(done) - }() - - Eventually(done).Should(BeClosed()) - Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ + Expect(cl.StartHandshake()).To(MatchError(&qerr.TransportError{ ErrorCode: qerr.InternalError, ErrorMessage: "tls: invalid NextProtos value", - }))) + })) }) It("errors when a message is received at the wrong encryption level", func() { - sErrChan := make(chan error, 1) _, sInitialStream, sHandshakeStream := initStreams() runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) var token protocol.StatelessResetToken server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - runner, - testdata.GetTLSConfig(), - false, - &utils.RTTStats{}, nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.Version1, - ) - - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - close(done) - }() - - fakeCH := append([]byte{byte(typeClientHello), 0, 0, 6}, []byte("foobar")...) - server.HandleMessage(fakeCH, protocol.EncryptionHandshake) // wrong encryption level - Expect(sErrChan).To(Receive(MatchError(&qerr.TransportError{ - ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), - ErrorMessage: "expected handshake message ClientHello to have encryption level Initial, has Handshake", - }))) - - // make the go routine return - Expect(server.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) - }) - - It("returns Handshake() when handling a message fails", func() { - sErrChan := make(chan error, 1) - _, sInitialStream, sHandshakeStream := initStreams() - runner := NewMockHandshakeRunner(mockCtrl) - runner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }) - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, protocol.ConnectionID{}, - nil, - nil, &wire.TransportParameters{StatelessResetToken: &token}, runner, - serverConf, + testdata.GetTLSConfig(), false, &utils.RTTStats{}, nil, @@ -225,49 +122,13 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.Version1, ) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - var err error - Expect(sErrChan).To(Receive(&err)) - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode).To(BeEquivalentTo(0x100 + int(alertUnexpectedMessage))) - close(done) - }() - - fakeCH := append([]byte{byte(typeServerHello), 0, 0, 6}, []byte("foobar")...) - server.HandleMessage(fakeCH, protocol.EncryptionInitial) // wrong encryption level - Eventually(done).Should(BeClosed()) - }) - - It("returns Handshake() when it is closed", func() { - _, sInitialStream, sHandshakeStream := initStreams() - var token protocol.StatelessResetToken - server := NewCryptoSetupServer( - sInitialStream, - sHandshakeStream, - protocol.ConnectionID{}, - nil, - nil, - &wire.TransportParameters{StatelessResetToken: &token}, - NewMockHandshakeRunner(mockCtrl), - serverConf, - false, - &utils.RTTStats{}, - nil, - utils.DefaultLogger.WithPrefix("server"), - protocol.Version1, - ) + Expect(server.StartHandshake()).To(Succeed()) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - server.RunHandshake() - close(done) - }() - Expect(server.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) + fakeCH := append([]byte{typeClientHello, 0, 0, 6}, []byte("foobar")...) + // wrong encryption level + err := server.HandleMessage(fakeCH, protocol.EncryptionHandshake) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) }) Context("doing the handshake", func() { @@ -297,55 +158,32 @@ var _ = Describe("Crypto Setup TLS", func() { return rttStats } - handshake := func(client CryptoSetup, cChunkChan <-chan chunk, - server CryptoSetup, sChunkChan <-chan chunk, - ) { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - for { - select { - case c := <-cChunkChan: - msgType := messageType(c.data[0]) - finished := server.HandleMessage(c.data, c.encLevel) - if msgType == typeFinished { - Expect(finished).To(BeTrue()) - } else if msgType == typeClientHello { - // If this ClientHello didn't elicit a HelloRetryRequest, we're done with Initial keys. - _, err := server.GetHandshakeOpener() - Expect(finished).To(Equal(err == nil)) - } else { - Expect(finished).To(BeFalse()) - } - case c := <-sChunkChan: - msgType := messageType(c.data[0]) - finished := client.HandleMessage(c.data, c.encLevel) - if msgType == typeFinished { - Expect(finished).To(BeTrue()) - } else if msgType == typeServerHello { - Expect(finished).To(Equal(!bytes.Equal(c.data[6:6+32], helloRetryRequestRandom))) - } else { - Expect(finished).To(BeFalse()) - } - case <-done: // handshake complete - return - } - } - }() + handshake := func(client CryptoSetup, cChunkChan <-chan chunk, server CryptoSetup, sChunkChan <-chan chunk) { + Expect(client.StartHandshake()).To(Succeed()) + Expect(server.StartHandshake()).To(Succeed()) - go func() { - defer GinkgoRecover() - defer close(done) - server.RunHandshake() - ticket, err := server.GetSessionTicket() - Expect(err).ToNot(HaveOccurred()) - if ticket != nil { - client.HandleMessage(ticket, protocol.Encryption1RTT) + for { + select { + case c := <-cChunkChan: + Expect(server.HandleMessage(c.data, c.encLevel)).To(Succeed()) + continue + default: } - }() + select { + case c := <-sChunkChan: + Expect(client.HandleMessage(c.data, c.encLevel)).To(Succeed()) + continue + default: + } + // no more messages to send from client and server. Handshake complete? + break + } - client.RunHandshake() - Eventually(done).Should(BeClosed()) + ticket, err := server.GetSessionTicket() + Expect(err).ToNot(HaveOccurred()) + if ticket != nil { + Expect(client.HandleMessage(ticket, protocol.Encryption1RTT)).To(Succeed()) + } } handshakeWithTLSConf := func( @@ -359,15 +197,14 @@ var _ = Describe("Crypto Setup TLS", func() { cErrChan := make(chan error, 1) cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()) - cRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { cErrChan <- e }).MaxTimes(1) + cRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise cRunner.EXPECT().OnHandshakeComplete().Do(func() { cHandshakeComplete = true }).MaxTimes(1) cRunner.EXPECT().DropKeys(gomock.Any()).MaxTimes(1) client, clientHelloWrittenChan := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, clientTransportParameters, cRunner, clientConf, @@ -383,7 +220,7 @@ var _ = Describe("Crypto Setup TLS", func() { sErrChan := make(chan error, 1) sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) - sRunner.EXPECT().OnError(gomock.Any()).Do(func(e error) { sErrChan <- e }).MaxTimes(1) + sRunner.EXPECT().OnReceivedReadKeys().MinTimes(2).MaxTimes(3) // 3 if using 0-RTT, 2 otherwise sRunner.EXPECT().OnHandshakeComplete().Do(func() { sHandshakeComplete = true }).MaxTimes(1) if serverTransportParameters.StatelessResetToken == nil { var token protocol.StatelessResetToken @@ -392,9 +229,8 @@ var _ = Describe("Crypto Setup TLS", func() { server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, serverTransportParameters, sRunner, serverConf, @@ -462,9 +298,8 @@ var _ = Describe("Crypto Setup TLS", func() { client, chChan := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{}, runner, &tls.Config{InsecureSkipVerify: true}, @@ -475,24 +310,15 @@ var _ = Describe("Crypto Setup TLS", func() { protocol.Version1, ) - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - client.RunHandshake() - close(done) - }() + Expect(client.StartHandshake()).To(Succeed()) var ch chunk Eventually(cChunkChan).Should(Receive(&ch)) Eventually(chChan).Should(Receive(BeNil())) // make sure the whole ClientHello was written Expect(len(ch.data)).To(BeNumerically(">=", 4)) - Expect(messageType(ch.data[0])).To(Equal(typeClientHello)) + Expect(ch.data[0]).To(BeEquivalentTo(typeClientHello)) length := int(ch.data[1])<<16 | int(ch.data[2])<<8 | int(ch.data[3]) Expect(len(ch.data) - 4).To(Equal(length)) - - // make the go routine return - Expect(client.Close()).To(Succeed()) - Eventually(done).Should(BeClosed()) }) It("receives transport parameters", func() { @@ -500,14 +326,14 @@ var _ = Describe("Crypto Setup TLS", func() { cChunkChan, cInitialStream, cHandshakeStream := initStreams() cTransportParameters := &wire.TransportParameters{ActiveConnectionIDLimit: 2, MaxIdleTimeout: 0x42 * time.Second} cRunner := NewMockHandshakeRunner(mockCtrl) + cRunner.EXPECT().OnReceivedReadKeys().Times(2) cRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { sTransportParametersRcvd = tp }) cRunner.EXPECT().OnHandshakeComplete() client, _ := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, cTransportParameters, cRunner, clientConf, @@ -521,6 +347,7 @@ var _ = Describe("Crypto Setup TLS", func() { sChunkChan, sInitialStream, sHandshakeStream := initStreams() var token protocol.StatelessResetToken sRunner := NewMockHandshakeRunner(mockCtrl) + sRunner.EXPECT().OnReceivedReadKeys().Times(2) sRunner.EXPECT().OnReceivedParams(gomock.Any()).Do(func(tp *wire.TransportParameters) { cTransportParametersRcvd = tp }) sRunner.EXPECT().OnHandshakeComplete() sTransportParameters := &wire.TransportParameters{ @@ -531,9 +358,8 @@ var _ = Describe("Crypto Setup TLS", func() { server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, sTransportParameters, sRunner, serverConf, @@ -561,13 +387,13 @@ var _ = Describe("Crypto Setup TLS", func() { cChunkChan, cInitialStream, cHandshakeStream := initStreams() cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnReceivedReadKeys().Times(2) cRunner.EXPECT().OnHandshakeComplete() client, _ := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, cRunner, clientConf, @@ -581,14 +407,14 @@ var _ = Describe("Crypto Setup TLS", func() { sChunkChan, sInitialStream, sHandshakeStream := initStreams() sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnReceivedReadKeys().Times(2) sRunner.EXPECT().OnHandshakeComplete() var token protocol.StatelessResetToken server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token}, sRunner, serverConf, @@ -608,25 +434,23 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(done).Should(BeClosed()) // inject an invalid session ticket - cRunner.EXPECT().OnError(&qerr.TransportError{ - ErrorCode: 0x100 + qerr.TransportErrorCode(alertUnexpectedMessage), - ErrorMessage: "expected handshake message NewSessionTicket to have encryption level 1-RTT, has Handshake", - }) b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) - client.HandleMessage(b, protocol.EncryptionHandshake) + err := client.HandleMessage(b, protocol.EncryptionHandshake) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("tls: handshake data received at wrong level")) }) It("errors when handling the NewSessionTicket fails", func() { cChunkChan, cInitialStream, cHandshakeStream := initStreams() cRunner := NewMockHandshakeRunner(mockCtrl) cRunner.EXPECT().OnReceivedParams(gomock.Any()) + cRunner.EXPECT().OnReceivedReadKeys().Times(2) cRunner.EXPECT().OnHandshakeComplete() client, _ := NewCryptoSetupClient( cInitialStream, cHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2}, cRunner, clientConf, @@ -640,14 +464,14 @@ var _ = Describe("Crypto Setup TLS", func() { sChunkChan, sInitialStream, sHandshakeStream := initStreams() sRunner := NewMockHandshakeRunner(mockCtrl) sRunner.EXPECT().OnReceivedParams(gomock.Any()) + sRunner.EXPECT().OnReceivedReadKeys().Times(2) sRunner.EXPECT().OnHandshakeComplete() var token protocol.StatelessResetToken server := NewCryptoSetupServer( sInitialStream, sHandshakeStream, - protocol.ConnectionID{}, - nil, nil, + protocol.ConnectionID{}, &wire.TransportParameters{ActiveConnectionIDLimit: 2, StatelessResetToken: &token}, sRunner, serverConf, @@ -667,12 +491,10 @@ var _ = Describe("Crypto Setup TLS", func() { Eventually(done).Should(BeClosed()) // inject an invalid session ticket - cRunner.EXPECT().OnError(gomock.Any()).Do(func(err error) { - Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) - Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) - }) b := append([]byte{uint8(typeNewSessionTicket), 0, 0, 6}, []byte("foobar")...) - client.HandleMessage(b, protocol.Encryption1RTT) + err := client.HandleMessage(b, protocol.Encryption1RTT) + Expect(err).To(BeAssignableToTypeOf(&qerr.TransportError{})) + Expect(err.(*qerr.TransportError).ErrorCode.IsCryptoError()).To(BeTrue()) }) It("uses session resumption", func() { @@ -785,7 +607,6 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(clientHelloWrittenChan).To(Receive(BeNil())) csc.EXPECT().Get(gomock.Any()).Return(state, true) - csc.EXPECT().Put(gomock.Any(), nil) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) clientRTTStats := &utils.RTTStats{} @@ -840,7 +661,6 @@ var _ = Describe("Crypto Setup TLS", func() { Expect(clientHelloWrittenChan).To(Receive(BeNil())) csc.EXPECT().Get(gomock.Any()).Return(state, true) - csc.EXPECT().Put(gomock.Any(), nil) csc.EXPECT().Put(gomock.Any(), gomock.Any()).MaxTimes(1) clientRTTStats := &utils.RTTStats{} diff --git a/internal/handshake/handshake_suite_test.go b/internal/handshake/handshake_suite_test.go index a68f2b9589a..3289928ea07 100644 --- a/internal/handshake/handshake_suite_test.go +++ b/internal/handshake/handshake_suite_test.go @@ -6,8 +6,6 @@ import ( "strings" "testing" - "github.com/quic-go/quic-go/internal/qtls" - "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" @@ -41,8 +39,8 @@ func splitHexString(s string) (slice []byte) { return } -var cipherSuites = []*qtls.CipherSuiteTLS13{ - qtls.CipherSuiteTLS13ByID(tls.TLS_AES_128_GCM_SHA256), - qtls.CipherSuiteTLS13ByID(tls.TLS_AES_256_GCM_SHA384), - qtls.CipherSuiteTLS13ByID(tls.TLS_CHACHA20_POLY1305_SHA256), +var cipherSuites = []*cipherSuite{ + getCipherSuite(tls.TLS_AES_128_GCM_SHA256), + getCipherSuite(tls.TLS_AES_256_GCM_SHA384), + getCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256), } diff --git a/internal/handshake/header_protector.go b/internal/handshake/header_protector.go index 274fb30cbd4..fb6092e040a 100644 --- a/internal/handshake/header_protector.go +++ b/internal/handshake/header_protector.go @@ -10,7 +10,6 @@ import ( "golang.org/x/crypto/chacha20" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/qtls" ) type headerProtector interface { @@ -25,7 +24,7 @@ func hkdfHeaderProtectionLabel(v protocol.VersionNumber) string { return "quic hp" } -func newHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector { +func newHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, v protocol.VersionNumber) headerProtector { hkdfLabel := hkdfHeaderProtectionLabel(v) switch suite.ID { case tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384: @@ -45,7 +44,7 @@ type aesHeaderProtector struct { var _ headerProtector = &aesHeaderProtector{} -func newAESHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { +func newAESHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) block, err := aes.NewCipher(hpKey) if err != nil { @@ -90,7 +89,7 @@ type chachaHeaderProtector struct { var _ headerProtector = &chachaHeaderProtector{} -func newChaChaHeaderProtector(suite *qtls.CipherSuiteTLS13, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { +func newChaChaHeaderProtector(suite *cipherSuite, trafficSecret []byte, isLongHeader bool, hkdfLabel string) headerProtector { hpKey := hkdfExpandLabel(suite.Hash, trafficSecret, []byte{}, hkdfLabel, suite.KeyLen) p := &chachaHeaderProtector{ diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index ea39e7fd58a..b0377c39a81 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -7,7 +7,6 @@ import ( "golang.org/x/crypto/hkdf" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/qtls" ) var ( @@ -29,12 +28,7 @@ func getSalt(v protocol.VersionNumber) []byte { return quicSaltV1 } -var initialSuite = &qtls.CipherSuiteTLS13{ - ID: tls.TLS_AES_128_GCM_SHA256, - KeyLen: 16, - AEAD: qtls.AEADAESGCMTLS13, - Hash: crypto.SHA256, -} +var initialSuite = getCipherSuite(tls.TLS_AES_128_GCM_SHA256) // NewInitialAEAD creates a new AEAD for Initial encryption / decryption. func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v protocol.VersionNumber) (LongHeaderSealer, LongHeaderOpener) { @@ -50,8 +44,8 @@ func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective, v p myKey, myIV := computeInitialKeyAndIV(mySecret, v) otherKey, otherIV := computeInitialKeyAndIV(otherSecret, v) - encrypter := qtls.AEADAESGCMTLS13(myKey, myIV) - decrypter := qtls.AEADAESGCMTLS13(otherKey, otherIV) + encrypter := initialSuite.AEAD(myKey, myIV) + decrypter := initialSuite.AEAD(otherKey, otherIV) return newLongHeaderSealer(encrypter, newHeaderProtector(initialSuite, mySecret, true, v)), newLongHeaderOpener(decrypter, newAESHeaderProtector(initialSuite, otherSecret, true, hkdfHeaderProtectionLabel(v))) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index f80b6e0e36f..ab242953e9d 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -1,12 +1,12 @@ package handshake import ( + "crypto/tls" "errors" "io" "time" "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/internal/wire" ) @@ -22,9 +22,6 @@ var ( ErrDecryptionFailed = errors.New("decryption failed") ) -// ConnectionState contains information about the state of the connection. -type ConnectionState = qtls.ConnectionState - type headerDecryptor interface { DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) } @@ -56,28 +53,26 @@ type ShortHeaderSealer interface { KeyPhase() protocol.KeyPhaseBit } -// A tlsExtensionHandler sends and received the QUIC TLS extension. -type tlsExtensionHandler interface { - GetExtensions(msgType uint8) []qtls.Extension - ReceivedExtensions(msgType uint8, exts []qtls.Extension) - TransportParameters() <-chan []byte -} - type handshakeRunner interface { OnReceivedParams(*wire.TransportParameters) OnHandshakeComplete() - OnError(error) + OnReceivedReadKeys() DropKeys(protocol.EncryptionLevel) } +type ConnectionState struct { + tls.ConnectionState + Used0RTT bool +} + // CryptoSetup handles the handshake and protecting / unprotecting packets type CryptoSetup interface { - RunHandshake() + StartHandshake() error io.Closer ChangeConnectionID(protocol.ConnectionID) GetSessionTicket() ([]byte, error) - HandleMessage([]byte, protocol.EncryptionLevel) bool + HandleMessage([]byte, protocol.EncryptionLevel) error SetLargest1RTTAcked(protocol.PacketNumber) error SetHandshakeConfirmed() ConnectionState() ConnectionState diff --git a/internal/handshake/mock_handshake_runner_test.go b/internal/handshake/mock_handshake_runner_test.go index fa8decbe498..9d3cfef52ae 100644 --- a/internal/handshake/mock_handshake_runner_test.go +++ b/internal/handshake/mock_handshake_runner_test.go @@ -47,18 +47,6 @@ func (mr *MockHandshakeRunnerMockRecorder) DropKeys(arg0 interface{}) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).DropKeys), arg0) } -// OnError mocks base method. -func (m *MockHandshakeRunner) OnError(arg0 error) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "OnError", arg0) -} - -// OnError indicates an expected call of OnError. -func (mr *MockHandshakeRunnerMockRecorder) OnError(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnError", reflect.TypeOf((*MockHandshakeRunner)(nil).OnError), arg0) -} - // OnHandshakeComplete mocks base method. func (m *MockHandshakeRunner) OnHandshakeComplete() { m.ctrl.T.Helper() @@ -82,3 +70,15 @@ func (mr *MockHandshakeRunnerMockRecorder) OnReceivedParams(arg0 interface{}) *g mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedParams", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedParams), arg0) } + +// OnReceivedReadKeys mocks base method. +func (m *MockHandshakeRunner) OnReceivedReadKeys() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "OnReceivedReadKeys") +} + +// OnReceivedReadKeys indicates an expected call of OnReceivedReadKeys. +func (mr *MockHandshakeRunnerMockRecorder) OnReceivedReadKeys() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnReceivedReadKeys", reflect.TypeOf((*MockHandshakeRunner)(nil).OnReceivedReadKeys)) +} diff --git a/internal/handshake/session_ticket.go b/internal/handshake/session_ticket.go index 56bcbcd5d06..d3efeb2941c 100644 --- a/internal/handshake/session_ticket.go +++ b/internal/handshake/session_ticket.go @@ -10,7 +10,7 @@ import ( "github.com/quic-go/quic-go/quicvarint" ) -const sessionTicketRevision = 2 +const sessionTicketRevision = 3 type sessionTicket struct { Parameters *wire.TransportParameters diff --git a/internal/handshake/tls_extension_handler.go b/internal/handshake/tls_extension_handler.go deleted file mode 100644 index e46a930c73a..00000000000 --- a/internal/handshake/tls_extension_handler.go +++ /dev/null @@ -1,61 +0,0 @@ -package handshake - -import ( - "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/qtls" -) - -const quicTLSExtensionType = 0x39 - -type extensionHandler struct { - ourParams []byte - paramsChan chan []byte - - extensionType uint16 - - perspective protocol.Perspective -} - -var _ tlsExtensionHandler = &extensionHandler{} - -// newExtensionHandler creates a new extension handler -func newExtensionHandler(params []byte, pers protocol.Perspective) tlsExtensionHandler { - return &extensionHandler{ - ourParams: params, - paramsChan: make(chan []byte), - perspective: pers, - extensionType: quicTLSExtensionType, - } -} - -func (h *extensionHandler) GetExtensions(msgType uint8) []qtls.Extension { - if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeClientHello) || - (h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeEncryptedExtensions) { - return nil - } - return []qtls.Extension{{ - Type: h.extensionType, - Data: h.ourParams, - }} -} - -func (h *extensionHandler) ReceivedExtensions(msgType uint8, exts []qtls.Extension) { - if (h.perspective == protocol.PerspectiveClient && messageType(msgType) != typeEncryptedExtensions) || - (h.perspective == protocol.PerspectiveServer && messageType(msgType) != typeClientHello) { - return - } - - var data []byte - for _, ext := range exts { - if ext.Type == h.extensionType { - data = ext.Data - break - } - } - - h.paramsChan <- data -} - -func (h *extensionHandler) TransportParameters() <-chan []byte { - return h.paramsChan -} diff --git a/internal/handshake/tls_extension_handler_test.go b/internal/handshake/tls_extension_handler_test.go deleted file mode 100644 index 4e557c9628e..00000000000 --- a/internal/handshake/tls_extension_handler_test.go +++ /dev/null @@ -1,165 +0,0 @@ -package handshake - -import ( - "github.com/quic-go/quic-go/internal/protocol" - "github.com/quic-go/quic-go/internal/qtls" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("TLS Extension Handler, for the server", func() { - var ( - handlerServer tlsExtensionHandler - handlerClient tlsExtensionHandler - ) - - JustBeforeEach(func() { - handlerServer = newExtensionHandler([]byte("foobar"), protocol.PerspectiveServer) - handlerClient = newExtensionHandler([]byte("raboof"), protocol.PerspectiveClient) - }) - - Context("for the server", func() { - Context("sending", func() { - It("only adds TransportParameters for the Encrypted Extensions", func() { - // test 2 other handshake types - Expect(handlerServer.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) - Expect(handlerServer.GetExtensions(uint8(typeFinished))).To(BeEmpty()) - }) - - It("adds TransportParameters to the EncryptedExtensions message", func() { - exts := handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) - Expect(exts).To(HaveLen(1)) - Expect(exts[0].Type).To(BeEquivalentTo(quicTLSExtensionType)) - Expect(exts[0].Data).To(Equal([]byte("foobar"))) - }) - }) - - Context("receiving", func() { - var chExts []qtls.Extension - - JustBeforeEach(func() { - chExts = handlerClient.GetExtensions(uint8(typeClientHello)) - Expect(chExts).To(HaveLen(1)) - }) - - It("sends the extension on the channel", func() { - go func() { - defer GinkgoRecover() - handlerServer.ReceivedExtensions(uint8(typeClientHello), chExts) - }() - - var data []byte - Eventually(handlerServer.TransportParameters()).Should(Receive(&data)) - Expect(data).To(Equal([]byte("raboof"))) - }) - - It("sends nil on the channel if the extension is missing", func() { - go func() { - defer GinkgoRecover() - handlerServer.ReceivedExtensions(uint8(typeClientHello), nil) - }() - - var data []byte - Eventually(handlerServer.TransportParameters()).Should(Receive(&data)) - Expect(data).To(BeEmpty()) - }) - - It("ignores extensions with different code points", func() { - go func() { - defer GinkgoRecover() - exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}} - handlerServer.ReceivedExtensions(uint8(typeClientHello), exts) - }() - - var data []byte - Eventually(handlerServer.TransportParameters()).Should(Receive()) - Expect(data).To(BeEmpty()) - }) - - It("ignores extensions that are not sent with the ClientHello", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handlerServer.ReceivedExtensions(uint8(typeFinished), chExts) - close(done) - }() - - Consistently(handlerServer.TransportParameters()).ShouldNot(Receive()) - Eventually(done).Should(BeClosed()) - }) - }) - }) - - Context("for the client", func() { - Context("sending", func() { - It("only adds TransportParameters for the Encrypted Extensions", func() { - // test 2 other handshake types - Expect(handlerClient.GetExtensions(uint8(typeCertificate))).To(BeEmpty()) - Expect(handlerClient.GetExtensions(uint8(typeFinished))).To(BeEmpty()) - }) - - It("adds TransportParameters to the ClientHello message", func() { - exts := handlerClient.GetExtensions(uint8(typeClientHello)) - Expect(exts).To(HaveLen(1)) - Expect(exts[0].Type).To(BeEquivalentTo(quicTLSExtensionType)) - Expect(exts[0].Data).To(Equal([]byte("raboof"))) - }) - }) - - Context("receiving", func() { - var chExts []qtls.Extension - - JustBeforeEach(func() { - chExts = handlerServer.GetExtensions(uint8(typeEncryptedExtensions)) - Expect(chExts).To(HaveLen(1)) - }) - - It("sends the extension on the channel", func() { - go func() { - defer GinkgoRecover() - handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), chExts) - }() - - var data []byte - Eventually(handlerClient.TransportParameters()).Should(Receive(&data)) - Expect(data).To(Equal([]byte("foobar"))) - }) - - It("sends nil on the channel if the extension is missing", func() { - go func() { - defer GinkgoRecover() - handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), nil) - }() - - var data []byte - Eventually(handlerClient.TransportParameters()).Should(Receive(&data)) - Expect(data).To(BeEmpty()) - }) - - It("ignores extensions with different code points", func() { - go func() { - defer GinkgoRecover() - exts := []qtls.Extension{{Type: 0x1337, Data: []byte("invalid")}} - handlerClient.ReceivedExtensions(uint8(typeEncryptedExtensions), exts) - }() - - var data []byte - Eventually(handlerClient.TransportParameters()).Should(Receive()) - Expect(data).To(BeEmpty()) - }) - - It("ignores extensions that are not sent with the EncryptedExtensions", func() { - done := make(chan struct{}) - go func() { - defer GinkgoRecover() - handlerClient.ReceivedExtensions(uint8(typeFinished), chExts) - close(done) - }() - - Consistently(handlerClient.TransportParameters()).ShouldNot(Receive()) - Eventually(done).Should(BeClosed()) - }) - }) - }) -}) diff --git a/internal/handshake/updatable_aead.go b/internal/handshake/updatable_aead.go index ac01acdb1c1..919b8a5bcf0 100644 --- a/internal/handshake/updatable_aead.go +++ b/internal/handshake/updatable_aead.go @@ -10,7 +10,6 @@ import ( "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/qerr" - "github.com/quic-go/quic-go/internal/qtls" "github.com/quic-go/quic-go/internal/utils" "github.com/quic-go/quic-go/logging" ) @@ -24,7 +23,7 @@ var KeyUpdateInterval uint64 = protocol.KeyUpdateInterval var FirstKeyUpdateInterval uint64 = 100 type updatableAEAD struct { - suite *qtls.CipherSuiteTLS13 + suite *cipherSuite keyPhase protocol.KeyPhase largestAcked protocol.PacketNumber @@ -121,7 +120,7 @@ func (a *updatableAEAD) getNextTrafficSecret(hash crypto.Hash, ts []byte) []byte // SetReadKey sets the read key. // For the client, this function is called before SetWriteKey. // For the server, this function is called after SetWriteKey. -func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { +func (a *updatableAEAD) SetReadKey(suite *cipherSuite, trafficSecret []byte) { a.rcvAEAD = createAEAD(suite, trafficSecret, a.version) a.headerDecrypter = newHeaderProtector(suite, trafficSecret, false, a.version) if a.suite == nil { @@ -135,7 +134,7 @@ func (a *updatableAEAD) SetReadKey(suite *qtls.CipherSuiteTLS13, trafficSecret [ // SetWriteKey sets the write key. // For the client, this function is called after SetReadKey. // For the server, this function is called before SetWriteKey. -func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret []byte) { +func (a *updatableAEAD) SetWriteKey(suite *cipherSuite, trafficSecret []byte) { a.sendAEAD = createAEAD(suite, trafficSecret, a.version) a.headerEncrypter = newHeaderProtector(suite, trafficSecret, false, a.version) if a.suite == nil { @@ -146,7 +145,7 @@ func (a *updatableAEAD) SetWriteKey(suite *qtls.CipherSuiteTLS13, trafficSecret a.nextSendAEAD = createAEAD(suite, a.nextSendTrafficSecret, a.version) } -func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *qtls.CipherSuiteTLS13) { +func (a *updatableAEAD) setAEADParameters(aead cipher.AEAD, suite *cipherSuite) { a.nonceBuf = make([]byte, aead.NonceSize()) a.aeadOverhead = aead.Overhead() a.suite = suite diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index 43830825af1..0c5d528fb8e 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -10,7 +10,6 @@ import ( gomock "github.com/golang/mock/gomock" handshake "github.com/quic-go/quic-go/internal/handshake" protocol "github.com/quic-go/quic-go/internal/protocol" - qtls "github.com/quic-go/quic-go/internal/qtls" ) // MockCryptoSetup is a mock of CryptoSetup interface. @@ -63,10 +62,10 @@ func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call { } // ConnectionState mocks base method. -func (m *MockCryptoSetup) ConnectionState() qtls.ConnectionState { +func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ConnectionState") - ret0, _ := ret[0].(qtls.ConnectionState) + ret0, _ := ret[0].(handshake.ConnectionState) return ret0 } @@ -212,10 +211,10 @@ func (mr *MockCryptoSetupMockRecorder) GetSessionTicket() *gomock.Call { } // HandleMessage mocks base method. -func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { +func (m *MockCryptoSetup) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) - ret0, _ := ret[0].(bool) + ret0, _ := ret[0].(error) return ret0 } @@ -225,18 +224,6 @@ func (mr *MockCryptoSetupMockRecorder) HandleMessage(arg0, arg1 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockCryptoSetup)(nil).HandleMessage), arg0, arg1) } -// RunHandshake mocks base method. -func (m *MockCryptoSetup) RunHandshake() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "RunHandshake") -} - -// RunHandshake indicates an expected call of RunHandshake. -func (mr *MockCryptoSetupMockRecorder) RunHandshake() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).RunHandshake)) -} - // SetHandshakeConfirmed mocks base method. func (m *MockCryptoSetup) SetHandshakeConfirmed() { m.ctrl.T.Helper() @@ -262,3 +249,17 @@ func (mr *MockCryptoSetupMockRecorder) SetLargest1RTTAcked(arg0 interface{}) *go mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLargest1RTTAcked", reflect.TypeOf((*MockCryptoSetup)(nil).SetLargest1RTTAcked), arg0) } + +// StartHandshake mocks base method. +func (m *MockCryptoSetup) StartHandshake() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartHandshake") + ret0, _ := ret[0].(error) + return ret0 +} + +// StartHandshake indicates an expected call of StartHandshake. +func (mr *MockCryptoSetupMockRecorder) StartHandshake() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartHandshake", reflect.TypeOf((*MockCryptoSetup)(nil).StartHandshake)) +} diff --git a/internal/qerr/error_codes.go b/internal/qerr/error_codes.go index cc846df6a77..a037acd22e6 100644 --- a/internal/qerr/error_codes.go +++ b/internal/qerr/error_codes.go @@ -40,7 +40,7 @@ func (e TransportErrorCode) Message() string { if !e.IsCryptoError() { return "" } - return qtls.Alert(e - 0x100).Error() + return qtls.AlertError(e - 0x100).Error() } func (e TransportErrorCode) String() string { diff --git a/internal/qtls/cipher_suite_go121.go b/internal/qtls/cipher_suite_go121.go new file mode 100644 index 00000000000..aa8c768fd25 --- /dev/null +++ b/internal/qtls/cipher_suite_go121.go @@ -0,0 +1,66 @@ +//go:build go1.21 + +package qtls + +import ( + "crypto" + "crypto/cipher" + "crypto/tls" + "fmt" + "unsafe" +) + +type cipherSuiteTLS13 struct { + ID uint16 + KeyLen int + AEAD func(key, fixedNonce []byte) cipher.AEAD + Hash crypto.Hash +} + +//go:linkname cipherSuiteTLS13ByID crypto/tls.cipherSuiteTLS13ByID +func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 + +//go:linkname cipherSuitesTLS13 crypto/tls.cipherSuitesTLS13 +var cipherSuitesTLS13 []unsafe.Pointer + +//go:linkname defaultCipherSuitesTLS13 crypto/tls.defaultCipherSuitesTLS13 +var defaultCipherSuitesTLS13 []uint16 + +//go:linkname defaultCipherSuitesTLS13NoAES crypto/tls.defaultCipherSuitesTLS13NoAES +var defaultCipherSuitesTLS13NoAES []uint16 + +var cipherSuitesModified bool + +// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls +// such that it only contains the cipher suite with the chosen id. +// The reset function returned resets them back to the original value. +func SetCipherSuite(id uint16) (reset func()) { + if cipherSuitesModified { + panic("cipher suites modified multiple times without resetting") + } + cipherSuitesModified = true + + origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...) + origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...) + origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...) + // The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls. + switch id { + case tls.TLS_AES_128_GCM_SHA256: + cipherSuitesTLS13 = cipherSuitesTLS13[:1] + case tls.TLS_CHACHA20_POLY1305_SHA256: + cipherSuitesTLS13 = cipherSuitesTLS13[1:2] + case tls.TLS_AES_256_GCM_SHA384: + cipherSuitesTLS13 = cipherSuitesTLS13[2:] + default: + panic(fmt.Sprintf("unexpected cipher suite: %d", id)) + } + defaultCipherSuitesTLS13 = []uint16{id} + defaultCipherSuitesTLS13NoAES = []uint16{id} + + return func() { + cipherSuitesTLS13 = origCipherSuitesTLS13 + defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13 + defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES + cipherSuitesModified = false + } +} diff --git a/internal/qtls/client_session_cache.go b/internal/qtls/client_session_cache.go new file mode 100644 index 00000000000..e4dc06eecce --- /dev/null +++ b/internal/qtls/client_session_cache.go @@ -0,0 +1,63 @@ +//go:build go1.21 + +package qtls + +import ( + "crypto/tls" +) + +type clientSessionCache struct { + getData func() []byte + setData func([]byte) + wrapped tls.ClientSessionCache +} + +var _ tls.ClientSessionCache = &clientSessionCache{} + +func (c clientSessionCache) Put(key string, cs *tls.ClientSessionState) { + if cs == nil { + c.wrapped.Put(key, nil) + return + } + ticket, state, err := cs.ResumptionState() + if err != nil || !state.EarlyData { + c.wrapped.Put(key, cs) + return + } + state.Extra = append(state.Extra, addExtraPrefix(c.getData())) + newCS, err := tls.NewResumptionState(ticket, state) + if err != nil { + // It's not clear why this would error. Just save the original state. + c.wrapped.Put(key, cs) + return + } + c.wrapped.Put(key, newCS) +} + +func (c clientSessionCache) Get(key string) (*tls.ClientSessionState, bool) { + cs, ok := c.wrapped.Get(key) + if !ok || cs == nil { + return cs, ok + } + ticket, state, err := cs.ResumptionState() + if err != nil { + // It's not clear why this would error. + // Remove the ticket from the session cache, so we don't run into this error over and over again + c.wrapped.Put(key, nil) + return nil, false + } + if state.EarlyData { + // restore QUIC transport parameters stored in state.Extra + if extra := findExtraData(state.Extra); extra != nil { + c.setData(extra) + } + } + session, err := tls.NewResumptionState(ticket, state) + if err != nil { + // It's not clear why this would error. + // Remove the ticket from the session cache, so we don't run into this error over and over again + c.wrapped.Put(key, nil) + return nil, false + } + return session, true +} diff --git a/internal/qtls/go119.go b/internal/qtls/go119.go deleted file mode 100644 index f040b859c6e..00000000000 --- a/internal/qtls/go119.go +++ /dev/null @@ -1,145 +0,0 @@ -//go:build go1.19 && !go1.20 - -package qtls - -import ( - "crypto" - "crypto/cipher" - "crypto/tls" - "fmt" - "net" - "unsafe" - - "github.com/quic-go/qtls-go1-19" -) - -type ( - // Alert is a TLS alert - Alert = qtls.Alert - // A Certificate is qtls.Certificate. - Certificate = qtls.Certificate - // CertificateRequestInfo contains information about a certificate request. - CertificateRequestInfo = qtls.CertificateRequestInfo - // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 - CipherSuiteTLS13 = qtls.CipherSuiteTLS13 - // ClientHelloInfo contains information about a ClientHello. - ClientHelloInfo = qtls.ClientHelloInfo - // ClientSessionCache is a cache used for session resumption. - ClientSessionCache = qtls.ClientSessionCache - // ClientSessionState is a state needed for session resumption. - ClientSessionState = qtls.ClientSessionState - // A Config is a qtls.Config. - Config = qtls.Config - // A Conn is a qtls.Conn. - Conn = qtls.Conn - // ConnectionState contains information about the state of the connection. - ConnectionState = qtls.ConnectionStateWith0RTT - // EncryptionLevel is the encryption level of a message. - EncryptionLevel = qtls.EncryptionLevel - // Extension is a TLS extension - Extension = qtls.Extension - // ExtraConfig is the qtls.ExtraConfig - ExtraConfig = qtls.ExtraConfig - // RecordLayer is a qtls RecordLayer. - RecordLayer = qtls.RecordLayer -) - -const ( - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake = qtls.EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT = qtls.Encryption0RTT - // EncryptionApplication is the application data encryption level - EncryptionApplication = qtls.EncryptionApplication -) - -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return qtls.AEADAESGCMTLS13(key, fixedNonce) -} - -// Client returns a new TLS client side connection. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Client(conn, config, extraConfig) -} - -// Server returns a new TLS server side connection. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Server(conn, config, extraConfig) -} - -func GetConnectionState(conn *Conn) ConnectionState { - return conn.ConnectionStateWith0RTT() -} - -// ToTLSConnectionState extracts the tls.ConnectionState -func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { - return cs.ConnectionState -} - -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash -} - -//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-19.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. -func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { - val := cipherSuiteTLS13ByID(id) - cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) - return &qtls.CipherSuiteTLS13{ - ID: cs.ID, - KeyLen: cs.KeyLen, - AEAD: cs.AEAD, - Hash: cs.Hash, - } -} - -//go:linkname cipherSuitesTLS13 github.com/quic-go/qtls-go1-19.cipherSuitesTLS13 -var cipherSuitesTLS13 []unsafe.Pointer - -//go:linkname defaultCipherSuitesTLS13 github.com/quic-go/qtls-go1-19.defaultCipherSuitesTLS13 -var defaultCipherSuitesTLS13 []uint16 - -//go:linkname defaultCipherSuitesTLS13NoAES github.com/quic-go/qtls-go1-19.defaultCipherSuitesTLS13NoAES -var defaultCipherSuitesTLS13NoAES []uint16 - -var cipherSuitesModified bool - -// SetCipherSuite modifies the cipherSuiteTLS13 slice of cipher suites inside qtls -// such that it only contains the cipher suite with the chosen id. -// The reset function returned resets them back to the original value. -func SetCipherSuite(id uint16) (reset func()) { - if cipherSuitesModified { - panic("cipher suites modified multiple times without resetting") - } - cipherSuitesModified = true - - origCipherSuitesTLS13 := append([]unsafe.Pointer{}, cipherSuitesTLS13...) - origDefaultCipherSuitesTLS13 := append([]uint16{}, defaultCipherSuitesTLS13...) - origDefaultCipherSuitesTLS13NoAES := append([]uint16{}, defaultCipherSuitesTLS13NoAES...) - // The order is given by the order of the slice elements in cipherSuitesTLS13 in qtls. - switch id { - case tls.TLS_AES_128_GCM_SHA256: - cipherSuitesTLS13 = cipherSuitesTLS13[:1] - case tls.TLS_CHACHA20_POLY1305_SHA256: - cipherSuitesTLS13 = cipherSuitesTLS13[1:2] - case tls.TLS_AES_256_GCM_SHA384: - cipherSuitesTLS13 = cipherSuitesTLS13[2:] - default: - panic(fmt.Sprintf("unexpected cipher suite: %d", id)) - } - defaultCipherSuitesTLS13 = []uint16{id} - defaultCipherSuitesTLS13NoAES = []uint16{id} - - return func() { - cipherSuitesTLS13 = origCipherSuitesTLS13 - defaultCipherSuitesTLS13 = origDefaultCipherSuitesTLS13 - defaultCipherSuitesTLS13NoAES = origDefaultCipherSuitesTLS13NoAES - cipherSuitesModified = false - } -} diff --git a/internal/qtls/go120.go b/internal/qtls/go120.go index a40146ab4fa..b92e1ceb577 100644 --- a/internal/qtls/go120.go +++ b/internal/qtls/go120.go @@ -1,101 +1,94 @@ -//go:build go1.20 +//go:build go1.20 && !go1.21 package qtls import ( - "crypto" - "crypto/cipher" "crypto/tls" "fmt" - "net" "unsafe" + "github.com/quic-go/quic-go/internal/protocol" + "github.com/quic-go/qtls-go1-20" ) type ( - // Alert is a TLS alert - Alert = qtls.Alert - // A Certificate is qtls.Certificate. - Certificate = qtls.Certificate - // CertificateRequestInfo contains information about a certificate request. - CertificateRequestInfo = qtls.CertificateRequestInfo - // A CipherSuiteTLS13 is a cipher suite for TLS 1.3 - CipherSuiteTLS13 = qtls.CipherSuiteTLS13 - // ClientHelloInfo contains information about a ClientHello. - ClientHelloInfo = qtls.ClientHelloInfo - // ClientSessionCache is a cache used for session resumption. - ClientSessionCache = qtls.ClientSessionCache - // ClientSessionState is a state needed for session resumption. - ClientSessionState = qtls.ClientSessionState - // A Config is a qtls.Config. - Config = qtls.Config - // A Conn is a qtls.Conn. - Conn = qtls.Conn - // ConnectionState contains information about the state of the connection. - ConnectionState = qtls.ConnectionStateWith0RTT - // EncryptionLevel is the encryption level of a message. - EncryptionLevel = qtls.EncryptionLevel - // Extension is a TLS extension - Extension = qtls.Extension - // ExtraConfig is the qtls.ExtraConfig - ExtraConfig = qtls.ExtraConfig - // RecordLayer is a qtls RecordLayer. - RecordLayer = qtls.RecordLayer + QUICConn = qtls.QUICConn + QUICConfig = qtls.QUICConfig + QUICEvent = qtls.QUICEvent + QUICEventKind = qtls.QUICEventKind + QUICEncryptionLevel = qtls.QUICEncryptionLevel + AlertError = qtls.AlertError ) const ( - // EncryptionHandshake is the Handshake encryption level - EncryptionHandshake = qtls.EncryptionHandshake - // Encryption0RTT is the 0-RTT encryption level - Encryption0RTT = qtls.Encryption0RTT - // EncryptionApplication is the application data encryption level - EncryptionApplication = qtls.EncryptionApplication + QUICEncryptionLevelInitial = qtls.QUICEncryptionLevelInitial + QUICEncryptionLevelEarly = qtls.QUICEncryptionLevelEarly + QUICEncryptionLevelHandshake = qtls.QUICEncryptionLevelHandshake + QUICEncryptionLevelApplication = qtls.QUICEncryptionLevelApplication ) -// AEADAESGCMTLS13 creates a new AES-GCM AEAD for TLS 1.3 -func AEADAESGCMTLS13(key, fixedNonce []byte) cipher.AEAD { - return qtls.AEADAESGCMTLS13(key, fixedNonce) -} +const ( + QUICNoEvent = qtls.QUICNoEvent + QUICSetReadSecret = qtls.QUICSetReadSecret + QUICSetWriteSecret = qtls.QUICSetWriteSecret + QUICWriteData = qtls.QUICWriteData + QUICTransportParameters = qtls.QUICTransportParameters + QUICTransportParametersRequired = qtls.QUICTransportParametersRequired + QUICRejectedEarlyData = qtls.QUICRejectedEarlyData + QUICHandshakeDone = qtls.QUICHandshakeDone +) -// Client returns a new TLS client side connection. -func Client(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Client(conn, config, extraConfig) +func SetupConfigForServer(conf *QUICConfig, enable0RTT bool, getDataForSessionTicket func() []byte, accept0RTT func([]byte) bool) { + conf.ExtraConfig = &qtls.ExtraConfig{ + Enable0RTT: enable0RTT, + Accept0RTT: accept0RTT, + GetAppDataForSessionTicket: getDataForSessionTicket, + } } -// Server returns a new TLS server side connection. -func Server(conn net.Conn, config *Config, extraConfig *ExtraConfig) *Conn { - return qtls.Server(conn, config, extraConfig) +func SetupConfigForClient(conf *QUICConfig, getDataForSessionState func() []byte, setDataFromSessionState func([]byte)) { + conf.ExtraConfig = &qtls.ExtraConfig{ + GetAppDataForSessionState: getDataForSessionState, + SetAppDataFromSessionState: setDataFromSessionState, + } } -func GetConnectionState(conn *Conn) ConnectionState { - return conn.ConnectionStateWith0RTT() +func QUICServer(config *QUICConfig) *QUICConn { + return qtls.QUICServer(config) } -// ToTLSConnectionState extracts the tls.ConnectionState -func ToTLSConnectionState(cs ConnectionState) tls.ConnectionState { - return cs.ConnectionState +func QUICClient(config *QUICConfig) *QUICConn { + return qtls.QUICClient(config) } -type cipherSuiteTLS13 struct { - ID uint16 - KeyLen int - AEAD func(key, fixedNonce []byte) cipher.AEAD - Hash crypto.Hash +func ToTLSEncryptionLevel(e protocol.EncryptionLevel) qtls.QUICEncryptionLevel { + switch e { + case protocol.EncryptionInitial: + return qtls.QUICEncryptionLevelInitial + case protocol.EncryptionHandshake: + return qtls.QUICEncryptionLevelHandshake + case protocol.Encryption1RTT: + return qtls.QUICEncryptionLevelApplication + case protocol.Encryption0RTT: + return qtls.QUICEncryptionLevelEarly + default: + panic(fmt.Sprintf("unexpected encryption level: %s", e)) + } } -//go:linkname cipherSuiteTLS13ByID github.com/quic-go/qtls-go1-20.cipherSuiteTLS13ByID -func cipherSuiteTLS13ByID(id uint16) *cipherSuiteTLS13 - -// CipherSuiteTLS13ByID gets a TLS 1.3 cipher suite. -func CipherSuiteTLS13ByID(id uint16) *CipherSuiteTLS13 { - val := cipherSuiteTLS13ByID(id) - cs := (*cipherSuiteTLS13)(unsafe.Pointer(val)) - return &qtls.CipherSuiteTLS13{ - ID: cs.ID, - KeyLen: cs.KeyLen, - AEAD: cs.AEAD, - Hash: cs.Hash, +func FromTLSEncryptionLevel(e qtls.QUICEncryptionLevel) protocol.EncryptionLevel { + switch e { + case qtls.QUICEncryptionLevelInitial: + return protocol.EncryptionInitial + case qtls.QUICEncryptionLevelHandshake: + return protocol.EncryptionHandshake + case qtls.QUICEncryptionLevelApplication: + return protocol.Encryption1RTT + case qtls.QUICEncryptionLevelEarly: + return protocol.Encryption0RTT + default: + panic(fmt.Sprintf("unexpect encryption level: %s", e)) } } diff --git a/internal/qtls/go121.go b/internal/qtls/go121.go index b33406397b6..780c5267413 100644 --- a/internal/qtls/go121.go +++ b/internal/qtls/go121.go @@ -2,4 +2,147 @@ package qtls -var _ int = "The version of quic-go you're using can't be built on Go 1.21 yet. For more details, please see https://github.com/quic-go/quic-go/wiki/quic-go-and-Go-versions." +import ( + "bytes" + "crypto/tls" + "fmt" + + "github.com/quic-go/quic-go/internal/protocol" +) + +type ( + QUICConn = tls.QUICConn + QUICConfig = tls.QUICConfig + QUICEvent = tls.QUICEvent + QUICEventKind = tls.QUICEventKind + QUICEncryptionLevel = tls.QUICEncryptionLevel + AlertError = tls.AlertError +) + +const ( + QUICEncryptionLevelInitial = tls.QUICEncryptionLevelInitial + QUICEncryptionLevelEarly = tls.QUICEncryptionLevelEarly + QUICEncryptionLevelHandshake = tls.QUICEncryptionLevelHandshake + QUICEncryptionLevelApplication = tls.QUICEncryptionLevelApplication +) + +const ( + QUICNoEvent = tls.QUICNoEvent + QUICSetReadSecret = tls.QUICSetReadSecret + QUICSetWriteSecret = tls.QUICSetWriteSecret + QUICWriteData = tls.QUICWriteData + QUICTransportParameters = tls.QUICTransportParameters + QUICTransportParametersRequired = tls.QUICTransportParametersRequired + QUICRejectedEarlyData = tls.QUICRejectedEarlyData + QUICHandshakeDone = tls.QUICHandshakeDone +) + +func QUICServer(config *QUICConfig) *QUICConn { return tls.QUICServer(config) } +func QUICClient(config *QUICConfig) *QUICConn { return tls.QUICClient(config) } + +func SetupConfigForServer(qconf *QUICConfig, _ bool, getData func() []byte, accept0RTT func([]byte) bool) { + conf := qconf.TLSConfig + conf = conf.Clone() + qconf.TLSConfig = conf + + // add callbacks to save transport parameters into the session ticket + origWrapSession := conf.WrapSession + conf.WrapSession = func(cs tls.ConnectionState, state *tls.SessionState) ([]byte, error) { + // Add QUIC transport parameters if this is a 0-RTT packet. + // TODO(#3853): also save the RTT for non-0-RTT tickets + if state.EarlyData { + state.Extra = append(state.Extra, addExtraPrefix(getData())) + } + if origWrapSession != nil { + return origWrapSession(cs, state) + } + b, err := conf.EncryptTicket(cs, state) + return b, err + } + origUnwrapSession := conf.UnwrapSession + // UnwrapSession might be called multiple times, as the client can use multiple session tickets. + // However, using 0-RTT is only possible with the first session ticket. + // crypto/tls guarantees that this callback is called in the same order as the session ticket in the ClientHello. + var unwrapCount int + conf.UnwrapSession = func(identity []byte, connState tls.ConnectionState) (*tls.SessionState, error) { + unwrapCount++ + var state *tls.SessionState + var err error + if origUnwrapSession != nil { + state, err = origUnwrapSession(identity, connState) + } else { + state, err = conf.DecryptTicket(identity, connState) + } + if err != nil || state == nil { + return nil, err + } + if state.EarlyData { + extra := findExtraData(state.Extra) + if unwrapCount == 1 && extra != nil { // first session ticket + state.EarlyData = accept0RTT(extra) + } else { // subsequent session ticket, can't be used for 0-RTT + state.EarlyData = false + } + } + return state, nil + } +} + +func SetupConfigForClient(qconf *QUICConfig, getData func() []byte, setData func([]byte)) { + conf := qconf.TLSConfig + if conf.ClientSessionCache != nil { + origCache := conf.ClientSessionCache + conf.ClientSessionCache = &clientSessionCache{ + wrapped: origCache, + getData: getData, + setData: setData, + } + } +} + +func ToTLSEncryptionLevel(e protocol.EncryptionLevel) tls.QUICEncryptionLevel { + switch e { + case protocol.EncryptionInitial: + return tls.QUICEncryptionLevelInitial + case protocol.EncryptionHandshake: + return tls.QUICEncryptionLevelHandshake + case protocol.Encryption1RTT: + return tls.QUICEncryptionLevelApplication + case protocol.Encryption0RTT: + return tls.QUICEncryptionLevelEarly + default: + panic(fmt.Sprintf("unexpected encryption level: %s", e)) + } +} + +func FromTLSEncryptionLevel(e tls.QUICEncryptionLevel) protocol.EncryptionLevel { + switch e { + case tls.QUICEncryptionLevelInitial: + return protocol.EncryptionInitial + case tls.QUICEncryptionLevelHandshake: + return protocol.EncryptionHandshake + case tls.QUICEncryptionLevelApplication: + return protocol.Encryption1RTT + case tls.QUICEncryptionLevelEarly: + return protocol.Encryption0RTT + default: + panic(fmt.Sprintf("unexpect encryption level: %s", e)) + } +} + +const extraPrefix = "quic-go1" + +func addExtraPrefix(b []byte) []byte { + return append([]byte(extraPrefix), b...) +} + +func findExtraData(extras [][]byte) []byte { + prefix := []byte(extraPrefix) + for _, extra := range extras { + if len(extra) < len(prefix) || !bytes.Equal(prefix, extra[:len(prefix)]) { + continue + } + return extra[len(prefix):] + } + return nil +} diff --git a/internal/qtls/go_oldversion.go b/internal/qtls/go_oldversion.go index e15f03629a6..0fca80a3881 100644 --- a/internal/qtls/go_oldversion.go +++ b/internal/qtls/go_oldversion.go @@ -1,4 +1,4 @@ -//go:build !go1.19 +//go:build !go1.20 package qtls diff --git a/internal/qtls/qtls_test.go b/internal/qtls/qtls_test.go deleted file mode 100644 index 01aa7170dfc..00000000000 --- a/internal/qtls/qtls_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package qtls - -import ( - "crypto/tls" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("qtls wrapper", func() { - It("gets cipher suites", func() { - for _, id := range []uint16{tls.TLS_AES_128_GCM_SHA256, tls.TLS_AES_256_GCM_SHA384, tls.TLS_CHACHA20_POLY1305_SHA256} { - cs := CipherSuiteTLS13ByID(id) - Expect(cs.ID).To(Equal(id)) - } - }) -}) diff --git a/internal/testdata/cert.go b/internal/testdata/cert.go index 6cc7a091e51..f77a7b2ddbe 100644 --- a/internal/testdata/cert.go +++ b/internal/testdata/cert.go @@ -31,6 +31,7 @@ func GetTLSConfig() *tls.Config { panic(err) } return &tls.Config{ + MinVersion: tls.VersionTLS13, Certificates: []tls.Certificate{cert}, } } diff --git a/mock_crypto_data_handler_test.go b/mock_crypto_data_handler_test.go index 47deb442dd5..d685289180a 100644 --- a/mock_crypto_data_handler_test.go +++ b/mock_crypto_data_handler_test.go @@ -35,10 +35,10 @@ func (m *MockCryptoDataHandler) EXPECT() *MockCryptoDataHandlerMockRecorder { } // HandleMessage mocks base method. -func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) bool { +func (m *MockCryptoDataHandler) HandleMessage(arg0 []byte, arg1 protocol.EncryptionLevel) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) - ret0, _ := ret[0].(bool) + ret0, _ := ret[0].(error) return ret0 } diff --git a/server.go b/server.go index 0f8219e3ae3..6108cc2faed 100644 --- a/server.go +++ b/server.go @@ -228,6 +228,8 @@ func newServer( onClose func(), acceptEarly bool, ) (*baseServer, error) { + tlsConf = tlsConf.Clone() + tlsConf.MinVersion = tls.VersionTLS13 tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) if err != nil { return nil, err diff --git a/transport.go b/transport.go index 7590be326c5..c4d3261d2d1 100644 --- a/transport.go +++ b/transport.go @@ -156,6 +156,8 @@ func (t *Transport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config if t.isSingleUse { onClose = func() { t.Close() } } + tlsConf = tlsConf.Clone() + tlsConf.MinVersion = tls.VersionTLS13 return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, false) } @@ -172,6 +174,8 @@ func (t *Transport) DialEarly(ctx context.Context, addr net.Addr, tlsConf *tls.C if t.isSingleUse { onClose = func() { t.Close() } } + tlsConf = tlsConf.Clone() + tlsConf.MinVersion = tls.VersionTLS13 return dial(ctx, newSendConn(t.conn, addr), t.connIDGenerator, t.handlerMap, tlsConf, conf, onClose, true) }