From fbd35b1f48ea9325ad807236867d1741aa23f12b Mon Sep 17 00:00:00 2001 From: Yutaka Takeda Date: Tue, 12 Mar 2019 14:38:05 -0700 Subject: [PATCH 1/2] Added dropNWrites() to Bridge To be used by testing in sctp Relates to pions/sctp#11 --- test/bridge.go | 43 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/test/bridge.go b/test/bridge.go index b81934c..19bb84f 100644 --- a/test/bridge.go +++ b/test/bridge.go @@ -7,6 +7,15 @@ import ( "time" ) +type bridgeConnAddr int + +func (a bridgeConnAddr) Network() string { + return "udp" +} +func (a bridgeConnAddr) String() string { + return fmt.Sprintf("a%d", a) +} + // bridgeConn is a net.Conn that represents an endpoint of the bridge. type bridgeConn struct { br *Bridge @@ -46,7 +55,9 @@ func (conn *bridgeConn) Close() error { } // LocalAddr is not used -func (conn *bridgeConn) LocalAddr() net.Addr { return nil } +func (conn *bridgeConn) LocalAddr() net.Addr { + return bridgeConnAddr(conn.id) +} // RemoteAddr is not used func (conn *bridgeConn) RemoteAddr() net.Addr { return nil } @@ -68,6 +79,9 @@ type Bridge struct { queue0to1 [][]byte queue1to0 [][]byte + + dropNWrites0 int + dropNWrites1 int } func inverse(s [][]byte) error { @@ -126,9 +140,19 @@ func (br *Bridge) Push(d []byte, fromID int) { defer br.mutex.Unlock() if fromID == 0 { - br.queue0to1 = append(br.queue0to1, d) + if br.dropNWrites0 > 0 { + br.dropNWrites0-- + //fmt.Printf("br: dropped a packet (rem: %d for q0)\n", br.dropNWrites0) + } else { + br.queue0to1 = append(br.queue0to1, d) + } } else { - br.queue1to0 = append(br.queue1to0, d) + if br.dropNWrites1 > 0 { + br.dropNWrites1-- + //fmt.Printf("br: dropped a packet (rem: %d for q1)\n", br.dropNWrites1) + } else { + br.queue1to0 = append(br.queue1to0, d) + } } } @@ -156,6 +180,19 @@ func (br *Bridge) Drop(fromID, offset, n int) { } } +// DropNextNWrites drops the next n packets that will be written +// to the specified queue. +func (br *Bridge) DropNextNWrites(fromID, n int) { + br.mutex.Lock() + defer br.mutex.Unlock() + + if fromID == 0 { + br.dropNWrites0 = n + } else { + br.dropNWrites1 = n + } +} + // Tick attempts to hand a packet from the queue for each directions, to readers, // if there are waiting on the queue. If there's no reader, it will return // immediately. From a1b6637c31b27af8ba11a74c5fcc72b9855daf49 Mon Sep 17 00:00:00 2001 From: Yutaka Takeda Date: Thu, 14 Mar 2019 22:24:02 -0700 Subject: [PATCH 2/2] Added test for bridge.DropNextNWrites Relates to pions/sctp#24 --- test/bridge_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/test/bridge_test.go b/test/bridge_test.go index 32d3399..6f4fb6d 100644 --- a/test/bridge_test.go +++ b/test/bridge_test.go @@ -1,6 +1,8 @@ package test import ( + "fmt" + "net" "testing" ) @@ -17,6 +19,7 @@ func closeBridge(br *Bridge) error { type AsyncResult struct { n int err error + msg string } func TestBridge(t *testing.T) { @@ -29,6 +32,22 @@ func TestBridge(t *testing.T) { conn0 := br.GetConn0() conn1 := br.GetConn1() + if conn0.LocalAddr().String() != "a0" { + t.Error("conn0 local addr name should be a0") + } + + if conn1.LocalAddr().String() != "a1" { + t.Error("conn0 local addr name should be a1") + } + + if conn0.LocalAddr().Network() != "udp" { + t.Error("conn0 local addr name should be a0") + } + + if conn1.LocalAddr().Network() != "udp" { + t.Error("conn0 local addr name should be a1") + } + n, err := conn0.Write([]byte(msg)) if err != nil { t.Error(err.Error()) @@ -359,4 +378,74 @@ func TestBridge(t *testing.T) { t.Error("read should fail as conn is closed") } }) + + t.Run("drop next N packets", func(t *testing.T) { + testFrom := func(t *testing.T, fromID int) { + readRes := make(chan AsyncResult, 5) + br := NewBridge() + conn0 := br.GetConn0() + conn1 := br.GetConn1() + var srcConn, dstConn net.Conn + + if fromID == 0 { + br.DropNextNWrites(0, 3) + srcConn = conn0 + dstConn = conn1 + } else { + br.DropNextNWrites(1, 3) + srcConn = conn1 + dstConn = conn0 + } + + go func() { + for { + nInner, errInner := dstConn.Read(buf) + if errInner != nil { + break + } + readRes <- AsyncResult{ + n: nInner, + err: nil, + msg: string(buf)} + } + }() + msgs := make([]string, 0) + + for i := 0; i < 5; i++ { + msg := fmt.Sprintf("msg%d", i) + msgs = append(msgs, msg) + n, err := srcConn.Write([]byte(msg)) + if err != nil { + t.Errorf("[%d] %s", fromID, err.Error()) + } + if n != len(msg) { + t.Errorf("[%d] unexpected length", fromID) + } + + br.Process() + } + + nResults := len(readRes) + if nResults != 2 { + t.Errorf("[%d] unexpected number of packets", fromID) + } + + for i := 0; i < 2; i++ { + ar := <-readRes + if ar.err != nil { + t.Errorf("[%d] %s", fromID, ar.err.Error()) + } + if ar.n != len(msgs[i+3]) { + t.Errorf("[%d] unexpected length", fromID) + } + } + + if err := closeBridge(br); err != nil { + t.Errorf("[%d] %s", fromID, err.Error()) + } + } + + testFrom(t, 0) + testFrom(t, 1) + }) }