Skip to content

Commit

Permalink
use Transport.VerifySourceAddress to control the Retry Mechanism (#4362)
Browse files Browse the repository at this point in the history
* use Transport.VerifySourceAddress to control the Retry Mechanism

This can be used to rate-limit handshakes originating from unverified
source addresses. Rate-limiting for handshakes can be implemented using
the GetConfigForClient callback on the Config.

* pass the remote address to Transport.VerifySourceAddress
  • Loading branch information
marten-seemann committed Mar 15, 2024
1 parent 497d3f5 commit 9971fed
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 382 deletions.
1 change: 1 addition & 0 deletions go.mod
Expand Up @@ -13,6 +13,7 @@ require (
golang.org/x/net v0.10.0
golang.org/x/sync v0.2.0
golang.org/x/sys v0.8.0
golang.org/x/time v0.5.0
)

require (
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Expand Up @@ -175,6 +175,8 @@ golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030000716-a0a13e073c7b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
Expand Down
2 changes: 2 additions & 0 deletions integrationtests/gomodvendor/go.sum
Expand Up @@ -41,6 +41,8 @@ golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw=
Expand Down
5 changes: 2 additions & 3 deletions integrationtests/self/handshake_drop_test.go
Expand Up @@ -11,11 +11,10 @@ import (
"sync/atomic"
"time"

"github.com/quic-go/quic-go/quicvarint"

"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/quicvarint"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -50,7 +49,7 @@ var _ = Describe("Handshake drop tests", func() {
Expect(err).ToNot(HaveOccurred())
tr := &quic.Transport{Conn: conn}
if doRetry {
tr.MaxUnvalidatedHandshakes = -1
tr.VerifySourceAddress = func(net.Addr) bool { return true }
}
ln, err = tr.Listen(tlsConf, conf)
Expect(err).ToNot(HaveOccurred())
Expand Down
6 changes: 3 additions & 3 deletions integrationtests/self/handshake_rtt_test.go
Expand Up @@ -54,15 +54,15 @@ var _ = Describe("Handshake RTT tests", func() {

// 1 RTT for verifying the source address
// 1 RTT for the TLS handshake
It("is forward-secure after 2 RTTs", func() {
It("is forward-secure after 2 RTTs with Retry", func() {
laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
udpConn, err := net.ListenUDP("udp", laddr)
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
Conn: udpConn,
VerifySourceAddress: func(net.Addr) bool { return true },
}
addTracer(tr)
defer tr.Close()
Expand Down
147 changes: 2 additions & 145 deletions integrationtests/self/handshake_test.go
Expand Up @@ -7,15 +7,13 @@ import (
"fmt"
"io"
"net"
"sync/atomic"
"time"

"github.com/quic-go/quic-go"
quicproxy "github.com/quic-go/quic-go/integrationtests/tools/proxy"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/qtls"
"github.com/quic-go/quic-go/logging"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -464,147 +462,6 @@ var _ = Describe("Handshake tests", func() {
})
})

Context("limiting handshakes", func() {
var conn *net.UDPConn

BeforeEach(func() {
addr, err := net.ResolveUDPAddr("udp", "localhost:0")
Expect(err).ToNot(HaveOccurred())
conn, err = net.ListenUDP("udp", addr)
Expect(err).ToNot(HaveOccurred())
})

AfterEach(func() { conn.Close() })

It("sends a Retry when the number of handshakes reaches MaxUnvalidatedHandshakes", func() {
const limit = 3
tr := &quic.Transport{
Conn: conn,
MaxUnvalidatedHandshakes: limit,
}
addTracer(tr)
defer tr.Close()

// Block all handshakes.
handshakes := make(chan struct{})
var tlsConf tls.Config
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
handshakes <- struct{}{}
return getTLSConfig(), nil
}
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer ln.Close()

const additional = 2
results := make([]struct{ retry, closed atomic.Bool }, limit+additional)
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
// exactly 2 to experience a Retry.
for i := 0; i < limit+additional; i++ {
go func(index int) {
defer GinkgoRecover()
quicConf := getQuicConfig(&quic.Config{
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
ReceivedRetry: func(*logging.Header) { results[index].retry.Store(true) },
ClosedConnection: func(error) { results[index].closed.Store(true) },
}
},
})
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
}(i)
}
numRetries := func() (n int) {
for i := 0; i < limit+additional; i++ {
if results[i].retry.Load() {
n++
}
}
return
}
numClosed := func() (n int) {
for i := 0; i < limit+2; i++ {
if results[i].closed.Load() {
n++
}
}
return
}
Eventually(numRetries).Should(Equal(additional))
// allow the handshakes to complete
for i := 0; i < limit+additional; i++ {
Eventually(handshakes).Should(Receive())
}
Eventually(numClosed).Should(Equal(limit + additional))
Expect(numRetries()).To(Equal(additional)) // just to be on the safe side
})

It("rejects connections when the number of handshakes reaches MaxHandshakes", func() {
const limit = 3
tr := &quic.Transport{
Conn: conn,
MaxHandshakes: limit,
}
addTracer(tr)
defer tr.Close()

// Block all handshakes.
handshakes := make(chan struct{})
var tlsConf tls.Config
tlsConf.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) {
handshakes <- struct{}{}
return getTLSConfig(), nil
}
ln, err := tr.Listen(&tlsConf, getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
defer ln.Close()

const additional = 2
// Dial the server from multiple clients. All handshakes will get blocked on the handshakes channel.
// Since we're dialing limit+2 times, we expect limit handshakes to go through with a Retry, and
// exactly 2 to experience a Retry.
var numSuccessful, numFailed atomic.Int32
for i := 0; i < limit+additional; i++ {
go func() {
defer GinkgoRecover()
quicConf := getQuicConfig(&quic.Config{
Tracer: func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
return &logging.ConnectionTracer{
ReceivedRetry: func(*logging.Header) { Fail("didn't expect any Retry") },
}
},
})
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), quicConf)
if err != nil {
var transportErr *quic.TransportError
if !errors.As(err, &transportErr) || transportErr.ErrorCode != qerr.ConnectionRefused {
Fail(fmt.Sprintf("expected CONNECTION_REFUSED error, got %v", err))
}
numFailed.Add(1)
return
}
numSuccessful.Add(1)
conn.CloseWithError(0, "")
}()
}
Eventually(func() int { return int(numFailed.Load()) }).Should(Equal(additional))
// allow the handshakes to complete
for i := 0; i < limit; i++ {
Eventually(handshakes).Should(Receive())
}
Eventually(func() int { return int(numSuccessful.Load()) }).Should(Equal(limit))

// make sure that the server is reachable again after these handshakes have completed
go func() { <-handshakes }() // allow this handshake to complete immediately
conn, err := quic.DialAddr(context.Background(), ln.Addr().String(), getTLSClientConfig(), getQuicConfig(nil))
Expect(err).ToNot(HaveOccurred())
conn.CloseWithError(0, "")
})
})

Context("ALPN", func() {
It("negotiates an application protocol", func() {
ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig)
Expand Down Expand Up @@ -718,8 +575,8 @@ var _ = Describe("Handshake tests", func() {
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
Conn: udpConn,
VerifySourceAddress: func(net.Addr) bool { return true },
}
addTracer(tr)
defer tr.Close()
Expand Down
2 changes: 1 addition & 1 deletion integrationtests/self/mitm_test.go
Expand Up @@ -43,7 +43,7 @@ var _ = Describe("MITM test", func() {
}
addTracer(serverTransport)
if forceAddressValidation {
serverTransport.MaxUnvalidatedHandshakes = -1
serverTransport.VerifySourceAddress = func(net.Addr) bool { return true }
}
ln, err := serverTransport.Listen(getTLSConfig(), serverConfig)
Expect(err).ToNot(HaveOccurred())
Expand Down
4 changes: 2 additions & 2 deletions integrationtests/self/zero_rtt_test.go
Expand Up @@ -461,8 +461,8 @@ var _ = Describe("0-RTT", func() {
Expect(err).ToNot(HaveOccurred())
defer udpConn.Close()
tr := &quic.Transport{
Conn: udpConn,
MaxUnvalidatedHandshakes: -1,
Conn: udpConn,
VerifySourceAddress: func(net.Addr) bool { return true },
}
addTracer(tr)
defer tr.Close()
Expand Down
2 changes: 1 addition & 1 deletion interop/http09/server.go
Expand Up @@ -71,7 +71,7 @@ func (s *Server) ListenAndServe() error {
tlsConf.NextProtos = []string{h09alpn}
tr := quic.Transport{Conn: conn}
if s.ForceRetry {
tr.MaxUnvalidatedHandshakes = -1
tr.VerifySourceAddress = func(net.Addr) bool { return true }
}
ln, err := tr.ListenEarly(tlsConf, s.QuicConfig)
if err != nil {
Expand Down

0 comments on commit 9971fed

Please sign in to comment.