diff --git a/integrationtests/self/multiplex_test.go b/integrationtests/self/multiplex_test.go index 9b00bc34358..72c858a10cf 100644 --- a/integrationtests/self/multiplex_test.go +++ b/integrationtests/self/multiplex_test.go @@ -2,9 +2,11 @@ package self_test import ( "context" + "crypto/rand" "io" "net" "runtime" + "sync/atomic" "time" "github.com/quic-go/quic-go" @@ -210,4 +212,67 @@ var _ = Describe("Multiplexing", func() { }) } }) + + It("sends and receives non-QUIC packets", func() { + addr1, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn1, err := net.ListenUDP("udp", addr1) + Expect(err).ToNot(HaveOccurred()) + defer conn1.Close() + tr1 := &quic.Transport{Conn: conn1} + + addr2, err := net.ResolveUDPAddr("udp", "localhost:0") + Expect(err).ToNot(HaveOccurred()) + conn2, err := net.ListenUDP("udp", addr2) + Expect(err).ToNot(HaveOccurred()) + defer conn2.Close() + tr2 := &quic.Transport{Conn: conn2} + + server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil)) + Expect(err).ToNot(HaveOccurred()) + runServer(server) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var sentPackets, rcvdPackets atomic.Int64 + const packetLen = 128 + // send a non-QUIC packet every 100µs + go func() { + defer GinkgoRecover() + ticker := time.NewTicker(time.Millisecond / 10) + defer ticker.Stop() + for { + select { + case <-ticker.C: + case <-ctx.Done(): + return + } + b := make([]byte, packetLen) + rand.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet + _, err := tr1.WriteTo(b, tr2.Conn.LocalAddr()) + Expect(err).ToNot(HaveOccurred()) + sentPackets.Add(1) + } + }() + + // receive and count non-QUIC packets + go func() { + defer GinkgoRecover() + for { + b := make([]byte, 1024) + n, addr, err := tr2.ReadNonQUICPacket(ctx, b) + if err != nil { + Expect(err).To(MatchError(context.Canceled)) + return + } + Expect(addr).To(Equal(tr1.Conn.LocalAddr())) + Expect(n).To(Equal(packetLen)) + rcvdPackets.Add(1) + } + }() + dial(tr2, server.Addr()) + Eventually(func() int64 { return sentPackets.Load() }).Should(BeNumerically(">", 10)) + Eventually(func() int64 { return rcvdPackets.Load() }).Should(BeNumerically(">=", sentPackets.Load()*4/5)) + }) }) diff --git a/internal/wire/header.go b/internal/wire/header.go index e2dc72e421f..0c60f4dd948 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -74,6 +74,10 @@ func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.Arbitra return destConnID, srcConnID, nil } +func IsPotentialQUICPacket(firstByte byte) bool { + return firstByte&0x40 > 0 +} + // IsLongHeaderPacket says if this is a Long Header packet func IsLongHeaderPacket(firstByte byte) bool { return firstByte&0x80 > 0 diff --git a/transport.go b/transport.go index ae44e3da638..fe6dc1fc38e 100644 --- a/transport.go +++ b/transport.go @@ -7,12 +7,12 @@ import ( "errors" "net" "sync" + "sync/atomic" "time" - "github.com/quic-go/quic-go/internal/wire" - "github.com/quic-go/quic-go/internal/protocol" "github.com/quic-go/quic-go/internal/utils" + "github.com/quic-go/quic-go/internal/wire" "github.com/quic-go/quic-go/logging" ) @@ -85,6 +85,9 @@ type Transport struct { createdConn bool isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial + readingNonQUICPackets atomic.Bool + nonQUICPackets chan receivedPacket + logger utils.Logger } @@ -341,6 +344,13 @@ func (t *Transport) listen(conn rawConn) { } func (t *Transport) handlePacket(p receivedPacket) { + if len(p.data) == 0 { + return + } + if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) { + t.handleNonQUICPacket(p) + return + } connID, err := wire.ParseConnectionID(p.data, t.connIDLen) if err != nil { t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err) @@ -429,3 +439,42 @@ func (t *Transport) maybeHandleStatelessReset(data []byte) bool { } return false } + +func (t *Transport) handleNonQUICPacket(p receivedPacket) { + // Strictly speaking, this is racy, + // but we only care about receiving packets at some point after ReadNonQUICPacket has been called. + if !t.readingNonQUICPackets.Load() { + return + } + select { + case t.nonQUICPackets <- p: + default: + if t.Tracer != nil { + t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention) + } + } +} + +const maxQueuedNonQUICPackets = 32 + +// ReadNonQUICPacket reads non-QUIC packets received on the underlying connection. +// The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0. +// Note that this is stricter than the detection logic defined in RFC 9443. +func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) { + if err := t.init(false); err != nil { + return 0, nil, err + } + if !t.readingNonQUICPackets.Load() { + t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets) + t.readingNonQUICPackets.Store(true) + } + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + case p := <-t.nonQUICPackets: + n := copy(b, p.data) + return n, p.remoteAddr, nil + case <-t.listening: + return 0, nil, errors.New("closed") + } +} diff --git a/transport_test.go b/transport_test.go index f46affb3dd9..93e1d32ab82 100644 --- a/transport_test.go +++ b/transport_test.go @@ -2,6 +2,7 @@ package quic import ( "bytes" + "context" "crypto/rand" "crypto/tls" "errors" @@ -122,7 +123,7 @@ var _ = Describe("Transport", func() { tr.Close() }) - It("drops unparseable packets", func() { + It("drops unparseable QUIC packets", func() { addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} packetChan := make(chan packetToRead) tracer := mocklogging.NewMockTracer(mockCtrl) @@ -136,7 +137,7 @@ var _ = Describe("Transport", func() { tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) }) packetChan <- packetToRead{ addr: addr, - data: []byte{0, 1, 2, 3}, + data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3}, } Eventually(dropped).Should(BeClosed()) @@ -323,6 +324,78 @@ var _ = Describe("Transport", func() { conns := getMultiplexer().(*connMultiplexer).conns Expect(len(conns)).To(BeZero()) }) + + It("allows receiving non-QUIC packets", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} + packetChan := make(chan packetToRead) + tracer := mocklogging.NewMockTracer(mockCtrl) + tr := &Transport{ + Conn: newMockPacketConn(packetChan), + ConnectionIDLength: 10, + Tracer: tracer, + } + tr.init(true) + receivedPacketChan := make(chan []byte) + go func() { + defer GinkgoRecover() + b := make([]byte, 100) + n, addr, err := tr.ReadNonQUICPacket(context.Background(), b) + Expect(err).ToNot(HaveOccurred()) + Expect(addr).To(Equal(remoteAddr)) + receivedPacketChan <- b[:n] + }() + // Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called. + // Give the Go routine some time to spin up. + time.Sleep(scaleDuration(50 * time.Millisecond)) + packetChan <- packetToRead{ + addr: remoteAddr, + data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, + } + + Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3}))) + + // shutdown + close(packetChan) + tr.Close() + }) + + It("drops non-QUIC packet if the application doesn't process them quickly enough", func() { + remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234} + packetChan := make(chan packetToRead) + tracer := mocklogging.NewMockTracer(mockCtrl) + tr := &Transport{ + Conn: newMockPacketConn(packetChan), + ConnectionIDLength: 10, + Tracer: tracer, + } + tr.init(true) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10)) + Expect(err).To(MatchError(context.Canceled)) + + for i := 0; i < maxQueuedNonQUICPackets; i++ { + packetChan <- packetToRead{ + addr: remoteAddr, + data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, + } + } + + done := make(chan struct{}) + tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { + close(done) + }) + packetChan <- packetToRead{ + addr: remoteAddr, + data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3}, + } + Eventually(done).Should(BeClosed()) + + // shutdown + close(packetChan) + tr.Close() + }) }) type mockSyscallConn struct {