Skip to content

Commit

Permalink
add a method to retrieve non-QUIC packets from the Transport (#3992)
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Aug 19, 2023
1 parent 6880f88 commit fe3c4f2
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 4 deletions.
65 changes: 65 additions & 0 deletions integrationtests/self/multiplex_test.go
Expand Up @@ -2,9 +2,11 @@ package self_test

import (
"context"
"crypto/rand"
"io"
"net"
"runtime"
"sync/atomic"
"time"

"github.com/quic-go/quic-go"
Expand Down Expand Up @@ -210,4 +212,67 @@ var _ = Describe("Multiplexing", func() {
})
}
})

It("sends and receives non-QUIC packets", func() {
addr1, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn1, err := net.ListenUDP("udp", addr1)
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1}

addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn2, err := net.ListenUDP("udp", addr2)
Expect(err).ToNot(HaveOccurred())
defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2}

server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
runServer(server)
defer server.Close()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var sentPackets, rcvdPackets atomic.Int64
const packetLen = 128
// send a non-QUIC packet every 100µs
go func() {
defer GinkgoRecover()
ticker := time.NewTicker(time.Millisecond / 10)
defer ticker.Stop()
for {
select {
case <-ticker.C:
case <-ctx.Done():
return
}
b := make([]byte, packetLen)
rand.Read(b[1:]) // keep the first byte set to 0, so it's not classified as a QUIC packet
_, err := tr1.WriteTo(b, tr2.Conn.LocalAddr())
Expect(err).ToNot(HaveOccurred())
sentPackets.Add(1)
}
}()

// receive and count non-QUIC packets
go func() {
defer GinkgoRecover()
for {
b := make([]byte, 1024)
n, addr, err := tr2.ReadNonQUICPacket(ctx, b)
if err != nil {
Expect(err).To(MatchError(context.Canceled))
return
}
Expect(addr).To(Equal(tr1.Conn.LocalAddr()))
Expect(n).To(Equal(packetLen))
rcvdPackets.Add(1)
}
}()
dial(tr2, server.Addr())
Eventually(func() int64 { return sentPackets.Load() }).Should(BeNumerically(">", 10))
Eventually(func() int64 { return rcvdPackets.Load() }).Should(BeNumerically(">=", sentPackets.Load()*4/5))
})
})
4 changes: 4 additions & 0 deletions internal/wire/header.go
Expand Up @@ -74,6 +74,10 @@ func parseArbitraryLenConnectionIDs(r *bytes.Reader) (dest, src protocol.Arbitra
return destConnID, srcConnID, nil
}

func IsPotentialQUICPacket(firstByte byte) bool {
return firstByte&0x40 > 0
}

// IsLongHeaderPacket says if this is a Long Header packet
func IsLongHeaderPacket(firstByte byte) bool {
return firstByte&0x80 > 0
Expand Down
53 changes: 51 additions & 2 deletions transport.go
Expand Up @@ -7,12 +7,12 @@ import (
"errors"
"net"
"sync"
"sync/atomic"
"time"

"github.com/quic-go/quic-go/internal/wire"

"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)

Expand Down Expand Up @@ -85,6 +85,9 @@ type Transport struct {
createdConn bool
isSingleUse bool // was created for a single server or client, i.e. by calling quic.Listen or quic.Dial

readingNonQUICPackets atomic.Bool
nonQUICPackets chan receivedPacket

logger utils.Logger
}

Expand Down Expand Up @@ -341,6 +344,13 @@ func (t *Transport) listen(conn rawConn) {
}

func (t *Transport) handlePacket(p receivedPacket) {
if len(p.data) == 0 {
return
}
if !wire.IsPotentialQUICPacket(p.data[0]) && !wire.IsLongHeaderPacket(p.data[0]) {
t.handleNonQUICPacket(p)
return
}
connID, err := wire.ParseConnectionID(p.data, t.connIDLen)
if err != nil {
t.logger.Debugf("error parsing connection ID on packet from %s: %s", p.remoteAddr, err)
Expand Down Expand Up @@ -429,3 +439,42 @@ func (t *Transport) maybeHandleStatelessReset(data []byte) bool {
}
return false
}

func (t *Transport) handleNonQUICPacket(p receivedPacket) {
// Strictly speaking, this is racy,
// but we only care about receiving packets at some point after ReadNonQUICPacket has been called.
if !t.readingNonQUICPackets.Load() {
return
}
select {
case t.nonQUICPackets <- p:
default:
if t.Tracer != nil {
t.Tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropDOSPrevention)
}
}
}

const maxQueuedNonQUICPackets = 32

// ReadNonQUICPacket reads non-QUIC packets received on the underlying connection.
// The detection logic is very simple: Any packet that has the first and second bit of the packet set to 0.
// Note that this is stricter than the detection logic defined in RFC 9443.
func (t *Transport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) {
if err := t.init(false); err != nil {
return 0, nil, err
}
if !t.readingNonQUICPackets.Load() {
t.nonQUICPackets = make(chan receivedPacket, maxQueuedNonQUICPackets)
t.readingNonQUICPackets.Store(true)
}
select {
case <-ctx.Done():
return 0, nil, ctx.Err()
case p := <-t.nonQUICPackets:
n := copy(b, p.data)
return n, p.remoteAddr, nil
case <-t.listening:
return 0, nil, errors.New("closed")
}
}
77 changes: 75 additions & 2 deletions transport_test.go
Expand Up @@ -2,6 +2,7 @@ package quic

import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"errors"
Expand Down Expand Up @@ -122,7 +123,7 @@ var _ = Describe("Transport", func() {
tr.Close()
})

It("drops unparseable packets", func() {
It("drops unparseable QUIC packets", func() {
addr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
packetChan := make(chan packetToRead)
tracer := mocklogging.NewMockTracer(mockCtrl)
Expand All @@ -136,7 +137,7 @@ var _ = Describe("Transport", func() {
tracer.EXPECT().DroppedPacket(addr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropHeaderParseError).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) { close(dropped) })
packetChan <- packetToRead{
addr: addr,
data: []byte{0, 1, 2, 3},
data: []byte{0x40 /* set the QUIC bit */, 1, 2, 3},
}
Eventually(dropped).Should(BeClosed())

Expand Down Expand Up @@ -323,6 +324,78 @@ var _ = Describe("Transport", func() {
conns := getMultiplexer().(*connMultiplexer).conns
Expect(len(conns)).To(BeZero())
})

It("allows receiving non-QUIC packets", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
packetChan := make(chan packetToRead)
tracer := mocklogging.NewMockTracer(mockCtrl)
tr := &Transport{
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: 10,
Tracer: tracer,
}
tr.init(true)
receivedPacketChan := make(chan []byte)
go func() {
defer GinkgoRecover()
b := make([]byte, 100)
n, addr, err := tr.ReadNonQUICPacket(context.Background(), b)
Expect(err).ToNot(HaveOccurred())
Expect(addr).To(Equal(remoteAddr))
receivedPacketChan <- b[:n]
}()
// Receiving of non-QUIC packets is enabled when ReadNonQUICPacket is called.
// Give the Go routine some time to spin up.
time.Sleep(scaleDuration(50 * time.Millisecond))
packetChan <- packetToRead{
addr: remoteAddr,
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
}

Eventually(receivedPacketChan).Should(Receive(Equal([]byte{0, 1, 2, 3})))

// shutdown
close(packetChan)
tr.Close()
})

It("drops non-QUIC packet if the application doesn't process them quickly enough", func() {
remoteAddr := &net.UDPAddr{IP: net.IPv4(9, 8, 7, 6), Port: 1234}
packetChan := make(chan packetToRead)
tracer := mocklogging.NewMockTracer(mockCtrl)
tr := &Transport{
Conn: newMockPacketConn(packetChan),
ConnectionIDLength: 10,
Tracer: tracer,
}
tr.init(true)

ctx, cancel := context.WithCancel(context.Background())
cancel()
_, _, err := tr.ReadNonQUICPacket(ctx, make([]byte, 10))
Expect(err).To(MatchError(context.Canceled))

for i := 0; i < maxQueuedNonQUICPackets; i++ {
packetChan <- packetToRead{
addr: remoteAddr,
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
}
}

done := make(chan struct{})
tracer.EXPECT().DroppedPacket(remoteAddr, logging.PacketTypeNotDetermined, protocol.ByteCount(4), logging.PacketDropDOSPrevention).Do(func(net.Addr, logging.PacketType, protocol.ByteCount, logging.PacketDropReason) {
close(done)
})
packetChan <- packetToRead{
addr: remoteAddr,
data: []byte{0 /* don't set the QUIC bit */, 1, 2, 3},
}
Eventually(done).Should(BeClosed())

// shutdown
close(packetChan)
tr.Close()
})
})

type mockSyscallConn struct {
Expand Down

0 comments on commit fe3c4f2

Please sign in to comment.