diff --git a/send_queue_test.go b/send_queue_test.go index 91d49cb3377..b245ec68169 100644 --- a/send_queue_test.go +++ b/send_queue_test.go @@ -1,16 +1,17 @@ package quic import ( + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) var _ = Describe("Send Queue", func() { var q *sendQueue - var c *mockConnection + var c *MockConnection BeforeEach(func() { - c = newMockConnection() + c = NewMockConnection(mockCtrl) q = newSendQueue(c) }) @@ -25,8 +26,11 @@ var _ = Describe("Send Queue", func() { } It("sends a packet", func() { - q.Send(getPacket([]byte("foobar"))) + p := getPacket([]byte("foobar")) + q.Send(p) + written := make(chan struct{}) + c.EXPECT().Write(p.raw).Do(func([]byte) { close(written) }) done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -34,7 +38,7 @@ var _ = Describe("Send Queue", func() { close(done) }() - Eventually(c.written).Should(Receive(Equal([]byte("foobar")))) + Eventually(written).Should(BeClosed()) q.Close() Eventually(done).Should(BeClosed()) }) @@ -42,6 +46,9 @@ var _ = Describe("Send Queue", func() { It("blocks sending when too many packets are queued", func() { q.Send(getPacket([]byte("foobar"))) + written := make(chan []byte, 2) + c.EXPECT().Write(gomock.Any()).Do(func(p []byte) { written <- p }).Times(2) + sent := make(chan struct{}) go func() { defer GinkgoRecover() @@ -58,8 +65,8 @@ var _ = Describe("Send Queue", func() { close(done) }() - Eventually(c.written).Should(Receive(Equal([]byte("foobar")))) - Eventually(c.written).Should(Receive(Equal([]byte("raboof")))) + Eventually(written).Should(Receive(Equal([]byte("foobar")))) + Eventually(written).Should(Receive(Equal([]byte("raboof")))) q.Close() Eventually(done).Should(BeClosed()) }) diff --git a/session_test.go b/session_test.go index 253d3307b06..ff443100682 100644 --- a/session_test.go +++ b/session_test.go @@ -27,38 +27,6 @@ import ( "github.com/lucas-clemente/quic-go/internal/wire" ) -type mockConnection struct { - remoteAddr net.Addr - localAddr net.Addr - written chan []byte -} - -func newMockConnection() *mockConnection { - return &mockConnection{ - remoteAddr: &net.UDPAddr{}, - written: make(chan []byte, 100), - } -} - -func (m *mockConnection) Write(p []byte) error { - b := make([]byte, len(p)) - copy(b, p) - select { - case m.written <- b: - default: - panic("mockConnection channel full") - } - return nil -} -func (m *mockConnection) Read([]byte) (int, net.Addr, error) { panic("not implemented") } - -func (m *mockConnection) SetCurrentRemoteAddr(addr net.Addr) { - m.remoteAddr = addr -} -func (m *mockConnection) LocalAddr() net.Addr { return m.localAddr } -func (m *mockConnection) RemoteAddr() net.Addr { return m.remoteAddr } -func (*mockConnection) Close() error { panic("not implemented") } - func areSessionsRunning() bool { var b bytes.Buffer pprof.Lookup("goroutine").WriteTo(&b, 1)