Skip to content

Commit

Permalink
DNS caching
Browse files Browse the repository at this point in the history
This commit introduces DNS caching with the -dns-ttl flag.

Supersedes #576

Signed-off-by: Tomás Senart <tsenart@gmail.com>
  • Loading branch information
tsenart committed Jul 16, 2023
1 parent 556bf61 commit 174d804
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 44 deletions.
5 changes: 5 additions & 0 deletions attack.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ func attackCmd() command {
fs.Var(&opts.laddr, "laddr", "Local IP address")
fs.BoolVar(&opts.keepalive, "keepalive", true, "Use persistent connections")
fs.StringVar(&opts.unixSocket, "unix-socket", "", "Connect over a unix socket. This overrides the host address in target URLs")
fs.Var(&dnsTTLFlag{&opts.dnsTTL}, "dns-ttl", "Cache DNS lookups for the given duration [-1 = disabled, 0 = forever]")
systemSpecificFlags(fs, opts)

return command{fs, func(args []string) error {
Expand Down Expand Up @@ -99,6 +100,7 @@ type attackOpts struct {
keepalive bool
resolvers csl
unixSocket string
dnsTTL time.Duration
}

// attack validates the attack arguments, sets up the
Expand All @@ -116,6 +118,8 @@ func attack(opts *attackOpts) (err error) {
net.DefaultResolver = res
}

net.DefaultResolver.PreferGo = true

files := map[string]io.Reader{}
for _, filename := range []string{opts.targetsf, opts.bodyf} {
if filename == "" {
Expand Down Expand Up @@ -188,6 +192,7 @@ func attack(opts *attackOpts) (err error) {
vegeta.UnixSocket(opts.unixSocket),
vegeta.ProxyHeader(proxyHdr),
vegeta.ChunkedBody(opts.chunked),
vegeta.DNSCaching(opts.dnsTTL),
)

res := atk.Attack(tr, opts.rate, opts.duration, opts.name)
Expand Down
21 changes: 21 additions & 0 deletions flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,24 @@ func (f *maxBodyFlag) String() string {
}
return datasize.ByteSize(*(f.n)).String()
}

type dnsTTLFlag struct{ ttl *time.Duration }

func (f *dnsTTLFlag) Set(v string) (err error) {
if v == "-1" {
*(f.ttl) = -1
return nil
}

*(f.ttl), err = time.ParseDuration(v)
return err
}

func (f *dnsTTLFlag) String() string {
if f.ttl == nil {
return ""
} else if *(f.ttl) == -1 {
return "-1"
}
return f.ttl.String()
}
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ require (
require (
github.com/iancoleman/orderedmap v0.3.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/rs/dnscache v0.0.0-20211102005908-e0241e321417 // indirect
github.com/shurcooL/httpfs v0.0.0-20230704072500-f1e31cf0ba5c // indirect
github.com/shurcooL/vfsgen v0.0.0-20230704071429-0000e147ea92 // indirect
golang.org/x/mod v0.8.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.10.0 // indirect
golang.org/x/text v0.11.0 // indirect
golang.org/x/tools v0.6.0 // indirect
Expand Down
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo=
github.com/miekg/dns v1.1.55/go.mod h1:uInx36IzPl7FYnDcMeVWxj9byh7DutNykX4G9Sj60FY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/dnscache v0.0.0-20211102005908-e0241e321417 h1:Lt9DzQALzHoDwMBGJ6v8ObDPR0dzr2a6sXTB1Fq7IHs=
github.com/rs/dnscache v0.0.0-20211102005908-e0241e321417/go.mod h1:qe5TWALJ8/a1Lqznoc5BDHpYX/8HU60Hm2AwRmqzxqA=
github.com/shurcooL/httpfs v0.0.0-20230704072500-f1e31cf0ba5c h1:aqg5Vm5dwtvL+YgDpBcK1ITf3o96N/K7/wsRXQnUTEs=
github.com/shurcooL/httpfs v0.0.0-20230704072500-f1e31cf0ba5c/go.mod h1:owqhoLW1qZoYLZzLnBw+QkPP9WZnjlSWihhxAJC1+/M=
github.com/shurcooL/vfsgen v0.0.0-20230704071429-0000e147ea92 h1:OfRzdxCzDhp+rsKWXuOO2I/quKMJ/+TQwVbIP/gltZg=
github.com/shurcooL/vfsgen v0.0.0-20230704071429-0000e147ea92/go.mod h1:7/OT02F6S6I7v6WXb+IjhMuZEYfH/RJ5RwEWnEo5BMg=
github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d h1:X4+kt6zM/OVO6gbJdAfJR60MGPsqCzbtXNnjoGqdfAs=
github.com/streadway/quantile v0.0.0-20220407130108-4246515d968d/go.mod h1:lbP8tGiBjZ5YWIc2fzuRpTaz0b/53vT6PEs3QuAWzuU=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
Expand All @@ -39,7 +45,9 @@ golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
Expand Down
175 changes: 147 additions & 28 deletions lib/attack.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import (
"io"
"io/ioutil"
"math"
"math/rand"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"

"github.com/rs/dnscache"
"golang.org/x/net/http2"
)

Expand All @@ -27,9 +29,6 @@ type Attacker struct {
maxWorkers uint64
maxBody int64
redirects int
seqmu sync.Mutex
seq uint64
began time.Time
chunked bool
}

Expand Down Expand Up @@ -73,7 +72,6 @@ func NewAttacker(opts ...func(*Attacker)) *Attacker {
workers: DefaultWorkers,
maxWorkers: DefaultMaxWorkers,
maxBody: DefaultMaxBody,
began: time.Now(),
}

a.dialer = &net.Dialer{
Expand All @@ -85,7 +83,7 @@ func NewAttacker(opts ...func(*Attacker)) *Attacker {
Timeout: DefaultTimeout,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: a.dialer.Dial,
DialContext: a.dialer.DialContext,
TLSClientConfig: DefaultTLSConfig,
MaxIdleConnsPerHost: DefaultConnections,
MaxConnsPerHost: DefaultMaxConnections,
Expand Down Expand Up @@ -177,7 +175,7 @@ func LocalAddr(addr net.IPAddr) func(*Attacker) {
return func(a *Attacker) {
tr := a.client.Transport.(*http.Transport)
a.dialer.LocalAddr = &net.TCPAddr{IP: addr.IP, Zone: addr.Zone}
tr.Dial = a.dialer.Dial
tr.DialContext = a.dialer.DialContext
}
}

Expand All @@ -189,7 +187,7 @@ func KeepAlive(keepalive bool) func(*Attacker) {
tr.DisableKeepAlives = !keepalive
if !keepalive {
a.dialer.KeepAlive = 0
tr.Dial = a.dialer.Dial
tr.DialContext = a.dialer.DialContext
}
}
}
Expand Down Expand Up @@ -223,8 +221,8 @@ func H2C(enabled bool) func(*Attacker) {
if tr := a.client.Transport.(*http.Transport); enabled {
a.client.Transport = &http2.Transport{
AllowHTTP: true,
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
return tr.Dial(network, addr)
DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
return tr.DialContext(ctx, network, addr)
},
}
}
Expand Down Expand Up @@ -263,6 +261,119 @@ func ProxyHeader(h http.Header) func(*Attacker) {
}
}

// DNSCaching returns a functional option that enables DNS caching for
// the given ttl. When ttl is zero cached entries will never expire.
// When ttl is non-zero, this will start a refresh go-routine that updates
// the cache every ttl interval. This go-routine will be stopped when the
// attack is stopped.
// When the ttl is negative, no caching will be performed.
func DNSCaching(ttl time.Duration) func(*Attacker) {
return func(a *Attacker) {
if ttl < 0 {
return
}

if tr, ok := a.client.Transport.(*http.Transport); ok {
dial := tr.DialContext
if dial == nil {
dial = a.dialer.DialContext
}

resolver := &dnscache.Resolver{}

if ttl != 0 {
go func() {
refresh := time.NewTicker(ttl)
defer refresh.Stop()
for {
select {
case <-refresh.C:
resolver.Refresh(true)
case <-a.stopch:
return
}
}
}()
}

rng := rand.New(rand.NewSource(time.Now().UnixNano()))

tr.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}

ips, err := resolver.LookupHost(ctx, host)
if err != nil {
return nil, err
}

if len(ips) == 0 {
return nil, &net.DNSError{Err: "no such host", Name: addr}
}

// Pick a random IP from each IP family and dial each concurrently.
// The first that succeeds wins, the other gets canceled.

rng.Shuffle(len(ips), func(i, j int) { ips[i], ips[j] = ips[j], ips[i] })

// In place filtering of ips to only include the first IPv4 and IPv6.
j := 0
for i := 0; i < len(ips) && j < 2; i++ {
ip := net.ParseIP(ips[i])
switch {
case len(ip.To4()) == net.IPv4len && j == 0:
fallthrough
case len(ip) == net.IPv6len && j == 1:
ips[j] = ips[i]
j++
}
}
ips = ips[:j]

type result struct {
conn net.Conn
err error
}

ch := make(chan result, len(ips))
ctx, cancel := context.WithCancel(ctx)
defer cancel()

for _, ip := range ips {
go func(ip string) {
conn, err := dial(ctx, network, net.JoinHostPort(ip, port))
ch <- result{conn, err}
}(ip)
}

for i := 0; i < cap(ch); i++ {
select {
case <-ctx.Done():
return nil, ctx.Err()
case r := <-ch:
if err = r.err; err != nil {
continue
}
return r.conn, nil
}
}

return nil, err
}
}
}
}

type attack struct {
name string
began time.Time

seqmu sync.Mutex
seq uint64
}

// Attack reads its Targets from the passed Targeter and attacks them at
// the rate specified by the Pacer. When the duration is zero the attack
// runs until Stop is called. Results are sent to the returned channel as soon
Expand All @@ -275,21 +386,29 @@ func (a *Attacker) Attack(tr Targeter, p Pacer, du time.Duration, name string) <
workers = a.maxWorkers
}

atk := &attack{
name: name,
began: time.Now(),
}

results := make(chan *Result)
ticks := make(chan struct{})
for i := uint64(0); i < workers; i++ {
wg.Add(1)
go a.attack(tr, name, &wg, ticks, results)
go a.attack(tr, atk, &wg, ticks, results)
}

go func() {
defer close(results)
defer wg.Wait()
defer close(ticks)

began, count := time.Now(), uint64(0)
defer func() {
close(ticks)
wg.Wait()
close(results)
a.Stop()
}()

count := uint64(0)
for {
elapsed := time.Since(began)
elapsed := time.Since(atk.began)
if du > 0 && elapsed > du {
return
}
Expand All @@ -312,7 +431,7 @@ func (a *Attacker) Attack(tr Targeter, p Pacer, du time.Duration, name string) <
// all workers are blocked. start one more and try again
workers++
wg.Add(1)
go a.attack(tr, name, &wg, ticks, results)
go a.attack(tr, atk, &wg, ticks, results)
}
}

Expand Down Expand Up @@ -342,25 +461,25 @@ func (a *Attacker) Stop() bool {
}
}

func (a *Attacker) attack(tr Targeter, name string, workers *sync.WaitGroup, ticks <-chan struct{}, results chan<- *Result) {
func (a *Attacker) attack(tr Targeter, atk *attack, workers *sync.WaitGroup, ticks <-chan struct{}, results chan<- *Result) {
defer workers.Done()
for range ticks {
results <- a.hit(tr, name)
results <- a.hit(tr, atk)
}
}

func (a *Attacker) hit(tr Targeter, name string) *Result {
func (a *Attacker) hit(tr Targeter, atk *attack) *Result {
var (
res = Result{Attack: name}
res = Result{Attack: atk.name}
tgt Target
err error
)

a.seqmu.Lock()
res.Timestamp = a.began.Add(time.Since(a.began))
res.Seq = a.seq
a.seq++
a.seqmu.Unlock()
atk.seqmu.Lock()
res.Timestamp = atk.began.Add(time.Since(atk.began))
res.Seq = atk.seq
atk.seq++
atk.seqmu.Unlock()

defer func() {
res.Latency = time.Since(res.Timestamp)
Expand All @@ -382,8 +501,8 @@ func (a *Attacker) hit(tr Targeter, name string) *Result {
return &res
}

if name != "" {
req.Header.Set("X-Vegeta-Attack", name)
if atk.name != "" {
req.Header.Set("X-Vegeta-Attack", atk.name)
}

req.Header.Set("X-Vegeta-Seq", strconv.FormatUint(res.Seq, 10))
Expand Down
Loading

0 comments on commit 174d804

Please sign in to comment.