Skip to content

Commit

Permalink
Merge a1b6637 into 0d2a2d9
Browse files Browse the repository at this point in the history
  • Loading branch information
enobufs committed Mar 15, 2019
2 parents 0d2a2d9 + a1b6637 commit 0f78156
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 3 deletions.
43 changes: 40 additions & 3 deletions test/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ import (
"time"
)

type bridgeConnAddr int

func (a bridgeConnAddr) Network() string {
return "udp"
}
func (a bridgeConnAddr) String() string {
return fmt.Sprintf("a%d", a)
}

// bridgeConn is a net.Conn that represents an endpoint of the bridge.
type bridgeConn struct {
br *Bridge
Expand Down Expand Up @@ -46,7 +55,9 @@ func (conn *bridgeConn) Close() error {
}

// LocalAddr is not used
func (conn *bridgeConn) LocalAddr() net.Addr { return nil }
func (conn *bridgeConn) LocalAddr() net.Addr {
return bridgeConnAddr(conn.id)
}

// RemoteAddr is not used
func (conn *bridgeConn) RemoteAddr() net.Addr { return nil }
Expand All @@ -68,6 +79,9 @@ type Bridge struct {

queue0to1 [][]byte
queue1to0 [][]byte

dropNWrites0 int
dropNWrites1 int
}

func inverse(s [][]byte) error {
Expand Down Expand Up @@ -126,9 +140,19 @@ func (br *Bridge) Push(d []byte, fromID int) {
defer br.mutex.Unlock()

if fromID == 0 {
br.queue0to1 = append(br.queue0to1, d)
if br.dropNWrites0 > 0 {
br.dropNWrites0--
//fmt.Printf("br: dropped a packet (rem: %d for q0)\n", br.dropNWrites0)
} else {
br.queue0to1 = append(br.queue0to1, d)
}
} else {
br.queue1to0 = append(br.queue1to0, d)
if br.dropNWrites1 > 0 {
br.dropNWrites1--
//fmt.Printf("br: dropped a packet (rem: %d for q1)\n", br.dropNWrites1)
} else {
br.queue1to0 = append(br.queue1to0, d)
}
}
}

Expand Down Expand Up @@ -156,6 +180,19 @@ func (br *Bridge) Drop(fromID, offset, n int) {
}
}

// DropNextNWrites drops the next n packets that will be written
// to the specified queue.
func (br *Bridge) DropNextNWrites(fromID, n int) {
br.mutex.Lock()
defer br.mutex.Unlock()

if fromID == 0 {
br.dropNWrites0 = n
} else {
br.dropNWrites1 = n
}
}

// Tick attempts to hand a packet from the queue for each directions, to readers,
// if there are waiting on the queue. If there's no reader, it will return
// immediately.
Expand Down
89 changes: 89 additions & 0 deletions test/bridge_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package test

import (
"fmt"
"net"
"testing"
)

Expand All @@ -17,6 +19,7 @@ func closeBridge(br *Bridge) error {
type AsyncResult struct {
n int
err error
msg string
}

func TestBridge(t *testing.T) {
Expand All @@ -29,6 +32,22 @@ func TestBridge(t *testing.T) {
conn0 := br.GetConn0()
conn1 := br.GetConn1()

if conn0.LocalAddr().String() != "a0" {
t.Error("conn0 local addr name should be a0")
}

if conn1.LocalAddr().String() != "a1" {
t.Error("conn0 local addr name should be a1")
}

if conn0.LocalAddr().Network() != "udp" {
t.Error("conn0 local addr name should be a0")
}

if conn1.LocalAddr().Network() != "udp" {
t.Error("conn0 local addr name should be a1")
}

n, err := conn0.Write([]byte(msg))
if err != nil {
t.Error(err.Error())
Expand Down Expand Up @@ -359,4 +378,74 @@ func TestBridge(t *testing.T) {
t.Error("read should fail as conn is closed")
}
})

t.Run("drop next N packets", func(t *testing.T) {
testFrom := func(t *testing.T, fromID int) {
readRes := make(chan AsyncResult, 5)
br := NewBridge()
conn0 := br.GetConn0()
conn1 := br.GetConn1()
var srcConn, dstConn net.Conn

if fromID == 0 {
br.DropNextNWrites(0, 3)
srcConn = conn0
dstConn = conn1
} else {
br.DropNextNWrites(1, 3)
srcConn = conn1
dstConn = conn0
}

go func() {
for {
nInner, errInner := dstConn.Read(buf)
if errInner != nil {
break
}
readRes <- AsyncResult{
n: nInner,
err: nil,
msg: string(buf)}
}
}()
msgs := make([]string, 0)

for i := 0; i < 5; i++ {
msg := fmt.Sprintf("msg%d", i)
msgs = append(msgs, msg)
n, err := srcConn.Write([]byte(msg))
if err != nil {
t.Errorf("[%d] %s", fromID, err.Error())
}
if n != len(msg) {
t.Errorf("[%d] unexpected length", fromID)
}

br.Process()
}

nResults := len(readRes)
if nResults != 2 {
t.Errorf("[%d] unexpected number of packets", fromID)
}

for i := 0; i < 2; i++ {
ar := <-readRes
if ar.err != nil {
t.Errorf("[%d] %s", fromID, ar.err.Error())
}
if ar.n != len(msgs[i+3]) {
t.Errorf("[%d] unexpected length", fromID)
}
}

if err := closeBridge(br); err != nil {
t.Errorf("[%d] %s", fromID, err.Error())
}
}

testFrom(t, 0)
testFrom(t, 1)
})
}

0 comments on commit 0f78156

Please sign in to comment.