Skip to content

Commit

Permalink
Fix test and reduce test scope
Browse files Browse the repository at this point in the history
Reduces the amount of ressources required to run the tests by only
testing the TBF itself and not the router setup around it.
  • Loading branch information
mengelbart committed Dec 8, 2021
1 parent dd4cd2f commit 78491f1
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 172 deletions.
27 changes: 22 additions & 5 deletions vnet/tbf.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package vnet

import (
"context"
"sync"
"time"
)
Expand All @@ -26,6 +25,9 @@ type TokenBucketFilter struct {
mutex sync.Mutex
rate int
maxBurst int

wg sync.WaitGroup
done chan struct{}
}

// TBFOption is the option type to configure a TokenBucketFilter
Expand Down Expand Up @@ -73,7 +75,7 @@ func (t *TokenBucketFilter) Set(opts ...TBFOption) (previous TBFOption) {
}

// NewTokenBucketFilter creates and starts a new TokenBucketFilter
func NewTokenBucketFilter(ctx context.Context, n NIC, opts ...TBFOption) (*TokenBucketFilter, error) {
func NewTokenBucketFilter(n NIC, opts ...TBFOption) (*TokenBucketFilter, error) {
tbf := &TokenBucketFilter{
NIC: n,
currentTokensInBucket: 0,
Expand All @@ -83,24 +85,29 @@ func NewTokenBucketFilter(ctx context.Context, n NIC, opts ...TBFOption) (*Token
mutex: sync.Mutex{},
rate: 1 * MBit,
maxBurst: 2 * KBit,
wg: sync.WaitGroup{},
done: make(chan struct{}),
}
tbf.Set(opts...)
tbf.queue = newChunkQueue(0, tbf.queueSize)
go tbf.run(ctx)
tbf.wg.Add(1)
go tbf.run()
return tbf, nil
}

func (t *TokenBucketFilter) onInboundChunk(c Chunk) {
t.c <- c
}

func (t *TokenBucketFilter) run(ctx context.Context) {
func (t *TokenBucketFilter) run() {
defer t.wg.Done()
ticker := time.NewTicker(1 * time.Millisecond)

for {
select {
case <-ctx.Done():
case <-t.done:
ticker.Stop()
t.drainQueue()
return
case <-ticker.C:
t.mutex.Lock()
Expand All @@ -121,6 +128,9 @@ func (t *TokenBucketFilter) run(ctx context.Context) {
func (t *TokenBucketFilter) drainQueue() {
for {
next := t.queue.peek()
if next == nil {
break
}
tokens := len(next.UserData())
if t.currentTokensInBucket < tokens {
break
Expand All @@ -130,3 +140,10 @@ func (t *TokenBucketFilter) drainQueue() {
t.currentTokensInBucket -= tokens
}
}

// Close closes and stops the token bucket filter queue
func (t *TokenBucketFilter) Close() error {
close(t.done)
t.wg.Wait()
return nil
}
235 changes: 68 additions & 167 deletions vnet/tbf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package vnet

import (
"context"
"fmt"
"sync"
"testing"
"time"
Expand All @@ -11,209 +10,111 @@ import (
"github.com/stretchr/testify/assert"
)

func TestRouterBandwidth(t *testing.T) {
loggerFactory := logging.NewDefaultLoggerFactory()
log := loggerFactory.NewLogger("test")
func TestTokenBucketFilter(t *testing.T) {
t.Run("bitrateBelowCapacity", func(t *testing.T) {
mnic := newMockNIC(t)

leftAddr, rightAddr := "1.2.3.4", "1.2.3.5"

subTest := func(t *testing.T, capacity int, duration time.Duration) {
wan, err := NewRouter(&RouterConfig{
CIDR: "1.2.3.0/24",
LoggerFactory: loggerFactory,
})
tbf, err := NewTokenBucketFilter(mnic, TBFRate(10*MBit), TBFMaxBurst(10*MBit))
assert.NoError(t, err)
assert.NotNil(t, wan)

leftNet := NewNet(&NetConfig{
StaticIPs: []string{leftAddr},
})
received := 0
mnic.mockOnInboundChunk = func(Chunk) {
received++
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
time.Sleep(1 * time.Second)

// Configure network bandwidth capacity:
tbf, err := NewTokenBucketFilter(ctx, leftNet, TBFRate(capacity))
assert.NoError(t, err)
sent := 100
for i := 0; i < sent; i++ {
tbf.onInboundChunk(&chunkUDP{
userData: make([]byte, 1200),
})
}

err = wan.AddNet(tbf)
assert.NoError(t, err)
assert.NoError(t, tbf.Close())

rightNet := NewNet(&NetConfig{
StaticIPs: []string{rightAddr},
})
err = wan.AddNet(rightNet)
assert.NoError(t, err)
assert.Equal(t, sent, received)
})

subTest := func(t *testing.T, capacity int, duration time.Duration) {
log := logging.NewDefaultLoggerFactory().NewLogger("test")

mnic := newMockNIC(t)

err = wan.Start()
tbf, err := NewTokenBucketFilter(mnic, TBFRate(capacity))
assert.NoError(t, err)
defer func() {
err = wan.Stop()
assert.NoError(t, err)
}()

done := make(chan struct{})
type metrics struct {
packets int
bytes int
chunkChan := make(chan Chunk)
mnic.mockOnInboundChunk = func(c Chunk) {
chunkChan <- c
}
received := make(chan metrics, 1000)
sent := make(chan metrics, 1000)

var logWG sync.WaitGroup
logWG.Add(1)
var wg sync.WaitGroup
wg.Add(1)

go func() {
defer logWG.Done()
bytesReceived := 0
pktReceived := 0
bytesSent := 0
pktSent := 0
ctx, cancel := context.WithCancel(context.Background())

totalPktSent := 0
totalPktReceived := 0
totalBytesReceived := 0
var start time.Time
go func() {
defer wg.Done()

defer func() {
d := time.Since(start)
lossRate := 100 * (1 - float64(totalPktReceived)/float64(totalPktSent))
bits := float64(totalBytesReceived) * 8.0
rate := bits / d.Seconds()

assert.Less(t, rate, float64(capacity))
assert.Greater(t, rate, float64(0))

mBitPerSecond := rate / float64(MBit)
log.Infof("total packets received: %v / %v, lossrate=%.2f%%, throughput=%.2f Mb/s\n", totalPktReceived, totalPktSent, lossRate, mBitPerSecond)
}()

ticker := time.NewTicker(1 * time.Second)
start = time.Now()
lastLog := start
updateCapacity := time.After(duration / 2)
currentCapacity := capacity
capacityUpdated := false
bytesReceived := 0
packetsReceived := 0
start := time.Now()
for {
select {
case <-updateCapacity:
capacityUpdated = true
tbf.Set(TBFRate(capacity / 2))
case <-done:
for r := range received {
pktReceived += r.packets
bytesReceived += r.bytes
totalPktReceived += r.packets
totalBytesReceived += r.bytes
}
for s := range sent {
pktSent += s.packets
bytesSent += s.bytes
totalPktSent += s.packets
}
return
case r := <-received:
pktReceived += r.packets
bytesReceived += r.bytes
totalPktReceived += r.packets
totalBytesReceived += r.bytes
case s := <-sent:
pktSent += s.packets
bytesSent += s.bytes
totalPktSent += s.packets

case now := <-ticker.C:
d := now.Sub(lastLog)
lastLog = now
bits := float64(bytesReceived) * 8
rate := bits / d.Seconds()
rateInMbit := rate / float64(MBit)
log.Infof("sent: %v B / %v P, received %v B / %v P => %.2f Mb/s\n", bytesSent, pktSent, bytesReceived, pktReceived, rateInMbit)

maxCap := float64(currentCapacity) + 0.1*float64(currentCapacity)
assert.Less(t, rate, maxCap)

if capacityUpdated {
currentCapacity /= 2
capacityUpdated = false
}

pktReceived = 0
bytesReceived = 0
pktSent = 0
bytesSent = 0
}
}
}()
case <-ctx.Done():
bits := float64(bytesReceived) * 8.0
rate := bits / time.Since(start).Seconds()
mBitPerSecond := rate / float64(MBit)

connLeft, err := leftNet.ListenPacket("udp4", fmt.Sprintf("%v:0", leftAddr))
assert.NoError(t, err)
// Allow 5% more than capacity due to max bursts
assert.Less(t, rate, 1.05*float64(capacity))
assert.Greater(t, rate, float64(0))

go func() {
defer close(received)
buf := make([]byte, 1500)
for {
n, _, err1 := connLeft.ReadFrom(buf)
if err1 != nil {
break
}
received <- metrics{
packets: 1,
bytes: n,
log.Infof("duration=%v, bytesReceived=%v, packetsReceived=%v throughput=%.2f Mb/s\n", time.Since(start), bytesReceived, packetsReceived, mBitPerSecond)
return

case c := <-chunkChan:
bytesReceived += len(c.UserData())
packetsReceived++
}
}
}()

connRight, err := rightNet.ListenPacket("udp", fmt.Sprintf("%v:0", rightAddr))
assert.NoError(t, err)

var wg sync.WaitGroup
wg.Add(1)

raddr := connLeft.LocalAddr()
go func() {
defer wg.Done()
defer func() {
err1 := connRight.Close()
assert.NoError(t, err1)
}()
defer close(done)
defer close(sent)
timer := time.NewTicker(duration)
buf := make([]byte, 1500)
for {
select {
case <-timer.C:
return
default:
}
n, err1 := connRight.WriteTo(buf, raddr)
assert.NoError(t, err1)
sent <- metrics{
packets: 1,
bytes: n,
defer cancel()
bytesSent := 0
packetsSent := 0
var start time.Time
for start = time.Now(); time.Since(start) < duration; {
c := &chunkUDP{
userData: make([]byte, 1200),
}
time.Sleep(5 * time.Nanosecond)
tbf.onInboundChunk(c)
bytesSent += len(c.UserData())
packetsSent++
time.Sleep(1 * time.Millisecond)
}
bits := float64(bytesSent) * 8.0
rate := bits / time.Since(start).Seconds()
mBitPerSecond := rate / float64(MBit)
log.Infof("duration=%v, bytesSent=%v, pacetsSent=%v throughput=%.2f Mb/s\n", time.Since(start), bytesSent, packetsSent, mBitPerSecond)

assert.NoError(t, tbf.Close())
}()

wg.Wait()
err = connLeft.Close()
assert.NoError(t, err)
logWG.Wait()
}

t.Run("Router bandwidth 500Kbit", func(t *testing.T) {
t.Run("500Kbit-s", func(t *testing.T) {
subTest(t, 500*KBit, 10*time.Second)
})

time.Sleep(2 * time.Second)
t.Run("Router bandwidth 1Mbit", func(t *testing.T) {
t.Run("1Mbit-s", func(t *testing.T) {
subTest(t, 1*MBit, 10*time.Second)
})

time.Sleep(2 * time.Second)
t.Run("Router bandwidth 2Mbit", func(t *testing.T) {
t.Run("2Mbit-s", func(t *testing.T) {
subTest(t, 2*MBit, 10*time.Second)
})
}

0 comments on commit 78491f1

Please sign in to comment.