Skip to content

Commit

Permalink
use the transport tracer in integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Mar 9, 2024
1 parent 55c05ac commit 30e01b9
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 34 deletions.
2 changes: 2 additions & 0 deletions integrationtests/self/conn_id_test.go
Expand Up @@ -50,6 +50,7 @@ var _ = Describe("Connection ID lengths tests", func() {
ConnectionIDLength: connIDLen,
ConnectionIDGenerator: connIDGenerator,
}
addTracer(tr)
ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
go func() {
Expand Down Expand Up @@ -92,6 +93,7 @@ var _ = Describe("Connection ID lengths tests", func() {
ConnectionIDLength: connIDLen,
ConnectionIDGenerator: connIDGenerator,
}
addTracer(tr)
defer tr.Close()
cl, err := tr.Dial(
context.Background(),
Expand Down
1 change: 1 addition & 0 deletions integrationtests/self/handshake_rtt_test.go
Expand Up @@ -64,6 +64,7 @@ var _ = Describe("Handshake RTT tests", func() {
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
}
addTracer(tr)
defer tr.Close()
ln, err := tr.Listen(serverTLSConfig, serverConfig)
Expect(err).ToNot(HaveOccurred())
Expand Down
17 changes: 11 additions & 6 deletions integrationtests/self/handshake_test.go
Expand Up @@ -328,7 +328,10 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
pconn, err = net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
dialer = &quic.Transport{Conn: pconn, ConnectionIDLength: 4}
dialer = &quic.Transport{
Conn: pconn,
ConnectionIDLength: 4,
}
})

AfterEach(func() {
Expand Down Expand Up @@ -431,9 +434,8 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
tr := quic.Transport{
Conn: udpConn,
}
tr := &quic.Transport{Conn: udpConn}
addTracer(tr)
defer tr.Close()
tlsConf := &tls.Config{}
done := make(chan struct{})
Expand Down Expand Up @@ -476,10 +478,11 @@ var _ = Describe("Handshake tests", func() {

It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() {
const limit = 3
tr := quic.Transport{
tr := &quic.Transport{
Conn: conn,
MaxUnvalidatedHandshakes: limit,
}
addTracer(tr)
defer tr.Close()

// Block all handshakes.
Expand Down Expand Up @@ -541,10 +544,11 @@ var _ = Describe("Handshake tests", func() {

It("rejects connections when the number of handshakes reaches MaxHandshakes", func() {
const limit = 3
tr := quic.Transport{
tr := &quic.Transport{
Conn: conn,
MaxHandshakes: limit,
}
addTracer(tr)
defer tr.Close()

// Block all handshakes.
Expand Down Expand Up @@ -717,6 +721,7 @@ var _ = Describe("Handshake tests", func() {
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
}
addTracer(tr)
defer tr.Close()
server, err := tr.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
Expand Down
2 changes: 2 additions & 0 deletions integrationtests/self/mitm_test.go
Expand Up @@ -41,6 +41,7 @@ var _ = Describe("MITM test", func() {
Conn: c,
ConnectionIDLength: connIDLen,
}
addTracer(serverTransport)
if forceAddressValidation {
serverTransport.MaxUnvalidatedHandshakes = -1
}
Expand Down Expand Up @@ -86,6 +87,7 @@ var _ = Describe("MITM test", func() {
Conn: clientUDPConn,
ConnectionIDLength: connIDLen,
}
addTracer(clientTransport)
})

Context("unsuccessful attacks", func() {
Expand Down
7 changes: 7 additions & 0 deletions integrationtests/self/multiplex_test.go
Expand Up @@ -74,6 +74,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
addTracer(tr)

done1 := make(chan struct{})
done2 := make(chan struct{})
Expand Down Expand Up @@ -109,6 +110,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
addTracer(tr)

done1 := make(chan struct{})
done2 := make(chan struct{})
Expand Down Expand Up @@ -139,6 +141,7 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn.Close()
tr := &quic.Transport{Conn: conn}
addTracer(tr)
server, err := tr.Listen(
getTLSConfig(),
getQuicConfig(nil),
Expand Down Expand Up @@ -167,13 +170,15 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1}
addTracer(tr1)

addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn2, err := net.ListenUDP("udp", addr2)
Expect(err).ToNot(HaveOccurred())
defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2}
addTracer(tr2)

server1, err := tr1.Listen(
getTLSConfig(),
Expand Down Expand Up @@ -220,13 +225,15 @@ var _ = Describe("Multiplexing", func() {
Expect(err).ToNot(HaveOccurred())
defer conn1.Close()
tr1 := &quic.Transport{Conn: conn1}
addTracer(tr1)

addr2, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn2, err := net.ListenUDP("udp", addr2)
Expect(err).ToNot(HaveOccurred())
defer conn2.Close()
tr2 := &quic.Transport{Conn: conn2}
addTracer(tr2)

server, err := tr1.Listen(getTLSConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
Expand Down
58 changes: 37 additions & 21 deletions integrationtests/self/self_suite_test.go
Expand Up @@ -86,7 +86,6 @@ var (
logBuf *syncedBuffer
versionParam string

qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer
enableQlog bool

version quic.Version
Expand Down Expand Up @@ -138,9 +137,6 @@ func init() {
}

var _ = BeforeSuite(func() {
if enableQlog {
qlogTracer = tools.NewQlogger(GinkgoWriter)
}
switch versionParam {
case "1":
version = quic.Version1
Expand Down Expand Up @@ -175,28 +171,48 @@ func getQuicConfig(conf *quic.Config) *quic.Config {
} else {
conf = conf.Clone()
}
if enableQlog {
if conf.Tracer == nil {
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
qlogTracer(ctx, p, connID),
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
&logging.ConnectionTracer{},
)
}
} else if qlogTracer != nil {
origTracer := conf.Tracer
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
qlogTracer(ctx, p, connID),
origTracer(ctx, p, connID),
)
}
if !enableQlog {
return conf
}
if conf.Tracer == nil {
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
&logging.ConnectionTracer{},
)
}
return conf
}
origTracer := conf.Tracer
conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
return logging.NewMultiplexedConnectionTracer(
tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
origTracer(ctx, p, connID),
)
}
return conf
}

func addTracer(tr *quic.Transport) {
if !enableQlog {
return
}
if tr.Tracer == nil {
tr.Tracer = logging.NewMultiplexedTracer(
tools.QlogTracer(GinkgoWriter),
// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
&logging.Tracer{},
)
return
}
origTracer := tr.Tracer
tr.Tracer = logging.NewMultiplexedTracer(
tools.QlogTracer(GinkgoWriter),
origTracer,
)
}

var _ = BeforeEach(func() {
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)

Expand Down
2 changes: 2 additions & 0 deletions integrationtests/self/zero_rtt_test.go
Expand Up @@ -175,6 +175,7 @@ var _ = Describe("0-RTT", func() {
Conn: udpConn,
ConnectionIDLength: connIDLen,
}
addTracer(tr)
defer tr.Close()
conn, err = tr.DialEarly(
context.Background(),
Expand Down Expand Up @@ -463,6 +464,7 @@ var _ = Describe("0-RTT", func() {
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
}
addTracer(tr)
defer tr.Close()
ln, err := tr.ListenEarly(
tlsConf,
Expand Down
21 changes: 15 additions & 6 deletions integrationtests/tools/qlog.go
Expand Up @@ -7,20 +7,29 @@ import (
"io"
"log"
"os"
"time"

"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/qlog"
)

func NewQlogger(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
func QlogTracer(logger io.Writer) *logging.Tracer {
filename := fmt.Sprintf("log_%s_transport.qlog", time.Now().Format("2006-01-02T15:04:05"))
fmt.Fprintf(logger, "Creating %s.\n", filename)
f, err := os.Create(filename)
if err != nil {
log.Fatalf("failed to create qlog file: %s", err)
return nil
}
bw := bufio.NewWriter(f)
return qlog.NewTracer(utils.NewBufferedWriteCloser(bw, f))
}

func NewQlogConnectionTracer(logger io.Writer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return func(_ context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
role := "server"
if p == logging.PerspectiveClient {
role = "client"
}
filename := fmt.Sprintf("log_%s_%s.qlog", connID, role)
filename := fmt.Sprintf("log_%s_%s.qlog", connID, p.String())
fmt.Fprintf(logger, "Creating %s.\n", filename)
f, err := os.Create(filename)
if err != nil {
Expand Down
Expand Up @@ -65,7 +65,7 @@ func maybeAddQLOGTracer(c *quic.Config) *quic.Config {
if !enableQlog {
return c
}
qlogger := tools.NewQlogger(GinkgoWriter)
qlogger := tools.NewQlogConnectionTracer(GinkgoWriter)
if c.Tracer == nil {
c.Tracer = qlogger
} else if qlogger != nil {
Expand Down

0 comments on commit 30e01b9

Please sign in to comment.