Skip to content

Commit

Permalink
Merge pull request #68 from databus23/readerfrom-writerto
Browse files Browse the repository at this point in the history
Implement io.ReaderFrom/WriterTo for Conn
  • Loading branch information
pires committed Feb 23, 2021
2 parents fff0abf + ce59419 commit c4bcea2
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 0 deletions.
18 changes: 18 additions & 0 deletions protocol.go
Expand Up @@ -2,6 +2,7 @@ package proxyproto

import (
"bufio"
"io"
"net"
"sync"
"time"
Expand Down Expand Up @@ -237,3 +238,20 @@ func (p *Conn) readHeader() error {

return err
}

// ReadFrom implements the io.ReaderFrom ReadFrom method
func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
if rf, ok := p.conn.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
return io.Copy(p.conn, r)
}

// WriteTo implements io.WriterTo
func (p *Conn) WriteTo(w io.Writer) (int64, error) {
p.once.Do(func() { p.readErr = p.readHeader() })
if p.readErr != nil {
return 0, p.readErr
}
return p.bufReader.WriteTo(w)
}
182 changes: 182 additions & 0 deletions protocol_test.go
Expand Up @@ -9,6 +9,8 @@ import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"io/ioutil"
"net"
"testing"
)
Expand Down Expand Up @@ -741,6 +743,186 @@ func Test_MisconfiguredTLSServerRespondsWithUnderlyingError(t *testing.T) {
}
}

type testConn struct {
readFromCalledWith io.Reader
reads int
net.Conn // nil; crash on any unexpected use
}

func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
c.readFromCalledWith = r
b, err := ioutil.ReadAll(r)
return int64(len(b)), err
}
func (c *testConn) Write(p []byte) (int, error) {
return len(p), nil
}
func (c *testConn) Read(p []byte) (int, error) {
if c.reads == 0 {
return 0, io.EOF
}
c.reads--
return 1, nil
}

func TestCopyToWrappedConnection(t *testing.T) {
innerConn := &testConn{}
wrappedConn := NewConn(innerConn)
dummySrc := &testConn{reads: 1}

io.Copy(wrappedConn, dummySrc)
if innerConn.readFromCalledWith != dummySrc {
t.Error("Expected io.Copy to delegate to ReadFrom function of inner destination connection")
}
}

func TestCopyFromWrappedConnection(t *testing.T) {
wrappedConn := NewConn(&testConn{reads: 1})
dummyDst := &testConn{}

io.Copy(dummyDst, wrappedConn)
if dummyDst.readFromCalledWith != wrappedConn.conn {
t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom method of destination")
}
}

func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
innerConn1 := &testConn{reads: 1}
wrappedConn1 := NewConn(innerConn1)
innerConn2 := &testConn{}
wrappedConn2 := NewConn(innerConn2)

io.Copy(wrappedConn1, wrappedConn2)
if innerConn1.readFromCalledWith != innerConn2 {
t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom of inner destination connection")
}
}

func benchmarkTCPProxy(size int, b *testing.B) {
//create and start the echo backend
backend, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
b.Fatalf("err: %v", err)
}
defer backend.Close()
go func() {
for {
conn, err := backend.Accept()
if err != nil {
break
}
_, err = io.Copy(conn, conn)
conn.Close()
if err != nil {
b.Fatalf("Failed to read entire payload: %v", err)
}
}
}()

//start the proxyprotocol enabled tcp proxy
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
b.Fatalf("err: %v", err)
}
defer l.Close()
pl := &Listener{Listener: l}
go func() {
for {
conn, err := pl.Accept()
if err != nil {
break
}
bConn, err := net.Dial("tcp", backend.Addr().String())
if err != nil {
b.Fatalf("failed to dial backend: %v", err)
}
go func() {
_, err = io.Copy(bConn, conn)
if err != nil {
b.Fatalf("Failed to proxy incoming data to backend: %v", err)
}
bConn.(*net.TCPConn).CloseWrite()
}()
_, err = io.Copy(conn, bConn)
if err != nil {
b.Fatalf("Failed to proxy data from backend: %v", err)
}
conn.Close()
bConn.Close()
}
}()

data := make([]byte, size)

header := &Header{
Version: 2,
Command: PROXY,
TransportProtocol: TCPv4,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP("10.1.1.1"),
Port: 1000,
},
DestinationAddr: &net.TCPAddr{
IP: net.ParseIP("20.2.2.2"),
Port: 2000,
},
}

//now for the actual benchmark
b.ResetTimer()
for n := 0; n < b.N; n++ {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
b.Fatalf("err: %v", err)
}
// Write out the header!
header.WriteTo(conn)
//send data
go func() {
_, err = conn.Write(data)
if err != nil {
b.Fatalf("Failed to write data: %v", err)
}
conn.(*net.TCPConn).CloseWrite()

}()
//receive data
n, err := io.Copy(ioutil.Discard, conn)
if n != int64(len(data)) {
b.Fatalf("Expected to receive %d bytes, got %d", len(data), n)
}
if err != nil {
b.Fatalf("Failed to read data: %v", err)
}
conn.Close()
}
}

func BenchmarkTCPProxy16KB(b *testing.B) {
benchmarkTCPProxy(16*1024, b)
}
func BenchmarkTCPProxy32KB(b *testing.B) {
benchmarkTCPProxy(32*1024, b)
}
func BenchmarkTCPProxy64KB(b *testing.B) {
benchmarkTCPProxy(64*1024, b)
}
func BenchmarkTCPProxy128KB(b *testing.B) {
benchmarkTCPProxy(128*1024, b)
}
func BenchmarkTCPProxy256KB(b *testing.B) {
benchmarkTCPProxy(256*1024, b)
}
func BenchmarkTCPProxy512KB(b *testing.B) {
benchmarkTCPProxy(512*1024, b)
}
func BenchmarkTCPProxy1024KB(b *testing.B) {
benchmarkTCPProxy(1024*1024, b)
}
func BenchmarkTCPProxy2048KB(b *testing.B) {
benchmarkTCPProxy(2048*1024, b)
}

// copied from src/net/http/internal/testcert.go

// Copyright 2015 The Go Authors. All rights reserved.
Expand Down

0 comments on commit c4bcea2

Please sign in to comment.